├── LICENSE
├── README.md
├── __init__.py
├── algorithm.png
├── data
├── CIFAR_FS.py
├── FC100.py
├── __init__.py
├── mini_imagenet.py
└── tiered_imagenet.py
├── models
├── R2D2_embedding.py
├── ResNet12_embedding.py
├── __init__.py
├── classification_heads.py
├── dropblock.py
└── protonet_embedding.py
├── requirements.txt
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Learning with Differentiable Convex Optimization
2 | This repository contains the code for the paper:
3 |
4 | [**Meta-Learning with Differentiable Convex Optimization**](https://arxiv.org/pdf/1904.03758.pdf)
5 |
6 | Kwonjoon Lee, [Subhransu Maji](https://people.cs.umass.edu/~smaji/), Avinash Ravichandran, [Stefano Soatto](http://web.cs.ucla.edu/~soatto/)
7 | CVPR 2019 (**Oral**)
8 |
9 |
10 |
11 |
12 | ### Abstract
13 |
14 | Many meta-learning approaches for few-shot learning rely on simple base learners such as nearest-neighbor classifiers. However, even in the few-shot regime, discriminatively trained linear predictors can offer better generalization. We propose to use these predictors as base learners to learn representations for few-shot learning and show they offer better tradeoffs between feature size and performance across a range of few-shot recognition benchmarks. Our objective is to learn feature embeddings that generalize well under a linear classification rule for novel categories. To efficiently solve the objective, we exploit two properties of linear classifiers: implicit differentiation of the optimality conditions of the convex problem and the dual formulation of the optimization problem. This allows us to use high-dimensional embeddings with improved generalization at a modest increase in computational overhead. Our approach, named MetaOptNet, achieves state-of-the-art performance on miniImageNet, tieredImageNet, CIFAR-FS and FC100 few-shot learning benchmarks.
15 |
16 | ### Citation
17 |
18 | If you use this code for your research, please cite our paper:
19 | ```
20 | @inproceedings{lee2019meta,
21 | title={Meta-Learning with Differentiable Convex Optimization},
22 | author={Kwonjoon Lee and Subhransu Maji and Avinash Ravichandran and Stefano Soatto},
23 | booktitle={CVPR},
24 | year={2019}
25 | }
26 | ```
27 |
28 | ## Dependencies
29 | * Python 2.7+ (not tested on Python 3)
30 | * [PyTorch 0.4.0+](http://pytorch.org)
31 | * [qpth 0.0.11+](https://github.com/locuslab/qpth)
32 | * [tqdm](https://github.com/tqdm/tqdm)
33 |
34 | ## Usage
35 |
36 | ### Installation
37 |
38 | 1. Clone this repository:
39 | ```bash
40 | git clone https://github.com/kjunelee/MetaOptNet.git
41 | cd MetaOptNet
42 | ```
43 | 2. Download and decompress dataset files: [**miniImageNet**](https://drive.google.com/file/d/12V7qi-AjrYi6OoJdYcN_k502BM_jcP8D/view?usp=sharing) (courtesy of [**Spyros Gidaris**](https://github.com/gidariss/FewShotWithoutForgetting)), [**tieredImageNet**](https://drive.google.com/open?id=1nVGCTd9ttULRXFezh4xILQ9lUkg0WZCG), [**FC100**](https://drive.google.com/file/d/1_ZsLyqI487NRDQhwvI7rg86FK3YAZvz1/view?usp=sharing), [**CIFAR-FS**](https://drive.google.com/file/d/1GjGMI0q3bgcpcB_CjI40fX54WgLPuTpS/view?usp=sharing)
44 |
45 | 3. For each dataset loader, specify the path to the directory. For example, in MetaOptNet/data/mini_imagenet.py line 30:
46 | ```python
47 | _MINI_IMAGENET_DATASET_DIR = 'path/to/miniImageNet'
48 | ```
49 |
50 | ### Meta-training
51 | 1. To train MetaOptNet-SVM on 5-way miniImageNet benchmark:
52 | ```bash
53 | python train.py --gpu 0,1,2,3 --save-path "./experiments/miniImageNet_MetaOptNet_SVM" --train-shot 15 \
54 | --head SVM --network ResNet --dataset miniImageNet --eps 0.1
55 | ```
56 | As shown in Figure 2, of our paper, we can meta-train the embedding once with a high shot for all meta-testing shots. We don't need to meta-train with all possible meta-test shots unlike in Prototypical Networks.
57 | 2. You can experiment with varying base learners by changing '--head' argument to ProtoNet or Ridge. Also, you can change the backbone architecture to vanilla 4-layer conv net by setting '--network' argument to ProtoNet. For other arguments, please see MetaOptNet/train.py from lines 85 to 114.
58 | 3. To train MetaOptNet-SVM on 5-way tieredImageNet benchmark:
59 | ```bash
60 | python train.py --gpu 0,1,2,3 --save-path "./experiments/tieredImageNet_MetaOptNet_SVM" --train-shot 10 \
61 | --head SVM --network ResNet --dataset tieredImageNet
62 | ```
63 | 3. To train MetaOptNet-RR on 5-way CIFAR-FS benchmark:
64 | ```bash
65 | python train.py --gpu 0 --save-path "./experiments/CIFAR_FS_MetaOptNet_RR" --train-shot 5 \
66 | --head Ridge --network ResNet --dataset CIFAR_FS
67 | ```
68 | 4. To train MetaOptNet-RR on 5-way FC100 benchmark:
69 | ```bash
70 | python train.py --gpu 0 --save-path "./experiments/FC100_MetaOptNet_RR" --train-shot 15 \
71 | --head Ridge --network ResNet --dataset FC100
72 | ```
73 | ### Meta-testing
74 | 1. To test MetaOptNet-SVM on 5-way miniImageNet 1-shot benchmark:
75 | ```
76 | python test.py --gpu 0,1,2,3 --load ./experiments/miniImageNet_MetaOptNet_SVM/best_model.pth --episode 1000 \
77 | --way 5 --shot 1 --query 15 --head SVM --network ResNet --dataset miniImageNet
78 | ```
79 | 2. Similarly, to test MetaOptNet-SVM on 5-way miniImageNet 5-shot benchmark:
80 | ```
81 | python test.py --gpu 0,1,2,3 --load ./experiments/miniImageNet_MetaOptNet_SVM/best_model.pth --episode 1000 \
82 | --way 5 --shot 5 --query 15 --head SVM --network ResNet --dataset miniImageNet
83 | ```
84 |
85 | ## Acknowledgments
86 |
87 | This code is based on the implementations of [**Prototypical Networks**](https://github.com/cyvius96/prototypical-network-pytorch), [**Dynamic Few-Shot Visual Learning without Forgetting**](https://github.com/gidariss/FewShotWithoutForgetting), and [**DropBlock**](https://github.com/miguelvr/dropblock).
88 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Implement your code here.
2 |
--------------------------------------------------------------------------------
/algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kjunelee/MetaOptNet/19601591da8090734e2e69113c9510831da5528a/algorithm.png
--------------------------------------------------------------------------------
/data/CIFAR_FS.py:
--------------------------------------------------------------------------------
1 | # Dataloader of Gidaris & Komodakis, CVPR 2018
2 | # Adapted from:
3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4 | from __future__ import print_function
5 |
6 | import os
7 | import os.path
8 | import numpy as np
9 | import random
10 | import pickle
11 | import json
12 | import math
13 |
14 | import torch
15 | import torch.utils.data as data
16 | import torchvision
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | import torchnet as tnt
20 |
21 | import h5py
22 |
23 | from PIL import Image
24 | from PIL import ImageEnhance
25 |
26 | from pdb import set_trace as breakpoint
27 |
28 |
29 | # Set the appropriate paths of the datasets here.
30 | _CIFAR_FS_DATASET_DIR = '/mnt/cube/datasets/few-shot/CIFAR_FS'
31 |
32 | def buildLabelIndex(labels):
33 | label2inds = {}
34 | for idx, label in enumerate(labels):
35 | if label not in label2inds:
36 | label2inds[label] = []
37 | label2inds[label].append(idx)
38 |
39 | return label2inds
40 |
41 | def load_data(file):
42 | try:
43 | with open(file, 'rb') as fo:
44 | data = pickle.load(fo)
45 | return data
46 | except:
47 | with open(file, 'rb') as f:
48 | u = pickle._Unpickler(f)
49 | u.encoding = 'latin1'
50 | data = u.load()
51 | return data
52 |
53 | class CIFAR_FS(data.Dataset):
54 | def __init__(self, phase='train', do_not_use_random_transf=False):
55 |
56 | assert(phase=='train' or phase=='val' or phase=='test')
57 | self.phase = phase
58 | self.name = 'CIFAR_FS_' + phase
59 |
60 | print('Loading CIFAR-FS dataset - phase {0}'.format(phase))
61 | file_train_categories_train_phase = os.path.join(
62 | _CIFAR_FS_DATASET_DIR,
63 | 'CIFAR_FS_train.pickle')
64 | file_train_categories_val_phase = os.path.join(
65 | _CIFAR_FS_DATASET_DIR,
66 | 'CIFAR_FS_train.pickle')
67 | file_train_categories_test_phase = os.path.join(
68 | _CIFAR_FS_DATASET_DIR,
69 | 'CIFAR_FS_train.pickle')
70 | file_val_categories_val_phase = os.path.join(
71 | _CIFAR_FS_DATASET_DIR,
72 | 'CIFAR_FS_val.pickle')
73 | file_test_categories_test_phase = os.path.join(
74 | _CIFAR_FS_DATASET_DIR,
75 | 'CIFAR_FS_test.pickle')
76 |
77 | if self.phase=='train':
78 | # During training phase we only load the training phase images
79 | # of the training categories (aka base categories).
80 | data_train = load_data(file_train_categories_train_phase)
81 | self.data = data_train['data']
82 | self.labels = data_train['labels']
83 |
84 | self.label2ind = buildLabelIndex(self.labels)
85 | self.labelIds = sorted(self.label2ind.keys())
86 | self.num_cats = len(self.labelIds)
87 | self.labelIds_base = self.labelIds
88 | self.num_cats_base = len(self.labelIds_base)
89 |
90 | elif self.phase=='val' or self.phase=='test':
91 | if self.phase=='test':
92 | # load data that will be used for evaluating the recognition
93 | # accuracy of the base categories.
94 | data_base = load_data(file_train_categories_test_phase)
95 | # load data that will be use for evaluating the few-shot recogniton
96 | # accuracy on the novel categories.
97 | data_novel = load_data(file_test_categories_test_phase)
98 | else: # phase=='val'
99 | # load data that will be used for evaluating the recognition
100 | # accuracy of the base categories.
101 | data_base = load_data(file_train_categories_val_phase)
102 | # load data that will be use for evaluating the few-shot recogniton
103 | # accuracy on the novel categories.
104 | data_novel = load_data(file_val_categories_val_phase)
105 |
106 | self.data = np.concatenate(
107 | [data_base['data'], data_novel['data']], axis=0)
108 | self.labels = data_base['labels'] + data_novel['labels']
109 |
110 | self.label2ind = buildLabelIndex(self.labels)
111 | self.labelIds = sorted(self.label2ind.keys())
112 | self.num_cats = len(self.labelIds)
113 |
114 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
115 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
116 | self.num_cats_base = len(self.labelIds_base)
117 | self.num_cats_novel = len(self.labelIds_novel)
118 | intersection = set(self.labelIds_base) & set(self.labelIds_novel)
119 | assert(len(intersection) == 0)
120 | else:
121 | raise ValueError('Not valid phase {0}'.format(self.phase))
122 |
123 | mean_pix = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
124 |
125 | std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
126 |
127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
128 |
129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
130 |
131 | self.transform = transforms.Compose([
132 | lambda x: np.asarray(x),
133 | transforms.ToTensor(),
134 | normalize
135 | ])
136 | else:
137 |
138 | self.transform = transforms.Compose([
139 | transforms.RandomCrop(32, padding=4),
140 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
141 | transforms.RandomHorizontalFlip(),
142 | lambda x: np.asarray(x),
143 | transforms.ToTensor(),
144 | normalize
145 | ])
146 |
147 | def __getitem__(self, index):
148 | img, label = self.data[index], self.labels[index]
149 | # doing this so that it is consistent with all other datasets
150 | # to return a PIL Image
151 | img = Image.fromarray(img)
152 | if self.transform is not None:
153 | img = self.transform(img)
154 | return img, label
155 |
156 | def __len__(self):
157 | return len(self.data)
158 |
159 |
160 | class FewShotDataloader():
161 | def __init__(self,
162 | dataset,
163 | nKnovel=5, # number of novel categories.
164 | nKbase=-1, # number of base categories.
165 | nExemplars=1, # number of training examples per novel category.
166 | nTestNovel=15*5, # number of test examples for all the novel categories.
167 | nTestBase=15*5, # number of test examples for all the base categories.
168 | batch_size=1, # number of training episodes per batch.
169 | num_workers=4,
170 | epoch_size=2000, # number of batches per epoch.
171 | ):
172 |
173 | self.dataset = dataset
174 | self.phase = self.dataset.phase
175 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
176 | else self.dataset.num_cats_novel)
177 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
178 | self.nKnovel = nKnovel
179 |
180 | max_possible_nKbase = self.dataset.num_cats_base
181 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
182 | if self.phase=='train' and nKbase > 0:
183 | nKbase -= self.nKnovel
184 | max_possible_nKbase -= self.nKnovel
185 |
186 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
187 | self.nKbase = nKbase
188 |
189 | self.nExemplars = nExemplars
190 | self.nTestNovel = nTestNovel
191 | self.nTestBase = nTestBase
192 | self.batch_size = batch_size
193 | self.epoch_size = epoch_size
194 | self.num_workers = num_workers
195 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
196 |
197 | def sampleImageIdsFrom(self, cat_id, sample_size=1):
198 | """
199 | Samples `sample_size` number of unique image ids picked from the
200 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
201 |
202 | Args:
203 | cat_id: a scalar with the id of the category from which images will
204 | be sampled.
205 | sample_size: number of images that will be sampled.
206 |
207 | Returns:
208 | image_ids: a list of length `sample_size` with unique image ids.
209 | """
210 | assert(cat_id in self.dataset.label2ind)
211 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
212 | # Note: random.sample samples elements without replacement.
213 | return random.sample(self.dataset.label2ind[cat_id], sample_size)
214 |
215 | def sampleCategories(self, cat_set, sample_size=1):
216 | """
217 | Samples `sample_size` number of unique categories picked from the
218 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
219 |
220 | Args:
221 | cat_set: string that specifies the set of categories from which
222 | categories will be sampled.
223 | sample_size: number of categories that will be sampled.
224 |
225 | Returns:
226 | cat_ids: a list of length `sample_size` with unique category ids.
227 | """
228 | if cat_set=='base':
229 | labelIds = self.dataset.labelIds_base
230 | elif cat_set=='novel':
231 | labelIds = self.dataset.labelIds_novel
232 | else:
233 | raise ValueError('Not recognized category set {}'.format(cat_set))
234 |
235 | assert(len(labelIds) >= sample_size)
236 | # return sample_size unique categories chosen from labelIds set of
237 | # categories (that can be either self.labelIds_base or self.labelIds_novel)
238 | # Note: random.sample samples elements without replacement.
239 | return random.sample(labelIds, sample_size)
240 |
241 | def sample_base_and_novel_categories(self, nKbase, nKnovel):
242 | """
243 | Samples `nKbase` number of base categories and `nKnovel` number of novel
244 | categories.
245 |
246 | Args:
247 | nKbase: number of base categories
248 | nKnovel: number of novel categories
249 |
250 | Returns:
251 | Kbase: a list of length 'nKbase' with the ids of the sampled base
252 | categories.
253 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
254 | categories.
255 | """
256 | if self.is_eval_mode:
257 | assert(nKnovel <= self.dataset.num_cats_novel)
258 | # sample from the set of base categories 'nKbase' number of base
259 | # categories.
260 | Kbase = sorted(self.sampleCategories('base', nKbase))
261 | # sample from the set of novel categories 'nKnovel' number of novel
262 | # categories.
263 | Knovel = sorted(self.sampleCategories('novel', nKnovel))
264 | else:
265 | # sample from the set of base categories 'nKnovel' + 'nKbase' number
266 | # of categories.
267 | cats_ids = self.sampleCategories('base', nKnovel+nKbase)
268 | assert(len(cats_ids) == (nKnovel+nKbase))
269 | # Randomly pick 'nKnovel' number of fake novel categories and keep
270 | # the rest as base categories.
271 | random.shuffle(cats_ids)
272 | Knovel = sorted(cats_ids[:nKnovel])
273 | Kbase = sorted(cats_ids[nKnovel:])
274 |
275 | return Kbase, Knovel
276 |
277 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
278 | """
279 | Sample `nTestBase` number of images from the `Kbase` categories.
280 |
281 | Args:
282 | Kbase: a list of length `nKbase` with the ids of the categories from
283 | where the images will be sampled.
284 | nTestBase: the total number of images that will be sampled.
285 |
286 | Returns:
287 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
288 | element of each tuple is the image id that was sampled and the
289 | 2nd elemend is its category label (which is in the range
290 | [0, len(Kbase)-1]).
291 | """
292 | Tbase = []
293 | if len(Kbase) > 0:
294 | # Sample for each base category a number images such that the total
295 | # number sampled images of all categories to be equal to `nTestBase`.
296 | KbaseIndices = np.random.choice(
297 | np.arange(len(Kbase)), size=nTestBase, replace=True)
298 | KbaseIndices, NumImagesPerCategory = np.unique(
299 | KbaseIndices, return_counts=True)
300 |
301 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
302 | imd_ids = self.sampleImageIdsFrom(
303 | Kbase[Kbase_idx], sample_size=NumImages)
304 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
305 |
306 | assert(len(Tbase) == nTestBase)
307 |
308 | return Tbase
309 |
310 | def sample_train_and_test_examples_for_novel_categories(
311 | self, Knovel, nTestNovel, nExemplars, nKbase):
312 | """Samples train and test examples of the novel categories.
313 |
314 | Args:
315 | Knovel: a list with the ids of the novel categories.
316 | nTestNovel: the total number of test images that will be sampled
317 | from all the novel categories.
318 | nExemplars: the number of training examples per novel category that
319 | will be sampled.
320 | nKbase: the number of base categories. It is used as offset of the
321 | category index of each sampled image.
322 |
323 | Returns:
324 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The
325 | 1st element of each tuple is the image id that was sampled and
326 | the 2nd element is its category label (which is in the range
327 | [nKbase, nKbase + len(Knovel) - 1]).
328 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element
329 | tuples. The 1st element of each tuple is the image id that was
330 | sampled and the 2nd element is its category label (which is in
331 | the ragne [nKbase, nKbase + len(Knovel) - 1]).
332 | """
333 |
334 | if len(Knovel) == 0:
335 | return [], []
336 |
337 | nKnovel = len(Knovel)
338 | Tnovel = []
339 | Exemplars = []
340 | assert((nTestNovel % nKnovel) == 0)
341 | nEvalExamplesPerClass = int(nTestNovel / nKnovel)
342 |
343 | for Knovel_idx in range(len(Knovel)):
344 | imd_ids = self.sampleImageIdsFrom(
345 | Knovel[Knovel_idx],
346 | sample_size=(nEvalExamplesPerClass + nExemplars))
347 |
348 | imds_tnovel = imd_ids[:nEvalExamplesPerClass]
349 | imds_ememplars = imd_ids[nEvalExamplesPerClass:]
350 |
351 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
352 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
353 | assert(len(Tnovel) == nTestNovel)
354 | assert(len(Exemplars) == len(Knovel) * nExemplars)
355 | random.shuffle(Exemplars)
356 |
357 | return Tnovel, Exemplars
358 |
359 | def sample_episode(self):
360 | """Samples a training episode."""
361 | nKnovel = self.nKnovel
362 | nKbase = self.nKbase
363 | nTestNovel = self.nTestNovel
364 | nTestBase = self.nTestBase
365 | nExemplars = self.nExemplars
366 |
367 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
368 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
369 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
370 | Knovel, nTestNovel, nExemplars, nKbase)
371 |
372 | # concatenate the base and novel category examples.
373 | Test = Tbase + Tnovel
374 | random.shuffle(Test)
375 | Kall = Kbase + Knovel
376 |
377 | return Exemplars, Test, Kall, nKbase
378 |
379 | def createExamplesTensorData(self, examples):
380 | """
381 | Creates the examples image and label tensor data.
382 |
383 | Args:
384 | examples: a list of 2-element tuples, each representing a
385 | train or test example. The 1st element of each tuple
386 | is the image id of the example and 2nd element is the
387 | category label of the example, which is in the range
388 | [0, nK - 1], where nK is the total number of categories
389 | (both novel and base).
390 |
391 | Returns:
392 | images: a tensor of shape [nExamples, Height, Width, 3] with the
393 | example images, where nExamples is the number of examples
394 | (i.e., nExamples = len(examples)).
395 | labels: a tensor of shape [nExamples] with the category label
396 | of each example.
397 | """
398 | images = torch.stack(
399 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
400 | labels = torch.LongTensor([label for _, label in examples])
401 | return images, labels
402 |
403 | def get_iterator(self, epoch=0):
404 | rand_seed = epoch
405 | random.seed(rand_seed)
406 | np.random.seed(rand_seed)
407 | def load_function(iter_idx):
408 | Exemplars, Test, Kall, nKbase = self.sample_episode()
409 | Xt, Yt = self.createExamplesTensorData(Test)
410 | Kall = torch.LongTensor(Kall)
411 | if len(Exemplars) > 0:
412 | Xe, Ye = self.createExamplesTensorData(Exemplars)
413 | return Xe, Ye, Xt, Yt, Kall, nKbase
414 | else:
415 | return Xt, Yt, Kall, nKbase
416 |
417 | tnt_dataset = tnt.dataset.ListDataset(
418 | elem_list=range(self.epoch_size), load=load_function)
419 | data_loader = tnt_dataset.parallel(
420 | batch_size=self.batch_size,
421 | num_workers=(0 if self.is_eval_mode else self.num_workers),
422 | shuffle=(False if self.is_eval_mode else True))
423 |
424 | return data_loader
425 |
426 | def __call__(self, epoch=0):
427 | return self.get_iterator(epoch)
428 |
429 | def __len__(self):
430 | return int(self.epoch_size / self.batch_size)
431 |
--------------------------------------------------------------------------------
/data/FC100.py:
--------------------------------------------------------------------------------
1 | # Dataloader of Gidaris & Komodakis, CVPR 2018
2 | # Adapted from:
3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4 | from __future__ import print_function
5 |
6 | import os
7 | import os.path
8 | import numpy as np
9 | import random
10 | import pickle
11 | import json
12 | import math
13 |
14 | import torch
15 | import torch.utils.data as data
16 | import torchvision
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | import torchnet as tnt
20 |
21 | import h5py
22 |
23 | from PIL import Image
24 | from PIL import ImageEnhance
25 |
26 | from pdb import set_trace as breakpoint
27 |
28 |
29 | # Set the appropriate paths of the datasets here.
30 | _FC100_DATASET_DIR = '/mnt/cube/datasets/few-shot/FC100'
31 |
32 | def buildLabelIndex(labels):
33 | label2inds = {}
34 | for idx, label in enumerate(labels):
35 | if label not in label2inds:
36 | label2inds[label] = []
37 | label2inds[label].append(idx)
38 |
39 | return label2inds
40 |
41 | def load_data(file):
42 | try:
43 | with open(file, 'rb') as fo:
44 | data = pickle.load(fo)
45 | return data
46 | except:
47 | with open(file, 'rb') as f:
48 | u = pickle._Unpickler(f)
49 | u.encoding = 'latin1'
50 | data = u.load()
51 | return data
52 |
53 | class FC100(data.Dataset):
54 | def __init__(self, phase='train', do_not_use_random_transf=False):
55 |
56 | assert(phase=='train' or phase=='val' or phase=='test')
57 | self.phase = phase
58 | self.name = 'FC100_' + phase
59 |
60 | print('Loading FC100 dataset - phase {0}'.format(phase))
61 | file_train_categories_train_phase = os.path.join(
62 | _FC100_DATASET_DIR,
63 | 'FC100_train.pickle')
64 | file_train_categories_val_phase = os.path.join(
65 | _FC100_DATASET_DIR,
66 | 'FC100_train.pickle')
67 | file_train_categories_test_phase = os.path.join(
68 | _FC100_DATASET_DIR,
69 | 'FC100_train.pickle')
70 | file_val_categories_val_phase = os.path.join(
71 | _FC100_DATASET_DIR,
72 | 'FC100_val.pickle')
73 | file_test_categories_test_phase = os.path.join(
74 | _FC100_DATASET_DIR,
75 | 'FC100_test.pickle')
76 |
77 | if self.phase=='train':
78 | # During training phase we only load the training phase images
79 | # of the training categories (aka base categories).
80 | data_train = load_data(file_train_categories_train_phase)
81 | self.data = data_train['data']
82 | self.labels = data_train['labels']
83 | #print (self.labels)
84 | self.label2ind = buildLabelIndex(self.labels)
85 | self.labelIds = sorted(self.label2ind.keys())
86 | self.num_cats = len(self.labelIds)
87 | self.labelIds_base = self.labelIds
88 | self.num_cats_base = len(self.labelIds_base)
89 | #print (self.data.shape)
90 | elif self.phase=='val' or self.phase=='test':
91 | if self.phase=='test':
92 | # load data that will be used for evaluating the recognition
93 | # accuracy of the base categories.
94 | data_base = load_data(file_train_categories_test_phase)
95 | # load data that will be use for evaluating the few-shot recogniton
96 | # accuracy on the novel categories.
97 | data_novel = load_data(file_test_categories_test_phase)
98 | else: # phase=='val'
99 | # load data that will be used for evaluating the recognition
100 | # accuracy of the base categories.
101 | data_base = load_data(file_train_categories_val_phase)
102 | # load data that will be use for evaluating the few-shot recogniton
103 | # accuracy on the novel categories.
104 | data_novel = load_data(file_val_categories_val_phase)
105 |
106 | self.data = np.concatenate(
107 | [data_base['data'], data_novel['data']], axis=0)
108 | self.labels = data_base['labels'] + data_novel['labels']
109 |
110 | self.label2ind = buildLabelIndex(self.labels)
111 | self.labelIds = sorted(self.label2ind.keys())
112 | self.num_cats = len(self.labelIds)
113 |
114 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
115 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
116 | self.num_cats_base = len(self.labelIds_base)
117 | self.num_cats_novel = len(self.labelIds_novel)
118 | intersection = set(self.labelIds_base) & set(self.labelIds_novel)
119 | assert(len(intersection) == 0)
120 | else:
121 | raise ValueError('Not valid phase {0}'.format(self.phase))
122 |
123 | mean_pix = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
124 |
125 | std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
126 |
127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
128 |
129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
130 | self.transform = transforms.Compose([
131 | lambda x: np.asarray(x),
132 | transforms.ToTensor(),
133 | normalize
134 | ])
135 | else:
136 | self.transform = transforms.Compose([
137 | transforms.RandomCrop(32, padding=4),
138 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
139 | transforms.RandomHorizontalFlip(),
140 | lambda x: np.asarray(x),
141 | transforms.ToTensor(),
142 | normalize
143 | ])
144 |
145 | def __getitem__(self, index):
146 | img, label = self.data[index], self.labels[index]
147 | # doing this so that it is consistent with all other datasets
148 | # to return a PIL Image
149 | img = Image.fromarray(img)
150 | if self.transform is not None:
151 | img = self.transform(img)
152 | return img, label
153 |
154 | def __len__(self):
155 | return len(self.data)
156 |
157 |
158 | class FewShotDataloader():
159 | def __init__(self,
160 | dataset,
161 | nKnovel=5, # number of novel categories.
162 | nKbase=-1, # number of base categories.
163 | nExemplars=1, # number of training examples per novel category.
164 | nTestNovel=15*5, # number of test examples for all the novel categories.
165 | nTestBase=15*5, # number of test examples for all the base categories.
166 | batch_size=1, # number of training episodes per batch.
167 | num_workers=4,
168 | epoch_size=2000, # number of batches per epoch.
169 | ):
170 |
171 | self.dataset = dataset
172 | self.phase = self.dataset.phase
173 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
174 | else self.dataset.num_cats_novel)
175 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
176 | self.nKnovel = nKnovel
177 |
178 | max_possible_nKbase = self.dataset.num_cats_base
179 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
180 | if self.phase=='train' and nKbase > 0:
181 | nKbase -= self.nKnovel
182 | max_possible_nKbase -= self.nKnovel
183 |
184 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
185 | self.nKbase = nKbase
186 |
187 | self.nExemplars = nExemplars
188 | self.nTestNovel = nTestNovel
189 | self.nTestBase = nTestBase
190 | self.batch_size = batch_size
191 | self.epoch_size = epoch_size
192 | self.num_workers = num_workers
193 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
194 |
195 | def sampleImageIdsFrom(self, cat_id, sample_size=1):
196 | """
197 | Samples `sample_size` number of unique image ids picked from the
198 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
199 |
200 | Args:
201 | cat_id: a scalar with the id of the category from which images will
202 | be sampled.
203 | sample_size: number of images that will be sampled.
204 |
205 | Returns:
206 | image_ids: a list of length `sample_size` with unique image ids.
207 | """
208 | assert(cat_id in self.dataset.label2ind)
209 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
210 | # Note: random.sample samples elements without replacement.
211 | return random.sample(self.dataset.label2ind[cat_id], sample_size)
212 |
213 | def sampleCategories(self, cat_set, sample_size=1):
214 | """
215 | Samples `sample_size` number of unique categories picked from the
216 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
217 |
218 | Args:
219 | cat_set: string that specifies the set of categories from which
220 | categories will be sampled.
221 | sample_size: number of categories that will be sampled.
222 |
223 | Returns:
224 | cat_ids: a list of length `sample_size` with unique category ids.
225 | """
226 | if cat_set=='base':
227 | labelIds = self.dataset.labelIds_base
228 | elif cat_set=='novel':
229 | labelIds = self.dataset.labelIds_novel
230 | else:
231 | raise ValueError('Not recognized category set {}'.format(cat_set))
232 |
233 | assert(len(labelIds) >= sample_size)
234 | # return sample_size unique categories chosen from labelIds set of
235 | # categories (that can be either self.labelIds_base or self.labelIds_novel)
236 | # Note: random.sample samples elements without replacement.
237 | return random.sample(labelIds, sample_size)
238 |
239 | def sample_base_and_novel_categories(self, nKbase, nKnovel):
240 | """
241 | Samples `nKbase` number of base categories and `nKnovel` number of novel
242 | categories.
243 |
244 | Args:
245 | nKbase: number of base categories
246 | nKnovel: number of novel categories
247 |
248 | Returns:
249 | Kbase: a list of length 'nKbase' with the ids of the sampled base
250 | categories.
251 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
252 | categories.
253 | """
254 | if self.is_eval_mode:
255 | assert(nKnovel <= self.dataset.num_cats_novel)
256 | # sample from the set of base categories 'nKbase' number of base
257 | # categories.
258 | Kbase = sorted(self.sampleCategories('base', nKbase))
259 | # sample from the set of novel categories 'nKnovel' number of novel
260 | # categories.
261 | Knovel = sorted(self.sampleCategories('novel', nKnovel))
262 | else:
263 | # sample from the set of base categories 'nKnovel' + 'nKbase' number
264 | # of categories.
265 | cats_ids = self.sampleCategories('base', nKnovel+nKbase)
266 | assert(len(cats_ids) == (nKnovel+nKbase))
267 | # Randomly pick 'nKnovel' number of fake novel categories and keep
268 | # the rest as base categories.
269 | random.shuffle(cats_ids)
270 | Knovel = sorted(cats_ids[:nKnovel])
271 | Kbase = sorted(cats_ids[nKnovel:])
272 |
273 | return Kbase, Knovel
274 |
275 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
276 | """
277 | Sample `nTestBase` number of images from the `Kbase` categories.
278 |
279 | Args:
280 | Kbase: a list of length `nKbase` with the ids of the categories from
281 | where the images will be sampled.
282 | nTestBase: the total number of images that will be sampled.
283 |
284 | Returns:
285 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
286 | element of each tuple is the image id that was sampled and the
287 | 2nd elemend is its category label (which is in the range
288 | [0, len(Kbase)-1]).
289 | """
290 | Tbase = []
291 | if len(Kbase) > 0:
292 | # Sample for each base category a number images such that the total
293 | # number sampled images of all categories to be equal to `nTestBase`.
294 | KbaseIndices = np.random.choice(
295 | np.arange(len(Kbase)), size=nTestBase, replace=True)
296 | KbaseIndices, NumImagesPerCategory = np.unique(
297 | KbaseIndices, return_counts=True)
298 |
299 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
300 | imd_ids = self.sampleImageIdsFrom(
301 | Kbase[Kbase_idx], sample_size=NumImages)
302 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
303 |
304 | assert(len(Tbase) == nTestBase)
305 |
306 | return Tbase
307 |
308 | def sample_train_and_test_examples_for_novel_categories(
309 | self, Knovel, nTestNovel, nExemplars, nKbase):
310 | """Samples train and test examples of the novel categories.
311 |
312 | Args:
313 | Knovel: a list with the ids of the novel categories.
314 | nTestNovel: the total number of test images that will be sampled
315 | from all the novel categories.
316 | nExemplars: the number of training examples per novel category that
317 | will be sampled.
318 | nKbase: the number of base categories. It is used as offset of the
319 | category index of each sampled image.
320 |
321 | Returns:
322 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The
323 | 1st element of each tuple is the image id that was sampled and
324 | the 2nd element is its category label (which is in the range
325 | [nKbase, nKbase + len(Knovel) - 1]).
326 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element
327 | tuples. The 1st element of each tuple is the image id that was
328 | sampled and the 2nd element is its category label (which is in
329 | the ragne [nKbase, nKbase + len(Knovel) - 1]).
330 | """
331 |
332 | if len(Knovel) == 0:
333 | return [], []
334 |
335 | nKnovel = len(Knovel)
336 | Tnovel = []
337 | Exemplars = []
338 | assert((nTestNovel % nKnovel) == 0)
339 | nEvalExamplesPerClass = int(nTestNovel / nKnovel)
340 |
341 | for Knovel_idx in range(len(Knovel)):
342 | imd_ids = self.sampleImageIdsFrom(
343 | Knovel[Knovel_idx],
344 | sample_size=(nEvalExamplesPerClass + nExemplars))
345 |
346 | imds_tnovel = imd_ids[:nEvalExamplesPerClass]
347 | imds_ememplars = imd_ids[nEvalExamplesPerClass:]
348 |
349 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
350 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
351 | assert(len(Tnovel) == nTestNovel)
352 | assert(len(Exemplars) == len(Knovel) * nExemplars)
353 | random.shuffle(Exemplars)
354 |
355 | return Tnovel, Exemplars
356 |
357 | def sample_episode(self):
358 | """Samples a training episode."""
359 | nKnovel = self.nKnovel
360 | nKbase = self.nKbase
361 | nTestNovel = self.nTestNovel
362 | nTestBase = self.nTestBase
363 | nExemplars = self.nExemplars
364 |
365 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
366 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
367 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
368 | Knovel, nTestNovel, nExemplars, nKbase)
369 |
370 | # concatenate the base and novel category examples.
371 | Test = Tbase + Tnovel
372 | random.shuffle(Test)
373 | Kall = Kbase + Knovel
374 |
375 | return Exemplars, Test, Kall, nKbase
376 |
377 | def createExamplesTensorData(self, examples):
378 | """
379 | Creates the examples image and label tensor data.
380 |
381 | Args:
382 | examples: a list of 2-element tuples, each representing a
383 | train or test example. The 1st element of each tuple
384 | is the image id of the example and 2nd element is the
385 | category label of the example, which is in the range
386 | [0, nK - 1], where nK is the total number of categories
387 | (both novel and base).
388 |
389 | Returns:
390 | images: a tensor of shape [nExamples, Height, Width, 3] with the
391 | example images, where nExamples is the number of examples
392 | (i.e., nExamples = len(examples)).
393 | labels: a tensor of shape [nExamples] with the category label
394 | of each example.
395 | """
396 | images = torch.stack(
397 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
398 | labels = torch.LongTensor([label for _, label in examples])
399 | return images, labels
400 |
401 | def get_iterator(self, epoch=0):
402 | rand_seed = epoch
403 | random.seed(rand_seed)
404 | np.random.seed(rand_seed)
405 | def load_function(iter_idx):
406 | Exemplars, Test, Kall, nKbase = self.sample_episode()
407 | Xt, Yt = self.createExamplesTensorData(Test)
408 | Kall = torch.LongTensor(Kall)
409 | if len(Exemplars) > 0:
410 | Xe, Ye = self.createExamplesTensorData(Exemplars)
411 | return Xe, Ye, Xt, Yt, Kall, nKbase
412 | else:
413 | return Xt, Yt, Kall, nKbase
414 |
415 | tnt_dataset = tnt.dataset.ListDataset(
416 | elem_list=range(self.epoch_size), load=load_function)
417 | data_loader = tnt_dataset.parallel(
418 | batch_size=self.batch_size,
419 | num_workers=(0 if self.is_eval_mode else self.num_workers),
420 | shuffle=(False if self.is_eval_mode else True))
421 |
422 | return data_loader
423 |
424 | def __call__(self, epoch=0):
425 | return self.get_iterator(epoch)
426 |
427 | def __len__(self):
428 | return int(self.epoch_size / self.batch_size)
429 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | # Dataloader of Gidaris & Komodakis, CVPR 2018
2 | # Adapted from:
3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4 | from __future__ import print_function
5 |
6 | import os
7 | import os.path
8 | import numpy as np
9 | import random
10 | import pickle
11 | import json
12 | import math
13 |
14 | import torch
15 | import torch.utils.data as data
16 | import torchvision
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | import torchnet as tnt
20 |
21 | import h5py
22 |
23 | from PIL import Image
24 | from PIL import ImageEnhance
25 |
26 | from pdb import set_trace as breakpoint
27 |
28 |
29 | # Set the appropriate paths of the datasets here.
30 | _MINI_IMAGENET_DATASET_DIR = '/efs/data/miniimagenet/kwonl/data/miniImageNet_numpy'
31 |
32 | def buildLabelIndex(labels):
33 | label2inds = {}
34 | for idx, label in enumerate(labels):
35 | if label not in label2inds:
36 | label2inds[label] = []
37 | label2inds[label].append(idx)
38 |
39 | return label2inds
40 |
41 |
42 | def load_data(file):
43 | try:
44 | with open(file, 'rb') as fo:
45 | data = pickle.load(fo)
46 | return data
47 | except:
48 | with open(file, 'rb') as f:
49 | u = pickle._Unpickler(f)
50 | u.encoding = 'latin1'
51 | data = u.load()
52 | return data
53 |
54 | class MiniImageNet(data.Dataset):
55 | def __init__(self, phase='train', do_not_use_random_transf=False):
56 |
57 | self.base_folder = 'miniImagenet'
58 | assert(phase=='train' or phase=='val' or phase=='test')
59 | self.phase = phase
60 | self.name = 'MiniImageNet_' + phase
61 |
62 | print('Loading mini ImageNet dataset - phase {0}'.format(phase))
63 | file_train_categories_train_phase = os.path.join(
64 | _MINI_IMAGENET_DATASET_DIR,
65 | 'miniImageNet_category_split_train_phase_train.pickle')
66 | file_train_categories_val_phase = os.path.join(
67 | _MINI_IMAGENET_DATASET_DIR,
68 | 'miniImageNet_category_split_train_phase_val.pickle')
69 | file_train_categories_test_phase = os.path.join(
70 | _MINI_IMAGENET_DATASET_DIR,
71 | 'miniImageNet_category_split_train_phase_test.pickle')
72 | file_val_categories_val_phase = os.path.join(
73 | _MINI_IMAGENET_DATASET_DIR,
74 | 'miniImageNet_category_split_val.pickle')
75 | file_test_categories_test_phase = os.path.join(
76 | _MINI_IMAGENET_DATASET_DIR,
77 | 'miniImageNet_category_split_test.pickle')
78 |
79 | if self.phase=='train':
80 | # During training phase we only load the training phase images
81 | # of the training categories (aka base categories).
82 | data_train = load_data(file_train_categories_train_phase)
83 | self.data = data_train['data']
84 | self.labels = data_train['labels']
85 |
86 | self.label2ind = buildLabelIndex(self.labels)
87 | self.labelIds = sorted(self.label2ind.keys())
88 | self.num_cats = len(self.labelIds)
89 | self.labelIds_base = self.labelIds
90 | self.num_cats_base = len(self.labelIds_base)
91 |
92 | elif self.phase=='val' or self.phase=='test':
93 | if self.phase=='test':
94 | # load data that will be used for evaluating the recognition
95 | # accuracy of the base categories.
96 | data_base = load_data(file_train_categories_test_phase)
97 | # load data that will be use for evaluating the few-shot recogniton
98 | # accuracy on the novel categories.
99 | data_novel = load_data(file_test_categories_test_phase)
100 | else: # phase=='val'
101 | # load data that will be used for evaluating the recognition
102 | # accuracy of the base categories.
103 | data_base = load_data(file_train_categories_val_phase)
104 | # load data that will be use for evaluating the few-shot recogniton
105 | # accuracy on the novel categories.
106 | data_novel = load_data(file_val_categories_val_phase)
107 |
108 | self.data = np.concatenate(
109 | [data_base['data'], data_novel['data']], axis=0)
110 | self.labels = data_base['labels'] + data_novel['labels']
111 |
112 | self.label2ind = buildLabelIndex(self.labels)
113 | self.labelIds = sorted(self.label2ind.keys())
114 | self.num_cats = len(self.labelIds)
115 |
116 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
117 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
118 | self.num_cats_base = len(self.labelIds_base)
119 | self.num_cats_novel = len(self.labelIds_novel)
120 | intersection = set(self.labelIds_base) & set(self.labelIds_novel)
121 | assert(len(intersection) == 0)
122 | else:
123 | raise ValueError('Not valid phase {0}'.format(self.phase))
124 |
125 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
126 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
128 |
129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
130 | self.transform = transforms.Compose([
131 | lambda x: np.asarray(x),
132 | transforms.ToTensor(),
133 | normalize
134 | ])
135 | else:
136 | self.transform = transforms.Compose([
137 | transforms.RandomCrop(84, padding=8),
138 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
139 | transforms.RandomHorizontalFlip(),
140 | lambda x: np.asarray(x),
141 | transforms.ToTensor(),
142 | normalize
143 | ])
144 |
145 | def __getitem__(self, index):
146 | img, label = self.data[index], self.labels[index]
147 | # doing this so that it is consistent with all other datasets
148 | # to return a PIL Image
149 | img = Image.fromarray(img)
150 | if self.transform is not None:
151 | img = self.transform(img)
152 | return img, label
153 |
154 | def __len__(self):
155 | return len(self.data)
156 |
157 |
158 | class FewShotDataloader():
159 | def __init__(self,
160 | dataset,
161 | nKnovel=5, # number of novel categories.
162 | nKbase=-1, # number of base categories.
163 | nExemplars=1, # number of training examples per novel category.
164 | nTestNovel=15*5, # number of test examples for all the novel categories.
165 | nTestBase=15*5, # number of test examples for all the base categories.
166 | batch_size=1, # number of training episodes per batch.
167 | num_workers=4,
168 | epoch_size=2000, # number of batches per epoch.
169 | ):
170 |
171 | self.dataset = dataset
172 | self.phase = self.dataset.phase
173 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
174 | else self.dataset.num_cats_novel)
175 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
176 | self.nKnovel = nKnovel
177 |
178 | max_possible_nKbase = self.dataset.num_cats_base
179 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
180 | if self.phase=='train' and nKbase > 0:
181 | nKbase -= self.nKnovel
182 | max_possible_nKbase -= self.nKnovel
183 |
184 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
185 | self.nKbase = nKbase
186 |
187 | self.nExemplars = nExemplars
188 | self.nTestNovel = nTestNovel
189 | self.nTestBase = nTestBase
190 | self.batch_size = batch_size
191 | self.epoch_size = epoch_size
192 | self.num_workers = num_workers
193 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
194 |
195 | def sampleImageIdsFrom(self, cat_id, sample_size=1):
196 | """
197 | Samples `sample_size` number of unique image ids picked from the
198 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
199 |
200 | Args:
201 | cat_id: a scalar with the id of the category from which images will
202 | be sampled.
203 | sample_size: number of images that will be sampled.
204 |
205 | Returns:
206 | image_ids: a list of length `sample_size` with unique image ids.
207 | """
208 | assert(cat_id in self.dataset.label2ind)
209 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
210 | # Note: random.sample samples elements without replacement.
211 | return random.sample(self.dataset.label2ind[cat_id], sample_size)
212 |
213 | def sampleCategories(self, cat_set, sample_size=1):
214 | """
215 | Samples `sample_size` number of unique categories picked from the
216 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
217 |
218 | Args:
219 | cat_set: string that specifies the set of categories from which
220 | categories will be sampled.
221 | sample_size: number of categories that will be sampled.
222 |
223 | Returns:
224 | cat_ids: a list of length `sample_size` with unique category ids.
225 | """
226 | if cat_set=='base':
227 | labelIds = self.dataset.labelIds_base
228 | elif cat_set=='novel':
229 | labelIds = self.dataset.labelIds_novel
230 | else:
231 | raise ValueError('Not recognized category set {}'.format(cat_set))
232 |
233 | assert(len(labelIds) >= sample_size)
234 | # return sample_size unique categories chosen from labelIds set of
235 | # categories (that can be either self.labelIds_base or self.labelIds_novel)
236 | # Note: random.sample samples elements without replacement.
237 | return random.sample(labelIds, sample_size)
238 |
239 | def sample_base_and_novel_categories(self, nKbase, nKnovel):
240 | """
241 | Samples `nKbase` number of base categories and `nKnovel` number of novel
242 | categories.
243 |
244 | Args:
245 | nKbase: number of base categories
246 | nKnovel: number of novel categories
247 |
248 | Returns:
249 | Kbase: a list of length 'nKbase' with the ids of the sampled base
250 | categories.
251 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
252 | categories.
253 | """
254 | if self.is_eval_mode:
255 | assert(nKnovel <= self.dataset.num_cats_novel)
256 | # sample from the set of base categories 'nKbase' number of base
257 | # categories.
258 | Kbase = sorted(self.sampleCategories('base', nKbase))
259 | # sample from the set of novel categories 'nKnovel' number of novel
260 | # categories.
261 | Knovel = sorted(self.sampleCategories('novel', nKnovel))
262 | else:
263 | # sample from the set of base categories 'nKnovel' + 'nKbase' number
264 | # of categories.
265 | cats_ids = self.sampleCategories('base', nKnovel+nKbase)
266 | assert(len(cats_ids) == (nKnovel+nKbase))
267 | # Randomly pick 'nKnovel' number of fake novel categories and keep
268 | # the rest as base categories.
269 | random.shuffle(cats_ids)
270 | Knovel = sorted(cats_ids[:nKnovel])
271 | Kbase = sorted(cats_ids[nKnovel:])
272 |
273 | return Kbase, Knovel
274 |
275 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
276 | """
277 | Sample `nTestBase` number of images from the `Kbase` categories.
278 |
279 | Args:
280 | Kbase: a list of length `nKbase` with the ids of the categories from
281 | where the images will be sampled.
282 | nTestBase: the total number of images that will be sampled.
283 |
284 | Returns:
285 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
286 | element of each tuple is the image id that was sampled and the
287 | 2nd elemend is its category label (which is in the range
288 | [0, len(Kbase)-1]).
289 | """
290 | Tbase = []
291 | if len(Kbase) > 0:
292 | # Sample for each base category a number images such that the total
293 | # number sampled images of all categories to be equal to `nTestBase`.
294 | KbaseIndices = np.random.choice(
295 | np.arange(len(Kbase)), size=nTestBase, replace=True)
296 | KbaseIndices, NumImagesPerCategory = np.unique(
297 | KbaseIndices, return_counts=True)
298 |
299 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
300 | imd_ids = self.sampleImageIdsFrom(
301 | Kbase[Kbase_idx], sample_size=NumImages)
302 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
303 |
304 | assert(len(Tbase) == nTestBase)
305 |
306 | return Tbase
307 |
308 | def sample_train_and_test_examples_for_novel_categories(
309 | self, Knovel, nTestNovel, nExemplars, nKbase):
310 | """Samples train and test examples of the novel categories.
311 |
312 | Args:
313 | Knovel: a list with the ids of the novel categories.
314 | nTestNovel: the total number of test images that will be sampled
315 | from all the novel categories.
316 | nExemplars: the number of training examples per novel category that
317 | will be sampled.
318 | nKbase: the number of base categories. It is used as offset of the
319 | category index of each sampled image.
320 |
321 | Returns:
322 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The
323 | 1st element of each tuple is the image id that was sampled and
324 | the 2nd element is its category label (which is in the range
325 | [nKbase, nKbase + len(Knovel) - 1]).
326 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element
327 | tuples. The 1st element of each tuple is the image id that was
328 | sampled and the 2nd element is its category label (which is in
329 | the ragne [nKbase, nKbase + len(Knovel) - 1]).
330 | """
331 |
332 | if len(Knovel) == 0:
333 | return [], []
334 |
335 | nKnovel = len(Knovel)
336 | Tnovel = []
337 | Exemplars = []
338 | assert((nTestNovel % nKnovel) == 0)
339 | nEvalExamplesPerClass = int(nTestNovel / nKnovel)
340 |
341 | for Knovel_idx in range(len(Knovel)):
342 | imd_ids = self.sampleImageIdsFrom(
343 | Knovel[Knovel_idx],
344 | sample_size=(nEvalExamplesPerClass + nExemplars))
345 |
346 | imds_tnovel = imd_ids[:nEvalExamplesPerClass]
347 | imds_ememplars = imd_ids[nEvalExamplesPerClass:]
348 |
349 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
350 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
351 | assert(len(Tnovel) == nTestNovel)
352 | assert(len(Exemplars) == len(Knovel) * nExemplars)
353 | random.shuffle(Exemplars)
354 |
355 | return Tnovel, Exemplars
356 |
357 | def sample_episode(self):
358 | """Samples a training episode."""
359 | nKnovel = self.nKnovel
360 | nKbase = self.nKbase
361 | nTestNovel = self.nTestNovel
362 | nTestBase = self.nTestBase
363 | nExemplars = self.nExemplars
364 |
365 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
366 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
367 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
368 | Knovel, nTestNovel, nExemplars, nKbase)
369 |
370 | # concatenate the base and novel category examples.
371 | Test = Tbase + Tnovel
372 | random.shuffle(Test)
373 | Kall = Kbase + Knovel
374 |
375 | return Exemplars, Test, Kall, nKbase
376 |
377 | def createExamplesTensorData(self, examples):
378 | """
379 | Creates the examples image and label tensor data.
380 |
381 | Args:
382 | examples: a list of 2-element tuples, each representing a
383 | train or test example. The 1st element of each tuple
384 | is the image id of the example and 2nd element is the
385 | category label of the example, which is in the range
386 | [0, nK - 1], where nK is the total number of categories
387 | (both novel and base).
388 |
389 | Returns:
390 | images: a tensor of shape [nExamples, Height, Width, 3] with the
391 | example images, where nExamples is the number of examples
392 | (i.e., nExamples = len(examples)).
393 | labels: a tensor of shape [nExamples] with the category label
394 | of each example.
395 | """
396 | images = torch.stack(
397 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
398 | labels = torch.LongTensor([label for _, label in examples])
399 | return images, labels
400 |
401 | def get_iterator(self, epoch=0):
402 | rand_seed = epoch
403 | random.seed(rand_seed)
404 | np.random.seed(rand_seed)
405 | def load_function(iter_idx):
406 | Exemplars, Test, Kall, nKbase = self.sample_episode()
407 | Xt, Yt = self.createExamplesTensorData(Test)
408 | Kall = torch.LongTensor(Kall)
409 | if len(Exemplars) > 0:
410 | Xe, Ye = self.createExamplesTensorData(Exemplars)
411 | return Xe, Ye, Xt, Yt, Kall, nKbase
412 | else:
413 | return Xt, Yt, Kall, nKbase
414 |
415 | tnt_dataset = tnt.dataset.ListDataset(
416 | elem_list=range(self.epoch_size), load=load_function)
417 | data_loader = tnt_dataset.parallel(
418 | batch_size=self.batch_size,
419 | num_workers=(0 if self.is_eval_mode else self.num_workers),
420 | shuffle=(False if self.is_eval_mode else True))
421 |
422 | return data_loader
423 |
424 | def __call__(self, epoch=0):
425 | return self.get_iterator(epoch)
426 |
427 | def __len__(self):
428 | return int(self.epoch_size / self.batch_size)
429 |
--------------------------------------------------------------------------------
/data/tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | # Dataloader of Gidaris & Komodakis, CVPR 2018
2 | # Adapted from:
3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4 | from __future__ import print_function
5 |
6 | import os
7 | import os.path
8 | import numpy as np
9 | import random
10 | import pickle
11 | import json
12 | import math
13 |
14 | import torch
15 | import torch.utils.data as data
16 | import torchvision
17 | import torchvision.datasets as datasets
18 | import torchvision.transforms as transforms
19 | import torchnet as tnt
20 |
21 | import h5py
22 |
23 | from PIL import Image
24 | from PIL import ImageEnhance
25 |
26 | from pdb import set_trace as breakpoint
27 |
28 |
29 | # Set the appropriate paths of the datasets here.
30 | _TIERED_IMAGENET_DATASET_DIR = '/mnt/cube/datasets/tiered-imagenet/'
31 |
32 | def buildLabelIndex(labels):
33 | label2inds = {}
34 | for idx, label in enumerate(labels):
35 | if label not in label2inds:
36 | label2inds[label] = []
37 | label2inds[label].append(idx)
38 |
39 | return label2inds
40 |
41 |
42 | def load_data(file):
43 | try:
44 | with open(file, 'rb') as fo:
45 | data = pickle.load(fo)
46 | return data
47 | except:
48 | with open(file, 'rb') as f:
49 | u = pickle._Unpickler(f)
50 | u.encoding = 'latin1'
51 | data = u.load()
52 | return data
53 |
54 | class tieredImageNet(data.Dataset):
55 | def __init__(self, phase='train', do_not_use_random_transf=False):
56 |
57 | assert(phase=='train' or phase=='val' or phase=='test')
58 | self.phase = phase
59 | self.name = 'tieredImageNet_' + phase
60 |
61 | print('Loading tiered ImageNet dataset - phase {0}'.format(phase))
62 | file_train_categories_train_phase = os.path.join(
63 | _TIERED_IMAGENET_DATASET_DIR,
64 | 'train_images.npz')
65 | label_train_categories_train_phase = os.path.join(
66 | _TIERED_IMAGENET_DATASET_DIR,
67 | 'train_labels.pkl')
68 | file_train_categories_val_phase = os.path.join(
69 | _TIERED_IMAGENET_DATASET_DIR,
70 | 'train_images.npz')
71 | label_train_categories_val_phase = os.path.join(
72 | _TIERED_IMAGENET_DATASET_DIR,
73 | 'train_labels.pkl')
74 | file_train_categories_test_phase = os.path.join(
75 | _TIERED_IMAGENET_DATASET_DIR,
76 | 'train_images.npz')
77 | label_train_categories_test_phase = os.path.join(
78 | _TIERED_IMAGENET_DATASET_DIR,
79 | 'train_labels.pkl')
80 |
81 | file_val_categories_val_phase = os.path.join(
82 | _TIERED_IMAGENET_DATASET_DIR,
83 | 'val_images.npz')
84 | label_val_categories_val_phase = os.path.join(
85 | _TIERED_IMAGENET_DATASET_DIR,
86 | 'val_labels.pkl')
87 | file_test_categories_test_phase = os.path.join(
88 | _TIERED_IMAGENET_DATASET_DIR,
89 | 'test_images.npz')
90 | label_test_categories_test_phase = os.path.join(
91 | _TIERED_IMAGENET_DATASET_DIR,
92 | 'test_labels.pkl')
93 |
94 | if self.phase=='train':
95 | # During training phase we only load the training phase images
96 | # of the training categories (aka base categories).
97 | data_train = load_data(label_train_categories_train_phase)
98 | #self.data = data_train['data']
99 | self.labels = data_train['labels']
100 | self.data = np.load(file_train_categories_train_phase)['images']#np.array(load_data(file_train_categories_train_phase))
101 | #self.labels = load_data(file_train_categories_train_phase)#data_train['labels']
102 |
103 | self.label2ind = buildLabelIndex(self.labels)
104 | self.labelIds = sorted(self.label2ind.keys())
105 | self.num_cats = len(self.labelIds)
106 | self.labelIds_base = self.labelIds
107 | self.num_cats_base = len(self.labelIds_base)
108 |
109 | elif self.phase=='val' or self.phase=='test':
110 | if self.phase=='test':
111 | # load data that will be used for evaluating the recognition
112 | # accuracy of the base categories.
113 | data_base = load_data(label_train_categories_test_phase)
114 | data_base_images = np.load(file_train_categories_test_phase)['images']
115 |
116 | # load data that will be use for evaluating the few-shot recogniton
117 | # accuracy on the novel categories.
118 | data_novel = load_data(label_test_categories_test_phase)
119 | data_novel_images = np.load(file_test_categories_test_phase)['images']
120 | else: # phase=='val'
121 | # load data that will be used for evaluating the recognition
122 | # accuracy of the base categories.
123 | data_base = load_data(label_train_categories_val_phase)
124 | data_base_images = np.load(file_train_categories_val_phase)['images']
125 | #print (data_base_images)
126 | #print (data_base_images.shape)
127 | # load data that will be use for evaluating the few-shot recogniton
128 | # accuracy on the novel categories.
129 | data_novel = load_data(label_val_categories_val_phase)
130 | data_novel_images = np.load(file_val_categories_val_phase)['images']
131 |
132 | self.data = np.concatenate(
133 | [data_base_images, data_novel_images], axis=0)
134 | self.labels = data_base['labels'] + data_novel['labels']
135 |
136 | self.label2ind = buildLabelIndex(self.labels)
137 | self.labelIds = sorted(self.label2ind.keys())
138 | self.num_cats = len(self.labelIds)
139 |
140 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
141 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
142 | self.num_cats_base = len(self.labelIds_base)
143 | self.num_cats_novel = len(self.labelIds_novel)
144 | intersection = set(self.labelIds_base) & set(self.labelIds_novel)
145 | print (intersection)
146 | assert(len(intersection) == 0)
147 | else:
148 | raise ValueError('Not valid phase {0}'.format(self.phase))
149 |
150 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
151 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
152 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
153 |
154 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
155 | self.transform = transforms.Compose([
156 | lambda x: np.asarray(x),
157 | transforms.ToTensor(),
158 | normalize
159 | ])
160 | else:
161 | self.transform = transforms.Compose([
162 | transforms.RandomCrop(84, padding=8),
163 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
164 | transforms.RandomHorizontalFlip(),
165 | lambda x: np.asarray(x),
166 | transforms.ToTensor(),
167 | normalize
168 | ])
169 |
170 | def __getitem__(self, index):
171 | img, label = self.data[index], self.labels[index]
172 | # doing this so that it is consistent with all other datasets
173 | # to return a PIL Image
174 | img = Image.fromarray(img)
175 | if self.transform is not None:
176 | img = self.transform(img)
177 | return img, label
178 |
179 | def __len__(self):
180 | return len(self.data)
181 |
182 |
183 | class FewShotDataloader():
184 | def __init__(self,
185 | dataset,
186 | nKnovel=5, # number of novel categories.
187 | nKbase=-1, # number of base categories.
188 | nExemplars=1, # number of training examples per novel category.
189 | nTestNovel=15*5, # number of test examples for all the novel categories.
190 | nTestBase=15*5, # number of test examples for all the base categories.
191 | batch_size=1, # number of training episodes per batch.
192 | num_workers=4,
193 | epoch_size=2000, # number of batches per epoch.
194 | ):
195 |
196 | self.dataset = dataset
197 | self.phase = self.dataset.phase
198 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train'
199 | else self.dataset.num_cats_novel)
200 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
201 | self.nKnovel = nKnovel
202 |
203 | max_possible_nKbase = self.dataset.num_cats_base
204 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
205 | if self.phase=='train' and nKbase > 0:
206 | nKbase -= self.nKnovel
207 | max_possible_nKbase -= self.nKnovel
208 |
209 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
210 | self.nKbase = nKbase
211 |
212 | self.nExemplars = nExemplars
213 | self.nTestNovel = nTestNovel
214 | self.nTestBase = nTestBase
215 | self.batch_size = batch_size
216 | self.epoch_size = epoch_size
217 | self.num_workers = num_workers
218 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
219 |
220 | def sampleImageIdsFrom(self, cat_id, sample_size=1):
221 | """
222 | Samples `sample_size` number of unique image ids picked from the
223 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
224 |
225 | Args:
226 | cat_id: a scalar with the id of the category from which images will
227 | be sampled.
228 | sample_size: number of images that will be sampled.
229 |
230 | Returns:
231 | image_ids: a list of length `sample_size` with unique image ids.
232 | """
233 | assert(cat_id in self.dataset.label2ind)
234 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
235 | # Note: random.sample samples elements without replacement.
236 | return random.sample(self.dataset.label2ind[cat_id], sample_size)
237 |
238 | def sampleCategories(self, cat_set, sample_size=1):
239 | """
240 | Samples `sample_size` number of unique categories picked from the
241 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
242 |
243 | Args:
244 | cat_set: string that specifies the set of categories from which
245 | categories will be sampled.
246 | sample_size: number of categories that will be sampled.
247 |
248 | Returns:
249 | cat_ids: a list of length `sample_size` with unique category ids.
250 | """
251 | if cat_set=='base':
252 | labelIds = self.dataset.labelIds_base
253 | elif cat_set=='novel':
254 | labelIds = self.dataset.labelIds_novel
255 | else:
256 | raise ValueError('Not recognized category set {}'.format(cat_set))
257 |
258 | assert(len(labelIds) >= sample_size)
259 | # return sample_size unique categories chosen from labelIds set of
260 | # categories (that can be either self.labelIds_base or self.labelIds_novel)
261 | # Note: random.sample samples elements without replacement.
262 | return random.sample(labelIds, sample_size)
263 |
264 | def sample_base_and_novel_categories(self, nKbase, nKnovel):
265 | """
266 | Samples `nKbase` number of base categories and `nKnovel` number of novel
267 | categories.
268 |
269 | Args:
270 | nKbase: number of base categories
271 | nKnovel: number of novel categories
272 |
273 | Returns:
274 | Kbase: a list of length 'nKbase' with the ids of the sampled base
275 | categories.
276 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
277 | categories.
278 | """
279 | if self.is_eval_mode:
280 | assert(nKnovel <= self.dataset.num_cats_novel)
281 | # sample from the set of base categories 'nKbase' number of base
282 | # categories.
283 | Kbase = sorted(self.sampleCategories('base', nKbase))
284 | # sample from the set of novel categories 'nKnovel' number of novel
285 | # categories.
286 | Knovel = sorted(self.sampleCategories('novel', nKnovel))
287 | else:
288 | # sample from the set of base categories 'nKnovel' + 'nKbase' number
289 | # of categories.
290 | cats_ids = self.sampleCategories('base', nKnovel+nKbase)
291 | assert(len(cats_ids) == (nKnovel+nKbase))
292 | # Randomly pick 'nKnovel' number of fake novel categories and keep
293 | # the rest as base categories.
294 | random.shuffle(cats_ids)
295 | Knovel = sorted(cats_ids[:nKnovel])
296 | Kbase = sorted(cats_ids[nKnovel:])
297 |
298 | return Kbase, Knovel
299 |
300 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
301 | """
302 | Sample `nTestBase` number of images from the `Kbase` categories.
303 |
304 | Args:
305 | Kbase: a list of length `nKbase` with the ids of the categories from
306 | where the images will be sampled.
307 | nTestBase: the total number of images that will be sampled.
308 |
309 | Returns:
310 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
311 | element of each tuple is the image id that was sampled and the
312 | 2nd elemend is its category label (which is in the range
313 | [0, len(Kbase)-1]).
314 | """
315 | Tbase = []
316 | if len(Kbase) > 0:
317 | # Sample for each base category a number images such that the total
318 | # number sampled images of all categories to be equal to `nTestBase`.
319 | KbaseIndices = np.random.choice(
320 | np.arange(len(Kbase)), size=nTestBase, replace=True)
321 | KbaseIndices, NumImagesPerCategory = np.unique(
322 | KbaseIndices, return_counts=True)
323 |
324 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
325 | imd_ids = self.sampleImageIdsFrom(
326 | Kbase[Kbase_idx], sample_size=NumImages)
327 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
328 |
329 | assert(len(Tbase) == nTestBase)
330 |
331 | return Tbase
332 |
333 | def sample_train_and_test_examples_for_novel_categories(
334 | self, Knovel, nTestNovel, nExemplars, nKbase):
335 | """Samples train and test examples of the novel categories.
336 |
337 | Args:
338 | Knovel: a list with the ids of the novel categories.
339 | nTestNovel: the total number of test images that will be sampled
340 | from all the novel categories.
341 | nExemplars: the number of training examples per novel category that
342 | will be sampled.
343 | nKbase: the number of base categories. It is used as offset of the
344 | category index of each sampled image.
345 |
346 | Returns:
347 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The
348 | 1st element of each tuple is the image id that was sampled and
349 | the 2nd element is its category label (which is in the range
350 | [nKbase, nKbase + len(Knovel) - 1]).
351 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element
352 | tuples. The 1st element of each tuple is the image id that was
353 | sampled and the 2nd element is its category label (which is in
354 | the ragne [nKbase, nKbase + len(Knovel) - 1]).
355 | """
356 |
357 | if len(Knovel) == 0:
358 | return [], []
359 |
360 | nKnovel = len(Knovel)
361 | Tnovel = []
362 | Exemplars = []
363 | assert((nTestNovel % nKnovel) == 0)
364 | nEvalExamplesPerClass = int(nTestNovel / nKnovel)
365 |
366 | for Knovel_idx in range(len(Knovel)):
367 | imd_ids = self.sampleImageIdsFrom(
368 | Knovel[Knovel_idx],
369 | sample_size=(nEvalExamplesPerClass + nExemplars))
370 |
371 | imds_tnovel = imd_ids[:nEvalExamplesPerClass]
372 | imds_ememplars = imd_ids[nEvalExamplesPerClass:]
373 |
374 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
375 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
376 | assert(len(Tnovel) == nTestNovel)
377 | assert(len(Exemplars) == len(Knovel) * nExemplars)
378 | random.shuffle(Exemplars)
379 |
380 | return Tnovel, Exemplars
381 |
382 | def sample_episode(self):
383 | """Samples a training episode."""
384 | nKnovel = self.nKnovel
385 | nKbase = self.nKbase
386 | nTestNovel = self.nTestNovel
387 | nTestBase = self.nTestBase
388 | nExemplars = self.nExemplars
389 |
390 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
391 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
392 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
393 | Knovel, nTestNovel, nExemplars, nKbase)
394 |
395 | # concatenate the base and novel category examples.
396 | Test = Tbase + Tnovel
397 | random.shuffle(Test)
398 | Kall = Kbase + Knovel
399 |
400 | return Exemplars, Test, Kall, nKbase
401 |
402 | def createExamplesTensorData(self, examples):
403 | """
404 | Creates the examples image and label tensor data.
405 |
406 | Args:
407 | examples: a list of 2-element tuples, each representing a
408 | train or test example. The 1st element of each tuple
409 | is the image id of the example and 2nd element is the
410 | category label of the example, which is in the range
411 | [0, nK - 1], where nK is the total number of categories
412 | (both novel and base).
413 |
414 | Returns:
415 | images: a tensor of shape [nExamples, Height, Width, 3] with the
416 | example images, where nExamples is the number of examples
417 | (i.e., nExamples = len(examples)).
418 | labels: a tensor of shape [nExamples] with the category label
419 | of each example.
420 | """
421 | images = torch.stack(
422 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
423 | labels = torch.LongTensor([label for _, label in examples])
424 | return images, labels
425 |
426 | def get_iterator(self, epoch=0):
427 | rand_seed = epoch
428 | random.seed(rand_seed)
429 | np.random.seed(rand_seed)
430 | def load_function(iter_idx):
431 | Exemplars, Test, Kall, nKbase = self.sample_episode()
432 | Xt, Yt = self.createExamplesTensorData(Test)
433 | Kall = torch.LongTensor(Kall)
434 | if len(Exemplars) > 0:
435 | Xe, Ye = self.createExamplesTensorData(Exemplars)
436 | return Xe, Ye, Xt, Yt, Kall, nKbase
437 | else:
438 | return Xt, Yt, Kall, nKbase
439 |
440 | tnt_dataset = tnt.dataset.ListDataset(
441 | elem_list=range(self.epoch_size), load=load_function)
442 | data_loader = tnt_dataset.parallel(
443 | batch_size=self.batch_size,
444 | num_workers=(0 if self.is_eval_mode else self.num_workers),
445 | shuffle=(False if self.is_eval_mode else True))
446 |
447 | return data_loader
448 |
449 | def __call__(self, epoch=0):
450 | return self.get_iterator(epoch)
451 |
452 | def __len__(self):
453 | return int(self.epoch_size / self.batch_size)
454 |
--------------------------------------------------------------------------------
/models/R2D2_embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | # Embedding network used in Meta-learning with differentiable closed-form solvers
5 | # (Bertinetto et al., in submission to NIPS 2018).
6 | # They call the ridge rigressor version as "Ridge Regression Differentiable Discriminator (R2D2)."
7 |
8 | # Note that they use a peculiar ordering of functions, namely conv-BN-pooling-lrelu,
9 | # as opposed to the conventional one (conv-BN-lrelu-pooling).
10 |
11 | def R2D2_conv_block(in_channels, out_channels, retain_activation=True, keep_prob=1.0):
12 | block = nn.Sequential(
13 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
14 | nn.BatchNorm2d(out_channels),
15 | nn.MaxPool2d(2)
16 | )
17 | if retain_activation:
18 | block.add_module("LeakyReLU", nn.LeakyReLU(0.1))
19 |
20 | if keep_prob < 1.0:
21 | block.add_module("Dropout", nn.Dropout(p=1 - keep_prob, inplace=False))
22 |
23 | return block
24 |
25 | class R2D2Embedding(nn.Module):
26 | def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \
27 | retain_last_activation=False):
28 | super(R2D2Embedding, self).__init__()
29 |
30 | self.block1 = R2D2_conv_block(x_dim, h1_dim)
31 | self.block2 = R2D2_conv_block(h1_dim, h2_dim)
32 | self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9)
33 | # In the last conv block, we disable activation function to boost the classification accuracy.
34 | # This trick was proposed by Gidaris et al. (CVPR 2018).
35 | # With this trick, the accuracy goes up from 50% to 51%.
36 | # Although the authors of R2D2 did not mention this trick in the paper,
37 | # we were unable to reproduce the result of Bertinetto et al. without resorting to this trick.
38 | self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7)
39 |
40 | def forward(self, x):
41 | b1 = self.block1(x)
42 | b2 = self.block2(b1)
43 | b3 = self.block3(b2)
44 | b4 = self.block4(b3)
45 | # Flatten and concatenate the output of the 3rd and 4th conv blocks as proposed in R2D2 paper.
46 | return torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1)
--------------------------------------------------------------------------------
/models/ResNet12_embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from models.dropblock import DropBlock
5 |
6 | # This ResNet network was designed following the practice of the following papers:
7 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
8 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = conv3x3(inplanes, planes)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu = nn.LeakyReLU(0.1)
24 | self.conv2 = conv3x3(planes, planes)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 | self.conv3 = conv3x3(planes, planes)
27 | self.bn3 = nn.BatchNorm2d(planes)
28 | self.maxpool = nn.MaxPool2d(stride)
29 | self.downsample = downsample
30 | self.stride = stride
31 | self.drop_rate = drop_rate
32 | self.num_batches_tracked = 0
33 | self.drop_block = drop_block
34 | self.block_size = block_size
35 | self.DropBlock = DropBlock(block_size=self.block_size)
36 |
37 | def forward(self, x):
38 | self.num_batches_tracked += 1
39 |
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv3(out)
51 | out = self.bn3(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 | out += residual
56 | out = self.relu(out)
57 | out = self.maxpool(out)
58 |
59 | if self.drop_rate > 0:
60 | if self.drop_block == True:
61 | feat_size = out.size()[2]
62 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
63 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
64 | out = self.DropBlock(out, gamma=gamma)
65 | else:
66 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
67 |
68 | return out
69 |
70 |
71 | class ResNet(nn.Module):
72 |
73 | def __init__(self, block, keep_prob=1.0, avg_pool=False, drop_rate=0.0, dropblock_size=5):
74 | self.inplanes = 3
75 | super(ResNet, self).__init__()
76 |
77 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate)
78 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate)
79 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
80 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
81 | if avg_pool:
82 | self.avgpool = nn.AvgPool2d(5, stride=1)
83 | self.keep_prob = keep_prob
84 | self.keep_avg_pool = avg_pool
85 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
86 | self.drop_rate = drop_rate
87 |
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
91 | elif isinstance(m, nn.BatchNorm2d):
92 | nn.init.constant_(m.weight, 1)
93 | nn.init.constant_(m.bias, 0)
94 |
95 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
96 | downsample = None
97 | if stride != 1 or self.inplanes != planes * block.expansion:
98 | downsample = nn.Sequential(
99 | nn.Conv2d(self.inplanes, planes * block.expansion,
100 | kernel_size=1, stride=1, bias=False),
101 | nn.BatchNorm2d(planes * block.expansion),
102 | )
103 |
104 | layers = []
105 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size))
106 | self.inplanes = planes * block.expansion
107 |
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | x = self.layer1(x)
112 | x = self.layer2(x)
113 | x = self.layer3(x)
114 | x = self.layer4(x)
115 | if self.keep_avg_pool:
116 | x = self.avgpool(x)
117 | x = x.view(x.size(0), -1)
118 | return x
119 |
120 |
121 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs):
122 | """Constructs a ResNet-12 model.
123 | """
124 | model = ResNet(BasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
125 | return model
126 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/classification_heads.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | from torch.autograd import Variable
6 | import torch.nn as nn
7 | from qpth.qp import QPFunction
8 |
9 |
10 | def computeGramMatrix(A, B):
11 | """
12 | Constructs a linear kernel matrix between A and B.
13 | We assume that each row in A and B represents a d-dimensional feature vector.
14 |
15 | Parameters:
16 | A: a (n_batch, n, d) Tensor.
17 | B: a (n_batch, m, d) Tensor.
18 | Returns: a (n_batch, n, m) Tensor.
19 | """
20 |
21 | assert(A.dim() == 3)
22 | assert(B.dim() == 3)
23 | assert(A.size(0) == B.size(0) and A.size(2) == B.size(2))
24 |
25 | return torch.bmm(A, B.transpose(1,2))
26 |
27 |
28 | def binv(b_mat):
29 | """
30 | Computes an inverse of each matrix in the batch.
31 | Pytorch 0.4.1 does not support batched matrix inverse.
32 | Hence, we are solving AX=I.
33 |
34 | Parameters:
35 | b_mat: a (n_batch, n, n) Tensor.
36 | Returns: a (n_batch, n, n) Tensor.
37 | """
38 |
39 | id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda()
40 | b_inv, _ = torch.gesv(id_matrix, b_mat)
41 |
42 | return b_inv
43 |
44 |
45 | def one_hot(indices, depth):
46 | """
47 | Returns a one-hot tensor.
48 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
49 |
50 | Parameters:
51 | indices: a (n_batch, m) Tensor or (m) Tensor.
52 | depth: a scalar. Represents the depth of the one hot dimension.
53 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
54 | """
55 |
56 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
57 | index = indices.view(indices.size()+torch.Size([1]))
58 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
59 |
60 | return encoded_indicies
61 |
62 | def batched_kronecker(matrix1, matrix2):
63 | matrix1_flatten = matrix1.reshape(matrix1.size()[0], -1)
64 | matrix2_flatten = matrix2.reshape(matrix2.size()[0], -1)
65 | return torch.bmm(matrix1_flatten.unsqueeze(2), matrix2_flatten.unsqueeze(1)).reshape([matrix1.size()[0]] + list(matrix1.size()[1:]) + list(matrix2.size()[1:])).permute([0, 1, 3, 2, 4]).reshape(matrix1.size(0), matrix1.size(1) * matrix2.size(1), matrix1.size(2) * matrix2.size(2))
66 |
67 | def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False):
68 | """
69 | Fits the support set with ridge regression and
70 | returns the classification score on the query set.
71 |
72 | Parameters:
73 | query: a (tasks_per_batch, n_query, d) Tensor.
74 | support: a (tasks_per_batch, n_support, d) Tensor.
75 | support_labels: a (tasks_per_batch, n_support) Tensor.
76 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
77 | n_shot: a scalar. Represents the number of support examples given per class.
78 | lambda_reg: a scalar. Represents the strength of L2 regularization.
79 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
80 | """
81 |
82 | tasks_per_batch = query.size(0)
83 | n_support = support.size(1)
84 | n_query = query.size(1)
85 |
86 | assert(query.dim() == 3)
87 | assert(support.dim() == 3)
88 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
89 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
90 |
91 | #Here we solve the dual problem:
92 | #Note that the classes are indexed by m & samples are indexed by i.
93 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
94 |
95 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
96 |
97 | #\alpha is an (n_support, n_way) matrix
98 | kernel_matrix = computeGramMatrix(support, support)
99 | kernel_matrix += lambda_reg * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
100 |
101 | block_kernel_matrix = kernel_matrix.repeat(n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support)
102 |
103 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way)
104 | support_labels_one_hot = support_labels_one_hot.transpose(0, 1) # (n_way, tasks_per_batch * n_support)
105 | support_labels_one_hot = support_labels_one_hot.reshape(n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support)
106 |
107 | G = block_kernel_matrix
108 | e = -2.0 * support_labels_one_hot
109 |
110 | #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint.
111 | id_matrix_1 = torch.zeros(tasks_per_batch*n_way, n_support, n_support)
112 | C = Variable(id_matrix_1)
113 | h = Variable(torch.zeros((tasks_per_batch*n_way, n_support)))
114 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
115 |
116 | if double_precision:
117 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
118 |
119 | else:
120 | G, e, C, h = [x.float().cuda() for x in [G, e, C, h]]
121 |
122 | # Solve the following QP to fit SVM:
123 | # \hat z = argmin_z 1/2 z^T G z + e^T z
124 | # subject to Cz <= h
125 | # We use detach() to prevent backpropagation to fixed variables.
126 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
127 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach())
128 |
129 | #qp_sol (n_way*tasks_per_batch, n_support)
130 | qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support)
131 | #qp_sol (n_way, tasks_per_batch, n_support)
132 | qp_sol = qp_sol.permute(1, 2, 0)
133 | #qp_sol (tasks_per_batch, n_support, n_way)
134 |
135 | # Compute the classification score.
136 | compatibility = computeGramMatrix(support, query)
137 | compatibility = compatibility.float()
138 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
139 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
140 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
141 | logits = logits * compatibility
142 | logits = torch.sum(logits, 1)
143 |
144 | return logits
145 |
146 | def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0):
147 | """
148 | Fits the support set with ridge regression and
149 | returns the classification score on the query set.
150 |
151 | This model is the classification head described in:
152 | Meta-learning with differentiable closed-form solvers
153 | (Bertinetto et al., in submission to NIPS 2018).
154 |
155 | Parameters:
156 | query: a (tasks_per_batch, n_query, d) Tensor.
157 | support: a (tasks_per_batch, n_support, d) Tensor.
158 | support_labels: a (tasks_per_batch, n_support) Tensor.
159 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
160 | n_shot: a scalar. Represents the number of support examples given per class.
161 | l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization.
162 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
163 | """
164 |
165 | tasks_per_batch = query.size(0)
166 | n_support = support.size(1)
167 |
168 | assert(query.dim() == 3)
169 | assert(support.dim() == 3)
170 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
171 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
172 |
173 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
174 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
175 |
176 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
177 |
178 | # Compute the dual form solution of the ridge regression.
179 | # W = X^T(X X^T - lambda * I)^(-1) Y
180 | ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix
181 | ridge_sol = binv(ridge_sol)
182 | ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol)
183 | ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot)
184 |
185 | # Compute the classification score.
186 | # score = W X
187 | logits = torch.bmm(query, ridge_sol)
188 |
189 | return logits
190 |
191 |
192 | def MetaOptNetHead_SVM_He(query, support, support_labels, n_way, n_shot, C_reg=0.01, double_precision=False):
193 | """
194 | Fits the support set with multi-class SVM and
195 | returns the classification score on the query set.
196 |
197 | This is the multi-class SVM presented in:
198 | A simplified multi-class support vector machine with reduced dual optimization
199 | (He et al., Pattern Recognition Letter 2012).
200 |
201 | This SVM is desirable because the dual variable of size is n_support
202 | (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM).
203 | This model is the classification head that we have initially used for our project.
204 | This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios.
205 |
206 | Parameters:
207 | query: a (tasks_per_batch, n_query, d) Tensor.
208 | support: a (tasks_per_batch, n_support, d) Tensor.
209 | support_labels: a (tasks_per_batch, n_support) Tensor.
210 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
211 | n_shot: a scalar. Represents the number of support examples given per class.
212 | C_reg: a scalar. Represents the cost parameter C in SVM.
213 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
214 | """
215 |
216 | tasks_per_batch = query.size(0)
217 | n_support = support.size(1)
218 | n_query = query.size(1)
219 |
220 | assert(query.dim() == 3)
221 | assert(support.dim() == 3)
222 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
223 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
224 |
225 |
226 | kernel_matrix = computeGramMatrix(support, support)
227 |
228 | V = (support_labels * n_way - torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1)
229 | G = computeGramMatrix(V, V).detach()
230 | G = kernel_matrix * G
231 |
232 | e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support))
233 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support)
234 | C = Variable(torch.cat((id_matrix, -id_matrix), 1))
235 | h = Variable(torch.cat((C_reg * torch.ones(tasks_per_batch, n_support), torch.zeros(tasks_per_batch, n_support)), 1))
236 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
237 |
238 | if double_precision:
239 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
240 | else:
241 | G, e, C, h = [x.cuda() for x in [G, e, C, h]]
242 |
243 | # Solve the following QP to fit SVM:
244 | # \hat z = argmin_z 1/2 z^T G z + e^T z
245 | # subject to Cz <= h
246 | # We use detach() to prevent backpropagation to fixed variables.
247 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
248 |
249 | # Compute the classification score.
250 | compatibility = computeGramMatrix(query, support)
251 | compatibility = compatibility.float()
252 |
253 | logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query, n_support)
254 | logits = logits * compatibility
255 | logits = logits.view(tasks_per_batch, n_query, n_shot, n_way)
256 | logits = torch.sum(logits, 2)
257 |
258 | return logits
259 |
260 | def ProtoNetHead(query, support, support_labels, n_way, n_shot, normalize=True):
261 | """
262 | Constructs the prototype representation of each class(=mean of support vectors of each class) and
263 | returns the classification score (=L2 distance to each class prototype) on the query set.
264 |
265 | This model is the classification head described in:
266 | Prototypical Networks for Few-shot Learning
267 | (Snell et al., NIPS 2017).
268 |
269 | Parameters:
270 | query: a (tasks_per_batch, n_query, d) Tensor.
271 | support: a (tasks_per_batch, n_support, d) Tensor.
272 | support_labels: a (tasks_per_batch, n_support) Tensor.
273 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
274 | n_shot: a scalar. Represents the number of support examples given per class.
275 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension.
276 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
277 | """
278 |
279 | tasks_per_batch = query.size(0)
280 | n_support = support.size(1)
281 | n_query = query.size(1)
282 | d = query.size(2)
283 |
284 | assert(query.dim() == 3)
285 | assert(support.dim() == 3)
286 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
287 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
288 |
289 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
290 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
291 |
292 | # From:
293 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py
294 | #************************* Compute Prototypes **************************
295 | labels_train_transposed = support_labels_one_hot.transpose(1,2)
296 | # Batch matrix multiplication:
297 | # prototypes = labels_train_transposed * features_train ==>
298 | # [batch_size x nKnovel x num_channels] =
299 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels]
300 | prototypes = torch.bmm(labels_train_transposed, support)
301 | # Divide with the number of examples per novel category.
302 | prototypes = prototypes.div(
303 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes)
304 | )
305 |
306 | # Distance Matrix Vectorization Trick
307 | AB = computeGramMatrix(query, prototypes)
308 | AA = (query * query).sum(dim=2, keepdim=True)
309 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way)
310 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB)
311 | logits = -logits
312 |
313 | if normalize:
314 | logits = logits / d
315 |
316 | return logits
317 |
318 | def MetaOptNetHead_SVM_CS(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15):
319 | """
320 | Fits the support set with multi-class SVM and
321 | returns the classification score on the query set.
322 |
323 | This is the multi-class SVM presented in:
324 | On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
325 | (Crammer and Singer, Journal of Machine Learning Research 2001).
326 |
327 | This model is the classification head that we use for the final version.
328 | Parameters:
329 | query: a (tasks_per_batch, n_query, d) Tensor.
330 | support: a (tasks_per_batch, n_support, d) Tensor.
331 | support_labels: a (tasks_per_batch, n_support) Tensor.
332 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
333 | n_shot: a scalar. Represents the number of support examples given per class.
334 | C_reg: a scalar. Represents the cost parameter C in SVM.
335 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
336 | """
337 |
338 | tasks_per_batch = query.size(0)
339 | n_support = support.size(1)
340 | n_query = query.size(1)
341 |
342 | assert(query.dim() == 3)
343 | assert(support.dim() == 3)
344 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
345 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
346 |
347 | #Here we solve the dual problem:
348 | #Note that the classes are indexed by m & samples are indexed by i.
349 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i
350 | #s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i
351 |
352 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i,
353 | #and C^m_i = C if m = y_i,
354 | #C^m_i = 0 if m != y_i.
355 | #This borrows the notation of liblinear.
356 |
357 | #\alpha is an (n_support, n_way) matrix
358 | kernel_matrix = computeGramMatrix(support, support)
359 |
360 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
361 | block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0)
362 | #This seems to help avoid PSD error from the QP solver.
363 | block_kernel_matrix += 1.0 * torch.eye(n_way*n_support).expand(tasks_per_batch, n_way*n_support, n_way*n_support).cuda()
364 |
365 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support)
366 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way)
367 | support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way)
368 |
369 | G = block_kernel_matrix
370 | e = -1.0 * support_labels_one_hot
371 | #print (G.size())
372 | #This part is for the inequality constraints:
373 | #\alpha^m_i <= C^m_i \forall m,i
374 | #where C^m_i = C if m = y_i,
375 | #C^m_i = 0 if m != y_i.
376 | id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
377 | C = Variable(id_matrix_1)
378 | h = Variable(C_reg * support_labels_one_hot)
379 | #print (C.size(), h.size())
380 | #This part is for the equality constraints:
381 | #\sum_m \alpha^m_i=0 \forall i
382 | id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda()
383 |
384 | A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda()))
385 | b = Variable(torch.zeros(tasks_per_batch, n_support))
386 | #print (A.size(), b.size())
387 | if double_precision:
388 | G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]]
389 | else:
390 | G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]]
391 |
392 | # Solve the following QP to fit SVM:
393 | # \hat z = argmin_z 1/2 z^T G z + e^T z
394 | # subject to Cz <= h
395 | # We use detach() to prevent backpropagation to fixed variables.
396 | qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach())
397 |
398 | # Compute the classification score.
399 | compatibility = computeGramMatrix(support, query)
400 | compatibility = compatibility.float()
401 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way)
402 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way)
403 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way)
404 | logits = logits * compatibility
405 | logits = torch.sum(logits, 1)
406 |
407 | return logits
408 |
409 | def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False):
410 | """
411 | Fits the support set with multi-class SVM and
412 | returns the classification score on the query set.
413 |
414 | This is the multi-class SVM presented in:
415 | Support Vector Machines for Multi Class Pattern Recognition
416 | (Weston and Watkins, ESANN 1999).
417 |
418 | Parameters:
419 | query: a (tasks_per_batch, n_query, d) Tensor.
420 | support: a (tasks_per_batch, n_support, d) Tensor.
421 | support_labels: a (tasks_per_batch, n_support) Tensor.
422 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
423 | n_shot: a scalar. Represents the number of support examples given per class.
424 | C_reg: a scalar. Represents the cost parameter C in SVM.
425 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
426 | """
427 | """
428 | Fits the support set with multi-class SVM and
429 | returns the classification score on the query set.
430 |
431 | This is the multi-class SVM presented in:
432 | Support Vector Machines for Multi Class Pattern Recognition
433 | (Weston and Watkins, ESANN 1999).
434 |
435 | Parameters:
436 | query: a (tasks_per_batch, n_query, d) Tensor.
437 | support: a (tasks_per_batch, n_support, d) Tensor.
438 | support_labels: a (tasks_per_batch, n_support) Tensor.
439 | n_way: a scalar. Represents the number of classes in a few-shot classification task.
440 | n_shot: a scalar. Represents the number of support examples given per class.
441 | C_reg: a scalar. Represents the cost parameter C in SVM.
442 | Returns: a (tasks_per_batch, n_query, n_way) Tensor.
443 | """
444 | tasks_per_batch = query.size(0)
445 | n_support = support.size(1)
446 | n_query = query.size(1)
447 |
448 | assert(query.dim() == 3)
449 | assert(support.dim() == 3)
450 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2))
451 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot
452 |
453 | #In theory, \alpha is an (n_support, n_way) matrix
454 | #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support)
455 | #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix
456 | #then transpose it, resulting in (n_support, n_way) matrix
457 | kernel_matrix = computeGramMatrix(support, support) + torch.ones(tasks_per_batch, n_support, n_support).cuda()
458 |
459 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda()
460 | block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix)
461 |
462 | kernel_matrix_mask_x = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support)
463 | kernel_matrix_mask_y = support_labels.reshape(tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support)
464 | kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float()
465 |
466 | block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix
467 | block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way)
468 |
469 | kernel_matrix_mask_second_term = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way)
470 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(1, 0).reshape(1, -1).repeat(n_support, 1).cuda()
471 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float()
472 |
473 | block_kernel_matrix -= (2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1)
474 |
475 | Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way)
476 | Y_support = Y_support.view(tasks_per_batch, n_support, n_way)
477 | Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support)
478 | Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support)
479 |
480 | G = block_kernel_matrix
481 |
482 | e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support)
483 | id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support)
484 |
485 | C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support
486 |
487 | C = Variable(torch.cat((id_matrix, -id_matrix), 1))
488 | #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1))
489 | zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda()
490 |
491 | h = Variable(torch.cat((C_mat, zer), 1))
492 |
493 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint.
494 |
495 | if double_precision:
496 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]]
497 | else:
498 | G, e, C, h = [x.cuda() for x in [G, e, C, h]]
499 |
500 | # Solve the following QP to fit SVM:
501 | # \hat z = argmin_z 1/2 z^T G z + e^T z
502 | # subject to Cz <= h
503 | # We use detach() to prevent backpropagation to fixed variables.
504 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach())
505 | qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach())
506 |
507 | # Compute the classification score.
508 | compatibility = computeGramMatrix(support, query) + torch.ones(tasks_per_batch, n_support, n_query).cuda()
509 | compatibility = compatibility.float()
510 | compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query)
511 | qp_sol = qp_sol.float()
512 | qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support)
513 | A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support)
514 | A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support)
515 | qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query)
516 | Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support)
517 | Y_support_reshaped = A_i * Y_support_reshaped
518 | Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query)
519 | logits = (Y_support_reshaped - qp_sol) * compatibility
520 |
521 | logits = torch.sum(logits, 2)
522 |
523 | return logits.transpose(1, 2)
524 |
525 | class ClassificationHead(nn.Module):
526 | def __init__(self, base_learner='MetaOptNet', enable_scale=True):
527 | super(ClassificationHead, self).__init__()
528 | if ('SVM-CS' in base_learner):
529 | self.head = MetaOptNetHead_SVM_CS
530 | elif ('Ridge' in base_learner):
531 | self.head = MetaOptNetHead_Ridge
532 | elif ('R2D2' in base_learner):
533 | self.head = R2D2Head
534 | elif ('Proto' in base_learner):
535 | self.head = ProtoNetHead
536 | elif ('SVM-He' in base_learner):
537 | self.head = MetaOptNetHead_SVM_He
538 | elif ('SVM-WW' in base_learner):
539 | self.head = MetaOptNetHead_SVM_WW
540 | else:
541 | print ("Cannot recognize the base learner type")
542 | assert(False)
543 |
544 | # Add a learnable scale
545 | self.enable_scale = enable_scale
546 | self.scale = nn.Parameter(torch.FloatTensor([1.0]))
547 |
548 | def forward(self, query, support, support_labels, n_way, n_shot, **kwargs):
549 | if self.enable_scale:
550 | return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs)
551 | else:
552 | return self.head(query, support, support_labels, n_way, n_shot, **kwargs)
553 |
--------------------------------------------------------------------------------
/models/dropblock.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.distributions import Bernoulli
5 |
6 |
7 | class DropBlock(nn.Module):
8 | def __init__(self, block_size):
9 | super(DropBlock, self).__init__()
10 |
11 | self.block_size = block_size
12 | #self.gamma = gamma
13 | #self.bernouli = Bernoulli(gamma)
14 |
15 | def forward(self, x, gamma):
16 | # shape: (bsize, channels, height, width)
17 |
18 | if self.training:
19 | batch_size, channels, height, width = x.shape
20 |
21 | bernoulli = Bernoulli(gamma)
22 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
23 | #print((x.sample[-2], x.sample[-1]))
24 | block_mask = self._compute_block_mask(mask)
25 | #print (block_mask.size())
26 | #print (x.size())
27 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
28 | count_ones = block_mask.sum()
29 |
30 | return block_mask * x * (countM / count_ones)
31 | else:
32 | return x
33 |
34 | def _compute_block_mask(self, mask):
35 | left_padding = int((self.block_size-1) / 2)
36 | right_padding = int(self.block_size / 2)
37 |
38 | batch_size, channels, height, width = mask.shape
39 | #print ("mask", mask[0][0])
40 | non_zero_idxs = mask.nonzero()
41 | nr_blocks = non_zero_idxs.shape[0]
42 |
43 | offsets = torch.stack(
44 | [
45 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
46 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
47 | ]
48 | ).t().cuda()
49 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
50 |
51 | if nr_blocks > 0:
52 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
53 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
54 | offsets = offsets.long()
55 |
56 | block_idxs = non_zero_idxs + offsets
57 | #block_idxs += left_padding
58 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
59 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
60 | else:
61 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
62 |
63 | block_mask = 1 - padded_mask#[:height, :width]
64 | return block_mask
65 |
--------------------------------------------------------------------------------
/models/protonet_embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | class ConvBlock(nn.Module):
5 | def __init__(self, in_channels, out_channels, retain_activation=True):
6 | super(ConvBlock, self).__init__()
7 |
8 | self.block = nn.Sequential(
9 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
10 | nn.BatchNorm2d(out_channels)
11 | )
12 |
13 | if retain_activation:
14 | self.block.add_module("ReLU", nn.ReLU(inplace=True))
15 | self.block.add_module("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
16 |
17 | def forward(self, x):
18 | out = self.block(x)
19 | return out
20 |
21 | # Embedding network used in Matching Networks (Vinyals et al., NIPS 2016), Meta-LSTM (Ravi & Larochelle, ICLR 2017),
22 | # MAML (w/ h_dim=z_dim=32) (Finn et al., ICML 2017), Prototypical Networks (Snell et al. NIPS 2017).
23 |
24 | class ProtoNetEmbedding(nn.Module):
25 | def __init__(self, x_dim=3, h_dim=64, z_dim=64, retain_last_activation=True):
26 | super(ProtoNetEmbedding, self).__init__()
27 | self.encoder = nn.Sequential(
28 | ConvBlock(x_dim, h_dim),
29 | ConvBlock(h_dim, h_dim),
30 | ConvBlock(h_dim, h_dim),
31 | ConvBlock(h_dim, z_dim, retain_activation=retain_last_activation),
32 | )
33 | for m in self.modules():
34 | if isinstance(m, nn.Conv2d):
35 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
36 | m.weight.data.normal_(0, math.sqrt(2. / n))
37 | elif isinstance(m, nn.BatchNorm2d):
38 | m.weight.data.fill_(1)
39 | m.bias.data.zero_()
40 |
41 | def forward(self, x):
42 | x = self.encoder(x)
43 | return x.view(x.size(0), -1)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0.1.post2
2 | torchvision==0.2.2.post2
3 | qpth==0.0.13
4 | torchnet
5 | tqdm
6 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from torch.autograd import Variable
9 |
10 | from tqdm import tqdm
11 |
12 | from models.protonet_embedding import ProtoNetEmbedding
13 | from models.R2D2_embedding import R2D2Embedding
14 | from models.ResNet12_embedding import resnet12
15 |
16 | from models.classification_heads import ClassificationHead
17 |
18 | from utils import pprint, set_gpu, Timer, count_accuracy, log
19 |
20 | import numpy as np
21 | import os
22 |
23 | def get_model(options):
24 | # Choose the embedding network
25 | if options.network == 'ProtoNet':
26 | network = ProtoNetEmbedding().cuda()
27 | elif options.network == 'R2D2':
28 | network = R2D2Embedding().cuda()
29 | elif options.network == 'ResNet':
30 | if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
31 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
32 | network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3])
33 | else:
34 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
35 | else:
36 | print ("Cannot recognize the network type")
37 | assert(False)
38 |
39 | # Choose the classification head
40 | if opt.head == 'ProtoNet':
41 | cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
42 | elif opt.head == 'Ridge':
43 | cls_head = ClassificationHead(base_learner='Ridge').cuda()
44 | elif opt.head == 'R2D2':
45 | cls_head = ClassificationHead(base_learner='R2D2').cuda()
46 | elif opt.head == 'SVM':
47 | cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
48 | else:
49 | print ("Cannot recognize the classification head type")
50 | assert(False)
51 |
52 | return (network, cls_head)
53 |
54 | def get_dataset(options):
55 | # Choose the embedding network
56 | if options.dataset == 'miniImageNet':
57 | from data.mini_imagenet import MiniImageNet, FewShotDataloader
58 | dataset_test = MiniImageNet(phase='test')
59 | data_loader = FewShotDataloader
60 | elif options.dataset == 'tieredImageNet':
61 | from data.tiered_imagenet import tieredImageNet, FewShotDataloader
62 | dataset_test = tieredImageNet(phase='test')
63 | data_loader = FewShotDataloader
64 | elif options.dataset == 'CIFAR_FS':
65 | from data.CIFAR_FS import CIFAR_FS, FewShotDataloader
66 | dataset_test = CIFAR_FS(phase='test')
67 | data_loader = FewShotDataloader
68 | elif options.dataset == 'FC100':
69 | from data.FC100 import FC100, FewShotDataloader
70 | dataset_test = FC100(phase='test')
71 | data_loader = FewShotDataloader
72 | else:
73 | print ("Cannot recognize the dataset type")
74 | assert(False)
75 |
76 | return (dataset_test, data_loader)
77 |
78 | if __name__ == '__main__':
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument('--gpu', default='0')
81 | parser.add_argument('--load', default='./experiments/exp_1/best_model.pth',
82 | help='path of the checkpoint file')
83 | parser.add_argument('--episode', type=int, default=1000,
84 | help='number of episodes to test')
85 | parser.add_argument('--way', type=int, default=5,
86 | help='number of classes in one test episode')
87 | parser.add_argument('--shot', type=int, default=1,
88 | help='number of support examples per training class')
89 | parser.add_argument('--query', type=int, default=15,
90 | help='number of query examples per training class')
91 | parser.add_argument('--network', type=str, default='ProtoNet',
92 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
93 | parser.add_argument('--head', type=str, default='ProtoNet',
94 | help='choose which embedding network to use. ProtoNet, Ridge, R2D2, SVM')
95 | parser.add_argument('--dataset', type=str, default='miniImageNet',
96 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
97 |
98 | opt = parser.parse_args()
99 | (dataset_test, data_loader) = get_dataset(opt)
100 |
101 | dloader_test = data_loader(
102 | dataset=dataset_test,
103 | nKnovel=opt.way,
104 | nKbase=0,
105 | nExemplars=opt.shot, # num training examples per novel category
106 | nTestNovel=opt.query * opt.way, # num test examples for all the novel categories
107 | nTestBase=0, # num test examples for all the base categories
108 | batch_size=1,
109 | num_workers=1,
110 | epoch_size=opt.episode, # num of batches per epoch
111 | )
112 |
113 | set_gpu(opt.gpu)
114 |
115 | log_file_path = os.path.join(os.path.dirname(opt.load), "test_log.txt")
116 | log(log_file_path, str(vars(opt)))
117 |
118 | # Define the models
119 | (embedding_net, cls_head) = get_model(opt)
120 |
121 | # Load saved model checkpoints
122 | saved_models = torch.load(opt.load)
123 | embedding_net.load_state_dict(saved_models['embedding'])
124 | embedding_net.eval()
125 | cls_head.load_state_dict(saved_models['head'])
126 | cls_head.eval()
127 |
128 | # Evaluate on test set
129 | test_accuracies = []
130 | for i, batch in enumerate(tqdm(dloader_test()), 1):
131 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
132 |
133 | n_support = opt.way * opt.shot
134 | n_query = opt.way * opt.query
135 |
136 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
137 | emb_support = emb_support.reshape(1, n_support, -1)
138 |
139 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
140 | emb_query = emb_query.reshape(1, n_query, -1)
141 |
142 | if opt.head == 'SVM':
143 | logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot, maxIter=3)
144 | else:
145 | logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot)
146 |
147 | acc = count_accuracy(logits.reshape(-1, opt.way), labels_query.reshape(-1))
148 | test_accuracies.append(acc.item())
149 |
150 | avg = np.mean(np.array(test_accuracies))
151 | std = np.std(np.array(test_accuracies))
152 | ci95 = 1.96 * std / np.sqrt(i + 1)
153 |
154 | if i % 50 == 0:
155 | print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} % ({:.2f} %)'\
156 | .format(i, opt.episode, avg, ci95, acc))
157 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 | from torch.autograd import Variable
11 |
12 | from models.classification_heads import ClassificationHead
13 | from models.R2D2_embedding import R2D2Embedding
14 | from models.protonet_embedding import ProtoNetEmbedding
15 | from models.ResNet12_embedding import resnet12
16 |
17 | from utils import set_gpu, Timer, count_accuracy, check_dir, log
18 |
19 | def one_hot(indices, depth):
20 | """
21 | Returns a one-hot tensor.
22 | This is a PyTorch equivalent of Tensorflow's tf.one_hot.
23 |
24 | Parameters:
25 | indices: a (n_batch, m) Tensor or (m) Tensor.
26 | depth: a scalar. Represents the depth of the one hot dimension.
27 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
28 | """
29 |
30 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
31 | index = indices.view(indices.size()+torch.Size([1]))
32 | encoded_indicies = encoded_indicies.scatter_(1,index,1)
33 |
34 | return encoded_indicies
35 |
36 | def get_model(options):
37 | # Choose the embedding network
38 | if options.network == 'ProtoNet':
39 | network = ProtoNetEmbedding().cuda()
40 | elif options.network == 'R2D2':
41 | network = R2D2Embedding().cuda()
42 | elif options.network == 'ResNet':
43 | if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
44 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
45 | network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3])
46 | else:
47 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
48 | else:
49 | print ("Cannot recognize the network type")
50 | assert(False)
51 |
52 | # Choose the classification head
53 | if options.head == 'ProtoNet':
54 | cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
55 | elif options.head == 'Ridge':
56 | cls_head = ClassificationHead(base_learner='Ridge').cuda()
57 | elif options.head == 'R2D2':
58 | cls_head = ClassificationHead(base_learner='R2D2').cuda()
59 | elif options.head == 'SVM':
60 | cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
61 | else:
62 | print ("Cannot recognize the dataset type")
63 | assert(False)
64 |
65 | return (network, cls_head)
66 |
67 | def get_dataset(options):
68 | # Choose the embedding network
69 | if options.dataset == 'miniImageNet':
70 | from data.mini_imagenet import MiniImageNet, FewShotDataloader
71 | dataset_train = MiniImageNet(phase='train')
72 | dataset_val = MiniImageNet(phase='val')
73 | data_loader = FewShotDataloader
74 | elif options.dataset == 'tieredImageNet':
75 | from data.tiered_imagenet import tieredImageNet, FewShotDataloader
76 | dataset_train = tieredImageNet(phase='train')
77 | dataset_val = tieredImageNet(phase='val')
78 | data_loader = FewShotDataloader
79 | elif options.dataset == 'CIFAR_FS':
80 | from data.CIFAR_FS import CIFAR_FS, FewShotDataloader
81 | dataset_train = CIFAR_FS(phase='train')
82 | dataset_val = CIFAR_FS(phase='val')
83 | data_loader = FewShotDataloader
84 | elif options.dataset == 'FC100':
85 | from data.FC100 import FC100, FewShotDataloader
86 | dataset_train = FC100(phase='train')
87 | dataset_val = FC100(phase='val')
88 | data_loader = FewShotDataloader
89 | else:
90 | print ("Cannot recognize the dataset type")
91 | assert(False)
92 |
93 | return (dataset_train, dataset_val, data_loader)
94 |
95 | if __name__ == '__main__':
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument('--num-epoch', type=int, default=60,
98 | help='number of training epochs')
99 | parser.add_argument('--save-epoch', type=int, default=10,
100 | help='frequency of model saving')
101 | parser.add_argument('--train-shot', type=int, default=15,
102 | help='number of support examples per training class')
103 | parser.add_argument('--val-shot', type=int, default=5,
104 | help='number of support examples per validation class')
105 | parser.add_argument('--train-query', type=int, default=6,
106 | help='number of query examples per training class')
107 | parser.add_argument('--val-episode', type=int, default=2000,
108 | help='number of episodes per validation')
109 | parser.add_argument('--val-query', type=int, default=15,
110 | help='number of query examples per validation class')
111 | parser.add_argument('--train-way', type=int, default=5,
112 | help='number of classes in one training episode')
113 | parser.add_argument('--test-way', type=int, default=5,
114 | help='number of classes in one test (or validation) episode')
115 | parser.add_argument('--save-path', default='./experiments/exp_1')
116 | parser.add_argument('--gpu', default='0, 1, 2, 3')
117 | parser.add_argument('--network', type=str, default='ProtoNet',
118 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
119 | parser.add_argument('--head', type=str, default='ProtoNet',
120 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM')
121 | parser.add_argument('--dataset', type=str, default='miniImageNet',
122 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
123 | parser.add_argument('--episodes-per-batch', type=int, default=8,
124 | help='number of episodes per batch')
125 | parser.add_argument('--eps', type=float, default=0.0,
126 | help='epsilon of label smoothing')
127 |
128 | opt = parser.parse_args()
129 |
130 | (dataset_train, dataset_val, data_loader) = get_dataset(opt)
131 |
132 | # Dataloader of Gidaris & Komodakis (CVPR 2018)
133 | dloader_train = data_loader(
134 | dataset=dataset_train,
135 | nKnovel=opt.train_way,
136 | nKbase=0,
137 | nExemplars=opt.train_shot, # num training examples per novel category
138 | nTestNovel=opt.train_way * opt.train_query, # num test examples for all the novel categories
139 | nTestBase=0, # num test examples for all the base categories
140 | batch_size=opt.episodes_per_batch,
141 | num_workers=4,
142 | epoch_size=opt.episodes_per_batch * 1000, # num of batches per epoch
143 | )
144 |
145 | dloader_val = data_loader(
146 | dataset=dataset_val,
147 | nKnovel=opt.test_way,
148 | nKbase=0,
149 | nExemplars=opt.val_shot, # num training examples per novel category
150 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories
151 | nTestBase=0, # num test examples for all the base categories
152 | batch_size=1,
153 | num_workers=0,
154 | epoch_size=1 * opt.val_episode, # num of batches per epoch
155 | )
156 |
157 | set_gpu(opt.gpu)
158 | check_dir('./experiments/')
159 | check_dir(opt.save_path)
160 |
161 | log_file_path = os.path.join(opt.save_path, "train_log.txt")
162 | log(log_file_path, str(vars(opt)))
163 |
164 | (embedding_net, cls_head) = get_model(opt)
165 |
166 | optimizer = torch.optim.SGD([{'params': embedding_net.parameters()},
167 | {'params': cls_head.parameters()}], lr=0.1, momentum=0.9, \
168 | weight_decay=5e-4, nesterov=True)
169 |
170 | lambda_epoch = lambda e: 1.0 if e < 20 else (0.06 if e < 40 else 0.012 if e < 50 else (0.0024))
171 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch, last_epoch=-1)
172 |
173 | max_val_acc = 0.0
174 |
175 | timer = Timer()
176 | x_entropy = torch.nn.CrossEntropyLoss()
177 |
178 | for epoch in range(1, opt.num_epoch + 1):
179 | # Train on the training split
180 | lr_scheduler.step()
181 |
182 | # Fetch the current epoch's learning rate
183 | epoch_learning_rate = 0.1
184 | for param_group in optimizer.param_groups:
185 | epoch_learning_rate = param_group['lr']
186 |
187 | log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format(
188 | epoch, epoch_learning_rate))
189 |
190 | _, _ = [x.train() for x in (embedding_net, cls_head)]
191 |
192 | train_accuracies = []
193 | train_losses = []
194 |
195 | for i, batch in enumerate(tqdm(dloader_train(epoch)), 1):
196 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
197 |
198 | train_n_support = opt.train_way * opt.train_shot
199 | train_n_query = opt.train_way * opt.train_query
200 |
201 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
202 | emb_support = emb_support.reshape(opt.episodes_per_batch, train_n_support, -1)
203 |
204 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
205 | emb_query = emb_query.reshape(opt.episodes_per_batch, train_n_query, -1)
206 |
207 | logit_query = cls_head(emb_query, emb_support, labels_support, opt.train_way, opt.train_shot)
208 |
209 | smoothed_one_hot = one_hot(labels_query.reshape(-1), opt.train_way)
210 | smoothed_one_hot = smoothed_one_hot * (1 - opt.eps) + (1 - smoothed_one_hot) * opt.eps / (opt.train_way - 1)
211 |
212 | log_prb = F.log_softmax(logit_query.reshape(-1, opt.train_way), dim=1)
213 | loss = -(smoothed_one_hot * log_prb).sum(dim=1)
214 | loss = loss.mean()
215 |
216 | acc = count_accuracy(logit_query.reshape(-1, opt.train_way), labels_query.reshape(-1))
217 |
218 | train_accuracies.append(acc.item())
219 | train_losses.append(loss.item())
220 |
221 | if (i % 100 == 0):
222 | train_acc_avg = np.mean(np.array(train_accuracies))
223 | log(log_file_path, 'Train Epoch: {}\tBatch: [{}/{}]\tLoss: {:.4f}\tAccuracy: {:.2f} % ({:.2f} %)'.format(
224 | epoch, i, len(dloader_train), loss.item(), train_acc_avg, acc))
225 |
226 | optimizer.zero_grad()
227 | loss.backward()
228 | optimizer.step()
229 |
230 | # Evaluate on the validation split
231 | _, _ = [x.eval() for x in (embedding_net, cls_head)]
232 |
233 | val_accuracies = []
234 | val_losses = []
235 |
236 | for i, batch in enumerate(tqdm(dloader_val(epoch)), 1):
237 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
238 |
239 | test_n_support = opt.test_way * opt.val_shot
240 | test_n_query = opt.test_way * opt.val_query
241 |
242 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
243 | emb_support = emb_support.reshape(1, test_n_support, -1)
244 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
245 | emb_query = emb_query.reshape(1, test_n_query, -1)
246 |
247 | logit_query = cls_head(emb_query, emb_support, labels_support, opt.test_way, opt.val_shot)
248 |
249 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
250 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1))
251 |
252 | val_accuracies.append(acc.item())
253 | val_losses.append(loss.item())
254 |
255 | val_acc_avg = np.mean(np.array(val_accuracies))
256 | val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)
257 |
258 | val_loss_avg = np.mean(np.array(val_losses))
259 |
260 | if val_acc_avg > max_val_acc:
261 | max_val_acc = val_acc_avg
262 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},\
263 | os.path.join(opt.save_path, 'best_model.pth'))
264 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'\
265 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
266 | else:
267 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'\
268 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
269 |
270 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
271 | , os.path.join(opt.save_path, 'last_epoch.pth'))
272 |
273 | if epoch % opt.save_epoch == 0:
274 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\
275 | , os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))
276 |
277 | log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(), timer.measure(epoch / float(opt.num_epoch))))
278 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import pprint
4 | import torch
5 |
6 | def set_gpu(x):
7 | os.environ['CUDA_VISIBLE_DEVICES'] = x
8 | print('using gpu:', x)
9 |
10 | def check_dir(path):
11 | '''
12 | Create directory if it does not exist.
13 | path: Path of directory.
14 | '''
15 | if not os.path.exists(path):
16 | os.mkdir(path)
17 |
18 | def count_accuracy(logits, label):
19 | pred = torch.argmax(logits, dim=1).view(-1)
20 | label = label.view(-1)
21 | accuracy = 100 * pred.eq(label).float().mean()
22 | return accuracy
23 |
24 | class Timer():
25 | def __init__(self):
26 | self.o = time.time()
27 |
28 | def measure(self, p=1):
29 | x = (time.time() - self.o) / float(p)
30 | x = int(x)
31 | if x >= 3600:
32 | return '{:.1f}h'.format(x / 3600)
33 | if x >= 60:
34 | return '{}m'.format(round(x / 60))
35 | return '{}s'.format(x)
36 |
37 | def log(log_file_path, string):
38 | '''
39 | Write one line of log into screen and file.
40 | log_file_path: Path of log file.
41 | string: String to write in log file.
42 | '''
43 | with open(log_file_path, 'a+') as f:
44 | f.write(string + '\n')
45 | f.flush()
46 | print(string)
--------------------------------------------------------------------------------