├── README.md ├── algorithm.py ├── config ├── CIFAR_1shot.yaml ├── CIFAR_5shot.yaml ├── miniImageNet_1shot.yaml └── miniImageNet_5shot.yaml ├── data ├── download_cifarfs.sh ├── download_miniimagenet.sh └── get_cifarfs.py ├── dataloader.py ├── dataset.py ├── main.py ├── main_feat.py ├── networks.py ├── requirements.txt ├── sib.py └── utils ├── config.py ├── outils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # \[Update on 31/03/2020\] This repository has been merged to [Xfer](https://github.com/amzn/xfer/tree/master/synthetic_info_bottleneck). 2 | 3 | # \[ICLR 2020\] Synthetic information bottleneck for transductive meta-learning 4 | This repo contains the implementation of the *synthetic information bottleneck* algorithm for few-shot classification on Mini-ImageNet, 5 | which is used in our ICLR 2020 paper 6 | [Empirical Bayes Transductive Meta-Learning with Synthetic Gradients](https://openreview.net/forum?id=Hkg-xgrYvH). 7 | 8 | If our code is helpful for your research, please consider citing: 9 | ``` Bash 10 | @inproceedings{ 11 | Hu2020Empirical, 12 | title={Empirical Bayes Transductive Meta-Learning with Synthetic Gradients}, 13 | author={Shell Xu Hu and Pablo Moreno and Yang Xiao and Xi Shen and Guillaume Obozinski and Neil Lawrence and Andreas Damianou}, 14 | booktitle={International Conference on Learning Representations (ICLR)}, 15 | year={2020}, 16 | url={https://openreview.net/forum?id=Hkg-xgrYvH} 17 | } 18 | ``` 19 | 20 | ## Authors of the code 21 | [Shell Xu Hu](http://hushell.github.io/), [Xi Shen](https://xishen0220.github.io/) and [Yang Xiao](https://youngxiao13.github.io/) 22 | 23 | 24 | ## Dependencies 25 | The code is tested under **Pytorch > 1.0 + Python 3.6** environment with extra packages: 26 | ``` Bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | ## How to use the code on Mini-ImageNet? 32 | ### **Step 0**: Download Mini-ImageNet dataset 33 | 34 | ``` Bash 35 | cd data 36 | bash download_miniimagenet.sh 37 | cd .. 38 | ``` 39 | 40 | ### **Step 1** (optional): train a WRN-28-10 feature network (aka backbone) 41 | The weights of the feature network is downloaded in step 0, but you may also train from scracth by running 42 | 43 | ``` Bash 44 | python main_feat.py --outDir miniImageNet_WRN_60Epoch --cuda --dataset miniImageNet --nbEpoch 60 45 | ``` 46 | 47 | ### **Step 2**: Meta-training on Mini-ImageNet, e.g., 5-way-1-shot: 48 | 49 | ``` Bash 50 | python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0 51 | ``` 52 | 53 | ### **Step 3**: Meta-testing on Mini-ImageNet with a checkpoint: 54 | 55 | ``` Bash 56 | python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0 --ckpt cache/miniImageNet_1shot_K3_seed100/outputs_xx.xxx/netSIBBestxx.xxx.pth 57 | ``` 58 | 59 | ## Mini-ImageNet Results (LAST ckpt) 60 | 61 | | Setup | 5-way-1-shot | 5-way-5-shot | 62 | | ------------- | -------------:| ------------:| 63 | | SIB (K=3) | 70.700% ± 0.585% | 80.045% ± 0.363%| 64 | | SIB (K=5) | 70.494 ± 0.619% | 80.192% ± 0.372%| 65 | 66 | ## CIFAR-FS Results (LAST ckpt) 67 | 68 | | Setup | 5-way-1-shot | 5-way-5-shot | 69 | | ------------- | -------------:| ------------:| 70 | | SIB (K=3) | 79.763% ± 0.577% | 85.721% ± 0.369%| 71 | | SIB (K=5) | 79.627 ± 0.593% | 85.590% ± 0.375%| 72 | -------------------------------------------------------------------------------- /algorithm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import itertools 17 | import torch 18 | import torch.nn.functional as F 19 | from tensorboardX import SummaryWriter 20 | from utils.outils import progress_bar, AverageMeter, accuracy, getCi 21 | from utils.utils import to_device 22 | 23 | class Algorithm: 24 | """ 25 | Algorithm logic is implemented here with training and validation functions etc. 26 | 27 | :param args: experimental configurations 28 | :type args: EasyDict 29 | :param logger: logger 30 | :param netFeat: feature network 31 | :type netFeat: class `WideResNet` or `ConvNet_4_64` 32 | :param netSIB: Classifier/decoder 33 | :type netSIB: class `ClassifierSIB` 34 | :param optimizer: optimizer 35 | :type optimizer: torch.optim.SGD 36 | :param criterion: loss 37 | :type criterion: nn.CrossEntropyLoss 38 | """ 39 | def __init__(self, args, logger, netFeat, netSIB, optimizer, criterion): 40 | self.netFeat = netFeat 41 | self.netSIB = netSIB 42 | self.optimizer = optimizer 43 | self.criterion = criterion 44 | 45 | self.nbIter = args.nbIter 46 | self.nStep = args.nStep 47 | self.outDir = args.outDir 48 | self.nFeat = args.nFeat 49 | self.batchSize = args.batchSize 50 | self.nEpisode = args.nEpisode 51 | self.momentum = args.momentum 52 | self.weightDecay = args.weightDecay 53 | 54 | self.logger = logger 55 | self.device = torch.device('cuda' if args.cuda else 'cpu') 56 | 57 | # Load pretrained model 58 | if args.resumeFeatPth : 59 | if args.cuda: 60 | param = torch.load(args.resumeFeatPth) 61 | else: 62 | param = torch.load(args.resumeFeatPth, map_location='cpu') 63 | self.netFeat.load_state_dict(param) 64 | msg = '\nLoading netFeat from {}'.format(args.resumeFeatPth) 65 | self.logger.info(msg) 66 | 67 | if args.test: 68 | self.load_ckpt(args.ckptPth) 69 | 70 | 71 | def load_ckpt(self, ckptPth): 72 | """ 73 | Load checkpoint from ckptPth. 74 | 75 | :param ckptPth: the path to the ckpt 76 | :type ckptPth: string 77 | """ 78 | param = torch.load(ckptPth) 79 | self.netFeat.load_state_dict(param['netFeat']) 80 | self.netSIB.load_state_dict(param['SIB']) 81 | lr = param['lr'] 82 | self.optimizer = torch.optim.SGD(itertools.chain(*[self.netSIB.parameters(),]), 83 | lr, 84 | momentum=self.momentum, 85 | weight_decay=self.weightDecay, 86 | nesterov=True) 87 | msg = '\nLoading networks from {}'.format(ckptPth) 88 | self.logger.info(msg) 89 | 90 | 91 | def compute_grad_loss(self, clsScore, QueryLabel): 92 | """ 93 | Compute the loss between true gradients and synthetic gradients. 94 | """ 95 | # register hooks 96 | def require_nonleaf_grad(v): 97 | def hook(g): 98 | v.grad_nonleaf = g 99 | h = v.register_hook(hook) 100 | return h 101 | handle = require_nonleaf_grad(clsScore) 102 | 103 | loss = self.criterion(clsScore, QueryLabel) 104 | loss.backward(retain_graph=True) # need to backward again 105 | 106 | # remove hook 107 | handle.remove() 108 | 109 | gradLogit = self.netSIB.dni(clsScore) # B * n x nKnovel 110 | gradLoss = F.mse_loss(gradLogit, clsScore.grad_nonleaf.detach()) 111 | 112 | return loss, gradLoss 113 | 114 | 115 | def validate(self, valLoader, lr=None, mode='val'): 116 | """ 117 | Run one epoch on val-set. 118 | 119 | :param valLoader: the dataloader of val-set 120 | :type valLoader: class `ValLoader` 121 | :param float lr: learning rate for synthetic GD 122 | :param string mode: 'val' or 'train' 123 | """ 124 | if mode == 'test': 125 | nEpisode = self.nEpisode 126 | self.logger.info('\n\nTest mode: randomly sample {:d} episodes...'.format(nEpisode)) 127 | elif mode == 'val': 128 | nEpisode = len(valLoader) 129 | self.logger.info('\n\nValidation mode: pre-defined {:d} episodes...'.format(nEpisode)) 130 | valLoader = iter(valLoader) 131 | else: 132 | raise ValueError('mode is wrong!') 133 | 134 | episodeAccLog = [] 135 | top1 = AverageMeter() 136 | 137 | self.netFeat.eval() 138 | #self.netSIB.eval() # set train mode, since updating bn helps to estimate better gradient 139 | 140 | if lr is None: 141 | lr = self.optimizer.param_groups[0]['lr'] 142 | 143 | #for batchIdx, data in enumerate(valLoader): 144 | for batchIdx in range(nEpisode): 145 | data = valLoader.getEpisode() if mode == 'test' else next(valLoader) 146 | data = to_device(data, self.device) 147 | 148 | SupportTensor, SupportLabel, QueryTensor, QueryLabel = \ 149 | data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \ 150 | data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0) 151 | 152 | with torch.no_grad(): 153 | SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor) 154 | SupportFeat, QueryFeat, SupportLabel = \ 155 | SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0) 156 | 157 | clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr) 158 | clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1) 159 | QueryLabel = QueryLabel.view(-1) 160 | acc1 = accuracy(clsScore, QueryLabel, topk=(1,)) 161 | top1.update(acc1[0].item(), clsScore.shape[0]) 162 | 163 | msg = 'Top1: {:.3f}%'.format(top1.avg) 164 | progress_bar(batchIdx, nEpisode, msg) 165 | episodeAccLog.append(acc1[0].item()) 166 | 167 | mean, ci95 = getCi(episodeAccLog) 168 | self.logger.info('Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.format(mean, ci95)) 169 | return mean, ci95 170 | 171 | 172 | def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0) : 173 | """ 174 | Run one epoch on train-set. 175 | 176 | :param trainLoader: the dataloader of train-set 177 | :type trainLoader: class `TrainLoader` 178 | :param valLoader: the dataloader of val-set 179 | :type valLoader: class `ValLoader` 180 | :param float lr: learning rate for synthetic GD 181 | :param float coeffGrad: deprecated 182 | """ 183 | bestAcc, ci = self.validate(valLoader, lr) 184 | self.logger.info('Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.format(bestAcc,ci)) 185 | 186 | self.netSIB.train() 187 | self.netFeat.eval() 188 | 189 | losses = AverageMeter() 190 | top1 = AverageMeter() 191 | history = {'trainLoss' : [], 'trainAcc' : [], 'valAcc' : []} 192 | 193 | for episode in range(self.nbIter): 194 | data = trainLoader.getBatch() 195 | data = to_device(data, self.device) 196 | 197 | with torch.no_grad() : 198 | SupportTensor, SupportLabel, QueryTensor, QueryLabel = \ 199 | data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel'] 200 | nC, nH, nW = SupportTensor.shape[2:] 201 | 202 | SupportFeat = self.netFeat(SupportTensor.reshape(-1, nC, nH, nW)) 203 | SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat) 204 | 205 | QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW)) 206 | QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat) 207 | 208 | if lr is None: 209 | lr = self.optimizer.param_groups[0]['lr'] 210 | 211 | self.optimizer.zero_grad() 212 | 213 | clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr) 214 | clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1) 215 | QueryLabel = QueryLabel.view(-1) 216 | 217 | if coeffGrad > 0: 218 | loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel) 219 | loss = loss + gradLoss * coeffGrad 220 | else: 221 | loss = self.criterion(clsScore, QueryLabel) 222 | 223 | loss.backward() 224 | self.optimizer.step() 225 | 226 | acc1 = accuracy(clsScore, QueryLabel, topk=(1, )) 227 | top1.update(acc1[0].item(), clsScore.shape[0]) 228 | losses.update(loss.item(), QueryFeat.shape[1]) 229 | msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg) 230 | if coeffGrad > 0: 231 | msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item()) 232 | progress_bar(episode, self.nbIter, msg) 233 | 234 | if episode % 1000 == 999 : 235 | acc, _ = self.validate(valLoader, lr) 236 | 237 | if acc > bestAcc : 238 | msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(bestAcc , acc) 239 | self.logger.info(msg) 240 | 241 | bestAcc = acc 242 | self.logger.info('Saving Best') 243 | torch.save({ 244 | 'lr': lr, 245 | 'netFeat': self.netFeat.state_dict(), 246 | 'SIB': self.netSIB.state_dict(), 247 | 'nbStep': self.nStep, 248 | }, os.path.join(self.outDir, 'netSIBBest.pth')) 249 | 250 | self.logger.info('Saving Last') 251 | torch.save({ 252 | 'lr': lr, 253 | 'netFeat': self.netFeat.state_dict(), 254 | 'SIB': self.netSIB.state_dict(), 255 | 'nbStep': self.nStep, 256 | }, os.path.join(self.outDir, 'netSIBLast.pth')) 257 | 258 | msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format( 259 | episode, losses.avg, top1.avg, acc) 260 | self.logger.info(msg) 261 | history['trainLoss'].append(losses.avg) 262 | history['trainAcc'].append(top1.avg) 263 | history['valAcc'].append(acc) 264 | 265 | losses = AverageMeter() 266 | top1 = AverageMeter() 267 | 268 | return bestAcc, acc, history 269 | -------------------------------------------------------------------------------- /config/CIFAR_1shot.yaml: -------------------------------------------------------------------------------- 1 | # Few-shot dataset 2 | nClsEpisode: 5 # number of categories in each episode 3 | nSupport: 1 # number of samples per category in the support set 4 | nQuery: 15 # number of samples per category in the query set 5 | dataset: 'Cifar' # choices = ['miniImageNet', 'Cifar'] 6 | 7 | # Network 8 | nStep: 3 # number of synthetic gradient steps 9 | architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4'] 10 | batchSize: 1 # number of episodes in each batch 11 | 12 | # Optimizer 13 | lr: 0.001 # lr is fixed 14 | weightDecay: 0.0005 15 | momentum: 0.9 16 | 17 | # Training details 18 | expName: cifar-fs 19 | nbIter: 50000 # number of training iterations 20 | seed: 100 # can be reset with --seed 21 | gpu: '1' # can be reset with --gpu 22 | resumeFeatPth : './ckpts/CIFAR-FS/netFeatBest62.561.pth' # feat ckpt 23 | coeffGrad: 0 # grad loss coeff 24 | 25 | # Testing 26 | nEpisode: 2000 # number of episodes for testing 27 | -------------------------------------------------------------------------------- /config/CIFAR_5shot.yaml: -------------------------------------------------------------------------------- 1 | # Few-shot dataset 2 | nClsEpisode: 5 # number of categories in each episode 3 | nSupport: 5 # number of samples per category in the support set 4 | nQuery: 15 # number of samples per category in the query set 5 | dataset: 'Cifar' # choices = ['miniImageNet', 'Cifar'] 6 | 7 | # Network 8 | nStep: 3 # number of synthetic gradient steps 9 | architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4'] 10 | batchSize: 1 # number of episodes in each batch 11 | 12 | # Optimizer 13 | lr: 0.001 # lr is fixed 14 | weightDecay: 0.0005 15 | momentum: 0.9 16 | 17 | # Training details 18 | expName: cifar-fs 19 | nbIter: 50000 # number of training iterations 20 | seed: 100 # can be reset with --seed 21 | gpu: '1' # can be reset with --gpu 22 | resumeFeatPth : './ckpts/CIFAR-FS/netFeatBest62.561.pth' # feat ckpt 23 | coeffGrad: 0 # grad loss coeff 24 | 25 | # Testing 26 | nEpisode: 2000 # number of episodes for testing 27 | -------------------------------------------------------------------------------- /config/miniImageNet_1shot.yaml: -------------------------------------------------------------------------------- 1 | # Few-shot dataset 2 | nClsEpisode: 5 # number of categories in each episode 3 | nSupport: 1 # number of samples per category in the support set 4 | nQuery: 15 # number of samples per category in the query set 5 | dataset: 'miniImageNet' # choices = ['miniImageNet', 'Cifar'] 6 | 7 | # Network 8 | nStep: 3 # number of synthetic gradient steps 9 | architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4'] 10 | batchSize: 1 # number of episodes in each batch 11 | 12 | # Optimizer 13 | lr: 0.001 # lr is fixed 14 | weightDecay: 0.0005 15 | momentum: 0.9 16 | 17 | # Training details 18 | expName: miniImageNet 19 | nbIter: 50000 # number of training iterations 20 | seed: 100 # can be reset with --seed 21 | gpu: '1' # can be reset with --gpu 22 | resumeFeatPth: './ckpts/Mini-ImageNet/netFeatBest64.653.pth' 23 | coeffGrad: 0 # grad loss coeff 24 | 25 | # Testing 26 | nEpisode: 2000 # number of episodes for testing 27 | -------------------------------------------------------------------------------- /config/miniImageNet_5shot.yaml: -------------------------------------------------------------------------------- 1 | # Few-shot dataset 2 | nClsEpisode: 5 # number of categories in each episode 3 | nSupport: 5 # number of samples per category in the support set 4 | nQuery: 15 # number of samples per category in the query set 5 | dataset: 'miniImageNet' # choices = ['miniImageNet', 'Cifar'] 6 | 7 | # Network 8 | nStep: 3 # number of synthetic gradient steps 9 | architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4'] 10 | batchSize: 1 # number of episodes in each batch 11 | 12 | # Optimizer 13 | lr: 0.001 # lr is fixed 14 | weightDecay: 0.0005 15 | momentum: 0.9 16 | 17 | # Training details 18 | expName: miniImageNet 19 | nbIter: 50000 # number of training iterations 20 | seed: 100 # can be reset with --seed 21 | gpu: '1' # can be reset with --gpu 22 | resumeFeatPth: './ckpts/Mini-ImageNet/netFeatBest64.653.pth' 23 | coeffGrad: 0 # grad loss coeff 24 | 25 | # Testing 26 | nEpisode: 2000 # number of episodes for testing 27 | -------------------------------------------------------------------------------- /data/download_cifarfs.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/wuxb1wlahado3nq/cifar-fs-splits.zip?dl=0 2 | mv cifar-fs-splits.zip?dl=0 cifar-fs-splits.zip 3 | unzip cifar-fs-splits.zip 4 | rm cifar-fs-splits.zip 5 | 6 | python get_cifarfs.py 7 | mv cifar-fs-splits/val1000* cifar-fs/ 8 | 9 | wget https://www.dropbox.com/s/g9ru5ac5tpupvg6/netFeatBest62.561.pth?dl=0 10 | mv netFeatBest62.561.pth?dl=0 netFeatBest62.561.pth 11 | mkdir ../ckpts 12 | mkdir ../ckpts/CIFAR-FS 13 | mv netFeatBest62.561.pth ../ckpts/CIFAR-FS/ 14 | -------------------------------------------------------------------------------- /data/download_miniimagenet.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/a2a0bll17f5dvhr/Mini-ImageNet.zip?dl=0 2 | mv Mini-ImageNet.zip?dl=0 Mini-ImageNet.zip 3 | unzip Mini-ImageNet.zip 4 | rm Mini-ImageNet.zip 5 | rm -r Mini-ImageNet/train_val Mini-ImageNet/train_test 6 | mv Mini-ImageNet/train_train Mini-ImageNet/train 7 | 8 | wget https://www.dropbox.com/s/2hqpf8cqansm1n7/val1000Episode_5_way_5_shot.json?dl=0 9 | mv val1000Episode_5_way_5_shot.json?dl=0 val1000Episode_5_way_5_shot.json 10 | mv val1000Episode_5_way_5_shot.json Mini-ImageNet/ 11 | 12 | wget https://www.dropbox.com/s/0n99mf5ylh4yefi/val1000Episode_5_way_1_shot.json?dl=0 13 | mv val1000Episode_5_way_1_shot.json?dl=0 val1000Episode_5_way_1_shot.json 14 | mv val1000Episode_5_way_1_shot.json Mini-ImageNet/ 15 | 16 | wget https://www.dropbox.com/s/t36y8ng47wlcxw0/netFeatBest64.653.pth?dl=0 17 | mv netFeatBest64.653.pth?dl=0 netFeatBest64.653.pth 18 | mkdir ../ckpts 19 | mkdir ../ckpts/Mini-ImageNet 20 | mv netFeatBest64.653.pth ../ckpts/Mini-ImageNet/ 21 | -------------------------------------------------------------------------------- /data/get_cifarfs.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Arnout Devos 3 | 2018/12/06 4 | MIT License 5 | 6 | Script for downloading, and reorganizing CIFAR few shot from CIFAR-100 according 7 | to split specifications in Luca et al. '18. 8 | Run this file as follows: 9 | python get_cifarfs.py 10 | 11 | """ 12 | 13 | import pickle 14 | import os 15 | import numpy as np 16 | from tqdm import tqdm 17 | import requests 18 | import math 19 | import tarfile,sys 20 | from PIL import Image 21 | import glob 22 | import shutil 23 | 24 | def download_file(url, filename): 25 | """ 26 | Helper method handling downloading large files from `url` to `filename`. Returns a pointer to `filename`. 27 | """ 28 | chunkSize = 1024 29 | r = requests.get(url, stream=True) 30 | with open(filename, 'wb') as f: 31 | pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) ) 32 | for chunk in r.iter_content(chunk_size=chunkSize): 33 | if chunk: # filter out keep-alive new chunks 34 | pbar.update (len(chunk)) 35 | f.write(chunk) 36 | return filename 37 | 38 | if not os.path.exists("cifar-100-python.tar.gz"): 39 | print("Downloading cifar-100-python.tar.gz\n") 40 | download_file('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz','cifar-100-python.tar.gz') 41 | print("Downloading done.\n") 42 | else: 43 | print("Dataset already downloaded. Did not download twice.\n") 44 | 45 | tarname = "cifar-100-python.tar.gz" 46 | print("Untarring: {}".format(tarname)) 47 | tar = tarfile.open(tarname) 48 | tar.extractall() 49 | tar.close() 50 | 51 | datapath = "cifar-100-python" 52 | 53 | print("Extracting jpg images and classes from pickle files") 54 | 55 | # in CIFAR 100, the files are given in a train and test format 56 | for batch in ['test','train']: 57 | 58 | print("Handling pickle file: {}".format(batch)) 59 | 60 | # Create variable which is the exact path to the file 61 | fpath = os.path.join(datapath, batch) 62 | 63 | # Unpickle the file, and its metadata (classnames) 64 | f = open(fpath, 'rb') 65 | labels = pickle.load(open(os.path.join(datapath, 'meta'), 'rb'), encoding="ASCII") 66 | d = pickle.load(f, encoding='bytes') 67 | 68 | # decode utf8 encoded keys, and copy files into new dictionary d_decoded 69 | d_decoded = {} 70 | for k, v in d.items(): 71 | d_decoded[k.decode('utf8')] = v 72 | 73 | d = d_decoded 74 | f.close() 75 | 76 | #for i, filename in enumerate(d['filenames']): 77 | i=0 78 | for filename in tqdm(d['filenames']): 79 | folder = os.path.join('images', 80 | labels['fine_label_names'][d['fine_labels'][i]] 81 | ) 82 | 83 | #batch, 84 | #labels['coarse_label_names'][d['coarse_labels'][i]], 85 | #labels['fine_label_names'][d['fine_labels'][i]] 86 | 87 | png_path = os.path.join(folder, filename.decode()) 88 | jpg_path = os.path.splitext(png_path)[0]+".jpg" 89 | 90 | if os.path.exists(jpg_path): 91 | continue 92 | else: 93 | os.makedirs(folder, exist_ok=True) 94 | q = d['data'][i] 95 | with open(jpg_path, 'wb') as outfile: 96 | #png.from_array(q.reshape((32, 32, 3), order='F').swapaxes(0,1), mode='RGB').save(outfile) 97 | img = Image.fromarray(q.reshape((32, 32, 3), order='F').swapaxes(0,1), 'RGB') 98 | img.save(outfile) 99 | 100 | i+=1 101 | 102 | print("Removing pickle files") 103 | shutil.rmtree('cifar-100-python', ignore_errors=True) 104 | 105 | print("Depending on the split files, organize train, val and test sets") 106 | for datatype in ['train', 'val', 'test']: 107 | os.makedirs(os.path.join('cifar-fs', datatype), exist_ok=True) 108 | with open(os.path.join('cifar-fs-splits', datatype + '.txt'), 'r') as f: 109 | content = f.readlines() 110 | # Remove whitespace characters like `\n` at the end of each line 111 | classes = [x.strip() for x in content] 112 | 113 | for img_class in classes: 114 | if os.path.exists(os.path.join('cifar-fs', datatype, img_class)): 115 | continue 116 | else: 117 | cur_dir = os.path.join('cifar-fs', datatype) 118 | os.makedirs(cur_dir, exist_ok=True) 119 | os.system('mv images/' + img_class + ' ' + cur_dir) 120 | 121 | print("Removing original CIFAR 100 images") 122 | shutil.rmtree('images', ignore_errors=True) 123 | 124 | print("Removing tar file") 125 | os.remove('cifar-100-python.tar.gz') 126 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import torch 17 | import torch.utils.data as data 18 | import PIL.Image as Image 19 | import numpy as np 20 | import json 21 | 22 | from torchvision import transforms 23 | from torchvision.datasets import ImageFolder 24 | 25 | 26 | def PilLoaderRGB(imgPath) : 27 | return Image.open(imgPath).convert('RGB') 28 | 29 | 30 | class EpisodeSampler(): 31 | """ 32 | Dataloader to sample a task/episode. 33 | In case of 5-way 1-shot: nSupport = 1, nClsEpisode = 5. 34 | 35 | :param string imgDir: image directory, each category is in a sub file; 36 | :param int nClsEpisode: number of classes in each episode; 37 | :param int nSupport: number of support examples; 38 | :param int nQuery: number of query examples; 39 | :param transform: image transformation/data augmentation; 40 | :param bool useGPU: whether to use gpu or not; 41 | :param int inputW: input image size, dimension W; 42 | :param int inputH: input image size, dimension H; 43 | """ 44 | def __init__(self, imgDir, nClsEpisode, nSupport, nQuery, transform, useGPU, inputW, inputH): 45 | self.imgDir = imgDir 46 | self.clsList = os.listdir(imgDir) 47 | self.nClsEpisode = nClsEpisode 48 | self.nSupport = nSupport 49 | self.nQuery = nQuery 50 | self.transform = transform 51 | 52 | floatType = torch.cuda.FloatTensor if useGPU else torch.FloatTensor 53 | intType = torch.cuda.LongTensor if useGPU else torch.LongTensor 54 | 55 | self.tensorSupport = floatType(nClsEpisode * nSupport, 3, inputW, inputH) 56 | self.labelSupport = intType(nClsEpisode * nSupport) 57 | self.tensorQuery = floatType(nClsEpisode * nQuery, 3, inputW, inputH) 58 | self.labelQuery = intType(nClsEpisode * nQuery) 59 | self.imgTensor = floatType(3, inputW, inputH) 60 | 61 | def getEpisode(self): 62 | """ 63 | Return an episode 64 | 65 | :return dict: {'SupportTensor': 1 x nSupport x 3 x H x W, 66 | 'SupportLabel': 1 x nSupport, 67 | 'QueryTensor': 1 x nQuery x 3 x H x W, 68 | 'QueryLabel': 1 x nQuery} 69 | """ 70 | # labels {0, ..., nClsEpisode-1} 71 | for i in range(self.nClsEpisode) : 72 | self.labelSupport[i * self.nSupport : (i+1) * self.nSupport] = i 73 | self.labelQuery[i * self.nQuery : (i+1) * self.nQuery] = i 74 | 75 | # select nClsEpisode from clsList 76 | clsEpisode = np.random.choice(self.clsList, self.nClsEpisode, replace=False) 77 | for i, cls in enumerate(clsEpisode) : 78 | clsPath = os.path.join(self.imgDir, cls) 79 | imgList = os.listdir(clsPath) 80 | 81 | # in total nQuery+nSupport images from each class 82 | imgCls = np.random.choice(imgList, self.nQuery + self.nSupport, replace=False) 83 | 84 | for j in range(self.nSupport) : 85 | img = imgCls[j] 86 | imgPath = os.path.join(clsPath, img) 87 | I = PilLoaderRGB(imgPath) 88 | self.tensorSupport[i * self.nSupport + j] = self.imgTensor.copy_(self.transform(I)) 89 | 90 | for j in range(self.nQuery) : 91 | img = imgCls[j + self.nSupport] 92 | imgPath = os.path.join(clsPath, img) 93 | I = PilLoaderRGB(imgPath) 94 | self.tensorQuery[i * self.nQuery + j] = self.imgTensor.copy_(self.transform(I)) 95 | 96 | ## Random permutation. Though this is not necessary in our approach 97 | permSupport = torch.randperm(self.nClsEpisode * self.nSupport) 98 | permQuery = torch.randperm(self.nClsEpisode * self.nQuery) 99 | 100 | return {'SupportTensor':self.tensorSupport[permSupport], 101 | 'SupportLabel':self.labelSupport[permSupport], 102 | 'QueryTensor':self.tensorQuery[permQuery], 103 | 'QueryLabel':self.labelQuery[permQuery] 104 | } 105 | 106 | 107 | class BatchSampler(): 108 | """ 109 | Dataloader to sample a task/episode. 110 | In case of 5-way 1-shot: nSupport = 1, nClsEpisode = 5. 111 | 112 | :param string imgDir: image directory, each category is in a sub file; 113 | :param int nClsEpisode: number of classes in each episode; 114 | :param int nSupport: number of support examples; 115 | :param int nQuery: number of query examples; 116 | :param transform: image transformation/data augmentation; 117 | :param bool useGPU: whether to use gpu or not; 118 | :param int inputW: input image size, dimension W; 119 | :param int inputH: input image size, dimension H; 120 | :param int batchSize: batch size (number of episode in each batch). 121 | """ 122 | def __init__(self, imgDir, nClsEpisode, nSupport, nQuery, transform, useGPU, inputW, inputH, batchSize): 123 | self.episodeSampler = EpisodeSampler(imgDir, nClsEpisode, nSupport, nQuery, 124 | transform, useGPU, inputW, inputH) 125 | 126 | floatType = torch.cuda.FloatTensor if useGPU else torch.FloatTensor 127 | intType = torch.cuda.LongTensor if useGPU else torch.LongTensor 128 | 129 | self.tensorSupport = floatType(batchSize, nClsEpisode * nSupport, 3, inputW, inputH) 130 | self.labelSupport = intType(batchSize, nClsEpisode * nSupport) 131 | self.tensorQuery = floatType(batchSize, nClsEpisode * nQuery, 3, inputW, inputH) 132 | self.labelQuery = intType(batchSize, nClsEpisode * nQuery) 133 | 134 | self.batchSize = batchSize 135 | 136 | def getBatch(self): 137 | """ 138 | Return an episode 139 | 140 | :return dict: {'SupportTensor': B x nSupport x 3 x H x W, 141 | 'SupportLabel': B x nSupport, 142 | 'QueryTensor': B x nQuery x 3 x H x W, 143 | 'QueryLabel': B x nQuery} 144 | """ 145 | for i in range(self.batchSize) : 146 | episode = self.episodeSampler.getEpisode() 147 | self.tensorSupport[i] = episode['SupportTensor'] 148 | self.labelSupport[i] = episode['SupportLabel'] 149 | self.tensorQuery[i] = episode['QueryTensor'] 150 | self.labelQuery[i] = episode['QueryLabel'] 151 | 152 | return {'SupportTensor':self.tensorSupport, 153 | 'SupportLabel':self.labelSupport, 154 | 'QueryTensor':self.tensorQuery, 155 | 'QueryLabel':self.labelQuery 156 | } 157 | 158 | 159 | class ValImageFolder(data.Dataset): 160 | """ 161 | To make validation results comparable, we fix 2000 episodes for validation. 162 | 163 | :param string episodeJson: ./data/Dataset/val1000Episode_K_way_N_shot.json 164 | :param string imgDir: image directory, each category is in a sub file; 165 | :param int inputW: input image size, dimension W; 166 | :param int inputH: input image size, dimension H; 167 | :param valTransform: image transformation/data augmentation; 168 | :param bool useGPU: whether to use gpu or not; 169 | """ 170 | def __init__(self, episodeJson, imgDir, inputW, inputH, valTransform, useGPU): 171 | with open(episodeJson, 'r') as f : 172 | self.episodeInfo = json.load(f) 173 | 174 | self.imgDir = imgDir 175 | self.nEpisode = len(self.episodeInfo) 176 | self.nClsEpisode = len(self.episodeInfo[0]['Support']) 177 | self.nSupport = len(self.episodeInfo[0]['Support'][0]) 178 | self.nQuery = len(self.episodeInfo[0]['Query'][0]) 179 | self.transform = valTransform 180 | floatType = torch.cuda.FloatTensor if useGPU else torch.FloatTensor 181 | intType = torch.cuda.LongTensor if useGPU else torch.LongTensor 182 | 183 | self.tensorSupport = floatType(self.nClsEpisode * self.nSupport, 3, inputW, inputH) 184 | self.labelSupport = intType(self.nClsEpisode * self.nSupport) 185 | self.tensorQuery = floatType(self.nClsEpisode * self.nQuery, 3, inputW, inputH) 186 | self.labelQuery = intType(self.nClsEpisode * self.nQuery) 187 | 188 | self.imgTensor = floatType(3, inputW, inputH) 189 | for i in range(self.nClsEpisode) : 190 | self.labelSupport[i * self.nSupport : (i+1) * self.nSupport] = i 191 | self.labelQuery[i * self.nQuery : (i+1) * self.nQuery] = i 192 | 193 | 194 | def __getitem__(self, index): 195 | """ 196 | Return an episode 197 | 198 | :param int index: index of data example 199 | :return dict: {'SupportTensor': 1 x nSupport x 3 x H x W, 200 | 'SupportLabel': 1 x nSupport, 201 | 'QueryTensor': 1 x nQuery x 3 x H x W, 202 | 'QueryLabel': 1 x nQuery} 203 | """ 204 | for i in range(self.nClsEpisode) : 205 | for j in range(self.nSupport) : 206 | imgPath = os.path.join(self.imgDir, self.episodeInfo[index]['Support'][i][j]) 207 | I = PilLoaderRGB(imgPath) 208 | self.tensorSupport[i * self.nSupport + j] = self.imgTensor.copy_(self.transform(I)) 209 | 210 | for j in range(self.nQuery) : 211 | imgPath = os.path.join(self.imgDir, self.episodeInfo[index]['Query'][i][j]) 212 | I = PilLoaderRGB(imgPath) 213 | self.tensorQuery[i * self.nQuery + j] = self.imgTensor.copy_(self.transform(I)) 214 | 215 | return {'SupportTensor':self.tensorSupport, 216 | 'SupportLabel':self.labelSupport, 217 | 'QueryTensor':self.tensorQuery, 218 | 'QueryLabel':self.labelQuery 219 | } 220 | 221 | def __len__(self): 222 | """ 223 | Number of episodes 224 | """ 225 | return self.nEpisode 226 | 227 | 228 | def ValLoader(episodeJson, imgDir, inputW, inputH, valTransform, useGPU) : 229 | dataloader = data.DataLoader(ValImageFolder(episodeJson, imgDir, inputW, inputH, 230 | valTransform, useGPU), 231 | shuffle=False) 232 | return dataloader 233 | 234 | 235 | def TrainLoader(batchSize, imgDir, trainTransform) : 236 | dataloader = data.DataLoader(ImageFolder(imgDir, trainTransform), 237 | batch_size=batchSize, shuffle=True, drop_last=True) 238 | return dataloader 239 | 240 | 241 | if __name__ == '__main__' : 242 | import torchvision.transforms as transforms 243 | mean = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]] 244 | std = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]] 245 | normalize = transforms.Normalize(mean=mean, std=std) 246 | trainTransform = transforms.Compose([ 247 | transforms.RandomCrop(80, padding=8), 248 | transforms.RandomHorizontalFlip(), 249 | lambda x: np.asarray(x), 250 | transforms.ToTensor(), 251 | normalize 252 | ]) 253 | 254 | TrainEpisodeSampler = EpisodeSampler(imgDir = '../data/Mini-ImageNet/train_train/', 255 | nClsEpisode = 5, 256 | nSupport = 5, 257 | nQuery = 14, 258 | transform = trainTransform, 259 | useGPU = True, 260 | inputW = 80, 261 | inputH = 80) 262 | data = TrainEpisodeSampler.getEpisode() 263 | print (data['SupportLabel']) 264 | 265 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import numpy as np 16 | import torchvision.transforms as transforms 17 | 18 | def dataset_setting(dataset, nSupport): 19 | """ 20 | Return dataset setting 21 | 22 | :param string dataset: name of dataset 23 | :param int nSupport: number of support examples 24 | """ 25 | if dataset == 'miniImageNet': 26 | mean = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]] 27 | std = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]] 28 | normalize = transforms.Normalize(mean=mean, std=std) 29 | trainTransform = transforms.Compose([transforms.RandomCrop(80, padding=8), 30 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 31 | transforms.RandomHorizontalFlip(), 32 | lambda x: np.asarray(x), 33 | transforms.ToTensor(), 34 | normalize 35 | ]) 36 | 37 | valTransform = transforms.Compose([transforms.CenterCrop(80), 38 | lambda x: np.asarray(x), 39 | transforms.ToTensor(), 40 | normalize]) 41 | 42 | inputW, inputH, nbCls = 80, 80, 64 43 | 44 | trainDir = './data/Mini-ImageNet/train/' 45 | valDir = './data/Mini-ImageNet/val/' 46 | testDir = './data/Mini-ImageNet/test/' 47 | episodeJson = './data/Mini-ImageNet/val1000Episode_5_way_1_shot.json' if nSupport == 1 \ 48 | else './data/Mini-ImageNet/val1000Episode_5_way_5_shot.json' 49 | 50 | elif dataset == 'Cifar': 51 | mean = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]] 52 | std = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]] 53 | normalize = transforms.Normalize(mean=mean, std=std) 54 | trainTransform = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 57 | transforms.RandomHorizontalFlip(), 58 | lambda x: np.asarray(x), 59 | transforms.ToTensor(), 60 | normalize 61 | ]) 62 | 63 | valTransform = transforms.Compose([lambda x: np.asarray(x), 64 | transforms.ToTensor(), 65 | normalize]) 66 | inputW, inputH, nbCls = 32, 32, 64 67 | 68 | trainDir = './data/cifar-fs/train/' 69 | valDir = './data/cifar-fs/val/' 70 | testDir = './data/cifar-fs/test/' 71 | episodeJson = './data/cifar-fs/val1000Episode_5_way_1_shot.json' if nSupport == 1 \ 72 | else './data/cifar-fs/val1000Episode_5_way_5_shot.json' 73 | 74 | else: 75 | raise ValueError('Do not support other datasets yet.') 76 | 77 | return trainTransform, valTransform, inputW, inputH, trainDir, valDir, testDir, episodeJson, nbCls 78 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | import torch.nn as nn 17 | import random 18 | import itertools 19 | import json 20 | import os 21 | 22 | from algorithm import Algorithm 23 | from networks import get_featnet 24 | from sib import ClassifierSIB 25 | from dataset import dataset_setting 26 | from dataloader import BatchSampler, ValLoader, EpisodeSampler 27 | from utils.config import get_config 28 | from utils.utils import get_logger, set_random_seed 29 | 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.enabled = True 32 | 33 | ############################################################################################# 34 | ## Read hyper-parameters 35 | args = get_config() 36 | 37 | # Setup logging to file and stdout 38 | logger = get_logger(args.logDir, args.expName) 39 | 40 | # Fix random seed to reproduce results 41 | set_random_seed(args.seed) 42 | logger.info('Start experiment with random seed: {:d}'.format(args.seed)) 43 | logger.info(args) 44 | 45 | # GPU setup 46 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 47 | if args.gpu != '': 48 | args.cuda = True 49 | device = torch.device('cuda' if args.cuda else 'cpu') 50 | 51 | ############################################################################################# 52 | ## Datasets 53 | trainTransform, valTransform, inputW, inputH, \ 54 | trainDir, valDir, testDir, episodeJson, nbCls = \ 55 | dataset_setting(args.dataset, args.nSupport) 56 | 57 | trainLoader = BatchSampler(imgDir = trainDir, 58 | nClsEpisode = args.nClsEpisode, 59 | nSupport = args.nSupport, 60 | nQuery = args.nQuery, 61 | transform = trainTransform, 62 | useGPU = args.cuda, 63 | inputW = inputW, 64 | inputH = inputH, 65 | batchSize = args.batchSize) 66 | 67 | valLoader = ValLoader(episodeJson, 68 | valDir, 69 | inputW, 70 | inputH, 71 | valTransform, 72 | args.cuda) 73 | 74 | testLoader = EpisodeSampler(imgDir = testDir, 75 | nClsEpisode = args.nClsEpisode, 76 | nSupport = args.nSupport, 77 | nQuery = args.nQuery, 78 | transform = valTransform, 79 | useGPU = args.cuda, 80 | inputW = inputW, 81 | inputH = inputH) 82 | 83 | 84 | ############################################################################################# 85 | ## Networks 86 | netFeat, args.nFeat = get_featnet(args.architecture, inputW, inputH) 87 | netFeat = netFeat.to(device) 88 | netSIB = ClassifierSIB(args.nClsEpisode, args.nFeat, args.nStep) 89 | netSIB = netSIB.to(device) 90 | 91 | ## Optimizer 92 | optimizer = torch.optim.SGD(itertools.chain(*[netSIB.parameters(),]), 93 | args.lr, 94 | momentum=args.momentum, 95 | weight_decay=args.weightDecay, 96 | nesterov=True) 97 | 98 | ## Loss 99 | criterion = nn.CrossEntropyLoss() 100 | 101 | ## Algorithm class 102 | alg = Algorithm(args, logger, netFeat, netSIB, optimizer, criterion) 103 | 104 | 105 | ############################################################################################# 106 | ## Training 107 | if not args.test: 108 | bestAcc, lastAcc, history = alg.train(trainLoader, valLoader, coeffGrad=args.coeffGrad) 109 | 110 | ## Finish training!!! 111 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netSIBBest.pth'), 112 | os.path.join(args.outDir, 'netSIBBest{:.3f}.pth'.format(bestAcc))) 113 | logger.info(msg) 114 | os.system(msg) 115 | 116 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netSIBLast.pth'), 117 | os.path.join(args.outDir, 'netSIBLast{:.3f}.pth'.format(lastAcc))) 118 | logger.info(msg) 119 | os.system(msg) 120 | 121 | with open(os.path.join(args.outDir, 'history.json'), 'w') as f : 122 | json.dump(history, f) 123 | 124 | msg = 'mv {} {}'.format(args.outDir, '{}_{:.3f}'.format(args.outDir, bestAcc)) 125 | logger.info(msg) 126 | os.system(msg) 127 | 128 | 129 | ############################################################################################# 130 | ## Testing 131 | logger.info('Testing model {}...'.format(args.ckptPth if args.test else 'LAST')) 132 | mean, ci95 = alg.validate(testLoader, mode='test') 133 | 134 | if not args.test: 135 | logger.info('Testing model BEST...') 136 | alg.load_ckpt(os.path.join('{}_{:.3f}'.format(args.outDir, bestAcc), 137 | 'netSIBBest{:.3f}.pth'.format(bestAcc))) 138 | mean, ci95 = alg.validate(testLoader, mode='test') 139 | -------------------------------------------------------------------------------- /main_feat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import itertools 16 | import json 17 | import os 18 | import argparse 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.optim as optim 25 | import torch.backends.cudnn as cudnn 26 | import torchvision.transforms as transforms 27 | from torch.optim.lr_scheduler import MultiStepLR 28 | 29 | from utils.outils import progress_bar, AverageMeter, accuracy, getCi 30 | from utils.utils import to_device 31 | 32 | from dataset import dataset_setting 33 | from networks import get_featnet 34 | from dataloader import TrainLoader, ValLoader 35 | 36 | randomSeed = 123 37 | torch.backends.cudnn.deterministic = True 38 | torch.manual_seed(randomSeed) 39 | 40 | 41 | ############################################################################################# 42 | class ClassifierEval(nn.Module): 43 | ''' 44 | There is nothing to be learned in this classifier 45 | it is only used to evaluate netFeat episodically 46 | ''' 47 | def __init__(self, nKnovel, nFeat): 48 | super(ClassifierEval, self).__init__() 49 | 50 | self.nKnovel = nKnovel 51 | self.nFeat = nFeat 52 | 53 | # bias & scale of classifier 54 | self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=False) 55 | self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=False) 56 | 57 | def apply_classification_weights(self, features, cls_weights): 58 | ''' 59 | (B x n x nFeat, B x nKnovel x nFeat) -> B x n x nKnovel 60 | (B x n x nFeat, B x nKnovel*nExamplar x nFeat) -> B x n x nKnovel*nExamplar if init_type is nn 61 | ''' 62 | features = F.normalize(features, p=2, dim=features.dim()-1, eps=1e-12) 63 | cls_weights = F.normalize(cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12) 64 | cls_scores = self.scale_cls * torch.baddbmm(1.0, self.bias.view(1, 1, 1), 1.0, features, cls_weights.transpose(1,2)) 65 | return cls_scores 66 | 67 | def forward(self, features_supp, features_query): 68 | ''' 69 | features_supp: (B, nKnovel * nExamplar, nFeat) 70 | features_query: (B, nKnovel * nTest, nFeat) 71 | ''' 72 | B = features_supp.size(0) 73 | 74 | weight = features_supp.view(B, self.nKnovel, -1, self.nFeat).mean(2) 75 | cls_scores = self.apply_classification_weights(features_query, weight) 76 | 77 | return cls_scores 78 | 79 | 80 | class ClassifierTrain(nn.Module): 81 | def __init__(self, nCls, nFeat=640, scaleCls = 10.): 82 | super(ClassifierTrain, self).__init__() 83 | 84 | self.scaleCls = scaleCls 85 | self.nFeat = nFeat 86 | self.nCls = nCls 87 | 88 | # weights of base categories 89 | self.weight = torch.FloatTensor(nFeat, nCls).normal_(0.0, np.sqrt(2.0/nFeat)) # Dimension nFeat * nCls 90 | self.weight = nn.Parameter(self.weight, requires_grad=True) 91 | 92 | # bias 93 | self.bias = nn.Parameter(torch.FloatTensor(1, nCls).fill_(0), requires_grad=True) # Dimension 1 * nCls 94 | 95 | # Scale of cls (Heat Parameter) 96 | self.scaleCls = nn.Parameter(torch.FloatTensor(1).fill_(scaleCls), requires_grad=True) 97 | 98 | # Method 99 | self.applyWeight = self.applyWeightCosine 100 | 101 | def getWeight(self): 102 | return self.weight, self.bias, self.scaleCls 103 | 104 | def applyWeightCosine(self, feature, weight, bias, scaleCls): 105 | batchSize, nFeat =feature.size() 106 | 107 | feature = F.normalize(feature, p=2, dim=1, eps=1e-12) ## Attention: normalized along 2nd dimension!!! 108 | weight = F.normalize(weight, p=2, dim=0, eps=1e-12)## Attention: normalized along 1st dimension!!! 109 | 110 | clsScore = scaleCls * (torch.mm(feature, weight) )#+ bias) 111 | return clsScore 112 | 113 | def forward(self, feature): 114 | weight, bias, scaleCls = self.getWeight() 115 | clsScore = self.applyWeight(feature, weight, bias, scaleCls) 116 | return clsScore 117 | 118 | 119 | class BaseTrainer: 120 | def __init__(self, trainLoader, valLoader, nbCls, nClsEpisode, nFeat, 121 | outDir, milestones=[50], inputW=80, inputH=80, cuda=False): 122 | 123 | self.trainLoader = trainLoader 124 | self.valLoader = valLoader 125 | self.outDir = outDir 126 | self.milestones = milestones 127 | if not os.path.isdir(self.outDir): 128 | os.mkdir(self.outDir) 129 | 130 | # Define model 131 | self.netFeat, nFeat = get_featnet('WRN_28_10', inputW, inputH) 132 | self.netClassifier = ClassifierTrain(nbCls) 133 | self.netClassifierVal = ClassifierEval(nClsEpisode, nFeat) 134 | 135 | # GPU setting 136 | self.device = torch.device('cuda' if cuda else 'cpu') 137 | if cuda: 138 | self.netFeat.cuda() 139 | self.netClassifier.cuda() 140 | self.netClassifierVal.cuda() 141 | 142 | self.criterion = nn.CrossEntropyLoss() 143 | self.bestAcc = 0 144 | 145 | def LrWarmUp(self, totalIter, lr): 146 | msg = '\nLearning rate warming up' 147 | print(msg) 148 | 149 | self.optimizer = torch.optim.SGD( 150 | itertools.chain(*[self.netFeat.parameters(), 151 | self.netClassifier.parameters()]), 152 | 1e-7, 153 | momentum=0.9, 154 | weight_decay=5e-4, 155 | nesterov=True) 156 | 157 | nbIter = 0 158 | lrUpdate = lr 159 | valTop1 = 0 160 | 161 | while nbIter < totalIter : 162 | self.netFeat.train() 163 | self.netClassifier.train() 164 | losses = AverageMeter() 165 | top1 = AverageMeter() 166 | top5 = AverageMeter() 167 | 168 | for batchIdx, (inputs, targets) in enumerate(self.trainLoader): 169 | nbIter += 1 170 | if nbIter == totalIter: 171 | break 172 | 173 | lrUpdate = nbIter / float(totalIter) * lr 174 | for g in self.optimizer.param_groups: 175 | g['lr'] = lrUpdate 176 | 177 | inputs = to_device(inputs, self.device) 178 | targets = to_device(targets, self.device) 179 | 180 | self.optimizer.zero_grad() 181 | outputs = self.netFeat(inputs) 182 | outputs = self.netClassifier(outputs) 183 | loss = self.criterion(outputs, targets) 184 | 185 | loss.backward() 186 | self.optimizer.step() 187 | 188 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 189 | losses.update(loss.item(), inputs.size()[0]) 190 | top1.update(acc1[0].item(), inputs.size()[0]) 191 | top5.update(acc5[0].item(), inputs.size()[0]) 192 | 193 | msg = 'Loss: {:.3f} | Lr : {:.5f} | Top1: {:.3f}% | Top5: {:.3f}%'.format( 194 | losses.avg, lrUpdate, top1.avg, top5.avg) 195 | progress_bar(batchIdx, len(self.trainLoader), msg) 196 | 197 | with torch.no_grad(): 198 | valTop1 = self.test(0) 199 | 200 | self.optimizer = torch.optim.SGD( 201 | itertools.chain(*[self.netFeat.parameters(), 202 | self.netClassifier.parameters()]), 203 | lrUpdate, 204 | momentum=0.9, 205 | weight_decay=5e-4, 206 | nesterov=True) 207 | 208 | self.lrScheduler = MultiStepLR(self.optimizer, milestones=self.milestones, gamma=0.1) 209 | return valTop1 210 | 211 | def train(self, epoch): 212 | msg = '\nTrain at Epoch: {:d}'.format(epoch) 213 | print (msg) 214 | 215 | self.netFeat.train() 216 | self.netClassifier.train() 217 | losses = AverageMeter() 218 | top1 = AverageMeter() 219 | top5 = AverageMeter() 220 | 221 | for batchIdx, (inputs, targets) in enumerate(self.trainLoader): 222 | 223 | inputs = to_device(inputs, self.device) 224 | targets = to_device(targets, self.device) 225 | 226 | self.optimizer.zero_grad() 227 | outputs = self.netFeat(inputs) 228 | outputs = self.netClassifier(outputs) 229 | loss = self.criterion(outputs, targets) 230 | 231 | loss.backward() 232 | self.optimizer.step() 233 | 234 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 235 | losses.update(loss.item(), inputs.size()[0]) 236 | top1.update(acc1[0].item(), inputs.size()[0]) 237 | top5.update(acc5[0].item(), inputs.size()[0]) 238 | 239 | msg = 'Loss: {:.3f} | Top1: {:.3f}% | Top5: {:.3f}%'.format(losses.avg, top1.avg, top5.avg) 240 | progress_bar(batchIdx, len(self.trainLoader), msg) 241 | 242 | return losses.avg, top1.avg, top5.avg 243 | 244 | def test(self, epoch): 245 | msg = '\nTest at Epoch: {:d}'.format(epoch) 246 | print (msg) 247 | 248 | self.netFeat.eval() 249 | self.netClassifierVal.eval() 250 | 251 | top1 = AverageMeter() 252 | 253 | for batchIdx, data in enumerate(self.valLoader): 254 | data = to_device(data, self.device) 255 | 256 | SupportTensor, SupportLabel, QueryTensor, QueryLabel = \ 257 | data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \ 258 | data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0) 259 | 260 | SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor) 261 | SupportFeat, QueryFeat = SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0) 262 | 263 | clsScore = self.netClassifierVal(SupportFeat, QueryFeat) 264 | clsScore = clsScore.view(QueryFeat.size()[1], -1) 265 | 266 | acc1 = accuracy(clsScore, QueryLabel, topk=(1, )) 267 | top1.update(acc1[0].item(), clsScore.size()[0]) 268 | msg = 'Top1: {:.3f}%'.format(top1.avg) 269 | progress_bar(batchIdx, len(self.valLoader), msg) 270 | 271 | ## Save checkpoint. 272 | acc = top1.avg 273 | if acc > self.bestAcc: 274 | print ('Saving Best') 275 | torch.save(self.netFeat.state_dict(), os.path.join(self.outDir, 'netFeatBest.pth')) 276 | torch.save(self.netClassifier.state_dict(), os.path.join(self.outDir, 'netClsBest.pth')) 277 | self.bestAcc = acc 278 | 279 | print('Saving Last') 280 | torch.save(self.netFeat.state_dict(), os.path.join(self.outDir, 'netFeatLast.pth')) 281 | torch.save(self.netClassifier.state_dict(), os.path.join(self.outDir, 'netClsLast.pth')) 282 | 283 | msg = 'Best Performance: {:.3f}'.format(self.bestAcc) 284 | print(msg) 285 | return top1.avg 286 | 287 | 288 | ############################################################################################# 289 | ## Parameters 290 | parser = argparse.ArgumentParser(description='Base/FeatureNet Classification') 291 | parser.add_argument('--lr', default=1e-1, type=float, help='learning rate') 292 | parser.add_argument('--weightDecay', default=5e-4, type=float, help='weight decay') 293 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 294 | parser.add_argument('--outDir', type=str, help='output directory') 295 | parser.add_argument('--batchSize', type = int, default = 64, help='batch size') 296 | parser.add_argument('--nbEpoch', type = int, default = 120, help='nb epoch') 297 | parser.add_argument('--cuda', action='store_true', help='whether to use gpu') 298 | parser.add_argument('--resumeFeatPth', type = str, help='resume feature Path') 299 | parser.add_argument('--resumeClassifierPth', type = str, help='resume classifier Path') 300 | parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet', 'CUB', 'Cifar', 'tieredImageNet'], help='Which dataset? Should modify normalization parameter') 301 | # Lr WarmUp 302 | parser.add_argument('--totalIter', type = int, default=6000, help='total iterations for learning rate warm') 303 | # Validation 304 | parser.add_argument('--nFeat', type = int, default=640, help='feature dimension') 305 | 306 | args = parser.parse_args() 307 | print (args) 308 | 309 | 310 | ############################################################################################# 311 | ## datasets 312 | trainTransform, valTransform, inputW, inputH, \ 313 | trainDir, valDir, testDir, episodeJson, nbCls = \ 314 | dataset_setting(args.dataset, 1) 315 | 316 | trainLoader = TrainLoader(args.batchSize, trainDir, trainTransform) 317 | valLoader = ValLoader(episodeJson, valDir, inputW, inputH, valTransform, args.cuda) 318 | 319 | with open(episodeJson, 'r') as f: 320 | episodeInfo = json.load(f) 321 | 322 | args.nClsEpisode = len(episodeInfo[0]['Support']) 323 | args.nSupport = len(episodeInfo[0]['Support'][0]) 324 | args.nQuery = len(episodeInfo[0]['Query'][0]) 325 | 326 | 327 | ############################################################################################# 328 | ## model 329 | 330 | #milestones=[50, 80, 100] 331 | milestones = [100] if args.dataset == 'CUB' else [50] # More epochs for CUB since less iterations / epoch 332 | 333 | baseModel = BaseTrainer(trainLoader, valLoader, nbCls, 334 | args.nClsEpisode, args.nFeat, args.outDir, milestones, 335 | inputW, inputH, 336 | args.cuda) 337 | 338 | ## Load pretrained model if there is 339 | if args.resumeFeatPth : 340 | baseModel.netFeat.load_state_dict(torch.load(args.resumeFeatPth)) 341 | msg = 'Loading weight from {}'.format(args.resumeFeatPth) 342 | print (msg) 343 | 344 | if args.resumeClassifierPth : 345 | baseModel.netClassifier.load_state_dict(torch.load(args.resumeClassifierPth)) 346 | msg = 'Loading weight from {}'.format(args.resumeClassifierPth) 347 | print (msg) 348 | 349 | 350 | 351 | ############################################################################################# 352 | ## main 353 | valTop1 = baseModel.LrWarmUp(args.totalIter, args.lr) 354 | 355 | testAccLog = [] 356 | trainAccLog = [] 357 | 358 | history = {'trainTop1':[], 'valTop1':[], 'trainTop5':[], 'trainLoss':[]} 359 | 360 | for epoch in range(args.nbEpoch): 361 | trainLoss, trainTop1, trainTop5 = baseModel.train(epoch) 362 | with torch.no_grad() : 363 | valTop1 = baseModel.test(epoch) 364 | history['trainTop1'].append(trainTop1) 365 | history['trainTop5'].append(trainTop5) 366 | history['trainLoss'].append(trainLoss) 367 | history['valTop1'].append(valTop1) 368 | 369 | with open(os.path.join(args.outDir, 'history.json'), 'w') as f : 370 | json.dump(history, f) 371 | baseModel.lrScheduler.step() 372 | 373 | ## Finish training!!! 374 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netFeatBest.pth'), os.path.join(args.outDir, 'netFeatBest{:.3f}.pth'.format(baseModel.bestAcc))) 375 | print (msg) 376 | os.system(msg) 377 | 378 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netFeatLast.pth'), os.path.join(args.outDir, 'netFeatLast{:.3f}.pth'.format(valTop1))) 379 | print (msg) 380 | os.system(msg) 381 | 382 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netClsBest.pth'), os.path.join(args.outDir, 'netClsBest{:.3f}.pth'.format(baseModel.bestAcc))) 383 | print (msg) 384 | os.system(msg) 385 | 386 | msg = 'mv {} {}'.format(os.path.join(args.outDir, 'netClsLast.pth'), os.path.join(args.outDir, 'netClsLast{:.3f}.pth'.format(valTop1))) 387 | print (msg) 388 | os.system(msg) 389 | 390 | msg = 'mv {} {}'.format(args.outDir, '{}_{:.3f}'.format(args.outDir, valTop1)) 391 | print (msg) 392 | os.system(msg) 393 | 394 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import numpy as np 20 | 21 | class ConvBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes): 23 | super(ConvBlock, self).__init__() 24 | self.layers = nn.Sequential() 25 | self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes, 26 | kernel_size=3, stride=1, padding=1, bias=False)) 27 | self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes)) 28 | 29 | self.layers.add_module('ReLU', nn.ReLU(inplace=True)) 30 | 31 | self.layers.add_module( 32 | 'MaxPool', nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) 33 | 34 | def forward(self, x): 35 | out = self.layers(x) 36 | return out 37 | 38 | class ConvNet_4_64(nn.Module): 39 | def __init__(self, inputW=80, inputH=80): 40 | super(ConvNet_4_64, self).__init__() 41 | 42 | conv_blocks = [] 43 | ## 4 blocks, each block conv + bn + relu + maxpool, with filter 64 44 | conv_blocks.append(ConvBlock(3, 64)) 45 | for i in range(3): 46 | conv_blocks.append(ConvBlock(64, 64)) 47 | 48 | self.conv_blocks = nn.Sequential(*conv_blocks) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | def forward(self, x): 59 | out = self.conv_blocks(x) 60 | out = out.view(out.size(0),-1) 61 | return out 62 | 63 | 64 | class BasicBlock(nn.Module): 65 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 66 | super(BasicBlock, self).__init__() 67 | self.bn1 = nn.BatchNorm2d(in_planes) 68 | self.relu1 = nn.ReLU(inplace=True) 69 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(out_planes) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 74 | padding=1, bias=False) 75 | 76 | self.droprate = dropRate 77 | if self.droprate > 0: 78 | self.dropoutLayer = nn.Dropout(p=self.droprate) 79 | self.equalInOut = (in_planes == out_planes) 80 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 81 | padding=0, bias=False) or None 82 | def forward(self, x): 83 | if not self.equalInOut: 84 | x = self.relu1(self.bn1(x)) 85 | else: 86 | out = self.relu1(self.bn1(x)) 87 | 88 | out = out if self.equalInOut else x 89 | out = self.conv1(out) 90 | if self.droprate > 0: 91 | out = self.dropoutLayer(out) 92 | #out = F.dropout(out, p=self.droprate, training=self.training) 93 | out = self.conv2(self.relu2(self.bn2(out))) 94 | 95 | if not self.equalInOut: 96 | return self.convShortcut(x) + out 97 | else: 98 | return x + out 99 | 100 | 101 | class NetworkBlock(nn.Module): 102 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 103 | super(NetworkBlock, self).__init__() 104 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 105 | 106 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 107 | layers = [] 108 | for i in range(nb_layers): 109 | in_plances_arg = i == 0 and in_planes or out_planes 110 | stride_arg = i == 0 and stride or 1 111 | layers.append(block(in_plances_arg, out_planes, stride_arg, dropRate)) 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | return self.layer(x) 116 | 117 | 118 | class WideResNet(nn.Module): 119 | def __init__(self, depth=28, widen_factor=10, dropRate=0.0, userelu=True, isCifar=False): 120 | super(WideResNet, self).__init__() 121 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 122 | assert((depth - 4) % 6 == 0) 123 | n = (depth - 4) // 6 124 | block = BasicBlock 125 | # 1st conv before any network block 126 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 127 | 128 | # 1st block 129 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) if isCifar \ 130 | else NetworkBlock(n, nChannels[0], nChannels[1], block, 2, dropRate) 131 | # 2nd block 132 | 133 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 134 | # 3rd block 135 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 136 | 137 | # global average pooling and classifier 138 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 139 | self.relu = nn.ReLU(inplace=True) if userelu else None 140 | self.nChannels = nChannels[3] 141 | 142 | for m in self.modules(): 143 | if isinstance(m, nn.Conv2d): 144 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 145 | m.weight.data.normal_(0, math.sqrt(2. / n)) 146 | elif isinstance(m, nn.BatchNorm2d): 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | 150 | def forward(self, x): 151 | out = self.conv1(x) 152 | out = self.block1(out) 153 | out = self.block2(out) 154 | out = self.block3(out) 155 | out = self.bn1(out) 156 | 157 | if self.relu is not None: 158 | out = self.relu(out) 159 | 160 | out = F.avg_pool2d(out, out.size(3)) 161 | out = out.view(-1, self.nChannels) 162 | 163 | return out 164 | 165 | 166 | def label_to_1hot(label, K): 167 | B, N = label.size() 168 | labels = [] 169 | for i in range(K): 170 | labels.append((label == i).unsqueeze(2)) 171 | return torch.cat(labels, -1).float() 172 | 173 | 174 | class dni_linear(nn.Module): 175 | def __init__(self, input_dims, dni_hidden_size=1024): 176 | super(dni_linear, self).__init__() 177 | self.layer1 = nn.Sequential( 178 | nn.Linear(input_dims, dni_hidden_size), 179 | nn.ReLU(), 180 | nn.BatchNorm1d(dni_hidden_size) 181 | ) 182 | self.layer2 = nn.Sequential( 183 | nn.Linear(dni_hidden_size, dni_hidden_size), 184 | nn.ReLU(), 185 | nn.BatchNorm1d(dni_hidden_size) 186 | ) 187 | self.layer3 = nn.Linear(dni_hidden_size, input_dims) 188 | 189 | def forward(self, x): 190 | out = self.layer1(x) 191 | out = self.layer2(out) 192 | out = self.layer3(out) 193 | return out 194 | 195 | 196 | class LinearDiag(nn.Module): 197 | def __init__(self, num_features, bias=False): 198 | super(LinearDiag, self).__init__() 199 | weight = torch.FloatTensor(num_features).fill_(1) # initialize to the identity transform 200 | self.weight = nn.Parameter(weight, requires_grad=True) 201 | 202 | if bias: 203 | bias = torch.FloatTensor(num_features).fill_(0) 204 | self.bias = nn.Parameter(bias, requires_grad=True) 205 | else: 206 | self.register_parameter('bias', None) 207 | 208 | def forward(self, X): 209 | assert(X.dim()==2 and X.size(1)==self.weight.size(0)) 210 | out = X * self.weight.expand_as(X) 211 | if self.bias is not None: 212 | out = out + self.bias.expand_as(out) 213 | return out 214 | 215 | 216 | class FeatExemplarAvgBlock(nn.Module): 217 | def __init__(self, nFeat): 218 | super(FeatExemplarAvgBlock, self).__init__() 219 | 220 | def forward(self, features_train, labels_train): 221 | labels_train_transposed = labels_train.transpose(1,2) 222 | # B x nK x nT @ B x nT x nC = B x nK x nC 223 | weight_novel = torch.bmm(labels_train_transposed, features_train) 224 | weight_novel = weight_novel.div( 225 | labels_train_transposed.sum(dim=2, keepdim=True).expand_as(weight_novel)) 226 | return weight_novel 227 | 228 | 229 | def get_featnet(architecture, inputW=80, inputH=80): 230 | # if cifar dataset, the last 2 blocks of WRN should be without stride 231 | isCifar = (inputW == 32) or (inputH == 32) 232 | if architecture == 'WRN_28_10': 233 | net = WideResNet(28, 10, isCifar=isCifar) 234 | return net, net.nChannels 235 | 236 | elif architecture == 'ConvNet_4_64': 237 | return eval(architecture)(inputW, inputH), 64 * (inputH/2**4) * (inputW/2**4) 238 | 239 | else: 240 | raise ValueError('No such feature net available!') 241 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | easydict 3 | tqdm 4 | bypy 5 | -------------------------------------------------------------------------------- /sib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from networks import label_to_1hot, dni_linear, LinearDiag, FeatExemplarAvgBlock 19 | 20 | 21 | class ClassifierSIB(nn.Module): 22 | """ 23 | Classifier whose weights are generated dynamically from synthetic gradient descent: 24 | Objective: E_{q(w | d_t^l, x_t)}[ -log p(y | feat(x), w) ] + KL( q(w|...) || p(w) ) 25 | Note: we use a simple parameterization 26 | - q(w | d_t^l, x_t) = Dirac_Delta(w - theta^k), 27 | theta^k = synthetic_gradient_descent(x_t, theta^0) 28 | theta^0 = init_net(d_t) 29 | - p(w) = zero-mean Gaussian and implemented by weight decay 30 | - p(y=k | feat(x), w) = prototypical network 31 | 32 | :param int nKnovel: number of categories in a task/episode. 33 | :param int nFeat: feature dimension of the input feature. 34 | :param int q_steps: number of synthetic gradient steps to obtain q(w | d_t^l, x_t). 35 | """ 36 | def __init__(self, nKnovel, nFeat, q_steps): 37 | super(ClassifierSIB, self).__init__() 38 | 39 | self.nKnovel = nKnovel 40 | self.nFeat = nFeat 41 | self.q_steps = q_steps 42 | 43 | # bias & scale of classifier p(y | x, theta) 44 | self.bias = nn.Parameter(torch.FloatTensor(1).fill_(0), requires_grad=True) 45 | self.scale_cls = nn.Parameter(torch.FloatTensor(1).fill_(10), requires_grad=True) 46 | 47 | # init_net lambda(d_t^l) 48 | self.favgblock = FeatExemplarAvgBlock(self.nFeat) 49 | self.wnLayerFavg = LinearDiag(self.nFeat) 50 | 51 | # grad_net (aka decoupled network interface) phi(x_t) 52 | self.dni = dni_linear(self.nKnovel, dni_hidden_size=self.nKnovel*8) 53 | 54 | def apply_classification_weights(self, features, cls_weights): 55 | """ 56 | Given feature and weights, computing negative log-likelihoods of nKnovel classes 57 | (B x n x nFeat, B x nKnovel x nFeat) -> B x n x nKnovel 58 | 59 | :param features: features of query set. 60 | :type features: torch.FloatTensor 61 | :param cls_weights: generated weights. 62 | :type cls_weights: torch.FloatTensor 63 | :return: classification scores 64 | :rtype: torch.FloatTensor 65 | """ 66 | features = F.normalize(features, p=2, dim=features.dim()-1, eps=1e-12) 67 | cls_weights = F.normalize(cls_weights, p=2, dim=cls_weights.dim()-1, eps=1e-12) 68 | 69 | cls_scores = self.scale_cls * torch.baddbmm(1.0, self.bias.view(1, 1, 1), 1.0, 70 | features, cls_weights.transpose(1,2)) 71 | return cls_scores 72 | 73 | def init_theta(self, features_supp, labels_supp_1hot): 74 | """ 75 | Compute theta^0 from support set using classwise feature averaging. 76 | 77 | :param features_supp: support features, B x nSupp x nFeat. 78 | :type features_supp: torch.FloatTensor 79 | :param labels_supp_1hot: one-hot representation of labels in support set. 80 | :return: theta^0, B * nKnovel x nFeat 81 | """ 82 | theta = self.favgblock(features_supp, labels_supp_1hot) # B x nKnovel x nFeat 83 | batch_size, nKnovel, num_channels = theta.size() 84 | theta = theta.view(batch_size * nKnovel, num_channels) 85 | theta = self.wnLayerFavg(theta) # weight each feature differently 86 | theta = theta.view(-1, nKnovel, num_channels) 87 | return theta 88 | 89 | def refine_theta(self, theta, features_query, lr=1e-3): 90 | """ 91 | Compute theta^k using synthetic gradient descent on x_t. 92 | 93 | :param theta: theta^0 94 | :type theta: torch.FloatTensor 95 | :param features_query: feat(x_t) 96 | :type features_query: torch.FloatTensor 97 | :param float lr: learning rate 98 | :return: theta^k 99 | :rtype: torch.FloatTensor 100 | """ 101 | batch_size, num_examples = features_query.size()[:2] 102 | new_batch_dim = batch_size * num_examples 103 | 104 | for t in range(self.q_steps): 105 | cls_scores = self.apply_classification_weights(features_query, theta) 106 | cls_scores = cls_scores.view(new_batch_dim, -1) # B * n x nKnovel 107 | grad_logit = self.dni(cls_scores) # B * n x nKnovel 108 | grad = torch.autograd.grad([cls_scores], [theta], 109 | grad_outputs=[grad_logit], 110 | create_graph=True, retain_graph=True, 111 | only_inputs=True)[0] # B x nKnovel x nFeat 112 | 113 | # perform synthetic GD 114 | theta = theta - lr * grad 115 | 116 | return theta 117 | 118 | def get_classification_weights(self, features_supp, labels_supp_1hot, features_query, lr): 119 | """ 120 | Obtain weights for the query set using features_supp, labels_supp and features_query. 121 | features_supp, labels_supp --> self.init_theta 122 | features_query --> self.refine_theta 123 | 124 | :features_supp: feat(x_t^l) 125 | :type features_supp: torch.FloatTensor 126 | :labels_supp_1hot: one-hot representation of support labels 127 | :type labels_supp: torch.FloatTensor 128 | :features_query: feat(x_t) 129 | :type features_query: torch.FloatTensor 130 | :lr float: learning rate of synthetic GD 131 | :return: weights for query set 132 | :rtype: torch.FloatTensor 133 | """ 134 | features_supp = F.normalize(features_supp, p=2, dim=features_supp.dim()-1, eps=1e-12) 135 | 136 | weight_novel = self.init_theta(features_supp, labels_supp_1hot) 137 | weight_novel = self.refine_theta(weight_novel, features_query, lr) 138 | 139 | return weight_novel 140 | 141 | 142 | def forward(self, features_supp, labels_supp, features_query, lr): 143 | """ 144 | Compute classification scores. 145 | :labels_supp_1hot: one-hot representation of support labels 146 | 147 | :features_supp: B x nKnovel*nExamplar x nFeat 148 | :type features_supp: torch.FloatTensor 149 | :labels_supp: B x nknovel*nExamplar in [0, nKnovel-1] 150 | :type labels_supp: torch.FloatTensor 151 | :features_query: B x nKnovel*nTest x nFeat 152 | :type features_query: torch.FloatTensor 153 | :lr float: learning rate of synthetic GD 154 | :return: classification scores 155 | :rtype: torch.FloatTensor 156 | """ 157 | labels_supp_1hot = label_to_1hot(labels_supp, self.nKnovel) 158 | cls_weights = self.get_classification_weights(features_supp, labels_supp_1hot, features_query, lr) 159 | cls_scores = self.apply_classification_weights(features_query, cls_weights) 160 | 161 | return cls_scores 162 | 163 | 164 | if __name__ == "__main__": 165 | net = ClassifierSIB(nKall=64, nKnovel=5, nFeat=512, q_steps=3) 166 | net = net.cuda() 167 | 168 | features_supp = torch.rand((8, 5 * 1, 512)).cuda() 169 | features_query = torch.rand((8, 5 * 15, 512)).cuda() 170 | labels_supp = torch.randint(5, (8, 5 * 1)).cuda() 171 | lr = 1e-3 172 | 173 | cls_scores = net(features_supp, labels_supp, features_query, lr) 174 | print(cls_scores.size()) 175 | 176 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import json 17 | import yaml 18 | import argparse 19 | from easydict import EasyDict 20 | 21 | def create_dirs(dirs): 22 | """ 23 | Create directories given by a list if these directories are not found 24 | 25 | :param list dirs: directories 26 | :return exit_code: 0:success -1:failed 27 | """ 28 | try: 29 | for dir_ in dirs: 30 | if not os.path.exists(dir_): 31 | os.makedirs(dir_) 32 | return 0 33 | except Exception as err: 34 | print("Creating directories error: {0}".format(err)) 35 | exit(-1) 36 | 37 | 38 | def get_config_from_json(json_file): 39 | """ 40 | Get the config from a json file 41 | 42 | :param string json_file: json configuration file 43 | :return: EasyDict config 44 | """ 45 | # parse the configurations from the config json file provided 46 | with open(json_file, 'r') as config_file: 47 | config_dict = json.load(config_file) 48 | 49 | # convert the dictionary to a namespace using bunch lib 50 | config = EasyDict(config_dict) 51 | return config 52 | 53 | 54 | def get_config_from_yaml(yaml_file): 55 | """ 56 | Get the config from a yaml file 57 | 58 | :param string yaml_file: yaml configuration file 59 | :return: EasyDict config 60 | """ 61 | with open(yaml_file) as fp: 62 | config_dict = yaml.load(fp) 63 | 64 | # convert the dictionary to a namespace using bunch lib 65 | config = EasyDict(config_dict) 66 | return config 67 | 68 | 69 | def get_args(): 70 | """ 71 | Create argparser for frequent configurations. 72 | 73 | :return: argparser object 74 | """ 75 | argparser = argparse.ArgumentParser(description=__doc__) 76 | argparser.add_argument( 77 | '-c', '--config', 78 | metavar='C', 79 | default=None, 80 | help='The Configuration file') 81 | argparser.add_argument( 82 | '-k', '--steps', 83 | default=3, 84 | type=int, 85 | help='The number of SIB steps') 86 | argparser.add_argument( 87 | '-s', '--seed', 88 | default=100, 89 | type=int, 90 | help='The random seed') 91 | argparser.add_argument( 92 | '--gpu', 93 | default=0, 94 | type=int, 95 | help='GPU id') 96 | argparser.add_argument( 97 | '--ckpt', 98 | default=None, 99 | help='The path to ckpt') 100 | args = argparser.parse_args() 101 | return args 102 | 103 | 104 | def get_config(): 105 | """ 106 | Create experimental config from argparse and config file. 107 | 108 | :return: Configuration EasyDict 109 | """ 110 | # read manual args 111 | args = get_args() 112 | config_file = args.config 113 | 114 | # load experimental configuration 115 | if config_file.endswith('json'): 116 | config = get_config_from_json(config_file) 117 | elif config_file.endswith('yaml'): 118 | config = get_config_from_yaml(config_file) 119 | else: 120 | raise Exception("Only .json and .yaml are supported!") 121 | 122 | # reset config from args 123 | config.nStep = args.steps 124 | config.seed = args.seed 125 | config.gpu = args.gpu 126 | config.test = False if args.ckpt is None else True 127 | config.ckptPth = args.ckpt 128 | 129 | # create directories 130 | config.cacheDir = os.path.join("cache", '{}_{}shot_K{}_seed{}'.format( 131 | config.expName, config.nSupport, config.nStep, config.seed)) 132 | config.logDir = os.path.join(config.cacheDir, 'logs') 133 | config.outDir = os.path.join(config.cacheDir, 'outputs') 134 | create_dirs([config.cacheDir, config.logDir, config.outDir]) 135 | 136 | return config 137 | -------------------------------------------------------------------------------- /utils/outils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | '''Some helper functions for PyTorch, including: 16 | - get_mean_and_std: calculate the mean and std value of dataset. 17 | - msr_init: net parameter initialization. 18 | - progress_bar: progress bar mimic xlua.progress. 19 | ''' 20 | import os 21 | import sys 22 | import time 23 | import math 24 | 25 | import torch.nn as nn 26 | import torch.nn.init as init 27 | import torch 28 | import numpy as np 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | def getCi(accLog): 49 | 50 | mean = np.mean(accLog) 51 | std = np.std(accLog) 52 | ci95 = 1.96*std/np.sqrt(len(accLog)) 53 | 54 | return mean, ci95 55 | 56 | 57 | def accuracy(output, target, topk=(1,)): 58 | """Computes the accuracy over the k top predictions for the specified values of k""" 59 | with torch.no_grad(): 60 | maxk = max(topk) 61 | batch_size = target.size()[0] 62 | 63 | _, pred = output.topk(maxk, 1, True, True) 64 | pred = pred.t() 65 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 66 | 67 | res = [] 68 | for k in topk: 69 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 70 | res.append(correct_k.mul_(100.0 / batch_size)) 71 | return res 72 | 73 | 74 | def get_mean_and_std(dataset): 75 | '''Compute the mean and std value of dataset.''' 76 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 77 | mean = torch.zeros(3) 78 | std = torch.zeros(3) 79 | print('==> Computing mean and std..') 80 | for inputs, targets in dataloader: 81 | for i in range(3): 82 | mean[i] += inputs[:,i,:,:].mean() 83 | std[i] += inputs[:,i,:,:].std() 84 | mean.div_(len(dataset)) 85 | std.div_(len(dataset)) 86 | return mean, std 87 | 88 | 89 | def init_params(net): 90 | '''Init layer parameters.''' 91 | for m in net.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | init.kaiming_normal(m.weight, mode='fan_out') 94 | if m.bias: 95 | init.constant(m.bias, 0) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | init.constant(m.weight, 1) 98 | init.constant(m.bias, 0) 99 | elif isinstance(m, nn.Linear): 100 | init.normal(m.weight, std=1e-3) 101 | if m.bias: 102 | init.constant(m.bias, 0) 103 | 104 | 105 | _, term_width = os.popen('stty size', 'r').read().split() 106 | term_width = int(term_width) 107 | 108 | TOTAL_BAR_LENGTH = 65. 109 | last_time = time.time() 110 | begin_time = last_time 111 | 112 | 113 | def progress_bar(current, total, msg=None): 114 | global last_time, begin_time 115 | if current == 0: 116 | begin_time = time.time() # Reset for new bar. 117 | 118 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 119 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 120 | 121 | sys.stdout.write(' [') 122 | for i in range(cur_len): 123 | sys.stdout.write('=') 124 | sys.stdout.write('>') 125 | for i in range(rest_len): 126 | sys.stdout.write('.') 127 | sys.stdout.write(']') 128 | 129 | cur_time = time.time() 130 | step_time = cur_time - last_time 131 | last_time = cur_time 132 | tot_time = cur_time - begin_time 133 | 134 | L = [] 135 | L.append(' Step: %s' % format_time(step_time)) 136 | L.append(' | Tot: %s' % format_time(tot_time)) 137 | if msg: 138 | L.append(' | ' + msg) 139 | 140 | msg = ''.join(L) 141 | sys.stdout.write(msg) 142 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 143 | sys.stdout.write(' ') 144 | 145 | # Go back to the center of the bar. 146 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 147 | sys.stdout.write('\b') 148 | sys.stdout.write(' %d/%d ' % (current+1, total)) 149 | 150 | if current < total-1: 151 | sys.stdout.write('\r') 152 | else: 153 | sys.stdout.write('\n') 154 | sys.stdout.flush() 155 | 156 | 157 | def format_time(seconds): 158 | days = int(seconds / 3600/24) 159 | seconds = seconds - days*3600*24 160 | hours = int(seconds / 3600) 161 | seconds = seconds - hours*3600 162 | minutes = int(seconds / 60) 163 | seconds = seconds - minutes*60 164 | secondsf = int(seconds) 165 | seconds = seconds - secondsf 166 | millis = int(seconds*1000) 167 | 168 | f = '' 169 | i = 1 170 | if days > 0: 171 | f += str(days) + 'D' 172 | i += 1 173 | if hours > 0 and i <= 2: 174 | f += str(hours) + 'h' 175 | i += 1 176 | if minutes > 0 and i <= 2: 177 | f += str(minutes) + 'm' 178 | i += 1 179 | if secondsf > 0 and i <= 2: 180 | f += str(secondsf) + 's' 181 | i += 1 182 | if millis > 0 and i <= 2: 183 | f += str(millis) + 'ms' 184 | i += 1 185 | if f == '': 186 | f = '0ms' 187 | return f 188 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | # ============================================================================== 14 | 15 | import os 16 | import sys 17 | import torch 18 | import logging 19 | import datetime 20 | import collections 21 | import numpy as np 22 | from collections import OrderedDict 23 | import random as pyrandom 24 | 25 | 26 | def set_random_seed(seed=3): 27 | pyrandom.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | 32 | 33 | def to_device(input, device): 34 | if torch.is_tensor(input): 35 | return input.to(device=device) 36 | elif isinstance(input, str): 37 | return input 38 | elif isinstance(input, collections.Mapping): 39 | return {k: to_device(sample, device=device) for k, sample in input.items()} 40 | elif isinstance(input, collections.Sequence): 41 | return [to_device(sample, device=device) for sample in input] 42 | else: 43 | raise TypeError("Input must contain tensor, dict or list, found {type(input)}") 44 | 45 | 46 | def fast_hist(label_pred, label_true, n_class): 47 | mask = (label_true >= 0) & (label_true < n_class) 48 | return np.bincount( 49 | n_class * label_true[mask].astype(int) + label_pred[mask], 50 | minlength=n_class ** 2).reshape(n_class, n_class) 51 | 52 | 53 | def convert_state_dict(state_dict): 54 | """ 55 | Converts a state dict saved from a dataParallel module to normal 56 | module state_dict inplace 57 | 58 | :param dict state_dict: is the loaded DataParallel model_state 59 | """ 60 | new_state_dict = OrderedDict() 61 | for k, v in state_dict.items(): 62 | name = k[7:] # remove `module.` 63 | new_state_dict[name] = v 64 | return new_state_dict 65 | 66 | 67 | def get_logger(logdir, name): 68 | logger = logging.getLogger(name) 69 | logger.setLevel(logging.INFO) 70 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 71 | 72 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 73 | ts = ts.replace(":", "_").replace("-", "_") 74 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 75 | file_hdlr = logging.FileHandler(file_path) 76 | file_hdlr.setFormatter(formatter) 77 | 78 | strm_hdlr = logging.StreamHandler(sys.stdout) 79 | strm_hdlr.setFormatter(formatter) 80 | 81 | logger.addHandler(file_hdlr) 82 | logger.addHandler(strm_hdlr) 83 | return logger 84 | --------------------------------------------------------------------------------