├── LICENSE ├── README.md ├── config.py ├── data ├── create_features_db.py ├── lmdb_dataset.py ├── meta_dataset_config.gin ├── meta_dataset_processing.py └── meta_dataset_reader.py ├── models ├── losses.py ├── model_helpers.py ├── model_utils.py ├── models_dict.py ├── resnet18.py ├── resnet18_pnf.py └── sur.py ├── paths.py ├── scripts ├── dump_test_episodes.sh ├── train_networks.sh └── train_pnf.sh ├── test.py ├── test_extractor.py ├── test_offline.py ├── train_net.py ├── train_pnf.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nikita Dvornik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SUR: Selecting from Universal Representations 2 | This repository contains the code to reproduce the few-shot classification experiment on MetaDataset carried out in [Selecting Relevant Features from a Universal Representation for Few-shot Learning](https://arxiv.org/abs/2003.09338). 3 | 4 | ## Dependencies 5 | This code requires the following: 6 | * Python 3.6 or greater 7 | * PyTorch 1.0 or greater 8 | * TensorFlow 1.14 or greater 9 | 10 | 11 | ## Installation 12 | 1. Clone or download this repository. 13 | 2. Configure Meta-Dataset: 14 | * Follow the the "User instructions" in the Meta-Dataset repository (https://github.com/google-research/meta-dataset) for "Installation" and "Downloading and converting datasets". Brace yourself, the full process would take around a day. 15 | 16 | **NOTE:** the MetaDataset codebase has significantly changed from the release of this code and it won't work as is. Please, run `git checkout 056ccac` in the MetaDataset root folder to checkout to the code's version used with in this project. 17 | * If you want to test out-of-domain behavior on additional datasets, namely, MNIST, CIFAR10, CIFAR100, follow the installation instructions in the [CNAPs repository](https://github.com/cambridge-mlg/cnaps) to get these datasets. This step is takes little time and we recommended to do it. 18 | 19 | ## Usage 20 | Here is how to initialize, train and test our method: 21 | #### Initialization 22 | 23 | 1. Before doing anything, first run the following commands. 24 | 25 | ```ulimit -n 50000``` 26 | ```export META_DATASET_ROOT=``` 27 | ```export RECORDS=``` 28 | 29 | Note the above commands need to be run every time you open a new command shell. 30 | 2. Enter the root directory of this project, i.e. the directory where this project was cloned or downloaded. 31 | 32 | #### Getting the Feature Extractors 33 | 1. The easiest way is to download our pre-trained models and use them to obtain a universal set of features directly. 34 | If that is what you want, execute the following command in the root directory of this project: 35 | 36 | ```wget http://thoth.inrialpes.fr/research/SUR/all_weights.zip && unzip all_weights.zip && rm all_weights.zip``` 37 | 38 | It will donwnload all the weights and place them in the `./weights` directory. 39 | 40 | 2. Alternatively, instead of using the pretrained models, one can train the models from scratch. 41 | To train 8 independent feature extractors, run: 42 | 43 | ```./scripts/train_networks.sh``` 44 | 45 | And/or to train a parametric network family, run: 46 | 47 | ```./scripts/train_pnf.sh``` 48 | 49 | 50 | #### Testing 51 | 1. This step would run our SUR procedure to select appropriate features from a universal feature set. 52 | To select from features obtained with different networks, run: 53 | 54 | ```python test.py --model.backbone=resnet18``` 55 | 56 | To select from features obtained with a parametric network family, run: 57 | 58 | ```python test.py --model.backbone=resnet18_pnf``` 59 | 60 | Note: If you train the models yourself, be sure you have trained the corresponding extractors. 61 | 62 | #### Offline Testing (optional) 63 | To speed up the testing procedure, one could first dump the features on the hard drive, and then use them for selection directly, without needing to run a CNN. To do so, follow the steps: 64 | 1. Dump test features extracted from the test episodes on your hard drive by running 65 | 66 | ```./scripts/dump_test_episodes.sh``` 67 | 68 | 2. Test SUR offline. Depending on your desired feature extractor, run: 69 | 70 | ```python test_offline.py --model.backbone=resnet18``` or ```python test_offline.py --model.backbone=resnet18_pnf``` 71 | 72 | This step is useful for those who want to experiment with selection by SUR and want to avoid recomputing the same features every run. 73 | 74 | ## Expected Results 75 | Below are the results extracted from our papers. The results will vary from run to run by a percent or two up or 76 | down due to the fact that the Meta-Dataset reader generates different tasks each run, due randomnes in training the networks and in SUR optimization. 77 | The SUR method selects from 8 independently trained feature extractors, while SUR-pnf selectrs from outputs of a parametric 78 | network family, which has fewer parameters. More details about that could be found in the original paper. 79 | 80 | **Models trained on all datasets** 81 | 82 | | Dataset | SUR | SUR-pnf | 83 | | --- | --- | --- | 84 | | Imagenet | 56.1±1.1 | 56.0±1.1 | 85 | | Omniglot | 93.1±0.5 | 90.0±0.6 | 86 | | Aircraft | 84.6±0.7 | 79.7±0.8 | 87 | | Birds | 70.6±1.0 | 75.9±0.9 | 88 | | Textures | 71.0±0.8 | 72.5±0.7 | 89 | | Quick Draw | 81.3±0.6 | 76.7±0.7 | 90 | | Fungi | 64.2±1.1 | 49.8±1.1 | 91 | | VGG Flower | 82.8±0.8 | 90.0±0.6 | 92 | | Traffic Signs | 53.4±1.0 | 52.2±0.8 | 93 | | MSCOCO | 50.1±1.0 | 50.2±1.0 | 94 | | MNIST | 94.5±0.5 | 93.1±0.4 | 95 | | CIFAR10 | 64.1±1.0 | 65.9±0.8 | 96 | | CIFAR100 | 56.1±1.0 | 57.1±1.0 | 97 | 98 | 99 | 100 | ## Citation 101 | If you use this code, please cite our [Selecting Relevant Features from a Universal Representation for Few-shot Learning](https://arxiv.org/abs/2003.09338) paper: 102 | ``` 103 | @article{dvornik2020selecting, 104 | title={Selecting Relevant Features from a Universal Representation for Few-shot Classification}, 105 | author={Dvornik, Nikita and Schmid, Cordelia and Mairal, Julien}, 106 | journal={arXiv preprint arXiv:2003.09338}, 107 | year={2020} 108 | } 109 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description='Train prototypical networks') 3 | 4 | # data args 5 | parser.add_argument('--data.train', type=str, default='cu_birds', metavar='TRAINSETS', nargs='+', help="Datasets for training extractors") 6 | parser.add_argument('--data.val', type=str, default='cu_birds', metavar='VALSETS', nargs='+', 7 | help="Datasets used for validation") 8 | parser.add_argument('--data.test', type=str, default='cu_birds', metavar='TESTSETS', nargs='+', 9 | help="Datasets used for testing") 10 | parser.add_argument('--data.num_workers', type=int, default=32, metavar='NEPOCHS', 11 | help="Number of workers that pre-process images in parallel") 12 | 13 | # model args 14 | default_model_name = 'noname' 15 | parser.add_argument('--model.name', type=str, default=default_model_name, metavar='MODELNAME', 16 | help="A name you give to the extractor".format(default_model_name)) 17 | parser.add_argument('--model.backbone', default='resnet18', help="Use ResNet18 for experiments (default: False)") 18 | parser.add_argument('--model.classifier', type=str, default='cosine', choices=['none', 'linear', 'cosine'], help="Do classification using cosine similatity between activations and weights") 19 | parser.add_argument('--model.dropout', type=float, default=0, help="Adding dropout inside a basic block of widenet") 20 | 21 | # train args 22 | parser.add_argument('--train.batch_size', type=int, default=16, metavar='BS', 23 | help='number of images in a batch') 24 | parser.add_argument('--train.max_iter', type=int, default=500000, metavar='NEPOCHS', 25 | help='number of epochs to train (default: 10000)') 26 | parser.add_argument('--train.weight_decay', type=float, default=7e-4, metavar='WD', 27 | help="weight decay coef") 28 | parser.add_argument('--train.optimizer', type=str, default='momentum', metavar='OPTIM', 29 | help='optimization method (default: momentum)') 30 | 31 | parser.add_argument('--train.learning_rate', type=float, default=0.01, metavar='LR', 32 | help='learning rate (default: 0.0001)') 33 | parser.add_argument('--train.lr_policy', type=str, default='cosine', metavar='LR_policy', 34 | help='learning rate decay policy') 35 | parser.add_argument('--train.lr_decay_step_gamma', type=int, default=1e-1, metavar='DECAY_GAMMA', 36 | help='the value to divide learning rate by when decayin lr') 37 | parser.add_argument('--train.lr_decay_step_freq', type=int, default=10000, metavar='DECAY_FREQ', 38 | help='the value to divide learning rate by when decayin lr') 39 | parser.add_argument('--train.exp_decay_final_lr', type=float, default=8e-5, metavar='FINAL_LR', 40 | help='the value to divide learning rate by when decayin lr') 41 | parser.add_argument('--train.exp_decay_start_iter', type=int, default=30000, metavar='START_ITER', 42 | help='the value to divide learning rate by when decayin lr') 43 | parser.add_argument('--train.cosine_anneal_freq', type=int, default=4000, metavar='ANNEAL_FREQ', 44 | help='the value to divide learning rate by when decayin lr') 45 | parser.add_argument('--train.nesterov_momentum', action='store_true', help="If to augment query images in order to avearge the embeddings") 46 | 47 | 48 | # evaluation during training 49 | parser.add_argument('--train.eval_freq', type=int, default=5000, metavar='EVAL_FREQ', 50 | help='How often to evaluate model during training') 51 | parser.add_argument('--train.eval_size', type=int, default=300, metavar='EVAL_SIZE', 52 | help='How many episodes to sample for validation') 53 | parser.add_argument('--train.resume', type=int, default=1, metavar='RESUME_TRAIN', 54 | help="Resume training starting from the last checkpoint (default: True)") 55 | 56 | 57 | # creating a database of features 58 | parser.add_argument('--dump.name', type=str, default='', metavar='DUMP_NAME', 59 | help='Name for dumped dataset of features') 60 | parser.add_argument('--dump.mode', type=str, default='test', metavar='DUMP_MODE', 61 | help='What split of the original dataset to dump') 62 | parser.add_argument('--dump.size', type=int, default=600, metavar='DUMP_SIZE', 63 | help='Howe many episodes to dump') 64 | 65 | 66 | # test args 67 | parser.add_argument('--test.size', type=int, default=600, metavar='TEST_SIZE', 68 | help='The number of test episodes sampled') 69 | parser.add_argument('--test.distance', type=str, choices=['cos', 'l2'], default='cos', metavar='DISTANCE_FN', 70 | help="If to augment support images in order to avearge the embeddings") 71 | 72 | # log args 73 | args = vars(parser.parse_args()) 74 | -------------------------------------------------------------------------------- /data/create_features_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import tensorflow as tf 6 | from tqdm import tqdm 7 | 8 | import json 9 | import lmdb 10 | import pickle as pkl 11 | 12 | sys.path.insert(0, '/'.join(os.path.realpath(__file__).split('/')[:-2])) 13 | 14 | from data.meta_dataset_reader import MetaDatasetEpisodeReader, MetaDatasetBatchReader 15 | from models.model_utils import CheckPointer 16 | from models.models_dict import DATASET_MODELS_DICT 17 | from models.model_helpers import get_domain_extractors, get_model 18 | from config import args 19 | from paths import META_DATA_ROOT 20 | from utils import check_dir, SerializableArray 21 | 22 | 23 | 24 | class DatasetWriter(object): 25 | def __init__(self, args, rewrite=True, write_frequency=10): 26 | self._mode = args['dump.mode'] 27 | self._write_frequency = write_frequency 28 | self._db = None 29 | self.args = args 30 | self.dataset_models = DATASET_MODELS_DICT[args['model.backbone']] 31 | print(self.dataset_models) 32 | 33 | trainsets, valsets, testsets = args['data.train'], args['data.val'], args['data.test'] 34 | loader = MetaDatasetEpisodeReader(self._mode, trainsets, valsets, testsets) 35 | self._map_size = 50000 * 100 ** 2 * 512 * 8 36 | self.trainsets = trainsets 37 | 38 | if self._mode == 'train': 39 | evalset = "allcat" 40 | self.load_sample = lambda sess: loader.get_train_task(sess) 41 | elif self._mode == 'test': 42 | evalset = testsets[0] 43 | self.load_sample = lambda sess: loader.get_test_task(sess, evalset) 44 | elif self._mode == 'val': 45 | evalset = valsets[0] 46 | self.load_sample = lambda sess: loader.get_validation_task(sess, evalset) 47 | 48 | dump_name = self._mode + '_dump' if not args['dump.name'] else args['dump.name'] 49 | path = check_dir(os.path.join(META_DATA_ROOT, 'Dumps', self.args['model.backbone'], 50 | self._mode, evalset, dump_name)) 51 | self._path = path 52 | if not (os.path.exists(path)): 53 | os.mkdir(path) 54 | self._keys_file = os.path.join(path, 'keys') 55 | self._keys = [] 56 | if os.path.exists(path) and not rewrite: 57 | raise NameError("Dataset {} already exists.".format(self._path)) 58 | 59 | # do not initialize during __init__ to avoid pickling error when using MPI 60 | def init(self): 61 | self._db = lmdb.open(self._path, map_size=self._map_size, map_async=True) 62 | self.embed_many = get_domain_extractors(self.trainsets, self.dataset_models, self.args) 63 | 64 | def close(self): 65 | keys = tuple(self._keys) 66 | pkl.dump(keys, open(self._keys_file, 'wb')) 67 | 68 | if self._db is not None: 69 | self._db.sync() 70 | self._db.close() 71 | self._db = None 72 | 73 | def encode_dataset(self, n_tasks=1000): 74 | if self._db is None: 75 | self.init() 76 | 77 | txn = self._db.begin(write=True) 78 | config = tf.compat.v1.ConfigProto() 79 | config.gpu_options.allow_growth = True 80 | with tf.compat.v1.Session(config=config) as session: 81 | # Sampling task idxs 82 | for idx in tqdm(range(n_tasks)): 83 | # compressing image 84 | sample = self.load_sample(session) 85 | # Embedding task images with my network 86 | support_embed_dict = self.embed_many(sample['context_images']) 87 | query_embed_dict = self.embed_many(sample['target_images']) 88 | # Putting the data into containers 89 | support_labels = SerializableArray(sample['context_labels'].detach().cpu().numpy()) 90 | query_labels = SerializableArray(sample['target_labels'].detach().cpu().numpy()) 91 | SerializableArray.__module__ = 'utils' 92 | 93 | # writing 94 | for dataset in support_embed_dict.keys(): 95 | support_batch = SerializableArray( 96 | support_embed_dict[dataset].detach().cpu().numpy()) 97 | query_batch = SerializableArray( 98 | query_embed_dict[dataset].detach().cpu().numpy()) 99 | SerializableArray.__module__ = 'utils' 100 | txn.put(f"{idx}_{dataset}_support".encode("ascii"), pkl.dumps(support_batch)) 101 | txn.put(f"{idx}_{dataset}_query".encode("ascii"), pkl.dumps(query_batch)) 102 | self._keys.extend([f"{idx}_{dataset}_support", f"{idx}_{dataset}_query"]) 103 | txn.put(f"{idx}_labels_support".encode("ascii"), pkl.dumps(support_labels)) 104 | txn.put(f"{idx}_labels_query".encode("ascii"), pkl.dumps(query_labels)) 105 | self._keys.extend([f"{idx}_labels_support", f"{idx}_labels_query"]) 106 | 107 | # flushing into lmdb (on the disk) 108 | if idx > 0 and idx % self._write_frequency == 0: 109 | txn.commit() 110 | txn = self._db.begin(write=True) 111 | txn.commit() 112 | 113 | 114 | if __name__ == '__main__': 115 | dr = DatasetWriter(args) 116 | dr.init() 117 | dr.encode_dataset(args['dump.size']) 118 | dr.close() 119 | print('Done') 120 | -------------------------------------------------------------------------------- /data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import lmdb 8 | import pickle as pkl 9 | from utils import SerializableArray, device 10 | 11 | from paths import META_DATA_ROOT 12 | 13 | class LMDBDataset: 14 | """ 15 | Opens several LMDB readers and loads data from there 16 | """ 17 | def __init__(self, extractor_domains, datasets, backbone, mode, dump_name, limit_len=None): 18 | self.mode = mode 19 | self.datasets = datasets 20 | 21 | # opening lmdbs 22 | self.dataset_readers = dict() 23 | for evalset in datasets: 24 | all_names = os.listdir(os.path.join(META_DATA_ROOT, 'Dumps', 25 | backbone, mode, evalset)) 26 | self.dataset_readers[evalset] = [ 27 | DatasetReader(extractor_domains, evalset, backbone, mode, name) 28 | for name in all_names if dump_name in name] 29 | self._current_sampling_dataset = datasets[0] 30 | self.full_len = sum([len(ds) for ds in self.dataset_readers[self._current_sampling_dataset]]) 31 | if limit_len is not None: 32 | self.full_len = min(self.full_len, limit_len) 33 | 34 | def __len__(self): 35 | return self.full_len 36 | 37 | def __getitem__(self, idx): 38 | if self.mode == 'train': 39 | random_lmdb_subset = random.sample(self.dataset_readers[self._current_sampling_dataset], 1)[0] 40 | idx = random.sample(range(len(random_lmdb_subset)), 1)[0] 41 | sample = random_lmdb_subset[idx] 42 | else: 43 | sample = self.dataset_readers[self._current_sampling_dataset][0][idx] 44 | 45 | for key, val in sample.items(): 46 | if isinstance(val, str): 47 | pass 48 | if 'label' in key: 49 | sample[key] = torch.from_numpy(val).long() 50 | elif 'feature_dict' in key: 51 | for fkey, fval in sample[key].items(): 52 | sample[key][fkey] = torch.from_numpy(fval) 53 | 54 | return sample 55 | 56 | def set_sampling_dataset(self, sampling_dataset): 57 | self._current_sampling_dataset = sampling_dataset 58 | 59 | # open lmdb environment and transaction 60 | # load keys from cache 61 | def _load_db(self, info, class_id): 62 | path = self._path 63 | 64 | self._env = lmdb.open( 65 | self._path, 66 | readonly=True, 67 | lock=False, 68 | readahead=False, 69 | meminit=False) 70 | self._txn = self._env.begin(write=False) 71 | 72 | if class_id is None: 73 | cache_file = os.path.join(path, 'keys') 74 | if os.path.isfile(cache_file): 75 | self.keys = pkl.load(open(cache_file, 'rb')) 76 | else: 77 | print('Loading dataset keys...') 78 | with self._env.begin(write=False) as txn: 79 | self.keys = [key.decode('ascii') 80 | for key, _ in tqdm(txn.cursor())] 81 | pkl.dump(self.keys, open(cache_file, 'wb')) 82 | else: 83 | self.keys = [str(k).encode() for k in info['labels2keys'][str(class_id)]] 84 | 85 | if not self.keys: 86 | raise ValueError('Empty dataset.') 87 | 88 | def eval(self): 89 | self.mode = 'eval' 90 | 91 | def train(self): 92 | self.mode = 'train' 93 | 94 | def transform(self, x): 95 | if self.mode == 'train': 96 | out = self.train_transform(x) if self.train_transform else x 97 | return out 98 | else: 99 | out = self.test_transform(x) if self.test_transform else x 100 | return out 101 | 102 | 103 | class DatasetReader(object): 104 | """ 105 | Opens a single LMDB file, containing dumped activations for a dataset, 106 | and samples data from it. 107 | """ 108 | def __init__(self, extractor_domains, evalset, backbone, mode, name): 109 | self._mode = mode 110 | self._env = None 111 | self._txn = None 112 | self.keys = None 113 | 114 | self.trainsets = extractor_domains 115 | path = os.path.join(META_DATA_ROOT, 'Dumps', backbone, mode, evalset, name) 116 | self._path = path 117 | 118 | self._load_db() 119 | 120 | def __len__(self): 121 | return self.full_len 122 | 123 | def _load_db(self): 124 | path = self._path 125 | 126 | self._env = lmdb.open( 127 | self._path, 128 | readonly=True, 129 | lock=False, 130 | readahead=False, 131 | meminit=False) 132 | self._txn = self._env.begin(write=False) 133 | 134 | cache_file = os.path.join(path, 'keys') 135 | if os.path.isfile(cache_file): 136 | self.keys = pkl.load(open(cache_file, 'rb')) 137 | else: 138 | print('Loading dataset keys...') 139 | with self._env.begin(write=False) as txn: 140 | self.keys = [key.decode('ascii') 141 | for key, _ in tqdm(txn.cursor())] 142 | pkl.dump(self.keys, open(cache_file, 'wb')) 143 | self.full_len = len(self.keys) // 18 144 | 145 | def __getitem__(self, idx): 146 | sample = dict() 147 | support_labels = pkl.loads(self._txn.get(f"{idx}_labels_support".encode("ascii"))) 148 | query_labels = pkl.loads(self._txn.get(f"{idx}_labels_query".encode("ascii"))) 149 | sample['context_labels'] = support_labels.get() 150 | sample['target_labels'] = query_labels.get() 151 | 152 | sample['context_feature_dict'] = dict() 153 | sample['target_feature_dict'] = dict() 154 | for dataset in self.trainsets: 155 | support_batch = pkl.loads(self._txn.get(f"{idx}_{dataset}_support".encode("ascii"))) 156 | query_batch = pkl.loads(self._txn.get(f"{idx}_{dataset}_query".encode("ascii"))) 157 | sample['context_feature_dict'][dataset] = support_batch.get() 158 | sample['target_feature_dict'][dataset] = query_batch.get() 159 | return sample 160 | -------------------------------------------------------------------------------- /data/meta_dataset_config.gin: -------------------------------------------------------------------------------- 1 | import data.meta_dataset_processing 2 | import meta_dataset.data.decoder 3 | 4 | # Default values for sampling variable shots / ways. 5 | EpisodeDescriptionConfig.min_ways = 5 6 | EpisodeDescriptionConfig.max_ways_upper_bound = 50 7 | EpisodeDescriptionConfig.max_num_query = 10 8 | EpisodeDescriptionConfig.max_support_set_size = 500 9 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100 10 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5) 11 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2) 12 | EpisodeDescriptionConfig.ignore_dag_ontology = False 13 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False 14 | 15 | # Other default values for the data pipeline. 16 | DataConfig.image_height = 84 17 | DataConfig.shuffle_buffer_size = 1000 18 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 19 | DataConfig.num_prefetch = 400 20 | 21 | meta_dataset_processing.ImageDecoder.image_size = 84 22 | 23 | # If we decode features then change the lines below to use FeatureDecoder. 24 | process_episode.support_decoder = @support/meta_dataset_processing.ImageDecoder() 25 | support/meta_dataset_processing.ImageDecoder.data_augmentation = @support/meta_dataset_processing.DataAugmentation() 26 | support/meta_dataset_processing.DataAugmentation.enable_jitter = True 27 | support/meta_dataset_processing.DataAugmentation.jitter_amount = 0 28 | support/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = True 29 | support/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 30 | support/meta_dataset_processing.DataAugmentation.enable_random_flip = False 31 | support/meta_dataset_processing.DataAugmentation.enable_random_brightness = False 32 | support/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0 33 | support/meta_dataset_processing.DataAugmentation.enable_random_contrast = False 34 | support/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0 35 | support/meta_dataset_processing.DataAugmentation.enable_random_hue = False 36 | support/meta_dataset_processing.DataAugmentation.random_hue_delta = 0 37 | support/meta_dataset_processing.DataAugmentation.enable_random_saturation = False 38 | support/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0 39 | 40 | process_episode.query_decoder = @query/meta_dataset_processing.ImageDecoder() 41 | query/meta_dataset_processing.ImageDecoder.data_augmentation = @query/meta_dataset_processing.DataAugmentation() 42 | query/meta_dataset_processing.DataAugmentation.enable_jitter = False 43 | query/meta_dataset_processing.DataAugmentation.jitter_amount = 0 44 | query/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = False 45 | query/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 46 | query/meta_dataset_processing.DataAugmentation.enable_random_flip = False 47 | query/meta_dataset_processing.DataAugmentation.enable_random_brightness = False 48 | query/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0 49 | query/meta_dataset_processing.DataAugmentation.enable_random_contrast = False 50 | query/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0 51 | query/meta_dataset_processing.DataAugmentation.enable_random_hue = False 52 | query/meta_dataset_processing.DataAugmentation.random_hue_delta = 0 53 | query/meta_dataset_processing.DataAugmentation.enable_random_saturation = False 54 | query/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0 55 | 56 | process_batch.batch_decoder = @batch/meta_dataset_processing.ImageDecoder() 57 | batch/meta_dataset_processing.ImageDecoder.data_augmentation = @batch/meta_dataset_processing.DataAugmentation() 58 | batch/meta_dataset_processing.DataAugmentation.enable_jitter = True 59 | batch/meta_dataset_processing.DataAugmentation.jitter_amount = 8 60 | batch/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = True 61 | batch/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 62 | batch/meta_dataset_processing.DataAugmentation.enable_random_flip = False 63 | batch/meta_dataset_processing.DataAugmentation.enable_random_brightness = True 64 | batch/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0.125 65 | batch/meta_dataset_processing.DataAugmentation.enable_random_contrast = True 66 | batch/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0.2 67 | batch/meta_dataset_processing.DataAugmentation.enable_random_hue = True 68 | batch/meta_dataset_processing.DataAugmentation.random_hue_delta = 0.03 69 | batch/meta_dataset_processing.DataAugmentation.enable_random_saturation = True 70 | batch/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0.2 -------------------------------------------------------------------------------- /data/meta_dataset_processing.py: -------------------------------------------------------------------------------- 1 | import gin.tf 2 | import tensorflow.compat.v1 as tf 3 | 4 | 5 | @gin.configurable 6 | class DataAugmentation(object): 7 | """Configurations for performing data augmentation.""" 8 | 9 | def __init__(self, enable_jitter, jitter_amount, enable_gaussian_noise, 10 | gaussian_noise_std, enable_random_flip, 11 | enable_random_brightness, random_brightness_delta, 12 | enable_random_contrast, random_contrast_delta, 13 | enable_random_hue, random_hue_delta, enable_random_saturation, 14 | random_saturation_delta): 15 | """Initialize a DataAugmentation. 16 | 17 | Args: 18 | enable_jitter: bool whether to use image jitter (pad each image using 19 | reflection along x and y axes and then random crop). 20 | jitter_amount: amount (in pixels) to pad on all sides of the image. 21 | enable_gaussian_noise: bool whether to use additive Gaussian noise. 22 | gaussian_noise_std: Standard deviation of the Gaussian distribution. 23 | """ 24 | self.enable_jitter = enable_jitter 25 | self.jitter_amount = jitter_amount 26 | self.enable_gaussian_noise = enable_gaussian_noise 27 | self.gaussian_noise_std = gaussian_noise_std 28 | self.enable_random_flip = enable_random_flip 29 | self.enable_random_brightness = enable_random_brightness 30 | self.random_brightness_delta = random_brightness_delta 31 | self.enable_random_contrast = enable_random_contrast 32 | self.random_contrast_delta = random_contrast_delta 33 | self.enable_random_hue = enable_random_hue 34 | self.random_hue_delta = random_hue_delta 35 | self.enable_random_saturation = enable_random_saturation 36 | self.random_saturation_delta = random_saturation_delta 37 | 38 | 39 | @gin.configurable 40 | class ImageDecoder(object): 41 | """Image decoder.""" 42 | 43 | def __init__(self, image_size=None, data_augmentation=None): 44 | """Class constructor. 45 | 46 | Args: 47 | image_size: int, desired image size. The extracted image will be resized 48 | to `[image_size, image_size]`. 49 | data_augmentation: A DataAugmentation object with parameters for 50 | perturbing the images. 51 | """ 52 | 53 | self.image_size = image_size 54 | self.data_augmentation = data_augmentation 55 | 56 | def __call__(self, example_string): 57 | """Processes a single example string. 58 | 59 | Extracts and processes the image, and ignores the label. We assume that the 60 | image has three channels. 61 | 62 | Args: 63 | example_string: str, an Example protocol buffer. 64 | 65 | Returns: 66 | image_rescaled: the image, resized to `image_size x image_size` and 67 | rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values 68 | to go beyond this range. 69 | """ 70 | image_string = tf.parse_single_example( 71 | example_string, 72 | features={ 73 | 'image': tf.FixedLenFeature([], dtype=tf.string), 74 | 'label': tf.FixedLenFeature([], tf.int64) 75 | })['image'] 76 | image_decoded = tf.image.decode_image(image_string, channels=3) 77 | image_decoded.set_shape([None, None, 3]) 78 | image_resized = tf.image.resize_images( 79 | image_decoded, [self.image_size, self.image_size], 80 | method=tf.image.ResizeMethod.BILINEAR, 81 | align_corners=True) 82 | image = tf.cast(image_resized, tf.float32) 83 | 84 | if self.data_augmentation is not None: 85 | if self.data_augmentation.enable_random_brightness: 86 | delta = self.data_augmentation.random_brightness_delta 87 | image = tf.image.random_brightness(image, delta) 88 | 89 | if self.data_augmentation.enable_random_saturation: 90 | delta = self.data_augmentation.random_saturation_delta 91 | image = tf.image.random_saturation(image, 1 - delta, 1 + delta) 92 | 93 | if self.data_augmentation.enable_random_contrast: 94 | delta = self.data_augmentation.random_contrast_delta 95 | image = tf.image.random_contrast(image, 1 - delta, 1 + delta) 96 | 97 | if self.data_augmentation.enable_random_hue: 98 | delta = self.data_augmentation.random_hue_delta 99 | image = tf.image.random_hue(image, delta) 100 | 101 | if self.data_augmentation.enable_random_flip: 102 | image = tf.image.random_flip_left_right(image) 103 | 104 | image = 2 * (image / 255.0 - 0.5) # Rescale to [-1, 1]. 105 | 106 | if self.data_augmentation is not None: 107 | if self.data_augmentation.enable_gaussian_noise: 108 | image = image + tf.random_normal( 109 | tf.shape(image)) * self.data_augmentation.gaussian_noise_std 110 | 111 | if self.data_augmentation.enable_jitter: 112 | j = self.data_augmentation.jitter_amount 113 | paddings = tf.constant([[j, j], [j, j], [0, 0]]) 114 | image = tf.pad(image, paddings, 'REFLECT') 115 | image = tf.image.random_crop(image, 116 | [self.image_size, self.image_size, 3]) 117 | 118 | return image 119 | -------------------------------------------------------------------------------- /data/meta_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import sys 4 | import torch 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from utils import device 9 | from paths import META_DATASET_ROOT, META_RECORDS_ROOT, PROJECT_ROOT 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Quiet the TensorFlow warnings 11 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # Quiet the TensorFlow warnings 12 | 13 | sys.path.append(os.path.abspath(META_DATASET_ROOT)) 14 | from meta_dataset.data import dataset_spec as dataset_spec_lib 15 | from meta_dataset.data import learning_spec 16 | from meta_dataset.data import pipeline 17 | from meta_dataset.data import config 18 | 19 | 20 | ALL_METADATASET_NAMES = "ilsvrc_2012 omniglot aircraft cu_birds dtd quickdraw fungi vgg_flower traffic_sign mscoco mnist cifar10 cifar100".split(' ') 21 | TRAIN_METADATASET_NAMES = ALL_METADATASET_NAMES[:8] 22 | TEST_METADATASET_NAMES = ALL_METADATASET_NAMES[-5:] 23 | 24 | SPLIT_NAME_TO_SPLIT = {'train': learning_spec.Split.TRAIN, 25 | 'val': learning_spec.Split.VALID, 26 | 'test': learning_spec.Split.TEST} 27 | 28 | 29 | class MetaDatasetReader(object): 30 | def __init__(self, mode, train_set, validation_set, test_set): 31 | assert (train_set is not None or validation_set is not None or test_set is not None) 32 | 33 | self.data_path = META_RECORDS_ROOT 34 | self.train_dataset_next_task = None 35 | self.validation_set_dict = {} 36 | self.test_set_dict = {} 37 | self.specs_dict = {} 38 | gin.parse_config_file(f"{PROJECT_ROOT}/data/meta_dataset_config.gin") 39 | 40 | def _get_dataset_spec(self, items): 41 | if isinstance(items, list): 42 | dataset_specs = [] 43 | for dataset_name in items: 44 | dataset_records_path = os.path.join(self.data_path, dataset_name) 45 | dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) 46 | dataset_specs.append(dataset_spec) 47 | return dataset_specs 48 | else: 49 | dataset_name = items 50 | dataset_records_path = os.path.join(self.data_path, dataset_name) 51 | dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) 52 | return dataset_spec 53 | 54 | def _to_torch(self, sample): 55 | for key, val in sample.items(): 56 | if isinstance(val, str): 57 | continue 58 | val = torch.from_numpy(val) 59 | if 'image' in key: 60 | val = val.permute(0, 3, 2, 1) 61 | else: 62 | val = val.long() 63 | sample[key] = val.to(device) 64 | return sample 65 | 66 | def num_classes(self, split_name): 67 | split = SPLIT_NAME_TO_SPLIT[split_name] 68 | all_split_specs = self.specs_dict[SPLIT_NAME_TO_SPLIT['train']] 69 | 70 | if not isinstance(all_split_specs, list): 71 | all_split_specs = [all_split_specs] 72 | 73 | total_n_classes = 0 74 | for specs in all_split_specs: 75 | total_n_classes += len(specs.get_classes(split)) 76 | return total_n_classes 77 | 78 | def build_class_to_identity(self): 79 | split = SPLIT_NAME_TO_SPLIT['train'] 80 | all_split_specs = self.specs_dict[SPLIT_NAME_TO_SPLIT['train']] 81 | 82 | if not isinstance(all_split_specs, list): 83 | all_split_specs = [all_split_specs] 84 | 85 | self.cls_to_identity = dict() 86 | self.dataset_id_to_dataset_name = dict() 87 | self.dataset_to_n_cats = dict() 88 | offset = 0 89 | for dataset_id, specs in enumerate(all_split_specs): 90 | dataset_name = specs.name 91 | self.dataset_id_to_dataset_name[dataset_id] = dataset_name 92 | n_cats = len(specs.get_classes(split)) 93 | self.dataset_to_n_cats[dataset_name] = n_cats 94 | for cat in range(n_cats): 95 | self.cls_to_identity[offset + cat] = (cat, dataset_id) 96 | offset += n_cats 97 | 98 | self.dataset_name_to_dataset_id = {v: k for k, v in 99 | self.dataset_id_to_dataset_name.items()} 100 | 101 | 102 | class MetaDatasetEpisodeReader(MetaDatasetReader): 103 | """ 104 | Class that wraps the Meta-Dataset episode readers. 105 | """ 106 | def __init__(self, mode, train_set=None, validation_set=None, test_set=None): 107 | super(MetaDatasetEpisodeReader, self).__init__(mode, train_set, validation_set, test_set) 108 | 109 | if mode == 'train': 110 | train_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 111 | self.train_dataset_next_task = self._init_multi_source_dataset( 112 | train_set, SPLIT_NAME_TO_SPLIT['train'], train_episode_desscription) 113 | 114 | if mode == 'val': 115 | test_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 116 | for item in validation_set: 117 | next_task = self._init_single_source_dataset( 118 | item, SPLIT_NAME_TO_SPLIT['val'], test_episode_desscription) 119 | self.validation_set_dict[item] = next_task 120 | 121 | if mode == 'test': 122 | test_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 123 | for item in test_set: 124 | next_task = self._init_single_source_dataset( 125 | item, SPLIT_NAME_TO_SPLIT['test'], test_episode_desscription) 126 | self.test_set_dict[item] = next_task 127 | 128 | def _init_multi_source_dataset(self, items, split, episode_description): 129 | dataset_specs = self._get_dataset_spec(items) 130 | self.specs_dict[split] = dataset_specs 131 | 132 | use_bilevel_ontology_list = [False] * len(items) 133 | use_dag_ontology_list = [False] * len(items) 134 | # Enable ontology aware sampling for Omniglot and ImageNet. 135 | if 'omniglot' in items: 136 | use_bilevel_ontology_list[items.index('omniglot')] = True 137 | if 'ilsvrc_2012' in items: 138 | use_dag_ontology_list[items.index('ilsvrc_2012')] = True 139 | 140 | multi_source_pipeline = pipeline.make_multisource_episode_pipeline( 141 | dataset_spec_list=dataset_specs, 142 | use_dag_ontology_list=use_dag_ontology_list, 143 | use_bilevel_ontology_list=use_bilevel_ontology_list, 144 | split=split, 145 | episode_descr_config = episode_description, 146 | image_size=84, 147 | shuffle_buffer_size=1000) 148 | 149 | iterator = multi_source_pipeline.make_one_shot_iterator() 150 | return iterator.get_next() 151 | 152 | def _init_single_source_dataset(self, dataset_name, split, episode_description): 153 | dataset_spec = self._get_dataset_spec(dataset_name) 154 | self.specs_dict[split] = dataset_spec 155 | 156 | # Enable ontology aware sampling for Omniglot and ImageNet. 157 | use_bilevel_ontology = False 158 | if 'omniglot' in dataset_name: 159 | use_bilevel_ontology = True 160 | 161 | use_dag_ontology = False 162 | if 'ilsvrc_2012' in dataset_name: 163 | use_dag_ontology = True 164 | 165 | single_source_pipeline = pipeline.make_one_source_episode_pipeline( 166 | dataset_spec=dataset_spec, 167 | use_dag_ontology=use_dag_ontology, 168 | use_bilevel_ontology=use_bilevel_ontology, 169 | split=split, 170 | episode_descr_config=episode_description, 171 | image_size=84, 172 | shuffle_buffer_size=1000) 173 | 174 | iterator = single_source_pipeline.make_one_shot_iterator() 175 | return iterator.get_next() 176 | 177 | def _get_task(self, next_task, session): 178 | episode = session.run(next_task)[0] 179 | task_dict = { 180 | 'context_images': episode[0], 181 | 'context_labels': episode[1], 182 | 'target_images': episode[3], 183 | 'target_labels': episode[4] 184 | } 185 | return self._to_torch(task_dict) 186 | 187 | def get_train_task(self, session): 188 | return self._get_task(self.train_dataset_next_task, session) 189 | 190 | def get_validation_task(self, session, item=None): 191 | item = item if item else list(self.validation_set_dict.keys())[0] 192 | return self._get_task(self.validation_set_dict[item], session) 193 | 194 | def get_test_task(self, session, item=None): 195 | item = item if item else list(self.test_set_dict.keys())[0] 196 | return self._get_task(self.test_set_dict[item], session) 197 | 198 | 199 | class MetaDatasetBatchReader(MetaDatasetReader): 200 | """ 201 | Class that wraps the Meta-Dataset episode readers. 202 | """ 203 | def __init__(self, mode, train_set, validation_set, test_set, batch_size): 204 | super(MetaDatasetBatchReader, self).__init__(mode, train_set, validation_set, test_set) 205 | self.batch_size = batch_size 206 | 207 | if mode == 'train': 208 | self.train_dataset_next_task = self._init_multi_source_dataset( 209 | train_set, SPLIT_NAME_TO_SPLIT['train']) 210 | 211 | if mode == 'val': 212 | for item in validation_set: 213 | next_task = self.validation_dataset = self._init_single_source_dataset( 214 | item, SPLIT_NAME_TO_SPLIT['val']) 215 | self.validation_set_dict[item] = next_task 216 | 217 | if mode == 'test': 218 | for item in test_set: 219 | next_task = self._init_single_source_dataset( 220 | item, SPLIT_NAME_TO_SPLIT['test']) 221 | self.test_set_dict[item] = next_task 222 | 223 | self.build_class_to_identity() 224 | 225 | def _init_multi_source_dataset(self, items, split): 226 | dataset_specs = self._get_dataset_spec(items) 227 | self.specs_dict[split] = dataset_specs 228 | multi_source_pipeline = pipeline.make_multisource_batch_pipeline( 229 | dataset_spec_list=dataset_specs, batch_size=self.batch_size, 230 | split=split, image_size=84, add_dataset_offset=True, shuffle_buffer_size=1000) 231 | 232 | iterator = multi_source_pipeline.make_one_shot_iterator() 233 | return iterator.get_next() 234 | 235 | def _init_single_source_dataset(self, dataset_name, split): 236 | dataset_specs = self._get_dataset_spec(dataset_name) 237 | self.specs_dict[split] = dataset_specs 238 | multi_source_pipeline = pipeline.make_one_source_batch_pipeline( 239 | dataset_spec=dataset_specs, batch_size=self.batch_size, 240 | split=split, image_size=84) 241 | 242 | iterator = multi_source_pipeline.make_one_shot_iterator() 243 | return iterator.get_next() 244 | 245 | def _get_batch(self, next_task, session): 246 | episode = session.run(next_task)[0] 247 | images, labels = episode[0], episode[1] 248 | local_classes, dataset_ids = [], [] 249 | for label in labels: 250 | local_class, dataset_id = self.cls_to_identity[label] 251 | local_classes.append(local_class) 252 | dataset_ids.append(dataset_id) 253 | task_dict = { 254 | 'images': images, 255 | 'labels': labels, 256 | 'local_classes': np.array(local_classes), 257 | 'dataset_ids': np.array(dataset_ids), 258 | 'dataset_name': self.dataset_id_to_dataset_name[dataset_ids[-1]] 259 | } 260 | return self._to_torch(task_dict) 261 | 262 | def get_train_batch(self, session): 263 | return self._get_batch(self.train_dataset_next_task, session) 264 | 265 | def get_validation_batch(self, item, session): 266 | return self._get_batch(self.validation_set_dict[item], session) 267 | 268 | def get_test_batch(self, item, session): 269 | return self._get_batch(self.test_set_dict[item], session) 270 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gin 3 | import numpy as np 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def cross_entropy_loss(logits, targets): 10 | log_p_y = F.log_softmax(logits, dim=1) 11 | preds = log_p_y.argmax(1) 12 | labels = targets.type(torch.long) 13 | loss = F.nll_loss(log_p_y, labels, reduction='mean') 14 | acc = torch.eq(preds, labels).float().mean() 15 | stats_dict = {'loss': loss.item(), 'acc': acc.item()} 16 | pred_dict = {'preds': preds.cpu().numpy(), 'labels': labels.cpu().numpy()} 17 | return loss, stats_dict, pred_dict 18 | 19 | 20 | def prototype_loss(support_embeddings, support_labels, 21 | query_embeddings, query_labels, distance='cos'): 22 | n_way = len(query_labels.unique()) 23 | 24 | prots = compute_prototypes(support_embeddings, support_labels, n_way).unsqueeze(0) 25 | embeds = query_embeddings.unsqueeze(1) 26 | 27 | if distance == 'l2': 28 | logits = -torch.pow(embeds - prots, 2).sum(-1) # shape [n_query, n_way] 29 | elif distance == 'cos': 30 | logits = F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) * 10 31 | elif distance == 'lin': 32 | logits = torch.einsum('izd,zjd->ij', embeds, prots) 33 | 34 | return cross_entropy_loss(logits, query_labels) 35 | 36 | 37 | def compute_prototypes(embeddings, labels, n_way): 38 | prots = torch.zeros(n_way, embeddings.shape[-1]).type( 39 | embeddings.dtype).to(embeddings.device) 40 | for i in range(n_way): 41 | prots[i] = embeddings[(labels == i).nonzero(), :].mean(0) 42 | return prots 43 | 44 | 45 | class AdaptiveCosineNCC(nn.Module): 46 | def __init__(self): 47 | super(AdaptiveCosineNCC, self).__init__() 48 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 49 | 50 | def forward(self, support_embeddings, support_labels, 51 | query_embeddings, query_labels, return_logits=False): 52 | n_way = len(query_labels.unique()) 53 | 54 | prots = compute_prototypes(support_embeddings, support_labels, n_way).unsqueeze(0) 55 | embeds = query_embeddings.unsqueeze(1) 56 | logits = F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) * self.scale 57 | 58 | if return_logits: 59 | return logits 60 | 61 | return cross_entropy_loss(logits, query_labels) 62 | 63 | -------------------------------------------------------------------------------- /models/model_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import torch 4 | from functools import partial 5 | 6 | from models.model_utils import CheckPointer 7 | from models.models_dict import DATASET_MODELS_RESNET18 8 | from utils import device 9 | from paths import PROJECT_ROOT 10 | 11 | 12 | def get_model(num_classes, args): 13 | train_classifier = args['model.classifier'] 14 | model_name = args['model.backbone'] 15 | dropout_rate = args.get('model.dropout', 0) 16 | 17 | if 'pnf' in model_name: 18 | from models.resnet18_pnf import resnet18 19 | 20 | base_network_name = DATASET_MODELS_RESNET18['ilsvrc_2012'] 21 | base_network_path = os.path.join(PROJECT_ROOT, 'weights', base_network_name, 'model_best.pth.tar') 22 | model_fn = partial(resnet18, dropout=dropout_rate, 23 | pretrained_model_path=base_network_path) 24 | else: 25 | from models.resnet18 import resnet18 26 | model_fn = partial(resnet18, dropout=dropout_rate) 27 | 28 | model = model_fn(classifier=train_classifier, 29 | num_classes=num_classes, 30 | global_pool=False) 31 | model.to(device) 32 | return model 33 | 34 | 35 | def get_optimizer(model, args, params=None): 36 | learning_rate = args['train.learning_rate'] 37 | weight_decay = args['train.weight_decay'] 38 | optimizer = args['train.optimizer'] 39 | params = model.parameters() if params is None else params 40 | if optimizer == 'adam': 41 | optimizer = torch.optim.Adam(params, 42 | lr=learning_rate, 43 | weight_decay=weight_decay) 44 | elif optimizer == 'momentum': 45 | optimizer = torch.optim.SGD(params, 46 | lr=learning_rate, 47 | momentum=0.9, nesterov=args['train.nesterov_momentum'], 48 | weight_decay=weight_decay) 49 | else: 50 | assert False, 'No such optimizer' 51 | return optimizer 52 | 53 | 54 | def get_domain_extractors(trainset, dataset_models, args): 55 | if 'pnf' in args['model.backbone']: 56 | return get_pnf_extractor(trainset, dataset_models, args) 57 | else: 58 | return get_multinet_extractor(trainset, dataset_models, args) 59 | 60 | 61 | def get_multinet_extractor(trainsets, dataset_models, args): 62 | extractors = dict() 63 | for dataset_name in trainsets: 64 | if dataset_name not in dataset_models: 65 | continue 66 | args['model.name'] = dataset_models[dataset_name] 67 | extractor = get_model(None, args) 68 | checkpointer = CheckPointer(args, extractor, optimizer=None) 69 | extractor.eval() 70 | checkpointer.restore_model(ckpt='best', strict=False) 71 | extractors[dataset_name] = extractor 72 | 73 | def embed_many(images, return_type='dict'): 74 | with torch.no_grad(): 75 | all_features = dict() 76 | for name, extractor in extractors.items(): 77 | all_features[name] = extractor.embed(images) 78 | if return_type == 'list': 79 | return list(all_features.values()) 80 | else: 81 | return all_features 82 | return embed_many 83 | 84 | 85 | def get_pnf_extractor(trainsets, dataset_models, args): 86 | film_layers = dict() 87 | for dataset_name in trainsets: 88 | if dataset_name not in dataset_models or 'ilsvrc' in dataset_name: 89 | continue 90 | ckpt_path = os.path.join(PROJECT_ROOT, 'weights', dataset_models[dataset_name], 91 | 'model_best.pth.tar') 92 | state_dict = torch.load(ckpt_path, map_location=device)['state_dict'] 93 | film_layers[dataset_name] = {k: v for k, v in state_dict.items() 94 | if 'cls' not in k} 95 | print('Loaded FiLM layers from {}'.format(ckpt_path)) 96 | 97 | # define the base extractor 98 | base_extractor = get_model(None, args) 99 | base_extractor.eval() 100 | base_layers = {k: v for k, v in base_extractor.get_state_dict().items() if 'cls' not in k} 101 | 102 | # initialize film layers of base extractor to identity 103 | film_layers['ilsvrc_2012'] = {k: v.clone() for k, v in base_layers.items()} 104 | 105 | def embed_many(images, return_type='dict'): 106 | with torch.no_grad(): 107 | all_features = dict() 108 | 109 | for domain_name in trainsets: 110 | # setting up domain-specific film layers 111 | domain_layers = film_layers[domain_name] 112 | for layer_name in base_layers.keys(): 113 | base_layers[layer_name].data.copy_(domain_layers[layer_name].data) 114 | 115 | # inference 116 | all_features[domain_name] = base_extractor.embed(images) 117 | if return_type == 'list': 118 | return list(all_features.values()) 119 | else: 120 | return all_features 121 | return embed_many 122 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | from torch import nn 6 | from torch.optim.lr_scheduler import (MultiStepLR, ExponentialLR, 7 | CosineAnnealingWarmRestarts, 8 | CosineAnnealingLR) 9 | from utils import check_dir, device 10 | from paths import PROJECT_ROOT 11 | 12 | 13 | sigmoid = nn.Sigmoid() 14 | 15 | 16 | def cosine_sim(embeds, prots): 17 | prots = prots.unsqueeze(0) 18 | embeds = embeds.unsqueeze(1) 19 | return F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) 20 | 21 | 22 | 23 | class CosineClassifier(nn.Module): 24 | def __init__(self, n_feat, num_classes): 25 | super(CosineClassifier, self).__init__() 26 | self.num_classes = num_classes 27 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 28 | weight = torch.FloatTensor(n_feat, num_classes).normal_( 29 | 0.0, np.sqrt(2.0 / num_classes)) 30 | self.weight = nn.Parameter(weight, requires_grad=True) 31 | 32 | def forward(self, x): 33 | x_norm = torch.nn.functional.normalize(x, p=2, dim=-1, eps=1e-12) 34 | weight = torch.nn.functional.normalize(self.weight, p=2, dim=0, eps=1e-12) 35 | cos_dist = x_norm @ weight 36 | scores = self.scale * cos_dist 37 | return scores 38 | 39 | def extra_repr(self): 40 | s = 'CosineClassifier: input_channels={}, num_classes={}; learned_scale: {}'.format( 41 | self.weight.shape[0], self.weight.shape[1], self.scale.item()) 42 | return s 43 | 44 | 45 | class CosineConv(nn.Module): 46 | def __init__(self, n_feat, num_classes, kernel_size=1): 47 | super(CosineConv, self).__init__() 48 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 49 | weight = torch.FloatTensor(num_classes, n_feat, 1, 1).normal_( 50 | 0.0, np.sqrt(2.0/num_classes)) 51 | self.weight = nn.Parameter(weight, requires_grad=True) 52 | 53 | def forward(self, x): 54 | x_normalized = torch.nn.functional.normalize( 55 | x, p=2, dim=1, eps=1e-12) 56 | weight = torch.nn.functional.normalize( 57 | self.weight, p=2, dim=1, eps=1e-12) 58 | 59 | cos_dist = torch.nn.functional.conv2d(x_normalized, weight) 60 | scores = self.scale * cos_dist 61 | return scores 62 | 63 | def extra_repr(self): 64 | s = 'CosineConv: num_inputs={}, num_classes={}, kernel_size=1; scale_value: {}'.format( 65 | self.weight.shape[0], self.weight.shape[1], self.scale.item()) 66 | return s 67 | 68 | 69 | class CheckPointer(object): 70 | def __init__(self, args, model=None, optimizer=None): 71 | self.model = model 72 | self.optimizer = optimizer 73 | self.args = args 74 | self.model_path = os.path.join(PROJECT_ROOT, 'weights', args['model.name']) 75 | self.last_ckpt = os.path.join(self.model_path, 'checkpoint.pth.tar') 76 | self.best_ckpt = os.path.join(self.model_path, 'model_best.pth.tar') 77 | 78 | def restore_model(self, ckpt='last', model=True, 79 | optimizer=True, strict=True): 80 | if not os.path.exists(self.model_path): 81 | assert False, "Model is not found at {}".format(self.model_path) 82 | self.last_ckpt = os.path.join(self.model_path, 'checkpoint.pth.tar') 83 | self.best_ckpt = os.path.join(self.model_path, 'model_best.pth.tar') 84 | ckpt_path = self.last_ckpt if ckpt == 'last' else self.best_ckpt 85 | 86 | if os.path.isfile(ckpt_path): 87 | print("=> loading {} checkpoint '{}'".format(ckpt, ckpt_path)) 88 | ch = torch.load(ckpt_path, map_location=device) 89 | if self.model is not None and model: 90 | self.model.load_state_dict(ch['state_dict'], strict=strict) 91 | if self.optimizer is not None and optimizer: 92 | self.optimizer.load_state_dict(ch['optimizer']) 93 | else: 94 | assert False, "No checkpoint! %s" % ckpt_path 95 | 96 | return ch.get('epoch', None), ch.get('best_val_loss', None), ch.get('best_val_acc', None) 97 | 98 | def save_checkpoint(self, epoch, best_val_acc, best_val_loss, 99 | is_best, filename='checkpoint.pth.tar', 100 | optimizer=None, state_dict=None, extra=None): 101 | state_dict = self.model.state_dict() if state_dict is None else state_dict 102 | state = {'epoch': epoch + 1, 103 | 'args': self.args, 104 | 'state_dict': state_dict, 105 | 'best_val_acc': best_val_acc, 106 | 'best_val_loss': best_val_loss} 107 | 108 | if extra is not None: 109 | state.update(extra) 110 | if optimizer is not None: 111 | state['optimizer'] = optimizer.state_dict() 112 | 113 | model_path = check_dir(self.model_path, True, False) 114 | torch.save(state, os.path.join(model_path, filename)) 115 | if is_best: 116 | shutil.copyfile(os.path.join(model_path, filename), 117 | os.path.join(model_path, 'model_best.pth.tar')) 118 | 119 | 120 | class UniformStepLR(object): 121 | def __init__(self, optimizer, args, start_iter): 122 | self.iter = start_iter 123 | self.max_iter = args['train.max_iter'] 124 | step_iters = self.compute_milestones(args) 125 | self.lr_scheduler = MultiStepLR( 126 | optimizer, milestones=step_iters, last_epoch=start_iter-1, 127 | gamma=args['train.step_decay_gamma']) 128 | 129 | def step(self, _iter): 130 | self.iter += 1 131 | self.lr_scheduler.step() 132 | stop_training = self.iter >= self.max_iter 133 | return stop_training 134 | 135 | def compute_milestones(self, args): 136 | max_iter = args['train.max_iter'] 137 | step_size = max_iter / args['train.decay_step_freq'] 138 | step_iters = [0] 139 | while step_iters[-1] < max_iter: 140 | step_iters.append(step_iters[-1] + step_size) 141 | return self.step_iters[1:] 142 | 143 | 144 | class ExpDecayLR(object): 145 | def __init__(self, optimizer, args, start_iter): 146 | self.iter = start_iter 147 | self.max_iter = args['train.max_iter'] 148 | self.start_decay_iter = args['train.exp_decay_start_iter'] 149 | gamma = self.compute_gamma(args) 150 | schedule_start = max(start_iter - self.start_decay_iter, 0) - 1 151 | self.lr_scheduler = ExponentialLR(optimizer, gamma=gamma, 152 | last_epoch=schedule_start) 153 | 154 | def step(self, _iter): 155 | self.iter += 1 156 | if _iter > self.start_decay_iter: 157 | self.lr_scheduler.step() 158 | stop_training = self.iter >= self.max_iter 159 | return stop_training 160 | 161 | def compute_gamma(self, args): 162 | last_iter, start_iter = args['train.max_iter'], args['train.exp_decay_start_iter'] 163 | start_lr, last_lr = args['train.learning_rate'], args['train.exp_decay_final_lr'] 164 | return np.power(last_lr / start_lr, 1 / (last_iter - start_iter)) 165 | 166 | 167 | class CosineAnnealRestartLR(object): 168 | def __init__(self, optimizer, args, start_iter): 169 | self.iter = start_iter 170 | self.max_iter = args['train.max_iter'] 171 | self.lr_scheduler = CosineAnnealingWarmRestarts( 172 | optimizer, args['train.cosine_anneal_freq'], last_epoch=start_iter-1) 173 | # self.lr_scheduler = CosineAnnealingLR( 174 | # optimizer, args['train.cosine_anneal_freq'], last_epoch=start_iter-1) 175 | 176 | def step(self, _iter): 177 | self.iter += 1 178 | self.lr_scheduler.step(_iter) 179 | stop_training = self.iter >= self.max_iter 180 | return stop_training 181 | 182 | -------------------------------------------------------------------------------- /models/models_dict.py: -------------------------------------------------------------------------------- 1 | DATASET_MODELS_RESNET18 = { 2 | 'ilsvrc_2012': 'imagenet-net', 3 | 'omniglot': 'omniglot-net', 4 | 'aircraft': 'aircraft-net', 5 | 'cu_birds': 'birds-net', 6 | 'dtd': 'textures-net', 7 | 'quickdraw': 'quickdraw-net', 8 | 'fungi': 'fungi-net', 9 | 'vgg_flower': 'vgg_flower-net' 10 | } 11 | 12 | 13 | DATASET_MODELS_RESNET18_PNF = { 14 | 'omniglot': 'omniglot-film', 15 | 'aircraft': 'aircraft-film', 16 | 'cu_birds': 'birds-film', 17 | 'dtd': 'textures-film', 18 | 'quickdraw': 'quickdraw-film', 19 | 'fungi': 'fungi-film', 20 | 'vgg_flower': 'vgg_flower-film' 21 | } 22 | 23 | DATASET_MODELS_DICT = {'resnet18': DATASET_MODELS_RESNET18, 24 | 'resnet18_pnf': DATASET_MODELS_RESNET18_PNF} 25 | -------------------------------------------------------------------------------- /models/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from models.model_utils import CosineClassifier 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | 52 | def __init__(self, block, layers, classifier=None, num_classes=64, 53 | dropout=0.0, global_pool=True): 54 | super(ResNet, self).__init__() 55 | self.initial_pool = False 56 | inplanes = self.inplanes = 64 57 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=2, 58 | padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(self.inplanes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 62 | self.layer1 = self._make_layer(block, inplanes, layers[0]) 63 | self.layer2 = self._make_layer(block, inplanes * 2, layers[1], stride=2) 64 | self.layer3 = self._make_layer(block, inplanes * 4, layers[2], stride=2) 65 | self.layer4 = self._make_layer(block, inplanes * 8, layers[3], stride=2) 66 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 67 | self.dropout = nn.Dropout(dropout) 68 | self.outplanes = 512 69 | 70 | # handle classifier creation 71 | if num_classes is not None: 72 | if classifier == 'linear': 73 | self.cls_fn = nn.Linear(self.outplanes, num_classes) 74 | elif classifier == 'cosine': 75 | self.cls_fn = CosineClassifier(self.outplanes, num_classes) 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 80 | elif isinstance(m, nn.BatchNorm2d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | def _make_layer(self, block, planes, blocks, stride=1): 85 | downsample = None 86 | if stride != 1 or self.inplanes != planes * block.expansion: 87 | downsample = nn.Sequential( 88 | conv1x1(self.inplanes, planes * block.expansion, stride), 89 | nn.BatchNorm2d(planes * block.expansion), 90 | ) 91 | 92 | layers = [] 93 | layers.append(block(self.inplanes, planes, stride, downsample)) 94 | self.inplanes = planes * block.expansion 95 | for _ in range(1, blocks): 96 | layers.append(block(self.inplanes, planes)) 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | x = self.embed(x) 102 | x = self.dropout(x) 103 | x = self.cls_fn(x) 104 | return x 105 | 106 | def embed(self, x, param_dict=None): 107 | x = self.conv1(x) 108 | x = self.bn1(x) 109 | x = self.relu(x) 110 | if self.initial_pool: 111 | x = self.maxpool(x) 112 | 113 | x = self.layer1(x) 114 | x = self.layer2(x) 115 | x = self.layer3(x) 116 | x = self.layer4(x) 117 | 118 | x = self.avgpool(x) 119 | return x.squeeze() 120 | 121 | def get_state_dict(self): 122 | """Outputs all the state elements""" 123 | return self.state_dict() 124 | 125 | def get_parameters(self): 126 | """Outputs all the parameters""" 127 | return [v for k, v in self.named_parameters()] 128 | 129 | 130 | def resnet18(pretrained=False, pretrained_model_path=None, **kwargs): 131 | """ 132 | Constructs a ResNet-18 model. 133 | """ 134 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 135 | if pretrained: 136 | device = model.get_state_dict()[0].device 137 | ckpt_dict = torch.load(pretrained_model_path, map_location=device) 138 | model.load_parameters(ckpt_dict['state_dict'], strict=False) 139 | print('Loaded shared weights from {}'.format(pretrained_model_path)) 140 | return model 141 | -------------------------------------------------------------------------------- /models/resnet18_pnf.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from functools import partial 4 | 5 | from models.model_utils import CosineClassifier 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class CatFilm(nn.Module): 20 | """Film layer that performs per-channel affine transformation.""" 21 | def __init__(self, planes): 22 | super(CatFilm, self).__init__() 23 | self.gamma = nn.Parameter(torch.ones(1, planes)) 24 | self.beta = nn.Parameter(torch.zeros(1, planes)) 25 | 26 | def forward(self, x): 27 | gamma = self.gamma.view(1, -1, 1, 1) 28 | beta = self.beta.view(1, -1, 1, 1) 29 | return gamma * x + beta 30 | 31 | 32 | class BasicBlockFilm(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlockFilm, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = nn.BatchNorm2d(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | self.film1 = CatFilm(planes) 45 | self.film2 = CatFilm(planes) 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.film1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | out = self.film2(out) 58 | 59 | if self.downsample is not None: 60 | identity = self.downsample(x) 61 | 62 | out += identity 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, layers, classifier=None, num_classes=None, 70 | dropout=0.0, global_pool=True): 71 | super(ResNet, self).__init__() 72 | self.initial_pool = False 73 | self.film_normalize = CatFilm(3) 74 | inplanes = self.inplanes = 64 75 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=2, 76 | padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 80 | self.layer1 = self._make_layer(block, inplanes, layers[0]) 81 | self.layer2 = self._make_layer(block, inplanes * 2, layers[1], stride=2) 82 | self.layer3 = self._make_layer(block, inplanes * 4, layers[2], stride=2) 83 | self.layer4 = self._make_layer(block, inplanes * 8, layers[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | self.dropout = nn.Dropout(dropout) 86 | self.outplanes = 512 87 | 88 | # handle classifier creation 89 | if num_classes is not None: 90 | if classifier == 'linear': 91 | self.cls_fn = nn.Linear(self.outplanes, num_classes) 92 | elif classifier == 'cosine': 93 | self.cls_fn = CosineClassifier(self.outplanes, num_classes) 94 | 95 | # initialize everything 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight, mode='fan_out', 99 | nonlinearity='relu') 100 | elif isinstance(m, nn.BatchNorm2d): 101 | nn.init.constant_(m.weight, 1) 102 | nn.init.constant_(m.bias, 0) 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | conv1x1(self.inplanes, planes * block.expansion, stride), 109 | nn.BatchNorm2d(planes * block.expansion)) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample)) 113 | self.inplanes = planes * block.expansion 114 | for _ in range(1, blocks): 115 | layers.append(block(self.inplanes, planes)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.embed(x) 121 | x = self.cls_fn(x) 122 | return x 123 | 124 | def embed(self, x, squeeze=True, param_dict=None): 125 | """Computing the features""" 126 | x = self.film_normalize(x) 127 | x = self.conv1(x) 128 | x = self.bn1(x) 129 | x = self.relu(x) 130 | if self.initial_pool: 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | 138 | x = self.avgpool(x) 139 | return x.squeeze() 140 | 141 | def get_state_dict(self): 142 | """Outputs the state elements that are domain-specific""" 143 | return {k: v for k, v in self.state_dict().items() 144 | if 'film' in k or 'cls' in k or 'running' in k} 145 | 146 | def get_parameters(self): 147 | """Outputs only the parameters that are domain-specific""" 148 | return [v for k, v in self.named_parameters() 149 | if 'film' in k or 'cls' in k] 150 | 151 | 152 | def resnet18(pretrained=False, pretrained_model_path=None, **kwargs): 153 | """ 154 | Constructs a FiLM adapted ResNet-18 model. 155 | """ 156 | model = ResNet(BasicBlockFilm, [2, 2, 2, 2], **kwargs) 157 | 158 | # loading shared convolutional weights 159 | if pretrained_model_path is not None: 160 | device = model.get_parameters()[0].device 161 | ckpt_dict = torch.load(pretrained_model_path, map_location=device)['state_dict'] 162 | shared_state = {k: v for k, v in ckpt_dict.items() if 'cls' not in k} 163 | model.load_state_dict(shared_state, strict=False) 164 | print('Loaded shared weights from {}'.format(pretrained_model_path)) 165 | return model 166 | -------------------------------------------------------------------------------- /models/sur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from models.model_utils import sigmoid, cosine_sim 5 | from models.losses import prototype_loss 6 | from utils import device 7 | 8 | 9 | def apply_selection(features_dict, lambdas, normalize=True): 10 | """ 11 | Performs masking of features via pointwise multiplying by lambda 12 | """ 13 | lambdas_01 = sigmoid(lambdas) 14 | features_list = list(features_dict.values()) 15 | if normalize: 16 | features_list = [f / (f ** 2).sum(-1, keepdim=True).sqrt() 17 | for f in features_list] 18 | n_cont = features_list[0].shape[0] 19 | concat_feat = torch.stack(features_list, -1) 20 | return (concat_feat * lambdas_01).reshape([n_cont, -1]) 21 | 22 | 23 | def sur(context_features_dict, context_labels, max_iter=40): 24 | """ 25 | SUR method: optimizes selection parameters lambda 26 | """ 27 | lambdas = torch.zeros([1, 1, len(context_features_dict)]).to(device) 28 | lambdas.requires_grad_(True) 29 | n_classes = len(np.unique(context_labels.cpu().numpy())) 30 | optimizer = torch.optim.Adadelta([lambdas], lr=(3e+3 / n_classes)) 31 | 32 | for i in range(max_iter): 33 | optimizer.zero_grad() 34 | selected_features = apply_selection(context_features_dict, lambdas) 35 | loss, stat, _ = prototype_loss(selected_features, context_labels, 36 | selected_features, context_labels) 37 | 38 | loss.backward() 39 | optimizer.step() 40 | return lambdas 41 | -------------------------------------------------------------------------------- /paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | PROJECT_ROOT = '/'.join(os.path.realpath(__file__).split('/')[:-1]) 5 | META_DATASET_ROOT = os.environ['META_DATASET_ROOT'] 6 | META_RECORDS_ROOT = os.environ['RECORDS'] 7 | META_DATA_ROOT = '/'.join(META_RECORDS_ROOT.split('/')[:-1]) 8 | -------------------------------------------------------------------------------- /scripts/dump_test_episodes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='test' 4 | backbone='resnet18_pnf' 5 | for dataset in "ilsvrc_2012" "omniglot" "aircraft" "cu_birds" "dtd" "quickdraw" "fungi" "vgg_flower" " traffic_sign" "mscoco" "mnist" "cifar10" "cifar100"; do 6 | python ./data/create_features_db.py --data.train ilsvrc_2012 omniglot aircraft cu_birds dtd quickdraw fungi vgg_flower --data.val ${dataset} --data.test ${dataset} --model.backbone=${backbone} --dump.mode=${mode} --dump.size=600 7 | done 8 | 9 | -------------------------------------------------------------------------------- /scripts/train_networks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################### Training independent feature extractors ################### 4 | function train_fn { 5 | python train_net.py --model.name=$1 --model.backbone=resnet18 --data.train $2 --data.val $2 --data.test $2 --train.batch_size=$3 --train.learning_rate=$4 --train.max_iter=$5 --train.cosine_anneal_freq=$6 --train.eval_freq=$6 6 | } 7 | 8 | # Train an independent feature extractor on every training dataset (the following models could be trained in parallel) 9 | 10 | # ImageNet 11 | NAME="imagenet-net"; TRAINSET="ilsvrc_2012"; BATCH_SIZE=64; LR="3e-2"; MAX_ITER=480000; ANNEAL_FREQ=48000 12 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 13 | 14 | # Omniglot 15 | NAME="omniglot-net"; TRAINSET="omniglot"; BATCH_SIZE=16; LR="3e-2"; MAX_ITER=50000; ANNEAL_FREQ=3000 16 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 17 | 18 | # Aircraft 19 | NAME="aircraft-net"; TRAINSET="aircraft"; BATCH_SIZE=8; LR="3e-2"; MAX_ITER=50000; ANNEAL_FREQ=3000 20 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 21 | 22 | # Birds 23 | NAME="birds-net"; TRAINSET="cu_birds"; BATCH_SIZE=16; LR="3e-2"; MAX_ITER=50000; ANNEAL_FREQ=3000 24 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 25 | 26 | # Textures 27 | NAME="textures-net"; TRAINSET="dtd"; BATCH_SIZE=32; LR="3e-2"; MAX_ITER=50000; ANNEAL_FREQ=1500 28 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 29 | 30 | # Quick Draw 31 | NAME="quickdraw-net"; TRAINSET="quickdraw"; BATCH_SIZE=64; LR="1e-2"; MAX_ITER=480000; ANNEAL_FREQ=48000 32 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 33 | 34 | # Fungi 35 | NAME="fungi-net"; TRAINSET="fungi"; BATCH_SIZE=32; LR="3e-2"; MAX_ITER=480000; ANNEAL_FREQ=15000 36 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 37 | 38 | # VGG Flower 39 | NAME="vgg_flower-net"; TRAINSET="vgg_flower"; BATCH_SIZE=8; LR="3e-2"; MAX_ITER=50000; ANNEAL_FREQ=1500 40 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 41 | 42 | echo "All Feature Extractors are trained!" 43 | -------------------------------------------------------------------------------- /scripts/train_pnf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################### Training independent feature extractors ################### 4 | function train_fn { 5 | python train_pnf.py --model.name=$1 --model.backbone=resnet18_pnf --data.train $2 --data.val $2 --data.test $2 --train.batch_size=$3 --train.learning_rate=$4 --train.max_iter=$5 --train.cosine_anneal_freq=$6 --train.eval_freq=$6 6 | } 7 | 8 | # Train base feature extractor on ImageNet 9 | NAME="imagenet-net"; TRAINSET="ilsvrc_2012"; BATCH_SIZE=64; LR="3e-2"; MAX_ITER=480000; ANNEAL_FREQ=48000 10 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 11 | 12 | #Then, train domain specific FiLM layers on every other dataset 13 | # Omniglot 14 | NAME="omniglot-film"; TRAINSET="omniglot"; BATCH_SIZE=16; LR="3e-2"; MAX_ITER=40000; ANNEAL_FREQ=3000 15 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 16 | 17 | # Aircraft 18 | NAME="aircraft-film"; TRAINSET="aircraft"; BATCH_SIZE=32; LR="1e-2"; MAX_ITER=30000; ANNEAL_FREQ=1500 19 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 20 | 21 | # Birds 22 | NAME="birds-film"; TRAINSET="cu_birds"; BATCH_SIZE=16; LR="3e-2"; MAX_ITER=40000; ANNEAL_FREQ=1500 23 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 24 | 25 | # Textures 26 | NAME="textures-film"; TRAINSET="dtd"; BATCH_SIZE=16; LR="3e-2"; MAX_ITER=40000; ANNEAL_FREQ=1500 27 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 28 | 29 | # Quick Draw 30 | NAME="quickdraw-film"; TRAINSET="quickdraw"; BATCH_SIZE=32; LR="1e-2"; MAX_ITER=400000; ANNEAL_FREQ=15000 31 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 32 | 33 | # Fungi 34 | NAME="fungi-film"; TRAINSET="fungi"; BATCH_SIZE=32; LR="1e-2"; MAX_ITER=400000; ANNEAL_FREQ=15000 35 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 36 | 37 | # VGG Flower 38 | NAME="vgg_flower-film"; TRAINSET="vgg_flower"; BATCH_SIZE=16; LR="1e-2"; MAX_ITER=30000; ANNEAL_FREQ=3000 39 | train_fn $NAME $TRAINSET $BATCH_SIZE $LR $MAX_ITER $ANNEAL_FREQ 40 | 41 | echo "All Feature Extractors are trained!" 42 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from tqdm import tqdm 8 | from tabulate import tabulate 9 | 10 | from data.meta_dataset_reader import (MetaDatasetEpisodeReader, 11 | TRAIN_METADATASET_NAMES, 12 | ALL_METADATASET_NAMES) 13 | from models.model_utils import CheckPointer, sigmoid, cosine_sim 14 | from models.model_helpers import get_domain_extractors 15 | from models.losses import prototype_loss 16 | from models.sur import apply_selection, sur 17 | from models.models_dict import DATASET_MODELS_DICT 18 | from utils import device 19 | from config import args 20 | 21 | 22 | def main(): 23 | LIMITER = 600 # has to be 600 for the testing under MetaDataset protocol 24 | config = tf.compat.v1.ConfigProto() 25 | config.gpu_options.allow_growth = True 26 | 27 | # Setting up datasets 28 | extractor_domains = TRAIN_METADATASET_NAMES 29 | all_test_datasets = ALL_METADATASET_NAMES 30 | loader = MetaDatasetEpisodeReader('test', 31 | train_set=extractor_domains, 32 | validation_set=extractor_domains, 33 | test_set=all_test_datasets) 34 | 35 | # define the embedding method 36 | dataset_models = DATASET_MODELS_DICT[args['model.backbone']] 37 | embed_many = get_domain_extractors(extractor_domains, dataset_models, args) 38 | 39 | accs_names = ['SUR'] 40 | 41 | all_accs = dict() 42 | # Go over all test datasets 43 | for test_dataset in all_test_datasets: 44 | print(test_dataset) 45 | all_accs[test_dataset] = {name: [] for name in accs_names} 46 | 47 | with tf.compat.v1.Session(config=config) as session: 48 | for idx in tqdm(range(LIMITER)): 49 | # extract image features and labels 50 | sample = loader.get_test_task(session, test_dataset) 51 | context_features_dict = embed_many(sample['context_images']) 52 | target_features_dict = embed_many(sample['target_images']) 53 | context_labels = sample['context_labels'].to(device) 54 | target_labels = sample['target_labels'].to(device) 55 | 56 | # optimize selection parameters and perform feature selection 57 | selection_params = sur(context_features_dict, context_labels, max_iter=40) 58 | selected_context = apply_selection(context_features_dict, selection_params) 59 | selected_target = apply_selection(target_features_dict, selection_params) 60 | 61 | final_acc = prototype_loss(selected_context, context_labels, 62 | selected_target, target_labels)[1]['acc'] 63 | all_accs[test_dataset]['SUR'].append(final_acc) 64 | 65 | # Make a nice accuracy table 66 | rows = [] 67 | for dataset_name in all_test_datasets: 68 | row = [dataset_name] 69 | for model_name in accs_names: 70 | acc = np.array(all_accs[dataset_name][model_name]) * 100 71 | mean_acc = acc.mean() 72 | conf = (1.96 * acc.std()) / np.sqrt(len(acc)) 73 | row.append(f"{mean_acc:0.2f} +- {conf:0.2f}") 74 | rows.append(row) 75 | 76 | table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f") 77 | print(table) 78 | print("\n") 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /test_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code allows you to evaluate performance of a single feature extractor + NCC 3 | on several dataset. 4 | 5 | For example, to test a resnet18 feature extractor trained on cu_birds 6 | (that you downloaded) on test splits of ilsrvc_2012, dtd, vgg_flower, quickdraw, run: 7 | python ./test_extractor.py --model.name=birds-net --model.backbone=resnet18 --data.test ilsrvc_2012 dtd vgg_flower quickdraw 8 | """ 9 | 10 | import os 11 | import torch 12 | import tensorflow as tf 13 | import numpy as np 14 | from tqdm import tqdm 15 | from tabulate import tabulate 16 | 17 | from models.losses import prototype_loss 18 | from models.model_utils import CheckPointer 19 | from models.model_helpers import get_model 20 | from data.meta_dataset_reader import (MetaDatasetEpisodeReader, MetaDatasetBatchReader) 21 | from config import args 22 | 23 | 24 | def main(): 25 | TEST_SIZE = 600 26 | 27 | # Setting up datasets 28 | trainsets, valsets, testsets = args['data.train'], args['data.val'], args['data.test'] 29 | test_loader = MetaDatasetEpisodeReader('test', trainsets, valsets, testsets) 30 | model = get_model(None, args) 31 | checkpointer = CheckPointer(args, model, optimizer=None) 32 | checkpointer.restore_model(ckpt='best', strict=False) 33 | model.eval() 34 | 35 | accs_names = ['NCC'] 36 | var_accs = dict() 37 | 38 | config = tf.compat.v1.ConfigProto() 39 | config.gpu_options.allow_growth = True 40 | with tf.compat.v1.Session(config=config) as session: 41 | # go over each test domain 42 | for dataset in testsets: 43 | print(dataset) 44 | var_accs[dataset] = {name: [] for name in accs_names} 45 | 46 | for i in tqdm(range(TEST_SIZE)): 47 | with torch.no_grad(): 48 | sample = test_loader.get_test_task(session, dataset) 49 | context_features = model.embed(sample['context_images']) 50 | target_features = model.embed(sample['target_images']) 51 | context_labels = sample['context_labels'] 52 | target_labels = sample['target_labels'] 53 | _, stats_dict, _ = prototype_loss( 54 | context_features, context_labels, 55 | target_features, target_labels) 56 | var_accs[dataset]['NCC'].append(stats_dict['acc']) 57 | 58 | # Print nice results table 59 | rows = [] 60 | for dataset_name in testsets: 61 | row = [dataset_name] 62 | for model_name in accs_names: 63 | acc = np.array(var_accs[dataset_name][model_name]) * 100 64 | mean_acc = acc.mean() 65 | conf = (1.96 * acc.std()) / np.sqrt(len(acc)) 66 | row.append(f"{mean_acc:0.2f} +- {conf:0.2f}") 67 | rows.append(row) 68 | 69 | table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f") 70 | print(table) 71 | print("\n") 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /test_offline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | from tabulate import tabulate 7 | from torch.utils.data import DataLoader 8 | 9 | from data.lmdb_dataset import LMDBDataset 10 | from data.meta_dataset_reader import TRAIN_METADATASET_NAMES, ALL_METADATASET_NAMES 11 | from models.model_utils import CheckPointer, sigmoid, cosine_sim 12 | from models.model_helpers import get_domain_extractors 13 | from models.losses import prototype_loss 14 | from models.sur import apply_selection, sur 15 | from models.models_dict import DATASET_MODELS_DICT 16 | from utils import device 17 | from config import args 18 | 19 | 20 | def main(): 21 | LIMITER = 600 22 | 23 | # Setting up datasets 24 | extractor_domains = TRAIN_METADATASET_NAMES 25 | all_test_datasets = ALL_METADATASET_NAMES 26 | dump_name = args['dump.name'] if args['dump.name'] else 'test_dump' 27 | testset = LMDBDataset(extractor_domains, all_test_datasets, 28 | args['model.backbone'], 'test', dump_name, LIMITER) 29 | 30 | # define the embedding method 31 | dataset_models = DATASET_MODELS_DICT[args['model.backbone']] 32 | embed_many = get_domain_extractors(extractor_domains, dataset_models, args) 33 | 34 | accs_names = ['SUR'] 35 | all_accs = dict() 36 | # Go over all test datasets 37 | for test_dataset in all_test_datasets: 38 | print(test_dataset) 39 | testset.set_sampling_dataset(test_dataset) 40 | test_loader = DataLoader(testset, batch_size=None, batch_sampler=None, num_workers=16) 41 | all_accs[test_dataset] = {name: [] for name in accs_names} 42 | 43 | for sample in tqdm(test_loader): 44 | context_labels = sample['context_labels'].to(device) 45 | target_labels = sample['target_labels'].to(device) 46 | context_features_dict = {k: v.to(device) for k, v in sample['context_feature_dict'].items()} 47 | target_features_dict = {k: v.to(device) for k, v in sample['target_feature_dict'].items()} 48 | 49 | # optimize selection parameters and perform feature selection 50 | selection_params = sur(context_features_dict, context_labels, max_iter=40) 51 | selected_context = apply_selection(context_features_dict, selection_params) 52 | selected_target = apply_selection(target_features_dict, selection_params) 53 | 54 | final_acc = prototype_loss(selected_context, context_labels, 55 | selected_target, target_labels)[1]['acc'] 56 | all_accs[test_dataset]['SUR'].append(final_acc) 57 | 58 | # Make a nice accuracy table 59 | rows = [] 60 | for dataset_name in all_test_datasets: 61 | row = [dataset_name] 62 | for model_name in accs_names: 63 | acc = np.array(all_accs[dataset_name][model_name]) * 100 64 | mean_acc = acc.mean() 65 | conf = (1.96 * acc.std()) / np.sqrt(len(acc)) 66 | row.append(f"{mean_acc:0.2f} +- {conf:0.2f}") 67 | rows.append(row) 68 | 69 | table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f") 70 | print(table) 71 | print("\n") 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import torch 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from tqdm import tqdm 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from data.meta_dataset_reader import (MetaDatasetBatchReader, 13 | MetaDatasetEpisodeReader) 14 | from models.losses import cross_entropy_loss, prototype_loss 15 | from models.model_utils import (CheckPointer, UniformStepLR, 16 | CosineAnnealRestartLR, ExpDecayLR) 17 | from models.model_helpers import get_model, get_optimizer 18 | from utils import Accumulator 19 | from config import args 20 | 21 | 22 | 23 | def train(): 24 | # initialize datasets and loaders 25 | trainsets, valsets, testsets = args['data.train'], args['data.val'], args['data.test'] 26 | train_loader = MetaDatasetBatchReader('train', trainsets, valsets, testsets, 27 | batch_size=args['train.batch_size']) 28 | val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets) 29 | 30 | # initialize model and optimizer 31 | num_train_classes = train_loader.num_classes('train') 32 | model = get_model(num_train_classes, args) 33 | optimizer = get_optimizer(model, args, params=model.get_parameters()) 34 | 35 | # restoring the last checkpoint 36 | checkpointer = CheckPointer(args, model, optimizer=optimizer) 37 | if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']: 38 | start_iter, best_val_loss, best_val_acc =\ 39 | checkpointer.restore_model(ckpt='last') 40 | else: 41 | print('No checkpoint restoration') 42 | best_val_loss = 999999999 43 | best_val_acc = start_iter = 0 44 | 45 | # define learning rate policy 46 | if args['train.lr_policy'] == "step": 47 | lr_manager = UniformStepLR(optimizer, args, start_iter) 48 | elif "exp_decay" in args['train.lr_policy']: 49 | lr_manager = ExpDecayLR(optimizer, args, start_iter) 50 | elif "cosine" in args['train.lr_policy']: 51 | lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter) 52 | 53 | # defining the summary writer 54 | writer = SummaryWriter(checkpointer.model_path) 55 | 56 | # Training loop 57 | max_iter = args['train.max_iter'] 58 | epoch_loss = {name: [] for name in trainsets} 59 | epoch_acc = {name: [] for name in trainsets} 60 | config = tf.compat.v1.ConfigProto() 61 | config.gpu_options.allow_growth = True 62 | with tf.compat.v1.Session(config=config) as session: 63 | for i in tqdm(range(max_iter)): 64 | if i < start_iter: 65 | continue 66 | 67 | optimizer.zero_grad() 68 | 69 | sample = train_loader.get_train_batch(session) 70 | logits = model.forward(sample['images']) 71 | batch_loss, stats_dict, _ = cross_entropy_loss(logits, sample['labels']) 72 | batch_dataset = sample['dataset_name'] 73 | epoch_loss[batch_dataset].append(stats_dict['loss']) 74 | epoch_acc[batch_dataset].append(stats_dict['acc']) 75 | 76 | batch_loss.backward() 77 | optimizer.step() 78 | lr_manager.step(i) 79 | 80 | if (i + 1) % 200 == 0: 81 | for dataset_name in trainsets: 82 | writer.add_scalar(f"loss/{dataset_name}-train_acc", 83 | np.mean(epoch_loss[dataset_name]), i) 84 | writer.add_scalar(f"accuracy/{dataset_name}-train_acc", 85 | np.mean(epoch_acc[dataset_name]), i) 86 | epoch_loss[dataset_name], epoch_acc[dataset_name] = [], [] 87 | 88 | writer.add_scalar('learning_rate', 89 | optimizer.param_groups[0]['lr'], i) 90 | 91 | # Evaluation inside the training loop 92 | if (i + 1) % args['train.eval_freq'] == 0: 93 | model.eval() 94 | dataset_accs, dataset_losses = [], [] 95 | for valset in valsets: 96 | val_losses, val_accs = [], [] 97 | for j in tqdm(range(args['train.eval_size'])): 98 | with torch.no_grad(): 99 | sample = val_loader.get_validation_task(session, valset) 100 | context_features = model.embed(sample['context_images']) 101 | target_features = model.embed(sample['target_images']) 102 | context_labels = sample['context_labels'] 103 | target_labels = sample['target_labels'] 104 | _, stats_dict, _ = prototype_loss(context_features, context_labels, 105 | target_features, target_labels) 106 | val_losses.append(stats_dict['loss']) 107 | val_accs.append(stats_dict['acc']) 108 | 109 | # write summaries per validation set 110 | dataset_acc, dataset_loss = np.mean(val_accs) * 100, np.mean(val_losses) 111 | dataset_accs.append(dataset_acc) 112 | dataset_losses.append(dataset_loss) 113 | writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss, i) 114 | writer.add_scalar(f"accuracy/{valset}/val_acc", dataset_acc, i) 115 | print(f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}") 116 | 117 | # write summaries averaged over datasets 118 | avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(dataset_accs) 119 | writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i) 120 | writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i) 121 | 122 | # saving checkpoints 123 | if avg_val_acc > best_val_acc: 124 | best_val_loss, best_val_acc = avg_val_loss, avg_val_acc 125 | is_best = True 126 | print('Best model so far!') 127 | else: 128 | is_best = False 129 | checkpointer.save_checkpoint(i, best_val_acc, best_val_loss, 130 | is_best, optimizer=optimizer, 131 | state_dict=model.get_state_dict()) 132 | 133 | model.train() 134 | print(f"Trained and evaluated at {i}") 135 | 136 | writer.close() 137 | if start_iter < max_iter: 138 | print(f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%""") 139 | else: 140 | print(f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}""") 141 | 142 | 143 | if __name__ == '__main__': 144 | train() 145 | -------------------------------------------------------------------------------- /train_pnf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import torch 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from models.losses import cross_entropy_loss, prototype_loss 12 | from models.model_utils import (CheckPointer, UniformStepLR, 13 | CosineAnnealRestartLR, ExpDecayLR) 14 | from data.meta_dataset_reader import (MetaDatasetBatchReader, 15 | MetaDatasetEpisodeReader) 16 | from models.model_helpers import get_model, get_optimizer 17 | from utils import Accumulator 18 | from config import args 19 | 20 | 21 | def train(): 22 | # initialize datasets and loaders 23 | trainsets, valsets, testsets = args['data.train'], args['data.val'], args['data.test'] 24 | train_loader = MetaDatasetBatchReader('train', trainsets, valsets, testsets, 25 | batch_size=args['train.batch_size']) 26 | val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets) 27 | 28 | # initialize model and optimizer 29 | num_train_classes = sum(list(train_loader.dataset_to_n_cats.values())) 30 | model = get_model(num_train_classes, args) 31 | optimizer = get_optimizer(model, args, params=model.get_parameters()) 32 | 33 | # Restoring the last checkpoint 34 | checkpointer = CheckPointer(args, model, optimizer=optimizer) 35 | if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']: 36 | start_iter, best_val_loss, best_val_acc =\ 37 | checkpointer.restore_model(ckpt='last', strict=False) 38 | else: 39 | print('No checkpoint restoration') 40 | best_val_loss = 999999999 41 | best_val_acc = start_iter = 0 42 | 43 | # define learning rate policy 44 | if args['train.lr_policy'] == "step": 45 | lr_manager = UniformStepLR(optimizer, args, start_iter) 46 | elif "exp_decay" in args['train.lr_policy']: 47 | lr_manager = ExpDecayLR(optimizer, args, start_iter) 48 | elif "cosine" in args['train.lr_policy']: 49 | lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter) 50 | 51 | # defining the summary writer 52 | writer = SummaryWriter(checkpointer.model_path) 53 | 54 | # Training loop 55 | max_iter = args['train.max_iter'] 56 | epoch_loss = {name: [] for name in trainsets} 57 | epoch_acc = {name: [] for name in trainsets} 58 | config = tf.compat.v1.ConfigProto() 59 | config.gpu_options.allow_growth = True 60 | with tf.compat.v1.Session(config=config) as session: 61 | for i in tqdm(range(max_iter)): 62 | if i < start_iter: 63 | continue 64 | 65 | optimizer.zero_grad() 66 | sample = train_loader.get_train_batch(session) 67 | batch_dataset = sample['dataset_name'] 68 | dataset_id = sample['dataset_ids'][0].detach().cpu().item() 69 | logits = model.forward(sample['images']) 70 | labels = sample['labels'] 71 | batch_loss, stats_dict, _ = cross_entropy_loss(logits, labels) 72 | epoch_loss[batch_dataset].append(stats_dict['loss']) 73 | epoch_acc[batch_dataset].append(stats_dict['acc']) 74 | 75 | batch_loss.backward() 76 | optimizer.step() 77 | lr_manager.step(i) 78 | 79 | if (i + 1) % 200 == 0: 80 | for dataset_name in trainsets: 81 | writer.add_scalar(f"loss/{dataset_name}-train_acc", 82 | np.mean(epoch_loss[dataset_name]), i) 83 | writer.add_scalar(f"accuracy/{dataset_name}-train_acc", 84 | np.mean(epoch_acc[dataset_name]), i) 85 | epoch_loss[dataset_name], epoch_acc[dataset_name] = [], [] 86 | 87 | writer.add_scalar('learning_rate', 88 | optimizer.param_groups[0]['lr'], i) 89 | 90 | # Evaluation inside the training loop 91 | if (i + 1) % args['train.eval_freq'] == 0: 92 | model.eval() 93 | dataset_accs, dataset_losses = [], [] 94 | for valset in valsets: 95 | dataset_id = train_loader.dataset_name_to_dataset_id[valset] 96 | val_losses, val_accs = [], [] 97 | for j in tqdm(range(args['train.eval_size'])): 98 | with torch.no_grad(): 99 | sample = val_loader.get_validation_task(session, valset) 100 | context_features = model.embed(sample['context_images']) 101 | target_features = model.embed(sample['target_images']) 102 | context_labels = sample['context_labels'] 103 | target_labels = sample['target_labels'] 104 | _, stats_dict, _ = prototype_loss(context_features, context_labels, 105 | target_features, target_labels) 106 | val_losses.append(stats_dict['loss']) 107 | val_accs.append(stats_dict['acc']) 108 | 109 | # write summaries per validation set 110 | dataset_acc, dataset_loss = np.mean(val_accs) * 100, np.mean(val_losses) 111 | dataset_accs.append(dataset_acc) 112 | dataset_losses.append(dataset_loss) 113 | writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss, i) 114 | writer.add_scalar(f"accuracy/{valset}/val_acc", dataset_acc, i) 115 | print(f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}") 116 | 117 | # write summaries averaged over datasets 118 | avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(dataset_accs) 119 | writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i) 120 | writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i) 121 | 122 | # saving checkpoints 123 | if avg_val_acc > best_val_acc: 124 | best_val_loss, best_val_acc = avg_val_loss, avg_val_acc 125 | is_best = True 126 | print('Best model so far!') 127 | else: 128 | is_best = False 129 | checkpointer.save_checkpoint(i, best_val_acc, best_val_loss, 130 | is_best, optimizer=optimizer, 131 | state_dict=model.get_state_dict()) 132 | 133 | model.train() 134 | print(f"Trained and evaluated at {i}") 135 | 136 | writer.close() 137 | if start_iter < max_iter: 138 | print(f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%""") 139 | else: 140 | print(f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}""") 141 | 142 | 143 | if __name__ == '__main__': 144 | train() 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | from time import time 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | class ConfusionMatrix(): 11 | def __init__(self, n_classes): 12 | self.n_classes = n_classes 13 | self.mat = np.zeros([n_classes, n_classes]) 14 | 15 | def update_mat(self, preds, labels, idxs): 16 | idxs = np.array(idxs) 17 | real_pred = idxs[preds] 18 | real_labels = idxs[labels] 19 | self.mat[real_pred, real_labels] += 1 20 | 21 | def get_mat(self): 22 | return self.mat 23 | 24 | 25 | class Accumulator(): 26 | def __init__(self, max_size=2000): 27 | self.max_size = max_size 28 | self.ac = np.empty(0) 29 | 30 | def append(self, v): 31 | self.ac = np.append(self.ac[-self.max_size:], v) 32 | 33 | def reset(self): 34 | self.ac = np.empty(0) 35 | 36 | def mean(self, last=None): 37 | last = last if last else self.max_size 38 | return self.ac[-last:].mean() 39 | 40 | 41 | class IterBeat(): 42 | def __init__(self, freq, length=None): 43 | self.length = length 44 | self.freq = freq 45 | 46 | def step(self, i): 47 | if i == 0: 48 | self.t = time() 49 | self.lastcall = 0 50 | else: 51 | if ((i % self.freq) == 0) or ((i + 1) == self.length): 52 | t = time() 53 | print('{0} / {1} ---- {2:.2f} it/sec'.format( 54 | i, self.length, (i - self.lastcall) / (t - self.t))) 55 | self.lastcall = i 56 | self.t = t 57 | 58 | 59 | class SerializableArray(object): 60 | def __init__(self, array): 61 | self.shape = array.shape 62 | self.data = array.tobytes() 63 | self.dtype = array.dtype 64 | 65 | def get(self): 66 | array = np.frombuffer(self.data, self.dtype) 67 | return np.reshape(array, self.shape) 68 | 69 | 70 | def print_res(array, name, file=None, prec=4, mult=1): 71 | array = np.array(array) * mult 72 | mean, std = np.mean(array), np.std(array) 73 | conf = 1.96 * std / np.sqrt(len(array)) 74 | stat_string = ("test {:s}: {:0.%df} +/- {:0.%df}" 75 | % (prec, prec)).format(name, mean, conf) 76 | print(stat_string) 77 | if file is not None: 78 | with open(file, 'a+') as f: 79 | f.write(stat_string + '\n') 80 | 81 | 82 | def process_copies(embeddings, labels, args): 83 | n_copy = args['test.n_copy'] 84 | test_embeddings = embeddings.view( 85 | args['data.test_query'] * args['data.test_way'], 86 | n_copy, -1).mean(dim=1) 87 | return test_embeddings, labels[0::n_copy] 88 | 89 | 90 | def set_determ(seed=1234): 91 | random.seed(seed) 92 | np.random.seed(seed) 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed(seed) 95 | torch.cuda.manual_seed_all(seed) 96 | 97 | 98 | def merge_dicts(dicts, torch_stack=True): 99 | def stack_fn(l): 100 | if isinstance(l[0], torch.Tensor): 101 | return torch.stack(l) 102 | elif isinstance(l[0], str): 103 | return l 104 | else: 105 | return torch.tensor(l) 106 | 107 | keys = dicts[0].keys() 108 | new_dict = {key: [] for key in keys} 109 | for key in keys: 110 | for d in dicts: 111 | new_dict[key].append(d[key]) 112 | if torch_stack: 113 | for key in keys: 114 | new_dict[key] = stack_fn(new_dict[key]) 115 | return new_dict 116 | 117 | 118 | def voting(preds, pref_ind=0): 119 | n_models = len(preds) 120 | n_test = len(preds[0]) 121 | final_preds = [] 122 | for i in range(n_test): 123 | cur_preds = [preds[k][i] for k in range(n_models)] 124 | classes, counts = np.unique(cur_preds, return_counts=True) 125 | if (counts == max(counts)).sum() > 1: 126 | final_preds.append(preds[pref_ind][i]) 127 | else: 128 | final_preds.append(classes[np.argmax(counts)]) 129 | return final_preds 130 | 131 | 132 | def agreement(preds): 133 | n_preds = preds.shape[0] 134 | mat = np.zeros((n_preds, n_preds)) 135 | for i in range(n_preds): 136 | for j in range(i, n_preds): 137 | mat[i, j] = mat[j, i] = ( 138 | preds[i] == preds[j]).astype('float').mean() 139 | return mat 140 | 141 | 142 | def read_textfile(filename, skip_last_line=True): 143 | with open(filename, 'r') as f: 144 | container = f.read().split('\n') 145 | if skip_last_line: 146 | container = container[:-1] 147 | return container 148 | 149 | 150 | def check_dir(dirname, verbose=True): 151 | """This function creates a directory 152 | in case it doesn't exist""" 153 | try: 154 | # Create target Directory 155 | os.makedirs(dirname) 156 | if verbose: 157 | print("Directory ", dirname, " was created") 158 | except FileExistsError: 159 | if verbose: 160 | print("Directory ", dirname, " already exists") 161 | return dirname 162 | --------------------------------------------------------------------------------