├── README.md ├── compare.png ├── models ├── __init__.py ├── resnet.py ├── vgg.py └── wide_resnet.py ├── recipes ├── run_rwp_ddp.sh └── run_rwp_imagenet.sh ├── train_rwp_imagenet.py ├── train_rwp_parallel.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Generalization Improvement Guided by Random Weight Perturbation 2 | 3 | This repository contains a PyTorch implementation of the paper: **Efficient Generalization Improvement Guided by Random Weight Perturbation**. 4 | 5 | 6 | ## Abstract 7 | To fully uncover the great potential of deep neural networks (DNNs), various learning algorithms have been developed to improve the model's generalization ability. Recently, sharpness-aware minimization (SAM) establishes a generic scheme for generalization improvements by minimizing the sharpness measure within a small neighborhood and achieves state-of-the-art performance. However, SAM requires two consecutive gradient evaluations for solving the min-max problem and inevitably doubles the training time. In this paper, we resort to filter-wise random weight perturbations (RWP) to decouple the nested gradients in SAM. Different from the small adversarial perturbations in SAM, RWP is softer and allows a much larger magnitude of perturbations. Specifically, we jointly optimize the loss function with random perturbations and the original loss function: the former guides the network towards a wider flat region while the latter helps recover the necessary local information. These two loss terms are complementary to each other and mutually independent. Hence, the corresponding gradients can be efficiently computed in parallel, enabling nearly the same training speed as regular training. As a result, we achieve very competitive performance on CIFAR and remarkably better performance on ImageNet (e.g. $+1.1\%$) compared with SAM, but always require half of the training time. 8 | 9 |
10 | 11 | 12 |
13 | 14 | ## Example Usage 15 | 16 | We provide example usages in `/recipes/`. 17 | For parallelized training of RWP, we could run 18 | 19 | ``` 20 | bash recipes/run_rwp_ddp.sh 21 | ``` 22 | -------------------------------------------------------------------------------- /compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/RWP/cb0acb0708720a40c441915b275fd2d5e70c734c/compare.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .vgg import * 3 | from .wide_resnet import * -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class BasicBlock(nn.Module): 15 | """Basic Block for resnet 18 and resnet 34 16 | 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | def __init__(self, in_channels, out_channels, stride=1): 57 | super().__init__() 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 67 | ) 68 | 69 | self.shortcut = nn.Sequential() 70 | 71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 75 | ) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, num_block, num_classes=100): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential( 88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 89 | nn.BatchNorm2d(64), 90 | nn.ReLU(inplace=True)) 91 | #we use a different inputsize than the original paper 92 | #so conv2_x's stride is 1 93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.fc = nn.Linear(512 * block.expansion, num_classes) 99 | 100 | def _make_layer(self, block, out_channels, num_blocks, stride): 101 | """make resnet layers(by layer i didnt mean this 'layer' was the 102 | same as a neuron netowork layer, ex. conv layer), one layer may 103 | contain more than one residual block 104 | 105 | Args: 106 | block: block type, basic block or bottle neck block 107 | out_channels: output depth channel number of this layer 108 | num_blocks: how many blocks per layer 109 | stride: the stride of the first block of this layer 110 | 111 | Return: 112 | return a resnet layer 113 | """ 114 | 115 | # we have num_block blocks per layer, the first block 116 | # could be 1 or 2, other blocks would always be 1 117 | strides = [stride] + [1] * (num_blocks - 1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_channels, out_channels, stride)) 121 | self.in_channels = out_channels * block.expansion 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | output = self.conv1(x) 127 | output = self.conv2_x(output) 128 | output = self.conv3_x(output) 129 | output = self.conv4_x(output) 130 | output = self.conv5_x(output) 131 | output = self.avg_pool(output) 132 | output = output.view(output.size(0), -1) 133 | output = self.fc(output) 134 | 135 | return output 136 | 137 | class resnet18: 138 | base = ResNet 139 | args = list() 140 | kwargs = {'block': BasicBlock, 'num_block': [2, 2, 2, 2]} 141 | 142 | # def resnet18(): 143 | # """ return a ResNet 18 object 144 | # """ 145 | # kwargs = {} 146 | # return ResNet(BasicBlock, [2, 2, 2, 2]) 147 | 148 | def resnet34(): 149 | """ return a ResNet 34 object 150 | """ 151 | return ResNet(BasicBlock, [3, 4, 6, 3]) 152 | 153 | def resnet50(): 154 | """ return a ResNet 50 object 155 | """ 156 | return ResNet(BottleNeck, [3, 4, 6, 3]) 157 | 158 | def resnet101(): 159 | """ return a ResNet 101 object 160 | """ 161 | return ResNet(BottleNeck, [3, 4, 23, 3]) 162 | 163 | def resnet152(): 164 | """ return a ResNet 152 object 165 | """ 166 | return ResNet(BottleNeck, [3, 8, 36, 3]) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | 10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN'] 11 | 12 | 13 | def make_layers(cfg, batch_norm=False): 14 | layers = list() 15 | in_channels = 3 16 | for v in cfg: 17 | if v == 'M': 18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return nn.Sequential(*layers) 27 | 28 | 29 | cfg = { 30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 32 | 512, 512, 512, 512, 'M'], 33 | } 34 | 35 | 36 | class VGG(nn.Module): 37 | def __init__(self, num_classes=10, depth=16, batch_norm=False): 38 | super(VGG, self).__init__() 39 | self.features = make_layers(cfg[depth], batch_norm) 40 | self.classifier = nn.Sequential( 41 | nn.Dropout(), 42 | nn.Linear(512, 512), 43 | nn.ReLU(True), 44 | nn.Dropout(), 45 | nn.Linear(512, 512), 46 | nn.ReLU(True), 47 | nn.Linear(512, num_classes), 48 | ) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | m.bias.data.zero_() 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.classifier(x) 60 | return x 61 | 62 | 63 | class Base: 64 | base = VGG 65 | args = list() 66 | kwargs = dict() 67 | transform_train = transforms.Compose([ 68 | transforms.RandomHorizontalFlip(), 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | ]) 73 | 74 | transform_test = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 77 | ]) 78 | 79 | 80 | class VGG16(Base): 81 | pass 82 | 83 | 84 | class VGG16BN(Base): 85 | kwargs = {'batch_norm': True} 86 | 87 | 88 | class VGG19(Base): 89 | kwargs = {'depth': 19} 90 | 91 | 92 | class VGG19BN(Base): 93 | kwargs = {'depth': 19, 'batch_norm': True} -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import math 11 | 12 | __all__ = ['WideResNet28x10', 'WideResNet16x8'] 13 | 14 | from collections import OrderedDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class BasicUnit(nn.Module): 22 | def __init__(self, channels: int, dropout: float): 23 | super(BasicUnit, self).__init__() 24 | self.block = nn.Sequential(OrderedDict([ 25 | ("0_normalization", nn.BatchNorm2d(channels)), 26 | ("1_activation", nn.ReLU(inplace=True)), 27 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 28 | ("3_normalization", nn.BatchNorm2d(channels)), 29 | ("4_activation", nn.ReLU(inplace=True)), 30 | ("5_dropout", nn.Dropout(dropout, inplace=True)), 31 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 32 | ])) 33 | 34 | def forward(self, x): 35 | return x + self.block(x) 36 | 37 | 38 | class DownsampleUnit(nn.Module): 39 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float): 40 | super(DownsampleUnit, self).__init__() 41 | self.norm_act = nn.Sequential(OrderedDict([ 42 | ("0_normalization", nn.BatchNorm2d(in_channels)), 43 | ("1_activation", nn.ReLU(inplace=True)), 44 | ])) 45 | self.block = nn.Sequential(OrderedDict([ 46 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)), 47 | ("1_normalization", nn.BatchNorm2d(out_channels)), 48 | ("2_activation", nn.ReLU(inplace=True)), 49 | ("3_dropout", nn.Dropout(dropout, inplace=True)), 50 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)), 51 | ])) 52 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False) 53 | 54 | def forward(self, x): 55 | x = self.norm_act(x) 56 | return self.block(x) + self.downsample(x) 57 | 58 | 59 | class Block(nn.Module): 60 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float): 61 | super(Block, self).__init__() 62 | self.block = nn.Sequential( 63 | DownsampleUnit(in_channels, out_channels, stride, dropout), 64 | *(BasicUnit(out_channels, dropout) for _ in range(depth)) 65 | ) 66 | 67 | def forward(self, x): 68 | return self.block(x) 69 | 70 | 71 | class WideResNet(nn.Module): 72 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, num_classes: int): 73 | super(WideResNet, self).__init__() 74 | 75 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor] 76 | self.block_depth = (depth - 4) // (3 * 2) 77 | 78 | self.f = nn.Sequential(OrderedDict([ 79 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)), 80 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)), 81 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)), 82 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)), 83 | ("4_normalization", nn.BatchNorm2d(self.filters[3])), 84 | ("5_activation", nn.ReLU(inplace=True)), 85 | ("6_pooling", nn.AvgPool2d(kernel_size=8)), 86 | ("7_flattening", nn.Flatten()), 87 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=num_classes)), 88 | ])) 89 | 90 | self._initialize() 91 | 92 | def _initialize(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu") 96 | if m.bias is not None: 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.weight.data.zero_() 103 | m.bias.data.zero_() 104 | 105 | def forward(self, x): 106 | return self.f(x) 107 | 108 | class WideResNet28x10: 109 | base = WideResNet 110 | args = list() 111 | kwargs = {'depth': 28, 'width_factor': 10, 'dropout': 0, 'in_channels': 3} 112 | transform_train = transforms.Compose([ 113 | transforms.RandomCrop(32, padding=4), 114 | transforms.RandomHorizontalFlip(), 115 | transforms.ToTensor(), 116 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 117 | ]) 118 | transform_test = transforms.Compose([ 119 | transforms.ToTensor(), 120 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 121 | ]) 122 | 123 | class WideResNet16x8: 124 | base = WideResNet 125 | args = list() 126 | kwargs = {'depth': 16, 'width_factor': 8, 'dropout': 0, 'in_channels': 3} 127 | transform_train = transforms.Compose([ 128 | transforms.RandomCrop(32, padding=4), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 132 | ]) 133 | transform_test = transforms.Compose([ 134 | transforms.ToTensor(), 135 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 136 | ]) -------------------------------------------------------------------------------- /recipes/run_rwp_ddp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################ CIFAR ################################### 4 | datasets=CIFAR100 5 | device=0,1 # use two GPUs for parallel computing 6 | model=resnet18 # resnet18 VGG16BN WideResNet16x8 WideResNet28x10 7 | schedule=cosine 8 | wd=0.001 9 | epoch=200 10 | bz=256 11 | lr=0.10 12 | port=1234 13 | seed=0 14 | alpha=0.5 15 | gamma=0.01 16 | 17 | DST=results/rwp_ddp_cutout_gamma$gamma\_alpha$alpha\_$epoch\_$bz\_$lr\_$model\_$wd\_$datasets\_$schedule\_seed$seed 18 | CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node 2 --master_port $port train_rwp_parallel.py --datasets $datasets \ 19 | --arch=$model --epochs=$epoch --wd=$wd --randomseed $seed --lr $lr --gamma $gamma --cutout -b $bz --alpha $alpha --workers 8 \ 20 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100 --schedule $schedule 21 | 22 | -------------------------------------------------------------------------------- /recipes/run_rwp_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=ImageNet 4 | device=0,1,2,3 # parallelized training for RWP 5 | 6 | model=resnet18 7 | path=... # dir for ImageNet datasets 8 | DST=save_resnet18 9 | CUDA_VISIBLE_DEVICES=$device python3 train_rwp_imagenet.py -a $model \ 10 | --epochs 90 --workers 16 --dist-url 'tcp://127.0.0.1:4234' --lr 0.1 -b 256 \ 11 | --dist-backend 'nccl' --multiprocessing-distributed --gamma 0.005 --alpha 0.5 \ 12 | --save-dir=$DST/checkpoints --log-dir=$DST \ 13 | --world-size 1 --rank 0 $path 14 | -------------------------------------------------------------------------------- /train_rwp_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import os 8 | 9 | import numpy as np 10 | import pickle 11 | 12 | from PIL import Image, ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | from utils import * 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.distributed as dist 21 | import torch.optim 22 | import torch.multiprocessing as mp 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | import torchvision.models as models 28 | 29 | model_names = sorted(name for name in models.__dict__ 30 | if name.islower() and not name.startswith("__") 31 | and callable(models.__dict__[name])) 32 | 33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 34 | parser.add_argument('data', metavar='DIR', 35 | help='path to dataset') 36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 37 | choices=model_names, 38 | help='model architecture: ' + 39 | ' | '.join(model_names) + 40 | ' (default: resnet18)') 41 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 42 | help='number of data loading workers (default: 4)') 43 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 44 | help='number of total epochs to run') 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | parser.add_argument('-b', '--batch-size', default=256, type=int, 48 | metavar='N', 49 | help='mini-batch size (default: 256), this is the total ' 50 | 'batch size of all GPUs on the current node when ' 51 | 'using Data Parallel or Distributed Data Parallel') 52 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 53 | metavar='LR', help='initial learning rate', dest='lr') 54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 55 | help='momentum') 56 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 57 | metavar='W', help='weight decay (default: 1e-4)', 58 | dest='weight_decay') 59 | parser.add_argument('--alpha', default=0.5, type=float, 60 | metavar='AA', help='alpha for mixing gradients') 61 | parser.add_argument('--gamma', default=0.01, type=float, 62 | metavar='GAMMA', help='gamma for noise') 63 | 64 | parser.add_argument('-p', '--print-freq', default=1000, type=int, 65 | metavar='N', help='print frequency (default: 10)') 66 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 69 | help='evaluate model on validation set') 70 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 71 | help='use pre-trained model') 72 | parser.add_argument('--world-size', default=-1, type=int, 73 | help='number of nodes for distributed training') 74 | parser.add_argument('--rank', default=-1, type=int, 75 | help='node rank for distributed training') 76 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 77 | help='url used to set up distributed training') 78 | parser.add_argument('--dist-backend', default='nccl', type=str, 79 | help='distributed backend') 80 | parser.add_argument('--seed', default=42, type=int, 81 | help='seed for initializing training. ') 82 | parser.add_argument('--gpu', default=None, type=int, 83 | help='GPU id to use.') 84 | parser.add_argument('--save-dir', dest='save_dir', 85 | help='The directory used to save the trained models', 86 | default='save_temp', type=str) 87 | parser.add_argument('--log-dir', dest='log_dir', 88 | help='The directory used to save the log', 89 | default='save_temp', type=str) 90 | parser.add_argument('--log-name', dest='log_name', 91 | help='The log file name', 92 | default='log', type=str) 93 | parser.add_argument('--multiprocessing-distributed', action='store_true', 94 | help='Use multi-processing distributed training to launch ' 95 | 'N processes per node, which has N GPUs. This is the ' 96 | 'fastest way to use PyTorch for either single node or ' 97 | 'multi node data parallel training') 98 | 99 | 100 | best_acc1 = 0 101 | 102 | 103 | param_vec = [] 104 | # Record training statistics 105 | train_loss = [] 106 | train_acc = [] 107 | test_loss = [] 108 | test_acc = [] 109 | arr_time = [] 110 | 111 | 112 | def get_model_grad_vec(model): 113 | # Return the model gradient as a vector 114 | 115 | vec = [] 116 | for name,param in model.named_parameters(): 117 | vec.append(param.grad.detach().reshape(-1)) 118 | return torch.cat(vec, 0) 119 | 120 | def update_grad(model, grad_vec): 121 | idx = 0 122 | for name,param in model.named_parameters(): 123 | arr_shape = param.grad.shape 124 | size = 1 125 | for i in range(len(list(arr_shape))): 126 | size *= arr_shape[i] 127 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 128 | idx += size 129 | 130 | 131 | iters = 0 132 | def get_model_param_vec(model): 133 | # Return the model parameters as a vector 134 | 135 | vec = [] 136 | for name,param in model.named_parameters(): 137 | vec.append(param.detach().cpu().reshape(-1).numpy()) 138 | return np.concatenate(vec, 0) 139 | 140 | def main(): 141 | global train_loss, train_acc, test_loss, test_acc, arr_time 142 | 143 | args = parser.parse_args() 144 | 145 | print ('gamma:', args.gamma) 146 | save_dir = 'save_' + args.arch 147 | if not os.path.exists(save_dir): 148 | os.makedirs(save_dir) 149 | args.save_dir = save_dir 150 | 151 | 152 | # Check the log_dir exists or not 153 | # if args.rank == 0: 154 | print ('log dir:', args.log_dir) 155 | if not os.path.exists(args.log_dir): 156 | os.makedirs(args.log_dir) 157 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 158 | print ('log dir:', args.log_dir) 159 | 160 | if args.seed is not None: 161 | random.seed(args.seed) 162 | torch.manual_seed(args.seed) 163 | cudnn.deterministic = True 164 | warnings.warn('You have chosen to seed training. ' 165 | 'This will turn on the CUDNN deterministic setting, ' 166 | 'which can slow down your training considerably! ' 167 | 'You may see unexpected behavior when restarting ' 168 | 'from checkpoints.') 169 | 170 | if args.gpu is not None: 171 | warnings.warn('You have chosen a specific GPU. This will completely ' 172 | 'disable data parallelism.') 173 | 174 | if args.dist_url == "env://" and args.world_size == -1: 175 | args.world_size = int(os.environ["WORLD_SIZE"]) 176 | 177 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 178 | 179 | ngpus_per_node = torch.cuda.device_count() 180 | if args.multiprocessing_distributed: 181 | # Since we have ngpus_per_node processes per node, the total world_size 182 | # needs to be adjusted accordingly 183 | args.world_size = ngpus_per_node * args.world_size 184 | # Use torch.multiprocessing.spawn to launch distributed processes: the 185 | # main_worker process function 186 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 187 | else: 188 | # Simply call main_worker function 189 | main_worker(args.gpu, ngpus_per_node, args) 190 | 191 | sample_idx = 0 192 | 193 | def main_worker(gpu, ngpus_per_node, args): 194 | global train_loss, train_acc, test_loss, test_acc, arr_time 195 | global best_acc1, param_vec, sample_idx 196 | args.gpu = gpu 197 | 198 | if args.gpu is not None: 199 | print("Use GPU: {} for training".format(args.gpu)) 200 | 201 | if args.distributed: 202 | if args.dist_url == "env://" and args.rank == -1: 203 | args.rank = int(os.environ["RANK"]) 204 | if args.multiprocessing_distributed: 205 | # For multiprocessing distributed training, rank needs to be the 206 | # global rank among all the processes 207 | args.rank = args.rank * ngpus_per_node + gpu 208 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 209 | world_size=args.world_size, rank=args.rank) 210 | # create model 211 | if args.pretrained: 212 | print("=> using pre-trained model '{}'".format(args.arch)) 213 | model = models.__dict__[args.arch](pretrained=True) 214 | else: 215 | print("=> creating model '{}'".format(args.arch)) 216 | model = models.__dict__[args.arch]() 217 | 218 | 219 | # Double the training epochs since each iteration will consume two batches of data for calculating g and g_s 220 | args.epochs = args.epochs * 2 221 | args.batch_size = args.batch_size * 2 222 | 223 | 224 | if not torch.cuda.is_available(): 225 | print('using CPU, this will be slow') 226 | elif args.distributed: 227 | # For multiprocessing distributed, DistributedDataParallel constructor 228 | # should always set the single device scope, otherwise, 229 | # DistributedDataParallel will use all available devices. 230 | if args.gpu is not None: 231 | torch.cuda.set_device(args.gpu) 232 | model.cuda(args.gpu) 233 | # When using a single GPU per process and per 234 | # DistributedDataParallel, we need to divide the batch size 235 | # ourselves based on the total number of GPUs we have 236 | args.batch_size = int(args.batch_size / ngpus_per_node) 237 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 238 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 239 | else: 240 | model.cuda() 241 | # DistributedDataParallel will divide and allocate batch_size to all 242 | # available GPUs if device_ids are not set 243 | model = torch.nn.parallel.DistributedDataParallel(model) 244 | elif args.gpu is not None: 245 | torch.cuda.set_device(args.gpu) 246 | model = model.cuda(args.gpu) 247 | else: 248 | # DataParallel will divide and allocate batch_size to all available GPUs 249 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 250 | model.features = torch.nn.DataParallel(model.features) 251 | model.cuda() 252 | else: 253 | model = torch.nn.DataParallel(model).cuda() 254 | 255 | # define loss function (criterion) and optimizer 256 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 257 | 258 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 259 | momentum=args.momentum, 260 | weight_decay=args.weight_decay) 261 | 262 | # optionally resume from a checkpoint 263 | if args.resume: 264 | if os.path.isfile(args.resume): 265 | print("=> loading checkpoint '{}'".format(args.resume)) 266 | if args.gpu is None: 267 | checkpoint = torch.load(args.resume) 268 | else: 269 | # Map model to be loaded to specified single gpu. 270 | loc = 'cuda:{}'.format(args.gpu) 271 | checkpoint = torch.load(args.resume, map_location=loc) 272 | args.start_epoch = checkpoint['epoch'] 273 | best_acc1 = checkpoint['best_acc1'] 274 | if args.gpu is not None: 275 | # best_acc1 may be from a checkpoint from a different GPU 276 | best_acc1 = best_acc1.to(args.gpu) 277 | model.load_state_dict(checkpoint['state_dict']) 278 | optimizer.load_state_dict(checkpoint['optimizer']) 279 | print("=> loaded checkpoint '{}' (epoch {})" 280 | .format(args.resume, checkpoint['epoch'])) 281 | else: 282 | print("=> no checkpoint found at '{}'".format(args.resume)) 283 | 284 | cudnn.benchmark = True 285 | 286 | # Data loading code 287 | traindir = os.path.join(args.data, 'train') 288 | valdir = os.path.join(args.data, 'val') 289 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 290 | std=[0.229, 0.224, 0.225]) 291 | 292 | train_dataset = datasets.ImageFolder( 293 | traindir, 294 | transforms.Compose([ 295 | transforms.RandomResizedCrop(224), 296 | transforms.RandomHorizontalFlip(), 297 | transforms.ToTensor(), 298 | normalize, 299 | ])) 300 | 301 | if args.distributed: 302 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 303 | else: 304 | train_sampler = None 305 | 306 | train_loader = torch.utils.data.DataLoader( 307 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 308 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 309 | 310 | val_loader = torch.utils.data.DataLoader( 311 | datasets.ImageFolder(valdir, transforms.Compose([ 312 | transforms.Resize(256), 313 | transforms.CenterCrop(224), 314 | transforms.ToTensor(), 315 | normalize, 316 | ])), 317 | batch_size=args.batch_size, shuffle=False, 318 | num_workers=args.workers, pin_memory=True) 319 | 320 | if args.evaluate: 321 | validate(val_loader, model, criterion, args) 322 | return 323 | 324 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 325 | torch.save(model.state_dict(), 'save_' + args.arch + '/' + str(sample_idx)+'.pt') 326 | 327 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 328 | 329 | for epoch in range(args.start_epoch, args.epochs): 330 | if args.distributed: 331 | train_sampler.set_epoch(epoch) 332 | 333 | # train for one epoch 334 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 335 | train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node) 336 | lr_scheduler.step() 337 | 338 | # evaluate on validation set 339 | acc1 = validate(val_loader, model, criterion, args) 340 | 341 | # remember best acc@1 and save checkpoint 342 | is_best = acc1 > best_acc1 343 | best_acc1 = max(acc1, best_acc1) 344 | 345 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 346 | and args.rank % ngpus_per_node == 0): 347 | save_checkpoint({ 348 | 'epoch': epoch + 1, 349 | 'arch': args.arch, 350 | 'state_dict': model.state_dict(), 351 | 'best_acc1': best_acc1, 352 | 'optimizer' : optimizer.state_dict(), 353 | }, is_best) 354 | 355 | torch.save(model, os.path.join(args.save_dir, 'model.pt')) 356 | 357 | print ('train loss: ', train_loss) 358 | print ('train acc: ', train_acc) 359 | print ('test loss: ', test_loss) 360 | print ('test acc: ', test_acc) 361 | 362 | print ('time: ', arr_time) 363 | 364 | 365 | def train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node): 366 | global iters, param_vec, sample_idx 367 | global train_loss, train_acc, test_loss, test_acc, arr_time 368 | 369 | batch_time = AverageMeter('Time', ':6.3f') 370 | data_time = AverageMeter('Data', ':6.3f') 371 | losses = AverageMeter('Loss', ':.4e') 372 | top1 = AverageMeter('Acc@1', ':6.2f') 373 | top5 = AverageMeter('Acc@5', ':6.2f') 374 | progress = ProgressMeter( 375 | len(train_loader), 376 | [batch_time, data_time, losses, top1, top5], 377 | prefix="Epoch: [{}]".format(epoch)) 378 | 379 | # switch to train mode 380 | model.train() 381 | 382 | end = time.time() 383 | epoch_start = end 384 | for i, (images, target) in enumerate(train_loader): 385 | # measure data loading time 386 | data_time.update(time.time() - end) 387 | 388 | if args.gpu is not None: 389 | images = images.cuda(args.gpu, non_blocking=True) 390 | if torch.cuda.is_available(): 391 | target = target.cuda(args.gpu, non_blocking=True) 392 | 393 | 394 | if args.rank % 2 == 1: 395 | weight = args.alpha * 2 396 | ##################### grw ############################# 397 | noise = [] 398 | for mp in model.parameters(): 399 | if len(mp.shape) > 1: 400 | sh = mp.shape 401 | sh_mul = np.prod(sh[1:]) 402 | temp = mp.view(sh[0], -1).norm(dim=1, keepdim=True).repeat(1, sh_mul).view(mp.shape) 403 | temp = torch.normal(0, args.gamma*temp).to(mp.data.device) 404 | else: 405 | temp = torch.empty_like(mp, device=mp.data.device) 406 | temp.normal_(0, args.gamma*(mp.view(-1).norm().item() + 1e-16)) 407 | noise.append(temp) 408 | mp.data.add_(noise[-1]) 409 | else: 410 | weight = (1 - args.alpha) * 2 411 | 412 | # compute output 413 | output = model(images) 414 | loss = criterion(output, target) * weight 415 | optimizer.zero_grad() 416 | loss.backward() 417 | 418 | if args.rank % 2 == 1: 419 | # going back to without theta 420 | with torch.no_grad(): 421 | for mp, n in zip(model.parameters(), noise): 422 | mp.data.sub_(n) 423 | 424 | optimizer.step() 425 | 426 | # measure accuracy and record loss 427 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 428 | losses.update(loss.item() / weight, images.size(0)) 429 | top1.update(acc1[0], images.size(0)) 430 | top5.update(acc5[0], images.size(0)) 431 | 432 | # compute gradient and do SGD step 433 | # optimizer.zero_grad() 434 | # loss.backward() 435 | # optimizer.step() 436 | 437 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 438 | and args.rank % ngpus_per_node == 0): 439 | 440 | if i % args.print_freq == 0: 441 | progress.display(i) 442 | 443 | if i > 0 and i % 1000 == 0 and i < 5000: 444 | sample_idx += 1 445 | # torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt') 446 | 447 | # measure elapsed time 448 | batch_time.update(time.time() - end) 449 | end = time.time() 450 | 451 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 452 | and args.rank % ngpus_per_node == 0): 453 | sample_idx += 1 454 | # torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt') 455 | 456 | arr_time.append(time.time() - epoch_start) 457 | train_loss.append(losses.avg) 458 | train_acc.append(top1.avg) 459 | 460 | 461 | def validate(val_loader, model, criterion, args): 462 | global train_loss, train_acc, test_loss, test_acc, arr_time 463 | batch_time = AverageMeter('Time', ':6.3f') 464 | losses = AverageMeter('Loss', ':.4e') 465 | top1 = AverageMeter('Acc@1', ':6.2f') 466 | top5 = AverageMeter('Acc@5', ':6.2f') 467 | progress = ProgressMeter( 468 | len(val_loader), 469 | [batch_time, losses, top1, top5], 470 | prefix='Test: ') 471 | 472 | # switch to evaluate mode 473 | model.eval() 474 | 475 | with torch.no_grad(): 476 | end = time.time() 477 | for i, (images, target) in enumerate(val_loader): 478 | if args.gpu is not None: 479 | images = images.cuda(args.gpu, non_blocking=True) 480 | if torch.cuda.is_available(): 481 | target = target.cuda(args.gpu, non_blocking=True) 482 | 483 | # compute output 484 | output = model(images) 485 | loss = criterion(output, target) 486 | 487 | # measure accuracy and record loss 488 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 489 | losses.update(loss.item(), images.size(0)) 490 | top1.update(acc1[0], images.size(0)) 491 | top5.update(acc5[0], images.size(0)) 492 | 493 | # measure elapsed time 494 | batch_time.update(time.time() - end) 495 | end = time.time() 496 | 497 | if i % args.print_freq == 0: 498 | progress.display(i) 499 | 500 | # TODO: this should also be done with the ProgressMeter 501 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 502 | .format(top1=top1, top5=top5)) 503 | test_acc.append(top1.avg) 504 | test_loss.append(losses.avg) 505 | return top1.avg 506 | 507 | 508 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 509 | torch.save(state, filename) 510 | if is_best: 511 | shutil.copyfile(filename, 'model_best.pth.tar') 512 | 513 | 514 | class AverageMeter(object): 515 | """Computes and stores the average and current value""" 516 | def __init__(self, name, fmt=':f'): 517 | self.name = name 518 | self.fmt = fmt 519 | self.reset() 520 | 521 | def reset(self): 522 | self.val = 0 523 | self.avg = 0 524 | self.sum = 0 525 | self.count = 0 526 | 527 | def update(self, val, n=1): 528 | self.val = val 529 | self.sum += val * n 530 | self.count += n 531 | self.avg = self.sum / self.count 532 | 533 | def __str__(self): 534 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 535 | return fmtstr.format(**self.__dict__) 536 | 537 | 538 | class ProgressMeter(object): 539 | def __init__(self, num_batches, meters, prefix=""): 540 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 541 | self.meters = meters 542 | self.prefix = prefix 543 | 544 | def display(self, batch): 545 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 546 | entries += [str(meter) for meter in self.meters] 547 | print('\t'.join(entries)) 548 | 549 | def _get_batch_fmtstr(self, num_batches): 550 | num_digits = len(str(num_batches // 1)) 551 | fmt = '{:' + str(num_digits) + 'd}' 552 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 553 | 554 | 555 | def accuracy(output, target, topk=(1,)): 556 | """Computes the accuracy over the k top predictions for the specified values of k""" 557 | with torch.no_grad(): 558 | maxk = max(topk) 559 | batch_size = target.size(0) 560 | 561 | _, pred = output.topk(maxk, 1, True, True) 562 | pred = pred.t() 563 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 564 | 565 | res = [] 566 | for k in topk: 567 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 568 | res.append(correct_k.mul_(100.0 / batch_size)) 569 | return res 570 | 571 | 572 | if __name__ == '__main__': 573 | main() -------------------------------------------------------------------------------- /train_rwp_parallel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | import os 4 | import time 5 | import numpy as np 6 | import random 7 | import sys 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim 14 | import torch.utils.data 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | 18 | import torch.distributed as dist 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | 21 | from utils import * 22 | 23 | 24 | # Parse arguments 25 | parser = argparse.ArgumentParser(description='DDP RWP training') 26 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='SGD') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', 28 | help='The architecture of the model') 29 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 30 | help='The training datasets') 31 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str, 32 | help='The optimizer for training') 33 | parser.add_argument('--schedule', metavar='SCHEDULE', default='step', type=str, 34 | help='The schedule for training') 35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=128, type=int, 42 | metavar='N', help='mini-batch size (default: 128)') 43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | metavar='LR', help='initial learning rate') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)') 49 | parser.add_argument('--print-freq', '-p', default=100, type=int, 50 | metavar='N', help='print frequency (default: 50 iterations)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 54 | help='evaluate model on validation set') 55 | parser.add_argument('--wandb', dest='wandb', action='store_true', 56 | help='use wandb to monitor statisitcs') 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pre-trained model') 59 | parser.add_argument('--half', dest='half', action='store_true', 60 | help='use half-precision(16-bit) ') 61 | parser.add_argument('--save-dir', dest='save_dir', 62 | help='The directory used to save the trained models', 63 | default='save_temp', type=str) 64 | parser.add_argument('--log-dir', dest='log_dir', 65 | help='The directory used to save the log', 66 | default='save_temp', type=str) 67 | parser.add_argument('--log-name', dest='log_name', 68 | help='The log file name', 69 | default='log', type=str) 70 | parser.add_argument('--randomseed', 71 | help='Randomseed for training and initialization', 72 | type=int, default=1) 73 | parser.add_argument('--cutout', dest='cutout', action='store_true', 74 | help='use cutout data augmentation') 75 | parser.add_argument('--alpha', default=0.5, type=float, 76 | metavar='A', help='alpha for mixing gradients') 77 | parser.add_argument('--gamma', default=0.01, type=float, 78 | metavar='gamma', help='Perturbation magnitude gamma for RWP') 79 | 80 | parser.add_argument("--local_rank", default=-1, type=int) 81 | 82 | best_prec1 = 0 83 | 84 | # Record training statistics 85 | train_loss = [] 86 | train_err = [] 87 | test_loss = [] 88 | test_err = [] 89 | arr_time = [] 90 | 91 | args = parser.parse_args() 92 | 93 | local_rank = args.local_rank 94 | torch.cuda.set_device(local_rank) 95 | dist.init_process_group(backend='nccl') 96 | args.world_size = torch.distributed.get_world_size() 97 | args.workers = int((args.workers + args.world_size - 1) / args.world_size) 98 | if args.local_rank == 0: 99 | print ('world size: {} workers per GPU: {}'.format(args.world_size, args.workers)) 100 | device = torch.device("cuda", local_rank) 101 | 102 | if args.wandb: 103 | import wandb 104 | wandb.init(project="TWA", entity="nblt") 105 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 106 | wandb.run.name = args.EXP + date 107 | 108 | 109 | def get_model_param_vec(model): 110 | # Return the model parameters as a vector 111 | 112 | vec = [] 113 | for name,param in model.named_parameters(): 114 | vec.append(param.data.detach().reshape(-1)) 115 | return torch.cat(vec, 0) 116 | 117 | 118 | def get_model_grad_vec(model): 119 | # Return the model gradient as a vector 120 | 121 | vec = [] 122 | for name,param in model.named_parameters(): 123 | vec.append(param.grad.detach().reshape(-1)) 124 | return torch.cat(vec, 0) 125 | 126 | def update_grad(model, grad_vec): 127 | idx = 0 128 | for name,param in model.named_parameters(): 129 | arr_shape = param.grad.shape 130 | size = param.grad.numel() 131 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 132 | idx += size 133 | 134 | def update_param(model, param_vec): 135 | idx = 0 136 | for name,param in model.named_parameters(): 137 | arr_shape = param.data.shape 138 | size = param.data.numel() 139 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone() 140 | idx += size 141 | 142 | def print_param_shape(model): 143 | for name,param in model.named_parameters(): 144 | print (name, param.data.shape) 145 | 146 | def main(): 147 | 148 | global args, best_prec1, p0 149 | global train_loss, train_err, test_loss, test_err, arr_time, running_weight 150 | 151 | set_seed(args.randomseed) 152 | 153 | # Check the save_dir exists or not 154 | if args.local_rank == 0: 155 | print ('save dir:', args.save_dir) 156 | if not os.path.exists(args.save_dir): 157 | os.makedirs(args.save_dir) 158 | 159 | # Check the log_dir exists or not 160 | if args.local_rank == 0: 161 | print ('log dir:', args.log_dir) 162 | if not os.path.exists(args.log_dir): 163 | os.makedirs(args.log_dir) 164 | 165 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 166 | 167 | # Define model 168 | # model = torch.nn.DataParallel(get_model(args)) 169 | model = get_model(args).to(device) 170 | model = DDP(model, device_ids=[local_rank], output_device=local_rank) 171 | 172 | # print_param_shape(model) 173 | 174 | # Optionally resume from a checkpoint 175 | if args.resume: 176 | # if os.path.isfile(args.resume): 177 | if os.path.isfile(os.path.join(args.save_dir, args.resume)): 178 | 179 | # model.load_state_dict(torch.load(os.path.join(args.save_dir, args.resume))) 180 | 181 | print ("=> loading checkpoint '{}'".format(args.resume)) 182 | checkpoint = torch.load(args.resume) 183 | args.start_epoch = checkpoint['epoch'] 184 | print ('from ', args.start_epoch) 185 | best_prec1 = checkpoint['best_prec1'] 186 | model.load_state_dict(checkpoint['state_dict']) 187 | print ("=> loaded checkpoint '{}' (epoch {})" 188 | .format(args.evaluate, checkpoint['epoch'])) 189 | else: 190 | print ("=> no checkpoint found at '{}'".format(args.resume)) 191 | 192 | cudnn.benchmark = True 193 | 194 | # Prepare Dataloader 195 | print ('cutout:', args.cutout) 196 | if args.cutout: 197 | train_loader, val_loader = get_datasets_cutout_ddp(args) 198 | else: 199 | train_loader, val_loader = get_datasets_ddp(args) 200 | 201 | # define loss function (criterion) and optimizer 202 | criterion = nn.CrossEntropyLoss().to(device) 203 | 204 | if args.half: 205 | model.half() 206 | criterion.half() 207 | 208 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 209 | 210 | # Double the training epochs since each iteration will consume two batches of data for calculating g and g_s 211 | args.epochs = args.epochs * 2 212 | 213 | if args.schedule == 'step': 214 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs * 0.5), int(args.epochs * 0.75)], last_epoch=args.start_epoch - 1) 215 | elif args.schedule == 'cosine': 216 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 217 | 218 | if args.evaluate: 219 | validate(val_loader, model, criterion) 220 | return 221 | 222 | 223 | is_best = 0 224 | print ('Start training: ', args.start_epoch, '->', args.epochs) 225 | print ('gamma:', args.gamma) 226 | print ('len(train_loader):', len(train_loader)) 227 | 228 | for epoch in range(args.start_epoch, args.epochs): 229 | train_loader.sampler.set_epoch(epoch) 230 | 231 | # train for one epoch 232 | if args.local_rank == 0: 233 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 234 | train(train_loader, model, criterion, optimizer, epoch) 235 | lr_scheduler.step() 236 | 237 | if epoch % 2 == 0: continue 238 | 239 | # evaluate on validation set 240 | prec1 = validate(val_loader, model, criterion) 241 | 242 | # remember best prec@1 and save checkpoint 243 | is_best = prec1 > best_prec1 244 | best_prec1 = max(prec1, best_prec1) 245 | 246 | if args.local_rank == 0: 247 | save_checkpoint({ 248 | 'state_dict': model.state_dict(), 249 | 'best_prec1': best_prec1, 250 | }, is_best, filename=os.path.join(args.save_dir, 'model.th')) 251 | 252 | if args.local_rank == 0: 253 | print ('train loss: ', train_loss) 254 | print ('train err: ', train_err) 255 | print ('test loss: ', test_loss) 256 | print ('test err: ', test_err) 257 | print ('time: ', arr_time) 258 | 259 | 260 | def train(train_loader, model, criterion, optimizer, epoch): 261 | """ 262 | Run one train epoch 263 | """ 264 | global train_loss, train_err, arr_time 265 | 266 | batch_time = AverageMeter() 267 | data_time = AverageMeter() 268 | losses = AverageMeter() 269 | top1 = AverageMeter() 270 | 271 | # switch to train mode 272 | model.train() 273 | 274 | total_loss, total_err = 0, 0 275 | end = time.time() 276 | for i, (input, target) in enumerate(train_loader): 277 | 278 | # measure data loading time 279 | data_time.update(time.time() - end) 280 | 281 | target = target.to(device) 282 | input_var = input.to(device) 283 | target_var = target 284 | if args.half: 285 | input_var = input_var.half() 286 | 287 | if args.local_rank % 2 == 1: 288 | weight = args.alpha * 2 289 | with torch.no_grad(): 290 | noise = [] 291 | for mp in model.parameters(): 292 | if len(mp.shape) > 1: 293 | sh = mp.shape 294 | sh_mul = np.prod(sh[1:]) 295 | temp = mp.view(sh[0], -1).norm(dim=1, keepdim=True).repeat(1, sh_mul).view(mp.shape) 296 | temp = torch.normal(0, args.gamma*temp).to(mp.data.device) 297 | else: 298 | temp = torch.empty_like(mp, device=mp.data.device) 299 | temp.normal_(0, args.gamma*(mp.view(-1).norm().item() + 1e-16)) 300 | noise.append(temp) 301 | mp.data.add_(noise[-1]) 302 | else: 303 | weight = (1 - args.alpha) * 2 304 | 305 | # compute output 306 | output = model(input_var) 307 | loss = criterion(output, target_var) * weight 308 | 309 | optimizer.zero_grad() 310 | loss.backward() 311 | 312 | if args.local_rank % 2 == 1: 313 | # going back to without theta 314 | with torch.no_grad(): 315 | for mp, n in zip(model.parameters(), noise): 316 | mp.data.sub_(n) 317 | 318 | optimizer.step() 319 | 320 | total_loss += loss.item() * input_var.shape[0] / weight 321 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 322 | 323 | output = output.float() 324 | loss = loss.float() 325 | 326 | # measure accuracy and record loss 327 | prec1 = accuracy(output.data, target)[0] 328 | losses.update(loss.item(), input.size(0)) 329 | top1.update(prec1.item(), input.size(0)) 330 | 331 | # measure elapsed time 332 | batch_time.update(time.time() - end) 333 | end = time.time() 334 | 335 | if args.local_rank == 0 and (i % args.print_freq == 0 or i == len(train_loader) - 1): 336 | print('Epoch: [{0}][{1}/{2}]\t' 337 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 338 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 339 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 340 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 341 | epoch, i, len(train_loader), batch_time=batch_time, 342 | data_time=data_time, loss=losses, top1=top1)) 343 | 344 | if args.local_rank == 0: 345 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum)) 346 | 347 | tloss = total_loss / len(train_loader.dataset) * args.world_size 348 | terr = total_err / len(train_loader.dataset) * args.world_size 349 | train_loss.append(tloss) 350 | train_err.append(terr) 351 | print ('train loss | acc', tloss, 1 - terr) 352 | 353 | if args.wandb: 354 | wandb.log({"train loss": total_loss / len(train_loader.dataset)}) 355 | wandb.log({"train acc": 1 - total_err / len(train_loader.dataset)}) 356 | 357 | arr_time.append(batch_time.sum) 358 | 359 | def validate(val_loader, model, criterion, add=True): 360 | """ 361 | Run evaluation 362 | """ 363 | global test_err, test_loss 364 | 365 | total_loss = 0 366 | total_err = 0 367 | 368 | batch_time = AverageMeter() 369 | losses = AverageMeter() 370 | top1 = AverageMeter() 371 | 372 | # switch to evaluate mode 373 | model.eval() 374 | 375 | end = time.time() 376 | with torch.no_grad(): 377 | for i, (input, target) in enumerate(val_loader): 378 | target = target.to(device) 379 | input_var = input.to(device) 380 | target_var = target.to(device) 381 | 382 | if args.half: 383 | input_var = input_var.half() 384 | 385 | # compute output 386 | output = model(input_var) 387 | loss = criterion(output, target_var) 388 | 389 | output = output.float() 390 | loss = loss.float() 391 | 392 | total_loss += loss.item() * input_var.shape[0] 393 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 394 | 395 | # measure accuracy and record loss 396 | prec1 = accuracy(output.data, target)[0] 397 | losses.update(loss.item(), input.size(0)) 398 | top1.update(prec1.item(), input.size(0)) 399 | 400 | # measure elapsed time 401 | batch_time.update(time.time() - end) 402 | end = time.time() 403 | 404 | if i % args.print_freq == 0 and add: 405 | print('Test: [{0}/{1}]\t' 406 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 407 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 408 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 409 | i, len(val_loader), batch_time=batch_time, loss=losses, 410 | top1=top1)) 411 | 412 | if add: 413 | print(' * Prec@1 {top1.avg:.3f}' 414 | .format(top1=top1)) 415 | 416 | test_loss.append(total_loss / len(val_loader.dataset)) 417 | test_err.append(total_err / len(val_loader.dataset)) 418 | 419 | if args.wandb: 420 | wandb.log({"test loss": total_loss / len(val_loader.dataset)}) 421 | wandb.log({"test acc": 1 - total_err / len(val_loader.dataset)}) 422 | 423 | return top1.avg 424 | 425 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 426 | """ 427 | Save the training model 428 | """ 429 | torch.save(state, filename) 430 | 431 | class AverageMeter(object): 432 | """Computes and stores the average and current value""" 433 | def __init__(self): 434 | self.reset() 435 | 436 | def reset(self): 437 | self.val = 0 438 | self.avg = 0 439 | self.sum = 0 440 | self.count = 0 441 | 442 | def update(self, val, n=1): 443 | self.val = val 444 | self.sum += val * n 445 | self.count += n 446 | self.avg = self.sum / self.count 447 | 448 | 449 | def accuracy(output, target, topk=(1,)): 450 | """Computes the precision@k for the specified values of k""" 451 | maxk = max(topk) 452 | batch_size = target.size(0) 453 | 454 | _, pred = output.topk(maxk, 1, True, True) 455 | pred = pred.t() 456 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 457 | 458 | res = [] 459 | for k in topk: 460 | correct_k = correct[:k].view(-1).float().sum(0) 461 | res.append(correct_k.mul_(100.0 / batch_size)) 462 | return res 463 | 464 | 465 | if __name__ == '__main__': 466 | main() 467 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models_imagenet 11 | 12 | import numpy as np 13 | import random 14 | import os 15 | import time 16 | import models 17 | import sys 18 | import torch.utils.data as data 19 | from torchvision.datasets.utils import download_url, check_integrity 20 | import os.path 21 | import pickle 22 | from PIL import Image 23 | 24 | def set_seed(seed=1): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | class Logger(object): 33 | def __init__(self,fileN ="Default.log"): 34 | self.terminal = sys.stdout 35 | self.log = open(fileN,"a") 36 | 37 | def write(self,message): 38 | self.terminal.write(message) 39 | self.log.write(message) 40 | 41 | def flush(self): 42 | self.terminal.flush() 43 | self.log.flush() 44 | 45 | ################################ datasets ####################################### 46 | 47 | import torchvision.transforms as transforms 48 | import torchvision.datasets as datasets 49 | from torch.utils.data import DataLoader, Subset 50 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder 51 | 52 | class Cutout: 53 | def __init__(self, size=16, p=0.5): 54 | self.size = size 55 | self.half_size = size // 2 56 | self.p = p 57 | 58 | def __call__(self, image): 59 | if torch.rand([1]).item() > self.p: 60 | return image 61 | 62 | left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item() 63 | top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item() 64 | right = min(image.size(1), left + self.size) 65 | bottom = min(image.size(2), top + self.size) 66 | 67 | image[:, max(0, left): right, max(0, top): bottom] = 0 68 | return image 69 | 70 | def get_datasets(args): 71 | if args.datasets == 'CIFAR10': 72 | print ('cifar10 dataset!') 73 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 74 | 75 | train_loader = torch.utils.data.DataLoader( 76 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 77 | transforms.RandomHorizontalFlip(), 78 | transforms.RandomCrop(32, 4), 79 | transforms.ToTensor(), 80 | normalize, 81 | ]), download=True), 82 | batch_size=args.batch_size, shuffle=True, 83 | num_workers=args.workers, pin_memory=True) 84 | 85 | val_loader = torch.utils.data.DataLoader( 86 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 87 | transforms.ToTensor(), 88 | normalize, 89 | ])), 90 | batch_size=128, shuffle=False, 91 | num_workers=args.workers, pin_memory=True) 92 | 93 | elif args.datasets == 'CIFAR100': 94 | print ('cifar100 dataset!') 95 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 96 | 97 | train_loader = torch.utils.data.DataLoader( 98 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 99 | transforms.RandomHorizontalFlip(), 100 | transforms.RandomCrop(32, 4), 101 | transforms.ToTensor(), 102 | normalize, 103 | ]), download=True), 104 | batch_size=args.batch_size, shuffle=True, 105 | num_workers=args.workers, pin_memory=True) 106 | 107 | val_loader = torch.utils.data.DataLoader( 108 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 109 | transforms.ToTensor(), 110 | normalize, 111 | ])), 112 | batch_size=128, shuffle=False, 113 | num_workers=args.workers, pin_memory=True) 114 | 115 | elif args.datasets == 'ImageNet': 116 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train') 117 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val') 118 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 119 | std=[0.229, 0.224, 0.225]) 120 | 121 | train_dataset = datasets.ImageFolder( 122 | traindir, 123 | transforms.Compose([ 124 | transforms.RandomResizedCrop(224), 125 | transforms.RandomHorizontalFlip(), 126 | transforms.ToTensor(), 127 | normalize, 128 | ])) 129 | 130 | train_loader = torch.utils.data.DataLoader( 131 | train_dataset, batch_size=args.batch_size, shuffle=True, 132 | num_workers=args.workers, pin_memory=True) 133 | 134 | val_loader = torch.utils.data.DataLoader( 135 | datasets.ImageFolder(valdir, transforms.Compose([ 136 | transforms.Resize(256), 137 | transforms.CenterCrop(224), 138 | transforms.ToTensor(), 139 | normalize, 140 | ])), 141 | batch_size=args.batch_size, shuffle=False, 142 | num_workers=args.workers) 143 | 144 | return train_loader, val_loader 145 | 146 | def get_datasets_ddp(args): 147 | if args.datasets == 'CIFAR10': 148 | print ('cifar10 dataset!') 149 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 150 | 151 | my_trainset = datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 152 | transforms.RandomHorizontalFlip(), 153 | transforms.RandomCrop(32, 4), 154 | transforms.ToTensor(), 155 | normalize, 156 | ]), download=True) 157 | 158 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) 159 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler) 160 | 161 | val_loader = torch.utils.data.DataLoader( 162 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 163 | transforms.ToTensor(), 164 | normalize, 165 | ])), 166 | batch_size=128, shuffle=False, 167 | num_workers=args.workers, pin_memory=True) 168 | 169 | elif args.datasets == 'CIFAR100': 170 | print ('cifar100 dataset!') 171 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 172 | 173 | my_trainset = datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 174 | transforms.RandomHorizontalFlip(), 175 | transforms.RandomCrop(32, 4), 176 | transforms.ToTensor(), 177 | normalize, 178 | ]), download=True) 179 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) 180 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler) 181 | 182 | val_loader = torch.utils.data.DataLoader( 183 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 184 | transforms.ToTensor(), 185 | normalize, 186 | ])), 187 | batch_size=128, shuffle=False, 188 | num_workers=args.workers, pin_memory=True) 189 | 190 | return train_loader, val_loader 191 | 192 | def get_datasets_cutout(args): 193 | print ('cutout!') 194 | if args.datasets == 'CIFAR10': 195 | print ('cifar10 dataset!') 196 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 197 | 198 | train_loader = torch.utils.data.DataLoader( 199 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 200 | transforms.RandomHorizontalFlip(), 201 | transforms.RandomCrop(32, 4), 202 | transforms.ToTensor(), 203 | normalize, 204 | Cutout() 205 | ]), download=True), 206 | batch_size=args.batch_size, shuffle=True, 207 | num_workers=args.workers, pin_memory=True) 208 | 209 | val_loader = torch.utils.data.DataLoader( 210 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 211 | transforms.ToTensor(), 212 | normalize, 213 | ])), 214 | batch_size=128, shuffle=False, 215 | num_workers=args.workers, pin_memory=True) 216 | 217 | elif args.datasets == 'CIFAR100': 218 | print ('cifar100 dataset!') 219 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 220 | 221 | train_loader = torch.utils.data.DataLoader( 222 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 223 | transforms.RandomHorizontalFlip(), 224 | transforms.RandomCrop(32, 4), 225 | transforms.ToTensor(), 226 | normalize, 227 | Cutout() 228 | ]), download=True), 229 | batch_size=args.batch_size, shuffle=True, 230 | num_workers=args.workers, pin_memory=True) 231 | 232 | val_loader = torch.utils.data.DataLoader( 233 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 234 | transforms.ToTensor(), 235 | normalize, 236 | ])), 237 | batch_size=128, shuffle=False, 238 | num_workers=args.workers, pin_memory=True) 239 | 240 | return train_loader, val_loader 241 | 242 | def get_datasets_cutout_ddp(args): 243 | if args.datasets == 'CIFAR10': 244 | print ('cifar10 dataset!') 245 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 246 | 247 | my_trainset = datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 248 | transforms.RandomHorizontalFlip(), 249 | transforms.RandomCrop(32, 4), 250 | transforms.ToTensor(), 251 | normalize, 252 | Cutout() 253 | ]), download=True) 254 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) 255 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler, drop_last=True, num_workers=args.workers, pin_memory=True) 256 | 257 | val_loader = torch.utils.data.DataLoader( 258 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 259 | transforms.ToTensor(), 260 | normalize, 261 | ])), 262 | batch_size=128, shuffle=False, 263 | num_workers=args.workers, pin_memory=True) 264 | 265 | elif args.datasets == 'CIFAR100': 266 | print ('cifar100 dataset!') 267 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 268 | 269 | my_trainset = datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 270 | transforms.RandomHorizontalFlip(), 271 | transforms.RandomCrop(32, 4), 272 | transforms.ToTensor(), 273 | normalize, 274 | Cutout() 275 | ]), download=True) 276 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset) 277 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler, drop_last=True, num_workers=args.workers, pin_memory=True) 278 | 279 | val_loader = torch.utils.data.DataLoader( 280 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 281 | transforms.ToTensor(), 282 | normalize, 283 | ])), 284 | batch_size=128, shuffle=False, 285 | num_workers=args.workers, pin_memory=True) 286 | 287 | return train_loader, val_loader 288 | 289 | def get_model(args): 290 | print('Model: {}'.format(args.arch)) 291 | 292 | if args.datasets == 'ImageNet': 293 | return models_imagenet.__dict__[args.arch]() 294 | 295 | if args.datasets == 'CIFAR10': 296 | num_classes = 10 297 | elif args.datasets == 'CIFAR100': 298 | num_classes = 100 299 | 300 | model_cfg = getattr(models, args.arch) 301 | 302 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 303 | 304 | class SAM(torch.optim.Optimizer): 305 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 306 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 307 | 308 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 309 | super(SAM, self).__init__(params, defaults) 310 | 311 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 312 | self.param_groups = self.base_optimizer.param_groups 313 | self.defaults.update(self.base_optimizer.defaults) 314 | 315 | @torch.no_grad() 316 | def first_step(self, zero_grad=False): 317 | grad_norm = self._grad_norm() 318 | for group in self.param_groups: 319 | scale = group["rho"] / (grad_norm + 1e-12) 320 | 321 | for p in group["params"]: 322 | if p.grad is None: continue 323 | self.state[p]["old_p"] = p.data.clone() 324 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 325 | p.add_(e_w) # climb to the local maximum "w + e(w)" 326 | 327 | if zero_grad: self.zero_grad() 328 | 329 | @torch.no_grad() 330 | def second_step(self, zero_grad=False): 331 | for group in self.param_groups: 332 | for p in group["params"]: 333 | if p.grad is None: continue 334 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 335 | 336 | self.base_optimizer.step() # do the actual "sharpness-aware" update 337 | 338 | if zero_grad: self.zero_grad() 339 | 340 | @torch.no_grad() 341 | def step(self, closure=None): 342 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 343 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 344 | 345 | self.first_step(zero_grad=True) 346 | closure() 347 | self.second_step() 348 | 349 | def _grad_norm(self): 350 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 351 | norm = torch.norm( 352 | torch.stack([ 353 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 354 | for group in self.param_groups for p in group["params"] 355 | if p.grad is not None 356 | ]), 357 | p=2 358 | ) 359 | return norm 360 | 361 | def load_state_dict(self, state_dict): 362 | super().load_state_dict(state_dict) 363 | self.base_optimizer.param_groups = self.param_groups 364 | 365 | 366 | --------------------------------------------------------------------------------