├── README.md
├── compare.png
├── models
├── __init__.py
├── resnet.py
├── vgg.py
└── wide_resnet.py
├── recipes
├── run_rwp_ddp.sh
└── run_rwp_imagenet.sh
├── train_rwp_imagenet.py
├── train_rwp_parallel.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Generalization Improvement Guided by Random Weight Perturbation
2 |
3 | This repository contains a PyTorch implementation of the paper: **Efficient Generalization Improvement Guided by Random Weight Perturbation**.
4 |
5 |
6 | ## Abstract
7 | To fully uncover the great potential of deep neural networks (DNNs), various learning algorithms have been developed to improve the model's generalization ability. Recently, sharpness-aware minimization (SAM) establishes a generic scheme for generalization improvements by minimizing the sharpness measure within a small neighborhood and achieves state-of-the-art performance. However, SAM requires two consecutive gradient evaluations for solving the min-max problem and inevitably doubles the training time. In this paper, we resort to filter-wise random weight perturbations (RWP) to decouple the nested gradients in SAM. Different from the small adversarial perturbations in SAM, RWP is softer and allows a much larger magnitude of perturbations. Specifically, we jointly optimize the loss function with random perturbations and the original loss function: the former guides the network towards a wider flat region while the latter helps recover the necessary local information. These two loss terms are complementary to each other and mutually independent. Hence, the corresponding gradients can be efficiently computed in parallel, enabling nearly the same training speed as regular training. As a result, we achieve very competitive performance on CIFAR and remarkably better performance on ImageNet (e.g. $+1.1\%$) compared with SAM, but always require half of the training time.
8 |
9 |
10 |

11 |
12 |
13 |
14 | ## Example Usage
15 |
16 | We provide example usages in `/recipes/`.
17 | For parallelized training of RWP, we could run
18 |
19 | ```
20 | bash recipes/run_rwp_ddp.sh
21 | ```
22 |
--------------------------------------------------------------------------------
/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/RWP/cb0acb0708720a40c441915b275fd2d5e70c734c/compare.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .vgg import *
3 | from .wide_resnet import *
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """resnet in pytorch
2 |
3 |
4 |
5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
6 |
7 | Deep Residual Learning for Image Recognition
8 | https://arxiv.org/abs/1512.03385v1
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | class BasicBlock(nn.Module):
15 | """Basic Block for resnet 18 and resnet 34
16 |
17 | """
18 |
19 | #BasicBlock and BottleNeck block
20 | #have different output size
21 | #we use class attribute expansion
22 | #to distinct
23 | expansion = 1
24 |
25 | def __init__(self, in_channels, out_channels, stride=1):
26 | super().__init__()
27 |
28 | #residual function
29 | self.residual_function = nn.Sequential(
30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
31 | nn.BatchNorm2d(out_channels),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
35 | )
36 |
37 | #shortcut
38 | self.shortcut = nn.Sequential()
39 |
40 | #the shortcut output dimension is not the same with residual function
41 | #use 1*1 convolution to match the dimension
42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
43 | self.shortcut = nn.Sequential(
44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
46 | )
47 |
48 | def forward(self, x):
49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
50 |
51 | class BottleNeck(nn.Module):
52 | """Residual block for resnet over 50 layers
53 |
54 | """
55 | expansion = 4
56 | def __init__(self, in_channels, out_channels, stride=1):
57 | super().__init__()
58 | self.residual_function = nn.Sequential(
59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
60 | nn.BatchNorm2d(out_channels),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
63 | nn.BatchNorm2d(out_channels),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion),
67 | )
68 |
69 | self.shortcut = nn.Sequential()
70 |
71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
72 | self.shortcut = nn.Sequential(
73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)
75 | )
76 |
77 | def forward(self, x):
78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
79 |
80 | class ResNet(nn.Module):
81 |
82 | def __init__(self, block, num_block, num_classes=100):
83 | super().__init__()
84 |
85 | self.in_channels = 64
86 |
87 | self.conv1 = nn.Sequential(
88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
89 | nn.BatchNorm2d(64),
90 | nn.ReLU(inplace=True))
91 | #we use a different inputsize than the original paper
92 | #so conv2_x's stride is 1
93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
98 | self.fc = nn.Linear(512 * block.expansion, num_classes)
99 |
100 | def _make_layer(self, block, out_channels, num_blocks, stride):
101 | """make resnet layers(by layer i didnt mean this 'layer' was the
102 | same as a neuron netowork layer, ex. conv layer), one layer may
103 | contain more than one residual block
104 |
105 | Args:
106 | block: block type, basic block or bottle neck block
107 | out_channels: output depth channel number of this layer
108 | num_blocks: how many blocks per layer
109 | stride: the stride of the first block of this layer
110 |
111 | Return:
112 | return a resnet layer
113 | """
114 |
115 | # we have num_block blocks per layer, the first block
116 | # could be 1 or 2, other blocks would always be 1
117 | strides = [stride] + [1] * (num_blocks - 1)
118 | layers = []
119 | for stride in strides:
120 | layers.append(block(self.in_channels, out_channels, stride))
121 | self.in_channels = out_channels * block.expansion
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, x):
126 | output = self.conv1(x)
127 | output = self.conv2_x(output)
128 | output = self.conv3_x(output)
129 | output = self.conv4_x(output)
130 | output = self.conv5_x(output)
131 | output = self.avg_pool(output)
132 | output = output.view(output.size(0), -1)
133 | output = self.fc(output)
134 |
135 | return output
136 |
137 | class resnet18:
138 | base = ResNet
139 | args = list()
140 | kwargs = {'block': BasicBlock, 'num_block': [2, 2, 2, 2]}
141 |
142 | # def resnet18():
143 | # """ return a ResNet 18 object
144 | # """
145 | # kwargs = {}
146 | # return ResNet(BasicBlock, [2, 2, 2, 2])
147 |
148 | def resnet34():
149 | """ return a ResNet 34 object
150 | """
151 | return ResNet(BasicBlock, [3, 4, 6, 3])
152 |
153 | def resnet50():
154 | """ return a ResNet 50 object
155 | """
156 | return ResNet(BottleNeck, [3, 4, 6, 3])
157 |
158 | def resnet101():
159 | """ return a ResNet 101 object
160 | """
161 | return ResNet(BottleNeck, [3, 4, 23, 3])
162 |
163 | def resnet152():
164 | """ return a ResNet 152 object
165 | """
166 | return ResNet(BottleNeck, [3, 8, 36, 3])
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | VGG model definition
3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
4 | """
5 |
6 | import math
7 | import torch.nn as nn
8 | import torchvision.transforms as transforms
9 |
10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN']
11 |
12 |
13 | def make_layers(cfg, batch_norm=False):
14 | layers = list()
15 | in_channels = 3
16 | for v in cfg:
17 | if v == 'M':
18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
19 | else:
20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
21 | if batch_norm:
22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
23 | else:
24 | layers += [conv2d, nn.ReLU(inplace=True)]
25 | in_channels = v
26 | return nn.Sequential(*layers)
27 |
28 |
29 | cfg = {
30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
32 | 512, 512, 512, 512, 'M'],
33 | }
34 |
35 |
36 | class VGG(nn.Module):
37 | def __init__(self, num_classes=10, depth=16, batch_norm=False):
38 | super(VGG, self).__init__()
39 | self.features = make_layers(cfg[depth], batch_norm)
40 | self.classifier = nn.Sequential(
41 | nn.Dropout(),
42 | nn.Linear(512, 512),
43 | nn.ReLU(True),
44 | nn.Dropout(),
45 | nn.Linear(512, 512),
46 | nn.ReLU(True),
47 | nn.Linear(512, num_classes),
48 | )
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
53 | m.weight.data.normal_(0, math.sqrt(2. / n))
54 | m.bias.data.zero_()
55 |
56 | def forward(self, x):
57 | x = self.features(x)
58 | x = x.view(x.size(0), -1)
59 | x = self.classifier(x)
60 | return x
61 |
62 |
63 | class Base:
64 | base = VGG
65 | args = list()
66 | kwargs = dict()
67 | transform_train = transforms.Compose([
68 | transforms.RandomHorizontalFlip(),
69 | transforms.RandomCrop(32, padding=4),
70 | transforms.ToTensor(),
71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
72 | ])
73 |
74 | transform_test = transforms.Compose([
75 | transforms.ToTensor(),
76 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
77 | ])
78 |
79 |
80 | class VGG16(Base):
81 | pass
82 |
83 |
84 | class VGG16BN(Base):
85 | kwargs = {'batch_norm': True}
86 |
87 |
88 | class VGG19(Base):
89 | kwargs = {'depth': 19}
90 |
91 |
92 | class VGG19BN(Base):
93 | kwargs = {'depth': 19, 'batch_norm': True}
--------------------------------------------------------------------------------
/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | WideResNet model definition
3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
4 | """
5 |
6 | import torchvision.transforms as transforms
7 | import torch.nn as nn
8 | import torch.nn.init as init
9 | import torch.nn.functional as F
10 | import math
11 |
12 | __all__ = ['WideResNet28x10', 'WideResNet16x8']
13 |
14 | from collections import OrderedDict
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 |
21 | class BasicUnit(nn.Module):
22 | def __init__(self, channels: int, dropout: float):
23 | super(BasicUnit, self).__init__()
24 | self.block = nn.Sequential(OrderedDict([
25 | ("0_normalization", nn.BatchNorm2d(channels)),
26 | ("1_activation", nn.ReLU(inplace=True)),
27 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
28 | ("3_normalization", nn.BatchNorm2d(channels)),
29 | ("4_activation", nn.ReLU(inplace=True)),
30 | ("5_dropout", nn.Dropout(dropout, inplace=True)),
31 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
32 | ]))
33 |
34 | def forward(self, x):
35 | return x + self.block(x)
36 |
37 |
38 | class DownsampleUnit(nn.Module):
39 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float):
40 | super(DownsampleUnit, self).__init__()
41 | self.norm_act = nn.Sequential(OrderedDict([
42 | ("0_normalization", nn.BatchNorm2d(in_channels)),
43 | ("1_activation", nn.ReLU(inplace=True)),
44 | ]))
45 | self.block = nn.Sequential(OrderedDict([
46 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)),
47 | ("1_normalization", nn.BatchNorm2d(out_channels)),
48 | ("2_activation", nn.ReLU(inplace=True)),
49 | ("3_dropout", nn.Dropout(dropout, inplace=True)),
50 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)),
51 | ]))
52 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False)
53 |
54 | def forward(self, x):
55 | x = self.norm_act(x)
56 | return self.block(x) + self.downsample(x)
57 |
58 |
59 | class Block(nn.Module):
60 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float):
61 | super(Block, self).__init__()
62 | self.block = nn.Sequential(
63 | DownsampleUnit(in_channels, out_channels, stride, dropout),
64 | *(BasicUnit(out_channels, dropout) for _ in range(depth))
65 | )
66 |
67 | def forward(self, x):
68 | return self.block(x)
69 |
70 |
71 | class WideResNet(nn.Module):
72 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, num_classes: int):
73 | super(WideResNet, self).__init__()
74 |
75 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor]
76 | self.block_depth = (depth - 4) // (3 * 2)
77 |
78 | self.f = nn.Sequential(OrderedDict([
79 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)),
80 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)),
81 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)),
82 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)),
83 | ("4_normalization", nn.BatchNorm2d(self.filters[3])),
84 | ("5_activation", nn.ReLU(inplace=True)),
85 | ("6_pooling", nn.AvgPool2d(kernel_size=8)),
86 | ("7_flattening", nn.Flatten()),
87 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=num_classes)),
88 | ]))
89 |
90 | self._initialize()
91 |
92 | def _initialize(self):
93 | for m in self.modules():
94 | if isinstance(m, nn.Conv2d):
95 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu")
96 | if m.bias is not None:
97 | m.bias.data.zero_()
98 | elif isinstance(m, nn.BatchNorm2d):
99 | m.weight.data.fill_(1)
100 | m.bias.data.zero_()
101 | elif isinstance(m, nn.Linear):
102 | m.weight.data.zero_()
103 | m.bias.data.zero_()
104 |
105 | def forward(self, x):
106 | return self.f(x)
107 |
108 | class WideResNet28x10:
109 | base = WideResNet
110 | args = list()
111 | kwargs = {'depth': 28, 'width_factor': 10, 'dropout': 0, 'in_channels': 3}
112 | transform_train = transforms.Compose([
113 | transforms.RandomCrop(32, padding=4),
114 | transforms.RandomHorizontalFlip(),
115 | transforms.ToTensor(),
116 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
117 | ])
118 | transform_test = transforms.Compose([
119 | transforms.ToTensor(),
120 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
121 | ])
122 |
123 | class WideResNet16x8:
124 | base = WideResNet
125 | args = list()
126 | kwargs = {'depth': 16, 'width_factor': 8, 'dropout': 0, 'in_channels': 3}
127 | transform_train = transforms.Compose([
128 | transforms.RandomCrop(32, padding=4),
129 | transforms.RandomHorizontalFlip(),
130 | transforms.ToTensor(),
131 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
132 | ])
133 | transform_test = transforms.Compose([
134 | transforms.ToTensor(),
135 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
136 | ])
--------------------------------------------------------------------------------
/recipes/run_rwp_ddp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ################################ CIFAR ###################################
4 | datasets=CIFAR100
5 | device=0,1 # use two GPUs for parallel computing
6 | model=resnet18 # resnet18 VGG16BN WideResNet16x8 WideResNet28x10
7 | schedule=cosine
8 | wd=0.001
9 | epoch=200
10 | bz=256
11 | lr=0.10
12 | port=1234
13 | seed=0
14 | alpha=0.5
15 | gamma=0.01
16 |
17 | DST=results/rwp_ddp_cutout_gamma$gamma\_alpha$alpha\_$epoch\_$bz\_$lr\_$model\_$wd\_$datasets\_$schedule\_seed$seed
18 | CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node 2 --master_port $port train_rwp_parallel.py --datasets $datasets \
19 | --arch=$model --epochs=$epoch --wd=$wd --randomseed $seed --lr $lr --gamma $gamma --cutout -b $bz --alpha $alpha --workers 8 \
20 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100 --schedule $schedule
21 |
22 |
--------------------------------------------------------------------------------
/recipes/run_rwp_imagenet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | datasets=ImageNet
4 | device=0,1,2,3 # parallelized training for RWP
5 |
6 | model=resnet18
7 | path=... # dir for ImageNet datasets
8 | DST=save_resnet18
9 | CUDA_VISIBLE_DEVICES=$device python3 train_rwp_imagenet.py -a $model \
10 | --epochs 90 --workers 16 --dist-url 'tcp://127.0.0.1:4234' --lr 0.1 -b 256 \
11 | --dist-backend 'nccl' --multiprocessing-distributed --gamma 0.005 --alpha 0.5 \
12 | --save-dir=$DST/checkpoints --log-dir=$DST \
13 | --world-size 1 --rank 0 $path
14 |
--------------------------------------------------------------------------------
/train_rwp_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import os
8 |
9 | import numpy as np
10 | import pickle
11 |
12 | from PIL import Image, ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | from utils import *
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.parallel
19 | import torch.backends.cudnn as cudnn
20 | import torch.distributed as dist
21 | import torch.optim
22 | import torch.multiprocessing as mp
23 | import torch.utils.data
24 | import torch.utils.data.distributed
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 | import torchvision.models as models
28 |
29 | model_names = sorted(name for name in models.__dict__
30 | if name.islower() and not name.startswith("__")
31 | and callable(models.__dict__[name]))
32 |
33 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
34 | parser.add_argument('data', metavar='DIR',
35 | help='path to dataset')
36 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
37 | choices=model_names,
38 | help='model architecture: ' +
39 | ' | '.join(model_names) +
40 | ' (default: resnet18)')
41 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
42 | help='number of data loading workers (default: 4)')
43 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
44 | help='number of total epochs to run')
45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
46 | help='manual epoch number (useful on restarts)')
47 | parser.add_argument('-b', '--batch-size', default=256, type=int,
48 | metavar='N',
49 | help='mini-batch size (default: 256), this is the total '
50 | 'batch size of all GPUs on the current node when '
51 | 'using Data Parallel or Distributed Data Parallel')
52 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
53 | metavar='LR', help='initial learning rate', dest='lr')
54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
55 | help='momentum')
56 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
57 | metavar='W', help='weight decay (default: 1e-4)',
58 | dest='weight_decay')
59 | parser.add_argument('--alpha', default=0.5, type=float,
60 | metavar='AA', help='alpha for mixing gradients')
61 | parser.add_argument('--gamma', default=0.01, type=float,
62 | metavar='GAMMA', help='gamma for noise')
63 |
64 | parser.add_argument('-p', '--print-freq', default=1000, type=int,
65 | metavar='N', help='print frequency (default: 10)')
66 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
67 | help='path to latest checkpoint (default: none)')
68 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
69 | help='evaluate model on validation set')
70 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
71 | help='use pre-trained model')
72 | parser.add_argument('--world-size', default=-1, type=int,
73 | help='number of nodes for distributed training')
74 | parser.add_argument('--rank', default=-1, type=int,
75 | help='node rank for distributed training')
76 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
77 | help='url used to set up distributed training')
78 | parser.add_argument('--dist-backend', default='nccl', type=str,
79 | help='distributed backend')
80 | parser.add_argument('--seed', default=42, type=int,
81 | help='seed for initializing training. ')
82 | parser.add_argument('--gpu', default=None, type=int,
83 | help='GPU id to use.')
84 | parser.add_argument('--save-dir', dest='save_dir',
85 | help='The directory used to save the trained models',
86 | default='save_temp', type=str)
87 | parser.add_argument('--log-dir', dest='log_dir',
88 | help='The directory used to save the log',
89 | default='save_temp', type=str)
90 | parser.add_argument('--log-name', dest='log_name',
91 | help='The log file name',
92 | default='log', type=str)
93 | parser.add_argument('--multiprocessing-distributed', action='store_true',
94 | help='Use multi-processing distributed training to launch '
95 | 'N processes per node, which has N GPUs. This is the '
96 | 'fastest way to use PyTorch for either single node or '
97 | 'multi node data parallel training')
98 |
99 |
100 | best_acc1 = 0
101 |
102 |
103 | param_vec = []
104 | # Record training statistics
105 | train_loss = []
106 | train_acc = []
107 | test_loss = []
108 | test_acc = []
109 | arr_time = []
110 |
111 |
112 | def get_model_grad_vec(model):
113 | # Return the model gradient as a vector
114 |
115 | vec = []
116 | for name,param in model.named_parameters():
117 | vec.append(param.grad.detach().reshape(-1))
118 | return torch.cat(vec, 0)
119 |
120 | def update_grad(model, grad_vec):
121 | idx = 0
122 | for name,param in model.named_parameters():
123 | arr_shape = param.grad.shape
124 | size = 1
125 | for i in range(len(list(arr_shape))):
126 | size *= arr_shape[i]
127 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone()
128 | idx += size
129 |
130 |
131 | iters = 0
132 | def get_model_param_vec(model):
133 | # Return the model parameters as a vector
134 |
135 | vec = []
136 | for name,param in model.named_parameters():
137 | vec.append(param.detach().cpu().reshape(-1).numpy())
138 | return np.concatenate(vec, 0)
139 |
140 | def main():
141 | global train_loss, train_acc, test_loss, test_acc, arr_time
142 |
143 | args = parser.parse_args()
144 |
145 | print ('gamma:', args.gamma)
146 | save_dir = 'save_' + args.arch
147 | if not os.path.exists(save_dir):
148 | os.makedirs(save_dir)
149 | args.save_dir = save_dir
150 |
151 |
152 | # Check the log_dir exists or not
153 | # if args.rank == 0:
154 | print ('log dir:', args.log_dir)
155 | if not os.path.exists(args.log_dir):
156 | os.makedirs(args.log_dir)
157 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
158 | print ('log dir:', args.log_dir)
159 |
160 | if args.seed is not None:
161 | random.seed(args.seed)
162 | torch.manual_seed(args.seed)
163 | cudnn.deterministic = True
164 | warnings.warn('You have chosen to seed training. '
165 | 'This will turn on the CUDNN deterministic setting, '
166 | 'which can slow down your training considerably! '
167 | 'You may see unexpected behavior when restarting '
168 | 'from checkpoints.')
169 |
170 | if args.gpu is not None:
171 | warnings.warn('You have chosen a specific GPU. This will completely '
172 | 'disable data parallelism.')
173 |
174 | if args.dist_url == "env://" and args.world_size == -1:
175 | args.world_size = int(os.environ["WORLD_SIZE"])
176 |
177 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
178 |
179 | ngpus_per_node = torch.cuda.device_count()
180 | if args.multiprocessing_distributed:
181 | # Since we have ngpus_per_node processes per node, the total world_size
182 | # needs to be adjusted accordingly
183 | args.world_size = ngpus_per_node * args.world_size
184 | # Use torch.multiprocessing.spawn to launch distributed processes: the
185 | # main_worker process function
186 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
187 | else:
188 | # Simply call main_worker function
189 | main_worker(args.gpu, ngpus_per_node, args)
190 |
191 | sample_idx = 0
192 |
193 | def main_worker(gpu, ngpus_per_node, args):
194 | global train_loss, train_acc, test_loss, test_acc, arr_time
195 | global best_acc1, param_vec, sample_idx
196 | args.gpu = gpu
197 |
198 | if args.gpu is not None:
199 | print("Use GPU: {} for training".format(args.gpu))
200 |
201 | if args.distributed:
202 | if args.dist_url == "env://" and args.rank == -1:
203 | args.rank = int(os.environ["RANK"])
204 | if args.multiprocessing_distributed:
205 | # For multiprocessing distributed training, rank needs to be the
206 | # global rank among all the processes
207 | args.rank = args.rank * ngpus_per_node + gpu
208 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
209 | world_size=args.world_size, rank=args.rank)
210 | # create model
211 | if args.pretrained:
212 | print("=> using pre-trained model '{}'".format(args.arch))
213 | model = models.__dict__[args.arch](pretrained=True)
214 | else:
215 | print("=> creating model '{}'".format(args.arch))
216 | model = models.__dict__[args.arch]()
217 |
218 |
219 | # Double the training epochs since each iteration will consume two batches of data for calculating g and g_s
220 | args.epochs = args.epochs * 2
221 | args.batch_size = args.batch_size * 2
222 |
223 |
224 | if not torch.cuda.is_available():
225 | print('using CPU, this will be slow')
226 | elif args.distributed:
227 | # For multiprocessing distributed, DistributedDataParallel constructor
228 | # should always set the single device scope, otherwise,
229 | # DistributedDataParallel will use all available devices.
230 | if args.gpu is not None:
231 | torch.cuda.set_device(args.gpu)
232 | model.cuda(args.gpu)
233 | # When using a single GPU per process and per
234 | # DistributedDataParallel, we need to divide the batch size
235 | # ourselves based on the total number of GPUs we have
236 | args.batch_size = int(args.batch_size / ngpus_per_node)
237 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
238 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
239 | else:
240 | model.cuda()
241 | # DistributedDataParallel will divide and allocate batch_size to all
242 | # available GPUs if device_ids are not set
243 | model = torch.nn.parallel.DistributedDataParallel(model)
244 | elif args.gpu is not None:
245 | torch.cuda.set_device(args.gpu)
246 | model = model.cuda(args.gpu)
247 | else:
248 | # DataParallel will divide and allocate batch_size to all available GPUs
249 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
250 | model.features = torch.nn.DataParallel(model.features)
251 | model.cuda()
252 | else:
253 | model = torch.nn.DataParallel(model).cuda()
254 |
255 | # define loss function (criterion) and optimizer
256 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
257 |
258 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
259 | momentum=args.momentum,
260 | weight_decay=args.weight_decay)
261 |
262 | # optionally resume from a checkpoint
263 | if args.resume:
264 | if os.path.isfile(args.resume):
265 | print("=> loading checkpoint '{}'".format(args.resume))
266 | if args.gpu is None:
267 | checkpoint = torch.load(args.resume)
268 | else:
269 | # Map model to be loaded to specified single gpu.
270 | loc = 'cuda:{}'.format(args.gpu)
271 | checkpoint = torch.load(args.resume, map_location=loc)
272 | args.start_epoch = checkpoint['epoch']
273 | best_acc1 = checkpoint['best_acc1']
274 | if args.gpu is not None:
275 | # best_acc1 may be from a checkpoint from a different GPU
276 | best_acc1 = best_acc1.to(args.gpu)
277 | model.load_state_dict(checkpoint['state_dict'])
278 | optimizer.load_state_dict(checkpoint['optimizer'])
279 | print("=> loaded checkpoint '{}' (epoch {})"
280 | .format(args.resume, checkpoint['epoch']))
281 | else:
282 | print("=> no checkpoint found at '{}'".format(args.resume))
283 |
284 | cudnn.benchmark = True
285 |
286 | # Data loading code
287 | traindir = os.path.join(args.data, 'train')
288 | valdir = os.path.join(args.data, 'val')
289 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
290 | std=[0.229, 0.224, 0.225])
291 |
292 | train_dataset = datasets.ImageFolder(
293 | traindir,
294 | transforms.Compose([
295 | transforms.RandomResizedCrop(224),
296 | transforms.RandomHorizontalFlip(),
297 | transforms.ToTensor(),
298 | normalize,
299 | ]))
300 |
301 | if args.distributed:
302 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
303 | else:
304 | train_sampler = None
305 |
306 | train_loader = torch.utils.data.DataLoader(
307 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
308 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
309 |
310 | val_loader = torch.utils.data.DataLoader(
311 | datasets.ImageFolder(valdir, transforms.Compose([
312 | transforms.Resize(256),
313 | transforms.CenterCrop(224),
314 | transforms.ToTensor(),
315 | normalize,
316 | ])),
317 | batch_size=args.batch_size, shuffle=False,
318 | num_workers=args.workers, pin_memory=True)
319 |
320 | if args.evaluate:
321 | validate(val_loader, model, criterion, args)
322 | return
323 |
324 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
325 | torch.save(model.state_dict(), 'save_' + args.arch + '/' + str(sample_idx)+'.pt')
326 |
327 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
328 |
329 | for epoch in range(args.start_epoch, args.epochs):
330 | if args.distributed:
331 | train_sampler.set_epoch(epoch)
332 |
333 | # train for one epoch
334 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
335 | train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node)
336 | lr_scheduler.step()
337 |
338 | # evaluate on validation set
339 | acc1 = validate(val_loader, model, criterion, args)
340 |
341 | # remember best acc@1 and save checkpoint
342 | is_best = acc1 > best_acc1
343 | best_acc1 = max(acc1, best_acc1)
344 |
345 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
346 | and args.rank % ngpus_per_node == 0):
347 | save_checkpoint({
348 | 'epoch': epoch + 1,
349 | 'arch': args.arch,
350 | 'state_dict': model.state_dict(),
351 | 'best_acc1': best_acc1,
352 | 'optimizer' : optimizer.state_dict(),
353 | }, is_best)
354 |
355 | torch.save(model, os.path.join(args.save_dir, 'model.pt'))
356 |
357 | print ('train loss: ', train_loss)
358 | print ('train acc: ', train_acc)
359 | print ('test loss: ', test_loss)
360 | print ('test acc: ', test_acc)
361 |
362 | print ('time: ', arr_time)
363 |
364 |
365 | def train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node):
366 | global iters, param_vec, sample_idx
367 | global train_loss, train_acc, test_loss, test_acc, arr_time
368 |
369 | batch_time = AverageMeter('Time', ':6.3f')
370 | data_time = AverageMeter('Data', ':6.3f')
371 | losses = AverageMeter('Loss', ':.4e')
372 | top1 = AverageMeter('Acc@1', ':6.2f')
373 | top5 = AverageMeter('Acc@5', ':6.2f')
374 | progress = ProgressMeter(
375 | len(train_loader),
376 | [batch_time, data_time, losses, top1, top5],
377 | prefix="Epoch: [{}]".format(epoch))
378 |
379 | # switch to train mode
380 | model.train()
381 |
382 | end = time.time()
383 | epoch_start = end
384 | for i, (images, target) in enumerate(train_loader):
385 | # measure data loading time
386 | data_time.update(time.time() - end)
387 |
388 | if args.gpu is not None:
389 | images = images.cuda(args.gpu, non_blocking=True)
390 | if torch.cuda.is_available():
391 | target = target.cuda(args.gpu, non_blocking=True)
392 |
393 |
394 | if args.rank % 2 == 1:
395 | weight = args.alpha * 2
396 | ##################### grw #############################
397 | noise = []
398 | for mp in model.parameters():
399 | if len(mp.shape) > 1:
400 | sh = mp.shape
401 | sh_mul = np.prod(sh[1:])
402 | temp = mp.view(sh[0], -1).norm(dim=1, keepdim=True).repeat(1, sh_mul).view(mp.shape)
403 | temp = torch.normal(0, args.gamma*temp).to(mp.data.device)
404 | else:
405 | temp = torch.empty_like(mp, device=mp.data.device)
406 | temp.normal_(0, args.gamma*(mp.view(-1).norm().item() + 1e-16))
407 | noise.append(temp)
408 | mp.data.add_(noise[-1])
409 | else:
410 | weight = (1 - args.alpha) * 2
411 |
412 | # compute output
413 | output = model(images)
414 | loss = criterion(output, target) * weight
415 | optimizer.zero_grad()
416 | loss.backward()
417 |
418 | if args.rank % 2 == 1:
419 | # going back to without theta
420 | with torch.no_grad():
421 | for mp, n in zip(model.parameters(), noise):
422 | mp.data.sub_(n)
423 |
424 | optimizer.step()
425 |
426 | # measure accuracy and record loss
427 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
428 | losses.update(loss.item() / weight, images.size(0))
429 | top1.update(acc1[0], images.size(0))
430 | top5.update(acc5[0], images.size(0))
431 |
432 | # compute gradient and do SGD step
433 | # optimizer.zero_grad()
434 | # loss.backward()
435 | # optimizer.step()
436 |
437 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
438 | and args.rank % ngpus_per_node == 0):
439 |
440 | if i % args.print_freq == 0:
441 | progress.display(i)
442 |
443 | if i > 0 and i % 1000 == 0 and i < 5000:
444 | sample_idx += 1
445 | # torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt')
446 |
447 | # measure elapsed time
448 | batch_time.update(time.time() - end)
449 | end = time.time()
450 |
451 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
452 | and args.rank % ngpus_per_node == 0):
453 | sample_idx += 1
454 | # torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt')
455 |
456 | arr_time.append(time.time() - epoch_start)
457 | train_loss.append(losses.avg)
458 | train_acc.append(top1.avg)
459 |
460 |
461 | def validate(val_loader, model, criterion, args):
462 | global train_loss, train_acc, test_loss, test_acc, arr_time
463 | batch_time = AverageMeter('Time', ':6.3f')
464 | losses = AverageMeter('Loss', ':.4e')
465 | top1 = AverageMeter('Acc@1', ':6.2f')
466 | top5 = AverageMeter('Acc@5', ':6.2f')
467 | progress = ProgressMeter(
468 | len(val_loader),
469 | [batch_time, losses, top1, top5],
470 | prefix='Test: ')
471 |
472 | # switch to evaluate mode
473 | model.eval()
474 |
475 | with torch.no_grad():
476 | end = time.time()
477 | for i, (images, target) in enumerate(val_loader):
478 | if args.gpu is not None:
479 | images = images.cuda(args.gpu, non_blocking=True)
480 | if torch.cuda.is_available():
481 | target = target.cuda(args.gpu, non_blocking=True)
482 |
483 | # compute output
484 | output = model(images)
485 | loss = criterion(output, target)
486 |
487 | # measure accuracy and record loss
488 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
489 | losses.update(loss.item(), images.size(0))
490 | top1.update(acc1[0], images.size(0))
491 | top5.update(acc5[0], images.size(0))
492 |
493 | # measure elapsed time
494 | batch_time.update(time.time() - end)
495 | end = time.time()
496 |
497 | if i % args.print_freq == 0:
498 | progress.display(i)
499 |
500 | # TODO: this should also be done with the ProgressMeter
501 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
502 | .format(top1=top1, top5=top5))
503 | test_acc.append(top1.avg)
504 | test_loss.append(losses.avg)
505 | return top1.avg
506 |
507 |
508 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
509 | torch.save(state, filename)
510 | if is_best:
511 | shutil.copyfile(filename, 'model_best.pth.tar')
512 |
513 |
514 | class AverageMeter(object):
515 | """Computes and stores the average and current value"""
516 | def __init__(self, name, fmt=':f'):
517 | self.name = name
518 | self.fmt = fmt
519 | self.reset()
520 |
521 | def reset(self):
522 | self.val = 0
523 | self.avg = 0
524 | self.sum = 0
525 | self.count = 0
526 |
527 | def update(self, val, n=1):
528 | self.val = val
529 | self.sum += val * n
530 | self.count += n
531 | self.avg = self.sum / self.count
532 |
533 | def __str__(self):
534 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
535 | return fmtstr.format(**self.__dict__)
536 |
537 |
538 | class ProgressMeter(object):
539 | def __init__(self, num_batches, meters, prefix=""):
540 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
541 | self.meters = meters
542 | self.prefix = prefix
543 |
544 | def display(self, batch):
545 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
546 | entries += [str(meter) for meter in self.meters]
547 | print('\t'.join(entries))
548 |
549 | def _get_batch_fmtstr(self, num_batches):
550 | num_digits = len(str(num_batches // 1))
551 | fmt = '{:' + str(num_digits) + 'd}'
552 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
553 |
554 |
555 | def accuracy(output, target, topk=(1,)):
556 | """Computes the accuracy over the k top predictions for the specified values of k"""
557 | with torch.no_grad():
558 | maxk = max(topk)
559 | batch_size = target.size(0)
560 |
561 | _, pred = output.topk(maxk, 1, True, True)
562 | pred = pred.t()
563 | correct = pred.eq(target.view(1, -1).expand_as(pred))
564 |
565 | res = []
566 | for k in topk:
567 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
568 | res.append(correct_k.mul_(100.0 / batch_size))
569 | return res
570 |
571 |
572 | if __name__ == '__main__':
573 | main()
--------------------------------------------------------------------------------
/train_rwp_parallel.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from torch.nn.modules.batchnorm import _BatchNorm
3 | import os
4 | import time
5 | import numpy as np
6 | import random
7 | import sys
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.optim
14 | import torch.utils.data
15 | import torchvision.transforms as transforms
16 | import torchvision.datasets as datasets
17 |
18 | import torch.distributed as dist
19 | from torch.nn.parallel import DistributedDataParallel as DDP
20 |
21 | from utils import *
22 |
23 |
24 | # Parse arguments
25 | parser = argparse.ArgumentParser(description='DDP RWP training')
26 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='SGD')
27 | parser.add_argument('--arch', '-a', metavar='ARCH',
28 | help='The architecture of the model')
29 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str,
30 | help='The training datasets')
31 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str,
32 | help='The optimizer for training')
33 | parser.add_argument('--schedule', metavar='SCHEDULE', default='step', type=str,
34 | help='The schedule for training')
35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
36 | help='number of data loading workers (default: 4)')
37 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
38 | help='number of total epochs to run')
39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
40 | help='manual epoch number (useful on restarts)')
41 | parser.add_argument('-b', '--batch-size', default=128, type=int,
42 | metavar='N', help='mini-batch size (default: 128)')
43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
44 | metavar='LR', help='initial learning rate')
45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
46 | help='momentum')
47 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
48 | metavar='W', help='weight decay (default: 1e-4)')
49 | parser.add_argument('--print-freq', '-p', default=100, type=int,
50 | metavar='N', help='print frequency (default: 50 iterations)')
51 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
52 | help='path to latest checkpoint (default: none)')
53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
54 | help='evaluate model on validation set')
55 | parser.add_argument('--wandb', dest='wandb', action='store_true',
56 | help='use wandb to monitor statisitcs')
57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
58 | help='use pre-trained model')
59 | parser.add_argument('--half', dest='half', action='store_true',
60 | help='use half-precision(16-bit) ')
61 | parser.add_argument('--save-dir', dest='save_dir',
62 | help='The directory used to save the trained models',
63 | default='save_temp', type=str)
64 | parser.add_argument('--log-dir', dest='log_dir',
65 | help='The directory used to save the log',
66 | default='save_temp', type=str)
67 | parser.add_argument('--log-name', dest='log_name',
68 | help='The log file name',
69 | default='log', type=str)
70 | parser.add_argument('--randomseed',
71 | help='Randomseed for training and initialization',
72 | type=int, default=1)
73 | parser.add_argument('--cutout', dest='cutout', action='store_true',
74 | help='use cutout data augmentation')
75 | parser.add_argument('--alpha', default=0.5, type=float,
76 | metavar='A', help='alpha for mixing gradients')
77 | parser.add_argument('--gamma', default=0.01, type=float,
78 | metavar='gamma', help='Perturbation magnitude gamma for RWP')
79 |
80 | parser.add_argument("--local_rank", default=-1, type=int)
81 |
82 | best_prec1 = 0
83 |
84 | # Record training statistics
85 | train_loss = []
86 | train_err = []
87 | test_loss = []
88 | test_err = []
89 | arr_time = []
90 |
91 | args = parser.parse_args()
92 |
93 | local_rank = args.local_rank
94 | torch.cuda.set_device(local_rank)
95 | dist.init_process_group(backend='nccl')
96 | args.world_size = torch.distributed.get_world_size()
97 | args.workers = int((args.workers + args.world_size - 1) / args.world_size)
98 | if args.local_rank == 0:
99 | print ('world size: {} workers per GPU: {}'.format(args.world_size, args.workers))
100 | device = torch.device("cuda", local_rank)
101 |
102 | if args.wandb:
103 | import wandb
104 | wandb.init(project="TWA", entity="nblt")
105 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
106 | wandb.run.name = args.EXP + date
107 |
108 |
109 | def get_model_param_vec(model):
110 | # Return the model parameters as a vector
111 |
112 | vec = []
113 | for name,param in model.named_parameters():
114 | vec.append(param.data.detach().reshape(-1))
115 | return torch.cat(vec, 0)
116 |
117 |
118 | def get_model_grad_vec(model):
119 | # Return the model gradient as a vector
120 |
121 | vec = []
122 | for name,param in model.named_parameters():
123 | vec.append(param.grad.detach().reshape(-1))
124 | return torch.cat(vec, 0)
125 |
126 | def update_grad(model, grad_vec):
127 | idx = 0
128 | for name,param in model.named_parameters():
129 | arr_shape = param.grad.shape
130 | size = param.grad.numel()
131 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone()
132 | idx += size
133 |
134 | def update_param(model, param_vec):
135 | idx = 0
136 | for name,param in model.named_parameters():
137 | arr_shape = param.data.shape
138 | size = param.data.numel()
139 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone()
140 | idx += size
141 |
142 | def print_param_shape(model):
143 | for name,param in model.named_parameters():
144 | print (name, param.data.shape)
145 |
146 | def main():
147 |
148 | global args, best_prec1, p0
149 | global train_loss, train_err, test_loss, test_err, arr_time, running_weight
150 |
151 | set_seed(args.randomseed)
152 |
153 | # Check the save_dir exists or not
154 | if args.local_rank == 0:
155 | print ('save dir:', args.save_dir)
156 | if not os.path.exists(args.save_dir):
157 | os.makedirs(args.save_dir)
158 |
159 | # Check the log_dir exists or not
160 | if args.local_rank == 0:
161 | print ('log dir:', args.log_dir)
162 | if not os.path.exists(args.log_dir):
163 | os.makedirs(args.log_dir)
164 |
165 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
166 |
167 | # Define model
168 | # model = torch.nn.DataParallel(get_model(args))
169 | model = get_model(args).to(device)
170 | model = DDP(model, device_ids=[local_rank], output_device=local_rank)
171 |
172 | # print_param_shape(model)
173 |
174 | # Optionally resume from a checkpoint
175 | if args.resume:
176 | # if os.path.isfile(args.resume):
177 | if os.path.isfile(os.path.join(args.save_dir, args.resume)):
178 |
179 | # model.load_state_dict(torch.load(os.path.join(args.save_dir, args.resume)))
180 |
181 | print ("=> loading checkpoint '{}'".format(args.resume))
182 | checkpoint = torch.load(args.resume)
183 | args.start_epoch = checkpoint['epoch']
184 | print ('from ', args.start_epoch)
185 | best_prec1 = checkpoint['best_prec1']
186 | model.load_state_dict(checkpoint['state_dict'])
187 | print ("=> loaded checkpoint '{}' (epoch {})"
188 | .format(args.evaluate, checkpoint['epoch']))
189 | else:
190 | print ("=> no checkpoint found at '{}'".format(args.resume))
191 |
192 | cudnn.benchmark = True
193 |
194 | # Prepare Dataloader
195 | print ('cutout:', args.cutout)
196 | if args.cutout:
197 | train_loader, val_loader = get_datasets_cutout_ddp(args)
198 | else:
199 | train_loader, val_loader = get_datasets_ddp(args)
200 |
201 | # define loss function (criterion) and optimizer
202 | criterion = nn.CrossEntropyLoss().to(device)
203 |
204 | if args.half:
205 | model.half()
206 | criterion.half()
207 |
208 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
209 |
210 | # Double the training epochs since each iteration will consume two batches of data for calculating g and g_s
211 | args.epochs = args.epochs * 2
212 |
213 | if args.schedule == 'step':
214 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs * 0.5), int(args.epochs * 0.75)], last_epoch=args.start_epoch - 1)
215 | elif args.schedule == 'cosine':
216 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
217 |
218 | if args.evaluate:
219 | validate(val_loader, model, criterion)
220 | return
221 |
222 |
223 | is_best = 0
224 | print ('Start training: ', args.start_epoch, '->', args.epochs)
225 | print ('gamma:', args.gamma)
226 | print ('len(train_loader):', len(train_loader))
227 |
228 | for epoch in range(args.start_epoch, args.epochs):
229 | train_loader.sampler.set_epoch(epoch)
230 |
231 | # train for one epoch
232 | if args.local_rank == 0:
233 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
234 | train(train_loader, model, criterion, optimizer, epoch)
235 | lr_scheduler.step()
236 |
237 | if epoch % 2 == 0: continue
238 |
239 | # evaluate on validation set
240 | prec1 = validate(val_loader, model, criterion)
241 |
242 | # remember best prec@1 and save checkpoint
243 | is_best = prec1 > best_prec1
244 | best_prec1 = max(prec1, best_prec1)
245 |
246 | if args.local_rank == 0:
247 | save_checkpoint({
248 | 'state_dict': model.state_dict(),
249 | 'best_prec1': best_prec1,
250 | }, is_best, filename=os.path.join(args.save_dir, 'model.th'))
251 |
252 | if args.local_rank == 0:
253 | print ('train loss: ', train_loss)
254 | print ('train err: ', train_err)
255 | print ('test loss: ', test_loss)
256 | print ('test err: ', test_err)
257 | print ('time: ', arr_time)
258 |
259 |
260 | def train(train_loader, model, criterion, optimizer, epoch):
261 | """
262 | Run one train epoch
263 | """
264 | global train_loss, train_err, arr_time
265 |
266 | batch_time = AverageMeter()
267 | data_time = AverageMeter()
268 | losses = AverageMeter()
269 | top1 = AverageMeter()
270 |
271 | # switch to train mode
272 | model.train()
273 |
274 | total_loss, total_err = 0, 0
275 | end = time.time()
276 | for i, (input, target) in enumerate(train_loader):
277 |
278 | # measure data loading time
279 | data_time.update(time.time() - end)
280 |
281 | target = target.to(device)
282 | input_var = input.to(device)
283 | target_var = target
284 | if args.half:
285 | input_var = input_var.half()
286 |
287 | if args.local_rank % 2 == 1:
288 | weight = args.alpha * 2
289 | with torch.no_grad():
290 | noise = []
291 | for mp in model.parameters():
292 | if len(mp.shape) > 1:
293 | sh = mp.shape
294 | sh_mul = np.prod(sh[1:])
295 | temp = mp.view(sh[0], -1).norm(dim=1, keepdim=True).repeat(1, sh_mul).view(mp.shape)
296 | temp = torch.normal(0, args.gamma*temp).to(mp.data.device)
297 | else:
298 | temp = torch.empty_like(mp, device=mp.data.device)
299 | temp.normal_(0, args.gamma*(mp.view(-1).norm().item() + 1e-16))
300 | noise.append(temp)
301 | mp.data.add_(noise[-1])
302 | else:
303 | weight = (1 - args.alpha) * 2
304 |
305 | # compute output
306 | output = model(input_var)
307 | loss = criterion(output, target_var) * weight
308 |
309 | optimizer.zero_grad()
310 | loss.backward()
311 |
312 | if args.local_rank % 2 == 1:
313 | # going back to without theta
314 | with torch.no_grad():
315 | for mp, n in zip(model.parameters(), noise):
316 | mp.data.sub_(n)
317 |
318 | optimizer.step()
319 |
320 | total_loss += loss.item() * input_var.shape[0] / weight
321 | total_err += (output.max(dim=1)[1] != target_var).sum().item()
322 |
323 | output = output.float()
324 | loss = loss.float()
325 |
326 | # measure accuracy and record loss
327 | prec1 = accuracy(output.data, target)[0]
328 | losses.update(loss.item(), input.size(0))
329 | top1.update(prec1.item(), input.size(0))
330 |
331 | # measure elapsed time
332 | batch_time.update(time.time() - end)
333 | end = time.time()
334 |
335 | if args.local_rank == 0 and (i % args.print_freq == 0 or i == len(train_loader) - 1):
336 | print('Epoch: [{0}][{1}/{2}]\t'
337 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
338 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
339 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
340 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
341 | epoch, i, len(train_loader), batch_time=batch_time,
342 | data_time=data_time, loss=losses, top1=top1))
343 |
344 | if args.local_rank == 0:
345 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum))
346 |
347 | tloss = total_loss / len(train_loader.dataset) * args.world_size
348 | terr = total_err / len(train_loader.dataset) * args.world_size
349 | train_loss.append(tloss)
350 | train_err.append(terr)
351 | print ('train loss | acc', tloss, 1 - terr)
352 |
353 | if args.wandb:
354 | wandb.log({"train loss": total_loss / len(train_loader.dataset)})
355 | wandb.log({"train acc": 1 - total_err / len(train_loader.dataset)})
356 |
357 | arr_time.append(batch_time.sum)
358 |
359 | def validate(val_loader, model, criterion, add=True):
360 | """
361 | Run evaluation
362 | """
363 | global test_err, test_loss
364 |
365 | total_loss = 0
366 | total_err = 0
367 |
368 | batch_time = AverageMeter()
369 | losses = AverageMeter()
370 | top1 = AverageMeter()
371 |
372 | # switch to evaluate mode
373 | model.eval()
374 |
375 | end = time.time()
376 | with torch.no_grad():
377 | for i, (input, target) in enumerate(val_loader):
378 | target = target.to(device)
379 | input_var = input.to(device)
380 | target_var = target.to(device)
381 |
382 | if args.half:
383 | input_var = input_var.half()
384 |
385 | # compute output
386 | output = model(input_var)
387 | loss = criterion(output, target_var)
388 |
389 | output = output.float()
390 | loss = loss.float()
391 |
392 | total_loss += loss.item() * input_var.shape[0]
393 | total_err += (output.max(dim=1)[1] != target_var).sum().item()
394 |
395 | # measure accuracy and record loss
396 | prec1 = accuracy(output.data, target)[0]
397 | losses.update(loss.item(), input.size(0))
398 | top1.update(prec1.item(), input.size(0))
399 |
400 | # measure elapsed time
401 | batch_time.update(time.time() - end)
402 | end = time.time()
403 |
404 | if i % args.print_freq == 0 and add:
405 | print('Test: [{0}/{1}]\t'
406 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
407 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
408 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
409 | i, len(val_loader), batch_time=batch_time, loss=losses,
410 | top1=top1))
411 |
412 | if add:
413 | print(' * Prec@1 {top1.avg:.3f}'
414 | .format(top1=top1))
415 |
416 | test_loss.append(total_loss / len(val_loader.dataset))
417 | test_err.append(total_err / len(val_loader.dataset))
418 |
419 | if args.wandb:
420 | wandb.log({"test loss": total_loss / len(val_loader.dataset)})
421 | wandb.log({"test acc": 1 - total_err / len(val_loader.dataset)})
422 |
423 | return top1.avg
424 |
425 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
426 | """
427 | Save the training model
428 | """
429 | torch.save(state, filename)
430 |
431 | class AverageMeter(object):
432 | """Computes and stores the average and current value"""
433 | def __init__(self):
434 | self.reset()
435 |
436 | def reset(self):
437 | self.val = 0
438 | self.avg = 0
439 | self.sum = 0
440 | self.count = 0
441 |
442 | def update(self, val, n=1):
443 | self.val = val
444 | self.sum += val * n
445 | self.count += n
446 | self.avg = self.sum / self.count
447 |
448 |
449 | def accuracy(output, target, topk=(1,)):
450 | """Computes the precision@k for the specified values of k"""
451 | maxk = max(topk)
452 | batch_size = target.size(0)
453 |
454 | _, pred = output.topk(maxk, 1, True, True)
455 | pred = pred.t()
456 | correct = pred.eq(target.view(1, -1).expand_as(pred))
457 |
458 | res = []
459 | for k in topk:
460 | correct_k = correct[:k].view(-1).float().sum(0)
461 | res.append(correct_k.mul_(100.0 / batch_size))
462 | return res
463 |
464 |
465 | if __name__ == '__main__':
466 | main()
467 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim as optim
6 | import torch.utils.data
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | import torchvision.datasets as datasets
10 | import torchvision.models as models_imagenet
11 |
12 | import numpy as np
13 | import random
14 | import os
15 | import time
16 | import models
17 | import sys
18 | import torch.utils.data as data
19 | from torchvision.datasets.utils import download_url, check_integrity
20 | import os.path
21 | import pickle
22 | from PIL import Image
23 |
24 | def set_seed(seed=1):
25 | random.seed(seed)
26 | np.random.seed(seed)
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed(seed)
29 | torch.backends.cudnn.deterministic = True
30 | torch.backends.cudnn.benchmark = False
31 |
32 | class Logger(object):
33 | def __init__(self,fileN ="Default.log"):
34 | self.terminal = sys.stdout
35 | self.log = open(fileN,"a")
36 |
37 | def write(self,message):
38 | self.terminal.write(message)
39 | self.log.write(message)
40 |
41 | def flush(self):
42 | self.terminal.flush()
43 | self.log.flush()
44 |
45 | ################################ datasets #######################################
46 |
47 | import torchvision.transforms as transforms
48 | import torchvision.datasets as datasets
49 | from torch.utils.data import DataLoader, Subset
50 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
51 |
52 | class Cutout:
53 | def __init__(self, size=16, p=0.5):
54 | self.size = size
55 | self.half_size = size // 2
56 | self.p = p
57 |
58 | def __call__(self, image):
59 | if torch.rand([1]).item() > self.p:
60 | return image
61 |
62 | left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item()
63 | top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item()
64 | right = min(image.size(1), left + self.size)
65 | bottom = min(image.size(2), top + self.size)
66 |
67 | image[:, max(0, left): right, max(0, top): bottom] = 0
68 | return image
69 |
70 | def get_datasets(args):
71 | if args.datasets == 'CIFAR10':
72 | print ('cifar10 dataset!')
73 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
74 |
75 | train_loader = torch.utils.data.DataLoader(
76 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
77 | transforms.RandomHorizontalFlip(),
78 | transforms.RandomCrop(32, 4),
79 | transforms.ToTensor(),
80 | normalize,
81 | ]), download=True),
82 | batch_size=args.batch_size, shuffle=True,
83 | num_workers=args.workers, pin_memory=True)
84 |
85 | val_loader = torch.utils.data.DataLoader(
86 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
87 | transforms.ToTensor(),
88 | normalize,
89 | ])),
90 | batch_size=128, shuffle=False,
91 | num_workers=args.workers, pin_memory=True)
92 |
93 | elif args.datasets == 'CIFAR100':
94 | print ('cifar100 dataset!')
95 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
96 |
97 | train_loader = torch.utils.data.DataLoader(
98 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
99 | transforms.RandomHorizontalFlip(),
100 | transforms.RandomCrop(32, 4),
101 | transforms.ToTensor(),
102 | normalize,
103 | ]), download=True),
104 | batch_size=args.batch_size, shuffle=True,
105 | num_workers=args.workers, pin_memory=True)
106 |
107 | val_loader = torch.utils.data.DataLoader(
108 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
109 | transforms.ToTensor(),
110 | normalize,
111 | ])),
112 | batch_size=128, shuffle=False,
113 | num_workers=args.workers, pin_memory=True)
114 |
115 | elif args.datasets == 'ImageNet':
116 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train')
117 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val')
118 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
119 | std=[0.229, 0.224, 0.225])
120 |
121 | train_dataset = datasets.ImageFolder(
122 | traindir,
123 | transforms.Compose([
124 | transforms.RandomResizedCrop(224),
125 | transforms.RandomHorizontalFlip(),
126 | transforms.ToTensor(),
127 | normalize,
128 | ]))
129 |
130 | train_loader = torch.utils.data.DataLoader(
131 | train_dataset, batch_size=args.batch_size, shuffle=True,
132 | num_workers=args.workers, pin_memory=True)
133 |
134 | val_loader = torch.utils.data.DataLoader(
135 | datasets.ImageFolder(valdir, transforms.Compose([
136 | transforms.Resize(256),
137 | transforms.CenterCrop(224),
138 | transforms.ToTensor(),
139 | normalize,
140 | ])),
141 | batch_size=args.batch_size, shuffle=False,
142 | num_workers=args.workers)
143 |
144 | return train_loader, val_loader
145 |
146 | def get_datasets_ddp(args):
147 | if args.datasets == 'CIFAR10':
148 | print ('cifar10 dataset!')
149 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
150 |
151 | my_trainset = datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
152 | transforms.RandomHorizontalFlip(),
153 | transforms.RandomCrop(32, 4),
154 | transforms.ToTensor(),
155 | normalize,
156 | ]), download=True)
157 |
158 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
159 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler)
160 |
161 | val_loader = torch.utils.data.DataLoader(
162 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
163 | transforms.ToTensor(),
164 | normalize,
165 | ])),
166 | batch_size=128, shuffle=False,
167 | num_workers=args.workers, pin_memory=True)
168 |
169 | elif args.datasets == 'CIFAR100':
170 | print ('cifar100 dataset!')
171 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
172 |
173 | my_trainset = datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
174 | transforms.RandomHorizontalFlip(),
175 | transforms.RandomCrop(32, 4),
176 | transforms.ToTensor(),
177 | normalize,
178 | ]), download=True)
179 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
180 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler)
181 |
182 | val_loader = torch.utils.data.DataLoader(
183 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
184 | transforms.ToTensor(),
185 | normalize,
186 | ])),
187 | batch_size=128, shuffle=False,
188 | num_workers=args.workers, pin_memory=True)
189 |
190 | return train_loader, val_loader
191 |
192 | def get_datasets_cutout(args):
193 | print ('cutout!')
194 | if args.datasets == 'CIFAR10':
195 | print ('cifar10 dataset!')
196 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
197 |
198 | train_loader = torch.utils.data.DataLoader(
199 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
200 | transforms.RandomHorizontalFlip(),
201 | transforms.RandomCrop(32, 4),
202 | transforms.ToTensor(),
203 | normalize,
204 | Cutout()
205 | ]), download=True),
206 | batch_size=args.batch_size, shuffle=True,
207 | num_workers=args.workers, pin_memory=True)
208 |
209 | val_loader = torch.utils.data.DataLoader(
210 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
211 | transforms.ToTensor(),
212 | normalize,
213 | ])),
214 | batch_size=128, shuffle=False,
215 | num_workers=args.workers, pin_memory=True)
216 |
217 | elif args.datasets == 'CIFAR100':
218 | print ('cifar100 dataset!')
219 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
220 |
221 | train_loader = torch.utils.data.DataLoader(
222 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
223 | transforms.RandomHorizontalFlip(),
224 | transforms.RandomCrop(32, 4),
225 | transforms.ToTensor(),
226 | normalize,
227 | Cutout()
228 | ]), download=True),
229 | batch_size=args.batch_size, shuffle=True,
230 | num_workers=args.workers, pin_memory=True)
231 |
232 | val_loader = torch.utils.data.DataLoader(
233 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
234 | transforms.ToTensor(),
235 | normalize,
236 | ])),
237 | batch_size=128, shuffle=False,
238 | num_workers=args.workers, pin_memory=True)
239 |
240 | return train_loader, val_loader
241 |
242 | def get_datasets_cutout_ddp(args):
243 | if args.datasets == 'CIFAR10':
244 | print ('cifar10 dataset!')
245 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
246 |
247 | my_trainset = datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
248 | transforms.RandomHorizontalFlip(),
249 | transforms.RandomCrop(32, 4),
250 | transforms.ToTensor(),
251 | normalize,
252 | Cutout()
253 | ]), download=True)
254 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
255 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler, drop_last=True, num_workers=args.workers, pin_memory=True)
256 |
257 | val_loader = torch.utils.data.DataLoader(
258 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
259 | transforms.ToTensor(),
260 | normalize,
261 | ])),
262 | batch_size=128, shuffle=False,
263 | num_workers=args.workers, pin_memory=True)
264 |
265 | elif args.datasets == 'CIFAR100':
266 | print ('cifar100 dataset!')
267 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
268 |
269 | my_trainset = datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
270 | transforms.RandomHorizontalFlip(),
271 | transforms.RandomCrop(32, 4),
272 | transforms.ToTensor(),
273 | normalize,
274 | Cutout()
275 | ]), download=True)
276 | train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
277 | train_loader = torch.utils.data.DataLoader(my_trainset, batch_size=args.batch_size, sampler=train_sampler, drop_last=True, num_workers=args.workers, pin_memory=True)
278 |
279 | val_loader = torch.utils.data.DataLoader(
280 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
281 | transforms.ToTensor(),
282 | normalize,
283 | ])),
284 | batch_size=128, shuffle=False,
285 | num_workers=args.workers, pin_memory=True)
286 |
287 | return train_loader, val_loader
288 |
289 | def get_model(args):
290 | print('Model: {}'.format(args.arch))
291 |
292 | if args.datasets == 'ImageNet':
293 | return models_imagenet.__dict__[args.arch]()
294 |
295 | if args.datasets == 'CIFAR10':
296 | num_classes = 10
297 | elif args.datasets == 'CIFAR100':
298 | num_classes = 100
299 |
300 | model_cfg = getattr(models, args.arch)
301 |
302 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
303 |
304 | class SAM(torch.optim.Optimizer):
305 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
306 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
307 |
308 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
309 | super(SAM, self).__init__(params, defaults)
310 |
311 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
312 | self.param_groups = self.base_optimizer.param_groups
313 | self.defaults.update(self.base_optimizer.defaults)
314 |
315 | @torch.no_grad()
316 | def first_step(self, zero_grad=False):
317 | grad_norm = self._grad_norm()
318 | for group in self.param_groups:
319 | scale = group["rho"] / (grad_norm + 1e-12)
320 |
321 | for p in group["params"]:
322 | if p.grad is None: continue
323 | self.state[p]["old_p"] = p.data.clone()
324 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
325 | p.add_(e_w) # climb to the local maximum "w + e(w)"
326 |
327 | if zero_grad: self.zero_grad()
328 |
329 | @torch.no_grad()
330 | def second_step(self, zero_grad=False):
331 | for group in self.param_groups:
332 | for p in group["params"]:
333 | if p.grad is None: continue
334 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
335 |
336 | self.base_optimizer.step() # do the actual "sharpness-aware" update
337 |
338 | if zero_grad: self.zero_grad()
339 |
340 | @torch.no_grad()
341 | def step(self, closure=None):
342 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
343 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
344 |
345 | self.first_step(zero_grad=True)
346 | closure()
347 | self.second_step()
348 |
349 | def _grad_norm(self):
350 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
351 | norm = torch.norm(
352 | torch.stack([
353 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
354 | for group in self.param_groups for p in group["params"]
355 | if p.grad is not None
356 | ]),
357 | p=2
358 | )
359 | return norm
360 |
361 | def load_state_dict(self, state_dict):
362 | super().load_state_dict(state_dict)
363 | self.base_optimizer.param_groups = self.param_groups
364 |
365 |
366 |
--------------------------------------------------------------------------------