├── BMS.png ├── requirements.txt ├── README.md ├── FSLTask.py └── test_standard_bms.py /BMS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhu01/BMS/HEAD/BMS.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.3 2 | Pillow==8.4.0 3 | torch==1.6.0 4 | torchvision==0.7.0 5 | tqdm==4.51.0 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Squeezing Backbone Feature Distributions to the Max for Efficient Few-Shot Learning 2 | 3 | This repository is the official implementation of [Squeezing Backbone Feature Distributions to the Max for Efficient Few-Shot Learning](https://arxiv.org/pdf/2110.09446.pdf). 4 | 5 | ![](BMS.png) 6 | 7 | ## Requirements 8 | 9 | To install requirements: 10 | 11 | ```setup 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Pre-trained Models 16 | 17 | You can download the extracted features here: 18 | 19 | - [Extracted novel class features](https://drive.google.com/file/d/1dSKfO0mMz0KzObXIU930JFbhb7OK1qUU/view?usp=sharing) on miniImageNet, tieredImagenet, cub and cifar-fs . 20 | 21 | - Create a 'checkpoint' folder 22 | - Untar the downloaded file and move it into the 'checkpoint' folder. 23 | 24 | 25 | ## Boosted Min-size Sinkhorn 26 | 27 | 28 | > 📋 To launch the BMS algorithm, run: 29 | ``` 30 | python test_standard_bms.py --dataset [mini/tiered/cub/cifar-fs] --model wrn --method [BMS/BMS_] --preprocess PEME --shot [1/5] --epoch [0/20/40] 31 | ``` 32 | 33 | ## Results 34 | 35 | Our model achieves the following performance on : 36 | 37 | 38 | | Dataset | 1-shot Accuracy | 5-shot Accuracy | 39 | | ---------------|---------------| -------------| 40 | | miniImageNet | 83.35+-0.25% | 89.53+-0.13% | 41 | | tieredImageNet | 86.07+-0.25% | 91.09+-0.14% | 42 | | CUB | 91.91+-0.18% | 94.62+-0.09% | 43 | | CIFAR-FS | 87.83+-0.22% | 91.20+-0.15% | 44 | 45 | ## References 46 | 47 | [Leveraing the Feature Distribution in Transfer-based Few-Shot Learning](https://arxiv.org/pdf/2006.03806.pdf) 48 | 49 | [Sinkhorn Distances: Lightspeed Computation of Optimal Transport](https://papers.nips.cc/paper/4927-sinkhorn-distances-lightspeed-computation-of-optimal-transport.pdf) 50 | 51 | [SimpleShot: Revisiting Nearest-Neighbor Classification for Few-Shot Learning](https://arxiv.org/pdf/1911.04623.pdf) 52 | 53 | [Notes on optimal transport](https://github.com/MichielStock/Teaching/tree/master/Optimal_transport) 54 | -------------------------------------------------------------------------------- /FSLTask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | # from tqdm import tqdm 6 | 7 | # ======================================================== 8 | # Usefull paths 9 | _datasetFeaturesFiles = { 10 | "cub_wrn": "./checkpoint/cub/output.plk", 11 | "tiered_wrn": "./checkpoint/tieredImagenet/output.plk", 12 | "mini_wrn": "./checkpoint/miniImagenet/output.plk", 13 | "cifar-fs_wrn": "./checkpoint/cifar-fs/output.plk", 14 | } 15 | _cacheDir = "./cache" 16 | _maxRuns = 10000 17 | _min_examples = -1 18 | 19 | # ======================================================== 20 | # Module internal functions and variables 21 | 22 | _randStates = None 23 | _rsCfg = None 24 | 25 | 26 | def _load_pickle(file): 27 | with open(file, 'rb') as f: 28 | data = pickle.load(f) 29 | labels = [np.full(shape=len(data[key]), fill_value=key) 30 | for key in data] 31 | data = [features for key in data for features in data[key]] 32 | dataset = dict() 33 | dataset['data'] = torch.FloatTensor(np.stack(data, axis=0)) 34 | dataset['labels'] = torch.LongTensor(np.concatenate(labels)) 35 | return dataset 36 | 37 | 38 | # ========================================================= 39 | # Callable variables and functions from outside the module 40 | 41 | data = None 42 | labels = None 43 | dsName = None 44 | 45 | 46 | def loadDataSet(dsname): 47 | if dsname not in _datasetFeaturesFiles: 48 | raise NameError('Unknwown dataset: {}'.format(dsname)) 49 | 50 | global dsName, data, labels, _randStates, _rsCfg, _min_examples 51 | dsName = dsname 52 | _randStates = None 53 | _rsCfg = None 54 | 55 | # Loading data from files on computer 56 | # home = expanduser("~") 57 | dataset = _load_pickle(_datasetFeaturesFiles[dsname]) 58 | 59 | # Computing the number of items per class in the dataset 60 | _min_examples = dataset["labels"].shape[0] 61 | for i in range(dataset["labels"].shape[0]): 62 | if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0: 63 | _min_examples = min(_min_examples, torch.where( 64 | dataset["labels"] == dataset["labels"][i])[0].shape[0]) 65 | print("Guaranteed number of items per class: {:d}\n".format(_min_examples)) 66 | 67 | # Generating data tensors 68 | data = torch.zeros((0, _min_examples, dataset["data"].shape[1])) 69 | labels = dataset["labels"].clone() 70 | while labels.shape[0] > 0: 71 | indices = torch.where(dataset["labels"] == labels[0])[0] 72 | data = torch.cat([data, dataset["data"][indices, :] 73 | [:_min_examples].view(1, _min_examples, -1)], dim=0) 74 | indices = torch.where(labels != labels[0])[0] 75 | labels = labels[indices] 76 | print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format( 77 | data.shape[0], data.shape[1], data.shape[2])) 78 | 79 | 80 | def GenerateRun(iRun, cfg, regenRState=False, generate=True): 81 | global _randStates, data, _min_examples 82 | if not regenRState: 83 | np.random.set_state(_randStates[iRun]) 84 | 85 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] 86 | shuffle_indices = np.arange(_min_examples) 87 | dataset = None 88 | if generate: 89 | dataset = torch.zeros( 90 | (cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) 91 | for i in range(cfg['ways']): 92 | shuffle_indices = np.random.permutation(shuffle_indices) 93 | if generate: 94 | dataset[i] = data[classes[i], shuffle_indices, 95 | :][:cfg['shot']+cfg['queries']] 96 | 97 | return dataset 98 | 99 | 100 | def ClassesInRun(iRun, cfg): 101 | global _randStates, data 102 | np.random.set_state(_randStates[iRun]) 103 | 104 | classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] 105 | return classes 106 | 107 | 108 | def setRandomStates(cfg): 109 | global _randStates, _rsCfg 110 | if _rsCfg == cfg: 111 | return 112 | 113 | rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}_r{}".format( 114 | dsName, cfg['shot'], cfg['queries'], cfg['ways'], cfg['runs'])) 115 | if not os.path.exists(rsFile): 116 | print("{} does not exist, regenerating it...".format(rsFile)) 117 | np.random.seed(0) 118 | _randStates = [] 119 | for iRun in range(cfg['runs']): 120 | _randStates.append(np.random.get_state()) 121 | GenerateRun(iRun, cfg, regenRState=True, generate=False) 122 | torch.save(_randStates, rsFile) 123 | else: 124 | print("reloading random states from file....") 125 | _randStates = torch.load(rsFile) 126 | _rsCfg = cfg 127 | 128 | 129 | def GenerateRunSet(cfg=None): 130 | global dataset, _maxRuns 131 | if cfg is None: 132 | cfg = {"shot": 1, "ways": 5, "queries": 15, "runs":_maxRuns} 133 | 134 | start = 0 135 | end = cfg['runs'] 136 | 137 | setRandomStates(cfg) 138 | print("generating task from {} to {}".format(start, end)) 139 | 140 | dataset = torch.zeros( 141 | (end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) 142 | for iRun in range(end-start): 143 | dataset[iRun] = GenerateRun(iRun, cfg) 144 | 145 | return dataset 146 | 147 | 148 | # define a main code to test this module 149 | if __name__ == "__main__": 150 | 151 | print("Testing Task loader for Few Shot Learning") 152 | loadDataSet('miniimagenet') 153 | 154 | cfg = {"shot": 1, "ways": 5, "queries": 15, "runs": 10} 155 | setRandomStates(cfg) 156 | 157 | run10 = GenerateRun(10, cfg) 158 | print("First call:", run10[:2, :2, :2]) 159 | print(ds.size()) 160 | -------------------------------------------------------------------------------- /test_standard_bms.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle 3 | import random 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import math 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from numpy import linalg as LA 15 | import argparse 16 | 17 | use_gpu = torch.cuda.is_available() 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description= 'few-shot script') 21 | parser.add_argument('--dataset', default='mini', help='model: mini/tiered/cub/cifar-fs') 22 | parser.add_argument('--model', default='wrn', help='model: wrn') 23 | parser.add_argument('--shot', default=1, type=int, help='1/5') 24 | parser.add_argument('--run', default=10000, type=int, help='600/1000/10000') 25 | parser.add_argument('--way', default=5, type=int) 26 | parser.add_argument('--query', default=15, type=int) 27 | parser.add_argument('--method', default='BMS', help='BMS/BMS_') 28 | parser.add_argument('--preprocess', default='PEME') 29 | parser.add_argument('--step', default=20, type=int) 30 | parser.add_argument('--mmt', default=0.8, type=float) 31 | parser.add_argument('--lr', default=0.1, type=float) 32 | parser.add_argument('--lam', default=8.5, type=float) 33 | parser.add_argument('--epoch', default=0, type=int) 34 | return parser.parse_args() 35 | 36 | class DataSet: 37 | data: None 38 | labels: None 39 | 40 | def __init__(self, data=None, n_shot=1, n_ways=5, n_queries=15): 41 | self.data = data 42 | self.n_shot = n_shot 43 | self.n_ways = n_ways 44 | if self.data is not None: 45 | self.n_runs = data.size(0) 46 | self.n_samples = data.size(1) 47 | self.n_feat = data.size(2) 48 | self.n_lsamples = n_ways*n_shot 49 | self.n_queries = n_queries 50 | self.n_usamples = n_ways*n_queries 51 | self.generateLabels() 52 | if self.n_samples != self.n_lsamples + self.n_usamples: 53 | print("Invalid settings: queries incorrect wrt size") 54 | self.exit() 55 | 56 | def cuda(self): 57 | self.data = self.data.cuda() 58 | self.labels = self.labels.cuda() 59 | 60 | def cpu(self): 61 | self.data = self.data.cpu() 62 | self.labels = self.labels.cpu() 63 | 64 | def generateLabels(self): 65 | self.labels = torch.arange(self.n_ways)\ 66 | .view(1,1,self.n_ways)\ 67 | .expand(self.n_runs,self.n_shot+self.n_queries,self.n_ways)\ 68 | .clone().view(self.n_runs, self.n_samples) 69 | def printState(self): 70 | print("DataSet: {}-shot, {}-ways, {}-queries, {}-runs, {}-feats".format( \ 71 | self.n_shot, self.n_ways, self.n_queries, self.n_runs, self.n_feat)) 72 | print("\t {}-labelled {}-unlabelled {}-tot".format( \ 73 | self.n_lsamples, self.n_usamples, self.n_samples)) 74 | 75 | class BaseModel: 76 | def __init__(self, ds): 77 | self.ds = ds 78 | 79 | # SHOULD not be override! 80 | # this should be done through getScores() overriding 81 | def getProbas(self, scoresRescale=1, forceLabelledToOne=True): 82 | scores = self.getScores() 83 | p_xj = F.softmax(-scores*scoresRescale, dim=2) 84 | if forceLabelledToOne: 85 | p_xj[:,:self.ds.n_lsamples].fill_(0) 86 | p_xj[:,:self.ds.n_lsamples].scatter_(2,self.ds.labels[:,:self.ds.n_lsamples].unsqueeze(2), 1) 87 | 88 | return p_xj 89 | 90 | class TrainedLinRegModel(BaseModel): 91 | def __init__(self, ds, useBias=False): 92 | super(TrainedLinRegModel, self).__init__(ds) 93 | self.mus = None # shape [n_runs][n_ways][n_nfeat] 94 | self.ds = ds 95 | self.weights = torch.Tensor(ds.n_runs, ds.n_feat, ds.n_ways) 96 | self.scalers = torch.Tensor(ds.n_runs) 97 | 98 | def cuda(self): 99 | self.mus = self.mus.cuda() 100 | self.weights = self.weights.cuda() 101 | self.scalers = self.scalers.cuda() 102 | 103 | def cpu(self): 104 | self.mus = self.mus.cpu() 105 | self.weights = self.weights.cpu() 106 | self.scalers = self.scalers.cpu() 107 | 108 | # intitialise params of linReg from pre-computed mus 109 | def initParams(self): 110 | self.scalers.fill_(1) 111 | self.mus = self.mus/self.mus.norm(dim=2, keepdim=True) 112 | cdata = self.mus.permute(0,2,1) 113 | self.weights.copy_(cdata) 114 | 115 | def initFromLabelledDatas(self): 116 | ds = self.ds 117 | self.mus = ds.data\ 118 | .reshape(ds.n_runs, ds.n_shot+ds.n_queries,ds.n_ways, ds.n_feat)[:,:ds.n_shot,]\ 119 | .mean(1) 120 | 121 | self.initParams() 122 | 123 | def getParameters(self, asNNParameter=False): 124 | params = [self.weights, self.scalers] 125 | if asNNParameter: 126 | pp = [nn.Parameter(t.clone()) for t in params] 127 | params = pp 128 | 129 | return params 130 | 131 | def getScores(self): 132 | params = self.getParameters() 133 | ds = self.ds 134 | scores = 1 - ds.data.matmul(self.weights) 135 | 136 | return scores 137 | 138 | def train(self, wsamples, trainCfg, wmasker=None, updateWeights=True): 139 | ds = self.ds 140 | 141 | # computing emus 142 | emus = wsamples.permute(0,2,1).matmul(ds.data).div(wsamples.sum(dim=1).unsqueeze(2)) 143 | emus = emus/emus.norm(dim=2, keepdim=True) 144 | cdata = emus.permute(0,2,1) #[10000, 640, 5] 145 | 146 | mparameters = self.getParameters(asNNParameter=True) 147 | 148 | # get initialisation from centroids estimate 149 | mparameters[0].data.copy_(cdata) 150 | mparameters[1].data.fill_(1) 151 | optimizer = torch.optim.SGD(mparameters, lr=trainCfg['lr'], momentum=trainCfg['mmt']) 152 | 153 | for epoch in range(trainCfg['epochs']): 154 | optimizer.zero_grad() 155 | scores = ds.data.matmul(mparameters[0]) 156 | scores = scores / mparameters[0].norm(dim=1, keepdim=True) 157 | scores = scores * mparameters[1].unsqueeze(1).unsqueeze(1) 158 | 159 | output = F.log_softmax(scores, dim=2) 160 | loss_train = -output.mul(wsamples).sum(2).mean(1).sum(0) 161 | 162 | loss_train.backward() 163 | optimizer.step() 164 | mparameters[0].data.div_(mparameters[0].data.norm(dim=1, keepdim=True)) 165 | 166 | if updateWeights: 167 | self.weights.copy_(mparameters[0].data) 168 | 169 | 170 | # ========================================= 171 | # Optimization routines 172 | # ========================================= 173 | 174 | class Optimizer: 175 | def __init__(self, ds, wmasker=None): 176 | self.ds = ds 177 | self.wmasker = wmasker 178 | if self.wmasker is None: 179 | self.wmasker = SimpleWMask() 180 | 181 | def getAccuracy(self, probas): 182 | olabels = probas.argmax(dim=2) 183 | matches = self.ds.labels.eq(olabels).float() 184 | acc_test = matches[:,self.ds.n_lsamples:].mean(1) 185 | 186 | m = acc_test.mean().item() 187 | pm = acc_test.std().item() *1.96 / math.sqrt(acc_test.size(0)) 188 | return m, pm 189 | 190 | # ========================================= 191 | # Class to define samples mask for Centroid computations 192 | # ========================================= 193 | 194 | class SimpleWMask: 195 | """ class that selects which samples to be used for centroid computatoin 196 | Default implementation use probas as wmask 197 | """ 198 | def __init__(self, ds): 199 | self.ds = ds 200 | self.doSinkhorn = True 201 | self.nIter = 50 202 | 203 | def BMS_(self, p, nIter=None): 204 | global epoch 205 | ds = self.ds 206 | target = ds.n_queries 207 | if nIter is None: 208 | nIter = self.nIter 209 | op = p[:,ds.n_lsamples:] 210 | 211 | for iter in range(nIter): 212 | wp = op.div(op.sum(1, keepdim=True)/target) 213 | op = wp.div(wp.sum(2, keepdim=True)) 214 | 215 | wm = p.clone() 216 | wm[:,ds.n_lsamples:] = op 217 | wm[:,:ds.n_lsamples] = 0 218 | wm[:,:ds.n_lsamples].scatter_(2,ds.labels[:,:ds.n_lsamples].unsqueeze(2), 1) 219 | return wm 220 | 221 | def BMS(self, p, minSize, nIter=None): 222 | global epoch 223 | ds = self.ds 224 | wm_total = [] 225 | n_runs, n, m = p.shape 226 | 227 | if nIter is None: 228 | nIter = self.nIter 229 | op = p[:,ds.n_lsamples:] 230 | 231 | for iter in range(nIter): 232 | op = op.div(op.sum(2, keepdim=True)) 233 | mask = (op.sum(1, keepdim=True) < minSize.unsqueeze(1).unsqueeze(1)).all(dim=1).int() # [10000, 5] 234 | mask_inv = (op.sum(1, keepdim=True) >= minSize.unsqueeze(1).unsqueeze(1)).all(dim=1).int() 235 | temp = op * mask.unsqueeze(1) 236 | temp = temp.div((temp+1e-20).sum(1, keepdim=True)/minSize.unsqueeze(1).unsqueeze(1)) 237 | op = op * mask_inv.unsqueeze(1) + temp * mask.unsqueeze(1) 238 | 239 | wm = p.clone() 240 | wm[:,ds.n_lsamples:] = op 241 | wm[:,:ds.n_lsamples] = 0 242 | wm[:,:ds.n_lsamples].scatter_(2, ds.labels[:,:ds.n_lsamples].unsqueeze(2), 1) 243 | return wm 244 | 245 | def fastOT(self, p, minSize): 246 | if params.method == 'BMS': 247 | q = self.BMS(p, minSize) 248 | elif params.method == 'BMS_': 249 | q = self.BMS_(p) 250 | return q 251 | 252 | def getMinsize(self, p): 253 | mask = p.new_zeros(p.shape) 254 | mask.scatter_(2, p.argmax(dim=2, keepdim=True), 1) 255 | nsamplesInCluster = mask.sum(1) 256 | minSize = nsamplesInCluster.min(1)[0] 257 | 258 | return minSize 259 | 260 | def getWMask(self, probas, minSize, epochInfo=None): 261 | 262 | if self.doSinkhorn: 263 | return self.fastOT(probas, minSize) 264 | else: 265 | return probas 266 | 267 | 268 | # ======================================== 269 | # loading datas 270 | 271 | def reloadRuns(shot=1, n_ways=5, n_queries=15, n_runs=10000): 272 | (datas, labels) = torch.load("cache/runs{}_s{}_w{}_q{}_r{}".format(n_runs, shot, n_ways, n_queries, n_runs)) 273 | print("-- loaded datas and labels size:") 274 | print(datas.size()) 275 | print(labels.size()) 276 | 277 | return datas, labels 278 | 279 | def save_pickle(file, data): 280 | with open(file, 'wb') as f: 281 | pickle.dump(data, f) 282 | 283 | def load_pickle(file): 284 | with open(file, 'rb') as f: 285 | return pickle.load(f) 286 | 287 | def centerDatas(datas): 288 | 289 | datas = datas - datas.mean(1, keepdim=True) 290 | return datas 291 | 292 | def scaleEachUnitaryDatas(datas): 293 | 294 | norms = datas.norm(dim=2, keepdim=True) 295 | return datas/norms 296 | 297 | def QRreduction(datas): 298 | 299 | #ndatas = torch.linalg.qr(datas.permute(0,2,1), mode='reduced').R 300 | ndatas = torch.qr(datas.permute(0,2,1)).R 301 | ndatas = ndatas.permute(0,2,1) 302 | return ndatas 303 | 304 | def getRunSet(n_shot, n_ways, n_queries, n_runs, preprocess='PEME', dataset='mini', model='wrn'): 305 | import FSLTask 306 | cfg = {'shot':n_shot, 'ways':n_ways, 'queries':n_queries, 'runs':n_runs} 307 | load = dataset + '_' + model 308 | FSLTask.loadDataSet(load) 309 | ndatas = FSLTask.GenerateRunSet(cfg=cfg) 310 | ds = DataSet(ndatas.permute(0,2,1,3).reshape(n_runs, -1, ndatas.size(3)), \ 311 | n_shot=n_shot, n_ways=n_ways, n_queries=n_queries) 312 | 313 | if preprocess == 'R': 314 | print("--- preprocess: QR decomposition") 315 | ds.data = QRreduction(ds.data) 316 | ds.n_feat = ds.data.size(2) 317 | return ds 318 | 319 | if 'P' not in preprocess: 320 | print("--- preprocess: QR decomposition") 321 | ds.data = QRreduction(ds.data) 322 | ds.n_feat = ds.data.size(2) 323 | 324 | for p in preprocess: 325 | if p=='P': 326 | print("--- preprocess: Power transform") 327 | ds.data = torch.sqrt(ds.data+1e-6) 328 | 329 | elif p=="M": 330 | print("--- preprocess: Mean subtraction") 331 | ds.data = centerDatas(ds.data) 332 | print("--- preprocess: QR decomposition") 333 | ds.data = QRreduction(ds.data) 334 | ds.n_feat = ds.data.size(2) 335 | elif p=="E": 336 | print("--- preprocess: Euclidean normalization") 337 | ds.data = scaleEachUnitaryDatas(ds.data) 338 | else: 339 | print("unknown preprocessing!!") 340 | torch.exit() 341 | 342 | return ds 343 | 344 | if __name__ == '__main__': 345 | # ---- data loading 346 | params = parse_args() 347 | n_shots = params.shot 348 | n_runs = params.run 349 | n_ways = params.way 350 | n_queries = params.query 351 | dataset = getRunSet(n_shot=n_shots, n_ways=n_ways, n_queries=n_queries, n_runs=n_runs, preprocess=params.preprocess, dataset=params.dataset, model=params.model) 352 | dataset.printState() 353 | dataset.cuda() 354 | 355 | wmasker = SimpleWMask(dataset) 356 | optim = Optimizer(dataset, wmasker=wmasker) 357 | mm = TrainedLinRegModel(dataset) 358 | mm.initFromLabelledDatas() 359 | mm.cuda() 360 | 361 | probas = mm.getProbas() 362 | init_acc = optim.getAccuracy(probas) 363 | print("initialisation model accuracy", init_acc) 364 | 365 | minSize = (torch.ones(n_runs) * n_shots).cuda() 366 | for iter in range(params.step): 367 | probas = mm.getProbas(scoresRescale=params.lam) 368 | probas = wmasker.getWMask(probas, minSize) 369 | minSize = wmasker.getMinsize(probas) 370 | 371 | trainCfg = {'lr':params.lr, 'mmt':params.mmt, 'epochs':params.epoch} 372 | mm.train(probas, trainCfg, wmasker=wmasker, updateWeights=True) 373 | sink_acc = optim.getAccuracy(probas) 374 | print(iter, sink_acc) 375 | 376 | --------------------------------------------------------------------------------