├── __init__.py ├── requirements.txt ├── configs ├── cifar10_KWNG.yml └── cifar100_KWNG.yml ├── LICENSE ├── data_loader.py ├── gaussian.py ├── train.py ├── README.md ├── kwng.py ├── resnet.py └── trainer.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.2 2 | torch==1.2.0 3 | torchvision==0.4.0 4 | numpy==1.17.2 5 | yalm==0.15.46 6 | tensorboardX==1.8 -------------------------------------------------------------------------------- /configs/cifar10_KWNG.yml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | log_name: 'cifar10' 3 | 4 | dataset: 'cifar10' 5 | 6 | network: 'ResNet18IllCond' 7 | #network: 'ResNet18' 8 | 9 | 10 | opt_scheme: 'KNWG' 11 | num_basis: 5 12 | with_diag_mat : 1 13 | epsilon: 0.00001 14 | log_bandwidth: 0.0 15 | grad_clip: True 16 | kernel: 'gaussian' 17 | 18 | 19 | # dumping 20 | dumping_freq : 5 21 | max_red: 0.75 22 | min_red: 0.25 23 | reduction_coeff : 0.85 24 | 25 | optimizer: 'sgd' 26 | use_scheduler: True 27 | scheduler : 'MultiStepLR' 28 | milestone : '50,100,200,300' 29 | lr_decay : 0.1 30 | lr: 10. 31 | clip_grad : True 32 | b_size: 128 33 | 34 | -------------------------------------------------------------------------------- /configs/cifar100_KWNG.yml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | log_name: 'cifar100' 3 | 4 | dataset: 'cifar100' 5 | 6 | network: 'ResNet18IllCond' 7 | #network: 'ResNet18' 8 | 9 | opt_scheme: 'KNWG' 10 | num_basis: 5 11 | with_diag_mat : 1 12 | epsilon: 0.00001 13 | log_bandwidth: 0.0 14 | grad_clip: True 15 | kernel: 'gaussian' 16 | 17 | 18 | # dumping 19 | dumping_freq : 5 20 | max_red: 0.75 21 | min_red: 0.25 22 | reduction_coeff : 0.85 23 | 24 | optimizer: 'sgd' 25 | use_scheduler: True 26 | scheduler : 'MultiStepLR' 27 | milestone : '100,200,300' 28 | lr_decay : 0.1 29 | lr: 1. 30 | clip_grad : True 31 | b_size: 128 32 | 33 | 34 | 35 | #python train.py --device=2 --config='configs/cifar100_KWNG.yml' 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Michael Arbel 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import os 6 | 7 | 8 | def CIFARLoader( data_path, train_batch = 128, test_batch = 100): 9 | 10 | data_path = os.path.join(data_path,'CIFAR10') 11 | 12 | transform_train = transforms.Compose([ 13 | transforms.RandomCrop(32, padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 17 | ]) 18 | 19 | transform_test = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 22 | ]) 23 | 24 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train) 25 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch, shuffle=True, num_workers=2) 26 | 27 | testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test) 28 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 29 | 30 | data_loaders = {"train": trainloader, "val": testloader, "test":testloader} 31 | return data_loaders 32 | 33 | 34 | 35 | 36 | def CIFAR100Loader( data_path, train_batch = 128, test_batch = 100): 37 | data_path = os.path.join(data_path,'CIFAR100') 38 | 39 | transform_train = transforms.Compose([ 40 | transforms.RandomCrop(32, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=[n/255. 44 | for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2, 65.4, 70.4]]), 45 | ]) 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=[n/255. 50 | for n in [129.3, 124.1, 112.4]], std=[n/255. for n in [68.2, 65.4, 70.4]]), 51 | ]) 52 | 53 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train) 54 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch, shuffle=True, num_workers=2) 55 | 56 | testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test) 57 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 58 | 59 | data_loaders = {"train": trainloader, "val": testloader, "test":testloader} 60 | return data_loaders 61 | 62 | 63 | -------------------------------------------------------------------------------- /gaussian.py: -------------------------------------------------------------------------------- 1 | import torch as tr 2 | 3 | 4 | class Gaussian(object): 5 | def __init__(self, D, log_sigma, dtype = tr.float32, device = 'cpu'): 6 | self.D =D 7 | self.params = log_sigma 8 | self.dtype = dtype 9 | self.device = device 10 | self.adaptive= False 11 | self.params_0 = log_sigma 12 | 13 | def get_exp_params(self): 14 | return pow_10(self.params, dtype= self.dtype, device = self.device) 15 | def update_params(self,log_sigma): 16 | self.params = log_sigma 17 | 18 | 19 | def square_dist(self, X, Y): 20 | # Squared distance matrix of pariwise elements in X and basis 21 | # Inputs: 22 | # X : N by d matrix of data points 23 | # basis : M by d matrix of basis points 24 | # output: N by M matrix 25 | 26 | return self._square_dist( X, Y) 27 | 28 | def kernel(self, X,Y): 29 | 30 | # Gramm matrix between vectors X and basis 31 | # Inputs: 32 | # X : N by d matrix of data points 33 | # basis : M by d matrix of basis points 34 | # output: N by M matrix 35 | 36 | return self._kernel(self.params,X, Y) 37 | 38 | def dkdxdy(self,X,Y,mask=None): 39 | return self._dkdxdy(self.params,X,Y,mask=mask) 40 | # Private functions 41 | 42 | def _square_dist(self,X, Y): 43 | n_x,d = X.shape 44 | n_y,d = Y.shape 45 | dist = -2*tr.einsum('mr,nr->mn',X,Y) + tr.sum(X**2,1).unsqueeze(-1).repeat(1,n_y) + tr.sum(Y**2,1).unsqueeze(0).repeat(n_x,1) # tr.einsum('m,n->mn', tr.ones([ n_x],dtype=self.dtype, device = self.device),tr.sum(Y**2,1)) 46 | 47 | return dist 48 | 49 | def _kernel(self,log_sigma,X,Y): 50 | N,d = X.shape 51 | sigma = pow_10(log_sigma,dtype= self.dtype, device = self.device) 52 | tmp = self._square_dist( X, Y) 53 | dist = tr.max(tmp,tr.zeros_like(tmp)) 54 | if self.adaptive: 55 | ss = tr.mean(dist).clone().detach() 56 | dist = dist/(ss+1e-5) 57 | return tr.exp(-0.5*dist/sigma) 58 | 59 | 60 | def _dkdxdy(self,log_sigma,X,Y,mask=None): 61 | # X : [M,T] 62 | # Y : [N,R] 63 | 64 | # dkdxdy , dkdxdy2 = [M,N,T,R] 65 | # dkdx = [M,N,T] 66 | N,d = X.shape 67 | sigma = pow_10(log_sigma,dtype= self.dtype, device = self.device) 68 | gram = self._kernel(log_sigma,X, Y) 69 | 70 | D = (X.unsqueeze(1) - Y.unsqueeze(0))/sigma 71 | 72 | I = tr.ones( D.shape[-1],dtype=self.dtype, device = self.device)/sigma 73 | 74 | dkdy = tr.einsum('mn,mnr->mnr', gram,D) 75 | dkdx = -dkdy 76 | 77 | 78 | 79 | if mask is None: 80 | D2 = tr.einsum('mnt,mnr->mntr', D, D) 81 | I = tr.eye( D.shape[-1],dtype=self.dtype, device = self.device)/sigma 82 | dkdxdy = I - D2 83 | dkdxdy = tr.einsum('mn, mntr->mntr', gram, dkdxdy) 84 | else: 85 | D_masked = tr.einsum('mnt,mt->mn', D, mask) 86 | D2 = tr.einsum('mn,mnr->mnr', D_masked, D) 87 | 88 | dkdxdy = tr.einsum('mn,mr->mnr', gram, mask)/sigma -tr.einsum('mn, mnr->mnr', gram, D2) 89 | dkdx = tr.einsum('mnt,mt->mn',dkdx,mask) 90 | 91 | return dkdxdy, dkdx, gram 92 | 93 | 94 | 95 | 96 | def pow_10(x, dtype=tr.float32,device = 'cpu'): 97 | 98 | return tr.pow(tr.tensor(10., dtype=dtype, device = device),x) 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import argparse 5 | import yaml 6 | 7 | from trainer import Trainer 8 | torch.backends.cudnn.deterministic = True 9 | torch.backends.cudnn.benchmark=False 10 | 11 | 12 | 13 | def make_flags(args,config_file): 14 | if config_file: 15 | config = yaml.load(open(config_file)) 16 | dic = vars(args) 17 | all(map(dic.pop, config)) 18 | dic.update(config) 19 | return args 20 | 21 | parser = argparse.ArgumentParser(description='KWNG') 22 | 23 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 24 | parser.add_argument('--log_name', default = '',type= str, help='log name') 25 | parser.add_argument('--log_dir', default = '',type= str, help='log directory for summaries and checkpoints') 26 | parser.add_argument('--dataset', default = 'cifar10',type= str, help='name of the dataset to use cifar10 or cifar100') 27 | parser.add_argument('--data_dir', default = 'data',type= str, help='directory to the dataset') 28 | parser.add_argument('--log_in_file', action = 'store_true' , help='log output in a file') 29 | 30 | parser.add_argument('--device', default = 0 ,type= int, help='gpu device') 31 | parser.add_argument('--seed', default = 0 ,type= int , help='seed for randomness') 32 | parser.add_argument('--dtype', default = '32' ,type= str , help='32 for float32 and 64 for float64') 33 | parser.add_argument('--total_epochs', default=350, type=int, help='total number of epochs') 34 | 35 | parser.add_argument('--network', default = 'ResNet18IllCond' ,type= str, help='classifier network: Ill-conditioned case:ResNet18IllCond and well-conditioned case:ResNet18 ') 36 | parser.add_argument('--num_classes', default = 10 ,type= int , help='number of classes') 37 | parser.add_argument('--criterion', default = 'cross_entropy' ,type= str , help='top level loss') 38 | 39 | # Optimizer parameters 40 | parser.add_argument('--optimizer', default = 'sgd',type= str, help='sgd') 41 | parser.add_argument('--b_size', default = 128 ,type= int, help='batch size') 42 | parser.add_argument('--lr', default=.1, type=float, help='learning rate') 43 | parser.add_argument('--momentum', default=0., type=float, help='momentum') 44 | parser.add_argument('--weight_decay', default=0., type=float, help='weight decay') 45 | 46 | parser.add_argument('--lr_decay', default = 0.1 ,type= float , help='decay factor for lr') 47 | parser.add_argument('--clip_grad', action = 'store_true', help=' clip the gradient by norm ') 48 | 49 | # Scheduler parameters 50 | parser.add_argument('--use_scheduler', default = 'store_true' , help='schedule the lr') 51 | parser.add_argument('--scheduler', default ='MultiStepLR' ,type= str , help=' scheduler ') 52 | parser.add_argument('--milestone', default = '100,200,300' ,type= str , help='decrease schedule for lr ') 53 | 54 | # estimator of the natural gradient 55 | parser.add_argument('--estimator', default = 'KWNG',type= str, help='proposed estimator') 56 | parser.add_argument('--kernel', default = 'gaussian' ,type= str, help=' the kernel used in the estimator ') 57 | parser.add_argument('--log_bandwidth', default = 0. ,type= float , help=' log bandwidth of the kernel ') 58 | parser.add_argument('--epsilon', default = 1e-5 ,type= float, help=' Initial value for damping ') 59 | parser.add_argument('--num_basis', default = 5 ,type= int , help='number of basis for KWNG ') 60 | 61 | # Dumping parameters 62 | parser.add_argument('--dumping_freq', default = 5 ,type= int , help=' update epsilon each dumping_freq iterations ') 63 | parser.add_argument('--reduction_coeff', default = 0.85 ,type= float , help=' increase or descrease epsilon by reduction_coeff factor') 64 | parser.add_argument('--min_red', default = 0.25 ,type= float , help=' min threshold for reduction factor') 65 | parser.add_argument('--max_red', default = 0.75 ,type= float , help=' max threshold for reduction factor') 66 | parser.add_argument('--with_diag_mat', default=1, type=int, help=' 1: Use the norm of the jacobian for non isotropic damping') 67 | 68 | parser.add_argument('--config', default ='' ,type= str , help='config file for the run ') 69 | parser.add_argument('--with_sacred', default =False ,type= bool , help=' disabled by default, can only work if sacred is installed') 70 | 71 | 72 | 73 | args = parser.parse_args() 74 | args = make_flags(args,args.config) 75 | exp = Trainer(args) 76 | 77 | 78 | train_acc,val_acc = exp.train() 79 | test_acc = exp.test() 80 | print('Training completed!') 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Table of contents 2 | 3 | * [Introduction](#introduction) 4 | * [Requirements](#requirements) 5 | * [How to use](#how-to-use) 6 | * [Cifar10](#cifar10) 7 | * [Cifar100](#cifar100) 8 | * [Resources](#resources) 9 | * [Data](#data) 10 | * [Hardware](#hardware) 11 | * [Full documentation](#full-documentation) 12 | * [Reference](#reference) 13 | * [License](#license) 14 | 15 | ## Introduction 16 | 17 | This repository contains an implementation of the Kernelized Wasserstein Natural Gradient estimator and provides scripts to reproduce the results of its [eponymous paper](https://arxiv.org/abs/1910.09652) published at ICLR 2020. 18 | 19 | 20 | ## Requirements 21 | 22 | 23 | This a Pytorch implementation which requires the following packages: 24 | 25 | ``` 26 | python==3.6.2 or newer 27 | torch==1.2.0 or newer 28 | torchvision==0.4.0 or newer 29 | numpy==1.17.2 or newer 30 | ``` 31 | 32 | All dependencies can be installed using: 33 | 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 39 | 40 | 41 | ## How to use 42 | 43 | 44 | ### Cifar10 45 | ``` 46 | python train.py --device=-1 --config='configs/cifar10_KWNG.yml' 47 | ``` 48 | 49 | ### Cifar100 50 | 51 | ``` 52 | python train.py --device=-1 --config='configs/cifar100_KWNG.yml' 53 | ``` 54 | 55 | 56 | 57 | 58 | ## Resources 59 | 60 | ### Data 61 | 62 | To be able to reproduce the results of the paper on Cifar10 and Cifar100 using the prodivided scripts, both datasets need to be downloaded. This is automatically done by the script. By default a directory named 'data' containing both datasets is created in the working directory. 63 | 64 | 65 | ### Hardware 66 | 67 | To use a particular GPU, set —device=#gpu_id 68 | To use GPU without specifying a particular one, set —device=-1 69 | To use CPU set —device=-2 70 | 71 | 72 | ## Full documentation 73 | 74 | ``` 75 | --resume resume from checkpoint [False] 76 | --log_name log name [''] 77 | --log_dir log directory for summaries and checkpoints [''] 78 | --dataset name of the dataset to use cifar10 or cifar100 ['cifar10'] 79 | --data_dir directory to the dataset ['data'] 80 | --log_in_file log output in a file [False] 81 | 82 | --device gpu device [0] 83 | --seed seed for randomness [0] 84 | --dtype 32 for float32 and 64 for float64 ['32'] 85 | --total_epochs total number of epochs [350] 86 | 87 | --network classifier network: [Ill-conditioned case: 'ResNet18IllCond', well-conditioned case: 'ResNet18'] 88 | --num_classes number of classes [10] 89 | --criterion top level loss ['cross_entropy'] 90 | 91 | # Optimizer parameters 92 | --optimizer Inner optimizer to compute the euclidean gradient['sgd'] 93 | --b_size batch size [128] 94 | --lr learning rate [.1] 95 | --momentum momentum [0.] 96 | --weight_decay weight decay [0.] 97 | 98 | --lr_decay decay factor for lr [0.1] 99 | --clip_grad clip the gradient by norm ['store_true'] 100 | 101 | # Scheduler parameters 102 | --use_scheduler schedule the lr ['store_true'] 103 | --scheduler scheduler ['MultiStepLR'] 104 | --milestone help='decrease schedule for lr ['100,200,300'] 105 | 106 | # estimator of the natural gradient 107 | --estimator proposed estimator ['KWNG'] 108 | --kernel the kernel used in the estimator ['gaussian'] 109 | --log_bandwidth log bandwidth of the kernel [0.] 110 | --epsilon Initial value for damping [1e-5] 111 | --num_basis Number of basis for KWNG [5] 112 | 113 | # Dumping parameters 114 | --dumping_freq update epsilon each dumping_freq iterations [5] 115 | --reduction_coeff increase or descrease epsilon by reduction_coeff factor [0.85] 116 | --min_red min threshold for reduction factor [0.25] 117 | --max_red max threshold for reduction factor [0.75] 118 | --with_diag_mat Use the norm of the jacobian for non isotropic damping [1] 119 | 120 | --configs config file for the run [''] 121 | --with_sacred disabled by default, can only work if sacred is installed [False] 122 | 123 | ``` 124 | 125 | ## Reference 126 | 127 | If using this code for research purposes, please cite: 128 | 129 | [1] M. Arbel, A. Gretton, W. Li, G. Montufar [*Kernelized Wasserstein Natural Gradient*](https://arxiv.org/abs/1910.09652) 130 | 131 | ``` 132 | @article{Arbel:2018, 133 | author = {Michael Arbel, Arthur Gretton, Wuchen Li, Guido Montufar}, 134 | title = {Kernelized Wasserstein Natural Gradient}, 135 | journal = {ICLR}, 136 | year = {2020}, 137 | url = {https://arxiv.org/abs/1910.09652}, 138 | } } 139 | ``` 140 | 141 | 142 | ## License 143 | 144 | This code is under a BSD license. 145 | -------------------------------------------------------------------------------- /kwng.py: -------------------------------------------------------------------------------- 1 | import torch as tr 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | 7 | def get_flat_grad(net): 8 | grads = [] 9 | for param in net.parameters(): 10 | grads.append(param.grad.view(-1)) 11 | flat_grad = tr.cat(grads) 12 | return flat_grad 13 | 14 | def set_flat_grad(model, flat_grad): 15 | prev_ind = 0 16 | for param in model.parameters(): 17 | flat_size = int(np.prod(list(param.size()))) 18 | param.grad.copy_( 19 | flat_grad[prev_ind:prev_ind + flat_size].view(param.size())) 20 | prev_ind += flat_size 21 | 22 | class OptimizerWrapper(object): 23 | def __init__(self,optimizer, criterion, net,clip_grad): 24 | self.optimizer = optimizer 25 | self.criterion = criterion 26 | self.net = net 27 | self.clip = 1. 28 | self.clip_grad = clip_grad 29 | def step(self,inputs, targets): 30 | self.optimizer.zero_grad() 31 | outputs = self.net(inputs) 32 | _, predicted = outputs.max(1) 33 | loss = self.criterion(outputs, targets) 34 | 35 | loss.backward() 36 | if self.clip_grad: 37 | tr.nn.utils.clip_grad_norm_(self.net.parameters(), self.clip) 38 | self.optimizer.step() 39 | return loss.item(),predicted 40 | 41 | def eval(self,inputs,targets): 42 | self.optimizer.zero_grad() 43 | outputs = self.net(inputs) 44 | run_loss = self.criterion(outputs, targets).item() 45 | _, predicted = outputs.max(1) 46 | return run_loss,predicted 47 | 48 | 49 | 50 | class KWNGWrapper(OptimizerWrapper): 51 | def __init__(self,optimizer,criterion,net,clip_grad,KWNG_estimator,dumping_freq,reduction_coeff,min_red,max_red): 52 | OptimizerWrapper.__init__(self,optimizer,criterion,net,clip_grad) 53 | self.criterion = criterion 54 | self.optimizer = optimizer 55 | self.net= net 56 | self.KWNG_estimator = KWNG_estimator 57 | self.clip_grad = clip_grad 58 | self.clip_value = 1. 59 | self.old_loss = -1. 60 | self.dot_prod = 0. 61 | self.reduction_coeff =reduction_coeff 62 | self.dumping_freq = dumping_freq 63 | self.min_red = min_red 64 | self.max_red = max_red 65 | self.eps_min = 1e-10 66 | self.eps_max = 1e5 67 | self.dumping_counter = 0 68 | self.reduction_factor = 0. 69 | 70 | def step(self,inputs,targets): 71 | lr = self.optimizer.param_groups[0]['lr'] 72 | self.optimizer.zero_grad() 73 | outputs = self.net(inputs) 74 | loss = self.criterion(outputs,targets) 75 | 76 | # Adjust epsilon 77 | self.dumping(loss,lr) 78 | self.KWNG_estimator.compute_cond_matrix(self.net,outputs) 79 | loss.backward() 80 | g = get_flat_grad(self.net) 81 | cond_g = self.KWNG_estimator.compute_natural_gradient(g) 82 | 83 | # If the dot product is negative, just use the euclidean gradient 84 | self.dot_prod = tr.sum(g*cond_g) 85 | if self.dot_prod<=0: 86 | cond_g = g 87 | # Gradient clipping by norm 88 | if self.clip_grad: 89 | cond_g = self.clip_gradient(cond_g) 90 | 91 | # Saving the current value of the loss 92 | self.old_loss=loss.item() 93 | 94 | set_flat_grad(self.net,cond_g) 95 | self.optimizer.step() 96 | _, predicted = outputs.max(1) 97 | 98 | return loss.item(),predicted 99 | def clip_gradient(self,cond_g): 100 | 101 | norm_grad = tr.norm(cond_g) 102 | clip_coef = self.clip_value / (norm_grad + 1e-6) 103 | if clip_coef<1.: 104 | self.dot_prod = self.dot_prod/norm_grad 105 | return cond_g/norm_grad 106 | else: 107 | return cond_g 108 | 109 | def dumping(self,loss,lr): 110 | if self.old_loss>-1: 111 | # Compute the reduction ratio 112 | red = 2.*(self.old_loss-loss)/(lr*self.dot_prod) 113 | if red > self.reduction_factor: 114 | self.reduction_factor = red.item() 115 | self.dumping_counter +=1 116 | if self.old_loss>-1 and np.mod(self.dumping_counter,self.dumping_freq)==0: 117 | if self.reduction_factor< self.min_red and self.KWNG_estimator.epsself.max_red and self.KWNG_estimator.eps>self.eps_min: 120 | self.KWNG_estimator.eps = self.KWNG_estimator.eps*self.reduction_coeff 121 | print("New epsilon: "+ str(self.KWNG_estimator.eps) + ", Reduction_factor: " + str(self.reduction_factor)) 122 | self.reduction_factor = 0. 123 | 124 | class KWNG(nn.Module): 125 | 126 | def __init__(self, kernel,num_basis=5,eps = 1e-5, with_diag_mat = True): 127 | super(KWNG,self).__init__() 128 | self.kernel = kernel 129 | self.eps = eps 130 | self.thresh = 0. 131 | self.num_basis = num_basis 132 | self.with_diag_mat= with_diag_mat 133 | self.K = None 134 | self.T = None 135 | 136 | def compute_cond_matrix(self,net,outputs): 137 | 138 | L,d = outputs.shape 139 | idx = tr.randperm(outputs.shape[0]) 140 | outputs = outputs.view(outputs.size(0), -1) 141 | basis= outputs[idx[0: self.num_basis]].clone().detach() 142 | mask_int = tr.LongTensor(self.num_basis).random_(0,d) 143 | mask = tr.nn.functional.one_hot(mask_int,d).to(outputs.device) 144 | mask = mask.type(outputs.dtype) 145 | 146 | sigma = tr.log(tr.mean(self.kernel.square_dist(basis,outputs))).clone().detach() 147 | print(" sigma: " + str(tr.exp(sigma).item())) 148 | sigma /= np.log(10.) 149 | 150 | if hasattr(self.kernel, 'params_0'): 151 | self.kernel.params = self.kernel.params_0 + sigma 152 | 153 | dkdxdy, dkdx, _= self.kernel.dkdxdy(basis,outputs,mask=mask) 154 | self.K = (1./L)*tr.einsum('mni,kni->mk',dkdxdy,dkdxdy) 155 | aux_loss = tr.mean(dkdx,dim = 1) 156 | self.T = self.compute_jacobian(aux_loss,net) 157 | 158 | def compute_natural_gradient(self,g): 159 | 160 | uu,ss,vv = tr.svd(self.K.double()) 161 | ss_inv,mask = self.pseudo_inverse(ss) 162 | ss_inv = tr.sqrt(ss_inv) 163 | vv = tr.einsum('i,ji->ij',ss_inv,vv) 164 | self.T = tr.einsum('ij,jk->ik', vv.float(), self.T) 165 | cond_g, G,D = self.make_system(g,mask) 166 | 167 | try: 168 | U = tr.cholesky(G) 169 | cond_g = tr.cholesky_solve(cond_g.unsqueeze(-1),U).squeeze(-1) 170 | except: 171 | try: 172 | cond_g = tr.solve(cond_g.unsqueeze(-1),G)[0].squeeze(-1) 173 | except: 174 | pinv = tr.pinverse(G) 175 | cond_g = tr.einsum('mk,k',pinv,cond_g) 176 | cond_g = tr.einsum('md,m->d',self.T,cond_g) 177 | cond_g = (g-cond_g)/self.eps 178 | cond_g = D*cond_g 179 | return cond_g 180 | 181 | def make_system(self,g,mask): 182 | if self.with_diag_mat==1: 183 | D = tr.sqrt(tr.sum(self.T*self.T,dim=0)) 184 | D = 1./(D+1e-8) 185 | elif self.with_diag_mat==0: 186 | D = tr.ones( self.T.shape[1],dtype=self.T.dtype,device=self.T.device) 187 | 188 | 189 | cond_g = D*g 190 | cond_g = tr.einsum('md,d->m', self.T,cond_g) 191 | P = tr.zeros_like(cond_g) 192 | P[mask] = 1. 193 | G = tr.einsum('md,d,kd->mk',self.T,D,self.T) + self.eps*tr.diag(P) 194 | return cond_g,G,D 195 | 196 | def pseudo_inverse(self,S): 197 | SS = 1./S 198 | mask = (S<=self.thresh) 199 | SS[mask]=0. 200 | mask = (S>self.thresh) 201 | return SS, mask 202 | 203 | def compute_jacobian(self,loss,net): 204 | J = [] 205 | b_size = loss.shape[0] 206 | for i in range(b_size): 207 | grads = tr.autograd.grad(loss[i], net.parameters(),retain_graph=True) 208 | grads = [x.view(-1) for x in grads] 209 | grads = tr.cat(grads) 210 | J.append(grads) 211 | 212 | return tr.stack(J,dim=0) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | class VeryBasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(VeryBasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,padding=1, bias=False) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.shortcut = nn.Sequential() 16 | if stride != 1 or in_planes != self.expansion*planes: 17 | self.shortcut = nn.Sequential( 18 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 19 | ) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.conv1(x)) 23 | out = self.conv2(out) 24 | out += self.shortcut(x) 25 | out = F.relu(out) 26 | return out 27 | 28 | 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | 41 | self.shortcut = nn.Sequential() 42 | if stride != 1 or in_planes != self.expansion*planes: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(self.expansion*planes) 46 | ) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.bn2(self.conv2(out)) 51 | out += self.shortcut(x) 52 | out = F.relu(out) 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, in_planes, planes, stride=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != self.expansion*planes: 70 | self.shortcut = nn.Sequential( 71 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 72 | nn.BatchNorm2d(self.expansion*planes) 73 | ) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = F.relu(self.bn2(self.conv2(out))) 78 | out = self.bn3(self.conv3(out)) 79 | out += self.shortcut(x) 80 | out = F.relu(out) 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, block, num_blocks, num_classes=10): 86 | super(ResNet, self).__init__() 87 | self.in_planes = 64 88 | 89 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 92 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 93 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 94 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 95 | self.linear = nn.Linear(512*block.expansion, num_classes) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1]*(num_blocks-1) 99 | layers = [] 100 | for stride in strides: 101 | layers.append(block(self.in_planes, planes, stride)) 102 | self.in_planes = planes * block.expansion 103 | return nn.Sequential(*layers) 104 | 105 | def _make_block(self,in_planes, planes, stride=1): 106 | return nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False), 109 | nn.ReLU(inplace=True) 110 | ) 111 | 112 | def forward(self, x): 113 | out = F.relu(self.bn1(self.conv1(x))) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = self.layer4(out) 118 | out = F.avg_pool2d(out, 4) 119 | out = out.view(out.size(0), -1) 120 | out = self.linear(out) 121 | return out 122 | 123 | 124 | 125 | class ResNet(nn.Module): 126 | def __init__(self, block, num_blocks, num_classes=10, bad_conditioning = False ): 127 | super(ResNet, self).__init__() 128 | self.in_planes = 64 129 | self.bad_conditioning = bad_conditioning 130 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding = 1, bias=False) 131 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 132 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 133 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 134 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 135 | self.linear = nn.Linear(512*block.expansion, num_classes) 136 | self.pad_1 = (1,1,1,1,0,0,0,0) 137 | if self.bad_conditioning: 138 | self.non_linearity = torch.nn.Tanhshrink() 139 | 140 | def _make_layer(self, block, planes, num_blocks, stride): 141 | strides = [stride] + [1]*(num_blocks-1) 142 | layers = [] 143 | for stride in strides: 144 | layers.append(block(self.in_planes, planes, stride)) 145 | self.in_planes = planes * block.expansion 146 | return nn.Sequential(*layers) 147 | 148 | def _make_block(self,in_planes, planes, stride=1): 149 | return nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, bias=False), 150 | nn.ReLU(inplace=True), 151 | nn.Conv2d(planes, planes, kernel_size=3, stride=1, bias=False), 152 | nn.ReLU(inplace=True) 153 | ) 154 | 155 | def forward(self, x): 156 | out = F.relu(self.conv1(x)) 157 | 158 | out = self.layer1(out) 159 | 160 | out = self.layer2(out) 161 | out = self.layer3(out) 162 | out = self.layer4(out) 163 | out = F.avg_pool2d(out, 4) 164 | out = out.view(out.size(0), -1) 165 | 166 | out = self.linear(out) 167 | 168 | if self.bad_conditioning: 169 | cond_weights = torch.logspace(start=-6,end=1,steps=out.shape[1],dtype=out.dtype, device=out.device) 170 | out = torch.einsum('bm,m->bm',out,cond_weights) 171 | 172 | 173 | return out 174 | 175 | 176 | def ResNetBasic(num_classes=10): 177 | return ResNet(VeryBasicBlock, [2,2,2,2],num_classes=num_classes) 178 | 179 | def ResNet18(num_classes=10): 180 | return ResNet(BasicBlock, [2,2,2,2],num_classes=num_classes) 181 | 182 | def ResNet18IllCond(num_classes=10): 183 | return ResNet(BasicBlock, [2,2,2,2], bad_conditioning = True,num_classes=num_classes) 184 | 185 | def ResNetBasicIllCond(num_classes=10): 186 | return ResNet(VeryBasicBlock, [2,2,2,2], bad_conditioning = True,num_classes=num_classes) 187 | 188 | 189 | def ResNet34(): 190 | return ResNet(BasicBlock, [3,4,6,3]) 191 | 192 | def ResNet50(): 193 | return ResNet(Bottleneck, [3,4,6,3]) 194 | 195 | def ResNet101(): 196 | return ResNet(Bottleneck, [3,4,23,3]) 197 | 198 | def ResNet152(): 199 | return ResNet(Bottleneck, [3,8,36,3]) 200 | 201 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | import os, sys 9 | from tensorboardX import SummaryWriter 10 | 11 | import time 12 | import numpy as np 13 | import pprint 14 | import socket 15 | import pickle 16 | 17 | from resnet import * 18 | from kwng import * 19 | from gaussian import * 20 | from data_loader import * 21 | 22 | class Trainer(object): 23 | def __init__(self,args): 24 | torch.manual_seed(args.seed) 25 | self.args = args 26 | self.device = assign_device(args.device) 27 | self.log_dir = make_log_dir(args) 28 | 29 | if args.log_in_file: 30 | self.log_file = open(os.path.join(self.log_dir, 'log.txt'), 'w', buffering=1) 31 | sys.stdout = self.log_file 32 | sys.stderr = self.log_file 33 | print("Process id: " + str(os.getpid()) + " | hostname: " + socket.gethostname()) 34 | pp = pprint.PrettyPrinter(indent=4) 35 | pp.pprint(vars(args)) 36 | 37 | print('Creating writer') 38 | self.writer = SummaryWriter(self.log_dir) 39 | 40 | print('Loading data') 41 | if not os.path.isdir(args.data_dir): 42 | os.makedirs(args.data_dir, exist_ok=True) 43 | self.data_loaders = get_data_loader(args) 44 | self.total_epochs = self.args.total_epochs 45 | print('==> Building model..') 46 | self.build_model() 47 | 48 | 49 | def build_model(self): 50 | self.net = get_network(self.args) 51 | self.net = self.net.to(self.device) 52 | if self.args.dtype=='64': 53 | self.net = self.net.double() 54 | if self.device == 'cuda': 55 | self.net = torch.nn.DataParallel(self.net) 56 | cudnn.benchmark = True 57 | self.init_train_values() 58 | self.criterion = get_criterion(self.args) 59 | self.optimizer = get_optimizer(self.args,self.net.parameters(),self.net) 60 | self.scheduler = get_scheduler(self.args,self.optimizer) 61 | self.wrapped_optimizer = get_wrapped_optimizer(self.args,self.optimizer,self.criterion,self.net, device=self.device) 62 | 63 | 64 | def train(self): 65 | 66 | print(' Starting training') 67 | 68 | self.init_train_values() 69 | 70 | for epoch in range(self.start_epoch, self.start_epoch+self.total_epochs): 71 | 72 | train_acc = self.epoch_pass(epoch,'train') 73 | val_acc = self.epoch_pass(epoch,'val') 74 | if self.args.use_scheduler: 75 | self.scheduler.step() 76 | 77 | return train_acc,val_acc 78 | def test(self): 79 | print('Starting test') 80 | test_acc = self.epoch_pass(0,'test') 81 | return test_acc 82 | 83 | def init_train_values(self): 84 | if self.args.resume: 85 | # Load checkpoint. 86 | print('==> Resuming from checkpoint..') 87 | assert os.path.isdir(self.log_dir+'/checkpoint'), 'Error: no checkpoint directory found!' 88 | checkpoint = torch.load(self.log_dir+'/checkpoint/ckpt.t7') 89 | self.net.load_state_dict(checkpoint['net']) 90 | self.best_acc = checkpoint['acc'] 91 | self.best_loss = checkpoint['loss'] 92 | self.start_epoch = checkpoint['epoch'] 93 | self.total_iters = checkpoint['total_iters'] 94 | else: 95 | self.best_acc = 0 # best test accuracy 96 | self.start_epoch = 0 # start from epoch 0 or last checkpoint epoch 97 | self.total_iters = 0 98 | self.best_loss = torch.tensor(np.inf) 99 | 100 | def epoch_pass(self,epoch,phase): 101 | print('Epoch: '+ str(epoch) + ' | ' + phase + ' phase') 102 | if phase == 'train': 103 | self.net.train(True) # Set model to training mode 104 | else: 105 | self.net.train(False) # Set model to evaluate mode 106 | 107 | self.net.train() 108 | loss = 0 109 | correct = 0 110 | total = 0 111 | counts = 0 112 | for batch_idx, (inputs, targets) in enumerate(self.data_loaders[phase]): 113 | tic = time.time() 114 | inputs, targets = inputs.to(self.device), targets.to(self.device) 115 | if self.args.dtype=='64': 116 | inputs=inputs.double() 117 | if phase=="train": 118 | self.total_iters+=1 119 | 120 | loss_step, predicted = self.wrapped_optimizer.step(inputs,targets) 121 | loss_step, predicted = self.wrapped_optimizer.eval(inputs,targets) 122 | loss += loss_step 123 | running_loss = loss/(batch_idx+1) 124 | total += targets.size(0) 125 | correct += predicted.eq(targets).sum().item() 126 | acc= 100.*correct/total 127 | if phase=="train": 128 | self.writer.add_scalars('data/train_loss_step',{"loss_step":loss_step,"loss_averaged":running_loss},self.total_iters) 129 | toc = time.time() 130 | print(' Loss: ' + str(round(running_loss,3))+ ' | Acc: '+ str(acc) + ' ' +'('+str(correct) +'/'+str(total)+')' + ' time: ' + str(toc-tic) + ' iter: '+ str(batch_idx)) 131 | counts += 1 132 | 133 | self.writer.add_scalars('data/total_stats_'+phase, {"loss":loss/(batch_idx+1), "correct":acc},epoch) 134 | 135 | # Save checkpoint. 136 | if phase == 'val': 137 | avg_loss = loss/(batch_idx+1) 138 | if avg_loss < self.best_loss: 139 | save_checkpoint(self.writer.logdir,acc,avg_loss,epoch,self.total_iters,self.wrapped_optimizer.net) 140 | self.best_loss = avg_loss 141 | 142 | return acc 143 | 144 | def save_checkpoint(checkpoint_dir,acc,loss,epoch,total_iters,net): 145 | 146 | print('Saving..') 147 | state = { 148 | 'net': net.state_dict(), 149 | 'acc': acc, 150 | 'loss':loss, 151 | 'epoch': epoch, 152 | 'total_iters':total_iters, 153 | } 154 | if not os.path.isdir(checkpoint_dir +'/checkpoint'): 155 | os.mkdir(checkpoint_dir + '/checkpoint') 156 | torch.save(state,checkpoint_dir +'/checkpoint/ckpt.t7') 157 | 158 | def assign_device(device): 159 | if device >-1: 160 | device = 'cuda:'+str(device) if torch.cuda.is_available() and device>-1 else 'cpu' 161 | elif device==-1: 162 | device = 'cuda' 163 | elif device==-2: 164 | device = 'cpu' 165 | return device 166 | def make_log_dir(args): 167 | if args.with_sacred: 168 | log_dir = args.log_dir + '_' + args.log_name 169 | else: 170 | log_dir = os.path.join(args.log_dir,args.log_name) 171 | if not os.path.isdir(log_dir): 172 | os.mkdir(log_dir) 173 | return log_dir 174 | 175 | def get_dtype(args): 176 | if args.dtype=='32': 177 | return torch.float32 178 | elif args.dtype=='64': 179 | return torch.float64 180 | 181 | 182 | def get_network(args): 183 | if args.network=='ResNet18': 184 | return ResNet18(num_classes = args.num_classes) 185 | elif args.network=='ResNet18IllCond': 186 | return ResNet18IllCond(num_classes = args.num_classes) 187 | def get_kernel(args,device = 'cuda'): 188 | dtype = get_dtype(args) 189 | if args.kernel=='gaussian': 190 | return Gaussian(1,args.log_bandwidth,dtype=dtype, device = device) 191 | 192 | def get_wrapped_optimizer(args,optimizer,criterion,net,device = 'cuda'): 193 | if args.estimator=='EuclideanGradient': 194 | return OptimizerWrapper(optimizer,criterion,net,args.clip_grad) 195 | elif args.estimator=='KWNG': 196 | kernel = get_kernel(args, device=device) 197 | estimator = KWNG(kernel,eps=args.epsilon, num_basis = args.num_basis,with_diag_mat = args.with_diag_mat) 198 | return KWNGWrapper(optimizer,criterion,net,args.clip_grad,estimator,args.dumping_freq,args.reduction_coeff,args.min_red,args.max_red) 199 | 200 | def get_data_loader(args): 201 | if args.dataset=='cifar10': 202 | args.num_classes = 10 203 | return CIFARLoader(args.data_dir,args.b_size) 204 | elif args.dataset=='cifar100': 205 | args.num_classes = 100 206 | return CIFAR100Loader(args.data_dir,args.b_size) 207 | 208 | def get_optimizer(args,params,net): 209 | if args.optimizer=='sgd': 210 | return optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 211 | 212 | def get_scheduler(args,optimizer): 213 | if args.scheduler=='MultiStepLR': 214 | if args.milestone is None: 215 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.total_epochs*0.5), int(args.total_epochs*0.75)], gamma=args.lr_decay) 216 | else: 217 | milestone = [int(_) for _ in args.milestone.split(',')] 218 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestone, gamma=args.lr_decay) 219 | return lr_scheduler 220 | def get_criterion(args): 221 | if args.criterion=='cross_entropy': 222 | return nn.CrossEntropyLoss() 223 | 224 | 225 | 226 | 227 | --------------------------------------------------------------------------------