├── .gitignore ├── README.md ├── exps └── pre-extract-feature.py ├── fast-exps ├── lib │ ├── config_utils │ │ ├── __init__.py │ │ ├── flop_benchmark.py │ │ ├── logger.py │ │ └── utils.py │ ├── data │ │ ├── create_features_db.py │ │ ├── lmdb_dataset.py │ │ ├── meta_dataset_config.gin │ │ ├── meta_dataset_processing.py │ │ └── meta_dataset_reader.py │ ├── datasets │ │ ├── EpisodeMetadata.py │ │ └── __init__.py │ ├── models │ │ ├── losses.py │ │ ├── model_utils.py │ │ ├── models_dict.py │ │ ├── new_model_helpers.py │ │ ├── new_prop_prototype.py │ │ ├── resnet18.py │ │ └── resnet18_film.py │ ├── paths.py │ └── utils.py └── urt-avg-head.py ├── fast-scripts └── urt-avg-head.sh └── scripts └── pre-extract-feature.sh /.gitignore: -------------------------------------------------------------------------------- 1 | core.* 2 | weights 3 | *.swo 4 | */*.swo 5 | */*.swp 6 | outputs 7 | logs 8 | *.pyc 9 | *.swp 10 | core.* 11 | *.pth 12 | save 13 | NEWsave 14 | __pycache__ 15 | output 16 | fast-outputs 17 | fast-outputs* 18 | *.tar 19 | *.pth.back 20 | *.zip 21 | */.DS_Store 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2021] A Universal Representation Transformer Layer for Few-Shot Image Classification 2 | 3 | 4 | ## Dependencies 5 | This code requires the following: 6 | * Python 3.6 or greater 7 | * PyTorch 1.0 or greater 8 | * TensorFlow 1.14 or greater 9 | 10 | 11 | ## Data Preparation 12 | 1. Meta-Dataset: 13 | 14 | Follow the the "User instructions" in the [Meta-Dataset repository](https://github.com/google-research/meta-dataset#user-instructions) for "Installation" and "Downloading and converting datasets". 15 | 16 | 2. Additional Test Datasets: 17 | 18 | If you want to test on additional datasets, i.e., MNIST, CIFAR10, CIFAR100, follow the installation instructions in the [CNAPs repository](https://github.com/cambridge-mlg/cnaps) to get these datasets. 19 | 20 | ## Getting the Feature Extractors 21 | 22 | URT can be built on top of backbones pretrained in any ways. 23 | 24 | The easiest way is to download SUR's pre-trained models and use them to obtain a universal set of features directly. If that is what you want, execute the following command in the root directory of this project:```wget http://thoth.inrialpes.fr/research/SUR/all_weights.zip && unzip all_weights.zip && rm all_weights.zip``` 25 | It will donwnload all the weights and place them in the `./weights` directory. 26 | Or pretrain the backbone by yourself on the training sets of Meta-Dataset and put the model weights under the directory of `./weights`. 27 | 28 | ## Train and evaluate URT 29 | 30 | 31 | ### Dumping features (for efficient training and evaluation) 32 | We found that the bottleneck of training URT is extracting features from CNN. Since we freeze the CNN when training the URT, we found dumping the extracted feature episodes can significantly speed up the training procedure from days to ~2 hours. The easiest way is to download all the extracted features from [HERE](https://drive.google.com/drive/folders/1Z3gsa4TSSiH2wTZj1Jp5bD7UEKPOVzx5?usp=sharing) and put it in the ${cache_dir}. 33 | Or you can extract by your own via ```bash ./scripts/pre-extract-feature.sh resnet18 ${cache_dir}``` 34 | 35 | ### Train and evaluate 36 | run command from the dir of this repo: ```bash ./fast-scripts/urt-avg-head.sh ${log_dir} ${num_head} ${penalty_coef} ${cache_dir}```, where the ${num_head}=2 and ${penalty_coef}=0.1 in our paper. 37 | -------------------------------------------------------------------------------- /exps/pre-extract-feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os, sys, time, json, random, argparse 3 | import collections 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | from pathlib import Path 10 | lib_dir = (Path(__file__).parent / '..' / 'fast-exps' / 'lib').resolve() 11 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 12 | 13 | from data.meta_dataset_reader import MetaDatasetEpisodeReader, TRAIN_METADATASET_NAMES, ALL_METADATASET_NAMES 14 | from models.new_model_helpers import get_extractors, extract_features 15 | # from xmodels.new_model_helpers import get_extractors, extract_features 16 | from models.models_dict import DATASET_MODELS_DICT 17 | from utils import convert_secs2time, time_string, AverageMeter 18 | from paths import META_RECORDS_ROOT 19 | 20 | from config_utils import Logger 21 | 22 | 23 | def load_config(): 24 | 25 | parser = argparse.ArgumentParser(description='Train prototypical networks') 26 | parser.add_argument('--save_dir', type=str, help="The saved path in dir.") 27 | 28 | # model args 29 | parser.add_argument('--model.backbone', default='resnet18', help="Use ResNet18 for experiments (default: False)") 30 | parser.add_argument('--model.classifier', type=str, default='cosine', choices=['none', 'linear', 'cosine'], help="Do classification using cosine similatity between activations and weights") 31 | 32 | # train args 33 | parser.add_argument('--train.max_iter', type=int, default=10000, help='number of epochs to train (default: 10000)') 34 | parser.add_argument('--eval.max_iter', type=int, default=600, help='number of epochs to train (default: 10000)') 35 | 36 | xargs = vars(parser.parse_args()) 37 | return xargs 38 | 39 | 40 | def extract_eval_dataset(backbone, mode, extractors, all_test_datasets, test_loader, num_iters, logger, save_dir): 41 | # dataset_models = DATASET_MODELS_DICT[backbone] 42 | 43 | logger.print('\n{:} starting extract the {:} mode by {:} iters.'.format(time_string(), mode, save_dir, num_iters)) 44 | config = tf.compat.v1.ConfigProto() 45 | config.gpu_options.allow_growth = True 46 | with tf.compat.v1.Session(config=config) as session: 47 | for idata, test_dataset in enumerate(all_test_datasets): 48 | logger.print('===>>> {:} --->>> {:02d}/{:02d} --->>> {:}'.format(time_string(), idata, len(all_test_datasets), test_dataset)) 49 | x_save_dir = save_dir / '{:}-{:}'.format(mode, num_iters) / '{:}'.format(test_dataset) 50 | x_save_dir.mkdir(parents=True, exist_ok=True) 51 | for idx in tqdm(range(num_iters)): 52 | # extract image features and labels 53 | if mode == "val": 54 | sample = test_loader.get_validation_task(session, test_dataset) 55 | elif mode == "test": 56 | sample = test_loader.get_test_task(session, test_dataset) 57 | else: 58 | raise ValueError("invalid mode:{}".format(mode)) 59 | 60 | with torch.no_grad(): 61 | context_labels = sample['context_labels'] 62 | target_labels = sample['target_labels'] 63 | # batch x #extractors x #features 64 | context_features = extract_features(extractors, sample['context_images']) 65 | target_features = extract_features(extractors, sample['target_images']) 66 | to_save_info = {'context_features': context_features.cpu(), 67 | 'context_labels': context_labels.cpu(), 68 | 'target_features': target_features.cpu(), 69 | 'target_labels': target_labels.cpu()} 70 | save_name = x_save_dir / '{:06d}.pth'.format(idx) 71 | torch.save(to_save_info, save_name) 72 | 73 | 74 | def main(xargs): 75 | 76 | # set up logger 77 | log_dir = Path(xargs['save_dir']).resolve() 78 | log_dir.mkdir(parents=True, exist_ok=True) 79 | #log_dir = "./NEWsave/{}_{}_allcache_{}_{}_{}_{}_{}_{}".format(args['train.optimizer'], args['train.scheduler'], args['prop.n_hop'], args['prop.temp'], args['prop.nonlinear'], args['prop.transform'], args['prop.layer_type'], args['prop.layer_type.att_space']) 80 | 81 | logger = Logger(str(log_dir), 888) 82 | logger.print('{:} --- args ---'.format(time_string())) 83 | for key, value in xargs.items(): 84 | logger.print(' [{:10s}] : {:}'.format(key, value)) 85 | logger.print('{:} --- args ---'.format(time_string())) 86 | 87 | # Setting up datasets 88 | extractor_domains = TRAIN_METADATASET_NAMES 89 | all_val_datasets = TRAIN_METADATASET_NAMES 90 | all_test_datasets = ALL_METADATASET_NAMES 91 | train_loader_lst = [MetaDatasetEpisodeReader('train', [d], [d], all_test_datasets) for d in extractor_domains] 92 | val_loader = MetaDatasetEpisodeReader('val' , extractor_domains, extractor_domains, all_test_datasets) 93 | test_loader = MetaDatasetEpisodeReader('test', extractor_domains, extractor_domains, all_test_datasets) 94 | class_name_dict = collections.OrderedDict() 95 | for d in extractor_domains: 96 | with open("{:}/{:}/dataset_spec.json".format(META_RECORDS_ROOT, d)) as f: 97 | data = json.load(f) 98 | class_name_dict[d] = data['class_names'] 99 | 100 | # initialize the feature extractors 101 | dataset_models = DATASET_MODELS_DICT[xargs['model.backbone']] 102 | extractors = get_extractors(extractor_domains, dataset_models, xargs['model.backbone'], xargs['model.classifier'], False) 103 | 104 | extract_eval_dataset(xargs['model.backbone'], 'test', extractors, all_test_datasets, test_loader, xargs['eval.max_iter'], logger, log_dir) 105 | # stop at here 106 | extract_eval_dataset(xargs['model.backbone'], 'val' , extractors, all_val_datasets , val_loader , xargs['eval.max_iter'], logger, log_dir) 107 | 108 | config = tf.compat.v1.ConfigProto() 109 | config.gpu_options.allow_growth = True 110 | 111 | xsave_dir = log_dir / 'train-{:}'.format(xargs['train.max_iter']) 112 | xsave_dir.mkdir(parents=True, exist_ok=True) 113 | logger.print('{:} save into {:}'.format(time_string(), xsave_dir)) 114 | with tf.compat.v1.Session(config=config) as session: 115 | for idx in tqdm(range(xargs['train.max_iter'])): 116 | if random.random() > 0.5: 117 | ep_domain = extractor_domains[0] 118 | else: 119 | ep_domain = random.choice(extractor_domains[1:]) 120 | domain_idx = extractor_domains.index(ep_domain) 121 | train_loader = train_loader_lst[domain_idx] 122 | samples = train_loader.get_train_task(session) 123 | # import pdb; pdb.set_trace() 124 | domain_extractor = extractors[ep_domain] 125 | 126 | with torch.no_grad(): 127 | # batch x #extractors x #features 128 | context_labels = samples['context_labels'].cpu() 129 | target_labels = samples['target_labels'].cpu() 130 | context_features = extract_features(extractors, samples['context_images']) 131 | target_features = extract_features(extractors, samples['target_images']) 132 | to_save_info = {'context_features': context_features.cpu(), 133 | 'context_labels': context_labels.cpu(), 134 | 'target_features': target_features.cpu(), 135 | 'target_labels': target_labels.cpu(), 136 | 'ep_domain': ep_domain, 137 | 'domain_idx': domain_idx} 138 | save_name = xsave_dir / '{:06d}.pth'.format(idx) 139 | torch.save(to_save_info, save_name) 140 | 141 | 142 | if __name__ == '__main__': 143 | xargs = load_config() 144 | main(xargs) 145 | -------------------------------------------------------------------------------- /fast-exps/lib/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .utils import AverageMeter, obtain_accuracy, convert_secs2time, time_string 3 | from .flop_benchmark import count_parameters_in_MB 4 | -------------------------------------------------------------------------------- /fast-exps/lib/config_utils/flop_benchmark.py: -------------------------------------------------------------------------------- 1 | import copy, torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def count_parameters_in_MB(model): 7 | if isinstance(model, nn.Module): 8 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 9 | else: 10 | return np.sum(np.prod(v.size()) for v in model)/1e6 11 | 12 | 13 | def get_model_infos(model, shape): 14 | #model = copy.deepcopy( model ) 15 | 16 | model = add_flops_counting_methods(model) 17 | #model = model.cuda() 18 | model.eval() 19 | 20 | #cache_inputs = torch.zeros(*shape).cuda() 21 | cache_inputs = torch.zeros(*shape) 22 | if next(model.parameters()).is_cuda: cache_inputs = cache_inputs.cuda() 23 | #print_log('In the calculating function : cache input size : {:}'.format(cache_inputs.size()), log) 24 | with torch.no_grad(): 25 | _____ = model(cache_inputs) 26 | FLOPs = compute_average_flops_cost( model ) / 1e6 27 | Param = count_parameters_in_MB(model) 28 | 29 | if hasattr(model, 'auxiliary_param'): 30 | aux_params = count_parameters_in_MB(model.auxiliary_param()) 31 | print ('The auxiliary params of this model is : {:}'.format(aux_params)) 32 | print ('We remove the auxiliary params from the total params ({:}) when counting'.format(Param)) 33 | Param = Param - aux_params 34 | 35 | #print_log('FLOPs : {:} MB'.format(FLOPs), log) 36 | torch.cuda.empty_cache() 37 | model.apply( remove_hook_function ) 38 | return FLOPs, Param 39 | 40 | 41 | # ---- Public functions 42 | def add_flops_counting_methods( model ): 43 | model.__batch_counter__ = 0 44 | add_batch_counter_hook_function( model ) 45 | model.apply( add_flops_counter_variable_or_reset ) 46 | model.apply( add_flops_counter_hook_function ) 47 | return model 48 | 49 | 50 | 51 | def compute_average_flops_cost(model): 52 | """ 53 | A method that will be available after add_flops_counting_methods() is called on a desired net object. 54 | Returns current mean flops consumption per image. 55 | """ 56 | batches_count = model.__batch_counter__ 57 | flops_sum = 0 58 | #or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 59 | for module in model.modules(): 60 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 61 | or isinstance(module, torch.nn.Conv1d) \ 62 | or hasattr(module, 'calculate_flop_self'): 63 | flops_sum += module.__flops__ 64 | return flops_sum / batches_count 65 | 66 | 67 | # ---- Internal functions 68 | def pool_flops_counter_hook(pool_module, inputs, output): 69 | batch_size = inputs[0].size(0) 70 | kernel_size = pool_module.kernel_size 71 | out_C, output_height, output_width = output.shape[1:] 72 | assert out_C == inputs[0].size(1), '{:} vs. {:}'.format(out_C, inputs[0].size()) 73 | 74 | overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size 75 | pool_module.__flops__ += overall_flops 76 | 77 | 78 | def self_calculate_flops_counter_hook(self_module, inputs, output): 79 | overall_flops = self_module.calculate_flop_self(inputs[0].shape, output.shape) 80 | self_module.__flops__ += overall_flops 81 | 82 | 83 | def fc_flops_counter_hook(fc_module, inputs, output): 84 | batch_size = inputs[0].size(0) 85 | xin, xout = fc_module.in_features, fc_module.out_features 86 | assert xin == inputs[0].size(1) and xout == output.size(1), 'IO=({:}, {:})'.format(xin, xout) 87 | overall_flops = batch_size * xin * xout 88 | if fc_module.bias is not None: 89 | overall_flops += batch_size * xout 90 | fc_module.__flops__ += overall_flops 91 | 92 | 93 | def conv1d_flops_counter_hook(conv_module, inputs, outputs): 94 | batch_size = inputs[0].size(0) 95 | outL = outputs.shape[-1] 96 | [kernel] = conv_module.kernel_size 97 | in_channels = conv_module.in_channels 98 | out_channels = conv_module.out_channels 99 | groups = conv_module.groups 100 | conv_per_position_flops = kernel * in_channels * out_channels / groups 101 | 102 | active_elements_count = batch_size * outL 103 | overall_flops = conv_per_position_flops * active_elements_count 104 | 105 | if conv_module.bias is not None: 106 | overall_flops += out_channels * active_elements_count 107 | conv_module.__flops__ += overall_flops 108 | 109 | 110 | def conv2d_flops_counter_hook(conv_module, inputs, output): 111 | batch_size = inputs[0].size(0) 112 | output_height, output_width = output.shape[2:] 113 | 114 | kernel_height, kernel_width = conv_module.kernel_size 115 | in_channels = conv_module.in_channels 116 | out_channels = conv_module.out_channels 117 | groups = conv_module.groups 118 | conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups 119 | 120 | active_elements_count = batch_size * output_height * output_width 121 | overall_flops = conv_per_position_flops * active_elements_count 122 | 123 | if conv_module.bias is not None: 124 | overall_flops += out_channels * active_elements_count 125 | conv_module.__flops__ += overall_flops 126 | 127 | 128 | def batch_counter_hook(module, inputs, output): 129 | # Can have multiple inputs, getting the first one 130 | inputs = inputs[0] 131 | batch_size = inputs.shape[0] 132 | module.__batch_counter__ += batch_size 133 | 134 | 135 | def add_batch_counter_hook_function(module): 136 | if not hasattr(module, '__batch_counter_handle__'): 137 | handle = module.register_forward_hook(batch_counter_hook) 138 | module.__batch_counter_handle__ = handle 139 | 140 | 141 | def add_flops_counter_variable_or_reset(module): 142 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear) \ 143 | or isinstance(module, torch.nn.Conv1d) \ 144 | or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d) \ 145 | or hasattr(module, 'calculate_flop_self'): 146 | module.__flops__ = 0 147 | 148 | 149 | def add_flops_counter_hook_function(module): 150 | if isinstance(module, torch.nn.Conv2d): 151 | if not hasattr(module, '__flops_handle__'): 152 | handle = module.register_forward_hook(conv2d_flops_counter_hook) 153 | module.__flops_handle__ = handle 154 | elif isinstance(module, torch.nn.Conv1d): 155 | if not hasattr(module, '__flops_handle__'): 156 | handle = module.register_forward_hook(conv1d_flops_counter_hook) 157 | module.__flops_handle__ = handle 158 | elif isinstance(module, torch.nn.Linear): 159 | if not hasattr(module, '__flops_handle__'): 160 | handle = module.register_forward_hook(fc_flops_counter_hook) 161 | module.__flops_handle__ = handle 162 | elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): 163 | if not hasattr(module, '__flops_handle__'): 164 | handle = module.register_forward_hook(pool_flops_counter_hook) 165 | module.__flops_handle__ = handle 166 | elif hasattr(module, 'calculate_flop_self'): # self-defined module 167 | if not hasattr(module, '__flops_handle__'): 168 | handle = module.register_forward_hook(self_calculate_flops_counter_hook) 169 | module.__flops_handle__ = handle 170 | 171 | 172 | def remove_hook_function(module): 173 | hookers = ['__batch_counter_handle__', '__flops_handle__'] 174 | for hooker in hookers: 175 | if hasattr(module, hooker): 176 | handle = getattr(module, hooker) 177 | handle.remove() 178 | keys = ['__flops__', '__batch_counter__', '__flops__'] + hookers 179 | for ckey in keys: 180 | if hasattr(module, ckey): delattr(module, ckey) 181 | -------------------------------------------------------------------------------- /fast-exps/lib/config_utils/logger.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | import numpy as np 4 | import scipy.misc 5 | import pprint 6 | 7 | pp = pprint.PrettyPrinter(indent=4) 8 | 9 | try: 10 | from StringIO import StringIO # Python 2.7 11 | except ImportError: 12 | from io import BytesIO # Python 3.x 13 | 14 | from .utils import time_for_file 15 | 16 | class Logger(object): 17 | 18 | def __init__(self, log_dir, seed): 19 | """Create a summary writer logging to log_dir.""" 20 | self.log_dir = Path("{:}".format(str(log_dir))) 21 | if not self.log_dir.exists(): os.makedirs(str(self.log_dir)) 22 | 23 | self.log_file = '{:}/log-{:}-date-{:}.txt'.format(self.log_dir, seed, time_for_file()) 24 | self.file_writer = open(self.log_file, 'w') 25 | 26 | def checkpoint(self, name): 27 | return self.log_dir / name 28 | 29 | def print(self, string, fprint=True, is_pp=False): 30 | if is_pp: pp.pprint (string) 31 | else: print(string) 32 | if fprint: 33 | self.file_writer.write('{:}\n'.format(string)) 34 | self.file_writer.flush() 35 | 36 | def close(self): 37 | self.file_writer.close() 38 | -------------------------------------------------------------------------------- /fast-exps/lib/config_utils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from torch.optim.optimizer import Optimizer 6 | import math 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def obtain_accuracy(output, target, topk=(1,)): 25 | with torch.no_grad(): 26 | """Computes the precision@k for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) # bs*k 31 | pred = pred.t() # t: transpose, k*bs 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) # 1*bs --> k*bs 33 | 34 | res = [] 35 | for k in topk: 36 | correct_k = correct[:k].view(-1).float().sum(0) 37 | res.append(correct_k.mul_(100.0 / batch_size)) 38 | return res 39 | 40 | 41 | def obtain_per_class_accuracy(predictions, xtargets): 42 | top1 = torch.argmax(predictions, dim=1) 43 | cls2accs = [] 44 | for cls in sorted(list(set(xtargets.tolist()))): 45 | selects = xtargets == cls 46 | accuracy = (top1[selects] == xtargets[selects]).float().mean() * 100 47 | cls2accs.append( accuracy.item() ) 48 | return sum(cls2accs) / len(cls2accs) 49 | 50 | 51 | def convert_secs2time(epoch_time, string=True): 52 | need_hour = int(epoch_time / 3600) 53 | need_mins = int((epoch_time - 3600*need_hour) / 60) 54 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 55 | if string: 56 | need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 57 | return need_time 58 | else: 59 | return need_hour, need_mins, need_secs 60 | 61 | def time_string(): 62 | ISOTIMEFORMAT='%Y-%m-%d-%X' 63 | string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 64 | return string 65 | 66 | def time_for_file(): 67 | ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' 68 | string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 69 | return string 70 | 71 | def spm_to_tensor(sparse_mx): 72 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 73 | indices = torch.from_numpy(np.vstack( 74 | (sparse_mx.row, sparse_mx.col))).long() 75 | values = torch.from_numpy(sparse_mx.data) 76 | shape = torch.Size(sparse_mx.shape) 77 | return torch.sparse.FloatTensor(indices, values, shape) 78 | -------------------------------------------------------------------------------- /fast-exps/lib/data/create_features_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import tensorflow as tf 6 | from tqdm import tqdm 7 | 8 | import json 9 | import lmdb 10 | import pickle as pkl 11 | 12 | sys.path.insert(0, '/'.join(os.path.realpath(__file__).split('/')[:-2])) 13 | 14 | from data.meta_dataset_reader import MetaDatasetEpisodeReader, MetaDatasetBatchReader 15 | from models.model_utils import CheckPointer 16 | from models.models_dict import DATASET_MODELS_DICT 17 | from models.model_helpers import get_domain_extractors, get_model 18 | from config import args 19 | from paths import META_DATA_ROOT 20 | from utils import check_dir, SerializableArray 21 | 22 | 23 | 24 | class DatasetWriter(object): 25 | def __init__(self, args, rewrite=True, write_frequency=10): 26 | self._mode = args['dump.mode'] 27 | self._write_frequency = write_frequency 28 | self._db = None 29 | self.args = args 30 | self.dataset_models = DATASET_MODELS_DICT[args['model.backbone']] 31 | print(self.dataset_models) 32 | 33 | trainsets, valsets, testsets = args['data.train'], args['data.val'], args['data.test'] 34 | loader = MetaDatasetEpisodeReader(self._mode, trainsets, valsets, testsets) 35 | self._map_size = 50000 * 100 ** 2 * 512 * 8 36 | self.trainsets = trainsets 37 | 38 | if self._mode == 'train': 39 | evalset = "allcat" 40 | self.load_sample = lambda sess: loader.get_train_task(sess) 41 | elif self._mode == 'test': 42 | evalset = testsets[0] 43 | self.load_sample = lambda sess: loader.get_test_task(sess, evalset) 44 | elif self._mode == 'val': 45 | evalset = valsets[0] 46 | self.load_sample = lambda sess: loader.get_validation_task(sess, evalset) 47 | 48 | dump_name = mode + '_dump' if not args['dump.name'] else args['dump.name'] 49 | path = check_dir(os.path.join(META_DATA_ROOT, 'Dumps', self.args['model.backbone'], 50 | self._mode, evalset, dump_name)) 51 | self._path = path 52 | if not (os.path.exists(path)): 53 | os.mkdir(path) 54 | self._keys_file = os.path.join(path, 'keys') 55 | self._keys = [] 56 | if os.path.exists(path) and not rewrite: 57 | raise NameError("Dataset {} already exists.".format(self._path)) 58 | 59 | # do not initialize during __init__ to avoid pickling error when using MPI 60 | def init(self): 61 | self._db = lmdb.open(self._path, map_size=self._map_size, map_async=True) 62 | self.embed_many = get_domain_extractors(self.trainsets, self.dataset_models, self.args) 63 | 64 | def close(self): 65 | keys = tuple(self._keys) 66 | pkl.dump(keys, open(self._keys_file, 'wb')) 67 | 68 | if self._db is not None: 69 | self._db.sync() 70 | self._db.close() 71 | self._db = None 72 | 73 | def encode_dataset(self, n_tasks=1000): 74 | if self._db is None: 75 | self.init() 76 | 77 | txn = self._db.begin(write=True) 78 | config = tf.compat.v1.ConfigProto() 79 | config.gpu_options.allow_growth = True 80 | with tf.compat.v1.Session(config=config) as session: 81 | for idx in tqdm(range(n_tasks)): 82 | # compressing image 83 | sample = self.load_sample(session) 84 | support_embed_dict = self.embed_many(sample['context_images']) 85 | query_embed_dict = self.embed_many(sample['target_images']) 86 | support_labels = SerializableArray(sample['context_labels'].detach().cpu().numpy()) 87 | query_labels = SerializableArray(sample['target_labels'].detach().cpu().numpy()) 88 | SerializableArray.__module__ = 'utils' 89 | 90 | # writing 91 | for dataset in support_embed_dict.keys(): 92 | support_batch = SerializableArray( 93 | support_embed_dict[dataset].detach().cpu().numpy()) 94 | query_batch = SerializableArray( 95 | query_embed_dict[dataset].detach().cpu().numpy()) 96 | SerializableArray.__module__ = 'utils' 97 | txn.put(f"{idx}_{dataset}_support".encode("ascii"), pkl.dumps(support_batch)) 98 | txn.put(f"{idx}_{dataset}_query".encode("ascii"), pkl.dumps(query_batch)) 99 | self._keys.extend([f"{idx}_{dataset}_support", f"{idx}_{dataset}_query"]) 100 | txn.put(f"{idx}_labels_support".encode("ascii"), pkl.dumps(support_labels)) 101 | txn.put(f"{idx}_labels_query".encode("ascii"), pkl.dumps(query_labels)) 102 | self._keys.extend([f"{idx}_labels_support", f"{idx}_labels_query"]) 103 | 104 | # flushing into lmdb 105 | if idx > 0 and idx % self._write_frequency == 0: 106 | txn.commit() 107 | txn = self._db.begin(write=True) 108 | txn.commit() 109 | 110 | 111 | if __name__ == '__main__': 112 | dr = DatasetWriter(args) 113 | dr.init() 114 | dr.encode_dataset(args['dump.size']) 115 | dr.close() 116 | print('Done') 117 | -------------------------------------------------------------------------------- /fast-exps/lib/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import lmdb 8 | import pickle as pkl 9 | from utils import SerializableArray, device 10 | 11 | from paths import META_DATA_ROOT 12 | 13 | class LMDBDataset: 14 | """ 15 | Opens several LMDB readers and loads data from there 16 | """ 17 | def __init__(self, extractor_domains, datasets, backbone, mode, dump_name, limit_len=None): 18 | self.mode = mode 19 | self.datasets = datasets 20 | 21 | self.dataset_readers = dict() 22 | for evalset in datasets: 23 | all_names = os.listdir(os.path.join(META_DATA_ROOT, 'Dumps', 24 | backbone, mode, evalset)) 25 | self.dataset_readers[evalset] = [ 26 | DatasetReader(extractor_domains, evalset, backbone, mode, name) 27 | for name in all_names if dump_name in name] 28 | self._current_sampling_dataset = datasets[0] 29 | self.full_len = sum([len(ds) for ds in self.dataset_readers[self._current_sampling_dataset]]) 30 | if limit_len is not None: 31 | self.full_len = min(self.full_len, limit_len) 32 | 33 | def __len__(self): 34 | return self.full_len 35 | 36 | def __getitem__(self, idx): 37 | if self.mode == 'train': 38 | random_lmdb_subset = random.sample(self.dataset_readers[self._current_sampling_dataset], 1)[0] 39 | idx = random.sample(range(len(random_lmdb_subset)), 1)[0] 40 | sample = random_lmdb_subset[idx] 41 | else: 42 | sample = self.dataset_readers[self._current_sampling_dataset][0][idx] 43 | 44 | for key, val in sample.items(): 45 | if isinstance(val, str): 46 | pass 47 | if 'label' in key: 48 | sample[key] = torch.from_numpy(val).long() 49 | elif 'feature_dict' in key: 50 | for fkey, fval in sample[key].items(): 51 | sample[key][fkey] = torch.from_numpy(fval) 52 | 53 | return sample 54 | 55 | def set_sampling_dataset(self, sampling_dataset): 56 | self._current_sampling_dataset = sampling_dataset 57 | 58 | # open lmdb environment and transaction 59 | # load keys from cache 60 | def _load_db(self, info, class_id): 61 | path = self._path 62 | 63 | self._env = lmdb.open( 64 | self._path, 65 | readonly=True, 66 | lock=False, 67 | readahead=False, 68 | meminit=False) 69 | self._txn = self._env.begin(write=False) 70 | 71 | if class_id is None: 72 | cache_file = os.path.join(path, 'keys') 73 | if os.path.isfile(cache_file): 74 | self.keys = pkl.load(open(cache_file, 'rb')) 75 | else: 76 | print('Loading dataset keys...') 77 | with self._env.begin(write=False) as txn: 78 | self.keys = [key.decode('ascii') 79 | for key, _ in tqdm(txn.cursor())] 80 | pkl.dump(self.keys, open(cache_file, 'wb')) 81 | else: 82 | self.keys = [str(k).encode() for k in info['labels2keys'][str(class_id)]] 83 | 84 | if not self.keys: 85 | raise ValueError('Empty dataset.') 86 | 87 | def eval(self): 88 | self.mode = 'eval' 89 | 90 | def train(self): 91 | self.mode = 'train' 92 | 93 | def transform(self, x): 94 | if self.mode == 'train': 95 | out = self.train_transform(x) if self.train_transform else x 96 | return out 97 | else: 98 | out = self.test_transform(x) if self.test_transform else x 99 | return out 100 | 101 | 102 | class DatasetReader(object): 103 | """ 104 | Opens a single LMDB file, containing dumped activations for a dataset, 105 | and samples data from it. 106 | """ 107 | def __init__(self, extractor_domains, evalset, backbone, mode, name): 108 | self._mode = mode 109 | self._env = None 110 | self._txn = None 111 | self.keys = None 112 | 113 | self.trainsets = extractor_domains 114 | path = os.path.join(META_DATA_ROOT, 'Dumps', backbone, mode, evalset, name) 115 | self._path = path 116 | 117 | self._load_db() 118 | 119 | def __len__(self): 120 | return self.full_len 121 | 122 | def _load_db(self): 123 | path = self._path 124 | 125 | self._env = lmdb.open( 126 | self._path, 127 | readonly=True, 128 | lock=False, 129 | readahead=False, 130 | meminit=False) 131 | self._txn = self._env.begin(write=False) 132 | 133 | cache_file = os.path.join(path, 'keys') 134 | if os.path.isfile(cache_file): 135 | self.keys = pkl.load(open(cache_file, 'rb')) 136 | else: 137 | print('Loading dataset keys...') 138 | with self._env.begin(write=False) as txn: 139 | self.keys = [key.decode('ascii') 140 | for key, _ in tqdm(txn.cursor())] 141 | pkl.dump(self.keys, open(cache_file, 'wb')) 142 | self.full_len = len(self.keys) // 18 143 | 144 | def __getitem__(self, idx): 145 | sample = dict() 146 | support_labels = pkl.loads(self._txn.get(f"{idx}_labels_support".encode("ascii"))) 147 | query_labels = pkl.loads(self._txn.get(f"{idx}_labels_query".encode("ascii"))) 148 | sample['context_labels'] = support_labels.get() 149 | sample['target_labels'] = query_labels.get() 150 | 151 | sample['context_feature_dict'] = dict() 152 | sample['target_feature_dict'] = dict() 153 | for dataset in self.trainsets: 154 | support_batch = pkl.loads(self._txn.get(f"{idx}_{dataset}_support".encode("ascii"))) 155 | query_batch = pkl.loads(self._txn.get(f"{idx}_{dataset}_query".encode("ascii"))) 156 | sample['context_feature_dict'][dataset] = support_batch.get() 157 | sample['target_feature_dict'][dataset] = query_batch.get() 158 | return sample 159 | -------------------------------------------------------------------------------- /fast-exps/lib/data/meta_dataset_config.gin: -------------------------------------------------------------------------------- 1 | import data.meta_dataset_processing 2 | import meta_dataset.data.decoder 3 | 4 | # Default values for sampling variable shots / ways. 5 | EpisodeDescriptionConfig.min_ways = 5 6 | EpisodeDescriptionConfig.max_ways_upper_bound = 50 7 | EpisodeDescriptionConfig.max_num_query = 10 8 | EpisodeDescriptionConfig.max_support_set_size = 500 9 | EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100 10 | EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529 # np.log(0.5) 11 | EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529 # np.log(2) 12 | EpisodeDescriptionConfig.ignore_dag_ontology = False 13 | EpisodeDescriptionConfig.ignore_bilevel_ontology = False 14 | 15 | # Other default values for the data pipeline. 16 | DataConfig.image_height = 84 17 | DataConfig.shuffle_buffer_size = 1000 18 | DataConfig.read_buffer_size_bytes = 1048576 # 1 MB (1024**2) 19 | DataConfig.num_prefetch = 400 20 | meta_dataset_processing.ImageDecoder.image_size = 84 21 | 22 | # If we decode features then change the lines below to use FeatureDecoder. 23 | process_episode.support_decoder = @support/meta_dataset_processing.ImageDecoder() 24 | support/meta_dataset_processing.ImageDecoder.data_augmentation = @support/meta_dataset_processing.DataAugmentation() 25 | support/meta_dataset_processing.DataAugmentation.enable_jitter = True 26 | support/meta_dataset_processing.DataAugmentation.jitter_amount = 0 27 | support/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = True 28 | support/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 29 | support/meta_dataset_processing.DataAugmentation.enable_random_flip = False 30 | support/meta_dataset_processing.DataAugmentation.enable_random_brightness = False 31 | support/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0 32 | support/meta_dataset_processing.DataAugmentation.enable_random_contrast = False 33 | support/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0 34 | support/meta_dataset_processing.DataAugmentation.enable_random_hue = False 35 | support/meta_dataset_processing.DataAugmentation.random_hue_delta = 0 36 | support/meta_dataset_processing.DataAugmentation.enable_random_saturation = False 37 | support/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0 38 | 39 | process_episode.query_decoder = @query/meta_dataset_processing.ImageDecoder() 40 | query/meta_dataset_processing.ImageDecoder.data_augmentation = @query/meta_dataset_processing.DataAugmentation() 41 | query/meta_dataset_processing.DataAugmentation.enable_jitter = False 42 | query/meta_dataset_processing.DataAugmentation.jitter_amount = 0 43 | query/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = False 44 | query/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 45 | query/meta_dataset_processing.DataAugmentation.enable_random_flip = False 46 | query/meta_dataset_processing.DataAugmentation.enable_random_brightness = False 47 | query/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0 48 | query/meta_dataset_processing.DataAugmentation.enable_random_contrast = False 49 | query/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0 50 | query/meta_dataset_processing.DataAugmentation.enable_random_hue = False 51 | query/meta_dataset_processing.DataAugmentation.random_hue_delta = 0 52 | query/meta_dataset_processing.DataAugmentation.enable_random_saturation = False 53 | query/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0 54 | 55 | process_batch.batch_decoder = @batch/meta_dataset_processing.ImageDecoder() 56 | batch/meta_dataset_processing.ImageDecoder.data_augmentation = @batch/meta_dataset_processing.DataAugmentation() 57 | batch/meta_dataset_processing.DataAugmentation.enable_jitter = True 58 | batch/meta_dataset_processing.DataAugmentation.jitter_amount = 8 59 | batch/meta_dataset_processing.DataAugmentation.enable_gaussian_noise = True 60 | batch/meta_dataset_processing.DataAugmentation.gaussian_noise_std = 0.0 61 | batch/meta_dataset_processing.DataAugmentation.enable_random_flip = False 62 | batch/meta_dataset_processing.DataAugmentation.enable_random_brightness = True 63 | batch/meta_dataset_processing.DataAugmentation.random_brightness_delta = 0.125 64 | batch/meta_dataset_processing.DataAugmentation.enable_random_contrast = True 65 | batch/meta_dataset_processing.DataAugmentation.random_contrast_delta = 0.2 66 | batch/meta_dataset_processing.DataAugmentation.enable_random_hue = True 67 | batch/meta_dataset_processing.DataAugmentation.random_hue_delta = 0.03 68 | batch/meta_dataset_processing.DataAugmentation.enable_random_saturation = True 69 | batch/meta_dataset_processing.DataAugmentation.random_saturation_delta = 0.2 70 | -------------------------------------------------------------------------------- /fast-exps/lib/data/meta_dataset_processing.py: -------------------------------------------------------------------------------- 1 | import gin.tf 2 | import tensorflow.compat.v1 as tf 3 | 4 | 5 | @gin.configurable 6 | class DataAugmentation(object): 7 | """Configurations for performing data augmentation.""" 8 | 9 | def __init__(self, enable_jitter, jitter_amount, enable_gaussian_noise, 10 | gaussian_noise_std, enable_random_flip, 11 | enable_random_brightness, random_brightness_delta, 12 | enable_random_contrast, random_contrast_delta, 13 | enable_random_hue, random_hue_delta, enable_random_saturation, 14 | random_saturation_delta): 15 | """Initialize a DataAugmentation. 16 | 17 | Args: 18 | enable_jitter: bool whether to use image jitter (pad each image using 19 | reflection along x and y axes and then random crop). 20 | jitter_amount: amount (in pixels) to pad on all sides of the image. 21 | enable_gaussian_noise: bool whether to use additive Gaussian noise. 22 | gaussian_noise_std: Standard deviation of the Gaussian distribution. 23 | """ 24 | self.enable_jitter = enable_jitter 25 | self.jitter_amount = jitter_amount 26 | self.enable_gaussian_noise = enable_gaussian_noise 27 | self.gaussian_noise_std = gaussian_noise_std 28 | self.enable_random_flip = enable_random_flip 29 | self.enable_random_brightness = enable_random_brightness 30 | self.random_brightness_delta = random_brightness_delta 31 | self.enable_random_contrast = enable_random_contrast 32 | self.random_contrast_delta = random_contrast_delta 33 | self.enable_random_hue = enable_random_hue 34 | self.random_hue_delta = random_hue_delta 35 | self.enable_random_saturation = enable_random_saturation 36 | self.random_saturation_delta = random_saturation_delta 37 | 38 | 39 | @gin.configurable 40 | class ImageDecoder(object): 41 | """Image decoder.""" 42 | 43 | def __init__(self, image_size=None, data_augmentation=None): 44 | """Class constructor. 45 | 46 | Args: 47 | image_size: int, desired image size. The extracted image will be resized 48 | to `[image_size, image_size]`. 49 | data_augmentation: A DataAugmentation object with parameters for 50 | perturbing the images. 51 | """ 52 | 53 | self.image_size = image_size 54 | self.data_augmentation = data_augmentation 55 | 56 | def __call__(self, example_string): 57 | """Processes a single example string. 58 | 59 | Extracts and processes the image, and ignores the label. We assume that the 60 | image has three channels. 61 | 62 | Args: 63 | example_string: str, an Example protocol buffer. 64 | 65 | Returns: 66 | image_rescaled: the image, resized to `image_size x image_size` and 67 | rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values 68 | to go beyond this range. 69 | """ 70 | image_string = tf.parse_single_example( 71 | example_string, 72 | features={ 73 | 'image': tf.FixedLenFeature([], dtype=tf.string), 74 | 'label': tf.FixedLenFeature([], tf.int64) 75 | })['image'] 76 | image_decoded = tf.image.decode_image(image_string, channels=3) 77 | image_decoded.set_shape([None, None, 3]) 78 | image_resized = tf.image.resize_images( 79 | image_decoded, [self.image_size, self.image_size], 80 | method=tf.image.ResizeMethod.BILINEAR, 81 | align_corners=True) 82 | image = tf.cast(image_resized, tf.float32) 83 | 84 | if self.data_augmentation is not None: 85 | if self.data_augmentation.enable_random_brightness: 86 | delta = self.data_augmentation.random_brightness_delta 87 | image = tf.image.random_brightness(image, delta) 88 | 89 | if self.data_augmentation.enable_random_saturation: 90 | delta = self.data_augmentation.random_saturation_delta 91 | image = tf.image.random_saturation(image, 1 - delta, 1 + delta) 92 | 93 | if self.data_augmentation.enable_random_contrast: 94 | delta = self.data_augmentation.random_contrast_delta 95 | image = tf.image.random_contrast(image, 1 - delta, 1 + delta) 96 | 97 | if self.data_augmentation.enable_random_hue: 98 | delta = self.data_augmentation.random_hue_delta 99 | image = tf.image.random_hue(image, delta) 100 | 101 | if self.data_augmentation.enable_random_flip: 102 | image = tf.image.random_flip_left_right(image) 103 | 104 | image = 2 * (image / 255.0 - 0.5) # Rescale to [-1, 1]. 105 | 106 | if self.data_augmentation is not None: 107 | if self.data_augmentation.enable_gaussian_noise: 108 | image = image + tf.random_normal( 109 | tf.shape(image)) * self.data_augmentation.gaussian_noise_std 110 | 111 | if self.data_augmentation.enable_jitter: 112 | j = self.data_augmentation.jitter_amount 113 | paddings = tf.constant([[j, j], [j, j], [0, 0]]) 114 | image = tf.pad(image, paddings, 'REFLECT') 115 | image = tf.image.random_crop(image, 116 | [self.image_size, self.image_size, 3]) 117 | 118 | return image 119 | -------------------------------------------------------------------------------- /fast-exps/lib/data/meta_dataset_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import sys 4 | import torch 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | # from utils import device 9 | from paths import META_DATASET_ROOT, META_RECORDS_ROOT, PROJECT_ROOT 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Quiet the TensorFlow warnings 11 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # Quiet the TensorFlow warnings 12 | 13 | sys.path.append(os.path.abspath(META_DATASET_ROOT)) 14 | from meta_dataset.data import dataset_spec as dataset_spec_lib 15 | from meta_dataset.data import learning_spec 16 | from meta_dataset.data import pipeline 17 | from meta_dataset.data import config 18 | 19 | 20 | ALL_METADATASET_NAMES = "ilsvrc_2012 omniglot aircraft cu_birds dtd quickdraw fungi vgg_flower traffic_sign mscoco mnist cifar10 cifar100".split(' ') 21 | TRAIN_METADATASET_NAMES = ALL_METADATASET_NAMES[:8] 22 | TEST_METADATASET_NAMES = ALL_METADATASET_NAMES[8:] 23 | 24 | SPLIT_NAME_TO_SPLIT = {'train': learning_spec.Split.TRAIN, 25 | 'val': learning_spec.Split.VALID, 26 | 'test': learning_spec.Split.TEST} 27 | 28 | 29 | class MetaDatasetReader(object): 30 | def __init__(self, mode, train_set, validation_set, test_set): 31 | assert (train_set is not None or validation_set is not None or test_set is not None) 32 | 33 | self.data_path = META_RECORDS_ROOT 34 | self.train_dataset_next_task = None 35 | self.validation_set_dict = {} 36 | self.test_set_dict = {} 37 | self.specs_dict = {} 38 | gin.parse_config_file(f"{PROJECT_ROOT}/lib/data/meta_dataset_config.gin") 39 | 40 | def _get_dataset_spec(self, items): 41 | if isinstance(items, list): 42 | dataset_specs = [] 43 | for dataset_name in items: 44 | dataset_records_path = os.path.join(self.data_path, dataset_name) 45 | dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) 46 | dataset_specs.append(dataset_spec) 47 | return dataset_specs 48 | else: 49 | dataset_name = items 50 | dataset_records_path = os.path.join(self.data_path, dataset_name) 51 | dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) 52 | return dataset_spec 53 | 54 | def _to_torch(self, sample): 55 | for key, val in sample.items(): 56 | if isinstance(val, str): 57 | continue 58 | val = torch.from_numpy(val) 59 | if 'image' in key: 60 | val = val.permute(0, 3, 2, 1) 61 | else: 62 | val = val.long() 63 | sample[key] = val.to('cpu') # device 64 | return sample 65 | 66 | def num_classes(self, split_name): 67 | split = SPLIT_NAME_TO_SPLIT[split_name] 68 | all_split_specs = self.specs_dict[SPLIT_NAME_TO_SPLIT['train']] 69 | 70 | if not isinstance(all_split_specs, list): 71 | all_split_specs = [all_split_specs] 72 | 73 | total_n_classes = 0 74 | for specs in all_split_specs: 75 | total_n_classes += len(specs.get_classes(split)) 76 | return total_n_classes 77 | 78 | def build_class_to_identity(self): 79 | split = SPLIT_NAME_TO_SPLIT['train'] 80 | all_split_specs = self.specs_dict[SPLIT_NAME_TO_SPLIT['train']] 81 | 82 | if not isinstance(all_split_specs, list): 83 | all_split_specs = [all_split_specs] 84 | 85 | self.cls_to_identity = dict() 86 | self.dataset_id_to_dataset_name = dict() 87 | self.dataset_to_n_cats = dict() 88 | offset = 0 89 | for dataset_id, specs in enumerate(all_split_specs): 90 | dataset_name = specs.name 91 | self.dataset_id_to_dataset_name[dataset_id] = dataset_name 92 | n_cats = len(specs.get_classes(split)) 93 | self.dataset_to_n_cats[dataset_name] = n_cats 94 | for cat in range(n_cats): 95 | self.cls_to_identity[offset + cat] = (cat, dataset_id) 96 | offset += n_cats 97 | 98 | self.dataset_name_to_dataset_id = {v: k for k, v in 99 | self.dataset_id_to_dataset_name.items()} 100 | 101 | 102 | class MetaDatasetEpisodeReader(MetaDatasetReader): 103 | """ 104 | Class that wraps the Meta-Dataset episode readers. 105 | """ 106 | def __init__(self, mode, train_set=None, validation_set=None, test_set=None): 107 | super(MetaDatasetEpisodeReader, self).__init__(mode, train_set, validation_set, test_set) 108 | 109 | if mode == 'train': 110 | train_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 111 | self.train_dataset_next_task = self._init_multi_source_dataset( 112 | train_set, SPLIT_NAME_TO_SPLIT['train'], train_episode_desscription) 113 | 114 | if mode == 'val': 115 | test_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 116 | for item in validation_set: 117 | next_task = self._init_single_source_dataset( 118 | item, SPLIT_NAME_TO_SPLIT['val'], test_episode_desscription) 119 | self.validation_set_dict[item] = next_task 120 | 121 | if mode == 'test': 122 | test_episode_desscription = config.EpisodeDescriptionConfig(None, None, None) 123 | for item in test_set: 124 | next_task = self._init_single_source_dataset( 125 | item, SPLIT_NAME_TO_SPLIT['test'], test_episode_desscription) 126 | self.test_set_dict[item] = next_task 127 | 128 | def _init_multi_source_dataset(self, items, split, episode_description): 129 | dataset_specs = self._get_dataset_spec(items) 130 | self.specs_dict[split] = dataset_specs 131 | 132 | use_bilevel_ontology_list = [False] * len(items) 133 | use_dag_ontology_list = [False] * len(items) 134 | # Enable ontology aware sampling for Omniglot and ImageNet. 135 | if 'omniglot' in items: 136 | use_bilevel_ontology_list[items.index('omniglot')] = True 137 | if 'ilsvrc_2012' in items: 138 | use_dag_ontology_list[items.index('ilsvrc_2012')] = True 139 | 140 | multi_source_pipeline = pipeline.make_multisource_episode_pipeline( 141 | dataset_spec_list=dataset_specs, 142 | use_dag_ontology_list=use_dag_ontology_list, 143 | use_bilevel_ontology_list=use_bilevel_ontology_list, 144 | split=split, 145 | episode_descr_config = episode_description, 146 | image_size=84, shuffle_buffer_size=1000) 147 | 148 | iterator = multi_source_pipeline.make_one_shot_iterator() 149 | return iterator.get_next() 150 | 151 | def _init_single_source_dataset(self, dataset_name, split, episode_description): 152 | dataset_spec = self._get_dataset_spec(dataset_name) 153 | self.specs_dict[split] = dataset_spec 154 | 155 | # Enable ontology aware sampling for Omniglot and ImageNet. 156 | use_bilevel_ontology = False 157 | if 'omniglot' in dataset_name: 158 | use_bilevel_ontology = True 159 | 160 | use_dag_ontology = False 161 | if 'ilsvrc_2012' in dataset_name: 162 | use_dag_ontology = True 163 | 164 | single_source_pipeline = pipeline.make_one_source_episode_pipeline( 165 | dataset_spec=dataset_spec, 166 | use_dag_ontology=use_dag_ontology, 167 | use_bilevel_ontology=use_bilevel_ontology, 168 | split=split, 169 | episode_descr_config=episode_description, 170 | image_size=84, shuffle_buffer_size=1000) 171 | 172 | iterator = single_source_pipeline.make_one_shot_iterator() 173 | return iterator.get_next() 174 | 175 | def _get_task(self, next_task, session): 176 | episode = session.run(next_task)[0] 177 | task_dict = { 178 | 'context_images': episode[0], 179 | 'context_labels': episode[1], 180 | 'target_images': episode[3], 181 | 'target_labels': episode[4] 182 | } 183 | return self._to_torch(task_dict) 184 | 185 | def get_train_task(self, session): 186 | return self._get_task(self.train_dataset_next_task, session) 187 | 188 | def get_validation_task(self, session, item=None): 189 | item = item if item else list(self.validation_set_dict.keys())[0] 190 | return self._get_task(self.validation_set_dict[item], session) 191 | 192 | def get_test_task(self, session, item=None): 193 | item = item if item else list(self.test_set_dict.keys())[0] 194 | return self._get_task(self.test_set_dict[item], session) 195 | 196 | 197 | class MetaDatasetBatchReader(MetaDatasetReader): 198 | """ 199 | Class that wraps the Meta-Dataset episode readers. 200 | """ 201 | def __init__(self, mode, train_set, validation_set, test_set, batch_size): 202 | super(MetaDatasetBatchReader, self).__init__(mode, train_set, validation_set, test_set) 203 | self.batch_size = batch_size 204 | 205 | if mode == 'train': 206 | self.train_dataset_next_task = self._init_multi_source_dataset( 207 | train_set, SPLIT_NAME_TO_SPLIT['train']) 208 | 209 | elif mode == 'val': 210 | for item in validation_set: 211 | next_task = self.validation_dataset = self._init_single_source_dataset( 212 | item, SPLIT_NAME_TO_SPLIT['val']) 213 | self.validation_set_dict[item] = next_task 214 | 215 | elif mode == 'test': 216 | for item in test_set: 217 | next_task = self._init_single_source_dataset( 218 | item, SPLIT_NAME_TO_SPLIT['test']) 219 | self.test_set_dict[item] = next_task 220 | else: 221 | raise ValueError('Invalid mode : {:}'.format(mode)) 222 | 223 | self.build_class_to_identity() 224 | 225 | def _init_multi_source_dataset(self, items, split): 226 | dataset_specs = self._get_dataset_spec(items) 227 | self.specs_dict[split] = dataset_specs 228 | multi_source_pipeline = pipeline.make_multisource_batch_pipeline( 229 | dataset_spec_list=dataset_specs, batch_size=self.batch_size, 230 | split=split, image_size=84, add_dataset_offset=True, shuffle_buffer_size=1000) 231 | 232 | iterator = multi_source_pipeline.make_one_shot_iterator() 233 | return iterator.get_next() 234 | 235 | def _init_single_source_dataset(self, dataset_name, split): 236 | dataset_specs = self._get_dataset_spec(dataset_name) 237 | self.specs_dict[split] = dataset_specs 238 | multi_source_pipeline = pipeline.make_one_source_batch_pipeline( 239 | dataset_spec=dataset_specs, batch_size=self.batch_size, 240 | split=split, image_size=84, shuffle_buffer_size=1000) 241 | 242 | iterator = multi_source_pipeline.make_one_shot_iterator() 243 | return iterator.get_next() 244 | 245 | def _get_batch(self, next_task, session): 246 | episode = session.run(next_task)[0] 247 | images, labels = episode[0], episode[1] 248 | local_classes, dataset_ids = [], [] 249 | for label in labels: 250 | local_class, dataset_id = self.cls_to_identity[label] 251 | local_classes.append(local_class) 252 | dataset_ids.append(dataset_id) 253 | task_dict = { 254 | 'images': images, 255 | 'labels': labels, 256 | 'local_classes': np.array(local_classes), 257 | 'dataset_ids': np.array(dataset_ids), 258 | 'dataset_name': self.dataset_id_to_dataset_name[dataset_ids[-1]] 259 | } 260 | return self._to_torch(task_dict) 261 | 262 | def get_train_batch(self, session): 263 | return self._get_batch(self.train_dataset_next_task, session) 264 | 265 | def get_validation_batch(self, item, session): 266 | return self._get_batch(self.validation_set_dict[item], session) 267 | 268 | def get_test_batch(self, item, session): 269 | return self._get_batch(self.test_set_dict[item], session) 270 | -------------------------------------------------------------------------------- /fast-exps/lib/datasets/EpisodeMetadata.py: -------------------------------------------------------------------------------- 1 | import os, sys, torch 2 | import numpy as np 3 | import torch.utils.data as data 4 | 5 | 6 | class EpisodeMetadata(data.Dataset): 7 | 8 | def __init__(self, root, name, total): 9 | self.name = name 10 | if name is None: 11 | self.root_dir = root 12 | else: 13 | self.root_dir = os.path.join(root, name) 14 | self.total = total 15 | self.files = [] 16 | for index in range(total): 17 | xfile = os.path.join(self.root_dir, '{:06d}.pth'.format(index)) 18 | assert os.path.exists(xfile), '{:}'.format(xfile) 19 | self.files.append(xfile) 20 | 21 | def __getitem__(self, index): 22 | xfile = self.files[index] 23 | xdata = torch.load(xfile, map_location='cpu') 24 | context_features = xdata['context_features'] 25 | context_labels = xdata['context_labels'] 26 | target_features = xdata['target_features'] 27 | target_labels = xdata['target_labels'] 28 | return torch.IntTensor([index]), context_features, context_labels, target_features, target_labels 29 | 30 | def __len__(self): 31 | return len(self.files) 32 | 33 | -------------------------------------------------------------------------------- /fast-exps/lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from .EpisodeMetadata import EpisodeMetadata 3 | 4 | 5 | def get_eval_datasets(root, dataset_names, num=600): 6 | #eval_dataset_names = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower'] 7 | datasets = OrderedDict() 8 | for name in dataset_names: 9 | dataset = EpisodeMetadata(root, name, num) 10 | datasets[name] = dataset 11 | return datasets 12 | 13 | 14 | def get_train_dataset(root, num=10000): 15 | return EpisodeMetadata(root, 'train-{:}'.format(num), num) 16 | -------------------------------------------------------------------------------- /fast-exps/lib/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gin 3 | import numpy as np 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def cross_entropy_loss(logits, targets): 10 | log_p_y = F.log_softmax(logits, dim=1) 11 | preds = log_p_y.argmax(1) 12 | labels = targets.type(torch.long) 13 | loss = F.nll_loss(log_p_y, labels, reduction='mean') 14 | acc = torch.eq(preds, labels).float().mean() 15 | stats_dict = {'loss': loss.item(), 'acc': acc.item()} 16 | pred_dict = {'preds': preds.cpu().numpy(), 'labels': labels.cpu().numpy()} 17 | return loss, stats_dict, pred_dict 18 | 19 | 20 | def prototype_loss(support_embeddings, support_labels, 21 | query_embeddings, query_labels, distance='cos'): 22 | n_way = len(query_labels.unique()) 23 | 24 | prots = compute_prototypes(support_embeddings, support_labels, n_way).unsqueeze(0) 25 | embeds = query_embeddings.unsqueeze(1) 26 | if distance == 'l2': 27 | logits = -torch.pow(embeds - prots, 2).sum(-1) # shape [n_query, n_way] 28 | elif distance == 'cos': 29 | logits = F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) * 10 30 | elif distance == 'lin': 31 | logits = torch.einsum('izd,zjd->ij', embeds, prots) 32 | 33 | return cross_entropy_loss(logits, query_labels) 34 | 35 | 36 | def compute_prototypes(embeddings, labels, n_way): 37 | prots = torch.zeros(n_way, embeddings.shape[-1]).type( 38 | embeddings.dtype).to(embeddings.device) 39 | for i in range(n_way): 40 | prots[i] = embeddings[(labels == i).nonzero(), :].mean(0) 41 | return prots 42 | 43 | 44 | class AdaptiveCosineNCC(nn.Module): 45 | def __init__(self): 46 | super(AdaptiveCosineNCC, self).__init__() 47 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 48 | 49 | def forward(self, support_embeddings, support_labels, 50 | query_embeddings, query_labels, return_logits=False): 51 | n_way = len(query_labels.unique()) 52 | 53 | prots = compute_prototypes(support_embeddings, support_labels, n_way).unsqueeze(0) 54 | embeds = query_embeddings.unsqueeze(1) 55 | logits = F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) * self.scale 56 | 57 | if return_logits: 58 | return logits 59 | 60 | return cross_entropy_loss(logits, query_labels) 61 | 62 | -------------------------------------------------------------------------------- /fast-exps/lib/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | from torch import nn 6 | from torch.optim.lr_scheduler import (MultiStepLR, ExponentialLR, 7 | CosineAnnealingWarmRestarts, 8 | CosineAnnealingLR) 9 | from utils import check_dir, device 10 | from paths import PROJECT_ROOT 11 | 12 | 13 | 14 | def cosine_sim(embeds, prots): 15 | prots = prots.unsqueeze(0) 16 | embeds = embeds.unsqueeze(1) 17 | return F.cosine_similarity(embeds, prots, dim=-1, eps=1e-30) 18 | 19 | 20 | 21 | class CosineClassifier(nn.Module): 22 | def __init__(self, n_feat, num_classes): 23 | super(CosineClassifier, self).__init__() 24 | self.num_classes = num_classes 25 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 26 | weight = torch.FloatTensor(n_feat, num_classes).normal_( 27 | 0.0, np.sqrt(2.0 / num_classes)) 28 | self.weight = nn.Parameter(weight, requires_grad=True) 29 | 30 | def forward(self, x): 31 | x_norm = torch.nn.functional.normalize(x, p=2, dim=-1, eps=1e-12) 32 | weight = torch.nn.functional.normalize(self.weight, p=2, dim=0, eps=1e-12) 33 | cos_dist = x_norm @ weight 34 | scores = self.scale * cos_dist 35 | return scores 36 | 37 | def extra_repr(self): 38 | s = 'CosineClassifier: input_channels={}, num_classes={}; learned_scale: {}'.format( 39 | self.weight.shape[0], self.weight.shape[1], self.scale.item()) 40 | return s 41 | 42 | 43 | class CosineConv(nn.Module): 44 | def __init__(self, n_feat, num_classes, kernel_size=1): 45 | super(CosineConv, self).__init__() 46 | self.scale = nn.Parameter(torch.tensor(10.0), requires_grad=True) 47 | weight = torch.FloatTensor(num_classes, n_feat, 1, 1).normal_( 48 | 0.0, np.sqrt(2.0/num_classes)) 49 | self.weight = nn.Parameter(weight, requires_grad=True) 50 | 51 | def forward(self, x): 52 | x_normalized = torch.nn.functional.normalize( 53 | x, p=2, dim=1, eps=1e-12) 54 | weight = torch.nn.functional.normalize( 55 | self.weight, p=2, dim=1, eps=1e-12) 56 | 57 | cos_dist = torch.nn.functional.conv2d(x_normalized, weight) 58 | scores = self.scale * cos_dist 59 | return scores 60 | 61 | def extra_repr(self): 62 | s = 'CosineConv: num_inputs={}, num_classes={}, kernel_size=1; scale_value: {}'.format( 63 | self.weight.shape[0], self.weight.shape[1], self.scale.item()) 64 | return s 65 | -------------------------------------------------------------------------------- /fast-exps/lib/models/models_dict.py: -------------------------------------------------------------------------------- 1 | DATASET_MODELS_RESNET18 = { 2 | 'ilsvrc_2012': 'imagenet-net', 3 | 'omniglot': 'omniglot-net', 4 | 'aircraft': 'aircraft-net', 5 | 'cu_birds': 'birds-net', 6 | 'dtd': 'textures-net', 7 | 'quickdraw': 'quickdraw-net', 8 | 'fungi': 'fungi-net', 9 | 'vgg_flower': 'vgg_flower-net' 10 | } 11 | 12 | 13 | DATASET_MODELS_RESNET18_PNF = { 14 | 'ilsvrc_2012': 'imagenet-net', 15 | 'omniglot': 'omniglot-film', 16 | 'aircraft': 'aircraft-film', 17 | 'cu_birds': 'birds-film', 18 | 'dtd': 'textures-film', 19 | 'quickdraw': 'quickdraw-film', 20 | 'fungi': 'fungi-film', 21 | 'vgg_flower': 'vgg_flower-film' 22 | } 23 | 24 | DATASET_MODELS_DICT = {'resnet18': DATASET_MODELS_RESNET18, 25 | 'resnet18_pnf': DATASET_MODELS_RESNET18_PNF} 26 | -------------------------------------------------------------------------------- /fast-exps/lib/models/new_model_helpers.py: -------------------------------------------------------------------------------- 1 | import os, sys, copy, torch 2 | import collections 3 | 4 | from typing import List, Dict, Text 5 | from .resnet18 import resnet18 6 | from .resnet18_film import resnet18 as resnet18_film 7 | 8 | MODEL_DICT = {'resnet18': resnet18, 9 | 'resnet18_pnf': resnet18_film} 10 | 11 | 12 | class CheckPointer(object): 13 | def __init__(self, model_name, model=None, optimizer=None): 14 | self.model = model 15 | self.optimizer = optimizer 16 | self.model_name = model_name 17 | TORCH_HOME = 'TORCH_HOME' 18 | if TORCH_HOME in os.environ: 19 | TORCH_HOME = os.environ[TORCH_HOME] 20 | else: 21 | TORCH_HOME = os.path.join(os.environ['HOME'], '.torch') 22 | self.model_path = os.path.join(TORCH_HOME, 'sur-weights', model_name) 23 | self.last_ckpt = os.path.join(self.model_path, 'checkpoint.pth.tar') 24 | self.best_ckpt = os.path.join(self.model_path, 'model_best.pth.tar') 25 | 26 | def restore_model(self, ckpt='last', model=True, optimizer=True, strict=True): 27 | if not os.path.exists(self.model_path): 28 | assert False, "Model is not found at {}".format(self.model_path) 29 | self.last_ckpt = os.path.join(self.model_path, 'checkpoint.pth.tar') 30 | self.best_ckpt = os.path.join(self.model_path, 'model_best.pth.tar') 31 | ckpt_path = self.last_ckpt if ckpt == 'last' else self.best_ckpt 32 | 33 | if os.path.isfile(ckpt_path): 34 | print("=> loading {} checkpoint '{}'".format(ckpt, ckpt_path)) 35 | ch = torch.load(ckpt_path, map_location='cpu') 36 | if self.model is not None and model: 37 | self.model.load_state_dict(ch['state_dict'], strict=strict) 38 | if self.optimizer is not None and optimizer: 39 | self.optimizer.load_state_dict(ch['optimizer']) 40 | else: 41 | assert False, "No checkpoint! %s" % ckpt_path 42 | 43 | return ch.get('epoch', None), ch.get('best_val_loss', None), ch.get('best_val_acc', None) 44 | 45 | 46 | def get_extractors(trainsets: List[Text], 47 | dataset_models: Dict[Text, Text], 48 | backbone: Text, classifier: Text, bn_train_mode: bool, use_cuda: bool = True): 49 | extractors = collections.OrderedDict() 50 | for dataset_name in trainsets: 51 | if dataset_name not in dataset_models: 52 | continue 53 | if dataset_name == 'ilsvrc_2012': 54 | extractor = MODEL_DICT['resnet18'](classifier=classifier, num_classes=None, global_pool=False, dropout=0.0) 55 | else: 56 | extractor = MODEL_DICT[backbone](classifier=classifier, num_classes=None, global_pool=False, dropout=0.0) 57 | extractor.train(bn_train_mode) 58 | print('Create {:}\'s network with BN={:}'.format(dataset_models[dataset_name], bn_train_mode)) 59 | if backbone == 'resnet18_pnf' and dataset_name != 'ilsvrc_2012': 60 | weights = copy.deepcopy(extractors['ilsvrc_2012'].module.state_dict()) 61 | extractor.load_state_dict(weights, strict=False) 62 | checkpointer = CheckPointer(dataset_models[dataset_name], extractor, optimizer=None) 63 | checkpointer.restore_model(ckpt='best', strict=False) 64 | if use_cuda: 65 | extractor = extractor.cuda() 66 | extractors[dataset_name] = torch.nn.DataParallel(extractor) 67 | return extractors 68 | 69 | 70 | def extract_features(extractors, images): 71 | all_features = [] 72 | for name, extractor in extractors.items(): 73 | features = extractor(images) 74 | all_features.append(features) 75 | return torch.stack(all_features, dim=1) # batch x #extractors x #features 76 | -------------------------------------------------------------------------------- /fast-exps/lib/models/new_prop_prototype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import random, math 6 | 7 | 8 | # TODO: integrate the two functions into the following codes 9 | def get_dotproduct_score(proto, cache, model): 10 | proto_emb = model['linear_q'](proto) 11 | s_cache_emb = model['linear_k'](cache) 12 | raw_score = F.cosine_similarity(proto_emb.unsqueeze(1), s_cache_emb.unsqueeze(0), dim=-1) 13 | return raw_score 14 | 15 | 16 | def get_mlp_score(proto, cache, model): 17 | n_proto, fea_dim = proto.shape 18 | n_cache, fea_dim = cache.shape 19 | raw_score = model['w']( model['nonlinear'](model['w1'](proto).view(n_proto, 1, fea_dim) + model['w2'](cache).view(1, n_cache, fea_dim) ) ) 20 | return raw_score.squeeze(-1) 21 | 22 | 23 | # this model does not need query, only key and value 24 | class MultiHeadURT_value(nn.Module): 25 | def __init__(self, fea_dim, hid_dim, temp=1, n_head=1): 26 | super(MultiHeadURT_value, self).__init__() 27 | self.w1 = nn.Linear(fea_dim, hid_dim) 28 | self.w2 = nn.Linear(hid_dim, n_head) 29 | self.temp = temp 30 | 31 | def forward(self, cat_proto): 32 | # cat_proto n_class*8*512 33 | n_class, n_extractors, fea_dim = cat_proto.shape 34 | raw_score = self.w2(self.w1(cat_proto)) # n_class*8*n_head 35 | score = F.softmax(self.temp * raw_score, dim=1) 36 | return score 37 | 38 | 39 | class URTPropagation(nn.Module): 40 | 41 | def __init__(self, key_dim, query_dim, hid_dim, temp=1, att="cosine"): 42 | super(URTPropagation, self).__init__() 43 | self.linear_q = nn.Linear(query_dim, hid_dim, bias=True) 44 | self.linear_k = nn.Linear(key_dim, hid_dim, bias=True) 45 | #self.linear_v_w = nn.Parameter(torch.rand(8, key_dim, key_dim)) 46 | self.linear_v_w = nn.Parameter( torch.eye(key_dim).unsqueeze(0).repeat(8,1,1)) 47 | self.temp = temp 48 | self.att = att 49 | # how different the init is 50 | for m in self.modules(): 51 | if isinstance(m, nn.Linear): 52 | m.weight.data.normal_(0, 0.001) 53 | 54 | def forward_transform(self, samples): 55 | bs, n_extractors, fea_dim = samples.shape 56 | ''' 57 | if self.training: 58 | w_trans = torch.nn.functional.gumbel_softmax(self.linear_v_w, tau=10, hard=True) 59 | else: 60 | # y_soft = torch.softmax(self.linear_v_w, -1) 61 | # index = y_soft.max(-1, keepdim=True)[1] 62 | index = self.linear_v_w.max(-1, keepdim=True)[1] 63 | y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0) 64 | w_trans = y_hard 65 | # w_trans = y_hard - y_soft.detach() + y_soft 66 | ''' 67 | w_trans = self.linear_v_w 68 | # compute regularization 69 | regularization = w_trans @ torch.transpose(w_trans, 1, 2) 70 | samples = samples.view(bs, n_extractors, fea_dim, 1) 71 | w_trans = w_trans.view(1, 8, fea_dim, fea_dim) 72 | return torch.matmul(w_trans, samples).view(bs, n_extractors, fea_dim), (regularization**2).sum() 73 | 74 | def forward(self, cat_proto): 75 | # cat_proto n_class*8*512 76 | # return: n_class*8 77 | n_class, n_extractors, fea_dim = cat_proto.shape 78 | q = cat_proto.view(n_class, -1) # n_class * 8_512 79 | k = cat_proto # n_class * 8 * 512 80 | q_emb = self.linear_q(q) # n_class * hid_dim 81 | k_emb = self.linear_k(k) # n_class * 8 * hid_dim | 8 * hid_dim 82 | if self.att == "cosine": 83 | raw_score = F.cosine_similarity(q_emb.view(n_class, 1, -1), k_emb.view(n_class, n_extractors, -1), dim=-1) 84 | elif self.att == "dotproduct": 85 | raw_score = torch.sum( q_emb.view(n_class, 1, -1) * k_emb.view(n_class, n_extractors, -1), dim=-1 ) / (math.sqrt(fea_dim)) 86 | else: 87 | raise ValueError('invalid att type : {:}'.format(self.att)) 88 | score = F.softmax(self.temp * raw_score, dim=1) 89 | return score 90 | 91 | 92 | class MultiHeadURT(nn.Module): 93 | def __init__(self, key_dim, query_dim, hid_dim, temp=1, att="cosine", n_head=1): 94 | super(MultiHeadURT, self).__init__() 95 | layers = [] 96 | for _ in range(n_head): 97 | layer = URTPropagation(key_dim, query_dim, hid_dim, temp, att) 98 | layers.append(layer) 99 | self.layers = nn.ModuleList(layers) 100 | 101 | def forward(self, cat_proto): 102 | score_lst = [] 103 | for i, layer in enumerate(self.layers): 104 | score = layer(cat_proto) 105 | score_lst.append(score) 106 | # n_class * n_extractor * n_head 107 | return torch.stack(score_lst, dim=-1) 108 | 109 | def get_lambda_urt_sample(context_features, context_labels, target_features, num_labels, model, normalize=True): 110 | if normalize: 111 | context_features = F.normalize(context_features, dim=-1) 112 | target_features = F.normalize(target_features, dim=-1) 113 | score_context, urt_context = model(context_features) 114 | score_target, urt_target = model(target_features) 115 | proto_list = [] 116 | for label in range(num_labels): 117 | proto = urt_context[context_labels == label].mean(dim=0) 118 | proto_list.append(proto) 119 | urt_proto = torch.stack(proto_list) 120 | # n_samples*8*512 121 | return score_context, urt_proto, score_target, urt_target 122 | 123 | def get_lambda_urt_avg(context_features, context_labels, num_labels, model, normalize=True): 124 | if normalize: 125 | context_features = F.normalize(context_features, dim=-1) 126 | proto_list = [] 127 | for label in range(num_labels): 128 | proto = context_features[context_labels == label].mean(dim=0) 129 | proto_list.append(proto) 130 | proto = torch.stack(proto_list) 131 | # n_class*8*512 132 | score_proto = model(proto) 133 | # n_extractors * n_head 134 | return torch.mean(score_proto, dim=0) 135 | 136 | def apply_urt_avg_selection(context_features, selection_params, normalize, value="sum", transform=None): 137 | selection_params = torch.transpose(selection_params, 0, 1) # n_head * 8 138 | n_samples, n_extractors, fea_dim = context_features.shape 139 | urt_fea_lst = [] 140 | if normalize: 141 | context_features = F.normalize(context_features, dim=-1) 142 | regularization_losses = [] 143 | for i, params in enumerate(selection_params): 144 | # class-wise lambda 145 | if transform: 146 | trans_features, reg_loss = transform.module.layers[i].forward_transform(context_features) 147 | regularization_losses.append(reg_loss) 148 | else: 149 | trans_features = context_features 150 | if value == "sum": 151 | urt_features = torch.sum(params.view(1,n_extractors,1) * trans_features, dim=1) # n_sample * 512 152 | elif value == "cat": 153 | urt_features = params.view(1,n_extractors,1) * trans_features # n_sample * 8 * 512 154 | urt_fea_lst.append(urt_features) 155 | if len(regularization_losses) == 0: 156 | return torch.stack( urt_fea_lst, dim=1 ).view(n_samples, -1) # n_sample * (n_head * 512) or n_sample * (8 * 512) 157 | else: 158 | return torch.stack( urt_fea_lst, dim=1 ).view(n_samples, -1), sum(regularization_losses) 159 | 160 | 161 | def apply_urt_selection(context_features, context_labels, selection_params, normalize): 162 | # class-wise lambda 163 | if normalize: 164 | context_features = F.normalize(context_features, dim=-1) 165 | lambda_lst = [] 166 | for lab in context_labels: 167 | lambda_lst.append(selection_params[lab]) 168 | lambda_tensor = torch.stack(lambda_lst, dim=0) 169 | n_sample, n_extractors = lambda_tensor.shape 170 | urt_features = torch.sum(lambda_tensor.view(n_sample, n_extractors, 1) * context_features, dim=1) 171 | return urt_features 172 | 173 | class PropagationLayer(nn.Module): 174 | 175 | def __init__(self, input_dim=512, hid_dim=128, temp=1, transform=False): 176 | super(PropagationLayer, self).__init__() 177 | self.linear_q = nn.Linear(input_dim, hid_dim, bias=False) 178 | self.linear_k = nn.Linear(input_dim, hid_dim, bias=False) 179 | self.temp = temp 180 | if transform: 181 | self.transform = nn.Linear(input_dim, input_dim) 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Linear): 185 | m.weight.data.normal_(0, 0.001) 186 | 187 | def forward(self, proto, s_cache, data2nclss, use_topk): 188 | if 'transform' in self.__dict__: 189 | proto = self.transform(proto) 190 | s_cache = self.transform(s_cache) 191 | proto_emb = self.linear_q(proto) 192 | s_cache_emb = self.linear_k(s_cache) 193 | raw_score = F.cosine_similarity(proto_emb.unsqueeze(1), s_cache_emb.unsqueeze(0), dim=-1) 194 | score = F.softmax(self.temp * raw_score, dim=1) 195 | prop_proto = torch.matmul( score, s_cache ) # n_class * n_cache @ n_cache * n_dim 196 | if random.random() > 0.99: 197 | print("top_1_idx: {} in {} cache".format(torch.topk(raw_score, 1)[1], len(s_cache))) 198 | print("score: {}".format(score)) 199 | print("mean:{}, var:{}, min:{}, max:{}".format(torch.mean(score, dim=1).data, torch.var(score, dim=1).data, torch.min(score, dim=1)[0].data, torch.max(score, dim=1)[0].data)) 200 | return raw_score, prop_proto 201 | 202 | 203 | class MultiHeadPropagationLayer(nn.Module): 204 | 205 | def __init__(self, input_dim, hid_dim, temp, transform, n_head): 206 | super(MultiHeadPropagationLayer, self).__init__() 207 | layers = [] 208 | for _ in range(n_head): 209 | layer = PropagationLayer(input_dim, hid_dim, temp, transform) 210 | layers.append(layer) 211 | self.layers = nn.ModuleList(layers) 212 | 213 | def forward(self, proto, s_cache, data2nclss, use_topk): 214 | raw_score_lst, prop_proto_lst = [], [] 215 | for i, layer in enumerate(self.layers): 216 | raw_score, prop_proto = layer(proto, s_cache, data2nclss, use_topk) 217 | if torch.isnan(raw_score).any() or torch.isnan(prop_proto).any(): import pdb; pdb.set_trace() 218 | raw_score_lst.append(raw_score) 219 | prop_proto_lst.append(prop_proto) 220 | return torch.stack(raw_score_lst, dim=0).mean(0), torch.stack(prop_proto_lst, dim=0).mean(0) 221 | 222 | def get_prototypes(features, labels, num_labels, model, cache): 223 | proto_list = [] 224 | for label in range(num_labels): 225 | proto = features[labels == label].mean(dim=0) 226 | proto_list.append(proto) 227 | proto = torch.stack(proto_list) 228 | num_devices = torch.cuda.device_count() 229 | num_slots, feature_dim = cache.shape 230 | cache_for_parallel = cache.view(1, num_slots, feature_dim).expand(num_devices, num_slots, feature_dim) 231 | raw_score, prop_proto = model(proto, cache_for_parallel) 232 | return raw_score, proto, prop_proto 233 | -------------------------------------------------------------------------------- /fast-exps/lib/models/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from models.model_utils import CosineClassifier 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | 52 | def __init__(self, block, layers, classifier=None, num_classes=64, 53 | dropout=0.0, global_pool=True): 54 | super(ResNet, self).__init__() 55 | self.initial_pool = False 56 | inplanes = self.inplanes = 64 57 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=2, 58 | padding=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(self.inplanes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 62 | self.layer1 = self._make_layer(block, inplanes, layers[0]) 63 | self.layer2 = self._make_layer(block, inplanes * 2, layers[1], stride=2) 64 | self.layer3 = self._make_layer(block, inplanes * 4, layers[2], stride=2) 65 | self.layer4 = self._make_layer(block, inplanes * 8, layers[3], stride=2) 66 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 67 | self.dropout = nn.Dropout(dropout) 68 | self.outplanes = 512 69 | 70 | # handle classifier creation 71 | if num_classes is not None: 72 | if classifier == 'linear': 73 | self.cls_fn = nn.Linear(self.outplanes, num_classes) 74 | elif classifier == 'cosine': 75 | self.cls_fn = CosineClassifier(self.outplanes, num_classes) 76 | else: 77 | self.cls_fn = None 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 82 | elif isinstance(m, nn.BatchNorm2d): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | def _make_layer(self, block, planes, blocks, stride=1): 87 | downsample = None 88 | if stride != 1 or self.inplanes != planes * block.expansion: 89 | downsample = nn.Sequential( 90 | conv1x1(self.inplanes, planes * block.expansion, stride), 91 | nn.BatchNorm2d(planes * block.expansion), 92 | ) 93 | 94 | layers = [] 95 | layers.append(block(self.inplanes, planes, stride, downsample)) 96 | self.inplanes = planes * block.expansion 97 | for _ in range(1, blocks): 98 | layers.append(block(self.inplanes, planes)) 99 | 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | x = self.embed(x) 104 | if self.cls_fn: 105 | x = self.dropout(x) 106 | x = self.cls_fn(x) 107 | return x 108 | 109 | def embed(self, x, param_dict=None): 110 | x = self.conv1(x) 111 | x = self.bn1(x) 112 | x = self.relu(x) 113 | if self.initial_pool: 114 | x = self.maxpool(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | 121 | x = self.avgpool(x) 122 | return x.squeeze() 123 | 124 | def get_state_dict(self): 125 | """Outputs all the state elements""" 126 | return self.state_dict() 127 | 128 | def get_parameters(self): 129 | """Outputs all the parameters""" 130 | return [v for k, v in self.named_parameters()] 131 | 132 | 133 | def resnet18(pretrained=False, pretrained_model_path=None, **kwargs): 134 | """ 135 | Constructs a ResNet-18 model. 136 | """ 137 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 138 | if pretrained: 139 | device = model.get_state_dict()[0].device 140 | ckpt_dict = torch.load(pretrained_model_path, map_location=device) 141 | model.load_parameters(ckpt_dict['state_dict'], strict=False) 142 | print('Loaded shared weights from {}'.format(pretrained_model_path)) 143 | return model 144 | -------------------------------------------------------------------------------- /fast-exps/lib/models/resnet18_film.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from models.model_utils import CosineClassifier 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class CatFilm(nn.Module): 19 | """Film layer that performs per-channel affine transformation.""" 20 | def __init__(self, planes): 21 | super(CatFilm, self).__init__() 22 | self.gamma = nn.Parameter(torch.ones(1, planes)) 23 | self.beta = nn.Parameter(torch.zeros(1, planes)) 24 | 25 | def forward(self, x): 26 | gamma = self.gamma.view(1, -1, 1, 1) 27 | beta = self.beta.view(1, -1, 1, 1) 28 | return gamma * x + beta 29 | 30 | 31 | class BasicBlockFilm(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlockFilm, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | self.film1 = CatFilm(planes) 44 | self.film2 = CatFilm(planes) 45 | 46 | def forward(self, x): 47 | identity = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.film1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.film2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | 69 | def __init__(self, block, layers, classifier=None, num_classes=None, 70 | dropout=0.0, global_pool=True): 71 | super(ResNet, self).__init__() 72 | self.initial_pool = False 73 | self.film_normalize = CatFilm(3) 74 | inplanes = self.inplanes = 64 75 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=5, stride=2, 76 | padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 80 | self.layer1 = self._make_layer(block, inplanes, layers[0]) 81 | self.layer2 = self._make_layer(block, inplanes * 2, layers[1], stride=2) 82 | self.layer3 = self._make_layer(block, inplanes * 4, layers[2], stride=2) 83 | self.layer4 = self._make_layer(block, inplanes * 8, layers[3], stride=2) 84 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 85 | self.dropout = nn.Dropout(dropout) 86 | self.outplanes = 512 87 | 88 | # handle classifier creation 89 | if num_classes is not None: 90 | if classifier == 'linear': 91 | self.cls_fn = nn.Linear(self.outplanes, num_classes) 92 | elif classifier == 'cosine': 93 | self.cls_fn = CosineClassifier(self.outplanes, num_classes) 94 | else: 95 | self.cls_fn = None 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | elif isinstance(m, nn.BatchNorm2d): 101 | nn.init.constant_(m.weight, 1) 102 | nn.init.constant_(m.bias, 0) 103 | 104 | def _make_layer(self, block, planes, blocks, stride=1): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | conv1x1(self.inplanes, planes * block.expansion, stride), 109 | nn.BatchNorm2d(planes * block.expansion)) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride=stride, downsample=downsample)) 113 | self.inplanes = planes * block.expansion 114 | for _ in range(1, blocks): 115 | layers.append(block(self.inplanes, planes)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.embed(x) 121 | if self.cls_fn: 122 | x = self.cls_fn(x) 123 | return x 124 | 125 | def embed(self, x, param_dict=None): 126 | """Computing the features""" 127 | x = self.film_normalize(x) 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | if self.initial_pool: 132 | x = self.maxpool(x) 133 | 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | x = self.layer4(x) 138 | 139 | x = self.avgpool(x) 140 | return x.squeeze() 141 | 142 | def get_state_dict(self): 143 | """Outputs the state elements that are domain-specific""" 144 | return {k: v for k, v in self.state_dict().items() 145 | if 'film' in k or 'cls' in k or 'running' in k} 146 | 147 | def get_parameters(self): 148 | """Outputs only the parameters that are domain-specific""" 149 | return [v for k, v in self.named_parameters() 150 | if 'film' in k or 'cls' in k] 151 | 152 | 153 | def resnet18(pretrained=False, pretrained_model_path=None, **kwargs): 154 | """ 155 | Constructs a FiLM adapted ResNet-18 model. 156 | """ 157 | model = ResNet(BasicBlockFilm, [2, 2, 2, 2], **kwargs) 158 | 159 | # loading shared convolutional weights 160 | if pretrained_model_path is not None: 161 | device = model.get_parameters()[0].device 162 | ckpt_dict = torch.load(pretrained_model_path, map_location=device)['state_dict'] 163 | shared_state = {k: v for k, v in ckpt_dict.items() if 'cls' not in k} 164 | model.load_state_dict(shared_state, strict=False) 165 | print('Loaded shared weights from {}'.format(pretrained_model_path)) 166 | return model 167 | -------------------------------------------------------------------------------- /fast-exps/lib/paths.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | 4 | PROJECT_ROOT = str((Path(__file__).parent / '..').resolve()) 5 | META_DATASET_ROOT = os.environ['META_DATASET_ROOT'] 6 | META_RECORDS_ROOT = os.environ['RECORDS'] 7 | META_DATA_ROOT = '/'.join(META_RECORDS_ROOT.split('/')[:-1]) 8 | -------------------------------------------------------------------------------- /fast-exps/lib/utils.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | import random 3 | import torch 4 | import numpy as np 5 | from tabulate import tabulate 6 | 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def time_string(): 11 | ISOTIMEFORMAT='%Y-%m-%d %X' 12 | string = '[{:}]'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) 13 | return string 14 | 15 | 16 | def convert_secs2time(epoch_time, string=True, xneed=True): 17 | need_hour = int(epoch_time / 3600) 18 | need_mins = int((epoch_time - 3600*need_hour) / 60) 19 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 20 | if string: 21 | if xneed: 22 | need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 23 | else: 24 | need_time = '{:02d}:{:02d}:{:02d}'.format(need_hour, need_mins, need_secs) 25 | return need_time 26 | else: 27 | return need_hour, need_mins, need_secs 28 | 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0.0 37 | self.avg = 0.0 38 | self.sum = 0.0 39 | self.count = 0.0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | def __repr__(self): 48 | return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__)) 49 | 50 | 51 | class ConfusionMatrix(): 52 | def __init__(self, n_classes): 53 | self.n_classes = n_classes 54 | self.mat = np.zeros([n_classes, n_classes]) 55 | 56 | def update_mat(self, preds, labels, idxs): 57 | idxs = np.array(idxs) 58 | real_pred = idxs[preds] 59 | real_labels = idxs[labels] 60 | self.mat[real_pred, real_labels] += 1 61 | 62 | def get_mat(self): 63 | return self.mat 64 | 65 | 66 | class Accumulator(): 67 | def __init__(self, max_size=2000): 68 | self.max_size = max_size 69 | self.ac = np.empty(0) 70 | 71 | def append(self, v): 72 | self.ac = np.append(self.ac[-self.max_size:], v) 73 | 74 | def reset(self): 75 | self.ac = np.empty(0) 76 | 77 | def mean(self, last=None): 78 | last = last if last else self.max_size 79 | return self.ac[-last:].mean() 80 | 81 | 82 | class IterBeat(): 83 | def __init__(self, freq, length=None): 84 | self.length = length 85 | self.freq = freq 86 | 87 | def step(self, i): 88 | if i == 0: 89 | self.t = time.time() 90 | self.lastcall = 0 91 | else: 92 | if ((i % self.freq) == 0) or ((i + 1) == self.length): 93 | t = time.time() 94 | print('{0} / {1} ---- {2:.2f} it/sec'.format( 95 | i, self.length, (i - self.lastcall) / (t - self.t))) 96 | self.lastcall = i 97 | self.t = t 98 | 99 | 100 | class SerializableArray(object): 101 | def __init__(self, array): 102 | self.shape = array.shape 103 | self.data = array.tobytes() 104 | self.dtype = array.dtype 105 | 106 | def get(self): 107 | array = np.frombuffer(self.data, self.dtype) 108 | return np.reshape(array, self.shape) 109 | 110 | 111 | def print_res(array, name, file=None, prec=4, mult=1): 112 | array = np.array(array) * mult 113 | mean, std = np.mean(array), np.std(array) 114 | conf = 1.96 * std / np.sqrt(len(array)) 115 | stat_string = ("test {:s}: {:0.%df} +/- {:0.%df}" 116 | % (prec, prec)).format(name, mean, conf) 117 | print(stat_string) 118 | if file is not None: 119 | with open(file, 'a+') as f: 120 | f.write(stat_string + '\n') 121 | 122 | 123 | def process_copies(embeddings, labels, args): 124 | n_copy = args['test.n_copy'] 125 | test_embeddings = embeddings.view( 126 | args['data.test_query'] * args['data.test_way'], 127 | n_copy, -1).mean(dim=1) 128 | return test_embeddings, labels[0::n_copy] 129 | 130 | 131 | def set_determ(seed=1234): 132 | random.seed(seed) 133 | np.random.seed(seed) 134 | torch.manual_seed(seed) 135 | torch.cuda.manual_seed(seed) 136 | torch.cuda.manual_seed_all(seed) 137 | 138 | 139 | def merge_dicts(dicts, torch_stack=True): 140 | def stack_fn(l): 141 | if isinstance(l[0], torch.Tensor): 142 | return torch.stack(l) 143 | elif isinstance(l[0], str): 144 | return l 145 | else: 146 | return torch.tensor(l) 147 | 148 | keys = dicts[0].keys() 149 | new_dict = {key: [] for key in keys} 150 | for key in keys: 151 | for d in dicts: 152 | new_dict[key].append(d[key]) 153 | if torch_stack: 154 | for key in keys: 155 | new_dict[key] = stack_fn(new_dict[key]) 156 | return new_dict 157 | 158 | 159 | def voting(preds, pref_ind=0): 160 | n_models = len(preds) 161 | n_test = len(preds[0]) 162 | final_preds = [] 163 | for i in range(n_test): 164 | cur_preds = [preds[k][i] for k in range(n_models)] 165 | classes, counts = np.unique(cur_preds, return_counts=True) 166 | if (counts == max(counts)).sum() > 1: 167 | final_preds.append(preds[pref_ind][i]) 168 | else: 169 | final_preds.append(classes[np.argmax(counts)]) 170 | return final_preds 171 | 172 | 173 | def agreement(preds): 174 | n_preds = preds.shape[0] 175 | mat = np.zeros((n_preds, n_preds)) 176 | for i in range(n_preds): 177 | for j in range(i, n_preds): 178 | mat[i, j] = mat[j, i] = ( 179 | preds[i] == preds[j]).astype('float').mean() 180 | return mat 181 | 182 | 183 | def read_textfile(filename, skip_last_line=True): 184 | with open(filename, 'r') as f: 185 | container = f.read().split('\n') 186 | if skip_last_line: 187 | container = container[:-1] 188 | return container 189 | 190 | 191 | def check_dir(dirname, verbose=True): 192 | """This function creates a directory 193 | in case it doesn't exist""" 194 | try: 195 | # Create target Directory 196 | os.makedirs(dirname) 197 | if verbose: 198 | print("Directory ", dirname, " was created") 199 | except FileExistsError: 200 | if verbose: 201 | print("Directory ", dirname, " already exists") 202 | return dirname 203 | 204 | 205 | def pre_load_results(): 206 | sur_paper = {"ilsvrc_2012":[.563], "omniglot":[.931], "aircraft":[.854], "cu_birds":[.714], "dtd":[.715], "quickdraw":[.813], "fungi":[.631], "vgg_flower":[.828], "traffic_sign":[.704], "mscoco":[.524], "mnist":[.943], "cifar10":[.668], "cifar100":[.566]} 207 | sur_exp = {"ilsvrc_2012":[.563], "omniglot":[.931], "aircraft":[.854], "cu_birds":[.714], "dtd":[.715], "quickdraw":[.813], "fungi":[.631], "vgg_flower":[.828], "traffic_sign":[.704], "mscoco":[.524], "mnist":[.943], "cifar10":[.668], "cifar100":[.566]} 208 | return sur_paper, sur_exp 209 | 210 | 211 | def show_results(dataset_names, alg2data2accuracy, compares, print_func): 212 | assert isinstance(compares, tuple) and len(compares) == 2 213 | rows, better = [], 0 214 | for dataset_name in dataset_names: 215 | row = [dataset_name] 216 | xname2acc = {} 217 | for model_name, data2accs in alg2data2accuracy.items(): 218 | acc = np.array(data2accs[dataset_name]) * 100 219 | mean_acc = acc.mean() 220 | conf = (1.96 * acc.std()) / np.sqrt(len(acc)) 221 | row.append(f"{mean_acc:0.2f} +- {conf:0.2f}") 222 | xname2acc[model_name] = mean_acc 223 | row.append("{:.2f}".format(xname2acc[compares[1]]-xname2acc[compares[0]])) 224 | better += xname2acc[compares[1]] > xname2acc[compares[0]] 225 | rows.append(row) 226 | alg_names = list(alg2data2accuracy.keys()) + ['ok-{:02d}/{:02d}'.format(better, len(dataset_names))] 227 | table = tabulate(rows, headers=['model \\ data'] + alg_names, floatfmt=".2f") 228 | print_func(table) 229 | -------------------------------------------------------------------------------- /fast-exps/urt-avg-head.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os, sys, time, argparse 3 | import collections 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from torch.utils.tensorboard import SummaryWriter 8 | import numpy as np 9 | from tabulate import tabulate 10 | import random, json 11 | from pathlib import Path 12 | lib_dir = (Path(__file__).parent / 'lib').resolve() 13 | if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir)) 14 | 15 | from datasets import get_eval_datasets, get_train_dataset 16 | from data.meta_dataset_reader import TRAIN_METADATASET_NAMES, ALL_METADATASET_NAMES 17 | from models.model_utils import cosine_sim 18 | from models.new_model_helpers import extract_features 19 | from models.losses import prototype_loss 20 | from models.models_dict import DATASET_MODELS_DICT 21 | from models.new_prop_prototype import MultiHeadURT, MultiHeadURT_value, get_lambda_urt_avg, apply_urt_avg_selection 22 | from utils import convert_secs2time, time_string, AverageMeter, show_results, pre_load_results 23 | from paths import META_RECORDS_ROOT 24 | 25 | from config_utils import Logger 26 | 27 | 28 | def load_config(): 29 | 30 | parser = argparse.ArgumentParser(description='Train URT networks') 31 | parser.add_argument('--save_dir', type=str, help="The saved path in dir.") 32 | parser.add_argument('--cache_dir', type=str, help="The saved path in dir.") 33 | parser.add_argument('--seed', type=int, help="The random seed.") 34 | parser.add_argument('--interval.train', type=int, default=100, help='The number to log training information') 35 | parser.add_argument('--interval.test', type=int, default=2000, help='The number to log training information') 36 | parser.add_argument('--interval.train.reset', type=int, default=500, help='The number to log training information') 37 | 38 | # model args 39 | parser.add_argument('--model.backbone', default='resnet18', help="Use ResNet18 for experiments (default: False)") 40 | parser.add_argument('--model.classifier', type=str, default='cosine', choices=['none', 'linear', 'cosine'], help="Do classification using cosine similatity between activations and weights") 41 | 42 | # urt model 43 | parser.add_argument('--urt.variant', type=str) 44 | parser.add_argument('--urt.temp', type=str) 45 | parser.add_argument('--urt.head', type=int) 46 | parser.add_argument('--urt.penalty_coef', type=float) 47 | # train args 48 | parser.add_argument('--train.max_iter', type=int, help='number of epochs to train (default: 10000)') 49 | parser.add_argument('--train.weight_decay', type=float, help="weight decay coef") 50 | parser.add_argument('--train.optimizer', type=str, help='optimization method (default: momentum)') 51 | 52 | parser.add_argument('--train.scheduler', type=str, help='optimization method (default: momentum)') 53 | parser.add_argument('--train.learning_rate', type=float, help='learning rate (default: 0.0001)') 54 | parser.add_argument('--train.lr_decay_step_gamma', type=float, metavar='DECAY_GAMMA') 55 | parser.add_argument('--train.lr_step', type=int, help='the value to divide learning rate by when decayin lr') 56 | 57 | xargs = vars(parser.parse_args()) 58 | return xargs 59 | 60 | 61 | def get_cosine_logits(selected_target, proto, temp): 62 | n_query, feat_dim = selected_target.shape 63 | n_classes, feat_dim = proto.shape 64 | logits = temp * F.cosine_similarity(selected_target.view(n_query, 1, feat_dim), proto.view(1, n_classes, feat_dim), dim=-1) 65 | return logits 66 | 67 | 68 | def test_all_dataset(xargs, test_loaders, URT_model, logger, writter, mode, training_iter, cosine_temp): 69 | URT_model.eval() 70 | our_name = 'urt' 71 | accs_names = [our_name] 72 | alg2data2accuracy = collections.OrderedDict() 73 | alg2data2accuracy['sur-paper'], alg2data2accuracy['sur-exp'] = pre_load_results() 74 | alg2data2accuracy[our_name] = {name: [] for name in test_loaders.keys()} 75 | 76 | logger.print('\n{:} starting evaluate the {:} set at the {:}-th iteration.'.format(time_string(), mode, training_iter)) 77 | for idata, (test_dataset, loader) in enumerate(test_loaders.items()): 78 | logger.print('===>>> {:} --->>> {:02d}/{:02d} --->>> {:}'.format(time_string(), idata, len(test_loaders), test_dataset)) 79 | our_losses = AverageMeter() 80 | for idx, (_, context_features, context_labels, target_features, target_labels) in enumerate(loader): 81 | context_features, context_labels = context_features.squeeze(0).cuda(), context_labels.squeeze(0).cuda() 82 | target_features, target_labels = target_features.squeeze(0).cuda(), target_labels.squeeze(0).cuda() 83 | n_classes = len(np.unique(context_labels.cpu().numpy())) 84 | # optimize selection parameters and perform feature selection 85 | avg_urt_params = get_lambda_urt_avg(context_features, context_labels, n_classes, URT_model, normalize=True) 86 | 87 | urt_context_features = apply_urt_avg_selection(context_features, avg_urt_params, normalize=True) 88 | urt_target_features = apply_urt_avg_selection(target_features, avg_urt_params, normalize=True) 89 | proto_list = [] 90 | for label in range(n_classes): 91 | proto = urt_context_features[context_labels == label].mean(dim=0) 92 | proto_list.append(proto) 93 | urt_proto = torch.stack(proto_list) 94 | 95 | #if random.random() > 0.99: 96 | # print("urt avg score {}".format(avg_urt_params)) 97 | # print("-"*20) 98 | with torch.no_grad(): 99 | logits = get_cosine_logits(urt_target_features, urt_proto, cosine_temp) 100 | loss = F.cross_entropy(logits, target_labels) 101 | our_losses.update(loss.item()) 102 | predicts = torch.argmax(logits, dim=-1) 103 | final_acc = torch.eq(target_labels, predicts).float().mean().item() 104 | alg2data2accuracy[our_name][test_dataset].append(final_acc) 105 | base_name = '{:}-{:}'.format(test_dataset, mode) 106 | writter.add_scalar("{:}-our-loss".format(base_name), our_losses.avg, training_iter) 107 | writter.add_scalar("{:}-our-acc".format(base_name) , np.mean(alg2data2accuracy[our_name][test_dataset]), training_iter) 108 | 109 | 110 | dataset_names = list(test_loaders.keys()) 111 | show_results(dataset_names, alg2data2accuracy, ('sur-paper', our_name), logger.print) 112 | logger.print("\n") 113 | 114 | def main(xargs): 115 | 116 | # set up logger 117 | log_dir = Path(xargs['save_dir']).resolve() 118 | log_dir.mkdir(parents=True, exist_ok=True) 119 | 120 | if xargs['seed'] is None or xargs['seed'] < 0: 121 | seed = len(list(Path(log_dir).glob("*.txt"))) 122 | else: 123 | seed = xargs['seed'] 124 | random.seed(seed) 125 | torch.manual_seed(seed) 126 | logger = Logger(str(log_dir), seed) 127 | logger.print('{:} --- args ---'.format(time_string())) 128 | for key, value in xargs.items(): 129 | logger.print(' [{:10s}] : {:}'.format(key, value)) 130 | logger.print('{:} --- args ---'.format(time_string())) 131 | writter = SummaryWriter(log_dir) 132 | 133 | # Setting up datasets 134 | extractor_domains = TRAIN_METADATASET_NAMES 135 | train_dataset = get_train_dataset(xargs['cache_dir'], xargs['train.max_iter']) 136 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True) 137 | # The validation loaders. 138 | val_datasets = get_eval_datasets(os.path.join(xargs['cache_dir'], 'val-600'), TRAIN_METADATASET_NAMES) 139 | val_loaders = collections.OrderedDict() 140 | for name, dataset in val_datasets.items(): 141 | val_loaders[name] = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 142 | # The test loaders 143 | test_datasets = get_eval_datasets(os.path.join(xargs['cache_dir'], 'test-600'), ALL_METADATASET_NAMES) 144 | test_loaders = collections.OrderedDict() 145 | for name, dataset in test_datasets.items(): 146 | test_loaders[name] = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) 147 | 148 | class_name_dict = collections.OrderedDict() 149 | for d in extractor_domains: 150 | with open("{:}/{:}/dataset_spec.json".format(META_RECORDS_ROOT, d)) as f: 151 | data = json.load(f) 152 | class_name_dict[d] = data['class_names'] 153 | 154 | # init prop model 155 | URT_model = MultiHeadURT(key_dim=512, query_dim=8*512, hid_dim=1024, temp=1, att="dotproduct", n_head=xargs['urt.head']) 156 | URT_model = torch.nn.DataParallel(URT_model) 157 | URT_model = URT_model.cuda() 158 | cosine_temp = nn.Parameter(torch.tensor(10.0).cuda()) 159 | params = [p for p in URT_model.parameters()] + [cosine_temp] 160 | 161 | optimizer = torch.optim.Adam(params, lr=xargs['train.learning_rate'], weight_decay=xargs['train.weight_decay']) 162 | logger.print(optimizer) 163 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=xargs['train.max_iter']) 164 | logger.print(lr_scheduler) 165 | 166 | # load checkpoint optional 167 | last_ckp_path = log_dir / 'last-ckp-seed-{:}.pth'.format(seed) 168 | if last_ckp_path.exists(): 169 | checkpoint = torch.load(last_ckp_path) 170 | start_iter = checkpoint['train_iter'] + 1 171 | URT_model.load_state_dict(checkpoint['URT_model']) 172 | optimizer.load_state_dict(checkpoint['optimizer']) 173 | lr_scheduler.load_state_dict(checkpoint['scheduler']) 174 | logger.print ('load checkpoint from {:}'.format(last_ckp_path)) 175 | else: 176 | logger.print ('randomly initialiization') 177 | start_iter = 0 178 | max_iter = xargs['train.max_iter'] 179 | 180 | our_losses, our_accuracies = AverageMeter(), AverageMeter() 181 | iter_time, timestamp = AverageMeter(), time.time() 182 | 183 | for index, (_, context_features, context_labels, target_features, target_labels) in enumerate(train_loader): 184 | context_features, context_labels = context_features.squeeze(0).cuda(), context_labels.squeeze(0).cuda() 185 | target_features, target_labels = target_features.squeeze(0).cuda(), target_labels.squeeze(0).cuda() 186 | URT_model.train() 187 | n_classes = len(np.unique(context_labels.cpu().numpy())) 188 | # optimize selection parameters and perform feature selection 189 | avg_urt_params = get_lambda_urt_avg(context_features, context_labels, n_classes, URT_model, normalize=True) 190 | # identity matrix panelize to be sparse, only focus on one aspect 191 | penalty = torch.pow( torch.norm( torch.transpose(avg_urt_params, 0, 1) @ avg_urt_params - torch.eye(xargs['urt.head']).cuda() ), 2) 192 | # n_samples * (n_head * 512) 193 | urt_context_features = apply_urt_avg_selection(context_features, avg_urt_params, normalize=True) 194 | urt_target_features = apply_urt_avg_selection(target_features, avg_urt_params, normalize=True) 195 | proto_list = [] 196 | for label in range(n_classes): 197 | proto = urt_context_features[context_labels == label].mean(dim=0) 198 | proto_list.append(proto) 199 | urt_proto = torch.stack(proto_list) 200 | logits = get_cosine_logits(urt_target_features, urt_proto, cosine_temp) 201 | loss = F.cross_entropy(logits, target_labels) + xargs['urt.penalty_coef']*penalty 202 | optimizer.zero_grad() 203 | loss.backward() 204 | optimizer.step() 205 | lr_scheduler.step() 206 | 207 | with torch.no_grad(): 208 | predicts = torch.argmax(logits, dim=-1) 209 | final_acc = torch.eq(target_labels, predicts).float().mean().item() 210 | our_losses.update(loss.item()) 211 | our_accuracies.update(final_acc * 100) 212 | 213 | if index % xargs['interval.train'] == 0 or index+1 == max_iter: 214 | logger.print("{:} [{:5d}/{:5d}] [OUR] lr: {:}, loss: {:.5f}, accuracy: {:.4f}".format(time_string(), index, max_iter, lr_scheduler.get_last_lr(), our_losses.avg, our_accuracies.avg)) 215 | writter.add_scalar("lr", lr_scheduler.get_last_lr()[0], index) 216 | writter.add_scalar("train_loss", our_losses.avg, index) 217 | writter.add_scalar("train_acc", our_accuracies.avg, index) 218 | if index+1 == max_iter: 219 | with torch.no_grad(): 220 | info = {'args' : xargs, 221 | 'train_iter': index, 222 | 'optimizer' : optimizer.state_dict(), 223 | 'scheduler' : lr_scheduler.state_dict(), 224 | 'URT_model' : URT_model.state_dict()} 225 | torch.save(info, "{:}/ckp-seed-{:}-iter-{:}.pth".format(log_dir, seed, index)) 226 | torch.save(info, last_ckp_path) 227 | 228 | # Reset the count 229 | if index % xargs['interval.train.reset'] == 0: 230 | our_losses.reset() 231 | our_accuracies.reset() 232 | time_str = convert_secs2time(iter_time.avg * (max_iter - index), True) 233 | logger.print("iteration [{:5d}/{:5d}], still need {:}".format(index, max_iter, time_str)) 234 | 235 | # measure time 236 | iter_time.update(time.time() - timestamp) 237 | timestamp = time.time() 238 | 239 | if (index+1) % xargs['interval.test'] == 0 or index+1 == max_iter: 240 | test_all_dataset(xargs, val_loaders, URT_model, logger, writter, "eval", index, cosine_temp) 241 | test_all_dataset(xargs, test_loaders, URT_model, logger, writter, "test", index, cosine_temp) 242 | 243 | 244 | if __name__ == '__main__': 245 | xargs = load_config() 246 | main(xargs) 247 | -------------------------------------------------------------------------------- /fast-scripts/urt-avg-head.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # bash ./fast-scripts/urt-avg-head.sh ./fast-outputs/urt-avg-head 2 ${cache_dir} 3 | 4 | echo script name: $0 5 | echo $# arguments 6 | if [ "$#" -ne 4 ] ;then 7 | echo "Input illegal number of parameters " $# 8 | echo "Need 4 args: path" 9 | exit 1 10 | fi 11 | if [ "$TORCH_HOME" = "" ]; then 12 | export TORCH_HOME="${HOME}/.torch" 13 | else 14 | echo "TORCH_HOME : $TORCH_HOME" 15 | fi 16 | if [ "$DATASET_DIR" = "" ]; then 17 | export DATASET_DIR="${HOME}/scratch/meta-dataset-x" 18 | else 19 | echo "DATASET_DIR : $DATASET_DIR" 20 | fi 21 | echo "DATASET_DIR : $DATASET_DIR" 22 | 23 | 24 | save_dir=$1 25 | n_head=$2 26 | penalty_coef=$3 27 | cache_dir=$4 28 | temp=1 29 | optimizer=adam 30 | scheduler=cosine 31 | test_interval=10000 32 | 33 | export META_DATASET_ROOT="${HOME}/scratch/git/meta-dataset-v1" 34 | export RECORDS="${HOME}/scratch/meta-dataset-records" 35 | echo "ROOT: $(pwd)" 36 | 37 | python fast-exps/urt-avg-head.py --save_dir ${save_dir} --cache_dir ${cache_dir} --urt.head ${n_head} --urt.penalty_coef ${penalty_coef} \ 38 | --train.max_iter=10000 --train.weight_decay=1e-5 --interval.train 100 --interval.test ${test_interval} \ 39 | --train.learning_rate=1e-2 --train.lr_decay_step_gamma=0.9 \ 40 | --urt.temp=${temp} --train.optimizer=${optimizer} --train.scheduler=${scheduler} 41 | -------------------------------------------------------------------------------- /scripts/pre-extract-feature.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # bash ./scripts/pre-extract-feature.sh resnet18 ./outputs/extract-feature 3 | echo script name: $0 4 | echo $# arguments 5 | if [ "$#" -ne 2 ] ;then 6 | echo "Input illegal number of parameters " $# 7 | echo "Need 2 args: path" 8 | exit 1 9 | fi 10 | if [ "$TORCH_HOME" = "" ]; then 11 | export TORCH_HOME="${HOME}/.torch" 12 | else 13 | echo "TORCH_HOME : $TORCH_HOME" 14 | fi 15 | if [ "$DATASET_DIR" = "" ]; then 16 | export DATASET_DIR="${HOME}/scratch/meta-dataset-x" 17 | else 18 | echo "DATASET_DIR : $DATASET_DIR" 19 | fi 20 | echo "DATASET_DIR : $DATASET_DIR" 21 | 22 | export META_DATASET_ROOT="${HOME}/scratch/git/meta-dataset-v1" 23 | export RECORDS="${HOME}/scratch/meta-dataset-records" 24 | 25 | backbone=$1 26 | save_dir=$2 27 | 28 | echo "ROOT: $(pwd)" 29 | 30 | ulimit -n 100000 31 | 32 | python exps/pre-extract-feature.py --save_dir ${save_dir} \ 33 | --model.backbone=${backbone} 34 | --------------------------------------------------------------------------------