163 | 164 | #### If you find this repository useful in your research, please cite: 165 | @article{shalam2024balanced, 166 | title={The Balanced-Pairwise-Affinities Feature Transform}, 167 | author={Shalam, Daniel and Korman, Simon}, 168 | journal={arXiv preprint arXiv:2407.01467}, 169 | year={2024} 170 | } 171 | 172 |
173 | 174 | --- 175 | 176 | ## Acknowledgment 177 | [Leveraging the Feature Distribution in Transfer-based Few-Shot Learning](https://github.com/yhu01/PT-MAP) 178 | 179 | [S2M2 Charting the Right Manifold: Manifold Mixup for Few-shot Learning](https://arxiv.org/pdf/1907.12087.pdf) 180 | 181 | [Few-Shot Learning via Embedding Adaptation with Set-to-Set Functions](https://arxiv.org/pdf/1812.03664.pdf) 182 | -------------------------------------------------------------------------------- /methods/pt_map/evaluation/test_standard.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import math 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from numpy import linalg as LA 15 | from tqdm.notebook import tqdm 16 | 17 | from bpa import BPA 18 | 19 | 20 | use_gpu = torch.cuda.is_available() 21 | 22 | 23 | # ======================================== 24 | # loading datas 25 | 26 | 27 | def centerDatas(datas): 28 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] - datas[:, :n_lsamples].mean(1, keepdim=True) 29 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] / torch.norm(datas[:, :n_lsamples, :], 2, 2)[:, :, None] 30 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] - datas[:, n_lsamples:].mean(1, keepdim=True) 31 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] / torch.norm(datas[:, n_lsamples:, :], 2, 2)[:, :, None] 32 | 33 | return datas 34 | 35 | def scaleEachUnitaryDatas(datas): 36 | 37 | norms = datas.norm(dim=2, keepdim=True) 38 | return datas/norms 39 | 40 | 41 | def QRreduction(datas): 42 | 43 | ndatas = torch.qr(datas.permute(0,2,1)).R 44 | ndatas = ndatas.permute(0,2,1) 45 | return ndatas 46 | 47 | 48 | class Model: 49 | def __init__(self, n_ways): 50 | self.n_ways = n_ways 51 | 52 | # --------- GaussianModel 53 | class GaussianModel(Model): 54 | def __init__(self, n_ways, lam): 55 | super(GaussianModel, self).__init__(n_ways) 56 | self.mus = None # shape [n_runs][n_ways][n_nfeat] 57 | self.lam = lam 58 | 59 | def clone(self): 60 | other = GaussianModel(self.n_ways) 61 | other.mus = self.mus.clone() 62 | return self 63 | 64 | def cuda(self): 65 | self.mus = self.mus.cuda() 66 | 67 | def initFromLabelledDatas(self): 68 | self.mus = ndatas.reshape(n_runs, n_shot+n_queries,n_ways, n_nfeat)[:,:n_shot,].mean(1) 69 | 70 | def updateFromEstimate(self, estimate, alpha): 71 | 72 | Dmus = estimate - self.mus 73 | self.mus = self.mus + alpha * (Dmus) 74 | 75 | def compute_optimal_transport(self, M, r, c, epsilon=1e-6): 76 | 77 | r = r.cuda() 78 | c = c.cuda() 79 | n_runs, n, m = M.shape 80 | P = torch.exp(- self.lam * M) 81 | P /= P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1) 82 | 83 | u = torch.zeros(n_runs, n).cuda() 84 | maxiters = 1000 85 | iters = 1 86 | # normalize this matrix 87 | while torch.max(torch.abs(u - P.sum(2))) > epsilon: 88 | u = P.sum(2) 89 | P *= (r / u).view((n_runs, -1, 1)) 90 | P *= (c / P.sum(1)).view((n_runs, 1, -1)) 91 | if iters == maxiters: 92 | break 93 | iters = iters + 1 94 | return P, torch.sum(P * M) 95 | 96 | def getProbas(self): 97 | # compute squared dist to centroids [n_runs][n_samples][n_ways] 98 | dist = (ndatas.unsqueeze(2)-self.mus.unsqueeze(1)).norm(dim=3).pow(2) 99 | 100 | p_xj = torch.zeros_like(dist) 101 | r = torch.ones(n_runs, n_usamples) 102 | c = torch.ones(n_runs, n_ways) * n_queries 103 | 104 | p_xj_test, _ = self.compute_optimal_transport(dist[:, n_lsamples:], r, c, epsilon=1e-6) 105 | p_xj[:, n_lsamples:] = p_xj_test 106 | 107 | p_xj[:,:n_lsamples].fill_(0) 108 | p_xj[:,:n_lsamples].scatter_(2,labels[:,:n_lsamples].unsqueeze(2), 1) 109 | 110 | return p_xj 111 | 112 | def estimateFromMask(self, mask): 113 | 114 | emus = mask.permute(0,2,1).matmul(ndatas).div(mask.sum(dim=1).unsqueeze(2)) 115 | 116 | return emus 117 | 118 | 119 | # ========================================= 120 | # MAP 121 | # ========================================= 122 | 123 | class MAP: 124 | def __init__(self, alpha=None): 125 | 126 | self.verbose = False 127 | self.progressBar = False 128 | self.alpha = alpha 129 | 130 | def getAccuracy(self, probas): 131 | olabels = probas.argmax(dim=2) 132 | matches = labels.eq(olabels).float() 133 | acc_test = matches[:,n_lsamples:].mean(1) 134 | 135 | m = acc_test.mean().item() 136 | pm = acc_test.std().item() *1.96 / math.sqrt(n_runs) 137 | return m, pm 138 | 139 | def performEpoch(self, model, epochInfo=None): 140 | 141 | p_xj = model.getProbas() 142 | self.probas = p_xj 143 | 144 | if self.verbose: 145 | print("accuracy from filtered probas", self.getAccuracy(self.probas)) 146 | 147 | m_estimates = model.estimateFromMask(self.probas) 148 | 149 | # update centroids 150 | model.updateFromEstimate(m_estimates, self.alpha) 151 | 152 | if self.verbose: 153 | op_xj = model.getProbas() 154 | acc = self.getAccuracy(op_xj) 155 | print("output model accuracy", acc) 156 | 157 | def loop(self, model, n_epochs=20): 158 | 159 | self.probas = model.getProbas() 160 | if self.verbose: 161 | print("initialisation model accuracy", self.getAccuracy(self.probas)) 162 | 163 | if self.progressBar: 164 | if type(self.progressBar) == bool: 165 | pb = tqdm(total = n_epochs) 166 | else: 167 | pb = self.progressBar 168 | 169 | for epoch in range(1, n_epochs+1): 170 | if self.verbose: 171 | print("----- epoch[{:3d}] lr_p: {:0.3f} lr_m: {:0.3f}".format(epoch, self.alpha)) 172 | self.performEpoch(model, epochInfo=(epoch, n_epochs)) 173 | if (self.progressBar): pb.update() 174 | 175 | # get final accuracy and return it 176 | op_xj = model.getProbas() 177 | acc = self.getAccuracy(op_xj) 178 | return acc 179 | 180 | 181 | if __name__ == '__main__': 182 | # ---- data loading 183 | n_shot = 5 184 | n_ways = 5 185 | n_queries = 15 186 | n_runs=10000 187 | n_lsamples = n_ways * n_shot 188 | n_usamples = n_ways * n_queries 189 | n_samples = n_lsamples + n_usamples 190 | 191 | import FSLTask 192 | cfg = {'shot':n_shot, 'ways':n_ways, 'queries':n_queries} 193 | FSLTask.loadDataSet("miniimagenet") 194 | FSLTask.setRandomStates(cfg) 195 | ndatas = FSLTask.GenerateRunSet(cfg=cfg) 196 | ndatas = ndatas.permute(0,2,1,3).reshape(n_runs, n_samples, -1) 197 | labels = torch.arange(n_ways).view(1,1,n_ways).expand(n_runs,n_shot+n_queries,n_ways).clone().view(n_runs, n_samples) 198 | 199 | # Power transform 200 | beta = 0.5 201 | ndatas[:,] = torch.pow(ndatas[:,]+1e-6, beta) 202 | 203 | ndatas = QRreduction(ndatas) 204 | n_nfeat = ndatas.size(2) 205 | 206 | ndatas = scaleEachUnitaryDatas(ndatas) 207 | # trans-mean-sub 208 | ndatas = centerDatas(ndatas) 209 | 210 | USE_BPA = True # comment it to run vanilla PT-MAP 211 | if USE_BPA: 212 | ndatas = BPA(ot_reg = 0.2)(ndatas) # BPA insertion 213 | 214 | # for PT-MAP, we need to scale and normalize again 215 | ndatas = scaleEachUnitaryDatas(ndatas) 216 | ndatas = centerDatas(ndatas) 217 | 218 | # expected results for 1 and 5 shots | W/o BPA: 1=82.11, 5=88.57 | W/ BPA_p: 1=82.62 , 5=89.14 | W/ BPA_t: 1=84.69 , 5=90.30 | 219 | 220 | n_nfeat = ndatas.size(2) 221 | print("size of the datas...", ndatas.size()) 222 | 223 | # switch to cuda 224 | ndatas = ndatas.cuda() 225 | labels = labels.cuda() 226 | 227 | #MAP 228 | lam = 10 229 | model = GaussianModel(n_ways, lam) 230 | model.initFromLabelledDatas() 231 | 232 | alpha = 0.2 233 | optim = MAP(alpha) 234 | 235 | optim.verbose=False 236 | optim.progressBar=True 237 | 238 | acc_test = optim.loop(model, n_epochs=20) 239 | 240 | print("final accuracy found {:0.2f} +- {:0.2f}".format(*(100*x for x in acc_test))) 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /models/wrn_mixup_model.py: -------------------------------------------------------------------------------- 1 | ### dropout has been removed in this code. original code had dropout##### 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | import sys, os 9 | import numpy as np 10 | import random 11 | 12 | act = torch.nn.ReLU() 13 | 14 | import math 15 | from torch.nn.utils.weight_norm import WeightNorm 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 20 | super(BasicBlock, self).__init__() 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.relu1 = nn.ReLU(inplace=True) 23 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(out_planes) 26 | self.relu2 = nn.ReLU(inplace=True) 27 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 28 | padding=1, bias=False) 29 | self.droprate = dropRate 30 | self.equalInOut = (in_planes == out_planes) 31 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 32 | padding=0, bias=False) or None 33 | 34 | def forward(self, x): 35 | if not self.equalInOut: 36 | x = self.relu1(self.bn1(x)) 37 | else: 38 | out = self.relu1(self.bn1(x)) 39 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 40 | if self.droprate > 0: 41 | out = F.dropout(out, p=self.droprate, training=self.training) 42 | out = self.conv2(out) 43 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 44 | 45 | 46 | class distLinear(nn.Module): 47 | def __init__(self, indim, outdim): 48 | super(distLinear, self).__init__() 49 | self.L = nn.Linear(indim, outdim, bias=False) 50 | self.class_wise_learnable_norm = True # See the issue#4&8 in the github 51 | if self.class_wise_learnable_norm: 52 | WeightNorm.apply(self.L, 'weight', dim=0) # split the weight update component to direction and norm 53 | 54 | if outdim <= 200: 55 | self.scale_factor = 2 # a fixed scale factor to scale the output of cos value into a reasonably large input for softmax 56 | else: 57 | self.scale_factor = 10 # in omniglot, a larger scale factor is required to handle >1000 output classes. 58 | 59 | def forward(self, x): 60 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 61 | x_normalized = x.div(x_norm + 0.00001) 62 | if not self.class_wise_learnable_norm: 63 | L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data) 64 | self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001) 65 | cos_dist = self.L( 66 | x_normalized) # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github 67 | scores = self.scale_factor * (cos_dist) 68 | 69 | return scores 70 | 71 | 72 | class NetworkBlock(nn.Module): 73 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 74 | super(NetworkBlock, self).__init__() 75 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 76 | 77 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 78 | layers = [] 79 | for i in range(int(nb_layers)): 80 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x): 84 | return self.layer(x) 85 | 86 | 87 | def to_one_hot(inp, num_classes): 88 | y_onehot = torch.FloatTensor(inp.size(0), num_classes) 89 | if torch.cuda.is_available(): 90 | y_onehot = y_onehot.cuda() 91 | 92 | y_onehot.zero_() 93 | x = inp.type(torch.LongTensor) 94 | if torch.cuda.is_available(): 95 | x = x.cuda() 96 | 97 | x = torch.unsqueeze(x, 1) 98 | y_onehot.scatter_(1, x, 1) 99 | 100 | return Variable(y_onehot, requires_grad=False) 101 | # return y_onehot 102 | 103 | 104 | def mixup_data(x, y, lam): 105 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 106 | 107 | batch_size = x.size()[0] 108 | index = torch.randperm(batch_size) 109 | if torch.cuda.is_available(): 110 | index = index.cuda() 111 | mixed_x = lam * x + (1 - lam) * x[index, :] 112 | y_a, y_b = y, y[index] 113 | 114 | return mixed_x, y_a, y_b, lam 115 | 116 | 117 | class WideResNet(nn.Module): 118 | def __init__(self, depth=28, widen_factor=10, num_classes=200, loss_type='dist', per_img_std=False, stride=1, 119 | dropRate=0.5): 120 | flatten = True 121 | super(WideResNet, self).__init__() 122 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 123 | assert ((depth - 4) % 6 == 0) 124 | n = (depth - 4) / 6 125 | block = BasicBlock 126 | # 1st conv before any network block 127 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 128 | padding=1, bias=False) 129 | # 1st block 130 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, stride, dropRate) 131 | # 2nd block 132 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 133 | # 3rd block 134 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 135 | # global average pooling and linear 136 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.nChannels = nChannels[3] 139 | 140 | self.num_classes = num_classes 141 | if flatten: 142 | self.final_feat_dim = 640 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 146 | m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | def forward(self, x, target=None, mixup=False, mixup_hidden=True, mixup_alpha=None, lam=0.4, return_logits=True): 152 | if target is not None: 153 | if mixup_hidden: 154 | layer_mix = random.randint(0, 3) 155 | elif mixup: 156 | layer_mix = 0 157 | else: 158 | layer_mix = None 159 | 160 | out = x 161 | 162 | target_a = target_b = target 163 | 164 | if layer_mix == 0: 165 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 166 | 167 | out = self.conv1(out) 168 | out = self.block1(out) 169 | 170 | if layer_mix == 1: 171 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 172 | 173 | out = self.block2(out) 174 | 175 | if layer_mix == 2: 176 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 177 | 178 | out = self.block3(out) 179 | if layer_mix == 3: 180 | out, target_a, target_b, lam = mixup_data(out, target, lam=lam) 181 | 182 | out = self.relu(self.bn1(out)) 183 | out = F.avg_pool2d(out, out.size()[2:]) 184 | out = out.view(out.size(0), -1) 185 | if not return_logits: 186 | return out, target_a, target_b 187 | 188 | out1 = self.linear(out) 189 | return out, out1, target_a, target_b 190 | else: 191 | out = x 192 | out = self.conv1(out) 193 | out = self.block1(out) 194 | out = self.block2(out) 195 | out = self.block3(out) 196 | out = self.relu(self.bn1(out)) 197 | out = F.avg_pool2d(out, out.size()[2:]) 198 | out = out.view(out.size(0), -1) 199 | # if not return_logits: 200 | return out 201 | 202 | # out1 = self.linear(out) 203 | # return out, out1 204 | 205 | 206 | def wrn28_10(num_classes=200, loss_type='dist', dropout=0): 207 | model = WideResNet(depth=28, widen_factor=10, num_classes=num_classes, loss_type=loss_type, per_img_std=False, 208 | stride=1, dropRate=dropout) 209 | return model 210 | -------------------------------------------------------------------------------- /methods/pt_map/test_standard.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | from tqdm import tqdm 5 | import sys 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from methods.pt_map import FSLTask 11 | from bpa import BPA 12 | 13 | 14 | def bool_flag(s): 15 | """ 16 | Parse boolean arguments from the command line. 17 | """ 18 | FALSY_STRINGS = {"off", "false", "0"} 19 | TRUTHY_STRINGS = {"on", "true", "1"} 20 | if s.lower() in FALSY_STRINGS: 21 | return False 22 | elif s.lower() in TRUTHY_STRINGS: 23 | return True 24 | else: 25 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 26 | 27 | 28 | def centerDatas(datas): 29 | datas[:, :n_lsamples] = datas[:, :n_lsamples, :] - datas[:, :n_lsamples].mean(1, keepdim=True) 30 | datas[:, n_lsamples:] = datas[:, n_lsamples:, :] - datas[:, n_lsamples:].mean(1, keepdim=True) 31 | return datas 32 | 33 | 34 | def scaleEachUnitaryDatas(datas): 35 | norms = datas.norm(dim=-1, keepdim=True) 36 | return datas / norms 37 | 38 | 39 | def QRreduction(datas): 40 | ndatas = torch.linalg.qr(datas.permute(0, 2, 1)).R 41 | ndatas = ndatas.permute(0, 2, 1) 42 | return ndatas 43 | 44 | 45 | class Model: 46 | def __init__(self, n_ways): 47 | self.n_ways = n_ways 48 | 49 | 50 | # --------- GaussianModel 51 | class GaussianModel(Model): 52 | def __init__(self, n_ways, lam, distance_metric: str = 'euclidean'): 53 | super(GaussianModel, self).__init__(n_ways) 54 | self.mus = None # shape [n_runs][n_ways][n_nfeat] 55 | self.lam = lam 56 | self.distance_metric = distance_metric 57 | 58 | def clone(self): 59 | other = GaussianModel(self.n_ways) 60 | other.mus = self.mus.clone() 61 | return self 62 | 63 | def cuda(self): 64 | self.mus = self.mus.cuda() 65 | 66 | def initFromLabelledDatas(self): 67 | self.mus = ndatas.reshape(n_runs, n_shot + n_queries, n_ways, n_nfeat)[:, :n_shot, ].mean(1) 68 | 69 | def updateFromEstimate(self, estimate, alpha): 70 | 71 | Dmus = estimate - self.mus 72 | self.mus = self.mus + alpha * Dmus 73 | 74 | def compute_optimal_transport(self, M, r, c, epsilon=1e-6): 75 | r = r.cuda() 76 | c = c.cuda() 77 | n_runs, n, m = M.shape 78 | P = torch.exp(- self.lam * M) 79 | P /= P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1) 80 | u = torch.zeros(n_runs, n).cuda() 81 | maxiters = 1000 82 | iters = 1 83 | # normalize this matrix 84 | while torch.max(torch.abs(u - P.sum(2))) > epsilon: 85 | u = P.sum(2) 86 | P *= (r / u).view((n_runs, -1, 1)) 87 | P *= (c / P.sum(1)).view((n_runs, 1, -1)) 88 | if iters == maxiters: 89 | break 90 | iters = iters + 1 91 | return P 92 | 93 | @staticmethod 94 | def _pairwise_dist(a, b): 95 | return (a.unsqueeze(2) - b.unsqueeze(1)).norm(dim=3).pow(2) 96 | 97 | def getProbas(self): 98 | global ndatas, n_nfeat 99 | # compute squared dist to centroids [n_runs][n_samples][n_ways] 100 | if self.distance_metric == 'cosine': 101 | dist = 1-torch.bmm(F.normalize(ndatas), F.normalize(self.mus.transpose(1, 2))) 102 | elif self.distance_metric == 'ce': 103 | dist = -torch.bmm(torch.log(ndatas + 1e-5), self.mus.transpose(1, 2)) 104 | else: 105 | dist = self._pairwise_dist(ndatas, self.mus) 106 | 107 | p_xj = torch.zeros_like(dist) 108 | r = torch.ones(n_runs, n_usamples, device='cuda') 109 | c = torch.ones(n_runs, n_ways, device='cuda') * n_queries 110 | p_xj_test = self.compute_optimal_transport(dist[:, n_lsamples:], r, c, epsilon=1e-4) 111 | p_xj[:, n_lsamples:] = p_xj_test 112 | 113 | p_xj[:, :n_lsamples].fill_(0) 114 | p_xj[:, :n_lsamples].scatter_(2, labels[:, :n_lsamples].unsqueeze(2), 1) 115 | 116 | return p_xj 117 | 118 | def estimateFromMask(self, mask): 119 | emus = mask.permute(0, 2, 1).matmul(ndatas).div(mask.sum(dim=1).unsqueeze(2)) 120 | return emus 121 | 122 | 123 | # ========================================= 124 | # MAP 125 | # ========================================= 126 | 127 | class MAP: 128 | def __init__(self, alpha=None, verbose: bool = False, progressBar: bool = False): 129 | self.verbose = verbose 130 | self.progressBar = progressBar 131 | self.alpha = alpha 132 | 133 | def getAccuracy(self, probas): 134 | olabels = probas.argmax(dim=2) 135 | matches = labels.eq(olabels).float() 136 | acc_test = matches[:, n_lsamples:].mean(1) 137 | 138 | m = acc_test.mean().item() 139 | pm = acc_test.std().item() * 1.96 / math.sqrt(n_runs) 140 | return m, pm 141 | 142 | def performEpoch(self, model, epochInfo=None): 143 | 144 | p_xj = model.getProbas() 145 | self.probas = p_xj 146 | 147 | m_estimates = model.estimateFromMask(self.probas) 148 | # update centroids 149 | model.updateFromEstimate(m_estimates, self.alpha) 150 | 151 | if self.verbose: 152 | op_xj = model.getProbas() 153 | acc = self.getAccuracy(op_xj) 154 | print("output model accuracy", acc) 155 | 156 | def loop(self, model, n_epochs=20): 157 | self.probas = model.getProbas() 158 | if self.verbose: 159 | print("initialisation model accuracy", self.getAccuracy(self.probas)) 160 | 161 | if self.progressBar: 162 | if type(self.progressBar) == bool: 163 | pb = tqdm(total=n_epochs) 164 | else: 165 | pb = self.progressBar 166 | 167 | for epoch in range(1, n_epochs + 1): 168 | self.performEpoch(model, epochInfo=(epoch, n_epochs)) 169 | if self.progressBar: pb.update() 170 | 171 | # get final accuracy and return it 172 | op_xj = model.getProbas() 173 | acc = self.getAccuracy(op_xj) 174 | return acc 175 | 176 | 177 | def get_args(): 178 | """ Description: Parses arguments at command line. """ 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument('--root', type=str, default='C:/Users/dani3/Documents/GitHub/SOT/') 181 | parser.add_argument('--features_path', type=str, 182 | default='/checkpoints/wrn/miniImagenet/WideResNet28_10_S2M2_R/last/output.plk') 183 | parser.add_argument('--dataset', type=str, default='miniimagenet', choices=['miniimagenet']) 184 | parser.add_argument('--num_way', type=int, default=5) 185 | parser.add_argument('--num_shot', type=int, default=5) 186 | parser.add_argument('--num_query', type=int, default=15) 187 | parser.add_argument('--num_runs', type=int, default=10000) 188 | parser.add_argument('--num_repeat', type=int, default=1, 189 | help='repeat the evaluation n times for averaging purposes.') 190 | parser.add_argument('--verbose', type=bool_flag, default=False) 191 | 192 | # BPA args 193 | parser.add_argument('--ot_reg', type=float, default=0.1) 194 | parser.add_argument('--sink_iters', type=int, default=10) 195 | parser.add_argument('--distance_metric', type=str, default='cosine') 196 | parser.add_argument('--norm_type', type=str, default='sinkhorn') 197 | parser.add_argument('--mask_diag', type=bool_flag, default=True) 198 | return parser.parse_args() 199 | 200 | 201 | if __name__ == '__main__': 202 | # ---- data loading 203 | args = get_args() 204 | n_shot = args.num_shot 205 | n_ways = args.num_way 206 | n_queries = args.num_query 207 | n_runs = args.num_runs 208 | n_lsamples = n_ways * n_shot 209 | n_usamples = n_ways * n_queries 210 | n_samples = n_lsamples + n_usamples 211 | 212 | cfg = {'shot': n_shot, 'ways': n_ways, 'queries': n_queries} 213 | FSLTask.loadDataSet(args.dataset, root=args.root, features_path=args.root + args.features_path) 214 | FSLTask.setRandomStates(cfg) 215 | ndatas = FSLTask.GenerateRunSet(cfg=cfg, end=n_runs) 216 | ndatas = ndatas.permute(0, 2, 1, 3).reshape(n_runs, n_samples, -1) 217 | labels = torch.arange(n_ways).view(1, 1, n_ways).expand(n_runs, n_shot + n_queries, n_ways).clone().view(n_runs, 218 | n_samples) 219 | labels = labels.cuda() 220 | ndatas = ndatas.cuda() 221 | 222 | # Power transform + QR + Normalize 223 | ndatas[:, ] = torch.pow(ndatas[:, ] + 1e-6, 0.5) 224 | ndatas = QRreduction(ndatas) 225 | ndatas = scaleEachUnitaryDatas(ndatas) 226 | # trans-mean-sub 227 | ndatas = centerDatas(ndatas) 228 | _ndatas = scaleEachUnitaryDatas(ndatas) 229 | # # transform data 230 | bpa = BPA( 231 | args.distance_metric, 232 | ot_reg=args.ot_reg, 233 | sinkhorn_iterations=args.sink_iters, 234 | mask_diag=args.mask_diag, 235 | ) 236 | 237 | for dm in ['euclidean']: 238 | print(f"DM {dm}") 239 | for mask_diag in [False, True]: 240 | bpa.mask_diag = mask_diag 241 | print(f"sot mask_diag {bpa.mask_diag }") 242 | # for max_temp in [False, True]: 243 | # print(f"sot max_temp {max_temp}") 244 | for reg in [0.1, 0.2, 0.3, 0.4, 0.5]: 245 | bpa.ot_reg = reg 246 | print(f"sot lambda {bpa.ot_reg}") 247 | 248 | ndatas = bpa(_ndatas) 249 | n_nfeat = ndatas.size(2) 250 | print("size of the datas...", ndatas.size()) 251 | 252 | # MAP 253 | model = GaussianModel(n_ways=n_ways, lam=10, distance_metric=dm) 254 | model.initFromLabelledDatas() 255 | 256 | optim = MAP(alpha=0.2, verbose=args.verbose) 257 | acc_test = optim.loop(model, n_epochs=20) 258 | 259 | print("final accuracy found {:0.2f} +- {:0.2f}".format(*(100 * x for x in acc_test))) 260 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | from time import time 5 | 6 | import torch 7 | 8 | import utils 9 | from bpa import BPA 10 | 11 | 12 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--seed', type=int, default=1, 19 | help="""Random seed.""") 20 | 21 | parser.add_argument('--root_path', type=str, default='./', 22 | help=""" Path to project root directory. """) 23 | parser.add_argument('--checkpoint_dir', type=str, default=None, 24 | help=""" Where to save model checkpoints. If None, it will automatically created. """) 25 | parser.add_argument('--dataset', type=str, default='miniimagenet', 26 | choices=['miniimagenet', 'cifar']) 27 | parser.add_argument('--data_path', type=str, default='./datasets/few_shot/miniimagenet', 28 | help="""Path to dataset root directory.""") 29 | 30 | parser.add_argument('--backbone', type=str, default='wrn', 31 | help="""Define which backbone network to use. """) 32 | parser.add_argument('--pretrained_path', type=str, default=False, 33 | help=""" Path to pretrained model, used for testing/fine-tuning. """) 34 | 35 | parser.add_argument('--eval', type=utils.bool_flag, default=False, 36 | help=""" If true, make evaluation on the *test set*. 37 | The amount of test episodes controlled by --test_episodes=<>""") 38 | parser.add_argument('--eval_freq', type=int, default=1, 39 | help=""" Evaluate training every n epochs. """) 40 | parser.add_argument('--eval_first', type=utils.bool_flag, default=False, 41 | help=""" Set to true to evaluate the model before training. Useful for fine-tuning. """) 42 | parser.add_argument('--num_workers', type=int, default=8) 43 | 44 | # wandb specific arguments 45 | parser.add_argument('--wandb', type=utils.bool_flag, default=False, 46 | help=""" Log data into wandb. """) 47 | parser.add_argument('--project', type=str, default='BPA', 48 | help=""" Project name in wandb. """) 49 | parser.add_argument('--entity', type=str, default='', 50 | help=""" Your wandb entity name. """) 51 | 52 | # few-shot specific arguments 53 | parser.add_argument('--method', type=str, default='pt_map_bpa', 54 | choices=['proto', 'proto_bpa', 'pt_map', 'pt_map_bpa'], 55 | help="""Specify which few-shot classifier to use.""") 56 | parser.add_argument('--train_way', type=int, default=5, 57 | help=""" Number of classes for each training task. """) 58 | parser.add_argument('--val_way', type=int, default=5, 59 | help=""" Number of classes for each validation/test task. """) 60 | parser.add_argument('--num_shot', type=int, default=5, 61 | help=""" Number of (labeled) support samples for each class. """) 62 | parser.add_argument('--num_query', type=int, default=15, 63 | help=""" Number of (un-labeled) query samples for each class. """) 64 | parser.add_argument('--train_episodes', type=int, default=200, 65 | help=""" Number of few-shot tasks for each epoch. """) 66 | parser.add_argument('--eval_episodes', type=int, default=400, 67 | help=""" Number of tasks to evaluate. """) 68 | parser.add_argument('--test_episodes', type=int, default=10000, 69 | help=""" Number of tasks to evaluate. """) 70 | parser.add_argument('--temperature', type=float, default=1., 71 | help=""" Temperature for ProtoNet. """) 72 | 73 | # training specific arguments 74 | parser.add_argument('--max_epochs', type=int, default=25, 75 | help="""Number of training/finetuning epochs. """) 76 | parser.add_argument('--optimizer', type=str, default='adam', 77 | help="""Optimizer""", choices=['adam', 'adamw', 'sgd']) 78 | parser.add_argument('--lr', type=float, default=5e-5, 79 | help="""Learning rate. """) 80 | parser.add_argument('--weight_decay', type=float, default=1e-4, 81 | help="""Weight decay. """) 82 | parser.add_argument('--dropout', type=float, default=0., 83 | help=""" Dropout probability. """) 84 | parser.add_argument('--momentum', type=float, default=0.9, 85 | help="""Momentum of SGD optimizer. """) 86 | parser.add_argument('--scheduler', type=str, default='step', 87 | help="""Learning rate scheduler. To disable the scheduler, use scheduler=''. """) 88 | parser.add_argument('--step_size', type=int, default=5, 89 | help="""Step size (in epochs) of StepLR scheduler. """) 90 | parser.add_argument('--gamma', type=float, default=0.5, 91 | help="""Gamma of StepLR scheduler. """) 92 | parser.add_argument('--augment', type=utils.bool_flag, default=False, 93 | help=""" Apply data augmentation. """) 94 | 95 | # BPA specific arguments 96 | parser.add_argument('--ot_reg', type=float, default=0.1, 97 | help=""" Sinkhorn entropy regularization. 98 | For few-shot methods, 0.1-0.2 seems to work best. 99 | For larger tasks (~10,000) samples, try to increase this value. """) 100 | parser.add_argument('--sink_iters', type=int, default=20, 101 | help=""" Number of Sinkhorn iterations. 102 | Usually small number (~ 5-10) is sufficient. """) 103 | parser.add_argument('--distance_metric', type=str, default='cosine', 104 | help=""" Distance metric for the OT cost matrix. """, 105 | choices=['cosine', 'euclidean']) 106 | parser.add_argument('--mask_diag', type=utils.bool_flag, default=True, 107 | help=""" If true, mask diagonal (self) values before and after the OT. """) 108 | parser.add_argument('--max_scale', type=utils.bool_flag, default=True, 109 | help=""" Scaling range of the BPA values to [0,1]. 110 | This should always be True. """) 111 | 112 | return parser.parse_args() 113 | 114 | 115 | def main(): 116 | args = get_args() 117 | utils.set_seed(seed=args.seed) 118 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 119 | output_dir = utils.get_output_dir(args=args) 120 | 121 | # define datasets and loaders 122 | args.set_episodes = dict(train=args.train_episodes, val=args.eval_episodes, test=args.test_episodes) 123 | if not args.eval: 124 | train_dataloader = utils.get_dataloader(set_name='train', args=args, constant=False) 125 | val_dataloader = utils.get_dataloader(set_name='val', args=args, constant=True) 126 | else: 127 | val_dataloader = utils.get_dataloader(set_name='test', args=args, constant=False) 128 | train_dataloader = None 129 | 130 | # define model and load pretrained weights if available 131 | model = utils.get_model(args.backbone, args) 132 | model = model.to(device) 133 | utils.load_weights(model, args.pretrained_path) 134 | 135 | # BPA and few-shot classification method (e.g. proto, pt-map...) 136 | bpa = None 137 | if 'bpa' in args.method.lower(): 138 | bpa = BPA( 139 | distance_metric=args.distance_metric, 140 | ot_reg=args.ot_reg, 141 | mask_diag=args.mask_diag, 142 | sinkhorn_iterations=args.sink_iters, 143 | max_scale=args.max_scale 144 | ) 145 | fewshot_method = utils.get_method(args=args, bpa=bpa) 146 | 147 | # few-shot labels 148 | train_labels = utils.get_fs_labels(args.method, args.train_way, args.num_query, args.num_shot) 149 | val_labels = utils.get_fs_labels(args.method, args.val_way, args.num_query, args.num_shot) 150 | 151 | # initialized wandb 152 | if args.wandb: 153 | utils.init_wandb(exp_name=output_dir.split('/')[-1] if output_dir[-1] != '/' else output_dir.split('/')[-2], 154 | args=args) 155 | 156 | # define loss 157 | criterion = utils.get_criterion_by_method(method=args.method) 158 | 159 | # Test-set evaluation 160 | if args.eval: 161 | print(f"Evaluate model for {args.test_episodes} episodes... ") 162 | loss, acc = eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, 0, args, set_name='test') 163 | print("Final evaluation results:\nAccuracy={:.4f}, Loss={:.4f}".format(acc, loss)) 164 | exit(1) 165 | 166 | # define optimizer and scheduler 167 | optimizer, lr_scheduler = utils.get_optimizer_and_lr_scheduler(args=args, params=model.parameters()) 168 | 169 | # evaluate model before training 170 | if args.eval_first: 171 | print("Evaluate model before training... ") 172 | eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, -1, args, set_name='val') 173 | 174 | # train 175 | print("Start training...") 176 | best_acc = 0. 177 | best_loss = math.inf 178 | for epoch in range(args.max_epochs): 179 | print("[Epoch {}/{}]...".format(epoch, args.max_epochs)) 180 | 181 | # train 182 | train_one_epoch(model, train_dataloader, optimizer, fewshot_method, criterion, train_labels, epoch, args) 183 | if lr_scheduler is not None: 184 | lr_scheduler.step() 185 | 186 | # evaluate 187 | if epoch % args.eval_freq == 0: 188 | eval_loss, eval_acc = eval_one_epoch(model, val_dataloader, fewshot_method, criterion, val_labels, 189 | epoch, args, set_name='val') 190 | # save best model 191 | if eval_loss < best_loss: 192 | best_loss = eval_loss 193 | torch.save(model.state_dict(), os.path.join(output_dir, 'min_loss.pth')) 194 | elif eval_acc > best_acc: 195 | best_acc = eval_acc 196 | torch.save(model.state_dict(), os.path.join(output_dir, 'max_acc.pth')) 197 | 198 | # save last checkpoint 199 | torch.save(model.state_dict(), os.path.join(output_dir, 'last.pth')) 200 | 201 | 202 | def train_one_epoch(model, dataloader, optimizer, fewshot_method, criterion, labels, epoch, args): 203 | metric_logger = utils.MetricLogger(delimiter=" ") 204 | header = 'Train Epoch: [{}/{}]'.format(epoch, args.max_epochs) 205 | log_freq = 50 206 | n_batches = len(dataloader) 207 | 208 | model.train() 209 | for batch_idx, (images, _) in enumerate(metric_logger.log_every(dataloader, log_freq, header=header)): 210 | images = images.to(device) 211 | # extract features 212 | features = model(images) 213 | # few-shot classifier 214 | probas, accuracy = fewshot_method(features, labels=labels, mode='train') 215 | q_labels = labels if len(labels) == len(probas) else labels[-len(probas):] 216 | # loss 217 | loss = criterion(probas, q_labels) 218 | 219 | optimizer.zero_grad() 220 | loss.backward() 221 | optimizer.step() 222 | 223 | metric_logger.update(loss=loss.detach().item(), accuracy=accuracy) 224 | 225 | if batch_idx % log_freq == 0: 226 | utils.wandb_log( 227 | { 228 | 'train/step': batch_idx + (epoch * n_batches), 229 | 'train/loss_step': loss.item(), 230 | 'train/accuracy_step': accuracy, 231 | } 232 | ) 233 | 234 | print("Averaged stats:", metric_logger) 235 | utils.wandb_log( 236 | { 237 | 'lr': optimizer.param_groups[0]['lr'], 238 | 'train/epoch': epoch, 239 | 'train/loss': metric_logger.loss.global_avg, 240 | 'train/accuracy': metric_logger.accuracy.global_avg, 241 | } 242 | ) 243 | return metric_logger 244 | 245 | 246 | @torch.no_grad() 247 | def eval_one_epoch(model, dataloader, fewshot_method, criterion, labels, epoch, args, set_name): 248 | metric_logger = utils.MetricLogger(delimiter=" ") 249 | header = 'Validation:' if set_name == "val" else 'Test:' 250 | log_freq = 50 251 | 252 | n_batches = len(dataloader) 253 | model.eval() 254 | for batch_idx, (images, _) in enumerate(metric_logger.log_every(dataloader, log_freq, header=header)): 255 | images = images.to(device) 256 | # extract features 257 | features = model(images) 258 | # few-shot classifier 259 | probas, accuracy = fewshot_method(X=features, labels=labels, mode='val') 260 | q_labels = labels if len(labels) == len(probas) else labels[-len(probas):] 261 | # loss 262 | loss = criterion(probas, q_labels) 263 | metric_logger.update(loss=loss.detach().item(), accuracy=accuracy) 264 | 265 | print("Averaged stats:", metric_logger) 266 | utils.wandb_log( 267 | { 268 | '{}/epoch'.format(set_name): epoch, 269 | '{}/loss'.format(set_name): metric_logger.loss.global_avg, 270 | '{}/accuracy'.format(set_name): metric_logger.accuracy.global_avg, 271 | } 272 | ) 273 | return metric_logger.loss.global_avg, metric_logger.accuracy.global_avg 274 | 275 | 276 | if __name__ == '__main__': 277 | main() 278 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import argparse 4 | import random 5 | import time 6 | from collections import defaultdict, deque 7 | 8 | import numpy as np 9 | import torch 10 | from torch import optim 11 | from torch.utils.data import DataLoader 12 | 13 | from models.wrn_mixup_model import wrn28_10 14 | from models.resnet12 import Res12 15 | from datasets import MiniImageNet, CIFAR, CUB 16 | from datasets.samplers import CategoriesSampler 17 | from methods import PTMAPLoss, ProtoLoss 18 | 19 | try: 20 | import wandb 21 | HAS_WANDB = True 22 | except Exception as e: 23 | HAS_WANDB = False 24 | 25 | 26 | MODELS = dict( 27 | wrn=wrn28_10, resnet12=Res12 28 | ) 29 | DATASETS = dict( 30 | miniimagenet=MiniImageNet, cifar=CIFAR 31 | ) 32 | METHODS = dict( 33 | pt_map=PTMAPLoss, pt_map_bpa=PTMAPLoss, proto=ProtoLoss, proto_bpa=ProtoLoss, 34 | ) 35 | 36 | 37 | def get_model(model_name: str, args): 38 | """ 39 | Get the backbone model. 40 | """ 41 | arch = model_name.lower() 42 | if arch in MODELS.keys(): 43 | model = MODELS[arch](dropout=args.dropout) 44 | if torch.cuda.is_available(): 45 | torch.backends.cudnn.benchmark = True 46 | return model 47 | else: 48 | raise ValueError(f'Model {model_name} not implemented. available models are: {list(MODELS.keys())}') 49 | 50 | 51 | def get_dataloader(set_name: str, args: argparse, constant: bool = False): 52 | """ 53 | Get dataloader with categorical sampler for few-shot classification. 54 | """ 55 | num_episodes = args.set_episodes[set_name] 56 | num_way = args.train_way if set_name == 'train' else args.val_way 57 | 58 | # define dataset sampler and data loader 59 | data_set = DATASETS[args.dataset.lower()]( 60 | args.data_path, set_name, args.backbone, 61 | augment=set_name == 'train' and args.augment 62 | ) 63 | args.img_size = data_set.image_size 64 | 65 | data_sampler = CategoriesSampler( 66 | set_name, data_set.label, num_episodes, const_loader=constant, 67 | num_way=num_way, num_shot=args.num_shot, num_query=args.num_query, 68 | replace=set_name == 'train', 69 | ) 70 | return DataLoader( 71 | data_set, batch_sampler=data_sampler, num_workers=args.num_workers, pin_memory=not constant 72 | ) 73 | 74 | 75 | def get_optimizer_and_lr_scheduler(args, params): 76 | optimizer = get_optimizer(args, params) 77 | lr_scheduler = get_scheduler(args, optimizer) 78 | return optimizer, lr_scheduler 79 | 80 | 81 | def get_optimizer(args, params): 82 | """ 83 | Get optimizer. 84 | """ 85 | if args.optimizer == 'adam': 86 | return optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 87 | elif args.optimizer == 'adamw': 88 | return optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) 89 | elif args.optimizer == 'sgd': 90 | return optim.SGD(params, lr=args.lr, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) 91 | else: 92 | raise ValueError(f'Optimizer {args.optimizer} not available.') 93 | 94 | 95 | def get_scheduler(args, optimizer: torch.optim): 96 | """ 97 | Get optimizer. 98 | """ 99 | if not args.scheduler or args.scheduler == '': 100 | return None 101 | elif args.scheduler == 'step': 102 | return optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=args.step_size, gamma=args.gamma) 103 | else: 104 | raise ValueError(f'Error: LR-scheduler {args.scheduler} is not available.') 105 | 106 | 107 | def get_method(args, bpa=None): 108 | """ 109 | Get the few-shot classification method (e.g. pt_map). 110 | """ 111 | 112 | if args.method.lower() in METHODS.keys(): 113 | return METHODS[args.method.lower()](args=vars(args), bpa=bpa) 114 | else: 115 | raise ValueError(f'Not implemented method. available methods are: {METHODS.keys()}') 116 | 117 | 118 | def get_criterion_by_method(method: str): 119 | """ 120 | Get loss function based on the method. 121 | """ 122 | 123 | if 'pt_map' in method: 124 | return torch.nn.NLLLoss() 125 | elif 'proto' in method: 126 | return torch.nn.CrossEntropyLoss() 127 | else: 128 | raise ValueError(f'Not implemented criterion for this method. available methods are: {list(METHODS.keys())}') 129 | 130 | 131 | def init_wandb(exp_name: str, args): 132 | """ 133 | Initialize and returns wandb logger if args.wandb is True. 134 | """ 135 | if not args.wandb: 136 | return None 137 | assert HAS_WANDB, "Install wandb via - 'pip install wandb' in order to use wandb logging. " 138 | logger = wandb.init(project=args.project, entity=args.entity, name=exp_name, config=vars(args)) 139 | # define which metrics will be plotted against it 140 | logger.define_metric("train_loss", step_metric="epoch") 141 | logger.define_metric("train_accuracy", step_metric="epoch") 142 | logger.define_metric("val_loss", step_metric="epoch") 143 | logger.define_metric("val_accuracy", step_metric="epoch") 144 | return logger 145 | 146 | 147 | def wandb_log(results: dict): 148 | """ 149 | Log step to the logger without print. 150 | """ 151 | if HAS_WANDB and wandb.run is not None: 152 | wandb.log(results) 153 | 154 | 155 | def get_output_dir(args: argparse): 156 | """ 157 | Initialize the output dir. 158 | """ 159 | 160 | if args.checkpoint_dir is None: 161 | checkpoint_dir = os.path.join(args.root_path, 'checkpoints', args.dataset.lower(), args.backbone.lower(), args.method.lower()) 162 | 163 | name_str = f'-n_way={args.train_way}' \ 164 | f'-n_shot={args.num_shot}' \ 165 | f'-lr={args.lr}' \ 166 | f'-scheduler={args.scheduler}' \ 167 | f'-dropout={args.dropout}' 168 | 169 | checkpoint_dir = os.path.join(checkpoint_dir, name_str) 170 | else: 171 | checkpoint_dir = args.checkpoint_dir 172 | 173 | if args.eval: 174 | return checkpoint_dir 175 | 176 | while os.path.exists(checkpoint_dir): 177 | checkpoint_dir += f'-{np.random.randint(100)}' 178 | 179 | os.makedirs(checkpoint_dir, exist_ok=True) 180 | 181 | # write args to a file 182 | with open(os.path.join(checkpoint_dir, "args.txt"), 'w') as f: 183 | for key, value in vars(args).items(): 184 | f.write('%s:%s\n' % (key, value)) 185 | 186 | print("=> Checkpoints will be saved at:\n", checkpoint_dir) 187 | 188 | return checkpoint_dir 189 | 190 | 191 | def load_weights(model: torch.nn.Module, pretrained_path: str): 192 | """ 193 | Load pretrained weights from given path. 194 | """ 195 | if not pretrained_path: 196 | return model 197 | 198 | print(f'Loading weights from {pretrained_path}') 199 | state_dict = torch.load(pretrained_path) 200 | sd_keys = list(state_dict.keys()) 201 | if 'state' in sd_keys: 202 | state_dict = state_dict['state'] 203 | for k in list(state_dict.keys()): 204 | if k.startswith('module.'): 205 | state_dict["{}".format(k[len('module.'):])] = state_dict[k] 206 | del state_dict[k] 207 | 208 | model.load_state_dict(state_dict, strict=False) 209 | 210 | elif 'params' in sd_keys: 211 | state_dict = state_dict['params'] 212 | for k in list(state_dict.keys()): 213 | if k.startswith('encoder.'): 214 | state_dict["{}".format(k[len('encoder.'):])] = state_dict[k] 215 | 216 | del state_dict[k] 217 | 218 | model.load_state_dict(state_dict, strict=True) 219 | else: 220 | model.load_state_dict(state_dict) 221 | 222 | print("Weights loaded successfully ") 223 | return model 224 | 225 | 226 | def get_fs_labels(method: str, num_way: int, num_query: int, num_shot: int): 227 | """ 228 | Prepare few-shot labels. For example for 5-way, 1-shot, 2-query: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, ...] 229 | """ 230 | n_samples = num_shot + num_query if 'map' in method else num_query 231 | labels = torch.arange(num_way, dtype=torch.int16).repeat(n_samples).type(torch.LongTensor) 232 | 233 | if torch.cuda.is_available(): 234 | return labels.cuda() 235 | else: 236 | return labels 237 | 238 | 239 | def bool_flag(s): 240 | """ 241 | Parse boolean arguments from the command line. 242 | """ 243 | FALSY_STRINGS = {"off", "false", "0"} 244 | TRUTHY_STRINGS = {"on", "true", "1"} 245 | if s.lower() in FALSY_STRINGS: 246 | return False 247 | elif s.lower() in TRUTHY_STRINGS: 248 | return True 249 | else: 250 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 251 | 252 | 253 | def print_and_log(results: dict, n: int = 0): 254 | """ 255 | Print and log current results. 256 | """ 257 | for key in results.keys(): 258 | # average by n if needed (n > 0) 259 | if n > 0 and 'time' not in key and '/epoch' not in key: 260 | results[key] = results[key] / n 261 | 262 | # print and log 263 | print(f'{key}: {results[key]:.4f}') 264 | 265 | if wandb.run is not None: 266 | wandb.log(results) 267 | 268 | 269 | def set_seed(seed: int): 270 | """ 271 | seed. 272 | """ 273 | random.seed(seed) 274 | np.random.seed(seed) 275 | torch.random.manual_seed(seed) 276 | torch.cuda.manual_seed(seed) 277 | 278 | 279 | class bcolors: 280 | HEADER = '\033[95m' 281 | OKBLUE = '\033[94m' 282 | OKCYAN = '\033[96m' 283 | OKGREEN = '\033[92m' 284 | WARNING = '\033[93m' 285 | FAIL = '\033[91m' 286 | ENDC = '\033[0m' 287 | BOLD = '\033[1m' 288 | UNDERLINE = '\033[4m' 289 | 290 | 291 | class SmoothedValue(object): 292 | """Track a series of values and provide access to smoothed values over a 293 | window or the global series average. 294 | """ 295 | 296 | def __init__(self, window_size=20, fmt=None): 297 | if fmt is None: 298 | fmt = "{median:.6f} ({global_avg:.6f})" 299 | self.deque = deque(maxlen=window_size) 300 | self.total = 0.0 301 | self.count = 0 302 | self.fmt = fmt 303 | 304 | def update(self, value, n=1): 305 | self.deque.append(value) 306 | self.count += n 307 | self.total += value * n 308 | 309 | @property 310 | def median(self): 311 | d = torch.tensor(list(self.deque)) 312 | return d.median().item() 313 | 314 | @property 315 | def avg(self): 316 | d = torch.tensor(list(self.deque), dtype=torch.float32) 317 | return d.mean().item() 318 | 319 | @property 320 | def global_avg(self): 321 | return self.total / self.count 322 | 323 | @property 324 | def max(self): 325 | return max(self.deque) 326 | 327 | @property 328 | def value(self): 329 | return self.deque[-1] 330 | 331 | def __str__(self): 332 | return self.fmt.format( 333 | median=self.median, 334 | avg=self.avg, 335 | global_avg=self.global_avg, 336 | max=self.max, 337 | value=self.value) 338 | 339 | 340 | class MetricLogger(object): 341 | def __init__(self, delimiter="\t"): 342 | self.meters = defaultdict(SmoothedValue) 343 | self.delimiter = delimiter 344 | 345 | def update(self, **kwargs): 346 | for k, v in kwargs.items(): 347 | if isinstance(v, torch.Tensor): 348 | v = v.item() 349 | assert isinstance(v, (float, int)) 350 | self.meters[k].update(v) 351 | 352 | def __getattr__(self, attr): 353 | if attr in self.meters: 354 | return self.meters[attr] 355 | if attr in self.__dict__: 356 | return self.__dict__[attr] 357 | raise AttributeError("'{}' object has no attribute '{}'".format( 358 | type(self).__name__, attr)) 359 | 360 | def __str__(self): 361 | loss_str = [] 362 | for name, meter in self.meters.items(): 363 | loss_str.append( 364 | "{}: {}".format(name, str(meter)) 365 | ) 366 | return self.delimiter.join(loss_str) 367 | 368 | def add_meter(self, name, meter): 369 | self.meters[name] = meter 370 | 371 | def log_every(self, iterable, print_freq, header=None): 372 | i = 0 373 | if not header: 374 | header = '' 375 | 376 | start_time = time.time() 377 | end = time.time() 378 | iter_time = SmoothedValue(fmt='{avg:.6f}') 379 | data_time = SmoothedValue(fmt='{avg:.6f}') 380 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 381 | if torch.cuda.is_available(): 382 | log_msg = self.delimiter.join([ 383 | header, 384 | '[{0' + space_fmt + '}/{1}]', 385 | 'eta: {eta}', 386 | '{meters}', 387 | 'time: {time}', 388 | 'data: {data}', 389 | 'mem: {memory:.0f} ' 390 | 'mem reserved: {memory_res:.0f} ' 391 | ]) 392 | else: 393 | log_msg = self.delimiter.join([ 394 | header, 395 | '[{0' + space_fmt + '}/{1}]', 396 | 'eta: {eta}', 397 | '{meters}', 398 | 'time: {time}', 399 | 'data: {data}' 400 | ]) 401 | MB = 1024.0 * 1024.0 402 | for obj in iterable: 403 | data_time.update(time.time() - end) 404 | yield obj 405 | iter_time.update(time.time() - end) 406 | len_iterable = len(iterable) 407 | if i % print_freq == 0 or i == len_iterable - 1: 408 | eta_seconds = iter_time.global_avg * (len_iterable - i) 409 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 410 | if torch.cuda.is_available(): 411 | print(log_msg.format( 412 | i, len_iterable, eta=eta_string, 413 | meters=str(self), 414 | time=str(iter_time), data=str(data_time), 415 | memory=torch.cuda.memory_allocated() / MB, 416 | memory_res=torch.cuda.memory_reserved() / MB)) 417 | else: 418 | print(log_msg.format( 419 | i, len_iterable, eta=eta_string, 420 | meters=str(self), 421 | time=str(iter_time), data=str(data_time))) 422 | i += 1 423 | end = time.time() 424 | total_time = time.time() - start_time 425 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 426 | print('{} Total time: {} ({:.6f} s / it)'.format( 427 | header, total_time_str, total_time / len(iterable))) 428 | --------------------------------------------------------------------------------