├── .gitignore ├── LICENSE ├── README.md ├── algorithms ├── Algorithm.py ├── ClassificationModel.py ├── FeatureClassificationModel.py └── __init__.py ├── architectures ├── AlexNet.py ├── LinearClassifier.py ├── MultipleLinearClassifiers.py ├── MultipleNonLinearClassifiers.py ├── NetworkInNetwork.py └── NonLinearClassifier.py ├── config ├── CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats.py ├── CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K100.py ├── CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K1000.py ├── CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K20.py ├── CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K400.py ├── CIFAR10_MultLayerClassifier_on_RotNet_NIN4blocks_Conv2_feats.py ├── CIFAR10_RotNet_NIN4blocks.py ├── CIFAR10_supervised_NIN.py ├── CIFAR10_supervised_NIN_K100.py ├── CIFAR10_supervised_NIN_K1000.py ├── CIFAR10_supervised_NIN_K20.py ├── CIFAR10_supervised_NIN_K400.py ├── ImageNet_LinearClassifiers_ImageNet_RotNet_AlexNet_Features.py ├── ImageNet_LinearClassifiers_Places205_RotNet_AlexNet_Features.py ├── ImageNet_NonLinearClassifiers_ImageNet_RotNet_AlexNet_Features.py ├── ImageNet_RotNet_AlexNet.py ├── Places205_LinearClassifiers_ImageNet_RotNet_AlexNet_Features.py ├── Places205_LinearClassifiers_Places205_RotNet_AlexNet_Features.py └── Places205_RotNet_AlexNet.py ├── dataloader.py ├── extras ├── AlexNet.prototxt ├── AlexNet_fcn.prototxt ├── AlexNet_rescaled.prototxt ├── AlexNet_without_BN.py ├── cat.jpg ├── convert_alexnet_from_pytorch2caffe.py └── convert_caffe_alexnet_to_fcn.py ├── main.py ├── run_cifar10_based_unsupervised_experiments.sh ├── run_cifar10_semi_supervised_experiments.sh ├── run_imagenet_based_unsupervised_feature_experiments.sh ├── run_places205_based_unsupervised_feature_experiments.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments 2 | datasets 3 | *~ 4 | *.pyc 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Spyros 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## *Unsupervised Representation Learning by Predicting Image Rotations* 2 | 3 | ### Introduction 4 | 5 | The current code implements on [pytorch](http://pytorch.org/) the following ICLR2018 paper: 6 | **Title:** "Unsupervised Representation Learning by Predicting Image Rotations" 7 | **Authors:** Spyros Gidaris, Praveer Singh, Nikos Komodakis 8 | **Institution:** Universite Paris Est, Ecole des Ponts ParisTech 9 | **Code:** https://github.com/gidariss/FeatureLearningRotNet 10 | **Link:** https://openreview.net/forum?id=S1v4N2l0- 11 | 12 | **Abstract:** 13 | Over the last years, deep convolutional neural networks (ConvNets) have transformed the field of computer vision thanks to their unparalleled capacity to learn high level semantic image features. However, in order to successfully learn those features, they usually require massive amounts of manually labeled data, which is both expensive and impractical to scale. Therefore, unsupervised semantic feature learning, i.e., learning without requiring manual annotation effort, is of crucial importance in order to successfully harvest the vast amount of visual data that are available today. In our work we propose to learn image features by training ConvNets to recognize the 2d rotation that is applied to the image that it gets as input. We demonstrate both qualitatively and quantitatively that this apparently simple task actually provides a very powerful supervisory signal for semantic feature learning. We exhaustively evaluate our method in various unsupervised feature learning benchmarks and we exhibit in all of them state-of-the-art performance. Specifically, our results on those benchmarks demonstrate dramatic improvements w.r.t. prior state-of-the-art approaches in unsupervised representation learning and thus significantly close the gap with supervised feature learning. For instance, in PASCAL VOC 2007 detection task our unsupervised pre-trained AlexNet model achieves the state-of-the-art (among unsupervised methods) mAP of 54.4%$that is only 2.4 points lower from the supervised case. We get similarly striking results when we transfer our unsupervised learned features on various other tasks, such as ImageNet classification, PASCAL classification, PASCAL segmentation, and CIFAR-10 classification. 14 | 15 | ### Citing FeatureLearningRotNet 16 | 17 | If you find the code useful in your research, please consider citing our ICLR2018 paper: 18 | ``` 19 | @inproceedings{ 20 | gidaris2018unsupervised, 21 | title={Unsupervised Representation Learning by Predicting Image Rotations}, 22 | author={Spyros Gidaris and Praveer Singh and Nikos Komodakis}, 23 | booktitle={International Conference on Learning Representations}, 24 | year={2018}, 25 | url={https://openreview.net/forum?id=S1v4N2l0-}, 26 | } 27 | ``` 28 | 29 | ### Requirements 30 | It was developed and tested with pytorch version 0.2.0_4 31 | 32 | ### License 33 | This code is released under the MIT License (refer to the LICENSE file for details). 34 | 35 | ### Before running the experiments 36 | * Inside the *FeatureLearningRotNet* directory with the downloaded code you must create a directory named *experiments* where the experiments-related data will be stored: `mkdir experiments`. 37 | * You must download the desired datasets and set in [dataloader.py](https://github.com/gidariss/FeatureLearningRotNet/blob/master/dataloader.py#L21) the paths to where the datasets reside in your machine. We recommend creating a *datasets* directory `mkdir datasets` and placing the downloaded datasets there. 38 | * Note that all the experiment configuration files are placed in the [./config](https://github.com/gidariss/FeatureLearningRotNet/tree/master/config) directory. 39 | 40 | ### CIFAR-10 experiments 41 | * In order to train (in an unsupervised way) the RotNet model on the CIFAR-10 training images and then evaluate object classifiers on top of the RotNet-based learned features see the [run_cifar10_based_unsupervised_experiments.sh](https://github.com/gidariss/FeatureLearningRotNet/blob/master/run_cifar10_based_unsupervised_experiments.sh) script. Pre-trained model (in pytorch format) is provided [here](https://github.com/gidariss/FeatureLearningRotNet/releases/download/v1/CIFAR10_RotNet_NIN4blocks.tar.gz) (note that it is not exactly the same model used in the paper). 42 | * In order to run the semi-supervised experiments on CIFAR-10 see the [run_cifar10_semi_supervised_experiments.sh](https://github.com/gidariss/FeatureLearningRotNet/blob/master/run_cifar10_semi_supervised_experiments.sh) script. 43 | 44 | ### ImageNet and Places205 experiments 45 | * In order to train (in an unsupervised way) a RotNet model with AlexNet-like architecture on the **ImageNet** training images and then evaluate object classifiers (for the ImageNet and Places205 classification tasks) on top of the RotNet-based learned features see the [run_imagenet_based_unsupervised_feature_experiments.sh](https://github.com/gidariss/FeatureLearningRotNet/blob/master/run_imagenet_based_unsupervised_feature_experiments.sh) script. 46 | * In order to train (in an unsupervised way) a RotNet model with AlexNet-like architecture on the **Places205** training images and then evaluate object classifiers (for the ImageNet and Places205 classification tasks) on top of the RotNet-based learned features see the [run_places205_based_unsupervised_feature_experiments.sh](https://github.com/gidariss/FeatureLearningRotNet/blob/master/run_places205_based_unsupervised_feature_experiments.sh) scritp. 47 | 48 | 49 | ### Download the already trained RotNet model 50 | * In order to download the RotNet model (with AlexNet architecture) trained on the ImageNet training images using the current code, go to: [ImageNet_RotNet_AlexNet_pytorch](https://github.com/gidariss/FeatureLearningRotNet/releases/download/v1/ImageNet_RotNet_AlexNet.tar.gz). Note that: 51 | 1. The model is saved in pytorch format. 52 | 2. It is not the same as the one used in the paper and probably will give (slightly) different outcomes (in the ImageNet and Places205 classification tasks that it was tested it gave better results than the paper's model). 53 | 3. It expects RGB images that their pixel values are normalized with the following mean RGB values `mean_rgb = [0.485, 0.456, 0.406]` and std RGB values `std_rgb = [0.229, 0.224, 0.225]`. Prior to normalization the range of the image values must be [0.0, 1.0]. 54 | 55 | 56 | * In order to download the RotNet model (with AlexNet architecture) trained on the ImageNet training images using the current code and convered in caffe format, go to: [ImageNet_RotNet_AlexNet_caffe](https://github.com/gidariss/FeatureLearningRotNet/releases/download/v1/ImageNet_RotNet_Alexnet_caffe.tar.gz). Note that: 57 | 1. The model is saved in caffe format. 58 | 2. It is not the same as the one used in the paper and probably will give (slightly) different outcomes (in the PASCAL segmentation task it gives slightly better results than the paper's model). 59 | 3. It expects BGR images that their pixel values are mean normalized with the following mean BGR values `mean_bgr = [0.406*255.0, 0.456*255.0, 0.485*255.0]`. Prior to normalization the range of the image values must be [0.0, 255.0]. 60 | 4. The weights of the model are rescaled with the approach of [Kraehenbuehl et al, ICLR 2016](https://github.com/philkr/magic_init). 61 | 62 | 63 | -------------------------------------------------------------------------------- /algorithms/Algorithm.py: -------------------------------------------------------------------------------- 1 | """Define a generic class for training and testing learning algorithms.""" 2 | from __future__ import print_function 3 | import os 4 | import os.path 5 | import imp 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim 11 | 12 | import utils 13 | import datetime 14 | import logging 15 | 16 | from pdb import set_trace as breakpoint 17 | 18 | class Algorithm(): 19 | def __init__(self, opt): 20 | self.set_experiment_dir(opt['exp_dir']) 21 | self.set_log_file_handler() 22 | 23 | self.logger.info('Algorithm options %s' % opt) 24 | self.opt = opt 25 | self.init_all_networks() 26 | self.init_all_criterions() 27 | self.allocate_tensors() 28 | self.curr_epoch = 0 29 | self.optimizers = {} 30 | 31 | self.keep_best_model_metric_name = opt['best_metric'] if ('best_metric' in opt) else None 32 | 33 | def set_experiment_dir(self,directory_path): 34 | self.exp_dir = directory_path 35 | if (not os.path.isdir(self.exp_dir)): 36 | os.makedirs(self.exp_dir) 37 | 38 | self.vis_dir = os.path.join(directory_path,'visuals') 39 | if (not os.path.isdir(self.vis_dir)): 40 | os.makedirs(self.vis_dir) 41 | 42 | self.preds_dir = os.path.join(directory_path,'preds') 43 | if (not os.path.isdir(self.preds_dir)): 44 | os.makedirs(self.preds_dir) 45 | 46 | def set_log_file_handler(self): 47 | self.logger = logging.getLogger(__name__) 48 | 49 | strHandler = logging.StreamHandler() 50 | formatter = logging.Formatter( 51 | '%(asctime)s - %(name)-8s - %(levelname)-6s - %(message)s') 52 | strHandler.setFormatter(formatter) 53 | self.logger.addHandler(strHandler) 54 | self.logger.setLevel(logging.INFO) 55 | 56 | log_dir = os.path.join(self.exp_dir, 'logs') 57 | if (not os.path.isdir(log_dir)): 58 | os.makedirs(log_dir) 59 | 60 | now_str = datetime.datetime.now().__str__().replace(' ','_') 61 | 62 | self.log_file = os.path.join(log_dir, 'LOG_INFO_'+now_str+'.txt') 63 | self.log_fileHandler = logging.FileHandler(self.log_file) 64 | self.log_fileHandler.setFormatter(formatter) 65 | self.logger.addHandler(self.log_fileHandler) 66 | 67 | def init_all_networks(self): 68 | networks_defs = self.opt['networks'] 69 | self.networks = {} 70 | self.optim_params = {} 71 | 72 | for key, val in networks_defs.items(): 73 | self.logger.info('Set network %s' % key) 74 | def_file = val['def_file'] 75 | net_opt = val['opt'] 76 | self.optim_params[key] = val['optim_params'] if ('optim_params' in val) else None 77 | pretrained_path = val['pretrained'] if ('pretrained' in val) else None 78 | self.networks[key] = self.init_network(def_file, net_opt, pretrained_path, key) 79 | 80 | def init_network(self, net_def_file, net_opt, pretrained_path, key): 81 | self.logger.info('==> Initiliaze network %s from file %s with opts: %s' % (key, net_def_file, net_opt)) 82 | if (not os.path.isfile(net_def_file)): 83 | raise ValueError('Non existing file: {0}'.format(net_def_file)) 84 | 85 | network = imp.load_source("",net_def_file).create_model(net_opt) 86 | if pretrained_path != None: 87 | self.load_pretrained(network, pretrained_path) 88 | 89 | return network 90 | 91 | def load_pretrained(self, network, pretrained_path): 92 | self.logger.info('==> Load pretrained parameters from file %s:' % (pretrained_path)) 93 | 94 | assert(os.path.isfile(pretrained_path)) 95 | pretrained_model = torch.load(pretrained_path) 96 | if pretrained_model['network'].keys() == network.state_dict().keys(): 97 | network.load_state_dict(pretrained_model['network']) 98 | else: 99 | self.logger.info('==> WARNING: network parameters in pre-trained file %s do not strictly match' % (pretrained_path)) 100 | for pname, param in network.named_parameters(): 101 | if pname in pretrained_model['network']: 102 | self.logger.info('==> Copying parameter %s from file %s' % (pname, pretrained_path)) 103 | param.data.copy_(pretrained_model['network'][pname]) 104 | 105 | def init_all_optimizers(self): 106 | self.optimizers = {} 107 | 108 | for key, oparams in self.optim_params.items(): 109 | self.optimizers[key] = None 110 | if oparams != None: 111 | self.optimizers[key] = self.init_optimizer( 112 | self.networks[key], oparams, key) 113 | 114 | def init_optimizer(self, net, optim_opts, key): 115 | optim_type = optim_opts['optim_type'] 116 | learning_rate = optim_opts['lr'] 117 | optimizer = None 118 | parameters = filter(lambda p: p.requires_grad, net.parameters()) 119 | self.logger.info('Initialize optimizer: %s with params: %s for netwotk: %s' 120 | % (optim_type, optim_opts, key)) 121 | if optim_type == 'adam': 122 | optimizer = torch.optim.Adam(parameters, lr=learning_rate, 123 | betas=optim_opts['beta']) 124 | elif optim_type == 'sgd': 125 | optimizer = torch.optim.SGD(parameters, lr=learning_rate, 126 | momentum=optim_opts['momentum'], 127 | nesterov=optim_opts['nesterov'] if ('nesterov' in optim_opts) else False, 128 | weight_decay=optim_opts['weight_decay']) 129 | else: 130 | raise ValueError('Not supported or recognized optim_type', optim_type) 131 | 132 | return optimizer 133 | 134 | def init_all_criterions(self): 135 | criterions_defs = self.opt['criterions'] 136 | self.criterions = {} 137 | for key, val in criterions_defs.items(): 138 | crit_type = val['ctype'] 139 | crit_opt = val['opt'] if ('opt' in val) else None 140 | self.logger.info('Initialize criterion[%s]: %s with options: %s' % (key, crit_type, crit_opt)) 141 | self.criterions[key] = self.init_criterion(crit_type, crit_opt) 142 | 143 | def init_criterion(self, ctype, copt): 144 | return getattr(nn, ctype)(copt) 145 | 146 | def load_to_gpu(self): 147 | for key, net in self.networks.items(): 148 | self.networks[key] = net.cuda() 149 | 150 | for key, criterion in self.criterions.items(): 151 | self.criterions[key] = criterion.cuda() 152 | 153 | for key, tensor in self.tensors.items(): 154 | self.tensors[key] = tensor.cuda() 155 | 156 | def save_checkpoint(self, epoch, suffix=''): 157 | for key, net in self.networks.items(): 158 | if self.optimizers[key] == None: continue 159 | self.save_network(key, epoch, suffix=suffix) 160 | self.save_optimizer(key, epoch, suffix=suffix) 161 | 162 | def load_checkpoint(self, epoch, train=True, suffix=''): 163 | self.logger.info('Load checkpoint of epoch %d' % (epoch)) 164 | 165 | for key, net in self.networks.items(): # Load networks 166 | if self.optim_params[key] == None: continue 167 | self.load_network(key, epoch,suffix) 168 | 169 | if train: # initialize and load optimizers 170 | self.init_all_optimizers() 171 | for key, net in self.networks.items(): 172 | if self.optim_params[key] == None: continue 173 | self.load_optimizer(key, epoch,suffix) 174 | 175 | self.curr_epoch = epoch 176 | 177 | def delete_checkpoint(self, epoch, suffix=''): 178 | for key, net in self.networks.items(): 179 | if self.optimizers[key] == None: continue 180 | 181 | filename_net = self._get_net_checkpoint_filename(key, epoch)+suffix 182 | if os.path.isfile(filename_net): os.remove(filename_net) 183 | 184 | filename_optim = self._get_optim_checkpoint_filename(key, epoch)+suffix 185 | if os.path.isfile(filename_optim): os.remove(filename_optim) 186 | 187 | def save_network(self, net_key, epoch, suffix=''): 188 | assert(net_key in self.networks) 189 | filename = self._get_net_checkpoint_filename(net_key, epoch)+suffix 190 | state = {'epoch': epoch,'network': self.networks[net_key].state_dict()} 191 | torch.save(state, filename) 192 | 193 | def save_optimizer(self, net_key, epoch, suffix=''): 194 | assert(net_key in self.optimizers) 195 | filename = self._get_optim_checkpoint_filename(net_key, epoch)+suffix 196 | state = {'epoch': epoch,'optimizer': self.optimizers[net_key].state_dict()} 197 | torch.save(state, filename) 198 | 199 | def load_network(self, net_key, epoch,suffix=''): 200 | assert(net_key in self.networks) 201 | filename = self._get_net_checkpoint_filename(net_key, epoch)+suffix 202 | assert(os.path.isfile(filename)) 203 | if os.path.isfile(filename): 204 | checkpoint = torch.load(filename) 205 | self.networks[net_key].load_state_dict(checkpoint['network']) 206 | 207 | def load_optimizer(self, net_key, epoch,suffix=''): 208 | assert(net_key in self.optimizers) 209 | filename = self._get_optim_checkpoint_filename(net_key, epoch)+suffix 210 | assert(os.path.isfile(filename)) 211 | if os.path.isfile(filename): 212 | checkpoint = torch.load(filename) 213 | self.optimizers[net_key].load_state_dict(checkpoint['optimizer']) 214 | 215 | def _get_net_checkpoint_filename(self, net_key, epoch): 216 | return os.path.join(self.exp_dir, net_key+'_net_epoch'+str(epoch)) 217 | 218 | def _get_optim_checkpoint_filename(self, net_key, epoch): 219 | return os.path.join(self.exp_dir, net_key+'_optim_epoch'+str(epoch)) 220 | 221 | def solve(self, data_loader_train, data_loader_test): 222 | self.max_num_epochs = self.opt['max_num_epochs'] 223 | start_epoch = self.curr_epoch 224 | if len(self.optimizers) == 0: 225 | self.init_all_optimizers() 226 | 227 | eval_stats = {} 228 | train_stats = {} 229 | self.init_record_of_best_model() 230 | for self.curr_epoch in xrange(start_epoch, self.max_num_epochs): 231 | self.logger.info('Training epoch [%3d / %3d]' % (self.curr_epoch+1, self.max_num_epochs)) 232 | self.adjust_learning_rates(self.curr_epoch) 233 | train_stats = self.run_train_epoch(data_loader_train, self.curr_epoch) 234 | self.logger.info('==> Training stats: %s' % (train_stats)) 235 | 236 | self.save_checkpoint(self.curr_epoch+1) # create a checkpoint in the current epoch 237 | if start_epoch != self.curr_epoch: # delete the checkpoint of the previous epoch 238 | self.delete_checkpoint(self.curr_epoch) 239 | 240 | if data_loader_test is not None: 241 | eval_stats = self.evaluate(data_loader_test) 242 | self.logger.info('==> Evaluation stats: %s' % (eval_stats)) 243 | self.keep_record_of_best_model(eval_stats, self.curr_epoch) 244 | 245 | self.print_eval_stats_of_best_model() 246 | 247 | def run_train_epoch(self, data_loader, epoch): 248 | self.logger.info('Training: %s' % os.path.basename(self.exp_dir)) 249 | self.dloader = data_loader 250 | self.dataset_train = data_loader.dataset 251 | 252 | for key, network in self.networks.items(): 253 | if self.optimizers[key] == None: network.eval() 254 | else: network.train() 255 | 256 | disp_step = self.opt['disp_step'] if ('disp_step' in self.opt) else 50 257 | train_stats = utils.DAverageMeter() 258 | self.bnumber = len(data_loader()) 259 | for idx, batch in enumerate(tqdm(data_loader(epoch))): 260 | self.biter = idx 261 | train_stats_this = self.train_step(batch) 262 | train_stats.update(train_stats_this) 263 | if (idx+1) % disp_step == 0: 264 | self.logger.info('==> Iteration [%3d][%4d / %4d]: %s' % (epoch+1, idx+1, len(data_loader), train_stats.average())) 265 | 266 | return train_stats.average() 267 | 268 | def evaluate(self, dloader): 269 | self.logger.info('Evaluating: %s' % os.path.basename(self.exp_dir)) 270 | 271 | self.dloader = dloader 272 | self.dataset_eval = dloader.dataset 273 | self.logger.info('==> Dataset: %s [%d images]' % (dloader.dataset.name, len(dloader))) 274 | for key, network in self.networks.items(): 275 | network.eval() 276 | 277 | eval_stats = utils.DAverageMeter() 278 | self.bnumber = len(dloader()) 279 | for idx, batch in enumerate(tqdm(dloader())): 280 | self.biter = idx 281 | eval_stats_this = self.evaluation_step(batch) 282 | eval_stats.update(eval_stats_this) 283 | 284 | self.logger.info('==> Results: %s' % eval_stats.average()) 285 | 286 | return eval_stats.average() 287 | 288 | def adjust_learning_rates(self, epoch): 289 | # filter out the networks that are not trainable and that do 290 | # not have a learning rate Look Up Table (LUT_lr) in their optim_params 291 | optim_params_filtered = {k:v for k,v in self.optim_params.items() 292 | if (v != None and ('LUT_lr' in v))} 293 | 294 | for key, oparams in optim_params_filtered.items(): 295 | LUT = oparams['LUT_lr'] 296 | lr = next((lr for (max_epoch, lr) in LUT if max_epoch>epoch), LUT[-1][1]) 297 | self.logger.info('==> Set to %s optimizer lr = %.10f' % (key, lr)) 298 | for param_group in self.optimizers[key].param_groups: 299 | param_group['lr'] = lr 300 | 301 | def init_record_of_best_model(self): 302 | self.max_metric_val = None 303 | self.best_stats = None 304 | self.best_epoch = None 305 | 306 | def keep_record_of_best_model(self, eval_stats, current_epoch): 307 | if self.keep_best_model_metric_name is not None: 308 | metric_name = self.keep_best_model_metric_name 309 | if (metric_name not in eval_stats): 310 | raise ValueError('The provided metric {0} for keeping the best model is not computed by the evaluation routine.'.format(metric_name)) 311 | metric_val = eval_stats[metric_name] 312 | if self.max_metric_val is None or metric_val > self.max_metric_val: 313 | self.max_metric_val = metric_val 314 | self.best_stats = eval_stats 315 | self.save_checkpoint(self.curr_epoch+1, suffix='.best') 316 | if self.best_epoch is not None: 317 | self.delete_checkpoint(self.best_epoch+1, suffix='.best') 318 | self.best_epoch = current_epoch 319 | self.print_eval_stats_of_best_model() 320 | 321 | def print_eval_stats_of_best_model(self): 322 | if self.best_stats is not None: 323 | metric_name = self.keep_best_model_metric_name 324 | self.logger.info('==> Best results w.r.t. %s metric: epoch: %d - %s' % (metric_name, self.best_epoch+1, self.best_stats)) 325 | 326 | 327 | # FROM HERE ON ARE ABSTRACT FUNCTIONS THAT MUST BE IMPLEMENTED BY THE CLASS 328 | # THAT INHERITS THE Algorithms CLASS 329 | def train_step(self, batch): 330 | """Implements a training step that includes: 331 | * Forward a batch through the network(s) 332 | * Compute loss(es) 333 | * Backward propagation through the networks 334 | * Apply optimization step(s) 335 | * Return a dictionary with the computed losses and any other desired 336 | stats. The key names on the dictionary can be arbitrary. 337 | """ 338 | pass 339 | 340 | def evaluation_step(self, batch): 341 | """Implements an evaluation step that includes: 342 | * Forward a batch through the network(s) 343 | * Compute loss(es) or any other evaluation metrics. 344 | * Return a dictionary with the computed losses the evaluation 345 | metrics for that batch. The key names on the dictionary can be 346 | arbitrary. 347 | """ 348 | pass 349 | 350 | def allocate_tensors(self): 351 | """(Optional) allocate torch tensors that could potentially be used in 352 | in the train_step() or evaluation_step() functions. If the 353 | load_to_gpu() function is called then those tensors will be moved to 354 | the gpu device. 355 | """ 356 | self.tensors = {} 357 | -------------------------------------------------------------------------------- /algorithms/ClassificationModel.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.optim 8 | import os 9 | import torchnet as tnt 10 | import utils 11 | import PIL 12 | import pickle 13 | from tqdm import tqdm 14 | import time 15 | 16 | from . import Algorithm 17 | from pdb import set_trace as breakpoint 18 | 19 | 20 | def accuracy(output, target, topk=(1,)): 21 | """Computes the precision@k for the specified values of k""" 22 | maxk = max(topk) 23 | batch_size = target.size(0) 24 | 25 | _, pred = output.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].view(-1).float().sum(0) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | class ClassificationModel(Algorithm): 36 | def __init__(self, opt): 37 | Algorithm.__init__(self, opt) 38 | 39 | def allocate_tensors(self): 40 | self.tensors = {} 41 | self.tensors['dataX'] = torch.FloatTensor() 42 | self.tensors['labels'] = torch.LongTensor() 43 | 44 | def train_step(self, batch): 45 | return self.process_batch(batch, do_train=True) 46 | 47 | def evaluation_step(self, batch): 48 | return self.process_batch(batch, do_train=False) 49 | 50 | def process_batch(self, batch, do_train=True): 51 | #*************** LOAD BATCH (AND MOVE IT TO GPU) ******** 52 | start = time.time() 53 | self.tensors['dataX'].resize_(batch[0].size()).copy_(batch[0]) 54 | self.tensors['labels'].resize_(batch[1].size()).copy_(batch[1]) 55 | dataX = self.tensors['dataX'] 56 | labels = self.tensors['labels'] 57 | batch_load_time = time.time() - start 58 | #******************************************************** 59 | 60 | #******************************************************** 61 | start = time.time() 62 | if do_train: # zero the gradients 63 | self.optimizers['model'].zero_grad() 64 | #******************************************************** 65 | 66 | #***************** SET TORCH VARIABLES ****************** 67 | dataX_var = torch.autograd.Variable(dataX, volatile=(not do_train)) 68 | labels_var = torch.autograd.Variable(labels, requires_grad=False) 69 | #******************************************************** 70 | 71 | #************ FORWARD THROUGH NET *********************** 72 | pred_var = self.networks['model'](dataX_var) 73 | #******************************************************** 74 | 75 | #*************** COMPUTE LOSSES ************************* 76 | record = {} 77 | loss_total = self.criterions['loss'](pred_var, labels_var) 78 | record['prec1'] = accuracy(pred_var.data, labels, topk=(1,))[0][0] 79 | record['loss'] = loss_total.data[0] 80 | #******************************************************** 81 | 82 | #****** BACKPROPAGATE AND APPLY OPTIMIZATION STEP ******* 83 | if do_train: 84 | loss_total.backward() 85 | self.optimizers['model'].step() 86 | #******************************************************** 87 | batch_process_time = time.time() - start 88 | total_time = batch_process_time + batch_load_time 89 | record['load_time'] = 100*(batch_load_time/total_time) 90 | record['process_time'] = 100*(batch_process_time/total_time) 91 | 92 | return record 93 | -------------------------------------------------------------------------------- /algorithms/FeatureClassificationModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import utils 4 | import time 5 | 6 | from . import Algorithm 7 | from pdb import set_trace as breakpoint 8 | 9 | 10 | def accuracy(output, target, topk=(1,)): 11 | """Computes the precision@k for the specified values of k""" 12 | maxk = max(topk) 13 | batch_size = target.size(0) 14 | 15 | _, pred = output.topk(maxk, 1, True, True) 16 | pred = pred.t() 17 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 18 | 19 | res = [] 20 | for k in topk: 21 | correct_k = correct[:k].view(-1).float().sum(0) 22 | res.append(correct_k.mul_(100.0 / batch_size)) 23 | return res 24 | 25 | class FeatureClassificationModel(Algorithm): 26 | def __init__(self, opt): 27 | self.out_feat_keys = opt['out_feat_keys'] 28 | Algorithm.__init__(self, opt) 29 | 30 | def allocate_tensors(self): 31 | self.tensors = {} 32 | self.tensors['dataX'] = torch.FloatTensor() 33 | self.tensors['labels'] = torch.LongTensor() 34 | 35 | def train_step(self, batch): 36 | return self.process_batch(batch, do_train=True) 37 | 38 | def evaluation_step(self, batch): 39 | return self.process_batch(batch, do_train=False) 40 | 41 | def process_batch(self, batch, do_train=True): 42 | #*************** LOAD BATCH (AND MOVE IT TO GPU) ******** 43 | start = time.time() 44 | self.tensors['dataX'].resize_(batch[0].size()).copy_(batch[0]) 45 | self.tensors['labels'].resize_(batch[1].size()).copy_(batch[1]) 46 | dataX = self.tensors['dataX'] 47 | labels = self.tensors['labels'] 48 | batch_load_time = time.time() - start 49 | #******************************************************** 50 | 51 | #******************************************************** 52 | start = time.time() 53 | out_feat_keys = self.out_feat_keys 54 | finetune_feat_extractor = self.optimizers['feat_extractor'] is not None 55 | if do_train: # zero the gradients 56 | self.optimizers['classifier'].zero_grad() 57 | if finetune_feat_extractor: 58 | self.optimizers['feat_extractor'].zero_grad() 59 | else: 60 | self.networks['feat_extractor'].eval() 61 | #******************************************************** 62 | 63 | #***************** SET TORCH VARIABLES ****************** 64 | dataX_var = Variable(dataX, volatile=((not do_train) or (not finetune_feat_extractor))) 65 | labels_var = Variable(labels, requires_grad=False) 66 | #******************************************************** 67 | 68 | #************ FORWARD PROPAGATION *********************** 69 | feat_var = self.networks['feat_extractor'](dataX_var, out_feat_keys=out_feat_keys) 70 | if not finetune_feat_extractor: 71 | if isinstance(feat_var, (list, tuple)): 72 | for i in range(len(feat_var)): 73 | feat_var[i] = Variable(feat_var[i].data, volatile=(not do_train)) 74 | else: 75 | feat_var = Variable(feat_var.data, volatile=(not do_train)) 76 | pred_var = self.networks['classifier'](feat_var) 77 | #******************************************************** 78 | 79 | #*************** COMPUTE LOSSES ************************* 80 | record = {} 81 | if isinstance(pred_var, (list, tuple)): 82 | loss_total = None 83 | for i in range(len(pred_var)): 84 | loss_this = self.criterions['loss'](pred_var[i], labels_var) 85 | loss_total = loss_this if (loss_total is None) else (loss_total + loss_this) 86 | record['prec1_c'+str(1+i)] = accuracy(pred_var[i].data, labels, topk=(1,))[0][0] 87 | record['prec5_c'+str(1+i)] = accuracy(pred_var[i].data, labels, topk=(5,))[0][0] 88 | else: 89 | loss_total = self.criterions['loss'](pred_var, labels_var) 90 | record['prec1'] = accuracy(pred_var.data, labels, topk=(1,))[0][0] 91 | record['prec5'] = accuracy(pred_var.data, labels, topk=(5,))[0][0] 92 | record['loss'] = loss_total.data[0] 93 | #******************************************************** 94 | 95 | #****** BACKPROPAGATE AND APPLY OPTIMIZATION STEP ******* 96 | if do_train: 97 | loss_total.backward() 98 | self.optimizers['classifier'].step() 99 | if finetune_feat_extractor: 100 | self.optimizers['feat_extractor'].step() 101 | #******************************************************** 102 | batch_process_time = time.time() - start 103 | total_time = batch_process_time + batch_load_time 104 | record['load_time'] = 100*(batch_load_time/total_time) 105 | record['process_time'] = 100*(batch_process_time/total_time) 106 | 107 | return record 108 | -------------------------------------------------------------------------------- /algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .Algorithm import * 2 | from .ClassificationModel import ClassificationModel 3 | from .FeatureClassificationModel import FeatureClassificationModel 4 | -------------------------------------------------------------------------------- /architectures/AlexNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pdb import set_trace as breakpoint 7 | 8 | class Flatten(nn.Module): 9 | def __init__(self): 10 | super(Flatten, self).__init__() 11 | 12 | def forward(self, feat): 13 | return feat.view(feat.size(0), -1) 14 | 15 | class AlexNet(nn.Module): 16 | def __init__(self, opt): 17 | super(AlexNet, self).__init__() 18 | num_classes = opt['num_classes'] 19 | 20 | conv1 = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(inplace=True), 24 | ) 25 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 26 | conv2 = nn.Sequential( 27 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 28 | nn.BatchNorm2d(192), 29 | nn.ReLU(inplace=True), 30 | ) 31 | pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 32 | conv3 = nn.Sequential( 33 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(384), 35 | nn.ReLU(inplace=True), 36 | ) 37 | conv4 = nn.Sequential( 38 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(inplace=True), 41 | ) 42 | conv5 = nn.Sequential( 43 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 44 | nn.BatchNorm2d(256), 45 | nn.ReLU(inplace=True), 46 | ) 47 | pool5 = nn.MaxPool2d(kernel_size=3, stride=2) 48 | 49 | num_pool5_feats = 6 * 6 * 256 50 | fc_block = nn.Sequential( 51 | Flatten(), 52 | nn.Linear(num_pool5_feats, 4096, bias=False), 53 | nn.BatchNorm1d(4096), 54 | nn.ReLU(inplace=True), 55 | nn.Linear(4096, 4096, bias=False), 56 | nn.BatchNorm1d(4096), 57 | nn.ReLU(inplace=True), 58 | ) 59 | classifier = nn.Sequential( 60 | nn.Linear(4096, num_classes), 61 | ) 62 | 63 | self._feature_blocks = nn.ModuleList([ 64 | conv1, 65 | pool1, 66 | conv2, 67 | pool2, 68 | conv3, 69 | conv4, 70 | conv5, 71 | pool5, 72 | fc_block, 73 | classifier, 74 | ]) 75 | self.all_feat_names = [ 76 | 'conv1', 77 | 'pool1', 78 | 'conv2', 79 | 'pool2', 80 | 'conv3', 81 | 'conv4', 82 | 'conv5', 83 | 'pool5', 84 | 'fc_block', 85 | 'classifier', 86 | ] 87 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 88 | 89 | def _parse_out_keys_arg(self, out_feat_keys): 90 | 91 | # By default return the features of the last layer / module. 92 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 93 | 94 | if len(out_feat_keys) == 0: 95 | raise ValueError('Empty list of output feature keys.') 96 | for f, key in enumerate(out_feat_keys): 97 | if key not in self.all_feat_names: 98 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 99 | elif key in out_feat_keys[:f]: 100 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 101 | 102 | # Find the highest output feature in `out_feat_keys 103 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 104 | 105 | return out_feat_keys, max_out_feat 106 | 107 | def forward(self, x, out_feat_keys=None): 108 | """Forward an image `x` through the network and return the asked output features. 109 | 110 | Args: 111 | x: input image. 112 | out_feat_keys: a list/tuple with the feature names of the features 113 | that the function should return. By default the last feature of 114 | the network is returned. 115 | 116 | Return: 117 | out_feats: If multiple output features were asked then `out_feats` 118 | is a list with the asked output features placed in the same 119 | order as in `out_feat_keys`. If a single output feature was 120 | asked then `out_feats` is that output feature (and not a list). 121 | """ 122 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 123 | out_feats = [None] * len(out_feat_keys) 124 | 125 | feat = x 126 | for f in range(max_out_feat+1): 127 | feat = self._feature_blocks[f](feat) 128 | key = self.all_feat_names[f] 129 | if key in out_feat_keys: 130 | out_feats[out_feat_keys.index(key)] = feat 131 | 132 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 133 | return out_feats 134 | 135 | def get_L1filters(self): 136 | convlayer = self._feature_blocks[0][0] 137 | batchnorm = self._feature_blocks[0][1] 138 | filters = convlayer.weight.data 139 | scalars = (batchnorm.weight.data / torch.sqrt(batchnorm.running_var + 1e-05)) 140 | filters = (filters * scalars.view(-1, 1, 1, 1).expand_as(filters)).cpu().clone() 141 | 142 | return filters 143 | 144 | def create_model(opt): 145 | return AlexNet(opt) 146 | 147 | if __name__ == '__main__': 148 | size = 224 149 | opt = {'num_classes':4} 150 | 151 | net = create_model(opt) 152 | x = torch.autograd.Variable(torch.FloatTensor(1,3,size,size).uniform_(-1,1)) 153 | 154 | out = net(x, out_feat_keys=net.all_feat_names) 155 | for f in range(len(out)): 156 | print('Output feature {0} - size {1}'.format( 157 | net.all_feat_names[f], out[f].size())) 158 | 159 | filters = net.get_L1filters() 160 | 161 | print('First layer filter shape: {0}'.format(filters.size())) 162 | -------------------------------------------------------------------------------- /architectures/LinearClassifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class Flatten(nn.Module): 8 | def __init__(self): 9 | super(Flatten, self).__init__() 10 | 11 | def forward(self, feat): 12 | return feat.view(feat.size(0), -1) 13 | 14 | class Classifier(nn.Module): 15 | def __init__(self, opt): 16 | super(Classifier, self).__init__() 17 | nChannels = opt['nChannels'] 18 | num_classes = opt['num_classes'] 19 | pool_size = opt['pool_size'] 20 | pool_type = opt['pool_type'] if ('pool_type' in opt) else 'max' 21 | nChannelsAll = nChannels * pool_size * pool_size 22 | 23 | self.classifier = nn.Sequential() 24 | if pool_type == 'max': 25 | self.classifier.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) 26 | elif pool_type == 'avg': 27 | self.classifier.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) 28 | self.classifier.add_module('BatchNorm', nn.BatchNorm2d(nChannels, affine=False)) 29 | self.classifier.add_module('Flatten', Flatten()) 30 | self.classifier.add_module('LiniearClassifier', nn.Linear(nChannelsAll, num_classes)) 31 | self.initilize() 32 | 33 | def forward(self, feat): 34 | return self.classifier(feat) 35 | 36 | def initilize(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Linear): 39 | fin = m.in_features 40 | fout = m.out_features 41 | std_val = np.sqrt(2.0/fout) 42 | m.weight.data.normal_(0.0, std_val) 43 | if m.bias is not None: 44 | m.bias.data.fill_(0.0) 45 | 46 | def create_model(opt): 47 | return Classifier(opt) 48 | -------------------------------------------------------------------------------- /architectures/MultipleLinearClassifiers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import imp 4 | import os 5 | 6 | 7 | current_path = os.path.abspath(__file__) 8 | filepath_to_linear_classifier_definition = os.path.join(os.path.dirname(current_path), 'LinearClassifier.py') 9 | LinearClassifier = imp.load_source('',filepath_to_linear_classifier_definition).create_model 10 | 11 | 12 | class MClassifier(nn.Module): 13 | def __init__(self, opts): 14 | super(MClassifier, self).__init__() 15 | self.classifiers = nn.ModuleList([LinearClassifier(opt) for opt in opts]) 16 | self.num_classifiers = len(opts) 17 | 18 | 19 | def forward(self, feats): 20 | assert(len(feats) == self.num_classifiers) 21 | return [self.classifiers[i](feat) for i, feat in enumerate(feats)] 22 | 23 | 24 | def create_model(opt): 25 | return MClassifier(opt) 26 | 27 | -------------------------------------------------------------------------------- /architectures/MultipleNonLinearClassifiers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import imp 4 | import os 5 | 6 | 7 | current_path = os.path.abspath(__file__) 8 | filepath_to_classifier_definition = os.path.join(os.path.dirname(current_path), 'NonLinearClassifier.py') 9 | NonLinearClassifier = imp.load_source('',filepath_to_classifier_definition).create_model 10 | 11 | 12 | class MClassifier(nn.Module): 13 | def __init__(self, opts): 14 | super(MClassifier, self).__init__() 15 | self.classifiers = nn.ModuleList([NonLinearClassifier(opt) for opt in opts]) 16 | self.num_classifiers = len(opts) 17 | 18 | 19 | def forward(self, feats): 20 | assert(len(feats) == self.num_classifiers) 21 | return [self.classifiers[i](feat) for i, feat in enumerate(feats)] 22 | 23 | 24 | def create_model(opt): 25 | return MClassifier(opt) 26 | -------------------------------------------------------------------------------- /architectures/NetworkInNetwork.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicBlock(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size): 8 | super(BasicBlock, self).__init__() 9 | padding = (kernel_size-1)/2 10 | self.layers = nn.Sequential() 11 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes, \ 12 | kernel_size=kernel_size, stride=1, padding=padding, bias=False)) 13 | self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes)) 14 | self.layers.add_module('ReLU', nn.ReLU(inplace=True)) 15 | 16 | def forward(self, x): 17 | return self.layers(x) 18 | 19 | feat = F.avg_pool2d(feat, feat.size(3)).view(-1, self.nChannels) 20 | 21 | class GlobalAveragePooling(nn.Module): 22 | def __init__(self): 23 | super(GlobalAveragePooling, self).__init__() 24 | 25 | def forward(self, feat): 26 | num_channels = feat.size(1) 27 | return F.avg_pool2d(feat, (feat.size(2), feat.size(3))).view(-1, num_channels) 28 | 29 | class NetworkInNetwork(nn.Module): 30 | def __init__(self, opt): 31 | super(NetworkInNetwork, self).__init__() 32 | 33 | num_classes = opt['num_classes'] 34 | num_inchannels = opt['num_inchannels'] if ('num_inchannels' in opt) else 3 35 | num_stages = opt['num_stages'] if ('num_stages' in opt) else 3 36 | use_avg_on_conv3 = opt['use_avg_on_conv3'] if ('use_avg_on_conv3' in opt) else True 37 | 38 | 39 | assert(num_stages >= 3) 40 | nChannels = 192 41 | nChannels2 = 160 42 | nChannels3 = 96 43 | 44 | blocks = [nn.Sequential() for i in range(num_stages)] 45 | # 1st block 46 | blocks[0].add_module('Block1_ConvB1', BasicBlock(num_inchannels, nChannels, 5)) 47 | blocks[0].add_module('Block1_ConvB2', BasicBlock(nChannels, nChannels2, 1)) 48 | blocks[0].add_module('Block1_ConvB3', BasicBlock(nChannels2, nChannels3, 1)) 49 | blocks[0].add_module('Block1_MaxPool', nn.MaxPool2d(kernel_size=3,stride=2,padding=1)) 50 | 51 | # 2nd block 52 | blocks[1].add_module('Block2_ConvB1', BasicBlock(nChannels3, nChannels, 5)) 53 | blocks[1].add_module('Block2_ConvB2', BasicBlock(nChannels, nChannels, 1)) 54 | blocks[1].add_module('Block2_ConvB3', BasicBlock(nChannels, nChannels, 1)) 55 | blocks[1].add_module('Block2_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 56 | 57 | # 3rd block 58 | blocks[2].add_module('Block3_ConvB1', BasicBlock(nChannels, nChannels, 3)) 59 | blocks[2].add_module('Block3_ConvB2', BasicBlock(nChannels, nChannels, 1)) 60 | blocks[2].add_module('Block3_ConvB3', BasicBlock(nChannels, nChannels, 1)) 61 | 62 | if num_stages > 3 and use_avg_on_conv3: 63 | blocks[2].add_module('Block3_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1)) 64 | for s in range(3, num_stages): 65 | blocks[s].add_module('Block'+str(s+1)+'_ConvB1', BasicBlock(nChannels, nChannels, 3)) 66 | blocks[s].add_module('Block'+str(s+1)+'_ConvB2', BasicBlock(nChannels, nChannels, 1)) 67 | blocks[s].add_module('Block'+str(s+1)+'_ConvB3', BasicBlock(nChannels, nChannels, 1)) 68 | 69 | # global average pooling and classifier 70 | blocks.append(nn.Sequential()) 71 | blocks[-1].add_module('GlobalAveragePooling', GlobalAveragePooling()) 72 | blocks[-1].add_module('Classifier', nn.Linear(nChannels, num_classes)) 73 | 74 | self._feature_blocks = nn.ModuleList(blocks) 75 | self.all_feat_names = ['conv'+str(s+1) for s in range(num_stages)] + ['classifier',] 76 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 77 | 78 | def _parse_out_keys_arg(self, out_feat_keys): 79 | 80 | # By default return the features of the last layer / module. 81 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 82 | 83 | if len(out_feat_keys) == 0: 84 | raise ValueError('Empty list of output feature keys.') 85 | for f, key in enumerate(out_feat_keys): 86 | if key not in self.all_feat_names: 87 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 88 | elif key in out_feat_keys[:f]: 89 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 90 | 91 | # Find the highest output feature in `out_feat_keys 92 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 93 | 94 | return out_feat_keys, max_out_feat 95 | 96 | def forward(self, x, out_feat_keys=None): 97 | """Forward an image `x` through the network and return the asked output features. 98 | 99 | Args: 100 | x: input image. 101 | out_feat_keys: a list/tuple with the feature names of the features 102 | that the function should return. By default the last feature of 103 | the network is returned. 104 | 105 | Return: 106 | out_feats: If multiple output features were asked then `out_feats` 107 | is a list with the asked output features placed in the same 108 | order as in `out_feat_keys`. If a single output feature was 109 | asked then `out_feats` is that output feature (and not a list). 110 | """ 111 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 112 | out_feats = [None] * len(out_feat_keys) 113 | 114 | feat = x 115 | for f in range(max_out_feat+1): 116 | feat = self._feature_blocks[f](feat) 117 | key = self.all_feat_names[f] 118 | if key in out_feat_keys: 119 | out_feats[out_feat_keys.index(key)] = feat 120 | 121 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 122 | return out_feats 123 | 124 | 125 | def weight_initialization(self): 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | if m.weight.requires_grad: 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | if m.weight.requires_grad: 133 | m.weight.data.fill_(1) 134 | if m.bias.requires_grad: 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.Linear): 137 | if m.bias.requires_grad: 138 | m.bias.data.zero_() 139 | 140 | def create_model(opt): 141 | return NetworkInNetwork(opt) 142 | 143 | if __name__ == '__main__': 144 | size = 32 145 | opt = {'num_classes':4, 'num_stages': 5} 146 | 147 | net = create_model(opt) 148 | x = torch.autograd.Variable(torch.FloatTensor(1,3,size,size).uniform_(-1,1)) 149 | 150 | out = net(x, out_feat_keys=net.all_feat_names) 151 | for f in range(len(out)): 152 | print('Output feature {0} - size {1}'.format( 153 | net.all_feat_names[f], out[f].size())) 154 | 155 | 156 | out = net(x) 157 | print('Final output: {0}'.format(out.size())) 158 | -------------------------------------------------------------------------------- /architectures/NonLinearClassifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1): 9 | super(BasicBlock, self).__init__() 10 | padding = (kernel_size-1)/2 11 | self.layers = nn.Sequential() 12 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes, \ 13 | kernel_size=kernel_size, stride=stride, padding=padding, bias=False)) 14 | self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes)) 15 | self.layers.add_module('ReLU', nn.ReLU(inplace=True)) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | class GlobalAvgPool(nn.Module): 21 | def __init__(self): 22 | super(GlobalAvgPool, self).__init__() 23 | 24 | def forward(self, feat): 25 | assert(feat.size(2) == feat.size(3)) 26 | feat_avg = F.avg_pool2d(feat, feat.size(2)).view(-1, feat.size(1)) 27 | return feat_avg 28 | 29 | class Flatten(nn.Module): 30 | def __init__(self): 31 | super(Flatten, self).__init__() 32 | 33 | def forward(self, feat): 34 | return feat.view(feat.size(0), -1) 35 | 36 | class Classifier(nn.Module): 37 | def __init__(self, opt): 38 | super(Classifier, self).__init__() 39 | nChannels = opt['nChannels'] 40 | num_classes = opt['num_classes'] 41 | self.cls_type = opt['cls_type'] 42 | 43 | self.classifier = nn.Sequential() 44 | 45 | if self.cls_type == 'MultLayer': 46 | nFeats = min(num_classes*20, 2048) 47 | self.classifier.add_module('Flatten', Flatten()) 48 | self.classifier.add_module('Liniear_1', nn.Linear(nChannels, nFeats, bias=False)) 49 | self.classifier.add_module('BatchNorm_1', nn.BatchNorm2d(nFeats)) 50 | self.classifier.add_module('ReLU_1', nn.ReLU(inplace=True)) 51 | self.classifier.add_module('Liniear_2', nn.Linear(nFeats, nFeats, bias=False)) 52 | self.classifier.add_module('BatchNorm2d', nn.BatchNorm2d(nFeats)) 53 | self.classifier.add_module('ReLU_2', nn.ReLU(inplace=True)) 54 | self.classifier.add_module('Liniear_F', nn.Linear(nFeats, num_classes)) 55 | elif self.cls_type == 'NIN_ConvBlock3': 56 | self.classifier.add_module('Block3_ConvB1', BasicBlock(nChannels, 192, 3)) 57 | self.classifier.add_module('Block3_ConvB2', BasicBlock(192, 192, 1)) 58 | self.classifier.add_module('Block3_ConvB3', BasicBlock(192, 192, 1)) 59 | self.classifier.add_module('GlobalAvgPool', GlobalAvgPool()) 60 | self.classifier.add_module('Liniear_F', nn.Linear(192, num_classes)) 61 | elif self.cls_type == 'Alexnet_conv5' or self.cls_type == 'Alexnet_conv4': 62 | if self.cls_type == 'Alexnet_conv4': 63 | block5 = nn.Sequential( 64 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 65 | nn.BatchNorm2d(256), 66 | nn.ReLU(inplace=True), 67 | ) 68 | self.classifier.add_module('ConvB5', block5) 69 | self.classifier.add_module('Pool5', nn.MaxPool2d(kernel_size=3, stride=2)) 70 | self.classifier.add_module('Flatten', Flatten()) 71 | self.classifier.add_module('Linear1', nn.Linear(256*6*6, 4096, bias=False)) 72 | self.classifier.add_module('BatchNorm1', nn.BatchNorm1d(4096)) 73 | self.classifier.add_module('ReLU1', nn.ReLU(inplace=True)) 74 | self.classifier.add_module('Liniear2', nn.Linear(4096, 4096, bias=False)) 75 | self.classifier.add_module('BatchNorm2', nn.BatchNorm1d(4096)) 76 | self.classifier.add_module('ReLU2', nn.ReLU(inplace=True)) 77 | self.classifier.add_module('LinearF', nn.Linear(4096, num_classes)) 78 | else: 79 | raise ValueError('Not recognized classifier type: %s' % self.cls_type) 80 | 81 | self.initilize() 82 | 83 | def forward(self, feat): 84 | return self.classifier(feat) 85 | 86 | def initilize(self): 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 90 | m.weight.data.normal_(0, math.sqrt(2. / n)) 91 | elif isinstance(m, nn.BatchNorm2d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | elif isinstance(m, nn.Linear): 95 | fin = m.in_features 96 | fout = m.out_features 97 | std_val = np.sqrt(2.0/fout) 98 | m.weight.data.normal_(0.0, std_val) 99 | if m.bias is not None: 100 | m.bias.data.fill_(0.0) 101 | 102 | def create_model(opt): 103 | return Classifier(opt) 104 | -------------------------------------------------------------------------------- /config/CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'cifar10' 19 | data_test_opt['split'] = 'test' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 100 24 | 25 | networks = {} 26 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 27 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 28 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 29 | 30 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)]} 31 | cls_net_opt = {'num_classes':10, 'nChannels':192, 'cls_type':'NIN_ConvBlock3'} 32 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 33 | config['out_feat_keys'] = ['conv2'] 34 | 35 | config['networks'] = networks 36 | 37 | criterions = {} 38 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 39 | config['criterions'] = criterions 40 | config['algorithm_type'] = 'FeatureClassificationModel' 41 | config['best_metric'] = 'prec1' 42 | -------------------------------------------------------------------------------- /config/CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K100.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 100 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 100 25 | 26 | networks = {} 27 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 28 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 29 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 30 | 31 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)]} 32 | cls_net_opt = {'num_classes':10, 'nChannels':192, 'cls_type':'NIN_ConvBlock3'} 33 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 34 | config['out_feat_keys'] = ['conv2'] 35 | 36 | config['networks'] = networks 37 | 38 | criterions = {} 39 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 40 | config['criterions'] = criterions 41 | config['algorithm_type'] = 'FeatureClassificationModel' 42 | config['best_metric'] = 'prec1' 43 | -------------------------------------------------------------------------------- /config/CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K1000.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 1000 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 100 25 | 26 | networks = {} 27 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 28 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 29 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 30 | 31 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)]} 32 | cls_net_opt = {'num_classes':10, 'nChannels':192, 'cls_type':'NIN_ConvBlock3'} 33 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 34 | config['out_feat_keys'] = ['conv2'] 35 | 36 | config['networks'] = networks 37 | 38 | criterions = {} 39 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 40 | config['criterions'] = criterions 41 | config['algorithm_type'] = 'FeatureClassificationModel' 42 | config['best_metric'] = 'prec1' 43 | -------------------------------------------------------------------------------- /config/CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K20.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 20 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 100 25 | 26 | networks = {} 27 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 28 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 29 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 30 | 31 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)]} 32 | cls_net_opt = {'num_classes':10, 'nChannels':192, 'cls_type':'NIN_ConvBlock3'} 33 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 34 | config['out_feat_keys'] = ['conv2'] 35 | 36 | config['networks'] = networks 37 | 38 | criterions = {} 39 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 40 | config['criterions'] = criterions 41 | config['algorithm_type'] = 'FeatureClassificationModel' 42 | config['best_metric'] = 'prec1' 43 | -------------------------------------------------------------------------------- /config/CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K400.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 400 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 100 25 | 26 | networks = {} 27 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 28 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 29 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 30 | 31 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(35, 0.1),(70, 0.02),(85, 0.004),(100, 0.0008)]} 32 | cls_net_opt = {'num_classes':10, 'nChannels':192, 'cls_type':'NIN_ConvBlock3'} 33 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 34 | config['out_feat_keys'] = ['conv2'] 35 | 36 | config['networks'] = networks 37 | 38 | criterions = {} 39 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 40 | config['criterions'] = criterions 41 | config['algorithm_type'] = 'FeatureClassificationModel' 42 | config['best_metric'] = 'prec1' 43 | -------------------------------------------------------------------------------- /config/CIFAR10_MultLayerClassifier_on_RotNet_NIN4blocks_Conv2_feats.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'cifar10' 19 | data_test_opt['split'] = 'test' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 100 24 | 25 | networks = {} 26 | feat_net_opt = {'num_classes': 4, 'num_stages': 4, 'use_avg_on_conv3': False} 27 | feat_pretrained_file = './experiments/CIFAR10_RotNet_NIN4blocks/model_net_epoch200' 28 | networks['feat_extractor'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': feat_pretrained_file, 'opt': feat_net_opt, 'optim_params': None} 29 | 30 | cls_net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(20, 0.1),(40, 0.02),(45, 0.004),(50, 0.0008)]} 31 | cls_net_opt = {'num_classes':10, 'nChannels':192*8*8, 'cls_type':'MultLayer'} 32 | networks['classifier'] = {'def_file': 'architectures/NonLinearClassifier.py', 'pretrained': None, 'opt': cls_net_opt, 'optim_params': cls_net_optim_params} 33 | config['out_feat_keys'] = ['conv2'] 34 | 35 | config['networks'] = networks 36 | 37 | criterions = {} 38 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 39 | config['criterions'] = criterions 40 | config['algorithm_type'] = 'FeatureClassificationModel' 41 | config['best_metric'] = 'prec1' 42 | -------------------------------------------------------------------------------- /config/CIFAR10_RotNet_NIN4blocks.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = True 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = True 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'cifar10' 19 | data_test_opt['split'] = 'test' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 200 24 | 25 | net_opt = {} 26 | net_opt['num_classes'] = 4 27 | net_opt['num_stages'] = 4 28 | net_opt['use_avg_on_conv3'] = False 29 | 30 | networks = {} 31 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 32 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 33 | config['networks'] = networks 34 | 35 | criterions = {} 36 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 37 | config['criterions'] = criterions 38 | config['algorithm_type'] = 'ClassificationModel' 39 | -------------------------------------------------------------------------------- /config/CIFAR10_supervised_NIN.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'cifar10' 19 | data_test_opt['split'] = 'test' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 200 24 | 25 | net_opt = {} 26 | net_opt['num_classes'] = 10 27 | net_opt['num_stages'] = 3 28 | 29 | networks = {} 30 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 31 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 32 | config['networks'] = networks 33 | 34 | criterions = {} 35 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 36 | config['criterions'] = criterions 37 | config['algorithm_type'] = 'ClassificationModel' 38 | -------------------------------------------------------------------------------- /config/CIFAR10_supervised_NIN_K100.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 100 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 200 25 | 26 | net_opt = {} 27 | net_opt['num_classes'] = 10 28 | net_opt['num_stages'] = 3 29 | 30 | networks = {} 31 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 32 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 33 | config['networks'] = networks 34 | 35 | criterions = {} 36 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 37 | config['criterions'] = criterions 38 | config['algorithm_type'] = 'ClassificationModel' 39 | -------------------------------------------------------------------------------- /config/CIFAR10_supervised_NIN_K1000.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 1000 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 200 25 | 26 | net_opt = {} 27 | net_opt['num_classes'] = 10 28 | net_opt['num_stages'] = 3 29 | 30 | networks = {} 31 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 32 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 33 | config['networks'] = networks 34 | 35 | criterions = {} 36 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 37 | config['criterions'] = criterions 38 | config['algorithm_type'] = 'ClassificationModel' 39 | -------------------------------------------------------------------------------- /config/CIFAR10_supervised_NIN_K20.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 20 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 200 25 | 26 | net_opt = {} 27 | net_opt['num_classes'] = 10 28 | net_opt['num_stages'] = 3 29 | 30 | networks = {} 31 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 32 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 33 | config['networks'] = networks 34 | 35 | criterions = {} 36 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 37 | config['criterions'] = criterions 38 | config['algorithm_type'] = 'ClassificationModel' 39 | -------------------------------------------------------------------------------- /config/CIFAR10_supervised_NIN_K400.py: -------------------------------------------------------------------------------- 1 | batch_size = 128 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = 10 * 5000 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'cifar10' 11 | data_train_opt['split'] = 'train' 12 | data_train_opt['num_imgs_per_cat'] = 400 13 | 14 | data_test_opt = {} 15 | data_test_opt['batch_size'] = batch_size 16 | data_test_opt['unsupervised'] = False 17 | data_test_opt['epoch_size'] = None 18 | data_test_opt['random_sized_crop'] = False 19 | data_test_opt['dataset_name'] = 'cifar10' 20 | data_test_opt['split'] = 'test' 21 | 22 | config['data_train_opt'] = data_train_opt 23 | config['data_test_opt'] = data_test_opt 24 | config['max_num_epochs'] = 200 25 | 26 | net_opt = {} 27 | net_opt['num_classes'] = 10 28 | net_opt['num_stages'] = 3 29 | 30 | networks = {} 31 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(60, 0.1),(120, 0.02),(160, 0.004),(200, 0.0008)]} 32 | networks['model'] = {'def_file': 'architectures/NetworkInNetwork.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 33 | config['networks'] = networks 34 | 35 | criterions = {} 36 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 37 | config['criterions'] = criterions 38 | config['algorithm_type'] = 'ClassificationModel' 39 | -------------------------------------------------------------------------------- /config/ImageNet_LinearClassifiers_ImageNet_RotNet_AlexNet_Features.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'imagenet' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'imagenet' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 35 24 | 25 | 26 | networks = {} 27 | 28 | pretrained = './experiments/ImageNet_RotNet_AlexNet/model_net_epoch50' 29 | networks['feat_extractor'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': pretrained, 'opt': {'num_classes': 4}, 'optim_params': None} 30 | 31 | net_opt_cls = [None] * 5 32 | net_opt_cls[0] = {'pool_type':'max', 'nChannels':64, 'pool_size':12, 'num_classes': 1000} 33 | net_opt_cls[1] = {'pool_type':'max', 'nChannels':192, 'pool_size':7, 'num_classes': 1000} 34 | net_opt_cls[2] = {'pool_type':'max', 'nChannels':384, 'pool_size':5, 'num_classes': 1000} 35 | net_opt_cls[3] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 1000} 36 | net_opt_cls[4] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 1000} 37 | out_feat_keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5'] 38 | net_optim_params_cls = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(5, 0.01),(15, 0.002),(25, 0.0004),(35, 0.00008)]} 39 | networks['classifier'] = {'def_file': 'architectures/MultipleLinearClassifiers.py', 'pretrained': None, 'opt': net_opt_cls, 'optim_params': net_optim_params_cls} 40 | 41 | config['networks'] = networks 42 | 43 | criterions = {} 44 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 45 | config['criterions'] = criterions 46 | config['algorithm_type'] = 'FeatureClassificationModel' 47 | config['out_feat_keys'] = out_feat_keys 48 | 49 | -------------------------------------------------------------------------------- /config/ImageNet_LinearClassifiers_Places205_RotNet_AlexNet_Features.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'imagenet' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'imagenet' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 35 24 | 25 | networks = {} 26 | 27 | pretrained = './experiments/Places205_RotNet_AlexNet/model_net_epoch50' 28 | networks['feat_extractor'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': pretrained, 'opt': {'num_classes': 4}, 'optim_params': None} 29 | 30 | net_opt_cls = [None] * 5 31 | net_opt_cls[0] = {'pool_type':'max', 'nChannels':64, 'pool_size':12, 'num_classes': 1000} 32 | net_opt_cls[1] = {'pool_type':'max', 'nChannels':192, 'pool_size':7, 'num_classes': 1000} 33 | net_opt_cls[2] = {'pool_type':'max', 'nChannels':384, 'pool_size':5, 'num_classes': 1000} 34 | net_opt_cls[3] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 1000} 35 | net_opt_cls[4] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 1000} 36 | out_feat_keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5'] 37 | net_optim_params_cls = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(5, 0.01),(15, 0.002),(25, 0.0004),(35, 0.00008)]} 38 | networks['classifier'] = {'def_file': 'architectures/MultipleLinearClassifiers.py', 'pretrained': None, 'opt': net_opt_cls, 'optim_params': net_optim_params_cls} 39 | 40 | config['networks'] = networks 41 | 42 | criterions = {} 43 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 44 | config['criterions'] = criterions 45 | config['algorithm_type'] = 'FeatureClassificationModel' 46 | config['out_feat_keys'] = out_feat_keys 47 | 48 | -------------------------------------------------------------------------------- /config/ImageNet_NonLinearClassifiers_ImageNet_RotNet_AlexNet_Features.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'imagenet' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'imagenet' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 40 24 | 25 | 26 | networks = {} 27 | 28 | pretrained = './experiments/ImageNet_RotNet_AlexNet/model_net_epoch50' 29 | networks['feat_extractor'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': pretrained, 'opt': {'num_classes': 4}, 'optim_params': None} 30 | 31 | net_opt_cls = [None] * 2 32 | net_opt_cls[0] = {'cls_type':'Alexnet_conv4', 'nChannels':256, 'num_classes':1000} 33 | net_opt_cls[1] = {'cls_type':'Alexnet_conv5', 'nChannels':256, 'num_classes':1000} 34 | out_feat_keys = ['conv4', 'conv5'] 35 | net_optim_params_cls = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(10, 0.01),(20, 0.002),(30, 0.0004),(40, 0.00008)]} 36 | networks['classifier'] = {'def_file': 'architectures/MultipleNonLinearClassifiers.py', 'pretrained': None, 'opt': net_opt_cls, 'optim_params': net_optim_params_cls} 37 | 38 | config['networks'] = networks 39 | 40 | criterions = {} 41 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 42 | config['criterions'] = criterions 43 | config['algorithm_type'] = 'FeatureClassificationModel' 44 | config['out_feat_keys'] = out_feat_keys 45 | -------------------------------------------------------------------------------- /config/ImageNet_RotNet_AlexNet.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = True 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = True 10 | data_train_opt['dataset_name'] = 'imagenet' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = True 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'imagenet' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 50 24 | 25 | net_opt = {} 26 | net_opt['num_classes'] = 4 27 | net_opt['num_stages'] = 4 28 | 29 | networks = {} 30 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(15, 0.01),(30, 0.001),(45, 0.0001),(50, 0.00001)]} 31 | networks['model'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 32 | config['networks'] = networks 33 | 34 | criterions = {} 35 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 36 | config['criterions'] = criterions 37 | config['algorithm_type'] = 'ClassificationModel' 38 | -------------------------------------------------------------------------------- /config/Places205_LinearClassifiers_ImageNet_RotNet_AlexNet_Features.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'places205' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'places205' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 35 24 | 25 | 26 | networks = {} 27 | 28 | pretrained = './experiments/ImageNet_RotNet_AlexNet/model_net_epoch50' 29 | networks['feat_extractor'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': pretrained, 'opt': {'num_classes': 4}, 'optim_params': None} 30 | 31 | net_opt_cls = [None] * 5 32 | net_opt_cls[0] = {'pool_type':'max', 'nChannels':64, 'pool_size':12, 'num_classes': 205} 33 | net_opt_cls[1] = {'pool_type':'max', 'nChannels':192, 'pool_size':7, 'num_classes': 205} 34 | net_opt_cls[2] = {'pool_type':'max', 'nChannels':384, 'pool_size':5, 'num_classes': 205} 35 | net_opt_cls[3] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 205} 36 | net_opt_cls[4] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 205} 37 | out_feat_keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5'] 38 | net_optim_params_cls = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(5, 0.01),(15, 0.002),(25, 0.0004),(35, 0.00008)]} 39 | networks['classifier'] = {'def_file': 'architectures/MultipleLinearClassifiers.py', 'pretrained': None, 'opt': net_opt_cls, 'optim_params': net_optim_params_cls} 40 | 41 | config['networks'] = networks 42 | 43 | criterions = {} 44 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 45 | config['criterions'] = criterions 46 | config['algorithm_type'] = 'FeatureClassificationModel' 47 | config['out_feat_keys'] = out_feat_keys 48 | 49 | -------------------------------------------------------------------------------- /config/Places205_LinearClassifiers_Places205_RotNet_AlexNet_Features.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = False 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = False 10 | data_train_opt['dataset_name'] = 'places205' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = False 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'places205' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 35 24 | 25 | 26 | networks = {} 27 | 28 | pretrained = './experiments/Places205_RotNet_AlexNet/model_net_epoch50' 29 | networks['feat_extractor'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': pretrained, 'opt': {'num_classes': 4}, 'optim_params': None} 30 | 31 | net_opt_cls = [None] * 5 32 | net_opt_cls[0] = {'pool_type':'max', 'nChannels':64, 'pool_size':12, 'num_classes': 205} 33 | net_opt_cls[1] = {'pool_type':'max', 'nChannels':192, 'pool_size':7, 'num_classes': 205} 34 | net_opt_cls[2] = {'pool_type':'max', 'nChannels':384, 'pool_size':5, 'num_classes': 205} 35 | net_opt_cls[3] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 205} 36 | net_opt_cls[4] = {'pool_type':'max', 'nChannels':256, 'pool_size':6, 'num_classes': 205} 37 | out_feat_keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5'] 38 | net_optim_params_cls = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(5, 0.01),(15, 0.002),(25, 0.0004),(35, 0.00008)]} 39 | networks['classifier'] = {'def_file': 'architectures/MultipleLinearClassifiers.py', 'pretrained': None, 'opt': net_opt_cls, 'optim_params': net_optim_params_cls} 40 | 41 | config['networks'] = networks 42 | 43 | criterions = {} 44 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 45 | config['criterions'] = criterions 46 | config['algorithm_type'] = 'FeatureClassificationModel' 47 | config['out_feat_keys'] = out_feat_keys 48 | 49 | -------------------------------------------------------------------------------- /config/Places205_RotNet_AlexNet.py: -------------------------------------------------------------------------------- 1 | batch_size = 192 2 | 3 | config = {} 4 | # set the parameters related to the training and testing set 5 | data_train_opt = {} 6 | data_train_opt['batch_size'] = batch_size 7 | data_train_opt['unsupervised'] = True 8 | data_train_opt['epoch_size'] = None 9 | data_train_opt['random_sized_crop'] = True 10 | data_train_opt['dataset_name'] = 'places205' 11 | data_train_opt['split'] = 'train' 12 | 13 | data_test_opt = {} 14 | data_test_opt['batch_size'] = batch_size 15 | data_test_opt['unsupervised'] = True 16 | data_test_opt['epoch_size'] = None 17 | data_test_opt['random_sized_crop'] = False 18 | data_test_opt['dataset_name'] = 'places205' 19 | data_test_opt['split'] = 'val' 20 | 21 | config['data_train_opt'] = data_train_opt 22 | config['data_test_opt'] = data_test_opt 23 | config['max_num_epochs'] = 50 24 | 25 | net_opt = {} 26 | net_opt['num_classes'] = 4 27 | net_opt['num_stages'] = 4 28 | 29 | networks = {} 30 | net_optim_params = {'optim_type': 'sgd', 'lr': 0.1, 'momentum':0.9, 'weight_decay': 5e-4, 'nesterov': True, 'LUT_lr':[(15, 0.01),(30, 0.001),(45, 0.0001),(50, 0.00001)]} 31 | networks['model'] = {'def_file': 'architectures/AlexNet.py', 'pretrained': None, 'opt': net_opt, 'optim_params': net_optim_params} 32 | config['networks'] = networks 33 | 34 | criterions = {} 35 | criterions['loss'] = {'ctype':'CrossEntropyLoss', 'opt':None} 36 | config['criterions'] = criterions 37 | config['algorithm_type'] = 'ClassificationModel' 38 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchnet as tnt 6 | import torchvision.datasets as datasets 7 | import torchvision.transforms as transforms 8 | # from Places205 import Places205 9 | import numpy as np 10 | import random 11 | from torch.utils.data.dataloader import default_collate 12 | from PIL import Image 13 | import os 14 | import errno 15 | import numpy as np 16 | import sys 17 | import csv 18 | 19 | from pdb import set_trace as breakpoint 20 | 21 | # Set the paths of the datasets here. 22 | _CIFAR_DATASET_DIR = './datasets/CIFAR' 23 | _IMAGENET_DATASET_DIR = './datasets/IMAGENET/ILSVRC2012' 24 | _PLACES205_DATASET_DIR = './datasets/Places205' 25 | 26 | 27 | def buildLabelIndex(labels): 28 | label2inds = {} 29 | for idx, label in enumerate(labels): 30 | if label not in label2inds: 31 | label2inds[label] = [] 32 | label2inds[label].append(idx) 33 | 34 | return label2inds 35 | 36 | class Places205(data.Dataset): 37 | def __init__(self, root, split, transform=None, target_transform=None): 38 | self.root = os.path.expanduser(root) 39 | self.data_folder = os.path.join(self.root, 'data', 'vision', 'torralba', 'deeplearning', 'images256') 40 | self.split_folder = os.path.join(self.root, 'trainvalsplit_places205') 41 | assert(split=='train' or split=='val') 42 | split_csv_file = os.path.join(self.split_folder, split+'_places205.csv') 43 | 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | with open(split_csv_file, 'rb') as f: 47 | reader = csv.reader(f, delimiter=' ') 48 | self.img_files = [] 49 | self.labels = [] 50 | for row in reader: 51 | self.img_files.append(row[0]) 52 | self.labels.append(long(row[1])) 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | 59 | Returns: 60 | tuple: (image, target) where target is index of the target class. 61 | """ 62 | image_path = os.path.join(self.data_folder, self.img_files[index]) 63 | img = Image.open(image_path).convert('RGB') 64 | target = self.labels[index] 65 | 66 | if self.transform is not None: 67 | img = self.transform(img) 68 | if self.target_transform is not None: 69 | target = self.target_transform(target) 70 | return img, target 71 | 72 | def __len__(self): 73 | return len(self.labels) 74 | 75 | class GenericDataset(data.Dataset): 76 | def __init__(self, dataset_name, split, random_sized_crop=False, 77 | num_imgs_per_cat=None): 78 | self.split = split.lower() 79 | self.dataset_name = dataset_name.lower() 80 | self.name = self.dataset_name + '_' + self.split 81 | self.random_sized_crop = random_sized_crop 82 | 83 | # The num_imgs_per_cats input argument specifies the number 84 | # of training examples per category that would be used. 85 | # This input argument was introduced in order to be able 86 | # to use less annotated examples than what are available 87 | # in a semi-superivsed experiment. By default all the 88 | # available training examplers per category are being used. 89 | self.num_imgs_per_cat = num_imgs_per_cat 90 | 91 | if self.dataset_name=='imagenet': 92 | assert(self.split=='train' or self.split=='val') 93 | self.mean_pix = [0.485, 0.456, 0.406] 94 | self.std_pix = [0.229, 0.224, 0.225] 95 | 96 | if self.split!='train': 97 | transforms_list = [ 98 | transforms.Scale(256), 99 | transforms.CenterCrop(224), 100 | lambda x: np.asarray(x), 101 | ] 102 | else: 103 | if self.random_sized_crop: 104 | transforms_list = [ 105 | transforms.RandomSizedCrop(224), 106 | transforms.RandomHorizontalFlip(), 107 | lambda x: np.asarray(x), 108 | ] 109 | else: 110 | transforms_list = [ 111 | transforms.Scale(256), 112 | transforms.RandomCrop(224), 113 | transforms.RandomHorizontalFlip(), 114 | lambda x: np.asarray(x), 115 | ] 116 | self.transform = transforms.Compose(transforms_list) 117 | split_data_dir = _IMAGENET_DATASET_DIR + '/' + self.split 118 | self.data = datasets.ImageFolder(split_data_dir, self.transform) 119 | elif self.dataset_name=='places205': 120 | self.mean_pix = [0.485, 0.456, 0.406] 121 | self.std_pix = [0.229, 0.224, 0.225] 122 | if self.split!='train': 123 | transforms_list = [ 124 | transforms.CenterCrop(224), 125 | lambda x: np.asarray(x), 126 | ] 127 | else: 128 | if self.random_sized_crop: 129 | transforms_list = [ 130 | transforms.RandomSizedCrop(224), 131 | transforms.RandomHorizontalFlip(), 132 | lambda x: np.asarray(x), 133 | ] 134 | else: 135 | transforms_list = [ 136 | transforms.RandomCrop(224), 137 | transforms.RandomHorizontalFlip(), 138 | lambda x: np.asarray(x), 139 | ] 140 | self.transform = transforms.Compose(transforms_list) 141 | self.data = Places205(root=_PLACES205_DATASET_DIR, split=self.split, 142 | transform=self.transform) 143 | elif self.dataset_name=='cifar10': 144 | self.mean_pix = [x/255.0 for x in [125.3, 123.0, 113.9]] 145 | self.std_pix = [x/255.0 for x in [63.0, 62.1, 66.7]] 146 | 147 | if self.random_sized_crop: 148 | raise ValueError('The random size crop option is not supported for the CIFAR dataset') 149 | 150 | transform = [] 151 | if (split != 'test'): 152 | transform.append(transforms.RandomCrop(32, padding=4)) 153 | transform.append(transforms.RandomHorizontalFlip()) 154 | transform.append(lambda x: np.asarray(x)) 155 | self.transform = transforms.Compose(transform) 156 | self.data = datasets.__dict__[self.dataset_name.upper()]( 157 | _CIFAR_DATASET_DIR, train=self.split=='train', 158 | download=True, transform=self.transform) 159 | else: 160 | raise ValueError('Not recognized dataset {0}'.format(dname)) 161 | 162 | if num_imgs_per_cat is not None: 163 | self._keep_first_k_examples_per_category(num_imgs_per_cat) 164 | 165 | 166 | def _keep_first_k_examples_per_category(self, num_imgs_per_cat): 167 | print('num_imgs_per_category {0}'.format(num_imgs_per_cat)) 168 | 169 | if self.dataset_name=='cifar10': 170 | labels = self.data.test_labels if (self.split=='test') else self.data.train_labels 171 | data = self.data.test_data if (self.split=='test') else self.data.train_data 172 | label2ind = buildLabelIndex(labels) 173 | all_indices = [] 174 | for cat in label2ind.keys(): 175 | label2ind[cat] = label2ind[cat][:num_imgs_per_cat] 176 | all_indices += label2ind[cat] 177 | all_indices = sorted(all_indices) 178 | data = data[all_indices] 179 | labels = [labels[idx] for idx in all_indices] 180 | if self.split=='test': 181 | self.data.test_labels = labels 182 | self.data.test_data = data 183 | else: 184 | self.data.train_labels = labels 185 | self.data.train_data = data 186 | 187 | label2ind = buildLabelIndex(labels) 188 | for k, v in label2ind.items(): 189 | assert(len(v)==num_imgs_per_cat) 190 | 191 | elif self.dataset_name=='imagenet': 192 | raise ValueError('Keeping k examples per category has not been implemented for the {0}'.format(dname)) 193 | elif self.dataset_name=='place205': 194 | raise ValueError('Keeping k examples per category has not been implemented for the {0}'.format(dname)) 195 | else: 196 | raise ValueError('Not recognized dataset {0}'.format(dname)) 197 | 198 | 199 | def __getitem__(self, index): 200 | img, label = self.data[index] 201 | return img, int(label) 202 | 203 | def __len__(self): 204 | return len(self.data) 205 | 206 | class Denormalize(object): 207 | def __init__(self, mean, std): 208 | self.mean = mean 209 | self.std = std 210 | 211 | def __call__(self, tensor): 212 | for t, m, s in zip(tensor, self.mean, self.std): 213 | t.mul_(s).add_(m) 214 | return tensor 215 | 216 | def rotate_img(img, rot): 217 | if rot == 0: # 0 degrees rotation 218 | return img 219 | elif rot == 90: # 90 degrees rotation 220 | return np.flipud(np.transpose(img, (1,0,2))) 221 | elif rot == 180: # 90 degrees rotation 222 | return np.fliplr(np.flipud(img)) 223 | elif rot == 270: # 270 degrees rotation / or -90 224 | return np.transpose(np.flipud(img), (1,0,2)) 225 | else: 226 | raise ValueError('rotation should be 0, 90, 180, or 270 degrees') 227 | 228 | 229 | class DataLoader(object): 230 | def __init__(self, 231 | dataset, 232 | batch_size=1, 233 | unsupervised=True, 234 | epoch_size=None, 235 | num_workers=0, 236 | shuffle=True): 237 | self.dataset = dataset 238 | self.shuffle = shuffle 239 | self.epoch_size = epoch_size if epoch_size is not None else len(dataset) 240 | self.batch_size = batch_size 241 | self.unsupervised = unsupervised 242 | self.num_workers = num_workers 243 | 244 | mean_pix = self.dataset.mean_pix 245 | std_pix = self.dataset.std_pix 246 | self.transform = transforms.Compose([ 247 | transforms.ToTensor(), 248 | transforms.Normalize(mean=mean_pix, std=std_pix) 249 | ]) 250 | self.inv_transform = transforms.Compose([ 251 | Denormalize(mean_pix, std_pix), 252 | lambda x: x.numpy() * 255.0, 253 | lambda x: x.transpose(1,2,0).astype(np.uint8), 254 | ]) 255 | 256 | def get_iterator(self, epoch=0): 257 | rand_seed = epoch * self.epoch_size 258 | random.seed(rand_seed) 259 | if self.unsupervised: 260 | # if in unsupervised mode define a loader function that given the 261 | # index of an image it returns the 4 rotated copies of the image 262 | # plus the label of the rotation, i.e., 0 for 0 degrees rotation, 263 | # 1 for 90 degrees, 2 for 180 degrees, and 3 for 270 degrees. 264 | def _load_function(idx): 265 | idx = idx % len(self.dataset) 266 | img0, _ = self.dataset[idx] 267 | rotated_imgs = [ 268 | self.transform(img0), 269 | self.transform(rotate_img(img0, 90)), 270 | self.transform(rotate_img(img0, 180)), 271 | self.transform(rotate_img(img0, 270)) 272 | ] 273 | rotation_labels = torch.LongTensor([0, 1, 2, 3]) 274 | return torch.stack(rotated_imgs, dim=0), rotation_labels 275 | def _collate_fun(batch): 276 | batch = default_collate(batch) 277 | assert(len(batch)==2) 278 | batch_size, rotations, channels, height, width = batch[0].size() 279 | batch[0] = batch[0].view([batch_size*rotations, channels, height, width]) 280 | batch[1] = batch[1].view([batch_size*rotations]) 281 | return batch 282 | else: # supervised mode 283 | # if in supervised mode define a loader function that given the 284 | # index of an image it returns the image and its categorical label 285 | def _load_function(idx): 286 | idx = idx % len(self.dataset) 287 | img, categorical_label = self.dataset[idx] 288 | img = self.transform(img) 289 | return img, categorical_label 290 | _collate_fun = default_collate 291 | 292 | tnt_dataset = tnt.dataset.ListDataset(elem_list=range(self.epoch_size), 293 | load=_load_function) 294 | data_loader = tnt_dataset.parallel(batch_size=self.batch_size, 295 | collate_fn=_collate_fun, num_workers=self.num_workers, 296 | shuffle=self.shuffle) 297 | return data_loader 298 | 299 | def __call__(self, epoch=0): 300 | return self.get_iterator(epoch) 301 | 302 | def __len__(self): 303 | return self.epoch_size / self.batch_size 304 | 305 | if __name__ == '__main__': 306 | from matplotlib import pyplot as plt 307 | 308 | dataset = GenericDataset('imagenet','train', random_sized_crop=True) 309 | dataloader = DataLoader(dataset, batch_size=8, unsupervised=True) 310 | 311 | for b in dataloader(0): 312 | data, label = b 313 | break 314 | 315 | inv_transform = dataloader.inv_transform 316 | for i in range(data.size(0)): 317 | plt.subplot(data.size(0)/4,4,i+1) 318 | fig=plt.imshow(inv_transform(data[i])) 319 | fig.axes.get_xaxis().set_visible(False) 320 | fig.axes.get_yaxis().set_visible(False) 321 | 322 | plt.show() 323 | -------------------------------------------------------------------------------- /extras/AlexNet.prototxt: -------------------------------------------------------------------------------- 1 | name: "AlexNet" 2 | layer { 3 | name: "data" type: "Input" top: "data" 4 | input_param { shape: { dim: 4 dim: 3 dim: 224 dim: 224 } } 5 | } 6 | layer { 7 | name: "conv1" type: "Convolution" bottom: "data" top: "conv1" 8 | param { lr_mult: 1 decay_mult: 1} 9 | param { lr_mult: 2 decay_mult: 0} 10 | convolution_param { num_output: 64 kernel_size: 11 stride: 4 pad: 2 } 11 | } 12 | layer { 13 | name: "relu1" type: "ReLU" bottom: "conv1" top: "conv1" 14 | } 15 | layer { 16 | name: "pool1" type: "Pooling" bottom: "conv1" top: "pool1" 17 | pooling_param {pool: MAX kernel_size: 3 stride: 2} 18 | } 19 | layer { 20 | name: "conv2" type: "Convolution" bottom: "pool1" top: "conv2" 21 | param { lr_mult: 1 decay_mult: 1 } 22 | param { lr_mult: 2 decay_mult: 0 } 23 | convolution_param { num_output: 192 pad: 2 kernel_size: 5} 24 | } 25 | layer { 26 | name: "relu2" type: "ReLU" bottom: "conv2" top: "conv2" 27 | } 28 | layer { 29 | name: "pool2" type: "Pooling" bottom: "conv2" top: "pool2" 30 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 31 | } 32 | layer { 33 | name: "conv3" type: "Convolution" bottom: "pool2" top: "conv3" 34 | param { lr_mult: 1 decay_mult: 1 } 35 | param { lr_mult: 2 decay_mult: 0 } 36 | convolution_param { num_output: 384 pad: 1 kernel_size: 3 } 37 | } 38 | layer { 39 | name: "relu3" type: "ReLU" bottom: "conv3" top: "conv3" 40 | } 41 | layer { 42 | name: "conv4" type: "Convolution" bottom: "conv3" top: "conv4" 43 | param { lr_mult: 1 decay_mult: 1 } 44 | param { lr_mult: 2 decay_mult: 0 } 45 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 46 | } 47 | layer { 48 | name: "relu4" type: "ReLU" bottom: "conv4" top: "conv4" 49 | } 50 | layer { 51 | name: "conv5" type: "Convolution" bottom: "conv4" top: "conv5" 52 | param { lr_mult: 1 decay_mult: 1 } 53 | param { lr_mult: 2 decay_mult: 0 } 54 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 55 | } 56 | layer { 57 | name: "relu5" type: "ReLU" bottom: "conv5" top: "conv5" 58 | } 59 | layer { 60 | name: "pool5" type: "Pooling" bottom: "conv5" top: "pool5" 61 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 62 | } 63 | layer { 64 | name: "fc6" type: "InnerProduct" bottom: "pool5" top: "fc6" 65 | param { lr_mult: 1 decay_mult: 1 } 66 | param { lr_mult: 2 decay_mult: 0 } 67 | inner_product_param { num_output: 4096 } 68 | } 69 | layer { 70 | name: "relu6" type: "ReLU" bottom: "fc6" top: "fc6" 71 | } 72 | layer { 73 | name: "fc7" type: "InnerProduct" bottom: "fc6" top: "fc7" 74 | param { lr_mult: 1 decay_mult: 1 } 75 | param { lr_mult: 2 decay_mult: 0 } 76 | inner_product_param { num_output: 4096 } 77 | } 78 | layer { 79 | name: "relu7" type: "ReLU" bottom: "fc7" top: "fc7" 80 | } 81 | -------------------------------------------------------------------------------- /extras/AlexNet_fcn.prototxt: -------------------------------------------------------------------------------- 1 | name: "AlexNet" 2 | layer { 3 | name: "data" type: "Input" top: "data" 4 | input_param { shape: { dim: 4 dim: 3 dim: 224 dim: 224 } } 5 | } 6 | layer { 7 | name: "conv1" type: "Convolution" bottom: "data" top: "conv1" 8 | param { lr_mult: 1 decay_mult: 1} 9 | param { lr_mult: 2 decay_mult: 0} 10 | convolution_param { num_output: 64 kernel_size: 11 stride: 4 pad: 2 } 11 | } 12 | layer { 13 | name: "relu1" type: "ReLU" bottom: "conv1" top: "conv1" 14 | } 15 | layer { 16 | name: "pool1" type: "Pooling" bottom: "conv1" top: "pool1" 17 | pooling_param {pool: MAX kernel_size: 3 stride: 2} 18 | } 19 | layer { 20 | name: "conv2" type: "Convolution" bottom: "pool1" top: "conv2" 21 | param { lr_mult: 1 decay_mult: 1 } 22 | param { lr_mult: 2 decay_mult: 0 } 23 | convolution_param { num_output: 192 pad: 2 kernel_size: 5} 24 | } 25 | layer { 26 | name: "relu2" type: "ReLU" bottom: "conv2" top: "conv2" 27 | } 28 | layer { 29 | name: "pool2" type: "Pooling" bottom: "conv2" top: "pool2" 30 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 31 | } 32 | layer { 33 | name: "conv3" type: "Convolution" bottom: "pool2" top: "conv3" 34 | param { lr_mult: 1 decay_mult: 1 } 35 | param { lr_mult: 2 decay_mult: 0 } 36 | convolution_param { num_output: 384 pad: 1 kernel_size: 3 } 37 | } 38 | layer { 39 | name: "relu3" type: "ReLU" bottom: "conv3" top: "conv3" 40 | } 41 | layer { 42 | name: "conv4" type: "Convolution" bottom: "conv3" top: "conv4" 43 | param { lr_mult: 1 decay_mult: 1 } 44 | param { lr_mult: 2 decay_mult: 0 } 45 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 46 | } 47 | layer { 48 | name: "relu4" type: "ReLU" bottom: "conv4" top: "conv4" 49 | } 50 | layer { 51 | name: "conv5" type: "Convolution" bottom: "conv4" top: "conv5" 52 | param { lr_mult: 1 decay_mult: 1 } 53 | param { lr_mult: 2 decay_mult: 0 } 54 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 55 | } 56 | layer { 57 | name: "relu5" type: "ReLU" bottom: "conv5" top: "conv5" 58 | } 59 | layer { 60 | name: "pool5" type: "Pooling" bottom: "conv5" top: "pool5" 61 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 62 | } 63 | layer { 64 | name: "fc6" type: "Convolution" bottom: "pool5" top: "fc6" 65 | param { lr_mult: 1 decay_mult: 1 } 66 | param { lr_mult: 2 decay_mult: 0 } 67 | convolution_param { 68 | num_output: 4096 kernel_size: 6 69 | } 70 | } 71 | layer { 72 | name: "relu6" type: "ReLU" bottom: "fc6" top: "fc6" 73 | } 74 | layer { 75 | name: "fc7" type: "Convolution" bottom: "fc6" top: "fc7" 76 | param { lr_mult: 1 decay_mult: 1 } 77 | param { lr_mult: 2 decay_mult: 0 } 78 | convolution_param { 79 | num_output: 4096 kernel_size: 1 80 | } 81 | } 82 | layer { 83 | name: "relu7" type: "ReLU" bottom: "fc7" top: "fc7" 84 | } 85 | -------------------------------------------------------------------------------- /extras/AlexNet_rescaled.prototxt: -------------------------------------------------------------------------------- 1 | name: "AlexNet" 2 | # no batch norm layers in this prototxt 3 | # for rescaling 4 | # ************************ 5 | # ***** DATA LAYER ******* 6 | # ************************ 7 | layer { 8 | name: "data" 9 | type: "Data" 10 | top: "data" 11 | include { phase: TRAIN } 12 | transform_param { 13 | mirror: 1 # 1 = on, 0 = off 14 | crop_size: 224 15 | } 16 | data_param { 17 | source: "/home/spyros/D2TBdatasets/imagenet_imdb/ilsvrc12_train_lmdb" 18 | batch_size: 80 19 | backend: LMDB 20 | } 21 | } 22 | layer { 23 | name: "data" 24 | type: "Data" 25 | top: "data" 26 | include { phase: TEST } 27 | transform_param { 28 | mirror: 0 # 1 = on, 0 = off 29 | crop_size: 224 30 | } 31 | data_param { 32 | source: "/home/spyros/D2TBdatasets/imagenet_imdb/ilsvrc12_val_lmdb" 33 | batch_size: 10 34 | backend: LMDB 35 | } 36 | } 37 | layer { 38 | type: "Input" 39 | name: "label" 40 | top: "label" 41 | include { phase: TRAIN } 42 | input_param { shape { dim: 80 dim: 1 } } 43 | } 44 | layer { 45 | type: "Input" 46 | name: "label" 47 | top: "label" 48 | include { phase: TEST } 49 | input_param { shape { dim: 1 dim: 1 } } 50 | } 51 | layer { 52 | name: "conv1" type: "Convolution" bottom: "data" top: "conv1" 53 | param { lr_mult: 1 decay_mult: 1} 54 | param { lr_mult: 2 decay_mult: 0} 55 | convolution_param { num_output: 64 kernel_size: 11 stride: 4 pad: 2 } 56 | } 57 | layer { 58 | name: "relu1" type: "ReLU" bottom: "conv1" top: "conv1" 59 | } 60 | layer { 61 | name: "pool1" type: "Pooling" bottom: "conv1" top: "pool1" 62 | pooling_param {pool: MAX kernel_size: 3 stride: 2} 63 | } 64 | layer { 65 | name: "conv2" type: "Convolution" bottom: "pool1" top: "conv2" 66 | param { lr_mult: 1 decay_mult: 1 } 67 | param { lr_mult: 2 decay_mult: 0 } 68 | convolution_param { num_output: 192 pad: 2 kernel_size: 5} 69 | } 70 | layer { 71 | name: "relu2" type: "ReLU" bottom: "conv2" top: "conv2" 72 | } 73 | layer { 74 | name: "pool2" type: "Pooling" bottom: "conv2" top: "pool2" 75 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 76 | } 77 | layer { 78 | name: "conv3" type: "Convolution" bottom: "pool2" top: "conv3" 79 | param { lr_mult: 1 decay_mult: 1 } 80 | param { lr_mult: 2 decay_mult: 0 } 81 | convolution_param { num_output: 384 pad: 1 kernel_size: 3 } 82 | } 83 | layer { 84 | name: "relu3" type: "ReLU" bottom: "conv3" top: "conv3" 85 | } 86 | layer { 87 | name: "conv4" type: "Convolution" bottom: "conv3" top: "conv4" 88 | param { lr_mult: 1 decay_mult: 1 } 89 | param { lr_mult: 2 decay_mult: 0 } 90 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 91 | } 92 | layer { 93 | name: "relu4" type: "ReLU" bottom: "conv4" top: "conv4" 94 | } 95 | layer { 96 | name: "conv5" type: "Convolution" bottom: "conv4" top: "conv5" 97 | param { lr_mult: 1 decay_mult: 1 } 98 | param { lr_mult: 2 decay_mult: 0 } 99 | convolution_param { num_output: 256 pad: 1 kernel_size: 3 } 100 | } 101 | layer { 102 | name: "relu5" type: "ReLU" bottom: "conv5" top: "conv5" 103 | } 104 | layer { 105 | name: "pool5" type: "Pooling" bottom: "conv5" top: "pool5" 106 | pooling_param { pool: MAX kernel_size: 3 stride: 2 } 107 | } 108 | layer { 109 | name: "fc6" type: "InnerProduct" bottom: "pool5" top: "fc6" 110 | param { lr_mult: 1 decay_mult: 1 } 111 | param { lr_mult: 2 decay_mult: 0 } 112 | inner_product_param { num_output: 4096 } 113 | } 114 | layer { 115 | name: "relu6" type: "ReLU" bottom: "fc6" top: "fc6" 116 | } 117 | layer { 118 | name: "fc7" type: "InnerProduct" bottom: "fc6" top: "fc7" 119 | param { lr_mult: 1 decay_mult: 1 } 120 | param { lr_mult: 2 decay_mult: 0 } 121 | inner_product_param { num_output: 4096 } 122 | } 123 | layer { 124 | name: "relu7" type: "ReLU" bottom: "fc7" top: "fc7" 125 | } 126 | layer { 127 | name: "fc8_" type: "InnerProduct" bottom: "fc7" top: "fc8_" 128 | param { lr_mult: 1 decay_mult: 1 } 129 | param { lr_mult: 2 decay_mult: 0 } 130 | inner_product_param { 131 | num_output: 1000 132 | weight_filler { type: "gaussian" std: 0.01} 133 | bias_filler { type: "constant" value: 1 } 134 | } 135 | } 136 | layer { 137 | name: "Softmax" 138 | type: "SoftmaxWithLoss" 139 | bottom: "fc8_" 140 | bottom: "label" 141 | top: "loss" 142 | loss_weight: 1 143 | } 144 | -------------------------------------------------------------------------------- /extras/AlexNet_without_BN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from pdb import set_trace as breakpoint 7 | 8 | class Flatten(nn.Module): 9 | def __init__(self): 10 | super(Flatten, self).__init__() 11 | 12 | def forward(self, feat): 13 | return feat.view(feat.size(0), -1) 14 | 15 | class AlexNet(nn.Module): 16 | def __init__(self, opt): 17 | super(AlexNet, self).__init__() 18 | num_classes = opt['num_classes'] 19 | 20 | conv1 = nn.Sequential( 21 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 22 | nn.ReLU(inplace=True), 23 | ) 24 | pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 25 | conv2 = nn.Sequential( 26 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 27 | nn.ReLU(inplace=True), 28 | ) 29 | pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 30 | conv3 = nn.Sequential( 31 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 32 | nn.ReLU(inplace=True), 33 | ) 34 | conv4 = nn.Sequential( 35 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | ) 38 | conv5 = nn.Sequential( 39 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 40 | nn.ReLU(inplace=True), 41 | ) 42 | pool5 = nn.MaxPool2d(kernel_size=3, stride=2) 43 | 44 | num_pool5_feats = 6 * 6 * 256 45 | fc_block = nn.Sequential( 46 | Flatten(), 47 | nn.Linear(num_pool5_feats, 4096), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(4096, 4096), 50 | nn.ReLU(inplace=True), 51 | ) 52 | classifier = nn.Sequential( 53 | nn.Linear(4096, num_classes), 54 | ) 55 | 56 | self._feature_blocks = nn.ModuleList([ 57 | conv1, 58 | pool1, 59 | conv2, 60 | pool2, 61 | conv3, 62 | conv4, 63 | conv5, 64 | pool5, 65 | fc_block, 66 | classifier, 67 | ]) 68 | self.all_feat_names = [ 69 | 'conv1', 70 | 'pool1', 71 | 'conv2', 72 | 'pool2', 73 | 'conv3', 74 | 'conv4', 75 | 'conv5', 76 | 'pool5', 77 | 'fc_block', 78 | 'classifier', 79 | ] 80 | assert(len(self.all_feat_names) == len(self._feature_blocks)) 81 | 82 | def _parse_out_keys_arg(self, out_feat_keys): 83 | 84 | # By default return the features of the last layer / module. 85 | out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys 86 | 87 | if len(out_feat_keys) == 0: 88 | raise ValueError('Empty list of output feature keys.') 89 | for f, key in enumerate(out_feat_keys): 90 | if key not in self.all_feat_names: 91 | raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names)) 92 | elif key in out_feat_keys[:f]: 93 | raise ValueError('Duplicate output feature key: {0}.'.format(key)) 94 | 95 | # Find the highest output feature in `out_feat_keys 96 | max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys]) 97 | 98 | return out_feat_keys, max_out_feat 99 | 100 | def forward(self, x, out_feat_keys=None): 101 | """Forward an image `x` through the network and return the asked output features. 102 | 103 | Args: 104 | x: input image. 105 | out_feat_keys: a list/tuple with the feature names of the features 106 | that the function should return. By default the last feature of 107 | the network is returned. 108 | 109 | Return: 110 | out_feats: If multiple output features were asked then `out_feats` 111 | is a list with the asked output features placed in the same 112 | order as in `out_feat_keys`. If a single output feature was 113 | asked then `out_feats` is that output feature (and not a list). 114 | """ 115 | out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys) 116 | out_feats = [None] * len(out_feat_keys) 117 | 118 | feat = x 119 | for f in range(max_out_feat+1): 120 | feat = self._feature_blocks[f](feat) 121 | key = self.all_feat_names[f] 122 | if key in out_feat_keys: 123 | out_feats[out_feat_keys.index(key)] = feat 124 | 125 | out_feats = out_feats[0] if len(out_feats)==1 else out_feats 126 | return out_feats 127 | 128 | def get_L1filters(self): 129 | convlayer = self._feature_blocks[0][0] 130 | batchnorm = self._feature_blocks[0][1] 131 | filters = convlayer.weight.data 132 | scalars = (batchnorm.weight.data / torch.sqrt(batchnorm.running_var + 1e-05)) 133 | filters = (filters * scalars.view(-1, 1, 1, 1).expand_as(filters)).cpu().clone() 134 | 135 | return filters 136 | 137 | def create_model(opt): 138 | return AlexNet(opt) 139 | 140 | if __name__ == '__main__': 141 | size = 224 142 | opt = {'num_classes':4} 143 | 144 | net = create_model(opt) 145 | x = torch.autograd.Variable(torch.FloatTensor(1,3,size,size).uniform_(-1,1)) 146 | 147 | out = net(x, out_feat_keys=net.all_feat_names) 148 | for f in range(len(out)): 149 | print('Output feature {0} - size {1}'.format( 150 | net.all_feat_names[f], out[f].size())) 151 | 152 | filters = net.get_L1filters() 153 | 154 | print('First layer filter shape: {0}'.format(filters.size())) 155 | -------------------------------------------------------------------------------- /extras/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gidariss/FeatureLearningRotNet/edafaba6561003292017d0a44272414356c959ff/extras/cat.jpg -------------------------------------------------------------------------------- /extras/convert_alexnet_from_pytorch2caffe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | 5 | import os 6 | import imp 7 | import argparse 8 | import numpy as np 9 | from PIL import Image 10 | import caffe 11 | 12 | def pytorch_init_network(net_def_file, opt): 13 | return imp.load_source("",net_def_file).create_model(opt) 14 | 15 | def pytorch_load_pretrained(network, pretrained_path): 16 | print('==> Load pretrained parameters from file %s:' % (pretrained_path)) 17 | assert(os.path.isfile(pretrained_path)) 18 | pretrained_model = torch.load(pretrained_path) 19 | network.load_state_dict(pretrained_model['network']) 20 | 21 | def merge_linear_with_bn(linearlayer, batchnorm): 22 | filters = linearlayer.weight.data 23 | bias = linearlayer.bias.data if (linearlayer.bias is not None) else 0 24 | 25 | epsilon = batchnorm.eps 26 | gamma = batchnorm.weight.data 27 | beta = batchnorm.bias.data 28 | mu = batchnorm.running_mean 29 | var = batchnorm.running_var 30 | 31 | scale_f = gamma / torch.sqrt(var+epsilon) 32 | print('==> Merge linear filters of size {} with batch norm layer'.format(filters.size())) 33 | if filters.dim() == 4: 34 | mergedfilters = filters.clone() * scale_f.view(-1, 1, 1, 1).expand_as(filters) 35 | else: 36 | assert(filters.dim()==2) 37 | mergedfilters = filters.clone() * scale_f.view(-1, 1).expand_as(filters) 38 | 39 | mergedbias = (bias-mu) * scale_f + beta 40 | 41 | return mergedfilters, mergedbias 42 | 43 | def copy_weights_to_linear_layer(linearlayer, weights, bias): 44 | linearlayer.weight.data.copy_(weights) 45 | linearlayer.bias.data.copy_(bias) 46 | 47 | def copy_params_to_the_no_bn_net(net_with_bn, net_without_bn): 48 | net_with_bn_layers_dict = dict(zip(net_with_bn.all_feat_names, 49 | net_with_bn._feature_blocks)) 50 | net_without_bn_layers_dict = dict(zip(net_without_bn.all_feat_names, 51 | net_without_bn._feature_blocks)) 52 | 53 | # 1st conv. block 54 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['conv1'][0], 55 | net_with_bn_layers_dict['conv1'][1]) 56 | copy_weights_to_linear_layer(net_without_bn_layers_dict['conv1'][0], 57 | mfilters, mbias) 58 | # 2nd conv. block 59 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['conv2'][0], 60 | net_with_bn_layers_dict['conv2'][1]) 61 | copy_weights_to_linear_layer(net_without_bn_layers_dict['conv2'][0], 62 | mfilters, mbias) 63 | 64 | # 3rd conv. block 65 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['conv3'][0], 66 | net_with_bn_layers_dict['conv3'][1]) 67 | copy_weights_to_linear_layer(net_without_bn_layers_dict['conv3'][0], 68 | mfilters, mbias) 69 | 70 | # 4th conv. block 71 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['conv4'][0], 72 | net_with_bn_layers_dict['conv4'][1]) 73 | copy_weights_to_linear_layer(net_without_bn_layers_dict['conv4'][0], 74 | mfilters, mbias) 75 | 76 | # 5th conv. block 77 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['conv5'][0], 78 | net_with_bn_layers_dict['conv5'][1]) 79 | copy_weights_to_linear_layer(net_without_bn_layers_dict['conv5'][0], 80 | mfilters, mbias) 81 | 82 | # hidden linear layers 83 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['fc_block'][1], 84 | net_with_bn_layers_dict['fc_block'][2]) 85 | copy_weights_to_linear_layer(net_without_bn_layers_dict['fc_block'][1], 86 | mfilters, mbias) 87 | mfilters, mbias = merge_linear_with_bn(net_with_bn_layers_dict['fc_block'][4], 88 | net_with_bn_layers_dict['fc_block'][5]) 89 | copy_weights_to_linear_layer(net_without_bn_layers_dict['fc_block'][3], 90 | mfilters, mbias) 91 | 92 | # classifier 93 | mfilters = net_with_bn_layers_dict['classifier'][0].weight.data 94 | mbias = net_with_bn_layers_dict['classifier'][0].bias.data 95 | copy_weights_to_linear_layer(net_without_bn_layers_dict['classifier'][0], 96 | mfilters, mbias) 97 | 98 | 99 | def printNetParamSizes(net_caffe): 100 | for param_name in net_caffe.params.keys(): 101 | print("Parameter {0}".format(param_name)) 102 | for idx, param in enumerate(net_caffe.params[param_name]): 103 | print("\t Shape[{0}] : {1}".format(idx, param.data.shape)) 104 | 105 | def print_data_stats(adata): 106 | adata = np.abs(adata) 107 | print("Abs: [mu: {0}, std: {1}, min: {2} max: {3}]".format( 108 | adata.mean(), adata.std(), adata.min(), adata.max() )) 109 | 110 | 111 | def copy_params_pytorch2caffe(net_pytorch, net_caffe, std_pixel_values): 112 | 113 | net_pytorch_layers_dict = dict(zip(net_pytorch.all_feat_names, 114 | net_pytorch._feature_blocks)) 115 | 116 | # 1st conv. block 117 | filters = net_pytorch_layers_dict['conv1'][0].weight.data.cpu().numpy() 118 | bias = net_pytorch_layers_dict['conv1'][0].bias.data.cpu().numpy() 119 | print("Layer {0} copy params with shape: {1} and {2}".format( 120 | 'conv1', filters.shape, bias.shape)) 121 | net_caffe.params['conv1'][0].data[:,:,:,:] = filters 122 | net_caffe.params['conv1'][1].data[:] = bias 123 | 124 | assert(net_caffe.params['conv1'][0].data.shape[1] == 3) 125 | net_caffe.params['conv1'][0].data[:,0,:,:] /= (std_pixel_values[0]*255.0) 126 | net_caffe.params['conv1'][0].data[:,1,:,:] /= (std_pixel_values[1]*255.0) 127 | net_caffe.params['conv1'][0].data[:,2,:,:] /= (std_pixel_values[2]*255.0) 128 | 129 | filters = net_caffe.params['conv1'][0].data.copy() 130 | net_caffe.params['conv1'][0].data[:,:,:,:] = filters[:,::-1,:,:] 131 | 132 | print("Conv 1 weight:") 133 | print_data_stats(net_caffe.params['conv1'][0].data) 134 | print("Conv 1 bias:") 135 | print_data_stats(net_caffe.params['conv1'][1].data) 136 | 137 | #print(net_caffe.params['conv1'][0].data.mean()) 138 | 139 | # 2nd conv. block 140 | filters = net_pytorch_layers_dict['conv2'][0].weight.data.cpu().numpy() 141 | bias = net_pytorch_layers_dict['conv2'][0].bias.data.cpu().numpy() 142 | print("Layer {0} copy params with shape: {1} and {2}".format( 143 | 'conv2', filters.shape, bias.shape)) 144 | net_caffe.params['conv2'][0].data[:,:,:,:] = filters 145 | net_caffe.params['conv2'][1].data[:] = bias 146 | 147 | print("Conv 2 weight:") 148 | print_data_stats(net_caffe.params['conv2'][0].data) 149 | print("Conv 2 bias:") 150 | print_data_stats(net_caffe.params['conv2'][1].data) 151 | 152 | # 3rd conv. block 153 | filters = net_pytorch_layers_dict['conv3'][0].weight.data.cpu().numpy() 154 | bias = net_pytorch_layers_dict['conv3'][0].bias.data.cpu().numpy() 155 | print("Layer {0} copy params with shape: {1} and {2}".format( 156 | 'conv3', filters.shape, bias.shape)) 157 | net_caffe.params['conv3'][0].data[:,:,:,:] = filters 158 | net_caffe.params['conv3'][1].data[:] = bias 159 | 160 | print("Conv 3 weight:") 161 | print_data_stats(net_caffe.params['conv3'][0].data) 162 | print("Conv 3 bias:") 163 | print_data_stats(net_caffe.params['conv3'][1].data) 164 | 165 | # 4th conv. block 166 | filters = net_pytorch_layers_dict['conv4'][0].weight.data.cpu().numpy() 167 | bias = net_pytorch_layers_dict['conv4'][0].bias.data.cpu().numpy() 168 | print("Layer {0} copy params with shape: {1} and {2}".format( 169 | 'conv4', filters.shape, bias.shape)) 170 | net_caffe.params['conv4'][0].data[:,:,:,:] = filters 171 | net_caffe.params['conv4'][1].data[:] = bias 172 | 173 | print("Conv 4 weight:") 174 | print_data_stats(net_caffe.params['conv4'][0].data) 175 | print("Conv 4 bias:") 176 | print_data_stats(net_caffe.params['conv4'][1].data) 177 | 178 | # 5th conv. block 179 | filters = net_pytorch_layers_dict['conv5'][0].weight.data.cpu().numpy() 180 | bias = net_pytorch_layers_dict['conv5'][0].bias.data.cpu().numpy() 181 | print("Layer {0} copy params with shape: {1} and {2}".format( 182 | 'conv5', filters.shape, bias.shape)) 183 | net_caffe.params['conv5'][0].data[:,:,:,:] = filters 184 | net_caffe.params['conv5'][1].data[:] = bias 185 | 186 | print("Conv 5 weight:") 187 | print_data_stats(net_caffe.params['conv5'][0].data) 188 | print("Conv 5 bias:") 189 | print_data_stats(net_caffe.params['conv5'][1].data) 190 | 191 | # fully connected: fc6 192 | filters = net_pytorch_layers_dict['fc_block'][1].weight.data.cpu().numpy() 193 | bias = net_pytorch_layers_dict['fc_block'][1].bias.data.cpu().numpy() 194 | print("Layer {0} copy params with shape: {1} and {2}".format( 195 | 'fc6', filters.shape, bias.shape)) 196 | net_caffe.params['fc6'][0].data[:,:] = filters 197 | net_caffe.params['fc6'][1].data[:] = bias 198 | 199 | # fully connected: fc7 200 | filters = net_pytorch_layers_dict['fc_block'][3].weight.data.cpu().numpy() 201 | bias = net_pytorch_layers_dict['fc_block'][3].bias.data.cpu().numpy() 202 | print("Layer {0} copy params with shape: {1} and {2}".format( 203 | 'fc7', filters.shape, bias.shape)) 204 | net_caffe.params['fc7'][0].data[:,:] = filters 205 | net_caffe.params['fc7'][1].data[:] = bias 206 | 207 | 208 | class normalize_numpy(object): 209 | def __init__(self, mean, std): 210 | self.mean = mean 211 | self.std = std 212 | 213 | def __call__(self, array): 214 | assert(len(array.shape)==3) 215 | assert(array.shape[0] == len(self.mean)) 216 | assert(array.shape[0] == len(self.std)) 217 | for c in range(array.shape[0]): 218 | array[c,:,:] -= self.mean[c] 219 | array[c,:,:] /= self.std[c] 220 | return array 221 | 222 | 223 | def prind_data_diffs(data1, data2): 224 | print("Data 1 shape {0} Data 2 shape {1}".format(data1.shape, data2.shape)) 225 | abs_diff = np.abs(data1 - data2) 226 | max_diff = abs_diff.max() 227 | mu_diff = abs_diff.mean() 228 | print("Max diff {0} Mu diff {1} Max1 {2} Max2 {3} Mu1 {4} Mu2 {5}".format( 229 | max_diff, mu_diff, data1.max(), data2.max(), data1.mean(), data2.mean())) 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument('--src', type=str, required=True, default='', help='Source location of the Alexnet model saved in pytorch format') 234 | parser.add_argument('--dst', type=str, required=True, default='', help='Destination location where the Alexnet model will be saved in caffe format') 235 | args_opt = parser.parse_args() 236 | 237 | pytorch_alexnet_def = './architectures/AlexNet.py' 238 | src_pytorch_alexnet_params = args_opt.src 239 | pytorch_alexnet_without_bn_def = './extras/AlexNet_without_BN.py' 240 | dst_caffe_alexnet_params = args_opt.dst 241 | 242 | net_opt = {'num_classes': 4} 243 | net_with_bn = pytorch_init_network(pytorch_alexnet_def, net_opt) 244 | net_without_bn = pytorch_init_network(pytorch_alexnet_without_bn_def, net_opt) 245 | pytorch_load_pretrained(net_with_bn, src_pytorch_alexnet_params) 246 | net_with_bn.eval() 247 | net_without_bn.eval() 248 | 249 | # Merge the batch norm layers to the linear layers 250 | copy_params_to_the_no_bn_net(net_with_bn, net_without_bn) 251 | 252 | 253 | caffe.set_mode_cpu() 254 | caffe_model_def = './extras/AlexNet.prototxt' 255 | net_caffe = caffe.Net(caffe_model_def, caffe.TEST) 256 | 257 | mean_pix = [0.485, 0.456, 0.406] 258 | std_pix = [0.229, 0.224, 0.225] 259 | mean_pix_caffe = [0.406*255.0, 0.456*255.0, 0.485*255.0] 260 | 261 | # Copy weights to the caffe model 262 | net_pytorch = net_without_bn 263 | copy_params_pytorch2caffe(net_pytorch, net_caffe, std_pix) 264 | 265 | transform = [] 266 | transform.append(transforms.Scale(256)) 267 | transform.append(transforms.CenterCrop(224)) 268 | transform.append(lambda x: np.asarray(x)) 269 | transform = transforms.Compose(transform) 270 | 271 | torch_transform = [transforms.ToTensor(), transforms.Normalize(mean=mean_pix, std=std_pix)] 272 | torch_transform = transforms.Compose([ 273 | transforms.ToTensor(), 274 | transforms.Normalize(mean=mean_pix, std=std_pix), 275 | lambda x: x.view([1,] + list(x.size())), 276 | ]) 277 | 278 | caffe_transform = transforms.Compose([ 279 | lambda x: x.transpose(2,0,1), 280 | lambda x: x.astype(np.float), 281 | lambda x: x[::-1,:,:], 282 | normalize_numpy(mean=mean_pix_caffe, std=[1.0, 1.0, 1.0]), 283 | lambda x: x.reshape((1,)+x.shape), 284 | ]) 285 | 286 | image_file = './extras/cat.jpg' 287 | img = transform(Image.open(image_file)) 288 | img_torch = torch_transform(img) 289 | img_caffe = caffe_transform(img) 290 | 291 | net_caffe.blobs['data'].data[...] = img_caffe 292 | out_caffe = net_caffe.forward() 293 | 294 | img_torch_var = torch.autograd.Variable(img_torch, volatile=True) 295 | out_pytorch = net_pytorch(img_torch_var, ['fc_block',]) 296 | 297 | data_caffe = out_caffe['fc7'].copy() 298 | data_pytorch = out_pytorch.data.cpu().numpy() 299 | abs_diff = np.abs(data_pytorch - data_caffe) 300 | max_diff = abs_diff.max() 301 | 302 | print('==> Maximum data elements difference between torch and caffe: {}'.format(max_diff)) 303 | 304 | if not os.path.isfile(dst_caffe_alexnet_params): 305 | print('==> Saving caffe alexnet model at {0}'.format(dst_caffe_alexnet_params)) 306 | net_caffe.save(dst_caffe_alexnet_params) 307 | else: 308 | print('==> File {0} already exists'.format(dst_caffe_alexnet_params)) 309 | 310 | -------------------------------------------------------------------------------- /extras/convert_caffe_alexnet_to_fcn.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | import math 6 | import caffe 7 | import math 8 | import os 9 | 10 | def printNetParamSizes(net): 11 | for param_name in net.params.keys(): 12 | print("Parameter {0}".format(param_name)) 13 | for idx, param in enumerate(net.params[param_name]): 14 | print("\t Shape[{0}] : {1}".format(idx, param.data.shape)) 15 | 16 | def print_data_stats(adata): 17 | adata = np.abs(adata) 18 | print("Abs: [mu: {0}, std: {1}, min: {2} max: {3}]".format(\ 19 | adata.mean(), adata.std(), adata.min(), adata.max() )) 20 | 21 | def copy_params_fc2fcn(net_fc, net_fcn): 22 | # 1st conv. block 23 | filters = net_fc.params['conv1'][0].data 24 | bias = net_fc.params['conv1'][1].data 25 | print("Layer {0} copy params with shape: {1} and {2}".format('conv1', filters.shape, bias.shape)) 26 | net_fcn.params['conv1'][0].data[:,:,:,:] = filters 27 | net_fcn.params['conv1'][1].data[:] = bias 28 | print("Conv 1 weight:") 29 | print_data_stats(net_fcn.params['conv1'][0].data) 30 | print("Conv 1 bias:") 31 | print_data_stats(net_fcn.params['conv1'][1].data) 32 | 33 | # 2nd conv. block 34 | filters = net_fc.params['conv2'][0].data 35 | bias = net_fc.params['conv2'][1].data 36 | print("Layer {0} copy params with shape: {1} and {2}".format('conv2', filters.shape, bias.shape)) 37 | net_fcn.params['conv2'][0].data[:,:,:,:] = filters 38 | net_fcn.params['conv2'][1].data[:] = bias 39 | print("Conv 2 weight:") 40 | print_data_stats(net_fcn.params['conv2'][0].data) 41 | print("Conv 2 bias:") 42 | print_data_stats(net_fcn.params['conv2'][1].data) 43 | 44 | # 3rd conv. block 45 | filters = net_fc.params['conv3'][0].data 46 | bias = net_fc.params['conv3'][1].data 47 | print("Layer {0} copy params with shape: {1} and {2}".format('conv3', filters.shape, bias.shape)) 48 | net_fcn.params['conv3'][0].data[:,:,:,:] = filters 49 | net_fcn.params['conv3'][1].data[:] = bias 50 | print("Conv 3 weight:") 51 | print_data_stats(net_fcn.params['conv3'][0].data) 52 | print("Conv 3 bias:") 53 | print_data_stats(net_fcn.params['conv3'][1].data) 54 | 55 | # 4th conv. block 56 | filters = net_fc.params['conv4'][0].data 57 | bias = net_fc.params['conv4'][1].data 58 | print("Layer {0} copy params with shape: {1} and {2}".format('conv4', filters.shape, bias.shape)) 59 | net_fcn.params['conv4'][0].data[:,:,:,:] = filters 60 | net_fcn.params['conv4'][1].data[:] = bias 61 | print("Conv 4 weight:") 62 | print_data_stats(net_fcn.params['conv4'][0].data) 63 | print("Conv 4 bias:") 64 | print_data_stats(net_fcn.params['conv4'][1].data) 65 | 66 | # 5th conv. block 67 | filters = net_fc.params['conv5'][0].data 68 | bias = net_fc.params['conv5'][1].data 69 | print("Layer {0} copy params with shape: {1} and {2}".format('conv5', filters.shape, bias.shape)) 70 | net_fcn.params['conv5'][0].data[:,:,:,:] = filters 71 | net_fcn.params['conv5'][1].data[:] = bias 72 | print("Conv 5 weight:") 73 | print_data_stats(net_fcn.params['conv5'][0].data) 74 | print("Conv 5 bias:") 75 | print_data_stats(net_fcn.params['conv5'][1].data) 76 | 77 | # 6th conv. block 78 | filters = net_fc.params['fc6'][0].data 79 | bias = net_fc.params['fc6'][1].data 80 | print("Layer {0} copy params with shape: {1} and {2}".format('fc6', filters.shape, bias.shape)) 81 | fcn_shape = net_fcn.params['fc6'][0].data.shape 82 | print(fcn_shape) 83 | net_fcn.params['fc6'][0].data[:,:,:,:] = filters.reshape(fcn_shape) 84 | net_fcn.params['fc6'][1].data[:] = bias 85 | print("fc 6 weight:") 86 | print_data_stats(net_fcn.params['fc6'][0].data) 87 | print("fc 6 bias:") 88 | print_data_stats(net_fcn.params['fc6'][1].data) 89 | 90 | # 7th conv. block 91 | filters = net_fc.params['fc7'][0].data 92 | bias = net_fc.params['fc7'][1].data 93 | print("Layer {0} copy params with shape: {1} and {2}".format('fc7', filters.shape, bias.shape)) 94 | fcn_shape = net_fcn.params['fc7'][0].data.shape 95 | print(fcn_shape) 96 | net_fcn.params['fc7'][0].data[:,:,:,:] = filters.reshape(fcn_shape) 97 | net_fcn.params['fc7'][1].data[:] = bias 98 | print("fc 7 weight:") 99 | print_data_stats(net_fcn.params['fc7'][0].data) 100 | print("fc 7 bias:") 101 | print_data_stats(net_fcn.params['fc7'][1].data) 102 | 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('--src', type=str, required=True, default='', help='Source location of the Alexnet model params saved in caffe format') 108 | parser.add_argument('--dst', type=str, required=True, default='', help='Destination location where the Alexnet-fcn model params will be saved in caffe format') 109 | args_opt = parser.parse_args() 110 | 111 | caffe_alexnet_model_def = './extras/AlexNet.prototxt' 112 | caffe_alexnet_fcn_model_def = './extras/AlexNet_fcn.prototxt' 113 | 114 | src_caffe_alexnet_params = args_opt.src 115 | dst_caffe_alexnet_fcn_params = args_opt.dst 116 | 117 | caffe.set_mode_cpu() 118 | caffe_model_def_fcn = './caffe_models/AlexNet_NoGroups_NoLRN_FCN.prototxt' 119 | net_caffe_fcn = caffe.Net(caffe_alexnet_fcn_model_def, caffe.TEST) 120 | net_caffe = caffe.Net(caffe_alexnet_model_def, src_caffe_alexnet_params, caffe.TEST) 121 | 122 | copy_params_fc2fcn(net_caffe, net_caffe_fcn) 123 | 124 | if not os.path.isfile(dst_caffe_alexnet_fcn_params): 125 | print('==> Saving caffe alexnet-fcn model at {0}'.format(dst_caffe_alexnet_fcn_params)) 126 | net_caffe_fcn.save(dst_caffe_alexnet_fcn_params) 127 | else: 128 | print('==> File {0} already exists'.format(dst_caffe_alexnet_fcn_params)) 129 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import imp 5 | import algorithms as alg 6 | from dataloader import DataLoader, GenericDataset 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--exp', type=str, required=True, default='', help='config file with parameters of the experiment') 10 | parser.add_argument('--evaluate', default=False, action='store_true') 11 | parser.add_argument('--checkpoint', type=int, default=0, help='checkpoint (epoch id) that will be loaded') 12 | parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers') 13 | parser.add_argument('--cuda' , type=bool, default=True, help='enables cuda') 14 | parser.add_argument('--disp_step', type=int, default=50, help='display step during training') 15 | args_opt = parser.parse_args() 16 | 17 | exp_config_file = os.path.join('.','config',args_opt.exp+'.py') 18 | # if args_opt.semi == -1: 19 | exp_directory = os.path.join('.','experiments',args_opt.exp) 20 | # else: 21 | # assert(args_opt.semi>0) 22 | # exp_directory = os.path.join('.','experiments/unsupervised',args_opt.exp+'_semi'+str(args_opt.semi)) 23 | 24 | # Load the configuration params of the experiment 25 | print('Launching experiment: %s' % exp_config_file) 26 | config = imp.load_source("",exp_config_file).config 27 | config['exp_dir'] = exp_directory # the place where logs, models, and other stuff will be stored 28 | print("Loading experiment %s from file: %s" % (args_opt.exp, exp_config_file)) 29 | print("Generated logs, snapshots, and model files will be stored on %s" % (config['exp_dir'])) 30 | 31 | # Set train and test datasets and the corresponding data loaders 32 | data_train_opt = config['data_train_opt'] 33 | data_test_opt = config['data_test_opt'] 34 | num_imgs_per_cat = data_train_opt['num_imgs_per_cat'] if ('num_imgs_per_cat' in data_train_opt) else None 35 | 36 | 37 | 38 | dataset_train = GenericDataset( 39 | dataset_name=data_train_opt['dataset_name'], 40 | split=data_train_opt['split'], 41 | random_sized_crop=data_train_opt['random_sized_crop'], 42 | num_imgs_per_cat=num_imgs_per_cat) 43 | dataset_test = GenericDataset( 44 | dataset_name=data_test_opt['dataset_name'], 45 | split=data_test_opt['split'], 46 | random_sized_crop=data_test_opt['random_sized_crop']) 47 | 48 | dloader_train = DataLoader( 49 | dataset=dataset_train, 50 | batch_size=data_train_opt['batch_size'], 51 | unsupervised=data_train_opt['unsupervised'], 52 | epoch_size=data_train_opt['epoch_size'], 53 | num_workers=args_opt.num_workers, 54 | shuffle=True) 55 | 56 | dloader_test = DataLoader( 57 | dataset=dataset_test, 58 | batch_size=data_test_opt['batch_size'], 59 | unsupervised=data_test_opt['unsupervised'], 60 | epoch_size=data_test_opt['epoch_size'], 61 | num_workers=args_opt.num_workers, 62 | shuffle=False) 63 | 64 | config['disp_step'] = args_opt.disp_step 65 | algorithm = getattr(alg, config['algorithm_type'])(config) 66 | if args_opt.cuda: # enable cuda 67 | algorithm.load_to_gpu() 68 | if args_opt.checkpoint > 0: # load checkpoint 69 | algorithm.load_checkpoint(args_opt.checkpoint, train= (not args_opt.evaluate)) 70 | 71 | if not args_opt.evaluate: # train the algorithm 72 | algorithm.solve(dloader_train, dloader_test) 73 | else: 74 | algorithm.evaluate(dloader_test) # evaluate the algorithm 75 | -------------------------------------------------------------------------------- /run_cifar10_based_unsupervised_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train a RotNet (with a NIN architecture of 4 conv. blocks) on training images of CIFAR10. 4 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_RotNet_NIN4blocks 5 | 6 | # Train & evaluate an object classifier (for the CIFAR10 task) with convolutional 7 | # layers on the feature maps of the 2nd conv. block of the RotNet model trained above. 8 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats 9 | 10 | # Train & evaluate an object classifier (for the CIFAR10 task) with 3 fully connected layers 11 | # on the feature maps of the 2nd conv. block of the above RotNet model trained above. 12 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_MultLayerClassifier_on_RotNet_NIN4blocks_Conv2_feats 13 | -------------------------------------------------------------------------------- /run_cifar10_semi_supervised_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Run semi supervised experiments" 3 | 4 | # Train a conv-based classifier on top of the feature maps of the 2nd conv. block of a NIN-based RotNet model 5 | # trained on the entire training set of CIFAR10. 6 | 7 | # Use K=5000 training examples per category (which is equal to using the entire training set). 8 | # CUDA_VISIBLE_DEVICES=2 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats 9 | # Use K=1000 training examples per category. 10 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K1000 11 | # Use K=400 training examples per category. 12 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K400 13 | # Use K=100 training examples per category. 14 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K100 15 | # Use K=400 training examples per category. 16 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_ConvClassifier_on_RotNet_NIN4blocks_Conv2_feats_K20 17 | 18 | # Train fully supervised NIN models using subsets of the CIFAR10 training set. 19 | # Use K=5000 training examples per category (which is equal to using the entire training set). 20 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_supervised_NIN # 21 | # Use K=1000 training examples per category. 22 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_supervised_NIN_K1000 23 | # Use K=400 training examples per category. 24 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_supervised_NIN_K400 25 | # Use K=100 training examples per category. 26 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_supervised_NIN_K100 27 | # Use K=20 training examples per category. 28 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=CIFAR10_supervised_NIN_K20 29 | -------------------------------------------------------------------------------- /run_imagenet_based_unsupervised_feature_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train a RotNet model (with AlexNet architecture) on the training images of ImageNet. 4 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=ImageNet_RotNet_AlexNet 5 | 6 | # Train & evaluate non-linear object classifiers for the ImageNet task on the feature maps 7 | # generated by the conv4 and conv5 convolutional layers ofthe above trained AlexNet-based RotNet model. 8 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=ImageNet_NonLinearClassifiers_ImageNet_RotNet_AlexNet_Features 9 | 10 | # Train & evaluate linear object classifiers for the ImageNet task on the feature maps 11 | # generated by the convolutional layers (i.e., conv1, conv2, conv3, conv4, and conv5) of 12 | # the above trained AlexNet-based RotNet model. 13 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=ImageNet_LinearClassifiers_ImageNet_RotNet_AlexNet_Features 14 | 15 | # Train & evaluate linear object classifiers for the Places205 task on the feature maps 16 | # generated by the convolutional layers (i.e., conv1, conv2, conv3, conv4, and conv5) of 17 | # the above trained AlexNet-based RotNet model. 18 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=Places205_LinearClassifiers_ImageNet_RotNet_AlexNet_Features 19 | 20 | -------------------------------------------------------------------------------- /run_places205_based_unsupervised_feature_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Train a RotNet model (with AlexNet architecture) on the training images of Places205. 4 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=Places205_RotNet_AlexNet 5 | 6 | # Train & evaluate linear object classifiers for the ImageNet task on the feature maps 7 | # generated by the convolutional layers (i.e., conv1, conv2, conv3, conv4, and conv5) of 8 | # the above trained AlexNet-based RotNet model. 9 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=ImageNet_LinearClassifiers_Places205_RotNet_AlexNet_Features 10 | 11 | # Train & evaluate linear object classifiers for the Places205 task on the feature maps 12 | # generated by the convolutional layers (i.e., conv1, conv2, conv3, conv4, and conv5) of 13 | # the above trained AlexNet-based RotNet model. 14 | CUDA_VISIBLE_DEVICES=0 python main.py --exp=Places205_LinearClassifiers_Places205_RotNet_AlexNet_Features 15 | 16 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | import imp 8 | from tqdm import tqdm 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.optim 13 | import torchnet as tnt 14 | 15 | 16 | import numbers 17 | 18 | class FastConfusionMeter(object): 19 | def __init__(self, k, normalized = False): 20 | #super(FastConfusionMeter, self).__init__() 21 | self.conf = np.ndarray((k,k), dtype=np.int32) 22 | self.normalized = normalized 23 | self.reset() 24 | 25 | def reset(self): 26 | self.conf.fill(0) 27 | 28 | def add(self, output, target): 29 | output = output.cpu().squeeze().numpy() 30 | target = target.cpu().squeeze().numpy() 31 | 32 | if np.ndim(output) == 1: 33 | output = output[None] 34 | 35 | onehot = np.ndim(target) != 1 36 | assert output.shape[0] == target.shape[0], \ 37 | 'number of targets and outputs do not match' 38 | assert output.shape[1] == self.conf.shape[0], \ 39 | 'number of outputs does not match size of confusion matrix' 40 | assert not onehot or target.shape[1] == output.shape[1], \ 41 | 'target should be 1D Tensor or have size of output (one-hot)' 42 | if onehot: 43 | assert (target >= 0).all() and (target <= 1).all(), \ 44 | 'in one-hot encoding, target values should be 0 or 1' 45 | assert (target.sum(1) == 1).all(), \ 46 | 'multi-label setting is not supported' 47 | 48 | target = target.argmax(1) if onehot else target 49 | pred = output.argmax(1) 50 | 51 | target = target.astype(np.int32) 52 | pred = pred.astype(np.int32) 53 | conf_this = np.bincount(target * self.conf.shape[0] + pred,minlength=np.prod(self.conf.shape)) 54 | conf_this = conf_this.astype(self.conf.dtype).reshape(self.conf.shape) 55 | self.conf += conf_this 56 | 57 | def value(self): 58 | if self.normalized: 59 | conf = self.conf.astype(np.float32) 60 | return conf / conf.sum(1).clip(min=1e-12)[:,None] 61 | else: 62 | return self.conf 63 | 64 | def getConfMatrixResults(matrix): 65 | assert(len(matrix.shape)==2 and matrix.shape[0]==matrix.shape[1]) 66 | 67 | count_correct = np.diag(matrix) 68 | count_preds = matrix.sum(1) 69 | count_gts = matrix.sum(0) 70 | epsilon = np.finfo(np.float32).eps 71 | accuracies = count_correct / (count_gts + epsilon) 72 | IoUs = count_correct / (count_gts + count_preds - count_correct + epsilon) 73 | totAccuracy = count_correct.sum() / (matrix.sum() + epsilon) 74 | 75 | num_valid = (count_gts > 0).sum() 76 | meanAccuracy = accuracies.sum() / (num_valid + epsilon) 77 | meanIoU = IoUs.sum() / (num_valid + epsilon) 78 | 79 | result = {'totAccuracy': round(totAccuracy,4), 'meanAccuracy': round(meanAccuracy,4), 'meanIoU': round(meanIoU,4)} 80 | if num_valid == 2: 81 | result['IoUs_bg'] = round(IoUs[0],4) 82 | result['IoUs_fg'] = round(IoUs[1],4) 83 | 84 | return result 85 | 86 | class AverageConfMeter(object): 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.val = np.asarray(0, dtype=np.float64) 92 | self.avg = np.asarray(0, dtype=np.float64) 93 | self.sum = np.asarray(0, dtype=np.float64) 94 | self.count = 0 95 | 96 | def update(self, val): 97 | self.val = val 98 | if self.count == 0: 99 | self.sum = val.copy().astype(np.float64) 100 | else: 101 | self.sum += val.astype(np.float64) 102 | 103 | self.count += 1 104 | self.avg = getConfMatrixResults(self.sum) 105 | 106 | class AverageMeter(object): 107 | """Computes and stores the average and current value""" 108 | def __init__(self): 109 | self.reset() 110 | 111 | def reset(self): 112 | self.val = 0 113 | self.avg = 0.0 114 | self.sum = 0.0 115 | self.count = 0 116 | 117 | def update(self, val, n=1): 118 | self.val = val 119 | self.sum += float(val * n) 120 | self.count += n 121 | self.avg = round(self.sum / self.count,4) 122 | 123 | class LAverageMeter(object): 124 | """Computes and stores the average and current value""" 125 | def __init__(self): 126 | self.reset() 127 | 128 | def reset(self): 129 | self.val = [] 130 | self.avg = [] 131 | self.sum = [] 132 | self.count = 0 133 | 134 | def update(self, val): 135 | self.val = val 136 | self.count += 1 137 | if len(self.sum) == 0: 138 | assert(self.count == 1) 139 | self.sum = [v for v in val] 140 | self.avg = [round(v,4) for v in val] 141 | else: 142 | assert(len(self.sum) == len(val)) 143 | for i, v in enumerate(val): 144 | self.sum[i] += v 145 | self.avg[i] = round(self.sum[i] / self.count,4) 146 | 147 | class DAverageMeter(object): 148 | def __init__(self): 149 | self.reset() 150 | 151 | def reset(self): 152 | self.values = {} 153 | 154 | def update(self, values): 155 | assert(isinstance(values, dict)) 156 | for key, val in values.items(): 157 | if isinstance(val, (float, int)): 158 | if not (key in self.values): 159 | self.values[key] = AverageMeter() 160 | self.values[key].update(val) 161 | elif isinstance(val, (tnt.meter.ConfusionMeter,FastConfusionMeter)): 162 | if not (key in self.values): 163 | self.values[key] = AverageConfMeter() 164 | self.values[key].update(val.value()) 165 | elif isinstance(val, AverageConfMeter): 166 | if not (key in self.values): 167 | self.values[key] = AverageConfMeter() 168 | self.values[key].update(val.sum) 169 | elif isinstance(val, dict): 170 | if not (key in self.values): 171 | self.values[key] = DAverageMeter() 172 | self.values[key].update(val) 173 | elif isinstance(val, list): 174 | if not (key in self.values): 175 | self.values[key] = LAverageMeter() 176 | self.values[key].update(val) 177 | 178 | def average(self): 179 | average = {} 180 | for key, val in self.values.items(): 181 | if isinstance(val, type(self)): 182 | average[key] = val.average() 183 | else: 184 | average[key] = val.avg 185 | 186 | return average 187 | 188 | def __str__(self): 189 | ave_stats = self.average() 190 | return ave_stats.__str__() 191 | --------------------------------------------------------------------------------