├── LICENSE ├── README.md ├── __init__.py ├── algorithm.png ├── data ├── CIFAR_FS.py ├── FC100.py ├── __init__.py ├── mini_imagenet.py └── tiered_imagenet.py ├── models ├── R2D2_embedding.py ├── ResNet12_embedding.py ├── __init__.py ├── classification_heads.py ├── dropblock.py └── protonet_embedding.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning with Differentiable Convex Optimization 2 | This repository contains the code for the paper: 3 |
4 | [**Meta-Learning with Differentiable Convex Optimization**](https://arxiv.org/pdf/1904.03758.pdf) 5 |
6 | Kwonjoon Lee, [Subhransu Maji](https://people.cs.umass.edu/~smaji/), Avinash Ravichandran, [Stefano Soatto](http://web.cs.ucla.edu/~soatto/) 7 | CVPR 2019 (**Oral**) 8 |

9 | 10 |

11 | 12 | ### Abstract 13 | 14 | Many meta-learning approaches for few-shot learning rely on simple base learners such as nearest-neighbor classifiers. However, even in the few-shot regime, discriminatively trained linear predictors can offer better generalization. We propose to use these predictors as base learners to learn representations for few-shot learning and show they offer better tradeoffs between feature size and performance across a range of few-shot recognition benchmarks. Our objective is to learn feature embeddings that generalize well under a linear classification rule for novel categories. To efficiently solve the objective, we exploit two properties of linear classifiers: implicit differentiation of the optimality conditions of the convex problem and the dual formulation of the optimization problem. This allows us to use high-dimensional embeddings with improved generalization at a modest increase in computational overhead. Our approach, named MetaOptNet, achieves state-of-the-art performance on miniImageNet, tieredImageNet, CIFAR-FS and FC100 few-shot learning benchmarks. 15 | 16 | ### Citation 17 | 18 | If you use this code for your research, please cite our paper: 19 | ``` 20 | @inproceedings{lee2019meta, 21 | title={Meta-Learning with Differentiable Convex Optimization}, 22 | author={Kwonjoon Lee and Subhransu Maji and Avinash Ravichandran and Stefano Soatto}, 23 | booktitle={CVPR}, 24 | year={2019} 25 | } 26 | ``` 27 | 28 | ## Dependencies 29 | * Python 2.7+ (not tested on Python 3) 30 | * [PyTorch 0.4.0+](http://pytorch.org) 31 | * [qpth 0.0.11+](https://github.com/locuslab/qpth) 32 | * [tqdm](https://github.com/tqdm/tqdm) 33 | 34 | ## Usage 35 | 36 | ### Installation 37 | 38 | 1. Clone this repository: 39 | ```bash 40 | git clone https://github.com/kjunelee/MetaOptNet.git 41 | cd MetaOptNet 42 | ``` 43 | 2. Download and decompress dataset files: [**miniImageNet**](https://drive.google.com/file/d/12V7qi-AjrYi6OoJdYcN_k502BM_jcP8D/view?usp=sharing) (courtesy of [**Spyros Gidaris**](https://github.com/gidariss/FewShotWithoutForgetting)), [**tieredImageNet**](https://drive.google.com/open?id=1nVGCTd9ttULRXFezh4xILQ9lUkg0WZCG), [**FC100**](https://drive.google.com/file/d/1_ZsLyqI487NRDQhwvI7rg86FK3YAZvz1/view?usp=sharing), [**CIFAR-FS**](https://drive.google.com/file/d/1GjGMI0q3bgcpcB_CjI40fX54WgLPuTpS/view?usp=sharing) 44 | 45 | 3. For each dataset loader, specify the path to the directory. For example, in MetaOptNet/data/mini_imagenet.py line 30: 46 | ```python 47 | _MINI_IMAGENET_DATASET_DIR = 'path/to/miniImageNet' 48 | ``` 49 | 50 | ### Meta-training 51 | 1. To train MetaOptNet-SVM on 5-way miniImageNet benchmark: 52 | ```bash 53 | python train.py --gpu 0,1,2,3 --save-path "./experiments/miniImageNet_MetaOptNet_SVM" --train-shot 15 \ 54 | --head SVM --network ResNet --dataset miniImageNet --eps 0.1 55 | ``` 56 | As shown in Figure 2, of our paper, we can meta-train the embedding once with a high shot for all meta-testing shots. We don't need to meta-train with all possible meta-test shots unlike in Prototypical Networks. 57 | 2. You can experiment with varying base learners by changing '--head' argument to ProtoNet or Ridge. Also, you can change the backbone architecture to vanilla 4-layer conv net by setting '--network' argument to ProtoNet. For other arguments, please see MetaOptNet/train.py from lines 85 to 114. 58 | 3. To train MetaOptNet-SVM on 5-way tieredImageNet benchmark: 59 | ```bash 60 | python train.py --gpu 0,1,2,3 --save-path "./experiments/tieredImageNet_MetaOptNet_SVM" --train-shot 10 \ 61 | --head SVM --network ResNet --dataset tieredImageNet 62 | ``` 63 | 3. To train MetaOptNet-RR on 5-way CIFAR-FS benchmark: 64 | ```bash 65 | python train.py --gpu 0 --save-path "./experiments/CIFAR_FS_MetaOptNet_RR" --train-shot 5 \ 66 | --head Ridge --network ResNet --dataset CIFAR_FS 67 | ``` 68 | 4. To train MetaOptNet-RR on 5-way FC100 benchmark: 69 | ```bash 70 | python train.py --gpu 0 --save-path "./experiments/FC100_MetaOptNet_RR" --train-shot 15 \ 71 | --head Ridge --network ResNet --dataset FC100 72 | ``` 73 | ### Meta-testing 74 | 1. To test MetaOptNet-SVM on 5-way miniImageNet 1-shot benchmark: 75 | ``` 76 | python test.py --gpu 0,1,2,3 --load ./experiments/miniImageNet_MetaOptNet_SVM/best_model.pth --episode 1000 \ 77 | --way 5 --shot 1 --query 15 --head SVM --network ResNet --dataset miniImageNet 78 | ``` 79 | 2. Similarly, to test MetaOptNet-SVM on 5-way miniImageNet 5-shot benchmark: 80 | ``` 81 | python test.py --gpu 0,1,2,3 --load ./experiments/miniImageNet_MetaOptNet_SVM/best_model.pth --episode 1000 \ 82 | --way 5 --shot 5 --query 15 --head SVM --network ResNet --dataset miniImageNet 83 | ``` 84 | 85 | ## Acknowledgments 86 | 87 | This code is based on the implementations of [**Prototypical Networks**](https://github.com/cyvius96/prototypical-network-pytorch), [**Dynamic Few-Shot Visual Learning without Forgetting**](https://github.com/gidariss/FewShotWithoutForgetting), and [**DropBlock**](https://github.com/miguelvr/dropblock). 88 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Implement your code here. 2 | -------------------------------------------------------------------------------- /algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kjunelee/MetaOptNet/19601591da8090734e2e69113c9510831da5528a/algorithm.png -------------------------------------------------------------------------------- /data/CIFAR_FS.py: -------------------------------------------------------------------------------- 1 | # Dataloader of Gidaris & Komodakis, CVPR 2018 2 | # Adapted from: 3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import random 10 | import pickle 11 | import json 12 | import math 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | import torchnet as tnt 20 | 21 | import h5py 22 | 23 | from PIL import Image 24 | from PIL import ImageEnhance 25 | 26 | from pdb import set_trace as breakpoint 27 | 28 | 29 | # Set the appropriate paths of the datasets here. 30 | _CIFAR_FS_DATASET_DIR = '/mnt/cube/datasets/few-shot/CIFAR_FS' 31 | 32 | def buildLabelIndex(labels): 33 | label2inds = {} 34 | for idx, label in enumerate(labels): 35 | if label not in label2inds: 36 | label2inds[label] = [] 37 | label2inds[label].append(idx) 38 | 39 | return label2inds 40 | 41 | def load_data(file): 42 | try: 43 | with open(file, 'rb') as fo: 44 | data = pickle.load(fo) 45 | return data 46 | except: 47 | with open(file, 'rb') as f: 48 | u = pickle._Unpickler(f) 49 | u.encoding = 'latin1' 50 | data = u.load() 51 | return data 52 | 53 | class CIFAR_FS(data.Dataset): 54 | def __init__(self, phase='train', do_not_use_random_transf=False): 55 | 56 | assert(phase=='train' or phase=='val' or phase=='test') 57 | self.phase = phase 58 | self.name = 'CIFAR_FS_' + phase 59 | 60 | print('Loading CIFAR-FS dataset - phase {0}'.format(phase)) 61 | file_train_categories_train_phase = os.path.join( 62 | _CIFAR_FS_DATASET_DIR, 63 | 'CIFAR_FS_train.pickle') 64 | file_train_categories_val_phase = os.path.join( 65 | _CIFAR_FS_DATASET_DIR, 66 | 'CIFAR_FS_train.pickle') 67 | file_train_categories_test_phase = os.path.join( 68 | _CIFAR_FS_DATASET_DIR, 69 | 'CIFAR_FS_train.pickle') 70 | file_val_categories_val_phase = os.path.join( 71 | _CIFAR_FS_DATASET_DIR, 72 | 'CIFAR_FS_val.pickle') 73 | file_test_categories_test_phase = os.path.join( 74 | _CIFAR_FS_DATASET_DIR, 75 | 'CIFAR_FS_test.pickle') 76 | 77 | if self.phase=='train': 78 | # During training phase we only load the training phase images 79 | # of the training categories (aka base categories). 80 | data_train = load_data(file_train_categories_train_phase) 81 | self.data = data_train['data'] 82 | self.labels = data_train['labels'] 83 | 84 | self.label2ind = buildLabelIndex(self.labels) 85 | self.labelIds = sorted(self.label2ind.keys()) 86 | self.num_cats = len(self.labelIds) 87 | self.labelIds_base = self.labelIds 88 | self.num_cats_base = len(self.labelIds_base) 89 | 90 | elif self.phase=='val' or self.phase=='test': 91 | if self.phase=='test': 92 | # load data that will be used for evaluating the recognition 93 | # accuracy of the base categories. 94 | data_base = load_data(file_train_categories_test_phase) 95 | # load data that will be use for evaluating the few-shot recogniton 96 | # accuracy on the novel categories. 97 | data_novel = load_data(file_test_categories_test_phase) 98 | else: # phase=='val' 99 | # load data that will be used for evaluating the recognition 100 | # accuracy of the base categories. 101 | data_base = load_data(file_train_categories_val_phase) 102 | # load data that will be use for evaluating the few-shot recogniton 103 | # accuracy on the novel categories. 104 | data_novel = load_data(file_val_categories_val_phase) 105 | 106 | self.data = np.concatenate( 107 | [data_base['data'], data_novel['data']], axis=0) 108 | self.labels = data_base['labels'] + data_novel['labels'] 109 | 110 | self.label2ind = buildLabelIndex(self.labels) 111 | self.labelIds = sorted(self.label2ind.keys()) 112 | self.num_cats = len(self.labelIds) 113 | 114 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys() 115 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys() 116 | self.num_cats_base = len(self.labelIds_base) 117 | self.num_cats_novel = len(self.labelIds_novel) 118 | intersection = set(self.labelIds_base) & set(self.labelIds_novel) 119 | assert(len(intersection) == 0) 120 | else: 121 | raise ValueError('Not valid phase {0}'.format(self.phase)) 122 | 123 | mean_pix = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]] 124 | 125 | std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]] 126 | 127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix) 128 | 129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True): 130 | 131 | self.transform = transforms.Compose([ 132 | lambda x: np.asarray(x), 133 | transforms.ToTensor(), 134 | normalize 135 | ]) 136 | else: 137 | 138 | self.transform = transforms.Compose([ 139 | transforms.RandomCrop(32, padding=4), 140 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 141 | transforms.RandomHorizontalFlip(), 142 | lambda x: np.asarray(x), 143 | transforms.ToTensor(), 144 | normalize 145 | ]) 146 | 147 | def __getitem__(self, index): 148 | img, label = self.data[index], self.labels[index] 149 | # doing this so that it is consistent with all other datasets 150 | # to return a PIL Image 151 | img = Image.fromarray(img) 152 | if self.transform is not None: 153 | img = self.transform(img) 154 | return img, label 155 | 156 | def __len__(self): 157 | return len(self.data) 158 | 159 | 160 | class FewShotDataloader(): 161 | def __init__(self, 162 | dataset, 163 | nKnovel=5, # number of novel categories. 164 | nKbase=-1, # number of base categories. 165 | nExemplars=1, # number of training examples per novel category. 166 | nTestNovel=15*5, # number of test examples for all the novel categories. 167 | nTestBase=15*5, # number of test examples for all the base categories. 168 | batch_size=1, # number of training episodes per batch. 169 | num_workers=4, 170 | epoch_size=2000, # number of batches per epoch. 171 | ): 172 | 173 | self.dataset = dataset 174 | self.phase = self.dataset.phase 175 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' 176 | else self.dataset.num_cats_novel) 177 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel) 178 | self.nKnovel = nKnovel 179 | 180 | max_possible_nKbase = self.dataset.num_cats_base 181 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase 182 | if self.phase=='train' and nKbase > 0: 183 | nKbase -= self.nKnovel 184 | max_possible_nKbase -= self.nKnovel 185 | 186 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase) 187 | self.nKbase = nKbase 188 | 189 | self.nExemplars = nExemplars 190 | self.nTestNovel = nTestNovel 191 | self.nTestBase = nTestBase 192 | self.batch_size = batch_size 193 | self.epoch_size = epoch_size 194 | self.num_workers = num_workers 195 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val') 196 | 197 | def sampleImageIdsFrom(self, cat_id, sample_size=1): 198 | """ 199 | Samples `sample_size` number of unique image ids picked from the 200 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]). 201 | 202 | Args: 203 | cat_id: a scalar with the id of the category from which images will 204 | be sampled. 205 | sample_size: number of images that will be sampled. 206 | 207 | Returns: 208 | image_ids: a list of length `sample_size` with unique image ids. 209 | """ 210 | assert(cat_id in self.dataset.label2ind) 211 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size) 212 | # Note: random.sample samples elements without replacement. 213 | return random.sample(self.dataset.label2ind[cat_id], sample_size) 214 | 215 | def sampleCategories(self, cat_set, sample_size=1): 216 | """ 217 | Samples `sample_size` number of unique categories picked from the 218 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. 219 | 220 | Args: 221 | cat_set: string that specifies the set of categories from which 222 | categories will be sampled. 223 | sample_size: number of categories that will be sampled. 224 | 225 | Returns: 226 | cat_ids: a list of length `sample_size` with unique category ids. 227 | """ 228 | if cat_set=='base': 229 | labelIds = self.dataset.labelIds_base 230 | elif cat_set=='novel': 231 | labelIds = self.dataset.labelIds_novel 232 | else: 233 | raise ValueError('Not recognized category set {}'.format(cat_set)) 234 | 235 | assert(len(labelIds) >= sample_size) 236 | # return sample_size unique categories chosen from labelIds set of 237 | # categories (that can be either self.labelIds_base or self.labelIds_novel) 238 | # Note: random.sample samples elements without replacement. 239 | return random.sample(labelIds, sample_size) 240 | 241 | def sample_base_and_novel_categories(self, nKbase, nKnovel): 242 | """ 243 | Samples `nKbase` number of base categories and `nKnovel` number of novel 244 | categories. 245 | 246 | Args: 247 | nKbase: number of base categories 248 | nKnovel: number of novel categories 249 | 250 | Returns: 251 | Kbase: a list of length 'nKbase' with the ids of the sampled base 252 | categories. 253 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel 254 | categories. 255 | """ 256 | if self.is_eval_mode: 257 | assert(nKnovel <= self.dataset.num_cats_novel) 258 | # sample from the set of base categories 'nKbase' number of base 259 | # categories. 260 | Kbase = sorted(self.sampleCategories('base', nKbase)) 261 | # sample from the set of novel categories 'nKnovel' number of novel 262 | # categories. 263 | Knovel = sorted(self.sampleCategories('novel', nKnovel)) 264 | else: 265 | # sample from the set of base categories 'nKnovel' + 'nKbase' number 266 | # of categories. 267 | cats_ids = self.sampleCategories('base', nKnovel+nKbase) 268 | assert(len(cats_ids) == (nKnovel+nKbase)) 269 | # Randomly pick 'nKnovel' number of fake novel categories and keep 270 | # the rest as base categories. 271 | random.shuffle(cats_ids) 272 | Knovel = sorted(cats_ids[:nKnovel]) 273 | Kbase = sorted(cats_ids[nKnovel:]) 274 | 275 | return Kbase, Knovel 276 | 277 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase): 278 | """ 279 | Sample `nTestBase` number of images from the `Kbase` categories. 280 | 281 | Args: 282 | Kbase: a list of length `nKbase` with the ids of the categories from 283 | where the images will be sampled. 284 | nTestBase: the total number of images that will be sampled. 285 | 286 | Returns: 287 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st 288 | element of each tuple is the image id that was sampled and the 289 | 2nd elemend is its category label (which is in the range 290 | [0, len(Kbase)-1]). 291 | """ 292 | Tbase = [] 293 | if len(Kbase) > 0: 294 | # Sample for each base category a number images such that the total 295 | # number sampled images of all categories to be equal to `nTestBase`. 296 | KbaseIndices = np.random.choice( 297 | np.arange(len(Kbase)), size=nTestBase, replace=True) 298 | KbaseIndices, NumImagesPerCategory = np.unique( 299 | KbaseIndices, return_counts=True) 300 | 301 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): 302 | imd_ids = self.sampleImageIdsFrom( 303 | Kbase[Kbase_idx], sample_size=NumImages) 304 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] 305 | 306 | assert(len(Tbase) == nTestBase) 307 | 308 | return Tbase 309 | 310 | def sample_train_and_test_examples_for_novel_categories( 311 | self, Knovel, nTestNovel, nExemplars, nKbase): 312 | """Samples train and test examples of the novel categories. 313 | 314 | Args: 315 | Knovel: a list with the ids of the novel categories. 316 | nTestNovel: the total number of test images that will be sampled 317 | from all the novel categories. 318 | nExemplars: the number of training examples per novel category that 319 | will be sampled. 320 | nKbase: the number of base categories. It is used as offset of the 321 | category index of each sampled image. 322 | 323 | Returns: 324 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The 325 | 1st element of each tuple is the image id that was sampled and 326 | the 2nd element is its category label (which is in the range 327 | [nKbase, nKbase + len(Knovel) - 1]). 328 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element 329 | tuples. The 1st element of each tuple is the image id that was 330 | sampled and the 2nd element is its category label (which is in 331 | the ragne [nKbase, nKbase + len(Knovel) - 1]). 332 | """ 333 | 334 | if len(Knovel) == 0: 335 | return [], [] 336 | 337 | nKnovel = len(Knovel) 338 | Tnovel = [] 339 | Exemplars = [] 340 | assert((nTestNovel % nKnovel) == 0) 341 | nEvalExamplesPerClass = int(nTestNovel / nKnovel) 342 | 343 | for Knovel_idx in range(len(Knovel)): 344 | imd_ids = self.sampleImageIdsFrom( 345 | Knovel[Knovel_idx], 346 | sample_size=(nEvalExamplesPerClass + nExemplars)) 347 | 348 | imds_tnovel = imd_ids[:nEvalExamplesPerClass] 349 | imds_ememplars = imd_ids[nEvalExamplesPerClass:] 350 | 351 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel] 352 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars] 353 | assert(len(Tnovel) == nTestNovel) 354 | assert(len(Exemplars) == len(Knovel) * nExemplars) 355 | random.shuffle(Exemplars) 356 | 357 | return Tnovel, Exemplars 358 | 359 | def sample_episode(self): 360 | """Samples a training episode.""" 361 | nKnovel = self.nKnovel 362 | nKbase = self.nKbase 363 | nTestNovel = self.nTestNovel 364 | nTestBase = self.nTestBase 365 | nExemplars = self.nExemplars 366 | 367 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) 368 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) 369 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( 370 | Knovel, nTestNovel, nExemplars, nKbase) 371 | 372 | # concatenate the base and novel category examples. 373 | Test = Tbase + Tnovel 374 | random.shuffle(Test) 375 | Kall = Kbase + Knovel 376 | 377 | return Exemplars, Test, Kall, nKbase 378 | 379 | def createExamplesTensorData(self, examples): 380 | """ 381 | Creates the examples image and label tensor data. 382 | 383 | Args: 384 | examples: a list of 2-element tuples, each representing a 385 | train or test example. The 1st element of each tuple 386 | is the image id of the example and 2nd element is the 387 | category label of the example, which is in the range 388 | [0, nK - 1], where nK is the total number of categories 389 | (both novel and base). 390 | 391 | Returns: 392 | images: a tensor of shape [nExamples, Height, Width, 3] with the 393 | example images, where nExamples is the number of examples 394 | (i.e., nExamples = len(examples)). 395 | labels: a tensor of shape [nExamples] with the category label 396 | of each example. 397 | """ 398 | images = torch.stack( 399 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0) 400 | labels = torch.LongTensor([label for _, label in examples]) 401 | return images, labels 402 | 403 | def get_iterator(self, epoch=0): 404 | rand_seed = epoch 405 | random.seed(rand_seed) 406 | np.random.seed(rand_seed) 407 | def load_function(iter_idx): 408 | Exemplars, Test, Kall, nKbase = self.sample_episode() 409 | Xt, Yt = self.createExamplesTensorData(Test) 410 | Kall = torch.LongTensor(Kall) 411 | if len(Exemplars) > 0: 412 | Xe, Ye = self.createExamplesTensorData(Exemplars) 413 | return Xe, Ye, Xt, Yt, Kall, nKbase 414 | else: 415 | return Xt, Yt, Kall, nKbase 416 | 417 | tnt_dataset = tnt.dataset.ListDataset( 418 | elem_list=range(self.epoch_size), load=load_function) 419 | data_loader = tnt_dataset.parallel( 420 | batch_size=self.batch_size, 421 | num_workers=(0 if self.is_eval_mode else self.num_workers), 422 | shuffle=(False if self.is_eval_mode else True)) 423 | 424 | return data_loader 425 | 426 | def __call__(self, epoch=0): 427 | return self.get_iterator(epoch) 428 | 429 | def __len__(self): 430 | return int(self.epoch_size / self.batch_size) 431 | -------------------------------------------------------------------------------- /data/FC100.py: -------------------------------------------------------------------------------- 1 | # Dataloader of Gidaris & Komodakis, CVPR 2018 2 | # Adapted from: 3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import random 10 | import pickle 11 | import json 12 | import math 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | import torchnet as tnt 20 | 21 | import h5py 22 | 23 | from PIL import Image 24 | from PIL import ImageEnhance 25 | 26 | from pdb import set_trace as breakpoint 27 | 28 | 29 | # Set the appropriate paths of the datasets here. 30 | _FC100_DATASET_DIR = '/mnt/cube/datasets/few-shot/FC100' 31 | 32 | def buildLabelIndex(labels): 33 | label2inds = {} 34 | for idx, label in enumerate(labels): 35 | if label not in label2inds: 36 | label2inds[label] = [] 37 | label2inds[label].append(idx) 38 | 39 | return label2inds 40 | 41 | def load_data(file): 42 | try: 43 | with open(file, 'rb') as fo: 44 | data = pickle.load(fo) 45 | return data 46 | except: 47 | with open(file, 'rb') as f: 48 | u = pickle._Unpickler(f) 49 | u.encoding = 'latin1' 50 | data = u.load() 51 | return data 52 | 53 | class FC100(data.Dataset): 54 | def __init__(self, phase='train', do_not_use_random_transf=False): 55 | 56 | assert(phase=='train' or phase=='val' or phase=='test') 57 | self.phase = phase 58 | self.name = 'FC100_' + phase 59 | 60 | print('Loading FC100 dataset - phase {0}'.format(phase)) 61 | file_train_categories_train_phase = os.path.join( 62 | _FC100_DATASET_DIR, 63 | 'FC100_train.pickle') 64 | file_train_categories_val_phase = os.path.join( 65 | _FC100_DATASET_DIR, 66 | 'FC100_train.pickle') 67 | file_train_categories_test_phase = os.path.join( 68 | _FC100_DATASET_DIR, 69 | 'FC100_train.pickle') 70 | file_val_categories_val_phase = os.path.join( 71 | _FC100_DATASET_DIR, 72 | 'FC100_val.pickle') 73 | file_test_categories_test_phase = os.path.join( 74 | _FC100_DATASET_DIR, 75 | 'FC100_test.pickle') 76 | 77 | if self.phase=='train': 78 | # During training phase we only load the training phase images 79 | # of the training categories (aka base categories). 80 | data_train = load_data(file_train_categories_train_phase) 81 | self.data = data_train['data'] 82 | self.labels = data_train['labels'] 83 | #print (self.labels) 84 | self.label2ind = buildLabelIndex(self.labels) 85 | self.labelIds = sorted(self.label2ind.keys()) 86 | self.num_cats = len(self.labelIds) 87 | self.labelIds_base = self.labelIds 88 | self.num_cats_base = len(self.labelIds_base) 89 | #print (self.data.shape) 90 | elif self.phase=='val' or self.phase=='test': 91 | if self.phase=='test': 92 | # load data that will be used for evaluating the recognition 93 | # accuracy of the base categories. 94 | data_base = load_data(file_train_categories_test_phase) 95 | # load data that will be use for evaluating the few-shot recogniton 96 | # accuracy on the novel categories. 97 | data_novel = load_data(file_test_categories_test_phase) 98 | else: # phase=='val' 99 | # load data that will be used for evaluating the recognition 100 | # accuracy of the base categories. 101 | data_base = load_data(file_train_categories_val_phase) 102 | # load data that will be use for evaluating the few-shot recogniton 103 | # accuracy on the novel categories. 104 | data_novel = load_data(file_val_categories_val_phase) 105 | 106 | self.data = np.concatenate( 107 | [data_base['data'], data_novel['data']], axis=0) 108 | self.labels = data_base['labels'] + data_novel['labels'] 109 | 110 | self.label2ind = buildLabelIndex(self.labels) 111 | self.labelIds = sorted(self.label2ind.keys()) 112 | self.num_cats = len(self.labelIds) 113 | 114 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys() 115 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys() 116 | self.num_cats_base = len(self.labelIds_base) 117 | self.num_cats_novel = len(self.labelIds_novel) 118 | intersection = set(self.labelIds_base) & set(self.labelIds_novel) 119 | assert(len(intersection) == 0) 120 | else: 121 | raise ValueError('Not valid phase {0}'.format(self.phase)) 122 | 123 | mean_pix = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]] 124 | 125 | std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]] 126 | 127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix) 128 | 129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True): 130 | self.transform = transforms.Compose([ 131 | lambda x: np.asarray(x), 132 | transforms.ToTensor(), 133 | normalize 134 | ]) 135 | else: 136 | self.transform = transforms.Compose([ 137 | transforms.RandomCrop(32, padding=4), 138 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 139 | transforms.RandomHorizontalFlip(), 140 | lambda x: np.asarray(x), 141 | transforms.ToTensor(), 142 | normalize 143 | ]) 144 | 145 | def __getitem__(self, index): 146 | img, label = self.data[index], self.labels[index] 147 | # doing this so that it is consistent with all other datasets 148 | # to return a PIL Image 149 | img = Image.fromarray(img) 150 | if self.transform is not None: 151 | img = self.transform(img) 152 | return img, label 153 | 154 | def __len__(self): 155 | return len(self.data) 156 | 157 | 158 | class FewShotDataloader(): 159 | def __init__(self, 160 | dataset, 161 | nKnovel=5, # number of novel categories. 162 | nKbase=-1, # number of base categories. 163 | nExemplars=1, # number of training examples per novel category. 164 | nTestNovel=15*5, # number of test examples for all the novel categories. 165 | nTestBase=15*5, # number of test examples for all the base categories. 166 | batch_size=1, # number of training episodes per batch. 167 | num_workers=4, 168 | epoch_size=2000, # number of batches per epoch. 169 | ): 170 | 171 | self.dataset = dataset 172 | self.phase = self.dataset.phase 173 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' 174 | else self.dataset.num_cats_novel) 175 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel) 176 | self.nKnovel = nKnovel 177 | 178 | max_possible_nKbase = self.dataset.num_cats_base 179 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase 180 | if self.phase=='train' and nKbase > 0: 181 | nKbase -= self.nKnovel 182 | max_possible_nKbase -= self.nKnovel 183 | 184 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase) 185 | self.nKbase = nKbase 186 | 187 | self.nExemplars = nExemplars 188 | self.nTestNovel = nTestNovel 189 | self.nTestBase = nTestBase 190 | self.batch_size = batch_size 191 | self.epoch_size = epoch_size 192 | self.num_workers = num_workers 193 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val') 194 | 195 | def sampleImageIdsFrom(self, cat_id, sample_size=1): 196 | """ 197 | Samples `sample_size` number of unique image ids picked from the 198 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]). 199 | 200 | Args: 201 | cat_id: a scalar with the id of the category from which images will 202 | be sampled. 203 | sample_size: number of images that will be sampled. 204 | 205 | Returns: 206 | image_ids: a list of length `sample_size` with unique image ids. 207 | """ 208 | assert(cat_id in self.dataset.label2ind) 209 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size) 210 | # Note: random.sample samples elements without replacement. 211 | return random.sample(self.dataset.label2ind[cat_id], sample_size) 212 | 213 | def sampleCategories(self, cat_set, sample_size=1): 214 | """ 215 | Samples `sample_size` number of unique categories picked from the 216 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. 217 | 218 | Args: 219 | cat_set: string that specifies the set of categories from which 220 | categories will be sampled. 221 | sample_size: number of categories that will be sampled. 222 | 223 | Returns: 224 | cat_ids: a list of length `sample_size` with unique category ids. 225 | """ 226 | if cat_set=='base': 227 | labelIds = self.dataset.labelIds_base 228 | elif cat_set=='novel': 229 | labelIds = self.dataset.labelIds_novel 230 | else: 231 | raise ValueError('Not recognized category set {}'.format(cat_set)) 232 | 233 | assert(len(labelIds) >= sample_size) 234 | # return sample_size unique categories chosen from labelIds set of 235 | # categories (that can be either self.labelIds_base or self.labelIds_novel) 236 | # Note: random.sample samples elements without replacement. 237 | return random.sample(labelIds, sample_size) 238 | 239 | def sample_base_and_novel_categories(self, nKbase, nKnovel): 240 | """ 241 | Samples `nKbase` number of base categories and `nKnovel` number of novel 242 | categories. 243 | 244 | Args: 245 | nKbase: number of base categories 246 | nKnovel: number of novel categories 247 | 248 | Returns: 249 | Kbase: a list of length 'nKbase' with the ids of the sampled base 250 | categories. 251 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel 252 | categories. 253 | """ 254 | if self.is_eval_mode: 255 | assert(nKnovel <= self.dataset.num_cats_novel) 256 | # sample from the set of base categories 'nKbase' number of base 257 | # categories. 258 | Kbase = sorted(self.sampleCategories('base', nKbase)) 259 | # sample from the set of novel categories 'nKnovel' number of novel 260 | # categories. 261 | Knovel = sorted(self.sampleCategories('novel', nKnovel)) 262 | else: 263 | # sample from the set of base categories 'nKnovel' + 'nKbase' number 264 | # of categories. 265 | cats_ids = self.sampleCategories('base', nKnovel+nKbase) 266 | assert(len(cats_ids) == (nKnovel+nKbase)) 267 | # Randomly pick 'nKnovel' number of fake novel categories and keep 268 | # the rest as base categories. 269 | random.shuffle(cats_ids) 270 | Knovel = sorted(cats_ids[:nKnovel]) 271 | Kbase = sorted(cats_ids[nKnovel:]) 272 | 273 | return Kbase, Knovel 274 | 275 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase): 276 | """ 277 | Sample `nTestBase` number of images from the `Kbase` categories. 278 | 279 | Args: 280 | Kbase: a list of length `nKbase` with the ids of the categories from 281 | where the images will be sampled. 282 | nTestBase: the total number of images that will be sampled. 283 | 284 | Returns: 285 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st 286 | element of each tuple is the image id that was sampled and the 287 | 2nd elemend is its category label (which is in the range 288 | [0, len(Kbase)-1]). 289 | """ 290 | Tbase = [] 291 | if len(Kbase) > 0: 292 | # Sample for each base category a number images such that the total 293 | # number sampled images of all categories to be equal to `nTestBase`. 294 | KbaseIndices = np.random.choice( 295 | np.arange(len(Kbase)), size=nTestBase, replace=True) 296 | KbaseIndices, NumImagesPerCategory = np.unique( 297 | KbaseIndices, return_counts=True) 298 | 299 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): 300 | imd_ids = self.sampleImageIdsFrom( 301 | Kbase[Kbase_idx], sample_size=NumImages) 302 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] 303 | 304 | assert(len(Tbase) == nTestBase) 305 | 306 | return Tbase 307 | 308 | def sample_train_and_test_examples_for_novel_categories( 309 | self, Knovel, nTestNovel, nExemplars, nKbase): 310 | """Samples train and test examples of the novel categories. 311 | 312 | Args: 313 | Knovel: a list with the ids of the novel categories. 314 | nTestNovel: the total number of test images that will be sampled 315 | from all the novel categories. 316 | nExemplars: the number of training examples per novel category that 317 | will be sampled. 318 | nKbase: the number of base categories. It is used as offset of the 319 | category index of each sampled image. 320 | 321 | Returns: 322 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The 323 | 1st element of each tuple is the image id that was sampled and 324 | the 2nd element is its category label (which is in the range 325 | [nKbase, nKbase + len(Knovel) - 1]). 326 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element 327 | tuples. The 1st element of each tuple is the image id that was 328 | sampled and the 2nd element is its category label (which is in 329 | the ragne [nKbase, nKbase + len(Knovel) - 1]). 330 | """ 331 | 332 | if len(Knovel) == 0: 333 | return [], [] 334 | 335 | nKnovel = len(Knovel) 336 | Tnovel = [] 337 | Exemplars = [] 338 | assert((nTestNovel % nKnovel) == 0) 339 | nEvalExamplesPerClass = int(nTestNovel / nKnovel) 340 | 341 | for Knovel_idx in range(len(Knovel)): 342 | imd_ids = self.sampleImageIdsFrom( 343 | Knovel[Knovel_idx], 344 | sample_size=(nEvalExamplesPerClass + nExemplars)) 345 | 346 | imds_tnovel = imd_ids[:nEvalExamplesPerClass] 347 | imds_ememplars = imd_ids[nEvalExamplesPerClass:] 348 | 349 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel] 350 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars] 351 | assert(len(Tnovel) == nTestNovel) 352 | assert(len(Exemplars) == len(Knovel) * nExemplars) 353 | random.shuffle(Exemplars) 354 | 355 | return Tnovel, Exemplars 356 | 357 | def sample_episode(self): 358 | """Samples a training episode.""" 359 | nKnovel = self.nKnovel 360 | nKbase = self.nKbase 361 | nTestNovel = self.nTestNovel 362 | nTestBase = self.nTestBase 363 | nExemplars = self.nExemplars 364 | 365 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) 366 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) 367 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( 368 | Knovel, nTestNovel, nExemplars, nKbase) 369 | 370 | # concatenate the base and novel category examples. 371 | Test = Tbase + Tnovel 372 | random.shuffle(Test) 373 | Kall = Kbase + Knovel 374 | 375 | return Exemplars, Test, Kall, nKbase 376 | 377 | def createExamplesTensorData(self, examples): 378 | """ 379 | Creates the examples image and label tensor data. 380 | 381 | Args: 382 | examples: a list of 2-element tuples, each representing a 383 | train or test example. The 1st element of each tuple 384 | is the image id of the example and 2nd element is the 385 | category label of the example, which is in the range 386 | [0, nK - 1], where nK is the total number of categories 387 | (both novel and base). 388 | 389 | Returns: 390 | images: a tensor of shape [nExamples, Height, Width, 3] with the 391 | example images, where nExamples is the number of examples 392 | (i.e., nExamples = len(examples)). 393 | labels: a tensor of shape [nExamples] with the category label 394 | of each example. 395 | """ 396 | images = torch.stack( 397 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0) 398 | labels = torch.LongTensor([label for _, label in examples]) 399 | return images, labels 400 | 401 | def get_iterator(self, epoch=0): 402 | rand_seed = epoch 403 | random.seed(rand_seed) 404 | np.random.seed(rand_seed) 405 | def load_function(iter_idx): 406 | Exemplars, Test, Kall, nKbase = self.sample_episode() 407 | Xt, Yt = self.createExamplesTensorData(Test) 408 | Kall = torch.LongTensor(Kall) 409 | if len(Exemplars) > 0: 410 | Xe, Ye = self.createExamplesTensorData(Exemplars) 411 | return Xe, Ye, Xt, Yt, Kall, nKbase 412 | else: 413 | return Xt, Yt, Kall, nKbase 414 | 415 | tnt_dataset = tnt.dataset.ListDataset( 416 | elem_list=range(self.epoch_size), load=load_function) 417 | data_loader = tnt_dataset.parallel( 418 | batch_size=self.batch_size, 419 | num_workers=(0 if self.is_eval_mode else self.num_workers), 420 | shuffle=(False if self.is_eval_mode else True)) 421 | 422 | return data_loader 423 | 424 | def __call__(self, epoch=0): 425 | return self.get_iterator(epoch) 426 | 427 | def __len__(self): 428 | return int(self.epoch_size / self.batch_size) 429 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | # Dataloader of Gidaris & Komodakis, CVPR 2018 2 | # Adapted from: 3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import random 10 | import pickle 11 | import json 12 | import math 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | import torchnet as tnt 20 | 21 | import h5py 22 | 23 | from PIL import Image 24 | from PIL import ImageEnhance 25 | 26 | from pdb import set_trace as breakpoint 27 | 28 | 29 | # Set the appropriate paths of the datasets here. 30 | _MINI_IMAGENET_DATASET_DIR = '/efs/data/miniimagenet/kwonl/data/miniImageNet_numpy' 31 | 32 | def buildLabelIndex(labels): 33 | label2inds = {} 34 | for idx, label in enumerate(labels): 35 | if label not in label2inds: 36 | label2inds[label] = [] 37 | label2inds[label].append(idx) 38 | 39 | return label2inds 40 | 41 | 42 | def load_data(file): 43 | try: 44 | with open(file, 'rb') as fo: 45 | data = pickle.load(fo) 46 | return data 47 | except: 48 | with open(file, 'rb') as f: 49 | u = pickle._Unpickler(f) 50 | u.encoding = 'latin1' 51 | data = u.load() 52 | return data 53 | 54 | class MiniImageNet(data.Dataset): 55 | def __init__(self, phase='train', do_not_use_random_transf=False): 56 | 57 | self.base_folder = 'miniImagenet' 58 | assert(phase=='train' or phase=='val' or phase=='test') 59 | self.phase = phase 60 | self.name = 'MiniImageNet_' + phase 61 | 62 | print('Loading mini ImageNet dataset - phase {0}'.format(phase)) 63 | file_train_categories_train_phase = os.path.join( 64 | _MINI_IMAGENET_DATASET_DIR, 65 | 'miniImageNet_category_split_train_phase_train.pickle') 66 | file_train_categories_val_phase = os.path.join( 67 | _MINI_IMAGENET_DATASET_DIR, 68 | 'miniImageNet_category_split_train_phase_val.pickle') 69 | file_train_categories_test_phase = os.path.join( 70 | _MINI_IMAGENET_DATASET_DIR, 71 | 'miniImageNet_category_split_train_phase_test.pickle') 72 | file_val_categories_val_phase = os.path.join( 73 | _MINI_IMAGENET_DATASET_DIR, 74 | 'miniImageNet_category_split_val.pickle') 75 | file_test_categories_test_phase = os.path.join( 76 | _MINI_IMAGENET_DATASET_DIR, 77 | 'miniImageNet_category_split_test.pickle') 78 | 79 | if self.phase=='train': 80 | # During training phase we only load the training phase images 81 | # of the training categories (aka base categories). 82 | data_train = load_data(file_train_categories_train_phase) 83 | self.data = data_train['data'] 84 | self.labels = data_train['labels'] 85 | 86 | self.label2ind = buildLabelIndex(self.labels) 87 | self.labelIds = sorted(self.label2ind.keys()) 88 | self.num_cats = len(self.labelIds) 89 | self.labelIds_base = self.labelIds 90 | self.num_cats_base = len(self.labelIds_base) 91 | 92 | elif self.phase=='val' or self.phase=='test': 93 | if self.phase=='test': 94 | # load data that will be used for evaluating the recognition 95 | # accuracy of the base categories. 96 | data_base = load_data(file_train_categories_test_phase) 97 | # load data that will be use for evaluating the few-shot recogniton 98 | # accuracy on the novel categories. 99 | data_novel = load_data(file_test_categories_test_phase) 100 | else: # phase=='val' 101 | # load data that will be used for evaluating the recognition 102 | # accuracy of the base categories. 103 | data_base = load_data(file_train_categories_val_phase) 104 | # load data that will be use for evaluating the few-shot recogniton 105 | # accuracy on the novel categories. 106 | data_novel = load_data(file_val_categories_val_phase) 107 | 108 | self.data = np.concatenate( 109 | [data_base['data'], data_novel['data']], axis=0) 110 | self.labels = data_base['labels'] + data_novel['labels'] 111 | 112 | self.label2ind = buildLabelIndex(self.labels) 113 | self.labelIds = sorted(self.label2ind.keys()) 114 | self.num_cats = len(self.labelIds) 115 | 116 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys() 117 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys() 118 | self.num_cats_base = len(self.labelIds_base) 119 | self.num_cats_novel = len(self.labelIds_novel) 120 | intersection = set(self.labelIds_base) & set(self.labelIds_novel) 121 | assert(len(intersection) == 0) 122 | else: 123 | raise ValueError('Not valid phase {0}'.format(self.phase)) 124 | 125 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]] 126 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]] 127 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix) 128 | 129 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True): 130 | self.transform = transforms.Compose([ 131 | lambda x: np.asarray(x), 132 | transforms.ToTensor(), 133 | normalize 134 | ]) 135 | else: 136 | self.transform = transforms.Compose([ 137 | transforms.RandomCrop(84, padding=8), 138 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 139 | transforms.RandomHorizontalFlip(), 140 | lambda x: np.asarray(x), 141 | transforms.ToTensor(), 142 | normalize 143 | ]) 144 | 145 | def __getitem__(self, index): 146 | img, label = self.data[index], self.labels[index] 147 | # doing this so that it is consistent with all other datasets 148 | # to return a PIL Image 149 | img = Image.fromarray(img) 150 | if self.transform is not None: 151 | img = self.transform(img) 152 | return img, label 153 | 154 | def __len__(self): 155 | return len(self.data) 156 | 157 | 158 | class FewShotDataloader(): 159 | def __init__(self, 160 | dataset, 161 | nKnovel=5, # number of novel categories. 162 | nKbase=-1, # number of base categories. 163 | nExemplars=1, # number of training examples per novel category. 164 | nTestNovel=15*5, # number of test examples for all the novel categories. 165 | nTestBase=15*5, # number of test examples for all the base categories. 166 | batch_size=1, # number of training episodes per batch. 167 | num_workers=4, 168 | epoch_size=2000, # number of batches per epoch. 169 | ): 170 | 171 | self.dataset = dataset 172 | self.phase = self.dataset.phase 173 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' 174 | else self.dataset.num_cats_novel) 175 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel) 176 | self.nKnovel = nKnovel 177 | 178 | max_possible_nKbase = self.dataset.num_cats_base 179 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase 180 | if self.phase=='train' and nKbase > 0: 181 | nKbase -= self.nKnovel 182 | max_possible_nKbase -= self.nKnovel 183 | 184 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase) 185 | self.nKbase = nKbase 186 | 187 | self.nExemplars = nExemplars 188 | self.nTestNovel = nTestNovel 189 | self.nTestBase = nTestBase 190 | self.batch_size = batch_size 191 | self.epoch_size = epoch_size 192 | self.num_workers = num_workers 193 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val') 194 | 195 | def sampleImageIdsFrom(self, cat_id, sample_size=1): 196 | """ 197 | Samples `sample_size` number of unique image ids picked from the 198 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]). 199 | 200 | Args: 201 | cat_id: a scalar with the id of the category from which images will 202 | be sampled. 203 | sample_size: number of images that will be sampled. 204 | 205 | Returns: 206 | image_ids: a list of length `sample_size` with unique image ids. 207 | """ 208 | assert(cat_id in self.dataset.label2ind) 209 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size) 210 | # Note: random.sample samples elements without replacement. 211 | return random.sample(self.dataset.label2ind[cat_id], sample_size) 212 | 213 | def sampleCategories(self, cat_set, sample_size=1): 214 | """ 215 | Samples `sample_size` number of unique categories picked from the 216 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. 217 | 218 | Args: 219 | cat_set: string that specifies the set of categories from which 220 | categories will be sampled. 221 | sample_size: number of categories that will be sampled. 222 | 223 | Returns: 224 | cat_ids: a list of length `sample_size` with unique category ids. 225 | """ 226 | if cat_set=='base': 227 | labelIds = self.dataset.labelIds_base 228 | elif cat_set=='novel': 229 | labelIds = self.dataset.labelIds_novel 230 | else: 231 | raise ValueError('Not recognized category set {}'.format(cat_set)) 232 | 233 | assert(len(labelIds) >= sample_size) 234 | # return sample_size unique categories chosen from labelIds set of 235 | # categories (that can be either self.labelIds_base or self.labelIds_novel) 236 | # Note: random.sample samples elements without replacement. 237 | return random.sample(labelIds, sample_size) 238 | 239 | def sample_base_and_novel_categories(self, nKbase, nKnovel): 240 | """ 241 | Samples `nKbase` number of base categories and `nKnovel` number of novel 242 | categories. 243 | 244 | Args: 245 | nKbase: number of base categories 246 | nKnovel: number of novel categories 247 | 248 | Returns: 249 | Kbase: a list of length 'nKbase' with the ids of the sampled base 250 | categories. 251 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel 252 | categories. 253 | """ 254 | if self.is_eval_mode: 255 | assert(nKnovel <= self.dataset.num_cats_novel) 256 | # sample from the set of base categories 'nKbase' number of base 257 | # categories. 258 | Kbase = sorted(self.sampleCategories('base', nKbase)) 259 | # sample from the set of novel categories 'nKnovel' number of novel 260 | # categories. 261 | Knovel = sorted(self.sampleCategories('novel', nKnovel)) 262 | else: 263 | # sample from the set of base categories 'nKnovel' + 'nKbase' number 264 | # of categories. 265 | cats_ids = self.sampleCategories('base', nKnovel+nKbase) 266 | assert(len(cats_ids) == (nKnovel+nKbase)) 267 | # Randomly pick 'nKnovel' number of fake novel categories and keep 268 | # the rest as base categories. 269 | random.shuffle(cats_ids) 270 | Knovel = sorted(cats_ids[:nKnovel]) 271 | Kbase = sorted(cats_ids[nKnovel:]) 272 | 273 | return Kbase, Knovel 274 | 275 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase): 276 | """ 277 | Sample `nTestBase` number of images from the `Kbase` categories. 278 | 279 | Args: 280 | Kbase: a list of length `nKbase` with the ids of the categories from 281 | where the images will be sampled. 282 | nTestBase: the total number of images that will be sampled. 283 | 284 | Returns: 285 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st 286 | element of each tuple is the image id that was sampled and the 287 | 2nd elemend is its category label (which is in the range 288 | [0, len(Kbase)-1]). 289 | """ 290 | Tbase = [] 291 | if len(Kbase) > 0: 292 | # Sample for each base category a number images such that the total 293 | # number sampled images of all categories to be equal to `nTestBase`. 294 | KbaseIndices = np.random.choice( 295 | np.arange(len(Kbase)), size=nTestBase, replace=True) 296 | KbaseIndices, NumImagesPerCategory = np.unique( 297 | KbaseIndices, return_counts=True) 298 | 299 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): 300 | imd_ids = self.sampleImageIdsFrom( 301 | Kbase[Kbase_idx], sample_size=NumImages) 302 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] 303 | 304 | assert(len(Tbase) == nTestBase) 305 | 306 | return Tbase 307 | 308 | def sample_train_and_test_examples_for_novel_categories( 309 | self, Knovel, nTestNovel, nExemplars, nKbase): 310 | """Samples train and test examples of the novel categories. 311 | 312 | Args: 313 | Knovel: a list with the ids of the novel categories. 314 | nTestNovel: the total number of test images that will be sampled 315 | from all the novel categories. 316 | nExemplars: the number of training examples per novel category that 317 | will be sampled. 318 | nKbase: the number of base categories. It is used as offset of the 319 | category index of each sampled image. 320 | 321 | Returns: 322 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The 323 | 1st element of each tuple is the image id that was sampled and 324 | the 2nd element is its category label (which is in the range 325 | [nKbase, nKbase + len(Knovel) - 1]). 326 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element 327 | tuples. The 1st element of each tuple is the image id that was 328 | sampled and the 2nd element is its category label (which is in 329 | the ragne [nKbase, nKbase + len(Knovel) - 1]). 330 | """ 331 | 332 | if len(Knovel) == 0: 333 | return [], [] 334 | 335 | nKnovel = len(Knovel) 336 | Tnovel = [] 337 | Exemplars = [] 338 | assert((nTestNovel % nKnovel) == 0) 339 | nEvalExamplesPerClass = int(nTestNovel / nKnovel) 340 | 341 | for Knovel_idx in range(len(Knovel)): 342 | imd_ids = self.sampleImageIdsFrom( 343 | Knovel[Knovel_idx], 344 | sample_size=(nEvalExamplesPerClass + nExemplars)) 345 | 346 | imds_tnovel = imd_ids[:nEvalExamplesPerClass] 347 | imds_ememplars = imd_ids[nEvalExamplesPerClass:] 348 | 349 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel] 350 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars] 351 | assert(len(Tnovel) == nTestNovel) 352 | assert(len(Exemplars) == len(Knovel) * nExemplars) 353 | random.shuffle(Exemplars) 354 | 355 | return Tnovel, Exemplars 356 | 357 | def sample_episode(self): 358 | """Samples a training episode.""" 359 | nKnovel = self.nKnovel 360 | nKbase = self.nKbase 361 | nTestNovel = self.nTestNovel 362 | nTestBase = self.nTestBase 363 | nExemplars = self.nExemplars 364 | 365 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) 366 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) 367 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( 368 | Knovel, nTestNovel, nExemplars, nKbase) 369 | 370 | # concatenate the base and novel category examples. 371 | Test = Tbase + Tnovel 372 | random.shuffle(Test) 373 | Kall = Kbase + Knovel 374 | 375 | return Exemplars, Test, Kall, nKbase 376 | 377 | def createExamplesTensorData(self, examples): 378 | """ 379 | Creates the examples image and label tensor data. 380 | 381 | Args: 382 | examples: a list of 2-element tuples, each representing a 383 | train or test example. The 1st element of each tuple 384 | is the image id of the example and 2nd element is the 385 | category label of the example, which is in the range 386 | [0, nK - 1], where nK is the total number of categories 387 | (both novel and base). 388 | 389 | Returns: 390 | images: a tensor of shape [nExamples, Height, Width, 3] with the 391 | example images, where nExamples is the number of examples 392 | (i.e., nExamples = len(examples)). 393 | labels: a tensor of shape [nExamples] with the category label 394 | of each example. 395 | """ 396 | images = torch.stack( 397 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0) 398 | labels = torch.LongTensor([label for _, label in examples]) 399 | return images, labels 400 | 401 | def get_iterator(self, epoch=0): 402 | rand_seed = epoch 403 | random.seed(rand_seed) 404 | np.random.seed(rand_seed) 405 | def load_function(iter_idx): 406 | Exemplars, Test, Kall, nKbase = self.sample_episode() 407 | Xt, Yt = self.createExamplesTensorData(Test) 408 | Kall = torch.LongTensor(Kall) 409 | if len(Exemplars) > 0: 410 | Xe, Ye = self.createExamplesTensorData(Exemplars) 411 | return Xe, Ye, Xt, Yt, Kall, nKbase 412 | else: 413 | return Xt, Yt, Kall, nKbase 414 | 415 | tnt_dataset = tnt.dataset.ListDataset( 416 | elem_list=range(self.epoch_size), load=load_function) 417 | data_loader = tnt_dataset.parallel( 418 | batch_size=self.batch_size, 419 | num_workers=(0 if self.is_eval_mode else self.num_workers), 420 | shuffle=(False if self.is_eval_mode else True)) 421 | 422 | return data_loader 423 | 424 | def __call__(self, epoch=0): 425 | return self.get_iterator(epoch) 426 | 427 | def __len__(self): 428 | return int(self.epoch_size / self.batch_size) 429 | -------------------------------------------------------------------------------- /data/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | # Dataloader of Gidaris & Komodakis, CVPR 2018 2 | # Adapted from: 3 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py 4 | from __future__ import print_function 5 | 6 | import os 7 | import os.path 8 | import numpy as np 9 | import random 10 | import pickle 11 | import json 12 | import math 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.datasets as datasets 18 | import torchvision.transforms as transforms 19 | import torchnet as tnt 20 | 21 | import h5py 22 | 23 | from PIL import Image 24 | from PIL import ImageEnhance 25 | 26 | from pdb import set_trace as breakpoint 27 | 28 | 29 | # Set the appropriate paths of the datasets here. 30 | _TIERED_IMAGENET_DATASET_DIR = '/mnt/cube/datasets/tiered-imagenet/' 31 | 32 | def buildLabelIndex(labels): 33 | label2inds = {} 34 | for idx, label in enumerate(labels): 35 | if label not in label2inds: 36 | label2inds[label] = [] 37 | label2inds[label].append(idx) 38 | 39 | return label2inds 40 | 41 | 42 | def load_data(file): 43 | try: 44 | with open(file, 'rb') as fo: 45 | data = pickle.load(fo) 46 | return data 47 | except: 48 | with open(file, 'rb') as f: 49 | u = pickle._Unpickler(f) 50 | u.encoding = 'latin1' 51 | data = u.load() 52 | return data 53 | 54 | class tieredImageNet(data.Dataset): 55 | def __init__(self, phase='train', do_not_use_random_transf=False): 56 | 57 | assert(phase=='train' or phase=='val' or phase=='test') 58 | self.phase = phase 59 | self.name = 'tieredImageNet_' + phase 60 | 61 | print('Loading tiered ImageNet dataset - phase {0}'.format(phase)) 62 | file_train_categories_train_phase = os.path.join( 63 | _TIERED_IMAGENET_DATASET_DIR, 64 | 'train_images.npz') 65 | label_train_categories_train_phase = os.path.join( 66 | _TIERED_IMAGENET_DATASET_DIR, 67 | 'train_labels.pkl') 68 | file_train_categories_val_phase = os.path.join( 69 | _TIERED_IMAGENET_DATASET_DIR, 70 | 'train_images.npz') 71 | label_train_categories_val_phase = os.path.join( 72 | _TIERED_IMAGENET_DATASET_DIR, 73 | 'train_labels.pkl') 74 | file_train_categories_test_phase = os.path.join( 75 | _TIERED_IMAGENET_DATASET_DIR, 76 | 'train_images.npz') 77 | label_train_categories_test_phase = os.path.join( 78 | _TIERED_IMAGENET_DATASET_DIR, 79 | 'train_labels.pkl') 80 | 81 | file_val_categories_val_phase = os.path.join( 82 | _TIERED_IMAGENET_DATASET_DIR, 83 | 'val_images.npz') 84 | label_val_categories_val_phase = os.path.join( 85 | _TIERED_IMAGENET_DATASET_DIR, 86 | 'val_labels.pkl') 87 | file_test_categories_test_phase = os.path.join( 88 | _TIERED_IMAGENET_DATASET_DIR, 89 | 'test_images.npz') 90 | label_test_categories_test_phase = os.path.join( 91 | _TIERED_IMAGENET_DATASET_DIR, 92 | 'test_labels.pkl') 93 | 94 | if self.phase=='train': 95 | # During training phase we only load the training phase images 96 | # of the training categories (aka base categories). 97 | data_train = load_data(label_train_categories_train_phase) 98 | #self.data = data_train['data'] 99 | self.labels = data_train['labels'] 100 | self.data = np.load(file_train_categories_train_phase)['images']#np.array(load_data(file_train_categories_train_phase)) 101 | #self.labels = load_data(file_train_categories_train_phase)#data_train['labels'] 102 | 103 | self.label2ind = buildLabelIndex(self.labels) 104 | self.labelIds = sorted(self.label2ind.keys()) 105 | self.num_cats = len(self.labelIds) 106 | self.labelIds_base = self.labelIds 107 | self.num_cats_base = len(self.labelIds_base) 108 | 109 | elif self.phase=='val' or self.phase=='test': 110 | if self.phase=='test': 111 | # load data that will be used for evaluating the recognition 112 | # accuracy of the base categories. 113 | data_base = load_data(label_train_categories_test_phase) 114 | data_base_images = np.load(file_train_categories_test_phase)['images'] 115 | 116 | # load data that will be use for evaluating the few-shot recogniton 117 | # accuracy on the novel categories. 118 | data_novel = load_data(label_test_categories_test_phase) 119 | data_novel_images = np.load(file_test_categories_test_phase)['images'] 120 | else: # phase=='val' 121 | # load data that will be used for evaluating the recognition 122 | # accuracy of the base categories. 123 | data_base = load_data(label_train_categories_val_phase) 124 | data_base_images = np.load(file_train_categories_val_phase)['images'] 125 | #print (data_base_images) 126 | #print (data_base_images.shape) 127 | # load data that will be use for evaluating the few-shot recogniton 128 | # accuracy on the novel categories. 129 | data_novel = load_data(label_val_categories_val_phase) 130 | data_novel_images = np.load(file_val_categories_val_phase)['images'] 131 | 132 | self.data = np.concatenate( 133 | [data_base_images, data_novel_images], axis=0) 134 | self.labels = data_base['labels'] + data_novel['labels'] 135 | 136 | self.label2ind = buildLabelIndex(self.labels) 137 | self.labelIds = sorted(self.label2ind.keys()) 138 | self.num_cats = len(self.labelIds) 139 | 140 | self.labelIds_base = buildLabelIndex(data_base['labels']).keys() 141 | self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys() 142 | self.num_cats_base = len(self.labelIds_base) 143 | self.num_cats_novel = len(self.labelIds_novel) 144 | intersection = set(self.labelIds_base) & set(self.labelIds_novel) 145 | print (intersection) 146 | assert(len(intersection) == 0) 147 | else: 148 | raise ValueError('Not valid phase {0}'.format(self.phase)) 149 | 150 | mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]] 151 | std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]] 152 | normalize = transforms.Normalize(mean=mean_pix, std=std_pix) 153 | 154 | if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True): 155 | self.transform = transforms.Compose([ 156 | lambda x: np.asarray(x), 157 | transforms.ToTensor(), 158 | normalize 159 | ]) 160 | else: 161 | self.transform = transforms.Compose([ 162 | transforms.RandomCrop(84, padding=8), 163 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 164 | transforms.RandomHorizontalFlip(), 165 | lambda x: np.asarray(x), 166 | transforms.ToTensor(), 167 | normalize 168 | ]) 169 | 170 | def __getitem__(self, index): 171 | img, label = self.data[index], self.labels[index] 172 | # doing this so that it is consistent with all other datasets 173 | # to return a PIL Image 174 | img = Image.fromarray(img) 175 | if self.transform is not None: 176 | img = self.transform(img) 177 | return img, label 178 | 179 | def __len__(self): 180 | return len(self.data) 181 | 182 | 183 | class FewShotDataloader(): 184 | def __init__(self, 185 | dataset, 186 | nKnovel=5, # number of novel categories. 187 | nKbase=-1, # number of base categories. 188 | nExemplars=1, # number of training examples per novel category. 189 | nTestNovel=15*5, # number of test examples for all the novel categories. 190 | nTestBase=15*5, # number of test examples for all the base categories. 191 | batch_size=1, # number of training episodes per batch. 192 | num_workers=4, 193 | epoch_size=2000, # number of batches per epoch. 194 | ): 195 | 196 | self.dataset = dataset 197 | self.phase = self.dataset.phase 198 | max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' 199 | else self.dataset.num_cats_novel) 200 | assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel) 201 | self.nKnovel = nKnovel 202 | 203 | max_possible_nKbase = self.dataset.num_cats_base 204 | nKbase = nKbase if nKbase >= 0 else max_possible_nKbase 205 | if self.phase=='train' and nKbase > 0: 206 | nKbase -= self.nKnovel 207 | max_possible_nKbase -= self.nKnovel 208 | 209 | assert(nKbase >= 0 and nKbase <= max_possible_nKbase) 210 | self.nKbase = nKbase 211 | 212 | self.nExemplars = nExemplars 213 | self.nTestNovel = nTestNovel 214 | self.nTestBase = nTestBase 215 | self.batch_size = batch_size 216 | self.epoch_size = epoch_size 217 | self.num_workers = num_workers 218 | self.is_eval_mode = (self.phase=='test') or (self.phase=='val') 219 | 220 | def sampleImageIdsFrom(self, cat_id, sample_size=1): 221 | """ 222 | Samples `sample_size` number of unique image ids picked from the 223 | category `cat_id` (i.e., self.dataset.label2ind[cat_id]). 224 | 225 | Args: 226 | cat_id: a scalar with the id of the category from which images will 227 | be sampled. 228 | sample_size: number of images that will be sampled. 229 | 230 | Returns: 231 | image_ids: a list of length `sample_size` with unique image ids. 232 | """ 233 | assert(cat_id in self.dataset.label2ind) 234 | assert(len(self.dataset.label2ind[cat_id]) >= sample_size) 235 | # Note: random.sample samples elements without replacement. 236 | return random.sample(self.dataset.label2ind[cat_id], sample_size) 237 | 238 | def sampleCategories(self, cat_set, sample_size=1): 239 | """ 240 | Samples `sample_size` number of unique categories picked from the 241 | `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'. 242 | 243 | Args: 244 | cat_set: string that specifies the set of categories from which 245 | categories will be sampled. 246 | sample_size: number of categories that will be sampled. 247 | 248 | Returns: 249 | cat_ids: a list of length `sample_size` with unique category ids. 250 | """ 251 | if cat_set=='base': 252 | labelIds = self.dataset.labelIds_base 253 | elif cat_set=='novel': 254 | labelIds = self.dataset.labelIds_novel 255 | else: 256 | raise ValueError('Not recognized category set {}'.format(cat_set)) 257 | 258 | assert(len(labelIds) >= sample_size) 259 | # return sample_size unique categories chosen from labelIds set of 260 | # categories (that can be either self.labelIds_base or self.labelIds_novel) 261 | # Note: random.sample samples elements without replacement. 262 | return random.sample(labelIds, sample_size) 263 | 264 | def sample_base_and_novel_categories(self, nKbase, nKnovel): 265 | """ 266 | Samples `nKbase` number of base categories and `nKnovel` number of novel 267 | categories. 268 | 269 | Args: 270 | nKbase: number of base categories 271 | nKnovel: number of novel categories 272 | 273 | Returns: 274 | Kbase: a list of length 'nKbase' with the ids of the sampled base 275 | categories. 276 | Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel 277 | categories. 278 | """ 279 | if self.is_eval_mode: 280 | assert(nKnovel <= self.dataset.num_cats_novel) 281 | # sample from the set of base categories 'nKbase' number of base 282 | # categories. 283 | Kbase = sorted(self.sampleCategories('base', nKbase)) 284 | # sample from the set of novel categories 'nKnovel' number of novel 285 | # categories. 286 | Knovel = sorted(self.sampleCategories('novel', nKnovel)) 287 | else: 288 | # sample from the set of base categories 'nKnovel' + 'nKbase' number 289 | # of categories. 290 | cats_ids = self.sampleCategories('base', nKnovel+nKbase) 291 | assert(len(cats_ids) == (nKnovel+nKbase)) 292 | # Randomly pick 'nKnovel' number of fake novel categories and keep 293 | # the rest as base categories. 294 | random.shuffle(cats_ids) 295 | Knovel = sorted(cats_ids[:nKnovel]) 296 | Kbase = sorted(cats_ids[nKnovel:]) 297 | 298 | return Kbase, Knovel 299 | 300 | def sample_test_examples_for_base_categories(self, Kbase, nTestBase): 301 | """ 302 | Sample `nTestBase` number of images from the `Kbase` categories. 303 | 304 | Args: 305 | Kbase: a list of length `nKbase` with the ids of the categories from 306 | where the images will be sampled. 307 | nTestBase: the total number of images that will be sampled. 308 | 309 | Returns: 310 | Tbase: a list of length `nTestBase` with 2-element tuples. The 1st 311 | element of each tuple is the image id that was sampled and the 312 | 2nd elemend is its category label (which is in the range 313 | [0, len(Kbase)-1]). 314 | """ 315 | Tbase = [] 316 | if len(Kbase) > 0: 317 | # Sample for each base category a number images such that the total 318 | # number sampled images of all categories to be equal to `nTestBase`. 319 | KbaseIndices = np.random.choice( 320 | np.arange(len(Kbase)), size=nTestBase, replace=True) 321 | KbaseIndices, NumImagesPerCategory = np.unique( 322 | KbaseIndices, return_counts=True) 323 | 324 | for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory): 325 | imd_ids = self.sampleImageIdsFrom( 326 | Kbase[Kbase_idx], sample_size=NumImages) 327 | Tbase += [(img_id, Kbase_idx) for img_id in imd_ids] 328 | 329 | assert(len(Tbase) == nTestBase) 330 | 331 | return Tbase 332 | 333 | def sample_train_and_test_examples_for_novel_categories( 334 | self, Knovel, nTestNovel, nExemplars, nKbase): 335 | """Samples train and test examples of the novel categories. 336 | 337 | Args: 338 | Knovel: a list with the ids of the novel categories. 339 | nTestNovel: the total number of test images that will be sampled 340 | from all the novel categories. 341 | nExemplars: the number of training examples per novel category that 342 | will be sampled. 343 | nKbase: the number of base categories. It is used as offset of the 344 | category index of each sampled image. 345 | 346 | Returns: 347 | Tnovel: a list of length `nTestNovel` with 2-element tuples. The 348 | 1st element of each tuple is the image id that was sampled and 349 | the 2nd element is its category label (which is in the range 350 | [nKbase, nKbase + len(Knovel) - 1]). 351 | Exemplars: a list of length len(Knovel) * nExemplars of 2-element 352 | tuples. The 1st element of each tuple is the image id that was 353 | sampled and the 2nd element is its category label (which is in 354 | the ragne [nKbase, nKbase + len(Knovel) - 1]). 355 | """ 356 | 357 | if len(Knovel) == 0: 358 | return [], [] 359 | 360 | nKnovel = len(Knovel) 361 | Tnovel = [] 362 | Exemplars = [] 363 | assert((nTestNovel % nKnovel) == 0) 364 | nEvalExamplesPerClass = int(nTestNovel / nKnovel) 365 | 366 | for Knovel_idx in range(len(Knovel)): 367 | imd_ids = self.sampleImageIdsFrom( 368 | Knovel[Knovel_idx], 369 | sample_size=(nEvalExamplesPerClass + nExemplars)) 370 | 371 | imds_tnovel = imd_ids[:nEvalExamplesPerClass] 372 | imds_ememplars = imd_ids[nEvalExamplesPerClass:] 373 | 374 | Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel] 375 | Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars] 376 | assert(len(Tnovel) == nTestNovel) 377 | assert(len(Exemplars) == len(Knovel) * nExemplars) 378 | random.shuffle(Exemplars) 379 | 380 | return Tnovel, Exemplars 381 | 382 | def sample_episode(self): 383 | """Samples a training episode.""" 384 | nKnovel = self.nKnovel 385 | nKbase = self.nKbase 386 | nTestNovel = self.nTestNovel 387 | nTestBase = self.nTestBase 388 | nExemplars = self.nExemplars 389 | 390 | Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel) 391 | Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase) 392 | Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories( 393 | Knovel, nTestNovel, nExemplars, nKbase) 394 | 395 | # concatenate the base and novel category examples. 396 | Test = Tbase + Tnovel 397 | random.shuffle(Test) 398 | Kall = Kbase + Knovel 399 | 400 | return Exemplars, Test, Kall, nKbase 401 | 402 | def createExamplesTensorData(self, examples): 403 | """ 404 | Creates the examples image and label tensor data. 405 | 406 | Args: 407 | examples: a list of 2-element tuples, each representing a 408 | train or test example. The 1st element of each tuple 409 | is the image id of the example and 2nd element is the 410 | category label of the example, which is in the range 411 | [0, nK - 1], where nK is the total number of categories 412 | (both novel and base). 413 | 414 | Returns: 415 | images: a tensor of shape [nExamples, Height, Width, 3] with the 416 | example images, where nExamples is the number of examples 417 | (i.e., nExamples = len(examples)). 418 | labels: a tensor of shape [nExamples] with the category label 419 | of each example. 420 | """ 421 | images = torch.stack( 422 | [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0) 423 | labels = torch.LongTensor([label for _, label in examples]) 424 | return images, labels 425 | 426 | def get_iterator(self, epoch=0): 427 | rand_seed = epoch 428 | random.seed(rand_seed) 429 | np.random.seed(rand_seed) 430 | def load_function(iter_idx): 431 | Exemplars, Test, Kall, nKbase = self.sample_episode() 432 | Xt, Yt = self.createExamplesTensorData(Test) 433 | Kall = torch.LongTensor(Kall) 434 | if len(Exemplars) > 0: 435 | Xe, Ye = self.createExamplesTensorData(Exemplars) 436 | return Xe, Ye, Xt, Yt, Kall, nKbase 437 | else: 438 | return Xt, Yt, Kall, nKbase 439 | 440 | tnt_dataset = tnt.dataset.ListDataset( 441 | elem_list=range(self.epoch_size), load=load_function) 442 | data_loader = tnt_dataset.parallel( 443 | batch_size=self.batch_size, 444 | num_workers=(0 if self.is_eval_mode else self.num_workers), 445 | shuffle=(False if self.is_eval_mode else True)) 446 | 447 | return data_loader 448 | 449 | def __call__(self, epoch=0): 450 | return self.get_iterator(epoch) 451 | 452 | def __len__(self): 453 | return int(self.epoch_size / self.batch_size) 454 | -------------------------------------------------------------------------------- /models/R2D2_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | # Embedding network used in Meta-learning with differentiable closed-form solvers 5 | # (Bertinetto et al., in submission to NIPS 2018). 6 | # They call the ridge rigressor version as "Ridge Regression Differentiable Discriminator (R2D2)." 7 | 8 | # Note that they use a peculiar ordering of functions, namely conv-BN-pooling-lrelu, 9 | # as opposed to the conventional one (conv-BN-lrelu-pooling). 10 | 11 | def R2D2_conv_block(in_channels, out_channels, retain_activation=True, keep_prob=1.0): 12 | block = nn.Sequential( 13 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 14 | nn.BatchNorm2d(out_channels), 15 | nn.MaxPool2d(2) 16 | ) 17 | if retain_activation: 18 | block.add_module("LeakyReLU", nn.LeakyReLU(0.1)) 19 | 20 | if keep_prob < 1.0: 21 | block.add_module("Dropout", nn.Dropout(p=1 - keep_prob, inplace=False)) 22 | 23 | return block 24 | 25 | class R2D2Embedding(nn.Module): 26 | def __init__(self, x_dim=3, h1_dim=96, h2_dim=192, h3_dim=384, z_dim=512, \ 27 | retain_last_activation=False): 28 | super(R2D2Embedding, self).__init__() 29 | 30 | self.block1 = R2D2_conv_block(x_dim, h1_dim) 31 | self.block2 = R2D2_conv_block(h1_dim, h2_dim) 32 | self.block3 = R2D2_conv_block(h2_dim, h3_dim, keep_prob=0.9) 33 | # In the last conv block, we disable activation function to boost the classification accuracy. 34 | # This trick was proposed by Gidaris et al. (CVPR 2018). 35 | # With this trick, the accuracy goes up from 50% to 51%. 36 | # Although the authors of R2D2 did not mention this trick in the paper, 37 | # we were unable to reproduce the result of Bertinetto et al. without resorting to this trick. 38 | self.block4 = R2D2_conv_block(h3_dim, z_dim, retain_activation=retain_last_activation, keep_prob=0.7) 39 | 40 | def forward(self, x): 41 | b1 = self.block1(x) 42 | b2 = self.block2(b1) 43 | b3 = self.block3(b2) 44 | b4 = self.block4(b3) 45 | # Flatten and concatenate the output of the 3rd and 4th conv blocks as proposed in R2D2 paper. 46 | return torch.cat((b3.view(b3.size(0), -1), b4.view(b4.size(0), -1)), 1) -------------------------------------------------------------------------------- /models/ResNet12_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from models.dropblock import DropBlock 5 | 6 | # This ResNet network was designed following the practice of the following papers: 7 | # TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and 8 | # A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018). 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.LeakyReLU(0.1) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.conv3 = conv3x3(planes, planes) 27 | self.bn3 = nn.BatchNorm2d(planes) 28 | self.maxpool = nn.MaxPool2d(stride) 29 | self.downsample = downsample 30 | self.stride = stride 31 | self.drop_rate = drop_rate 32 | self.num_batches_tracked = 0 33 | self.drop_block = drop_block 34 | self.block_size = block_size 35 | self.DropBlock = DropBlock(block_size=self.block_size) 36 | 37 | def forward(self, x): 38 | self.num_batches_tracked += 1 39 | 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | out += residual 56 | out = self.relu(out) 57 | out = self.maxpool(out) 58 | 59 | if self.drop_rate > 0: 60 | if self.drop_block == True: 61 | feat_size = out.size()[2] 62 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 63 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 64 | out = self.DropBlock(out, gamma=gamma) 65 | else: 66 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 67 | 68 | return out 69 | 70 | 71 | class ResNet(nn.Module): 72 | 73 | def __init__(self, block, keep_prob=1.0, avg_pool=False, drop_rate=0.0, dropblock_size=5): 74 | self.inplanes = 3 75 | super(ResNet, self).__init__() 76 | 77 | self.layer1 = self._make_layer(block, 64, stride=2, drop_rate=drop_rate) 78 | self.layer2 = self._make_layer(block, 160, stride=2, drop_rate=drop_rate) 79 | self.layer3 = self._make_layer(block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 80 | self.layer4 = self._make_layer(block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 81 | if avg_pool: 82 | self.avgpool = nn.AvgPool2d(5, stride=1) 83 | self.keep_prob = keep_prob 84 | self.keep_avg_pool = avg_pool 85 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 86 | self.drop_rate = drop_rate 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 91 | elif isinstance(m, nn.BatchNorm2d): 92 | nn.init.constant_(m.weight, 1) 93 | nn.init.constant_(m.bias, 0) 94 | 95 | def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 96 | downsample = None 97 | if stride != 1 or self.inplanes != planes * block.expansion: 98 | downsample = nn.Sequential( 99 | nn.Conv2d(self.inplanes, planes * block.expansion, 100 | kernel_size=1, stride=1, bias=False), 101 | nn.BatchNorm2d(planes * block.expansion), 102 | ) 103 | 104 | layers = [] 105 | layers.append(block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size)) 106 | self.inplanes = planes * block.expansion 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.layer1(x) 112 | x = self.layer2(x) 113 | x = self.layer3(x) 114 | x = self.layer4(x) 115 | if self.keep_avg_pool: 116 | x = self.avgpool(x) 117 | x = x.view(x.size(0), -1) 118 | return x 119 | 120 | 121 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs): 122 | """Constructs a ResNet-12 model. 123 | """ 124 | model = ResNet(BasicBlock, keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 125 | return model 126 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/classification_heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | from qpth.qp import QPFunction 8 | 9 | 10 | def computeGramMatrix(A, B): 11 | """ 12 | Constructs a linear kernel matrix between A and B. 13 | We assume that each row in A and B represents a d-dimensional feature vector. 14 | 15 | Parameters: 16 | A: a (n_batch, n, d) Tensor. 17 | B: a (n_batch, m, d) Tensor. 18 | Returns: a (n_batch, n, m) Tensor. 19 | """ 20 | 21 | assert(A.dim() == 3) 22 | assert(B.dim() == 3) 23 | assert(A.size(0) == B.size(0) and A.size(2) == B.size(2)) 24 | 25 | return torch.bmm(A, B.transpose(1,2)) 26 | 27 | 28 | def binv(b_mat): 29 | """ 30 | Computes an inverse of each matrix in the batch. 31 | Pytorch 0.4.1 does not support batched matrix inverse. 32 | Hence, we are solving AX=I. 33 | 34 | Parameters: 35 | b_mat: a (n_batch, n, n) Tensor. 36 | Returns: a (n_batch, n, n) Tensor. 37 | """ 38 | 39 | id_matrix = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat).cuda() 40 | b_inv, _ = torch.gesv(id_matrix, b_mat) 41 | 42 | return b_inv 43 | 44 | 45 | def one_hot(indices, depth): 46 | """ 47 | Returns a one-hot tensor. 48 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 49 | 50 | Parameters: 51 | indices: a (n_batch, m) Tensor or (m) Tensor. 52 | depth: a scalar. Represents the depth of the one hot dimension. 53 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 54 | """ 55 | 56 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 57 | index = indices.view(indices.size()+torch.Size([1])) 58 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 59 | 60 | return encoded_indicies 61 | 62 | def batched_kronecker(matrix1, matrix2): 63 | matrix1_flatten = matrix1.reshape(matrix1.size()[0], -1) 64 | matrix2_flatten = matrix2.reshape(matrix2.size()[0], -1) 65 | return torch.bmm(matrix1_flatten.unsqueeze(2), matrix2_flatten.unsqueeze(1)).reshape([matrix1.size()[0]] + list(matrix1.size()[1:]) + list(matrix2.size()[1:])).permute([0, 1, 3, 2, 4]).reshape(matrix1.size(0), matrix1.size(1) * matrix2.size(1), matrix1.size(2) * matrix2.size(2)) 66 | 67 | def MetaOptNetHead_Ridge(query, support, support_labels, n_way, n_shot, lambda_reg=50.0, double_precision=False): 68 | """ 69 | Fits the support set with ridge regression and 70 | returns the classification score on the query set. 71 | 72 | Parameters: 73 | query: a (tasks_per_batch, n_query, d) Tensor. 74 | support: a (tasks_per_batch, n_support, d) Tensor. 75 | support_labels: a (tasks_per_batch, n_support) Tensor. 76 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 77 | n_shot: a scalar. Represents the number of support examples given per class. 78 | lambda_reg: a scalar. Represents the strength of L2 regularization. 79 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 80 | """ 81 | 82 | tasks_per_batch = query.size(0) 83 | n_support = support.size(1) 84 | n_query = query.size(1) 85 | 86 | assert(query.dim() == 3) 87 | assert(support.dim() == 3) 88 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 89 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 90 | 91 | #Here we solve the dual problem: 92 | #Note that the classes are indexed by m & samples are indexed by i. 93 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i 94 | 95 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i, 96 | 97 | #\alpha is an (n_support, n_way) matrix 98 | kernel_matrix = computeGramMatrix(support, support) 99 | kernel_matrix += lambda_reg * torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 100 | 101 | block_kernel_matrix = kernel_matrix.repeat(n_way, 1, 1) #(n_way * tasks_per_batch, n_support, n_support) 102 | 103 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_way) 104 | support_labels_one_hot = support_labels_one_hot.transpose(0, 1) # (n_way, tasks_per_batch * n_support) 105 | support_labels_one_hot = support_labels_one_hot.reshape(n_way * tasks_per_batch, n_support) # (n_way*tasks_per_batch, n_support) 106 | 107 | G = block_kernel_matrix 108 | e = -2.0 * support_labels_one_hot 109 | 110 | #This is a fake inequlity constraint as qpth does not support QP without an inequality constraint. 111 | id_matrix_1 = torch.zeros(tasks_per_batch*n_way, n_support, n_support) 112 | C = Variable(id_matrix_1) 113 | h = Variable(torch.zeros((tasks_per_batch*n_way, n_support))) 114 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 115 | 116 | if double_precision: 117 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 118 | 119 | else: 120 | G, e, C, h = [x.float().cuda() for x in [G, e, C, h]] 121 | 122 | # Solve the following QP to fit SVM: 123 | # \hat z = argmin_z 1/2 z^T G z + e^T z 124 | # subject to Cz <= h 125 | # We use detach() to prevent backpropagation to fixed variables. 126 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 127 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), dummy.detach(), dummy.detach(), dummy.detach(), dummy.detach()) 128 | 129 | #qp_sol (n_way*tasks_per_batch, n_support) 130 | qp_sol = qp_sol.reshape(n_way, tasks_per_batch, n_support) 131 | #qp_sol (n_way, tasks_per_batch, n_support) 132 | qp_sol = qp_sol.permute(1, 2, 0) 133 | #qp_sol (tasks_per_batch, n_support, n_way) 134 | 135 | # Compute the classification score. 136 | compatibility = computeGramMatrix(support, query) 137 | compatibility = compatibility.float() 138 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) 139 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) 140 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) 141 | logits = logits * compatibility 142 | logits = torch.sum(logits, 1) 143 | 144 | return logits 145 | 146 | def R2D2Head(query, support, support_labels, n_way, n_shot, l2_regularizer_lambda=50.0): 147 | """ 148 | Fits the support set with ridge regression and 149 | returns the classification score on the query set. 150 | 151 | This model is the classification head described in: 152 | Meta-learning with differentiable closed-form solvers 153 | (Bertinetto et al., in submission to NIPS 2018). 154 | 155 | Parameters: 156 | query: a (tasks_per_batch, n_query, d) Tensor. 157 | support: a (tasks_per_batch, n_support, d) Tensor. 158 | support_labels: a (tasks_per_batch, n_support) Tensor. 159 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 160 | n_shot: a scalar. Represents the number of support examples given per class. 161 | l2_regularizer_lambda: a scalar. Represents the strength of L2 regularization. 162 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 163 | """ 164 | 165 | tasks_per_batch = query.size(0) 166 | n_support = support.size(1) 167 | 168 | assert(query.dim() == 3) 169 | assert(support.dim() == 3) 170 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 171 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 172 | 173 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 174 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 175 | 176 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 177 | 178 | # Compute the dual form solution of the ridge regression. 179 | # W = X^T(X X^T - lambda * I)^(-1) Y 180 | ridge_sol = computeGramMatrix(support, support) + l2_regularizer_lambda * id_matrix 181 | ridge_sol = binv(ridge_sol) 182 | ridge_sol = torch.bmm(support.transpose(1,2), ridge_sol) 183 | ridge_sol = torch.bmm(ridge_sol, support_labels_one_hot) 184 | 185 | # Compute the classification score. 186 | # score = W X 187 | logits = torch.bmm(query, ridge_sol) 188 | 189 | return logits 190 | 191 | 192 | def MetaOptNetHead_SVM_He(query, support, support_labels, n_way, n_shot, C_reg=0.01, double_precision=False): 193 | """ 194 | Fits the support set with multi-class SVM and 195 | returns the classification score on the query set. 196 | 197 | This is the multi-class SVM presented in: 198 | A simplified multi-class support vector machine with reduced dual optimization 199 | (He et al., Pattern Recognition Letter 2012). 200 | 201 | This SVM is desirable because the dual variable of size is n_support 202 | (as opposed to n_way*n_support in the Weston&Watkins or Crammer&Singer multi-class SVM). 203 | This model is the classification head that we have initially used for our project. 204 | This was dropped since it turned out that it performs suboptimally on the meta-learning scenarios. 205 | 206 | Parameters: 207 | query: a (tasks_per_batch, n_query, d) Tensor. 208 | support: a (tasks_per_batch, n_support, d) Tensor. 209 | support_labels: a (tasks_per_batch, n_support) Tensor. 210 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 211 | n_shot: a scalar. Represents the number of support examples given per class. 212 | C_reg: a scalar. Represents the cost parameter C in SVM. 213 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 214 | """ 215 | 216 | tasks_per_batch = query.size(0) 217 | n_support = support.size(1) 218 | n_query = query.size(1) 219 | 220 | assert(query.dim() == 3) 221 | assert(support.dim() == 3) 222 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 223 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 224 | 225 | 226 | kernel_matrix = computeGramMatrix(support, support) 227 | 228 | V = (support_labels * n_way - torch.ones(tasks_per_batch, n_support, n_way).cuda()) / (n_way - 1) 229 | G = computeGramMatrix(V, V).detach() 230 | G = kernel_matrix * G 231 | 232 | e = Variable(-1.0 * torch.ones(tasks_per_batch, n_support)) 233 | id_matrix = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support) 234 | C = Variable(torch.cat((id_matrix, -id_matrix), 1)) 235 | h = Variable(torch.cat((C_reg * torch.ones(tasks_per_batch, n_support), torch.zeros(tasks_per_batch, n_support)), 1)) 236 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 237 | 238 | if double_precision: 239 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 240 | else: 241 | G, e, C, h = [x.cuda() for x in [G, e, C, h]] 242 | 243 | # Solve the following QP to fit SVM: 244 | # \hat z = argmin_z 1/2 z^T G z + e^T z 245 | # subject to Cz <= h 246 | # We use detach() to prevent backpropagation to fixed variables. 247 | qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 248 | 249 | # Compute the classification score. 250 | compatibility = computeGramMatrix(query, support) 251 | compatibility = compatibility.float() 252 | 253 | logits = qp_sol.float().unsqueeze(1).expand(tasks_per_batch, n_query, n_support) 254 | logits = logits * compatibility 255 | logits = logits.view(tasks_per_batch, n_query, n_shot, n_way) 256 | logits = torch.sum(logits, 2) 257 | 258 | return logits 259 | 260 | def ProtoNetHead(query, support, support_labels, n_way, n_shot, normalize=True): 261 | """ 262 | Constructs the prototype representation of each class(=mean of support vectors of each class) and 263 | returns the classification score (=L2 distance to each class prototype) on the query set. 264 | 265 | This model is the classification head described in: 266 | Prototypical Networks for Few-shot Learning 267 | (Snell et al., NIPS 2017). 268 | 269 | Parameters: 270 | query: a (tasks_per_batch, n_query, d) Tensor. 271 | support: a (tasks_per_batch, n_support, d) Tensor. 272 | support_labels: a (tasks_per_batch, n_support) Tensor. 273 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 274 | n_shot: a scalar. Represents the number of support examples given per class. 275 | normalize: a boolean. Represents whether if we want to normalize the distances by the embedding dimension. 276 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 277 | """ 278 | 279 | tasks_per_batch = query.size(0) 280 | n_support = support.size(1) 281 | n_query = query.size(1) 282 | d = query.size(2) 283 | 284 | assert(query.dim() == 3) 285 | assert(support.dim() == 3) 286 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 287 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 288 | 289 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 290 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 291 | 292 | # From: 293 | # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/architectures/PrototypicalNetworksHead.py 294 | #************************* Compute Prototypes ************************** 295 | labels_train_transposed = support_labels_one_hot.transpose(1,2) 296 | # Batch matrix multiplication: 297 | # prototypes = labels_train_transposed * features_train ==> 298 | # [batch_size x nKnovel x num_channels] = 299 | # [batch_size x nKnovel x num_train_examples] * [batch_size * num_train_examples * num_channels] 300 | prototypes = torch.bmm(labels_train_transposed, support) 301 | # Divide with the number of examples per novel category. 302 | prototypes = prototypes.div( 303 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(prototypes) 304 | ) 305 | 306 | # Distance Matrix Vectorization Trick 307 | AB = computeGramMatrix(query, prototypes) 308 | AA = (query * query).sum(dim=2, keepdim=True) 309 | BB = (prototypes * prototypes).sum(dim=2, keepdim=True).reshape(tasks_per_batch, 1, n_way) 310 | logits = AA.expand_as(AB) - 2 * AB + BB.expand_as(AB) 311 | logits = -logits 312 | 313 | if normalize: 314 | logits = logits / d 315 | 316 | return logits 317 | 318 | def MetaOptNetHead_SVM_CS(query, support, support_labels, n_way, n_shot, C_reg=0.1, double_precision=False, maxIter=15): 319 | """ 320 | Fits the support set with multi-class SVM and 321 | returns the classification score on the query set. 322 | 323 | This is the multi-class SVM presented in: 324 | On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines 325 | (Crammer and Singer, Journal of Machine Learning Research 2001). 326 | 327 | This model is the classification head that we use for the final version. 328 | Parameters: 329 | query: a (tasks_per_batch, n_query, d) Tensor. 330 | support: a (tasks_per_batch, n_support, d) Tensor. 331 | support_labels: a (tasks_per_batch, n_support) Tensor. 332 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 333 | n_shot: a scalar. Represents the number of support examples given per class. 334 | C_reg: a scalar. Represents the cost parameter C in SVM. 335 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 336 | """ 337 | 338 | tasks_per_batch = query.size(0) 339 | n_support = support.size(1) 340 | n_query = query.size(1) 341 | 342 | assert(query.dim() == 3) 343 | assert(support.dim() == 3) 344 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 345 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 346 | 347 | #Here we solve the dual problem: 348 | #Note that the classes are indexed by m & samples are indexed by i. 349 | #min_{\alpha} 0.5 \sum_m ||w_m(\alpha)||^2 + \sum_i \sum_m e^m_i alpha^m_i 350 | #s.t. \alpha^m_i <= C^m_i \forall m,i , \sum_m \alpha^m_i=0 \forall i 351 | 352 | #where w_m(\alpha) = \sum_i \alpha^m_i x_i, 353 | #and C^m_i = C if m = y_i, 354 | #C^m_i = 0 if m != y_i. 355 | #This borrows the notation of liblinear. 356 | 357 | #\alpha is an (n_support, n_way) matrix 358 | kernel_matrix = computeGramMatrix(support, support) 359 | 360 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() 361 | block_kernel_matrix = batched_kronecker(kernel_matrix, id_matrix_0) 362 | #This seems to help avoid PSD error from the QP solver. 363 | block_kernel_matrix += 1.0 * torch.eye(n_way*n_support).expand(tasks_per_batch, n_way*n_support, n_way*n_support).cuda() 364 | 365 | support_labels_one_hot = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) # (tasks_per_batch * n_support, n_support) 366 | support_labels_one_hot = support_labels_one_hot.view(tasks_per_batch, n_support, n_way) 367 | support_labels_one_hot = support_labels_one_hot.reshape(tasks_per_batch, n_support * n_way) 368 | 369 | G = block_kernel_matrix 370 | e = -1.0 * support_labels_one_hot 371 | #print (G.size()) 372 | #This part is for the inequality constraints: 373 | #\alpha^m_i <= C^m_i \forall m,i 374 | #where C^m_i = C if m = y_i, 375 | #C^m_i = 0 if m != y_i. 376 | id_matrix_1 = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) 377 | C = Variable(id_matrix_1) 378 | h = Variable(C_reg * support_labels_one_hot) 379 | #print (C.size(), h.size()) 380 | #This part is for the equality constraints: 381 | #\sum_m \alpha^m_i=0 \forall i 382 | id_matrix_2 = torch.eye(n_support).expand(tasks_per_batch, n_support, n_support).cuda() 383 | 384 | A = Variable(batched_kronecker(id_matrix_2, torch.ones(tasks_per_batch, 1, n_way).cuda())) 385 | b = Variable(torch.zeros(tasks_per_batch, n_support)) 386 | #print (A.size(), b.size()) 387 | if double_precision: 388 | G, e, C, h, A, b = [x.double().cuda() for x in [G, e, C, h, A, b]] 389 | else: 390 | G, e, C, h, A, b = [x.float().cuda() for x in [G, e, C, h, A, b]] 391 | 392 | # Solve the following QP to fit SVM: 393 | # \hat z = argmin_z 1/2 z^T G z + e^T z 394 | # subject to Cz <= h 395 | # We use detach() to prevent backpropagation to fixed variables. 396 | qp_sol = QPFunction(verbose=False, maxIter=maxIter)(G, e.detach(), C.detach(), h.detach(), A.detach(), b.detach()) 397 | 398 | # Compute the classification score. 399 | compatibility = computeGramMatrix(support, query) 400 | compatibility = compatibility.float() 401 | compatibility = compatibility.unsqueeze(3).expand(tasks_per_batch, n_support, n_query, n_way) 402 | qp_sol = qp_sol.reshape(tasks_per_batch, n_support, n_way) 403 | logits = qp_sol.float().unsqueeze(2).expand(tasks_per_batch, n_support, n_query, n_way) 404 | logits = logits * compatibility 405 | logits = torch.sum(logits, 1) 406 | 407 | return logits 408 | 409 | def MetaOptNetHead_SVM_WW(query, support, support_labels, n_way, n_shot, C_reg=0.00001, double_precision=False): 410 | """ 411 | Fits the support set with multi-class SVM and 412 | returns the classification score on the query set. 413 | 414 | This is the multi-class SVM presented in: 415 | Support Vector Machines for Multi Class Pattern Recognition 416 | (Weston and Watkins, ESANN 1999). 417 | 418 | Parameters: 419 | query: a (tasks_per_batch, n_query, d) Tensor. 420 | support: a (tasks_per_batch, n_support, d) Tensor. 421 | support_labels: a (tasks_per_batch, n_support) Tensor. 422 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 423 | n_shot: a scalar. Represents the number of support examples given per class. 424 | C_reg: a scalar. Represents the cost parameter C in SVM. 425 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 426 | """ 427 | """ 428 | Fits the support set with multi-class SVM and 429 | returns the classification score on the query set. 430 | 431 | This is the multi-class SVM presented in: 432 | Support Vector Machines for Multi Class Pattern Recognition 433 | (Weston and Watkins, ESANN 1999). 434 | 435 | Parameters: 436 | query: a (tasks_per_batch, n_query, d) Tensor. 437 | support: a (tasks_per_batch, n_support, d) Tensor. 438 | support_labels: a (tasks_per_batch, n_support) Tensor. 439 | n_way: a scalar. Represents the number of classes in a few-shot classification task. 440 | n_shot: a scalar. Represents the number of support examples given per class. 441 | C_reg: a scalar. Represents the cost parameter C in SVM. 442 | Returns: a (tasks_per_batch, n_query, n_way) Tensor. 443 | """ 444 | tasks_per_batch = query.size(0) 445 | n_support = support.size(1) 446 | n_query = query.size(1) 447 | 448 | assert(query.dim() == 3) 449 | assert(support.dim() == 3) 450 | assert(query.size(0) == support.size(0) and query.size(2) == support.size(2)) 451 | assert(n_support == n_way * n_shot) # n_support must equal to n_way * n_shot 452 | 453 | #In theory, \alpha is an (n_support, n_way) matrix 454 | #NOTE: In this implementation, we solve for a flattened vector of size (n_way*n_support) 455 | #In order to turn it into a matrix, you must first reshape it into an (n_way, n_support) matrix 456 | #then transpose it, resulting in (n_support, n_way) matrix 457 | kernel_matrix = computeGramMatrix(support, support) + torch.ones(tasks_per_batch, n_support, n_support).cuda() 458 | 459 | id_matrix_0 = torch.eye(n_way).expand(tasks_per_batch, n_way, n_way).cuda() 460 | block_kernel_matrix = batched_kronecker(id_matrix_0, kernel_matrix) 461 | 462 | kernel_matrix_mask_x = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support) 463 | kernel_matrix_mask_y = support_labels.reshape(tasks_per_batch, 1, n_support).expand(tasks_per_batch, n_support, n_support) 464 | kernel_matrix_mask = (kernel_matrix_mask_x == kernel_matrix_mask_y).float() 465 | 466 | block_kernel_matrix_inter = kernel_matrix_mask * kernel_matrix 467 | block_kernel_matrix += block_kernel_matrix_inter.repeat(1, n_way, n_way) 468 | 469 | kernel_matrix_mask_second_term = support_labels.reshape(tasks_per_batch, n_support, 1).expand(tasks_per_batch, n_support, n_support * n_way) 470 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term == torch.arange(n_way).long().repeat(n_support).reshape(n_support, n_way).transpose(1, 0).reshape(1, -1).repeat(n_support, 1).cuda() 471 | kernel_matrix_mask_second_term = kernel_matrix_mask_second_term.float() 472 | 473 | block_kernel_matrix -= (2.0 - 1e-4) * (kernel_matrix_mask_second_term * kernel_matrix.repeat(1, 1, n_way)).repeat(1, n_way, 1) 474 | 475 | Y_support = one_hot(support_labels.view(tasks_per_batch * n_support), n_way) 476 | Y_support = Y_support.view(tasks_per_batch, n_support, n_way) 477 | Y_support = Y_support.transpose(1, 2) # (tasks_per_batch, n_way, n_support) 478 | Y_support = Y_support.reshape(tasks_per_batch, n_way * n_support) 479 | 480 | G = block_kernel_matrix 481 | 482 | e = -2.0 * torch.ones(tasks_per_batch, n_way * n_support) 483 | id_matrix = torch.eye(n_way * n_support).expand(tasks_per_batch, n_way * n_support, n_way * n_support) 484 | 485 | C_mat = C_reg * torch.ones(tasks_per_batch, n_way * n_support).cuda() - C_reg * Y_support 486 | 487 | C = Variable(torch.cat((id_matrix, -id_matrix), 1)) 488 | #C = Variable(torch.cat((id_matrix_masked, -id_matrix_masked), 1)) 489 | zer = torch.zeros(tasks_per_batch, n_way * n_support).cuda() 490 | 491 | h = Variable(torch.cat((C_mat, zer), 1)) 492 | 493 | dummy = Variable(torch.Tensor()).cuda() # We want to ignore the equality constraint. 494 | 495 | if double_precision: 496 | G, e, C, h = [x.double().cuda() for x in [G, e, C, h]] 497 | else: 498 | G, e, C, h = [x.cuda() for x in [G, e, C, h]] 499 | 500 | # Solve the following QP to fit SVM: 501 | # \hat z = argmin_z 1/2 z^T G z + e^T z 502 | # subject to Cz <= h 503 | # We use detach() to prevent backpropagation to fixed variables. 504 | #qp_sol = QPFunction(verbose=False)(G, e.detach(), C.detach(), h.detach(), dummy.detach(), dummy.detach()) 505 | qp_sol = QPFunction(verbose=False)(G, e, C, h, dummy.detach(), dummy.detach()) 506 | 507 | # Compute the classification score. 508 | compatibility = computeGramMatrix(support, query) + torch.ones(tasks_per_batch, n_support, n_query).cuda() 509 | compatibility = compatibility.float() 510 | compatibility = compatibility.unsqueeze(1).expand(tasks_per_batch, n_way, n_support, n_query) 511 | qp_sol = qp_sol.float() 512 | qp_sol = qp_sol.reshape(tasks_per_batch, n_way, n_support) 513 | A_i = torch.sum(qp_sol, 1) # (tasks_per_batch, n_support) 514 | A_i = A_i.unsqueeze(1).expand(tasks_per_batch, n_way, n_support) 515 | qp_sol = qp_sol.float().unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) 516 | Y_support_reshaped = Y_support.reshape(tasks_per_batch, n_way, n_support) 517 | Y_support_reshaped = A_i * Y_support_reshaped 518 | Y_support_reshaped = Y_support_reshaped.unsqueeze(3).expand(tasks_per_batch, n_way, n_support, n_query) 519 | logits = (Y_support_reshaped - qp_sol) * compatibility 520 | 521 | logits = torch.sum(logits, 2) 522 | 523 | return logits.transpose(1, 2) 524 | 525 | class ClassificationHead(nn.Module): 526 | def __init__(self, base_learner='MetaOptNet', enable_scale=True): 527 | super(ClassificationHead, self).__init__() 528 | if ('SVM-CS' in base_learner): 529 | self.head = MetaOptNetHead_SVM_CS 530 | elif ('Ridge' in base_learner): 531 | self.head = MetaOptNetHead_Ridge 532 | elif ('R2D2' in base_learner): 533 | self.head = R2D2Head 534 | elif ('Proto' in base_learner): 535 | self.head = ProtoNetHead 536 | elif ('SVM-He' in base_learner): 537 | self.head = MetaOptNetHead_SVM_He 538 | elif ('SVM-WW' in base_learner): 539 | self.head = MetaOptNetHead_SVM_WW 540 | else: 541 | print ("Cannot recognize the base learner type") 542 | assert(False) 543 | 544 | # Add a learnable scale 545 | self.enable_scale = enable_scale 546 | self.scale = nn.Parameter(torch.FloatTensor([1.0])) 547 | 548 | def forward(self, query, support, support_labels, n_way, n_shot, **kwargs): 549 | if self.enable_scale: 550 | return self.scale * self.head(query, support, support_labels, n_way, n_shot, **kwargs) 551 | else: 552 | return self.head(query, support, support_labels, n_way, n_shot, **kwargs) 553 | -------------------------------------------------------------------------------- /models/dropblock.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.distributions import Bernoulli 5 | 6 | 7 | class DropBlock(nn.Module): 8 | def __init__(self, block_size): 9 | super(DropBlock, self).__init__() 10 | 11 | self.block_size = block_size 12 | #self.gamma = gamma 13 | #self.bernouli = Bernoulli(gamma) 14 | 15 | def forward(self, x, gamma): 16 | # shape: (bsize, channels, height, width) 17 | 18 | if self.training: 19 | batch_size, channels, height, width = x.shape 20 | 21 | bernoulli = Bernoulli(gamma) 22 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 23 | #print((x.sample[-2], x.sample[-1])) 24 | block_mask = self._compute_block_mask(mask) 25 | #print (block_mask.size()) 26 | #print (x.size()) 27 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 28 | count_ones = block_mask.sum() 29 | 30 | return block_mask * x * (countM / count_ones) 31 | else: 32 | return x 33 | 34 | def _compute_block_mask(self, mask): 35 | left_padding = int((self.block_size-1) / 2) 36 | right_padding = int(self.block_size / 2) 37 | 38 | batch_size, channels, height, width = mask.shape 39 | #print ("mask", mask[0][0]) 40 | non_zero_idxs = mask.nonzero() 41 | nr_blocks = non_zero_idxs.shape[0] 42 | 43 | offsets = torch.stack( 44 | [ 45 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 46 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 47 | ] 48 | ).t().cuda() 49 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 50 | 51 | if nr_blocks > 0: 52 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 53 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 54 | offsets = offsets.long() 55 | 56 | block_idxs = non_zero_idxs + offsets 57 | #block_idxs += left_padding 58 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 59 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 60 | else: 61 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 62 | 63 | block_mask = 1 - padded_mask#[:height, :width] 64 | return block_mask 65 | -------------------------------------------------------------------------------- /models/protonet_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, retain_activation=True): 6 | super(ConvBlock, self).__init__() 7 | 8 | self.block = nn.Sequential( 9 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), 10 | nn.BatchNorm2d(out_channels) 11 | ) 12 | 13 | if retain_activation: 14 | self.block.add_module("ReLU", nn.ReLU(inplace=True)) 15 | self.block.add_module("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) 16 | 17 | def forward(self, x): 18 | out = self.block(x) 19 | return out 20 | 21 | # Embedding network used in Matching Networks (Vinyals et al., NIPS 2016), Meta-LSTM (Ravi & Larochelle, ICLR 2017), 22 | # MAML (w/ h_dim=z_dim=32) (Finn et al., ICML 2017), Prototypical Networks (Snell et al. NIPS 2017). 23 | 24 | class ProtoNetEmbedding(nn.Module): 25 | def __init__(self, x_dim=3, h_dim=64, z_dim=64, retain_last_activation=True): 26 | super(ProtoNetEmbedding, self).__init__() 27 | self.encoder = nn.Sequential( 28 | ConvBlock(x_dim, h_dim), 29 | ConvBlock(h_dim, h_dim), 30 | ConvBlock(h_dim, h_dim), 31 | ConvBlock(h_dim, z_dim, retain_activation=retain_last_activation), 32 | ) 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 36 | m.weight.data.normal_(0, math.sqrt(2. / n)) 37 | elif isinstance(m, nn.BatchNorm2d): 38 | m.weight.data.fill_(1) 39 | m.bias.data.zero_() 40 | 41 | def forward(self, x): 42 | x = self.encoder(x) 43 | return x.view(x.size(0), -1) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1.post2 2 | torchvision==0.2.2.post2 3 | qpth==0.0.13 4 | torchnet 5 | tqdm 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from torch.autograd import Variable 9 | 10 | from tqdm import tqdm 11 | 12 | from models.protonet_embedding import ProtoNetEmbedding 13 | from models.R2D2_embedding import R2D2Embedding 14 | from models.ResNet12_embedding import resnet12 15 | 16 | from models.classification_heads import ClassificationHead 17 | 18 | from utils import pprint, set_gpu, Timer, count_accuracy, log 19 | 20 | import numpy as np 21 | import os 22 | 23 | def get_model(options): 24 | # Choose the embedding network 25 | if options.network == 'ProtoNet': 26 | network = ProtoNetEmbedding().cuda() 27 | elif options.network == 'R2D2': 28 | network = R2D2Embedding().cuda() 29 | elif options.network == 'ResNet': 30 | if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet': 31 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda() 32 | network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3]) 33 | else: 34 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda() 35 | else: 36 | print ("Cannot recognize the network type") 37 | assert(False) 38 | 39 | # Choose the classification head 40 | if opt.head == 'ProtoNet': 41 | cls_head = ClassificationHead(base_learner='ProtoNet').cuda() 42 | elif opt.head == 'Ridge': 43 | cls_head = ClassificationHead(base_learner='Ridge').cuda() 44 | elif opt.head == 'R2D2': 45 | cls_head = ClassificationHead(base_learner='R2D2').cuda() 46 | elif opt.head == 'SVM': 47 | cls_head = ClassificationHead(base_learner='SVM-CS').cuda() 48 | else: 49 | print ("Cannot recognize the classification head type") 50 | assert(False) 51 | 52 | return (network, cls_head) 53 | 54 | def get_dataset(options): 55 | # Choose the embedding network 56 | if options.dataset == 'miniImageNet': 57 | from data.mini_imagenet import MiniImageNet, FewShotDataloader 58 | dataset_test = MiniImageNet(phase='test') 59 | data_loader = FewShotDataloader 60 | elif options.dataset == 'tieredImageNet': 61 | from data.tiered_imagenet import tieredImageNet, FewShotDataloader 62 | dataset_test = tieredImageNet(phase='test') 63 | data_loader = FewShotDataloader 64 | elif options.dataset == 'CIFAR_FS': 65 | from data.CIFAR_FS import CIFAR_FS, FewShotDataloader 66 | dataset_test = CIFAR_FS(phase='test') 67 | data_loader = FewShotDataloader 68 | elif options.dataset == 'FC100': 69 | from data.FC100 import FC100, FewShotDataloader 70 | dataset_test = FC100(phase='test') 71 | data_loader = FewShotDataloader 72 | else: 73 | print ("Cannot recognize the dataset type") 74 | assert(False) 75 | 76 | return (dataset_test, data_loader) 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--gpu', default='0') 81 | parser.add_argument('--load', default='./experiments/exp_1/best_model.pth', 82 | help='path of the checkpoint file') 83 | parser.add_argument('--episode', type=int, default=1000, 84 | help='number of episodes to test') 85 | parser.add_argument('--way', type=int, default=5, 86 | help='number of classes in one test episode') 87 | parser.add_argument('--shot', type=int, default=1, 88 | help='number of support examples per training class') 89 | parser.add_argument('--query', type=int, default=15, 90 | help='number of query examples per training class') 91 | parser.add_argument('--network', type=str, default='ProtoNet', 92 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet') 93 | parser.add_argument('--head', type=str, default='ProtoNet', 94 | help='choose which embedding network to use. ProtoNet, Ridge, R2D2, SVM') 95 | parser.add_argument('--dataset', type=str, default='miniImageNet', 96 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') 97 | 98 | opt = parser.parse_args() 99 | (dataset_test, data_loader) = get_dataset(opt) 100 | 101 | dloader_test = data_loader( 102 | dataset=dataset_test, 103 | nKnovel=opt.way, 104 | nKbase=0, 105 | nExemplars=opt.shot, # num training examples per novel category 106 | nTestNovel=opt.query * opt.way, # num test examples for all the novel categories 107 | nTestBase=0, # num test examples for all the base categories 108 | batch_size=1, 109 | num_workers=1, 110 | epoch_size=opt.episode, # num of batches per epoch 111 | ) 112 | 113 | set_gpu(opt.gpu) 114 | 115 | log_file_path = os.path.join(os.path.dirname(opt.load), "test_log.txt") 116 | log(log_file_path, str(vars(opt))) 117 | 118 | # Define the models 119 | (embedding_net, cls_head) = get_model(opt) 120 | 121 | # Load saved model checkpoints 122 | saved_models = torch.load(opt.load) 123 | embedding_net.load_state_dict(saved_models['embedding']) 124 | embedding_net.eval() 125 | cls_head.load_state_dict(saved_models['head']) 126 | cls_head.eval() 127 | 128 | # Evaluate on test set 129 | test_accuracies = [] 130 | for i, batch in enumerate(tqdm(dloader_test()), 1): 131 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 132 | 133 | n_support = opt.way * opt.shot 134 | n_query = opt.way * opt.query 135 | 136 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 137 | emb_support = emb_support.reshape(1, n_support, -1) 138 | 139 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 140 | emb_query = emb_query.reshape(1, n_query, -1) 141 | 142 | if opt.head == 'SVM': 143 | logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot, maxIter=3) 144 | else: 145 | logits = cls_head(emb_query, emb_support, labels_support, opt.way, opt.shot) 146 | 147 | acc = count_accuracy(logits.reshape(-1, opt.way), labels_query.reshape(-1)) 148 | test_accuracies.append(acc.item()) 149 | 150 | avg = np.mean(np.array(test_accuracies)) 151 | std = np.std(np.array(test_accuracies)) 152 | ci95 = 1.96 * std / np.sqrt(i + 1) 153 | 154 | if i % 50 == 0: 155 | print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} % ({:.2f} %)'\ 156 | .format(i, opt.episode, avg, ci95, acc)) 157 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.autograd import Variable 11 | 12 | from models.classification_heads import ClassificationHead 13 | from models.R2D2_embedding import R2D2Embedding 14 | from models.protonet_embedding import ProtoNetEmbedding 15 | from models.ResNet12_embedding import resnet12 16 | 17 | from utils import set_gpu, Timer, count_accuracy, check_dir, log 18 | 19 | def one_hot(indices, depth): 20 | """ 21 | Returns a one-hot tensor. 22 | This is a PyTorch equivalent of Tensorflow's tf.one_hot. 23 | 24 | Parameters: 25 | indices: a (n_batch, m) Tensor or (m) Tensor. 26 | depth: a scalar. Represents the depth of the one hot dimension. 27 | Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor. 28 | """ 29 | 30 | encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda() 31 | index = indices.view(indices.size()+torch.Size([1])) 32 | encoded_indicies = encoded_indicies.scatter_(1,index,1) 33 | 34 | return encoded_indicies 35 | 36 | def get_model(options): 37 | # Choose the embedding network 38 | if options.network == 'ProtoNet': 39 | network = ProtoNetEmbedding().cuda() 40 | elif options.network == 'R2D2': 41 | network = R2D2Embedding().cuda() 42 | elif options.network == 'ResNet': 43 | if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet': 44 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda() 45 | network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3]) 46 | else: 47 | network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda() 48 | else: 49 | print ("Cannot recognize the network type") 50 | assert(False) 51 | 52 | # Choose the classification head 53 | if options.head == 'ProtoNet': 54 | cls_head = ClassificationHead(base_learner='ProtoNet').cuda() 55 | elif options.head == 'Ridge': 56 | cls_head = ClassificationHead(base_learner='Ridge').cuda() 57 | elif options.head == 'R2D2': 58 | cls_head = ClassificationHead(base_learner='R2D2').cuda() 59 | elif options.head == 'SVM': 60 | cls_head = ClassificationHead(base_learner='SVM-CS').cuda() 61 | else: 62 | print ("Cannot recognize the dataset type") 63 | assert(False) 64 | 65 | return (network, cls_head) 66 | 67 | def get_dataset(options): 68 | # Choose the embedding network 69 | if options.dataset == 'miniImageNet': 70 | from data.mini_imagenet import MiniImageNet, FewShotDataloader 71 | dataset_train = MiniImageNet(phase='train') 72 | dataset_val = MiniImageNet(phase='val') 73 | data_loader = FewShotDataloader 74 | elif options.dataset == 'tieredImageNet': 75 | from data.tiered_imagenet import tieredImageNet, FewShotDataloader 76 | dataset_train = tieredImageNet(phase='train') 77 | dataset_val = tieredImageNet(phase='val') 78 | data_loader = FewShotDataloader 79 | elif options.dataset == 'CIFAR_FS': 80 | from data.CIFAR_FS import CIFAR_FS, FewShotDataloader 81 | dataset_train = CIFAR_FS(phase='train') 82 | dataset_val = CIFAR_FS(phase='val') 83 | data_loader = FewShotDataloader 84 | elif options.dataset == 'FC100': 85 | from data.FC100 import FC100, FewShotDataloader 86 | dataset_train = FC100(phase='train') 87 | dataset_val = FC100(phase='val') 88 | data_loader = FewShotDataloader 89 | else: 90 | print ("Cannot recognize the dataset type") 91 | assert(False) 92 | 93 | return (dataset_train, dataset_val, data_loader) 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--num-epoch', type=int, default=60, 98 | help='number of training epochs') 99 | parser.add_argument('--save-epoch', type=int, default=10, 100 | help='frequency of model saving') 101 | parser.add_argument('--train-shot', type=int, default=15, 102 | help='number of support examples per training class') 103 | parser.add_argument('--val-shot', type=int, default=5, 104 | help='number of support examples per validation class') 105 | parser.add_argument('--train-query', type=int, default=6, 106 | help='number of query examples per training class') 107 | parser.add_argument('--val-episode', type=int, default=2000, 108 | help='number of episodes per validation') 109 | parser.add_argument('--val-query', type=int, default=15, 110 | help='number of query examples per validation class') 111 | parser.add_argument('--train-way', type=int, default=5, 112 | help='number of classes in one training episode') 113 | parser.add_argument('--test-way', type=int, default=5, 114 | help='number of classes in one test (or validation) episode') 115 | parser.add_argument('--save-path', default='./experiments/exp_1') 116 | parser.add_argument('--gpu', default='0, 1, 2, 3') 117 | parser.add_argument('--network', type=str, default='ProtoNet', 118 | help='choose which embedding network to use. ProtoNet, R2D2, ResNet') 119 | parser.add_argument('--head', type=str, default='ProtoNet', 120 | help='choose which classification head to use. ProtoNet, Ridge, R2D2, SVM') 121 | parser.add_argument('--dataset', type=str, default='miniImageNet', 122 | help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100') 123 | parser.add_argument('--episodes-per-batch', type=int, default=8, 124 | help='number of episodes per batch') 125 | parser.add_argument('--eps', type=float, default=0.0, 126 | help='epsilon of label smoothing') 127 | 128 | opt = parser.parse_args() 129 | 130 | (dataset_train, dataset_val, data_loader) = get_dataset(opt) 131 | 132 | # Dataloader of Gidaris & Komodakis (CVPR 2018) 133 | dloader_train = data_loader( 134 | dataset=dataset_train, 135 | nKnovel=opt.train_way, 136 | nKbase=0, 137 | nExemplars=opt.train_shot, # num training examples per novel category 138 | nTestNovel=opt.train_way * opt.train_query, # num test examples for all the novel categories 139 | nTestBase=0, # num test examples for all the base categories 140 | batch_size=opt.episodes_per_batch, 141 | num_workers=4, 142 | epoch_size=opt.episodes_per_batch * 1000, # num of batches per epoch 143 | ) 144 | 145 | dloader_val = data_loader( 146 | dataset=dataset_val, 147 | nKnovel=opt.test_way, 148 | nKbase=0, 149 | nExemplars=opt.val_shot, # num training examples per novel category 150 | nTestNovel=opt.val_query * opt.test_way, # num test examples for all the novel categories 151 | nTestBase=0, # num test examples for all the base categories 152 | batch_size=1, 153 | num_workers=0, 154 | epoch_size=1 * opt.val_episode, # num of batches per epoch 155 | ) 156 | 157 | set_gpu(opt.gpu) 158 | check_dir('./experiments/') 159 | check_dir(opt.save_path) 160 | 161 | log_file_path = os.path.join(opt.save_path, "train_log.txt") 162 | log(log_file_path, str(vars(opt))) 163 | 164 | (embedding_net, cls_head) = get_model(opt) 165 | 166 | optimizer = torch.optim.SGD([{'params': embedding_net.parameters()}, 167 | {'params': cls_head.parameters()}], lr=0.1, momentum=0.9, \ 168 | weight_decay=5e-4, nesterov=True) 169 | 170 | lambda_epoch = lambda e: 1.0 if e < 20 else (0.06 if e < 40 else 0.012 if e < 50 else (0.0024)) 171 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch, last_epoch=-1) 172 | 173 | max_val_acc = 0.0 174 | 175 | timer = Timer() 176 | x_entropy = torch.nn.CrossEntropyLoss() 177 | 178 | for epoch in range(1, opt.num_epoch + 1): 179 | # Train on the training split 180 | lr_scheduler.step() 181 | 182 | # Fetch the current epoch's learning rate 183 | epoch_learning_rate = 0.1 184 | for param_group in optimizer.param_groups: 185 | epoch_learning_rate = param_group['lr'] 186 | 187 | log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format( 188 | epoch, epoch_learning_rate)) 189 | 190 | _, _ = [x.train() for x in (embedding_net, cls_head)] 191 | 192 | train_accuracies = [] 193 | train_losses = [] 194 | 195 | for i, batch in enumerate(tqdm(dloader_train(epoch)), 1): 196 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 197 | 198 | train_n_support = opt.train_way * opt.train_shot 199 | train_n_query = opt.train_way * opt.train_query 200 | 201 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 202 | emb_support = emb_support.reshape(opt.episodes_per_batch, train_n_support, -1) 203 | 204 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 205 | emb_query = emb_query.reshape(opt.episodes_per_batch, train_n_query, -1) 206 | 207 | logit_query = cls_head(emb_query, emb_support, labels_support, opt.train_way, opt.train_shot) 208 | 209 | smoothed_one_hot = one_hot(labels_query.reshape(-1), opt.train_way) 210 | smoothed_one_hot = smoothed_one_hot * (1 - opt.eps) + (1 - smoothed_one_hot) * opt.eps / (opt.train_way - 1) 211 | 212 | log_prb = F.log_softmax(logit_query.reshape(-1, opt.train_way), dim=1) 213 | loss = -(smoothed_one_hot * log_prb).sum(dim=1) 214 | loss = loss.mean() 215 | 216 | acc = count_accuracy(logit_query.reshape(-1, opt.train_way), labels_query.reshape(-1)) 217 | 218 | train_accuracies.append(acc.item()) 219 | train_losses.append(loss.item()) 220 | 221 | if (i % 100 == 0): 222 | train_acc_avg = np.mean(np.array(train_accuracies)) 223 | log(log_file_path, 'Train Epoch: {}\tBatch: [{}/{}]\tLoss: {:.4f}\tAccuracy: {:.2f} % ({:.2f} %)'.format( 224 | epoch, i, len(dloader_train), loss.item(), train_acc_avg, acc)) 225 | 226 | optimizer.zero_grad() 227 | loss.backward() 228 | optimizer.step() 229 | 230 | # Evaluate on the validation split 231 | _, _ = [x.eval() for x in (embedding_net, cls_head)] 232 | 233 | val_accuracies = [] 234 | val_losses = [] 235 | 236 | for i, batch in enumerate(tqdm(dloader_val(epoch)), 1): 237 | data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch] 238 | 239 | test_n_support = opt.test_way * opt.val_shot 240 | test_n_query = opt.test_way * opt.val_query 241 | 242 | emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:]))) 243 | emb_support = emb_support.reshape(1, test_n_support, -1) 244 | emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:]))) 245 | emb_query = emb_query.reshape(1, test_n_query, -1) 246 | 247 | logit_query = cls_head(emb_query, emb_support, labels_support, opt.test_way, opt.val_shot) 248 | 249 | loss = x_entropy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 250 | acc = count_accuracy(logit_query.reshape(-1, opt.test_way), labels_query.reshape(-1)) 251 | 252 | val_accuracies.append(acc.item()) 253 | val_losses.append(loss.item()) 254 | 255 | val_acc_avg = np.mean(np.array(val_accuracies)) 256 | val_acc_ci95 = 1.96 * np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode) 257 | 258 | val_loss_avg = np.mean(np.array(val_losses)) 259 | 260 | if val_acc_avg > max_val_acc: 261 | max_val_acc = val_acc_avg 262 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},\ 263 | os.path.join(opt.save_path, 'best_model.pth')) 264 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'\ 265 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95)) 266 | else: 267 | log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'\ 268 | .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95)) 269 | 270 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\ 271 | , os.path.join(opt.save_path, 'last_epoch.pth')) 272 | 273 | if epoch % opt.save_epoch == 0: 274 | torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()}\ 275 | , os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch))) 276 | 277 | log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(), timer.measure(epoch / float(opt.num_epoch)))) 278 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pprint 4 | import torch 5 | 6 | def set_gpu(x): 7 | os.environ['CUDA_VISIBLE_DEVICES'] = x 8 | print('using gpu:', x) 9 | 10 | def check_dir(path): 11 | ''' 12 | Create directory if it does not exist. 13 | path: Path of directory. 14 | ''' 15 | if not os.path.exists(path): 16 | os.mkdir(path) 17 | 18 | def count_accuracy(logits, label): 19 | pred = torch.argmax(logits, dim=1).view(-1) 20 | label = label.view(-1) 21 | accuracy = 100 * pred.eq(label).float().mean() 22 | return accuracy 23 | 24 | class Timer(): 25 | def __init__(self): 26 | self.o = time.time() 27 | 28 | def measure(self, p=1): 29 | x = (time.time() - self.o) / float(p) 30 | x = int(x) 31 | if x >= 3600: 32 | return '{:.1f}h'.format(x / 3600) 33 | if x >= 60: 34 | return '{}m'.format(round(x / 60)) 35 | return '{}s'.format(x) 36 | 37 | def log(log_file_path, string): 38 | ''' 39 | Write one line of log into screen and file. 40 | log_file_path: Path of log file. 41 | string: String to write in log file. 42 | ''' 43 | with open(log_file_path, 'a+') as f: 44 | f.write(string + '\n') 45 | f.flush() 46 | print(string) --------------------------------------------------------------------------------