├── twa.png
├── swa
├── models
│ ├── __init__.py
│ ├── vgg.py
│ └── preresnet.py
├── readme.md
├── run.sh
├── run_twa.sh
├── utils_swa.py
├── utils.py
├── train.py
└── train_twa.py
├── requirements.txt
├── .gitattributes
├── models
├── __pycache__
│ ├── vgg.cpython-37.pyc
│ ├── __init__.cpython-37.pyc
│ ├── resnet.cpython-37.pyc
│ ├── preresnet.cpython-37.pyc
│ └── wide_resnet.cpython-37.pyc
├── __init__.py
├── vgg.py
├── preresnet.py
├── resnet.py
└── wide_resnet.py
├── LICENSE
├── run.sh
├── README.md
├── utils.py
├── train_sgd_cifar.py
├── train_twa.py
├── train_twa_ddp.py
└── train_sgd_imagenet.py
/twa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/twa.png
--------------------------------------------------------------------------------
/swa/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .preresnet import *
2 | from .vgg import *
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.6.0
2 | torchvision>=0.6
3 | numpy>=1.21
4 | wandb==0.12.7
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/models/__pycache__/vgg.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/vgg.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .preresnet import *
2 | from .resnet import *
3 | from .vgg import *
4 | from .wide_resnet import *
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/preresnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/preresnet.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wide_resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/wide_resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/swa/readme.md:
--------------------------------------------------------------------------------
1 | ### TWA in tail stage training
2 | We show that TWA could improve the performance of SWA in the original SWA setting, where the improvements are more significant when the tail learning rate `swa_lr` is larger.
3 |
4 | First, run SWA using original [code](https://github.com/timgaripov/swa):
5 | ```
6 | bash run.sh
7 | ```
8 | Then, we could perform TWA using:
9 | ```
10 | bash run_twa.sh
11 | ```
12 | The training configuration is easy to set as you need in the scripts.
--------------------------------------------------------------------------------
/swa/run.sh:
--------------------------------------------------------------------------------
1 | device=0
2 | data_dir=../datasets/
3 |
4 | ############################### VGG16 ###################################
5 | dataset=CIFAR100
6 | model=VGG16BN
7 | seed=0
8 | swa_lr=0.05
9 | dir=swa_$model\_$dataset\_$seed\_$swa_lr
10 | UDA_VISIBLE_DEVICES=$device python3 train.py --dir=$dir --dataset=$dataset --data_path=$data_dir \
11 | --model=$model --epochs=300 --lr_init=0.1 --wd=5e-4 --seed $seed \
12 | --swa --swa_start=161 --swa_lr=$swa_lr |& tee -a $dir/log # SWA 1.5 Budgets
13 |
14 |
15 | ############################### PreResNet ###################################
16 | dataset=CIFAR100 # CIFAR10 CIFAR100
17 | model=PreResNet164
18 | seed=0
19 | swa_lr=0.05
20 |
21 | dir=swa_$model\_$dataset\_$seed\_$swa_lr
22 | CUDA_VISIBLE_DEVICES=$device python3 train.py --dir=$dir --seed $seed\
23 | --dataset=$dataset --data_path=$data_dir --model=PreResNet164 --epochs=225 \
24 | --lr_init=0.1 --wd=3e-4 --swa --swa_start=126 --swa_lr=$swa_lr |& tee -a $dir/log # SWA 1.5 Budgets
25 |
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 nblt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/swa/run_twa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # CIFAR experiments
4 |
5 | datasets=CIFAR100
6 | device=0
7 |
8 | ####################################### swa ########################################
9 |
10 | seed=0
11 | swa_lr=0.05 # 0.05 / 0.10
12 | model=PreResNet164
13 | wd_psgd=0.00005
14 | DST=swa_$model\_$datasets\_$seed\_$swa_lr
15 |
16 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \
17 | --opt SGD --extract Schmidt --schedule step --accumulate 1 \
18 | --lr 2 --params_start 126 --params_end 226 --train_start 225 --wd $wd_psgd \
19 | --batch-size 128 --arch=$model \
20 | --save-dir=$DST/checkpoints --log-dir=$DST --log-name=from_last
21 |
22 | seed=0
23 | swa_lr=0.05 # 0.05 / 0.10
24 | model=VGG16BN
25 | wd_psgd=0.00005
26 | DST=swa_$model\_$datasets\_$seed\_$swa_lr
27 |
28 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \
29 | --opt SGD --extract Schmidt --schedule step --accumulate 1 \
30 | --lr 2 --params_start 161 --params_end 301 --train_start 300 --wd $wd_psgd \
31 | --batch-size 128 --arch=$model \
32 | --save-dir=$DST/checkpoints --log-dir=$DST --log-name=from_last
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ################################ CIFAR ###################################
4 | datasets=CIFAR100
5 | device=0
6 | model=VGG16BN # PreResNet164
7 | DST=results/$model\_$datasets
8 |
9 | CUDA_VISIBLE_DEVICES=$device python -u train_sgd_cifar.py --datasets $datasets \
10 | --arch=$model --epochs=200 --lr 0.1 \
11 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100
12 |
13 | lr=2
14 | end=101
15 | wd_psgd=0.00001
16 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \
17 | --opt SGD --extract Schmidt --schedule step \
18 | --lr $lr --params_start 0 --params_end $end --train_start -1 --wd $wd_psgd \
19 | --batch-size 128 --arch=$model \
20 | --save-dir=$DST/checkpoints --log-dir=$DST
21 |
22 |
23 | ################################ ImageNet ################################
24 | datasets=ImageNet
25 | device=0,1,2,3
26 |
27 | model=resnet18
28 | path=/home/datasets/ILSVRC2012/
29 | CUDA_VISIBLE_DEVICES=$device python3 train_sgd_imagenet.py -a $model \
30 | --epochs 90 --workers 8 --dist-url 'tcp://127.0.0.1:1234' \
31 | --dist-backend 'nccl' --multiprocessing-distributed \
32 | --world-size 1 --rank 0 $path
33 |
34 | # TWA 60+2
35 | wd_psgd=0.00001
36 | lr=0.3
37 | DST=save_resnet18
38 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 2 --datasets $datasets \
39 | --opt SGD --extract Schmidt --schedule step --worker 8 \
40 | --lr $lr --params_start 0 --params_end 301 --train_start -1 --wd $wd_psgd \
41 | --batch-size 256 --arch=$model \
42 | --save-dir=$DST --log-dir=$DST
43 |
44 | # TWA (DDP version) 60+2
45 | datasets=ImageNet
46 | device=0,1,2,3
47 |
48 | model=resnet18
49 | wd_psgd=0.00001
50 | lr=0.3
51 | DST=save_resnet18
52 | CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node 4 train_twa_ddp.py \
53 | --epochs 2 --datasets $datasets --opt SGD --schedule step --worker 8 \
54 | --lr $lr --params_start 0 --params_end 301 --train_start -1 --wd $wd_psgd \
55 | --batch-size 256 --arch $model --save-dir $DST --log-dir $DST
56 |
57 | # TWA 90+1
58 | wd_psgd=0.00001
59 | lr=0.03
60 | DST=save_resnet18
61 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 1 --datasets $datasets \
62 | --opt SGD --extract Schmidt --schedule linear --worker 8 \
63 | --lr $lr --params_start 301 --params_end 451 --train_start -1 --wd $wd_psgd \
64 | --batch-size 256 --arch=$model \
65 | --save-dir=$DST --log-dir=$DST
66 |
67 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | VGG model definition
3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
4 | """
5 |
6 | import math
7 | import torch.nn as nn
8 | import torchvision.transforms as transforms
9 |
10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN']
11 |
12 |
13 | def make_layers(cfg, batch_norm=False):
14 | layers = list()
15 | in_channels = 3
16 | for v in cfg:
17 | if v == 'M':
18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
19 | else:
20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
21 | if batch_norm:
22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
23 | else:
24 | layers += [conv2d, nn.ReLU(inplace=True)]
25 | in_channels = v
26 | return nn.Sequential(*layers)
27 |
28 |
29 | cfg = {
30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
32 | 512, 512, 512, 512, 'M'],
33 | }
34 |
35 |
36 | class VGG(nn.Module):
37 | def __init__(self, num_classes=10, depth=16, batch_norm=False):
38 | super(VGG, self).__init__()
39 | self.features = make_layers(cfg[depth], batch_norm)
40 | self.classifier = nn.Sequential(
41 | nn.Dropout(),
42 | nn.Linear(512, 512),
43 | nn.ReLU(True),
44 | nn.Dropout(),
45 | nn.Linear(512, 512),
46 | nn.ReLU(True),
47 | nn.Linear(512, num_classes),
48 | )
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
53 | m.weight.data.normal_(0, math.sqrt(2. / n))
54 | m.bias.data.zero_()
55 |
56 | def forward(self, x):
57 | x = self.features(x)
58 | x = x.view(x.size(0), -1)
59 | x = self.classifier(x)
60 | return x
61 |
62 |
63 | class Base:
64 | base = VGG
65 | args = list()
66 | kwargs = dict()
67 | transform_train = transforms.Compose([
68 | transforms.RandomHorizontalFlip(),
69 | transforms.RandomCrop(32, padding=4),
70 | transforms.ToTensor(),
71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
72 | ])
73 |
74 | transform_test = transforms.Compose([
75 | transforms.ToTensor(),
76 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
77 | ])
78 |
79 |
80 | class VGG16(Base):
81 | pass
82 |
83 |
84 | class VGG16BN(Base):
85 | kwargs = {'batch_norm': True}
86 |
87 |
88 | class VGG19(Base):
89 | kwargs = {'depth': 19}
90 |
91 |
92 | class VGG19BN(Base):
93 | kwargs = {'depth': 19, 'batch_norm': True}
--------------------------------------------------------------------------------
/swa/models/vgg.py:
--------------------------------------------------------------------------------
1 | """
2 | VGG model definition
3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
4 | """
5 |
6 | import math
7 | import torch.nn as nn
8 | import torchvision.transforms as transforms
9 |
10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN']
11 |
12 |
13 | def make_layers(cfg, batch_norm=False):
14 | layers = list()
15 | in_channels = 3
16 | for v in cfg:
17 | if v == 'M':
18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
19 | else:
20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
21 | if batch_norm:
22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
23 | else:
24 | layers += [conv2d, nn.ReLU(inplace=True)]
25 | in_channels = v
26 | return nn.Sequential(*layers)
27 |
28 |
29 | cfg = {
30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
32 | 512, 512, 512, 512, 'M'],
33 | }
34 |
35 |
36 | class VGG(nn.Module):
37 | def __init__(self, num_classes=10, depth=16, batch_norm=False):
38 | super(VGG, self).__init__()
39 | self.features = make_layers(cfg[depth], batch_norm)
40 | self.classifier = nn.Sequential(
41 | nn.Dropout(),
42 | nn.Linear(512, 512),
43 | nn.ReLU(True),
44 | nn.Dropout(),
45 | nn.Linear(512, 512),
46 | nn.ReLU(True),
47 | nn.Linear(512, num_classes),
48 | )
49 |
50 | for m in self.modules():
51 | if isinstance(m, nn.Conv2d):
52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
53 | m.weight.data.normal_(0, math.sqrt(2. / n))
54 | m.bias.data.zero_()
55 |
56 | def forward(self, x):
57 | x = self.features(x)
58 | x = x.view(x.size(0), -1)
59 | x = self.classifier(x)
60 | return x
61 |
62 |
63 | class Base:
64 | base = VGG
65 | args = list()
66 | kwargs = dict()
67 | transform_train = transforms.Compose([
68 | transforms.RandomHorizontalFlip(),
69 | transforms.RandomCrop(32, padding=4),
70 | transforms.ToTensor(),
71 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
73 | ])
74 |
75 | transform_test = transforms.Compose([
76 | transforms.ToTensor(),
77 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
78 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
79 | ])
80 |
81 |
82 | class VGG16(Base):
83 | pass
84 |
85 |
86 | class VGG16BN(Base):
87 | kwargs = {'batch_norm': True}
88 |
89 |
90 | class VGG19(Base):
91 | kwargs = {'depth': 19}
92 |
93 |
94 | class VGG19BN(Base):
95 | kwargs = {'depth': 19, 'batch_norm': True}
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TWA
2 | The code is the official implementation of our ICLR paper
3 | [Trainable Weight Averaging: Efficient Training by Optimizing Historical Solutions](https://openreview.net/pdf?id=8wbnpOJY-f). For the journal version, please refer to this [branch](https://github.com/nblt/TWA/tree/journal).
4 |
5 | We propose to conduct neural network training in a tiny subspace spanned by historical solutions. Such optimization is equivalent to performing weight averaging on these solutions with trainable coefficients (TWA), in contrast with the equal averaging coefficients as in [SWA](https://github.com/timgaripov/swa). We show that a good solution can emerge early in DNN's training by properly averaging historical solutions with TWA. In this way, we are able to achieve great training efficiency (e.g. saving over **30%** training epochs on CIFAR / ImageNet) by optimizing these historical solutions. We also provide an efficient and scalable framework for multi-node training. Besides, TWA is also able to improve finetune results from multiple training configurations, which we are currently focusing on. This [colab](https://colab.research.google.com/drive/1fxUJ0K8dd7V3gsozmKsHhfdYHhYVB-WZ?usp=sharing) provides an exploratory example we adapt from [Model Soups](https://github.com/mlfoundations/model-soups).
6 |
7 |
8 |
9 |

10 |
11 |
12 |
13 |
14 | ## Dependencies
15 |
16 | Install required dependencies:
17 |
18 | ```
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ## How to run
23 |
24 | ### TWA in tail stage training
25 | We first show that TWA could improve the performance of SWA in the original SWA setting, where the improvements are more significant when the tail learning rate `swa_lr` is larger.
26 | ```
27 | cd swa
28 | ```
29 | First, run SWA using original [code](https://github.com/timgaripov/swa):
30 | ```
31 | bash run.sh
32 | ```
33 | Then, we could perform TWA using:
34 | ```
35 | bash run_twa.sh
36 | ```
37 | The training configuration is easy to set as you need in the scripts.
38 |
39 | ### TWA in head stage training
40 | In this part, we conduct TWA in the head training stage, where we achieve considerably **30%-40%** epochs saving on CIFAR-10/100 and ImageNet, with a comparable or even better performance against regular training.
41 | We show sample usages in `run.sh`.
42 |
43 | For the first step, we conduct regular training for generating the historical solutions (for ImageNet training, the dataset need to be prepared at folder `path`). For example,
44 |
45 | ```
46 | datasets=CIFAR100
47 | model=VGG16BN
48 | DST=results/$model\_$datasets
49 |
50 | CUDA_VISIBLE_DEVICES=0 python -u train_sgd_cifar.py --datasets $datasets \
51 | --arch=$model --epochs=200 --lr 0.1 \
52 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100
53 | ```
54 | Then, we conduct TWA training for quickly composing a good solution utilizing historical solutions (note that here we only utilize the first 100 epoch checkpoints):
55 | ```
56 | CUDA_VISIBLE_DEVICES=0 python -u train_twa.py --epochs 10 --datasets $datasets \
57 | --opt SGD --extract Schmidt --schedule step \
58 | --lr 2 --params_start 0 --params_end 101 --train_start -1 --wd 0.00001 \
59 | --batch-size 128 --arch=$model \
60 | --save-dir=$DST/checkpoints --log-dir=$DST
61 | ```
62 |
63 | ## Citation
64 | If you find this work helpful, please cite:
65 | ```
66 | @inproceedings{
67 | li2023trainable,
68 | title={Trainable Weight Averaging: Efficient Training by Optimizing Historical Solutions},
69 | author={Tao Li and Zhehao Huang and Qinghua Tao and Yingwen Wu and Xiaolin Huang},
70 | booktitle={The Eleventh International Conference on Learning Representations},
71 | year={2023},
72 | url={https://openreview.net/forum?id=8wbnpOJY-f}
73 | }
74 | ```
75 |
--------------------------------------------------------------------------------
/swa/utils_swa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | num_classes_dict = {
5 | "CIFAR10":10,
6 | "CIFAR100":100,
7 | }
8 |
9 |
10 | def adjust_learning_rate(optimizer, lr):
11 | for param_group in optimizer.param_groups:
12 | param_group['lr'] = lr
13 | return lr
14 |
15 |
16 | def save_checkpoint(dir, epoch, **kwargs):
17 | state = {
18 | 'epoch': epoch,
19 | }
20 | state.update(kwargs)
21 | filepath = os.path.join(dir, 'checkpoint-%d.pt' % epoch)
22 | torch.save(state, filepath)
23 |
24 |
25 | def train_epoch(loader, model, criterion, optimizer):
26 | loss_sum = 0.0
27 | correct = 0.0
28 |
29 | model.train()
30 |
31 | for i, (input, target) in enumerate(loader):
32 | # input = input.cuda(async=True)
33 | # target = target.cuda(async=True)
34 | input = input.cuda()
35 | target = target.cuda()
36 |
37 | input_var = torch.autograd.Variable(input)
38 | target_var = torch.autograd.Variable(target)
39 |
40 | output = model(input_var)
41 | loss = criterion(output, target_var)
42 |
43 | optimizer.zero_grad()
44 | loss.backward()
45 | optimizer.step()
46 |
47 | loss_sum += loss.item() * input.size(0)
48 | pred = output.data.max(1, keepdim=True)[1]
49 | correct += pred.eq(target_var.data.view_as(pred)).sum().item()
50 |
51 | return {
52 | 'loss': loss_sum / len(loader.dataset),
53 | 'accuracy': correct / len(loader.dataset) * 100.0,
54 | }
55 |
56 |
57 | def eval(loader, model, criterion):
58 | loss_sum = 0.0
59 | correct = 0.0
60 |
61 | model.eval()
62 |
63 | for i, (input, target) in enumerate(loader):
64 | input = input.cuda()
65 | target = target.cuda()
66 | input_var = torch.autograd.Variable(input)
67 | target_var = torch.autograd.Variable(target)
68 |
69 | output = model(input_var)
70 | loss = criterion(output, target_var)
71 |
72 | loss_sum += loss.item() * input.size(0)
73 | pred = output.data.max(1, keepdim=True)[1]
74 | correct += pred.eq(target_var.data.view_as(pred)).sum().item()
75 |
76 | return {
77 | 'loss': loss_sum / len(loader.dataset),
78 | 'accuracy': correct / len(loader.dataset) * 100.0,
79 | }
80 |
81 |
82 | def moving_average(net1, net2, alpha=1):
83 | for param1, param2 in zip(net1.parameters(), net2.parameters()):
84 | param1.data *= (1.0 - alpha)
85 | param1.data += param2.data * alpha
86 |
87 |
88 | def _check_bn(module, flag):
89 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
90 | flag[0] = True
91 |
92 |
93 | def check_bn(model):
94 | flag = [False]
95 | model.apply(lambda module: _check_bn(module, flag))
96 | return flag[0]
97 |
98 |
99 | def reset_bn(module):
100 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
101 | module.running_mean = torch.zeros_like(module.running_mean)
102 | module.running_var = torch.ones_like(module.running_var)
103 |
104 |
105 | def _get_momenta(module, momenta):
106 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
107 | momenta[module] = module.momentum
108 |
109 |
110 | def _set_momenta(module, momenta):
111 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
112 | module.momentum = momenta[module]
113 |
114 |
115 | def bn_update(loader, model):
116 | """
117 | BatchNorm buffers update (if any).
118 | Performs 1 epochs to estimate buffers average using train dataset.
119 |
120 | :param loader: train dataset loader for buffers average estimation.
121 | :param model: model being update
122 | :return: None
123 | """
124 | if not check_bn(model):
125 | return
126 | model.train()
127 | momenta = {}
128 | model.apply(reset_bn)
129 | model.apply(lambda module: _get_momenta(module, momenta))
130 | n = 0
131 | for input, _ in loader:
132 | input = input.cuda()
133 | input_var = torch.autograd.Variable(input)
134 | b = input_var.data.size(0)
135 |
136 | momentum = b / (n + b)
137 | for module in momenta.keys():
138 | module.momentum = momentum
139 |
140 | model(input_var)
141 | n += b
142 |
143 | model.apply(lambda module: _set_momenta(module, momenta))
144 |
--------------------------------------------------------------------------------
/models/preresnet.py:
--------------------------------------------------------------------------------
1 | """
2 | PreResNet model definition
3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py
4 | """
5 |
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 | import math
9 |
10 | __all__ = ['PreResNet18', 'PreResNet110', 'PreResNet164']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.bn1 = nn.BatchNorm2d(inplanes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.bn1(x)
35 | out = self.relu(out)
36 | out = self.conv1(out)
37 |
38 | out = self.bn2(out)
39 | out = self.relu(out)
40 | out = self.conv2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.bn1 = nn.BatchNorm2d(inplanes)
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | padding=1, bias=False)
60 | self.bn3 = nn.BatchNorm2d(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.bn1(x)
70 | out = self.relu(out)
71 | out = self.conv1(out)
72 |
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 | out = self.conv2(out)
76 |
77 | out = self.bn3(out)
78 | out = self.relu(out)
79 | out = self.conv3(out)
80 |
81 | if self.downsample is not None:
82 | residual = self.downsample(x)
83 |
84 | out += residual
85 |
86 | return out
87 |
88 |
89 | class PreResNet(nn.Module):
90 |
91 | def __init__(self, num_classes=10, depth=110):
92 | super(PreResNet, self).__init__()
93 | if depth >= 44:
94 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2'
95 | n = (depth - 2) // 9
96 | block = Bottleneck
97 | else:
98 | print ('depth:', (depth - 2) % 6)
99 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
100 | n = (depth - 2) // 6
101 | block = BasicBlock
102 |
103 |
104 | self.inplanes = 16
105 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
106 | bias=False)
107 | self.layer1 = self._make_layer(block, 16, n)
108 | self.layer2 = self._make_layer(block, 32, n, stride=2)
109 | self.layer3 = self._make_layer(block, 64, n, stride=2)
110 | self.bn = nn.BatchNorm2d(64 * block.expansion)
111 | self.relu = nn.ReLU(inplace=True)
112 | self.avgpool = nn.AvgPool2d(8)
113 | self.fc = nn.Linear(64 * block.expansion, num_classes)
114 |
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | nn.Conv2d(self.inplanes, planes * block.expansion,
128 | kernel_size=1, stride=stride, bias=False),
129 | )
130 |
131 | layers = list()
132 | layers.append(block(self.inplanes, planes, stride, downsample))
133 | self.inplanes = planes * block.expansion
134 | for i in range(1, blocks):
135 | layers.append(block(self.inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 |
142 | x = self.layer1(x) # 32x32
143 | x = self.layer2(x) # 16x16
144 | x = self.layer3(x) # 8x8
145 | x = self.bn(x)
146 | x = self.relu(x)
147 |
148 | x = self.avgpool(x)
149 | x = x.view(x.size(0), -1)
150 | x = self.fc(x)
151 |
152 | return x
153 |
154 |
155 | class PreResNet18:
156 | base = PreResNet
157 | args = list()
158 | kwargs = {'depth': 18}
159 |
160 | class PreResNet110:
161 | base = PreResNet
162 | args = list()
163 | kwargs = {'depth': 110}
164 |
165 | class PreResNet164:
166 | base = PreResNet
167 | args = list()
168 | kwargs = {'depth': 164}
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """resnet in pytorch
2 |
3 |
4 |
5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
6 |
7 | Deep Residual Learning for Image Recognition
8 | https://arxiv.org/abs/1512.03385v1
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | class BasicBlock(nn.Module):
15 | """Basic Block for resnet 18 and resnet 34
16 |
17 | """
18 |
19 | #BasicBlock and BottleNeck block
20 | #have different output size
21 | #we use class attribute expansion
22 | #to distinct
23 | expansion = 1
24 |
25 | def __init__(self, in_channels, out_channels, stride=1):
26 | super().__init__()
27 |
28 | #residual function
29 | self.residual_function = nn.Sequential(
30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
31 | nn.BatchNorm2d(out_channels),
32 | nn.ReLU(inplace=True),
33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
35 | )
36 |
37 | #shortcut
38 | self.shortcut = nn.Sequential()
39 |
40 | #the shortcut output dimension is not the same with residual function
41 | #use 1*1 convolution to match the dimension
42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
43 | self.shortcut = nn.Sequential(
44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion)
46 | )
47 |
48 | def forward(self, x):
49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
50 |
51 | class BottleNeck(nn.Module):
52 | """Residual block for resnet over 50 layers
53 |
54 | """
55 | expansion = 4
56 | def __init__(self, in_channels, out_channels, stride=1):
57 | super().__init__()
58 | self.residual_function = nn.Sequential(
59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
60 | nn.BatchNorm2d(out_channels),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
63 | nn.BatchNorm2d(out_channels),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion),
67 | )
68 |
69 | self.shortcut = nn.Sequential()
70 |
71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
72 | self.shortcut = nn.Sequential(
73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion)
75 | )
76 |
77 | def forward(self, x):
78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
79 |
80 | class ResNet(nn.Module):
81 |
82 | def __init__(self, block, num_block, num_classes=100):
83 | super().__init__()
84 |
85 | self.in_channels = 64
86 |
87 | self.conv1 = nn.Sequential(
88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
89 | nn.BatchNorm2d(64),
90 | nn.ReLU(inplace=True))
91 | #we use a different inputsize than the original paper
92 | #so conv2_x's stride is 1
93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
98 | self.fc = nn.Linear(512 * block.expansion, num_classes)
99 |
100 | def _make_layer(self, block, out_channels, num_blocks, stride):
101 | """make resnet layers(by layer i didnt mean this 'layer' was the
102 | same as a neuron netowork layer, ex. conv layer), one layer may
103 | contain more than one residual block
104 |
105 | Args:
106 | block: block type, basic block or bottle neck block
107 | out_channels: output depth channel number of this layer
108 | num_blocks: how many blocks per layer
109 | stride: the stride of the first block of this layer
110 |
111 | Return:
112 | return a resnet layer
113 | """
114 |
115 | # we have num_block blocks per layer, the first block
116 | # could be 1 or 2, other blocks would always be 1
117 | strides = [stride] + [1] * (num_blocks - 1)
118 | layers = []
119 | for stride in strides:
120 | layers.append(block(self.in_channels, out_channels, stride))
121 | self.in_channels = out_channels * block.expansion
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, x):
126 | output = self.conv1(x)
127 | output = self.conv2_x(output)
128 | output = self.conv3_x(output)
129 | output = self.conv4_x(output)
130 | output = self.conv5_x(output)
131 | output = self.avg_pool(output)
132 | output = output.view(output.size(0), -1)
133 | output = self.fc(output)
134 |
135 | return output
136 |
137 | class resnet18:
138 | base = ResNet
139 | args = list()
140 | kwargs = {'block': BasicBlock, 'num_block': [2, 2, 2, 2]}
141 |
142 | # def resnet18():
143 | # """ return a ResNet 18 object
144 | # """
145 | # kwargs = {}
146 | # return ResNet(BasicBlock, [2, 2, 2, 2])
147 |
148 | def resnet34():
149 | """ return a ResNet 34 object
150 | """
151 | return ResNet(BasicBlock, [3, 4, 6, 3])
152 |
153 | def resnet50():
154 | """ return a ResNet 50 object
155 | """
156 | return ResNet(BottleNeck, [3, 4, 6, 3])
157 |
158 | def resnet101():
159 | """ return a ResNet 101 object
160 | """
161 | return ResNet(BottleNeck, [3, 4, 23, 3])
162 |
163 | def resnet152():
164 | """ return a ResNet 152 object
165 | """
166 | return ResNet(BottleNeck, [3, 8, 36, 3])
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/swa/models/preresnet.py:
--------------------------------------------------------------------------------
1 | """
2 | PreResNet model definition
3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py
4 | """
5 |
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 | import math
9 |
10 | __all__ = ['PreResNet110', 'PreResNet164']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.bn1 = nn.BatchNorm2d(inplanes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.bn1(x)
35 | out = self.relu(out)
36 | out = self.conv1(out)
37 |
38 | out = self.bn2(out)
39 | out = self.relu(out)
40 | out = self.conv2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | out += residual
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.bn1 = nn.BatchNorm2d(inplanes)
56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
59 | padding=1, bias=False)
60 | self.bn3 = nn.BatchNorm2d(planes)
61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
62 | self.relu = nn.ReLU(inplace=True)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.bn1(x)
70 | out = self.relu(out)
71 | out = self.conv1(out)
72 |
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 | out = self.conv2(out)
76 |
77 | out = self.bn3(out)
78 | out = self.relu(out)
79 | out = self.conv3(out)
80 |
81 | if self.downsample is not None:
82 | residual = self.downsample(x)
83 |
84 | out += residual
85 |
86 | return out
87 |
88 |
89 | class PreResNet(nn.Module):
90 |
91 | def __init__(self, num_classes=10, depth=110):
92 | super(PreResNet, self).__init__()
93 | if depth >= 44:
94 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2'
95 | n = (depth - 2) // 9
96 | block = Bottleneck
97 | else:
98 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
99 | n = (depth - 2) // 6
100 | block = BasicBlock
101 |
102 |
103 | self.inplanes = 16
104 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
105 | bias=False)
106 | self.layer1 = self._make_layer(block, 16, n)
107 | self.layer2 = self._make_layer(block, 32, n, stride=2)
108 | self.layer3 = self._make_layer(block, 64, n, stride=2)
109 | self.bn = nn.BatchNorm2d(64 * block.expansion)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.avgpool = nn.AvgPool2d(8)
112 | self.fc = nn.Linear(64 * block.expansion, num_classes)
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 |
122 | def _make_layer(self, block, planes, blocks, stride=1):
123 | downsample = None
124 | if stride != 1 or self.inplanes != planes * block.expansion:
125 | downsample = nn.Sequential(
126 | nn.Conv2d(self.inplanes, planes * block.expansion,
127 | kernel_size=1, stride=stride, bias=False),
128 | )
129 |
130 | layers = list()
131 | layers.append(block(self.inplanes, planes, stride, downsample))
132 | self.inplanes = planes * block.expansion
133 | for i in range(1, blocks):
134 | layers.append(block(self.inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | x = self.conv1(x)
140 |
141 | x = self.layer1(x) # 32x32
142 | x = self.layer2(x) # 16x16
143 | x = self.layer3(x) # 8x8
144 | x = self.bn(x)
145 | x = self.relu(x)
146 |
147 | x = self.avgpool(x)
148 | x = x.view(x.size(0), -1)
149 | x = self.fc(x)
150 |
151 | return x
152 |
153 |
154 | class PreResNet110:
155 | base = PreResNet
156 | args = list()
157 | kwargs = {'depth': 110}
158 | transform_train = transforms.Compose([
159 | transforms.RandomCrop(32, padding=4),
160 | transforms.RandomHorizontalFlip(),
161 | transforms.ToTensor(),
162 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
163 | ])
164 | transform_test = transforms.Compose([
165 | transforms.ToTensor(),
166 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
167 | ])
168 |
169 | class PreResNet164:
170 | base = PreResNet
171 | args = list()
172 | kwargs = {'depth': 164}
173 | transform_train = transforms.Compose([
174 | transforms.RandomCrop(32, padding=4),
175 | transforms.RandomHorizontalFlip(),
176 | transforms.ToTensor(),
177 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
178 | ])
179 | transform_test = transforms.Compose([
180 | transforms.ToTensor(),
181 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
182 | ])
183 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim as optim
6 | import torch.utils.data
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | import torchvision.datasets as datasets
10 | import torchvision.models as models_imagenet
11 |
12 | import numpy as np
13 | import random
14 | import os
15 | import time
16 | import models
17 | import sys
18 |
19 | def set_seed(seed=1):
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | class Logger(object):
28 | def __init__(self,fileN ="Default.log"):
29 | self.terminal = sys.stdout
30 | self.log = open(fileN,"a")
31 |
32 | def write(self,message):
33 | self.terminal.write(message)
34 | self.log.write(message)
35 |
36 | def flush(self):
37 | pass
38 |
39 | def adjust_learning_rate(optimizer, lr):
40 | for param_group in optimizer.param_groups:
41 | param_group['lr'] = lr
42 | return lr
43 |
44 | ################################ datasets #######################################
45 |
46 | import torchvision.transforms as transforms
47 | import torchvision.datasets as datasets
48 | from torch.utils.data import DataLoader
49 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
50 |
51 | def get_datasets(args):
52 | if args.datasets == 'CIFAR10':
53 | print ('cifar10 dataset!')
54 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
55 |
56 | train_loader = torch.utils.data.DataLoader(
57 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
58 | transforms.RandomHorizontalFlip(),
59 | transforms.RandomCrop(32, 4),
60 | transforms.ToTensor(),
61 | normalize,
62 | ]), download=True),
63 | batch_size=args.batch_size, shuffle=True,
64 | num_workers=args.workers, pin_memory=True)
65 |
66 | val_loader = torch.utils.data.DataLoader(
67 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
68 | transforms.ToTensor(),
69 | normalize,
70 | ])),
71 | batch_size=128, shuffle=False,
72 | num_workers=args.workers, pin_memory=True)
73 |
74 | elif args.datasets == 'CIFAR100':
75 | print ('cifar100 dataset!')
76 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
77 |
78 | train_loader = torch.utils.data.DataLoader(
79 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
80 | transforms.RandomHorizontalFlip(),
81 | transforms.RandomCrop(32, 4),
82 | transforms.ToTensor(),
83 | normalize,
84 | ]), download=True),
85 | batch_size=args.batch_size, shuffle=True,
86 | num_workers=args.workers, pin_memory=True)
87 |
88 | val_loader = torch.utils.data.DataLoader(
89 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
90 | transforms.ToTensor(),
91 | normalize,
92 | ])),
93 | batch_size=128, shuffle=False,
94 | num_workers=args.workers, pin_memory=True)
95 |
96 | elif args.datasets == 'ImageNet':
97 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train')
98 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val')
99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
100 | std=[0.229, 0.224, 0.225])
101 |
102 | train_dataset = datasets.ImageFolder(
103 | traindir,
104 | transforms.Compose([
105 | transforms.RandomResizedCrop(224),
106 | transforms.RandomHorizontalFlip(),
107 | transforms.ToTensor(),
108 | normalize,
109 | ]))
110 |
111 | train_loader = torch.utils.data.DataLoader(
112 | train_dataset, batch_size=args.batch_size, shuffle=True,
113 | num_workers=args.workers, pin_memory=True)
114 |
115 | val_loader = torch.utils.data.DataLoader(
116 | datasets.ImageFolder(valdir, transforms.Compose([
117 | transforms.Resize(256),
118 | transforms.CenterCrop(224),
119 | transforms.ToTensor(),
120 | normalize,
121 | ])),
122 | batch_size=args.batch_size, shuffle=False,
123 | num_workers=args.workers)
124 |
125 | return train_loader, val_loader
126 |
127 |
128 | def get_imagenet_dataset():
129 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train')
130 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val')
131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
132 | std=[0.229, 0.224, 0.225])
133 |
134 | train_dataset = datasets.ImageFolder(
135 | traindir,
136 | transforms.Compose([
137 | transforms.RandomResizedCrop(224),
138 | transforms.RandomHorizontalFlip(),
139 | transforms.ToTensor(),
140 | normalize,
141 | ]))
142 |
143 | val_dataset = datasets.ImageFolder(
144 | valdir,
145 | transforms.Compose([
146 | transforms.Resize(256),
147 | transforms.CenterCrop(224),
148 | transforms.ToTensor(),
149 | normalize,
150 | ]))
151 | return train_dataset, val_dataset
152 |
153 | ################################ training & evaluation #######################################
154 |
155 | def eval_model(loader, model, criterion):
156 | loss_sum = 0.0
157 | correct = 0.0
158 |
159 | model.eval()
160 |
161 | for i, (input, target) in enumerate(loader):
162 | input = input.cuda()
163 | target = target.cuda()
164 |
165 | output = model(input)
166 | loss = criterion(output, target)
167 |
168 | loss_sum += loss.item() * input.size(0)
169 | pred = output.data.max(1, keepdim=True)[1]
170 | correct += pred.eq(target.data.view_as(pred)).sum().item()
171 |
172 | return {
173 | 'loss': loss_sum / len(loader.dataset),
174 | 'accuracy': correct / len(loader.dataset) * 100.0,
175 | }
176 |
177 | def bn_update(loader, model):
178 | model.train()
179 | for i, (input, target) in enumerate(loader):
180 | target = target.cuda()
181 | input_var = input.cuda()
182 | target_var = target
183 |
184 | # compute output
185 | output = model(input_var)
186 |
187 | def get_model(args):
188 | print('Model: {}'.format(args.arch))
189 |
190 | if args.datasets == 'ImageNet':
191 | return models_imagenet.__dict__[args.arch]()
192 |
193 | if args.datasets == 'CIFAR10':
194 | num_classes = 10
195 | elif args.datasets == 'CIFAR100':
196 | num_classes = 100
197 |
198 | model_cfg = getattr(models, args.arch)
199 |
200 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
201 |
202 |
203 |
--------------------------------------------------------------------------------
/swa/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim as optim
6 | import torch.utils.data
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | import torchvision.datasets as datasets
10 | import torchvision.models as models_imagenet
11 |
12 | import numpy as np
13 | import random
14 | import os
15 | import time
16 | import models
17 | import sys
18 |
19 | def set_seed(seed=1):
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.backends.cudnn.benchmark = False
26 |
27 | class Logger(object):
28 | def __init__(self,fileN ="Default.log"):
29 | self.terminal = sys.stdout
30 | self.log = open(fileN,"a")
31 |
32 | def write(self,message):
33 | self.terminal.write(message)
34 | self.log.write(message)
35 |
36 | def flush(self):
37 | pass
38 |
39 | def adjust_learning_rate(optimizer, lr):
40 | for param_group in optimizer.param_groups:
41 | param_group['lr'] = lr
42 | return lr
43 |
44 | ################################ datasets #######################################
45 |
46 | import torchvision.transforms as transforms
47 | import torchvision.datasets as datasets
48 | from torch.utils.data import DataLoader
49 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
50 |
51 | def get_datasets(args):
52 | if args.datasets == 'CIFAR10':
53 | print ('cifar10 dataset!')
54 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
55 |
56 | train_loader = torch.utils.data.DataLoader(
57 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([
58 | transforms.RandomHorizontalFlip(),
59 | transforms.RandomCrop(32, 4),
60 | transforms.ToTensor(),
61 | normalize,
62 | ]), download=True),
63 | batch_size=args.batch_size, shuffle=True,
64 | num_workers=args.workers, pin_memory=True)
65 |
66 | val_loader = torch.utils.data.DataLoader(
67 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([
68 | transforms.ToTensor(),
69 | normalize,
70 | ])),
71 | batch_size=128, shuffle=False,
72 | num_workers=args.workers, pin_memory=True)
73 |
74 | elif args.datasets == 'CIFAR100':
75 | print ('cifar100 dataset!')
76 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
77 |
78 | train_loader = torch.utils.data.DataLoader(
79 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([
80 | transforms.RandomHorizontalFlip(),
81 | transforms.RandomCrop(32, 4),
82 | transforms.ToTensor(),
83 | normalize,
84 | ]), download=True),
85 | batch_size=args.batch_size, shuffle=True,
86 | num_workers=args.workers, pin_memory=True)
87 |
88 | val_loader = torch.utils.data.DataLoader(
89 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([
90 | transforms.ToTensor(),
91 | normalize,
92 | ])),
93 | batch_size=128, shuffle=False,
94 | num_workers=args.workers, pin_memory=True)
95 |
96 | elif args.datasets == 'ImageNet':
97 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train')
98 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val')
99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
100 | std=[0.229, 0.224, 0.225])
101 |
102 | train_dataset = datasets.ImageFolder(
103 | traindir,
104 | transforms.Compose([
105 | transforms.RandomResizedCrop(224),
106 | transforms.RandomHorizontalFlip(),
107 | transforms.ToTensor(),
108 | normalize,
109 | ]))
110 |
111 | train_loader = torch.utils.data.DataLoader(
112 | train_dataset, batch_size=args.batch_size, shuffle=True,
113 | num_workers=args.workers, pin_memory=True)
114 |
115 | val_loader = torch.utils.data.DataLoader(
116 | datasets.ImageFolder(valdir, transforms.Compose([
117 | transforms.Resize(256),
118 | transforms.CenterCrop(224),
119 | transforms.ToTensor(),
120 | normalize,
121 | ])),
122 | batch_size=args.batch_size, shuffle=False,
123 | num_workers=args.workers)
124 |
125 | return train_loader, val_loader
126 |
127 |
128 | def get_imagenet_dataset():
129 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train')
130 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val')
131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
132 | std=[0.229, 0.224, 0.225])
133 |
134 | train_dataset = datasets.ImageFolder(
135 | traindir,
136 | transforms.Compose([
137 | transforms.RandomResizedCrop(224),
138 | transforms.RandomHorizontalFlip(),
139 | transforms.ToTensor(),
140 | normalize,
141 | ]))
142 |
143 | val_dataset = datasets.ImageFolder(
144 | valdir,
145 | transforms.Compose([
146 | transforms.Resize(256),
147 | transforms.CenterCrop(224),
148 | transforms.ToTensor(),
149 | normalize,
150 | ]))
151 | return train_dataset, val_dataset
152 |
153 | ################################ training & evaluation #######################################
154 |
155 | def eval_model(loader, model, criterion):
156 | loss_sum = 0.0
157 | correct = 0.0
158 |
159 | model.eval()
160 |
161 | for i, (input, target) in enumerate(loader):
162 | input = input.cuda()
163 | target = target.cuda()
164 |
165 | output = model(input)
166 | loss = criterion(output, target)
167 |
168 | loss_sum += loss.item() * input.size(0)
169 | pred = output.data.max(1, keepdim=True)[1]
170 | correct += pred.eq(target_var.data.view_as(pred)).sum().item()
171 |
172 | return {
173 | 'loss': loss_sum / len(loader.dataset),
174 | 'accuracy': correct / len(loader.dataset) * 100.0,
175 | }
176 |
177 | def bn_update(loader, model):
178 | model.train()
179 | for i, (input, target) in enumerate(loader):
180 | target = target.cuda()
181 | input_var = input.cuda()
182 | target_var = target
183 |
184 | # compute output
185 | output = model(input_var)
186 |
187 | def get_model(args):
188 | print('Model: {}'.format(args.arch))
189 |
190 | if args.datasets == 'ImageNet':
191 | return models_imagenet.__dict__[args.arch]()
192 |
193 | if args.datasets == 'CIFAR10':
194 | num_classes = 10
195 | elif args.datasets == 'CIFAR100':
196 | num_classes = 100
197 |
198 | model_cfg = getattr(models, args.arch)
199 |
200 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
201 |
202 |
203 |
--------------------------------------------------------------------------------
/swa/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | import models
9 | import utils_swa as utils
10 | import tabulate
11 |
12 |
13 | parser = argparse.ArgumentParser(description='SGD/SWA training')
14 | parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)')
15 |
16 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)')
17 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH',
18 | help='path to datasets location (default: None)')
19 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)')
20 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)')
21 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL',
22 | help='model name (default: None)')
23 |
24 | parser.add_argument('--resume', type=str, default=None, metavar='CKPT',
25 | help='checkpoint to resume training from (default: None)')
26 |
27 | parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)')
28 | parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)')
29 | parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)')
30 | parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)')
31 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
32 | parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)')
33 |
34 | parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)')
35 | parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)')
36 | parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)')
37 | parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N',
38 | help='SWA model collection frequency/cycle length in epochs (default: 1)')
39 |
40 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
41 |
42 | args = parser.parse_args()
43 |
44 |
45 | if not os.path.exists(args.dir):
46 | os.makedirs(args.dir)
47 |
48 | save_dir = os.path.join(args.dir, 'checkpoints')
49 | if not os.path.exists(save_dir):
50 | os.makedirs(save_dir)
51 |
52 | print('Preparing directory %s' % args.dir)
53 | os.makedirs(args.dir, exist_ok=True)
54 | with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
55 | f.write(' '.join(sys.argv))
56 | f.write('\n')
57 |
58 | torch.backends.cudnn.benchmark = True
59 | torch.manual_seed(args.seed)
60 | torch.cuda.manual_seed(args.seed)
61 |
62 | print('Using model %s' % args.model)
63 | model_cfg = getattr(models, args.model)
64 |
65 | print('Loading dataset %s from %s' % (args.dataset, args.data_path))
66 | ds = getattr(torchvision.datasets, args.dataset)
67 | path = os.path.join(args.data_path, args.dataset.lower())
68 | train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train)
69 | test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test)
70 | loaders = {
71 | 'train': torch.utils.data.DataLoader(
72 | train_set,
73 | batch_size=args.batch_size,
74 | shuffle=True,
75 | num_workers=args.num_workers,
76 | pin_memory=True
77 | ),
78 | 'test': torch.utils.data.DataLoader(
79 | test_set,
80 | batch_size=args.batch_size,
81 | shuffle=False,
82 | num_workers=args.num_workers,
83 | pin_memory=True
84 | )
85 | }
86 | print (train_set)
87 | # num_classes = max(train_set.train_labels) + 1
88 | num_classes = utils.num_classes_dict[args.dataset]
89 |
90 | print('Preparing model')
91 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
92 | model.cuda()
93 |
94 |
95 | if args.swa:
96 | print('SWA training')
97 | swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
98 | swa_model.cuda()
99 | swa_n = 0
100 | else:
101 | print('SGD training')
102 |
103 |
104 | def schedule(epoch):
105 | t = (epoch) / (args.swa_start if args.swa else args.epochs)
106 | lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01
107 | if t <= 0.5:
108 | factor = 1.0
109 | elif t <= 0.9:
110 | factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
111 | else:
112 | factor = lr_ratio
113 | return args.lr_init * factor
114 |
115 |
116 | criterion = F.cross_entropy
117 | optimizer = torch.optim.SGD(
118 | model.parameters(),
119 | lr=args.lr_init,
120 | momentum=args.momentum,
121 | weight_decay=args.wd
122 | )
123 |
124 | start_epoch = 0
125 | if args.resume is not None:
126 | print('Resume training from %s' % args.resume)
127 | checkpoint = torch.load(args.resume)
128 | start_epoch = checkpoint['epoch']
129 | model.load_state_dict(checkpoint['state_dict'])
130 | optimizer.load_state_dict(checkpoint['optimizer'])
131 | if args.swa:
132 | swa_state_dict = checkpoint['swa_state_dict']
133 | if swa_state_dict is not None:
134 | swa_model.load_state_dict(swa_state_dict)
135 | swa_n_ckpt = checkpoint['swa_n']
136 | if swa_n_ckpt is not None:
137 | swa_n = swa_n_ckpt
138 |
139 | # print (utils.eval(loaders['train'], swa_model, criterion))
140 | # print (utils.eval(loaders['test'], swa_model, criterion))
141 | # sys.kill()
142 |
143 | columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
144 | if args.swa:
145 | columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:]
146 | swa_res = {'loss': None, 'accuracy': None}
147 |
148 | utils.save_checkpoint(
149 | args.dir,
150 | start_epoch,
151 | state_dict=model.state_dict(),
152 | swa_state_dict=swa_model.state_dict() if args.swa else None,
153 | swa_n=swa_n if args.swa else None,
154 | optimizer=optimizer.state_dict()
155 | )
156 |
157 | # DLDR sampling
158 | sample_idx = 0
159 | torch.save(model.state_dict(), os.path.join(save_dir, str(sample_idx) + '.pt'))
160 |
161 | for epoch in range(start_epoch, args.epochs):
162 | time_ep = time.time()
163 |
164 | lr = schedule(epoch)
165 | utils.adjust_learning_rate(optimizer, lr)
166 | train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer)
167 | if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
168 | test_res = utils.eval(loaders['test'], model, criterion)
169 | else:
170 | test_res = {'loss': None, 'accuracy': None}
171 |
172 | if args.swa and (epoch + 1) >= args.swa_start and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0:
173 | utils.moving_average(swa_model, model, 1.0 / (swa_n + 1))
174 | swa_n += 1
175 | if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
176 | utils.bn_update(loaders['train'], swa_model)
177 | swa_res = utils.eval(loaders['test'], swa_model, criterion)
178 | else:
179 | swa_res = {'loss': None, 'accuracy': None}
180 |
181 | if (epoch + 1) % args.save_freq == 0:
182 | utils.save_checkpoint(
183 | args.dir,
184 | epoch + 1,
185 | state_dict=model.state_dict(),
186 | swa_state_dict=swa_model.state_dict() if args.swa else None,
187 | swa_n=swa_n if args.swa else None,
188 | optimizer=optimizer.state_dict()
189 | )
190 |
191 | # DLDR sampling
192 | sample_idx += 1
193 | torch.save(model.state_dict(), os.path.join(save_dir, str(sample_idx) + '.pt'))
194 |
195 | time_ep = time.time() - time_ep
196 | values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep]
197 | if args.swa:
198 | values = values[:-1] + [swa_res['loss'], swa_res['accuracy']] + values[-1:]
199 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
200 | if epoch % 40 == 0:
201 | table = table.split('\n')
202 | table = '\n'.join([table[1]] + table)
203 | else:
204 | table = table.split('\n')[2]
205 | print(table)
206 |
207 | if args.epochs % args.save_freq != 0:
208 | utils.save_checkpoint(
209 | args.dir,
210 | args.epochs,
211 | state_dict=model.state_dict(),
212 | swa_state_dict=swa_model.state_dict() if args.swa else None,
213 | swa_n=swa_n if args.swa else None,
214 | optimizer=optimizer.state_dict()
215 | )
216 |
217 |
--------------------------------------------------------------------------------
/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | WideResNet model definition
3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
4 | """
5 |
6 | import torchvision.transforms as transforms
7 | import torch.nn as nn
8 | import torch.nn.init as init
9 | import torch.nn.functional as F
10 | import math
11 |
12 | __all__ = ['WideResNet28x10', 'WideResNet16x8']
13 |
14 |
15 | # def conv3x3(in_planes, out_planes, stride=1):
16 | # return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
17 |
18 |
19 | # def conv_init(m):
20 | # classname = m.__class__.__name__
21 | # if classname.find('Conv') != -1:
22 | # init.xavier_uniform(m.weight, gain=math.sqrt(2))
23 | # init.constant(m.bias, 0)
24 | # elif classname.find('BatchNorm') != -1:
25 | # init.constant(m.weight, 1)
26 | # init.constant(m.bias, 0)
27 |
28 |
29 | # class WideBasic(nn.Module):
30 | # def __init__(self, in_planes, planes, dropout_rate, stride=1):
31 | # super(WideBasic, self).__init__()
32 | # self.bn1 = nn.BatchNorm2d(in_planes)
33 | # self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
34 | # self.dropout = nn.Dropout(p=dropout_rate)
35 | # self.bn2 = nn.BatchNorm2d(planes)
36 | # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
37 |
38 | # self.shortcut = nn.Sequential()
39 | # if stride != 1 or in_planes != planes:
40 | # self.shortcut = nn.Sequential(
41 | # nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
42 | # )
43 |
44 | # def forward(self, x):
45 | # out = self.dropout(self.conv1(F.relu(self.bn1(x))))
46 | # out = self.conv2(F.relu(self.bn2(out)))
47 | # out += self.shortcut(x)
48 |
49 | # return out
50 |
51 |
52 | # class WideResNet(nn.Module):
53 | # def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0.):
54 | # super(WideResNet, self).__init__()
55 | # self.in_planes = 16
56 |
57 | # assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
58 | # n = (depth - 4) / 6
59 | # k = widen_factor
60 |
61 | # nstages = [16, 16 * k, 32 * k, 64 * k]
62 |
63 | # self.conv1 = conv3x3(3, nstages[0])
64 | # self.layer1 = self._wide_layer(WideBasic, nstages[1], n, dropout_rate, stride=1)
65 | # self.layer2 = self._wide_layer(WideBasic, nstages[2], n, dropout_rate, stride=2)
66 | # self.layer3 = self._wide_layer(WideBasic, nstages[3], n, dropout_rate, stride=2)
67 | # self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9)
68 | # self.linear = nn.Linear(nstages[3], num_classes)
69 |
70 | # def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
71 | # strides = [stride] + [1] * int(num_blocks - 1)
72 | # layers = []
73 |
74 | # for stride in strides:
75 | # layers.append(block(self.in_planes, planes, dropout_rate, stride))
76 | # self.in_planes = planes
77 |
78 | # return nn.Sequential(*layers)
79 |
80 | # def forward(self, x):
81 | # out = self.conv1(x)
82 | # out = self.layer1(out)
83 | # out = self.layer2(out)
84 | # out = self.layer3(out)
85 | # out = F.relu(self.bn1(out))
86 | # out = F.avg_pool2d(out, 8)
87 | # out = out.view(out.size(0), -1)
88 | # out = self.linear(out)
89 |
90 | # return out
91 |
92 | from collections import OrderedDict
93 |
94 | import torch
95 | import torch.nn as nn
96 | import torch.nn.functional as F
97 |
98 |
99 | class BasicUnit(nn.Module):
100 | def __init__(self, channels: int, dropout: float):
101 | super(BasicUnit, self).__init__()
102 | self.block = nn.Sequential(OrderedDict([
103 | ("0_normalization", nn.BatchNorm2d(channels)),
104 | ("1_activation", nn.ReLU(inplace=True)),
105 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
106 | ("3_normalization", nn.BatchNorm2d(channels)),
107 | ("4_activation", nn.ReLU(inplace=True)),
108 | ("5_dropout", nn.Dropout(dropout, inplace=True)),
109 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)),
110 | ]))
111 |
112 | def forward(self, x):
113 | return x + self.block(x)
114 |
115 |
116 | class DownsampleUnit(nn.Module):
117 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float):
118 | super(DownsampleUnit, self).__init__()
119 | self.norm_act = nn.Sequential(OrderedDict([
120 | ("0_normalization", nn.BatchNorm2d(in_channels)),
121 | ("1_activation", nn.ReLU(inplace=True)),
122 | ]))
123 | self.block = nn.Sequential(OrderedDict([
124 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)),
125 | ("1_normalization", nn.BatchNorm2d(out_channels)),
126 | ("2_activation", nn.ReLU(inplace=True)),
127 | ("3_dropout", nn.Dropout(dropout, inplace=True)),
128 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)),
129 | ]))
130 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False)
131 |
132 | def forward(self, x):
133 | x = self.norm_act(x)
134 | return self.block(x) + self.downsample(x)
135 |
136 |
137 | class Block(nn.Module):
138 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float):
139 | super(Block, self).__init__()
140 | self.block = nn.Sequential(
141 | DownsampleUnit(in_channels, out_channels, stride, dropout),
142 | *(BasicUnit(out_channels, dropout) for _ in range(depth))
143 | )
144 |
145 | def forward(self, x):
146 | return self.block(x)
147 |
148 |
149 | class WideResNet(nn.Module):
150 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, num_classes: int):
151 | super(WideResNet, self).__init__()
152 |
153 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor]
154 | self.block_depth = (depth - 4) // (3 * 2)
155 |
156 | self.f = nn.Sequential(OrderedDict([
157 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)),
158 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)),
159 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)),
160 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)),
161 | ("4_normalization", nn.BatchNorm2d(self.filters[3])),
162 | ("5_activation", nn.ReLU(inplace=True)),
163 | ("6_pooling", nn.AvgPool2d(kernel_size=8)),
164 | ("7_flattening", nn.Flatten()),
165 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=num_classes)),
166 | ]))
167 |
168 | self._initialize()
169 |
170 | def _initialize(self):
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu")
174 | if m.bias is not None:
175 | m.bias.data.zero_()
176 | elif isinstance(m, nn.BatchNorm2d):
177 | m.weight.data.fill_(1)
178 | m.bias.data.zero_()
179 | elif isinstance(m, nn.Linear):
180 | m.weight.data.zero_()
181 | m.bias.data.zero_()
182 |
183 | def forward(self, x):
184 | return self.f(x)
185 |
186 | class WideResNet28x10:
187 | base = WideResNet
188 | args = list()
189 | kwargs = {'depth': 28, 'width_factor': 10}
190 | transform_train = transforms.Compose([
191 | transforms.RandomCrop(32, padding=4),
192 | transforms.RandomHorizontalFlip(),
193 | transforms.ToTensor(),
194 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
195 | ])
196 | transform_test = transforms.Compose([
197 | transforms.ToTensor(),
198 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
199 | ])
200 |
201 | class WideResNet16x8:
202 | base = WideResNet
203 | args = list()
204 | kwargs = {'depth': 16, 'width_factor': 8, 'dropout': 0, 'in_channels': 3}
205 | transform_train = transforms.Compose([
206 | transforms.RandomCrop(32, padding=4),
207 | transforms.RandomHorizontalFlip(),
208 | transforms.ToTensor(),
209 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
210 | ])
211 | transform_test = transforms.Compose([
212 | transforms.ToTensor(),
213 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
214 | ])
--------------------------------------------------------------------------------
/train_sgd_cifar.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import numpy as np
5 | import random
6 | import sys
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim
13 | import torch.utils.data
14 | import torchvision.transforms as transforms
15 | import torchvision.datasets as datasets
16 |
17 | from utils import get_datasets, get_model, adjust_learning_rate, set_seed, Logger
18 |
19 | # Parse arguments
20 | parser = argparse.ArgumentParser(description='Regular SGD training')
21 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='SGD')
22 | parser.add_argument('--arch', '-a', metavar='ARCH',
23 | help='The architecture of the model')
24 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str,
25 | help='The training datasets')
26 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str,
27 | help='The optimizer for training')
28 | parser.add_argument('--schedule', metavar='SCHEDULE', default='step', type=str,
29 | help='The schedule for training')
30 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
31 | help='number of data loading workers (default: 4)')
32 | parser.add_argument('--epochs', default=200, type=int, metavar='N',
33 | help='number of total epochs to run')
34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
35 | help='manual epoch number (useful on restarts)')
36 | parser.add_argument('-b', '--batch-size', default=128, type=int,
37 | metavar='N', help='mini-batch size (default: 128)')
38 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
39 | metavar='LR', help='initial learning rate')
40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
41 | help='momentum')
42 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
43 | metavar='W', help='weight decay (default: 1e-4)')
44 | parser.add_argument('--print-freq', '-p', default=50, type=int,
45 | metavar='N', help='print frequency (default: 50 iterations)')
46 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
47 | help='path to latest checkpoint (default: none)')
48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
49 | help='evaluate model on validation set')
50 | parser.add_argument('--wandb', dest='wandb', action='store_true',
51 | help='use wandb to monitor statisitcs')
52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
53 | help='use pre-trained model')
54 | parser.add_argument('--half', dest='half', action='store_true',
55 | help='use half-precision(16-bit) ')
56 | parser.add_argument('--save-dir', dest='save_dir',
57 | help='The directory used to save the trained models',
58 | default='save_temp', type=str)
59 | parser.add_argument('--log-dir', dest='log_dir',
60 | help='The directory used to save the log',
61 | default='save_temp', type=str)
62 | parser.add_argument('--log-name', dest='log_name',
63 | help='The log file name',
64 | default='log', type=str)
65 | parser.add_argument('--randomseed',
66 | help='Randomseed for training and initialization',
67 | type=int, default=1)
68 |
69 | best_prec1 = 0
70 |
71 |
72 | # Record training statistics
73 | train_loss = []
74 | train_err = []
75 | test_loss = []
76 | test_err = []
77 | arr_time = []
78 |
79 | p0 = None
80 |
81 | args = parser.parse_args()
82 |
83 | if args.wandb:
84 | import wandb
85 | wandb.init(project="TWA", entity="XXX")
86 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
87 | wandb.run.name = args.EXP + date
88 |
89 |
90 | def get_model_param_vec(model):
91 | # Return the model parameters as a vector
92 |
93 | vec = []
94 | for name,param in model.named_parameters():
95 | vec.append(param.data.detach().reshape(-1))
96 | return torch.cat(vec, 0)
97 |
98 |
99 | def get_model_grad_vec(model):
100 | # Return the model gradient as a vector
101 |
102 | vec = []
103 | for name,param in model.named_parameters():
104 | vec.append(param.grad.detach().reshape(-1))
105 | return torch.cat(vec, 0)
106 |
107 | def update_grad(model, grad_vec):
108 | idx = 0
109 | for name,param in model.named_parameters():
110 | arr_shape = param.grad.shape
111 | size = arr_shape.numel()
112 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape)
113 | idx += size
114 |
115 | def update_param(model, param_vec):
116 | idx = 0
117 | for name,param in model.named_parameters():
118 | arr_shape = param.data.shape
119 | size = arr_shape.numel()
120 | param.data = param_vec[idx:idx+size].reshape(arr_shape)
121 | idx += size
122 |
123 | sample_idx = 0
124 |
125 | def main():
126 |
127 | global args, best_prec1, p0, sample_idx
128 | global param_avg, train_loss, train_err, test_loss, test_err, arr_time, running_weight
129 |
130 | set_seed(args.randomseed)
131 |
132 | # Check the save_dir exists or not
133 | print ('save dir:', args.save_dir)
134 | if not os.path.exists(args.save_dir):
135 | os.makedirs(args.save_dir)
136 |
137 | # Check the log_dir exists or not
138 | print ('log dir:', args.log_dir)
139 | if not os.path.exists(args.log_dir):
140 | os.makedirs(args.log_dir)
141 |
142 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
143 |
144 | # Define model
145 | # model = torch.nn.DataParallel(get_model(args))
146 | model = get_model(args)
147 | model.cuda()
148 |
149 | # Optionally resume from a checkpoint
150 | if args.resume:
151 | # if os.path.isfile(args.resume):
152 | if os.path.isfile(os.path.join(args.save_dir, args.resume)):
153 |
154 | # model.load_state_dict(torch.load(os.path.join(args.save_dir, args.resume)))
155 |
156 | print ("=> loading checkpoint '{}'".format(args.resume))
157 | checkpoint = torch.load(args.resume)
158 | args.start_epoch = checkpoint['epoch']
159 | print ('from ', args.start_epoch)
160 | best_prec1 = checkpoint['best_prec1']
161 | model.load_state_dict(checkpoint['state_dict'])
162 | print ("=> loaded checkpoint '{}' (epoch {})"
163 | .format(args.evaluate, checkpoint['epoch']))
164 | else:
165 | print ("=> no checkpoint found at '{}'".format(args.resume))
166 |
167 | cudnn.benchmark = True
168 |
169 |
170 | # Prepare Dataloader
171 | train_loader, val_loader = get_datasets(args)
172 |
173 | # define loss function (criterion) and optimizer
174 | criterion = nn.CrossEntropyLoss().cuda()
175 |
176 | if args.half:
177 | model.half()
178 | criterion.half()
179 |
180 | print ('optimizer:', args.optimizer)
181 |
182 | if args.optimizer == 'sgd':
183 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
184 | momentum=args.momentum,
185 | weight_decay=args.weight_decay)
186 | elif args.optimizer == 'adam':
187 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
188 | weight_decay=args.weight_decay)
189 |
190 | if args.schedule == 'step':
191 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
192 |
193 | if args.evaluate:
194 | validate(val_loader, model, criterion)
195 | return
196 |
197 | is_best = 0
198 | print ('Start training: ', args.start_epoch, '->', args.epochs)
199 |
200 | # DLDR sampling
201 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(0) + '.pt'))
202 |
203 | p0 = get_model_param_vec(model)
204 | running_weight = p0
205 |
206 | for epoch in range(args.start_epoch, args.epochs):
207 |
208 | # train for one epoch
209 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
210 | train(train_loader, model, criterion, optimizer, epoch)
211 |
212 | lr_scheduler.step()
213 |
214 | # evaluate on validation set
215 | prec1 = validate(val_loader, model, criterion)
216 |
217 | # remember best prec@1 and save checkpoint
218 | is_best = prec1 > best_prec1
219 | best_prec1 = max(prec1, best_prec1)
220 |
221 | save_checkpoint({
222 | 'state_dict': model.state_dict(),
223 | 'best_prec1': best_prec1,
224 | }, is_best, filename=os.path.join(args.save_dir, 'model.th'))
225 |
226 | # DLDR sampling
227 | sample_idx += 1
228 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(sample_idx) + '.pt'))
229 |
230 | print ('train loss: ', train_loss)
231 | print ('train err: ', train_err)
232 | print ('test loss: ', test_loss)
233 | print ('test err: ', test_err)
234 |
235 | print ('time: ', arr_time)
236 |
237 | running_weight = None
238 |
239 | def train(train_loader, model, criterion, optimizer, epoch):
240 | """
241 | Run one train epoch
242 | """
243 | global train_loss, train_err, arr_time, p0, sample_idx, running_weight
244 |
245 | batch_time = AverageMeter()
246 | data_time = AverageMeter()
247 | losses = AverageMeter()
248 | top1 = AverageMeter()
249 |
250 | # switch to train mode
251 | model.train()
252 |
253 | param_epoch_sum = None
254 | cnt = 0
255 |
256 | total_loss, total_err = 0, 0
257 | end = time.time()
258 | for i, (input, target) in enumerate(train_loader):
259 |
260 | # measure data loading time
261 | data_time.update(time.time() - end)
262 |
263 | target = target.cuda()
264 | input_var = input.cuda()
265 | target_var = target
266 | if args.half:
267 | input_var = input_var.half()
268 |
269 | # compute output
270 | output = model(input_var)
271 | loss = criterion(output, target_var)
272 |
273 | # compute gradient and do SGD step
274 | optimizer.zero_grad()
275 | loss.backward()
276 | total_loss += loss.item() * input_var.shape[0]
277 | total_err += (output.max(dim=1)[1] != target_var).sum().item()
278 |
279 | optimizer.step()
280 | output = output.float()
281 | loss = loss.float()
282 |
283 | # measure accuracy and record loss
284 | prec1 = accuracy(output.data, target)[0]
285 | losses.update(loss.item(), input.size(0))
286 | top1.update(prec1.item(), input.size(0))
287 |
288 | # measure elapsed time
289 | batch_time.update(time.time() - end)
290 | end = time.time()
291 |
292 | if i % args.print_freq == 0:
293 | print('Epoch: [{0}][{1}/{2}]\t'
294 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
295 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
296 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
297 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
298 | epoch, i, len(train_loader), batch_time=batch_time,
299 | data_time=data_time, loss=losses, top1=top1))
300 |
301 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum))
302 |
303 | train_loss.append(total_loss / len(train_loader.dataset))
304 | train_err.append(total_err / len(train_loader.dataset))
305 | if args.wandb:
306 | wandb.log({"train loss": total_loss / len(train_loader.dataset)})
307 | wandb.log({"train acc": 1 - total_err / len(train_loader.dataset)})
308 |
309 | arr_time.append(batch_time.sum)
310 |
311 | def validate(val_loader, model, criterion):
312 | """
313 | Run evaluation
314 | """
315 | global test_err, test_loss
316 |
317 | total_loss = 0
318 | total_err = 0
319 |
320 | batch_time = AverageMeter()
321 | losses = AverageMeter()
322 | top1 = AverageMeter()
323 |
324 | # switch to evaluate mode
325 | model.eval()
326 |
327 | end = time.time()
328 | with torch.no_grad():
329 | for i, (input, target) in enumerate(val_loader):
330 | target = target.cuda()
331 | input_var = input.cuda()
332 | target_var = target.cuda()
333 |
334 | if args.half:
335 | input_var = input_var.half()
336 |
337 | # compute output
338 | output = model(input_var)
339 | loss = criterion(output, target_var)
340 |
341 | output = output.float()
342 | loss = loss.float()
343 |
344 | total_loss += loss.item() * input_var.shape[0]
345 | total_err += (output.max(dim=1)[1] != target_var).sum().item()
346 |
347 | # measure accuracy and record loss
348 | prec1 = accuracy(output.data, target)[0]
349 | losses.update(loss.item(), input.size(0))
350 | top1.update(prec1.item(), input.size(0))
351 |
352 | # measure elapsed time
353 | batch_time.update(time.time() - end)
354 | end = time.time()
355 |
356 | if i % args.print_freq == 0:
357 | print('Test: [{0}/{1}]\t'
358 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
359 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
360 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
361 | i, len(val_loader), batch_time=batch_time, loss=losses,
362 | top1=top1))
363 |
364 | print(' * Prec@1 {top1.avg:.3f}'
365 | .format(top1=top1))
366 |
367 | test_loss.append(total_loss / len(val_loader.dataset))
368 | test_err.append(total_err / len(val_loader.dataset))
369 |
370 | if args.wandb:
371 | wandb.log({"test loss": total_loss / len(val_loader.dataset)})
372 | wandb.log({"test acc": 1 - total_err / len(val_loader.dataset)})
373 |
374 | return top1.avg
375 |
376 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
377 | """
378 | Save the training model
379 | """
380 | torch.save(state, filename)
381 |
382 | class AverageMeter(object):
383 | """Computes and stores the average and current value"""
384 | def __init__(self):
385 | self.reset()
386 |
387 | def reset(self):
388 | self.val = 0
389 | self.avg = 0
390 | self.sum = 0
391 | self.count = 0
392 |
393 | def update(self, val, n=1):
394 | self.val = val
395 | self.sum += val * n
396 | self.count += n
397 | self.avg = self.sum / self.count
398 |
399 |
400 | def accuracy(output, target, topk=(1,)):
401 | """Computes the precision@k for the specified values of k"""
402 | maxk = max(topk)
403 | batch_size = target.size(0)
404 |
405 | _, pred = output.topk(maxk, 1, True, True)
406 | pred = pred.t()
407 | correct = pred.eq(target.view(1, -1).expand_as(pred))
408 |
409 | res = []
410 | for k in topk:
411 | correct_k = correct[:k].view(-1).float().sum(0)
412 | res.append(correct_k.mul_(100.0 / batch_size))
413 | return res
414 |
415 |
416 | if __name__ == '__main__':
417 | main()
418 |
--------------------------------------------------------------------------------
/train_twa.py:
--------------------------------------------------------------------------------
1 | from random import choices
2 | import argparse
3 | import _osx_support
4 | import time
5 | import os
6 | import sys
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim as optim
13 | import torch.utils.data
14 |
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import pickle
18 | import random
19 | import utils
20 | from utils import get_datasets, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger
21 |
22 | ########################## parse arguments ##########################
23 | parser = argparse.ArgumentParser(description='TWA')
24 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD')
25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN',
26 | help='model architecture (default: VGG16BN)')
27 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str,
28 | help='The training datasets')
29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
30 | help='number of data loading workers (default: 4)')
31 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
32 | help='number of total epochs to run')
33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
34 | help='manual epoch number (useful on restarts)')
35 | parser.add_argument('-b', '--batch-size', default=128, type=int,
36 | metavar='N', help='mini-batch size (default: 128)')
37 | parser.add_argument('-acc', '--accumulate', default=1, type=int,
38 | metavar='A', help='accumulate times for batch gradient (default: 1)')
39 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
40 | metavar='W', help='weight decay (default: 1e-4)')
41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
42 | help='momentum')
43 | parser.add_argument('--print-freq', '-p', default=200, type=int,
44 | metavar='N', help='print frequency (default: 50)')
45 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
46 | help='path to latest checkpoint (default: none)')
47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
48 | help='evaluate model on validation set')
49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
50 | help='use pre-trained model')
51 | parser.add_argument('--half', dest='half', action='store_true',
52 | help='use half-precision(16-bit) ')
53 | parser.add_argument('--randomseed',
54 | help='Randomseed for training and initialization',
55 | type=int, default=1)
56 | parser.add_argument('--save-dir', dest='save_dir',
57 | help='The directory used to save the trained models',
58 | default='save_temp', type=str)
59 | parser.add_argument('--log-dir', dest='log_dir',
60 | help='The directory used to save the log',
61 | default='save_temp', type=str)
62 | parser.add_argument('--log-name', dest='log_name',
63 | help='The log file name',
64 | default='log', type=str)
65 |
66 | ########################## P-SGD setting ##########################
67 | parser.add_argument('--extract', metavar='EXTRACT', help='method for extracting subspace',
68 | default='Schmidt', choices=['Schmidt'])
69 | parser.add_argument('--params_start', default=0, type=int, metavar='N',
70 | help='which idx start for TWA')
71 | parser.add_argument('--params_end', default=101, type=int, metavar='N',
72 | help='which idx end for TWA')
73 | parser.add_argument('--train_start', default=0, type=int, metavar='N',
74 | help='which idx start for training')
75 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA',
76 | default='SGD', choices=['SGD'])
77 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD',
78 | default='step', choices=['step', 'constant', 'linear'])
79 | parser.add_argument('--lr', default=1, type=float, metavar='N',
80 | help='lr for PSGD')
81 |
82 | args = parser.parse_args()
83 | set_seed(args.randomseed)
84 | best_prec1 = 0
85 | P = None
86 | train_acc, test_acc, train_loss, test_loss = [], [], [], []
87 |
88 | def get_model_param_vec(model):
89 | """
90 | Return model parameters as a vector
91 | """
92 | vec = []
93 | for name,param in model.named_parameters():
94 | vec.append(param.detach().cpu().numpy().reshape(-1))
95 | return np.concatenate(vec, 0)
96 |
97 | def get_model_param_vec_torch(model):
98 | """
99 | Return model parameters as a vector
100 | """
101 | vec = []
102 | for name,param in model.named_parameters():
103 | vec.append(param.data.detach().reshape(-1))
104 | return torch.cat(vec, 0)
105 |
106 | def get_model_grad_vec(model):
107 | """
108 | Return model grad as a vector
109 | """
110 | vec = []
111 | for name,param in model.named_parameters():
112 | vec.append(param.grad.detach().reshape(-1))
113 | return torch.cat(vec, 0)
114 |
115 | def update_grad(model, grad_vec):
116 | """
117 | Update model grad
118 | """
119 | idx = 0
120 | for name,param in model.named_parameters():
121 | arr_shape = param.grad.shape
122 | size = arr_shape.numel()
123 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone()
124 | idx += size
125 |
126 | def update_param(model, param_vec):
127 | idx = 0
128 | for name,param in model.named_parameters():
129 | arr_shape = param.data.shape
130 | size = arr_shape.numel()
131 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone()
132 | idx += size
133 |
134 | def main():
135 |
136 | global args, best_prec1, Bk, P, coeff, coeff_inv
137 |
138 | # Check the save_dir exists or not
139 | if not os.path.exists(args.save_dir):
140 | os.makedirs(args.save_dir)
141 |
142 | # Check the log_dir exists or not
143 | if not os.path.exists(args.log_dir):
144 | os.makedirs(args.log_dir)
145 |
146 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
147 | print ('twa-psgd')
148 | print ('save dir:', args.save_dir)
149 | print ('log dir:', args.log_dir)
150 |
151 | # Define model
152 | if args.datasets == 'ImageNet':
153 | model = torch.nn.DataParallel(get_model(args))
154 | else:
155 | model = get_model(args)
156 | model.cuda()
157 | cudnn.benchmark = True
158 |
159 | # Define loss function (criterion) and optimizer
160 | criterion = nn.CrossEntropyLoss().cuda()
161 |
162 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \
163 | momentum=args.momentum, \
164 | weight_decay=args.weight_decay)
165 |
166 | if args.schedule == 'step':
167 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
168 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1)
169 |
170 | elif args.schedule == 'constant' or args.schedule == 'linear':
171 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
172 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1)
173 |
174 | optimizer.zero_grad()
175 |
176 | # Prepare Dataloader
177 | train_loader, val_loader = get_datasets(args)
178 |
179 | args.total_iters = len(train_loader) * args.epochs
180 | args.current_iters = 0
181 |
182 | ########################## extract subspaces ##########################
183 | # Load sampled model parameters
184 | print ('weight decay:', args.weight_decay)
185 | print ('params: from', args.params_start, 'to', args.params_end)
186 | W = []
187 | for i in range(args.params_start, args.params_end):
188 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(i) + '.pt')))
189 | W.append(get_model_param_vec(model))
190 | W = np.array(W)
191 | print ('W:', W.shape)
192 |
193 | # Evaluate swa performance
194 | center = torch.from_numpy(np.mean(W, axis=0)).cuda()
195 |
196 | update_param(model, center)
197 | bn_update(train_loader, model)
198 | print (utils.eval_model(val_loader, model, criterion))
199 |
200 | if args.extract == 'Schmidt':
201 | P = torch.from_numpy(np.array(W)).cuda()
202 | n_dim = P.shape[0]
203 | args.n_components = n_dim
204 | coeff = torch.eye(n_dim).cuda()
205 | for i in range(n_dim):
206 | if i > 0:
207 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1))
208 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1)
209 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1)
210 | tmp = torch.norm(P[i])
211 | P[i] /= tmp
212 | coeff[i] /= tmp
213 | coeff_inv = coeff.T.inverse()
214 |
215 | print (P.shape)
216 |
217 | # set the start point
218 | if args.train_start >= 0:
219 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt')))
220 | print ('train start:', args.train_start)
221 |
222 | if args.half:
223 | model.half()
224 | criterion.half()
225 |
226 | if args.evaluate:
227 | validate(val_loader, model, criterion)
228 | return
229 |
230 | print ('Train:', (args.start_epoch, args.epochs))
231 | end = time.time()
232 | p0 = get_model_param_vec(model)
233 |
234 | for epoch in range(args.start_epoch, args.epochs):
235 | # Train for one epoch
236 |
237 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
238 | train(train_loader, model, criterion, optimizer, args, epoch, center)
239 |
240 | if args.schedule != 'linear':
241 | lr_scheduler.step()
242 |
243 | # Evaluate on validation set
244 | prec1 = validate(val_loader, model, criterion)
245 |
246 | # Remember best prec@1 and save checkpoint
247 | is_best = prec1 > best_prec1
248 | best_prec1 = max(prec1, best_prec1)
249 |
250 | print ('Save final model')
251 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'PSGD.pt'))
252 |
253 | bn_update(train_loader, model)
254 | print (utils.eval_model(val_loader, model, criterion))
255 |
256 | print ('total time:', time.time() - end)
257 | print ('train loss: ', train_loss)
258 | print ('train acc: ', train_acc)
259 | print ('test loss: ', test_loss)
260 | print ('test acc: ', test_acc)
261 | print ('best_prec1:', best_prec1)
262 |
263 |
264 | def train(train_loader, model, criterion, optimizer, args, epoch, center):
265 | # Run one train epoch
266 |
267 | global P, W, iters, T, train_loss, train_acc, search_times, coeff
268 |
269 | batch_time = AverageMeter()
270 | data_time = AverageMeter()
271 | losses = AverageMeter()
272 | top1 = AverageMeter()
273 |
274 | # Switch to train mode
275 | model.train()
276 |
277 | end = time.time()
278 | for i, (input, target) in enumerate(train_loader):
279 |
280 | # Measure data loading time
281 | data_time.update(time.time() - end)
282 |
283 | # Load batch data to cuda
284 | target = target.cuda()
285 | input_var = input.cuda()
286 | target_var = target
287 | if args.half:
288 | input_var = input_var.half()
289 |
290 | # Compute output
291 | output = model(input_var)
292 | loss = criterion(output, target_var)
293 |
294 | # Compute gradient and do SGD step
295 | optimizer.zero_grad()
296 | loss.backward()
297 | gk = get_model_grad_vec(model)
298 |
299 | if args.schedule == 'linear':
300 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr)
301 | args.current_iters += 1
302 |
303 | if args.opt == 'SGD':
304 | P_SGD(model, optimizer, gk, center)
305 |
306 | # Measure accuracy and record loss
307 | prec1 = accuracy(output.data, target)[0]
308 | losses.update(loss.item(), input.size(0))
309 | top1.update(prec1.item(), input.size(0))
310 |
311 | # Measure elapsed time
312 | batch_time.update(time.time() - end)
313 | end = time.time()
314 |
315 | if i % args.print_freq == 0 or i == len(train_loader)-1:
316 | print('Epoch: [{0}][{1}/{2}]\t'
317 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
318 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
319 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
320 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
321 | epoch, i, len(train_loader), batch_time=batch_time,
322 | data_time=data_time, loss=losses, top1=top1))
323 |
324 | train_loss.append(losses.avg)
325 | train_acc.append(top1.avg)
326 |
327 | def P_SGD(model, optimizer, grad, center):
328 |
329 | # p = get_model_param_vec_torch(model)
330 | gk = torch.mm(P, grad.reshape(-1,1))
331 | grad_proj = torch.mm(P.transpose(0, 1), gk)
332 |
333 | update_grad(model, grad_proj.reshape(-1))
334 |
335 | optimizer.step()
336 |
337 | def validate(val_loader, model, criterion):
338 | # Run evaluation
339 |
340 | global test_acc, test_loss
341 |
342 | batch_time = AverageMeter()
343 | losses = AverageMeter()
344 | top1 = AverageMeter()
345 |
346 | # Switch to evaluate mode
347 | model.eval()
348 |
349 | end = time.time()
350 | with torch.no_grad():
351 | for i, (input, target) in enumerate(val_loader):
352 | target = target.cuda()
353 | input_var = input.cuda()
354 | target_var = target.cuda()
355 |
356 | if args.half:
357 | input_var = input_var.half()
358 |
359 | # Compute output
360 | output = model(input_var)
361 | loss = criterion(output, target_var)
362 |
363 | output = output.float()
364 | loss = loss.float()
365 |
366 | # Measure accuracy and record loss
367 | prec1 = accuracy(output.data, target)[0]
368 | losses.update(loss.item(), input.size(0))
369 | top1.update(prec1.item(), input.size(0))
370 |
371 | # Measure elapsed time
372 | batch_time.update(time.time() - end)
373 | end = time.time()
374 |
375 | if i % args.print_freq == 0:
376 | print('Test: [{0}/{1}]\t'
377 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
378 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
379 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
380 | i, len(val_loader), batch_time=batch_time, loss=losses,
381 | top1=top1))
382 |
383 | print(' * Prec@1 {top1.avg:.3f}'
384 | .format(top1=top1))
385 |
386 | # Store the test loss and test accuracy
387 | test_loss.append(losses.avg)
388 | test_acc.append(top1.avg)
389 |
390 | return top1.avg
391 |
392 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
393 | # Save the training model
394 |
395 | torch.save(state, filename)
396 |
397 | class AverageMeter(object):
398 | # Computes and stores the average and current value
399 |
400 | def __init__(self):
401 | self.reset()
402 |
403 | def reset(self):
404 | self.val = 0
405 | self.avg = 0
406 | self.sum = 0
407 | self.count = 0
408 |
409 | def update(self, val, n=1):
410 | self.val = val
411 | self.sum += val * n
412 | self.count += n
413 | self.avg = self.sum / self.count
414 |
415 |
416 | def accuracy(output, target, topk=(1,)):
417 | # Computes the precision@k for the specified values of k
418 |
419 | maxk = max(topk)
420 | batch_size = target.size(0)
421 |
422 | _, pred = output.topk(maxk, 1, True, True)
423 | pred = pred.t()
424 | correct = pred.eq(target.view(1, -1).expand_as(pred))
425 |
426 | res = []
427 | for k in topk:
428 | correct_k = correct[:k].view(-1).float().sum(0)
429 | res.append(correct_k.mul_(100.0 / batch_size))
430 | return res
431 |
432 | if __name__ == '__main__':
433 | main()
--------------------------------------------------------------------------------
/swa/train_twa.py:
--------------------------------------------------------------------------------
1 | from random import choices
2 | import argparse
3 | import _osx_support
4 | import time
5 | import os
6 | import sys
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim as optim
13 | import torch.utils.data
14 |
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import pickle
18 | import random
19 | import utils
20 | from utils import get_datasets, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger
21 |
22 | ########################## parse arguments ##########################
23 | parser = argparse.ArgumentParser(description='SGD in Projected Subspace')
24 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD')
25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN',
26 | help='model architecture (default: VGG16BN)')
27 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str,
28 | help='The training datasets')
29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
30 | help='number of data loading workers (default: 4)')
31 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
32 | help='number of total epochs to run')
33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
34 | help='manual epoch number (useful on restarts)')
35 | parser.add_argument('-b', '--batch-size', default=128, type=int,
36 | metavar='N', help='mini-batch size (default: 128)')
37 | parser.add_argument('-acc', '--accumulate', default=1, type=int,
38 | metavar='A', help='accumulate times for batch gradient (default: 1)')
39 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
40 | metavar='W', help='weight decay (default: 1e-4)')
41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
42 | help='momentum')
43 | parser.add_argument('--print-freq', '-p', default=200, type=int,
44 | metavar='N', help='print frequency (default: 50)')
45 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
46 | help='path to latest checkpoint (default: none)')
47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
48 | help='evaluate model on validation set')
49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
50 | help='use pre-trained model')
51 | parser.add_argument('--half', dest='half', action='store_true',
52 | help='use half-precision(16-bit) ')
53 | parser.add_argument('--randomseed',
54 | help='Randomseed for training and initialization',
55 | type=int, default=1)
56 | parser.add_argument('--save-dir', dest='save_dir',
57 | help='The directory used to save the trained models',
58 | default='save_temp', type=str)
59 | parser.add_argument('--log-dir', dest='log_dir',
60 | help='The directory used to save the log',
61 | default='save_temp', type=str)
62 | parser.add_argument('--log-name', dest='log_name',
63 | help='The log file name',
64 | default='log', type=str)
65 |
66 | ########################## P-SGD setting ##########################
67 | parser.add_argument('--extract', metavar='EXTRACT', help='method for extracting subspace',
68 | default='Schmidt', choices=['Schmidt'])
69 | parser.add_argument('--params_start', default=0, type=int, metavar='N',
70 | help='which idx start for TWA')
71 | parser.add_argument('--params_end', default=51, type=int, metavar='N',
72 | help='which idx end for TWA')
73 | parser.add_argument('--train_start', default=0, type=int, metavar='N',
74 | help='which idx start for training')
75 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA',
76 | default='SGD', choices=['SGD'])
77 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD',
78 | default='step', choices=['step', 'constant', 'linear'])
79 | parser.add_argument('--lr', default=1, type=float, metavar='N',
80 | help='lr for PSGD')
81 |
82 | args = parser.parse_args()
83 | set_seed(args.randomseed)
84 | best_prec1 = 0
85 | P = None
86 | train_acc, test_acc, train_loss, test_loss = [], [], [], []
87 |
88 | def get_model_param_vec(model):
89 | """
90 | Return model parameters as a vector
91 | """
92 | vec = []
93 | for name,param in model.named_parameters():
94 | vec.append(param.detach().cpu().numpy().reshape(-1))
95 | return np.concatenate(vec, 0)
96 |
97 | def get_model_param_vec_torch(model):
98 | """
99 | Return model parameters as a vector
100 | """
101 | vec = []
102 | for name,param in model.named_parameters():
103 | vec.append(param.data.detach().reshape(-1))
104 | return torch.cat(vec, 0)
105 |
106 | def get_model_grad_vec(model):
107 | """
108 | Return model grad as a vector
109 | """
110 | vec = []
111 | for name,param in model.named_parameters():
112 | vec.append(param.grad.detach().reshape(-1))
113 | return torch.cat(vec, 0)
114 |
115 | def update_grad(model, grad_vec):
116 | """
117 | Update model grad
118 | """
119 | idx = 0
120 | for name,param in model.named_parameters():
121 | arr_shape = param.grad.shape
122 | size = arr_shape.numel()
123 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone()
124 | idx += size
125 |
126 | def update_param(model, param_vec):
127 | idx = 0
128 | for name,param in model.named_parameters():
129 | arr_shape = param.data.shape
130 | size = arr_shape.numel()
131 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone()
132 | idx += size
133 |
134 | def main():
135 |
136 | global args, best_prec1, Bk, P, coeff, coeff_inv
137 |
138 | # Check the save_dir exists or not
139 | if not os.path.exists(args.save_dir):
140 | os.makedirs(args.save_dir)
141 |
142 | # Check the log_dir exists or not
143 | if not os.path.exists(args.log_dir):
144 | os.makedirs(args.log_dir)
145 |
146 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
147 | print ('twa-psgd')
148 | print ('save dir:', args.save_dir)
149 | print ('log dir:', args.log_dir)
150 |
151 | # Define model
152 | if args.datasets == 'ImageNet':
153 | model = torch.nn.DataParallel(get_model(args))
154 | else:
155 | model = get_model(args)
156 | model.cuda()
157 | cudnn.benchmark = True
158 |
159 | # Define loss function (criterion) and optimizer
160 | criterion = nn.CrossEntropyLoss().cuda()
161 |
162 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \
163 | momentum=args.momentum, \
164 | weight_decay=args.weight_decay)
165 |
166 | if args.schedule == 'step':
167 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
168 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1)
169 |
170 | elif args.schedule == 'constant' or args.schedule == 'linear':
171 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
172 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1)
173 |
174 | optimizer.zero_grad()
175 |
176 | # Prepare Dataloader
177 | train_loader, val_loader = get_datasets(args)
178 |
179 | args.total_iters = len(train_loader) * args.epochs
180 | args.current_iters = 0
181 |
182 | ########################## extract subspaces ##########################
183 | # Load sampled model parameters
184 | print ('weight decay:', args.weight_decay)
185 | print ('params: from', args.params_start, 'to', args.params_end)
186 | W = []
187 | for i in range(args.params_start, args.params_end):
188 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(i) + '.pt')))
189 | W.append(get_model_param_vec(model))
190 | W = np.array(W)
191 | print ('W:', W.shape)
192 |
193 | # Evaluate swa performance
194 | center = torch.from_numpy(np.mean(W, axis=0)).cuda()
195 |
196 | update_param(model, center)
197 | bn_update(train_loader, model)
198 | print ('SWA:', utils.eval_model(val_loader, model, criterion))
199 |
200 | if args.extract == 'Schmidt':
201 | P = torch.from_numpy(np.array(W)).cuda()
202 | n_dim = P.shape[0]
203 | args.n_components = n_dim
204 | coeff = torch.eye(n_dim).cuda()
205 | for i in range(n_dim):
206 | if i > 0:
207 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1))
208 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1)
209 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1)
210 | tmp = torch.norm(P[i])
211 | P[i] /= tmp
212 | coeff[i] /= tmp
213 | coeff_inv = coeff.T.inverse()
214 |
215 | print (P.shape)
216 |
217 | # set the start point
218 | if args.train_start >= 0:
219 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt')))
220 | print ('train start:', args.train_start)
221 |
222 | if args.half:
223 | model.half()
224 | criterion.half()
225 |
226 | if args.evaluate:
227 | validate(val_loader, model, criterion)
228 | return
229 |
230 | print ('Train:', (args.start_epoch, args.epochs))
231 | end = time.time()
232 | p0 = get_model_param_vec(model)
233 |
234 | for epoch in range(args.start_epoch, args.epochs):
235 | # Train for one epoch
236 |
237 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
238 | train(train_loader, model, criterion, optimizer, args, epoch, center)
239 |
240 | if args.schedule != 'linear':
241 | lr_scheduler.step()
242 |
243 | # Evaluate on validation set
244 | prec1 = validate(val_loader, model, criterion)
245 |
246 | # Remember best prec@1 and save checkpoint
247 | is_best = prec1 > best_prec1
248 | best_prec1 = max(prec1, best_prec1)
249 |
250 | print ('Save final model')
251 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'PSGD.pt'))
252 |
253 | bn_update(train_loader, model)
254 | print (utils.eval_model(val_loader, model, criterion))
255 |
256 | print ('total time:', time.time() - end)
257 | print ('train loss: ', train_loss)
258 | print ('train acc: ', train_acc)
259 | print ('test loss: ', test_loss)
260 | print ('test acc: ', test_acc)
261 | print ('best_prec1:', best_prec1)
262 |
263 |
264 | def train(train_loader, model, criterion, optimizer, args, epoch, center):
265 | # Run one train epoch
266 |
267 | global P, W, iters, T, train_loss, train_acc, search_times, coeff
268 |
269 | batch_time = AverageMeter()
270 | data_time = AverageMeter()
271 | losses = AverageMeter()
272 | top1 = AverageMeter()
273 |
274 | # Switch to train mode
275 | model.train()
276 |
277 | end = time.time()
278 | for i, (input, target) in enumerate(train_loader):
279 |
280 | # Measure data loading time
281 | data_time.update(time.time() - end)
282 |
283 | # Load batch data to cuda
284 | target = target.cuda()
285 | input_var = input.cuda()
286 | target_var = target
287 | if args.half:
288 | input_var = input_var.half()
289 |
290 | # Compute output
291 | output = model(input_var)
292 | loss = criterion(output, target_var)
293 |
294 | # Compute gradient and do SGD step
295 | optimizer.zero_grad()
296 | loss.backward()
297 | gk = get_model_grad_vec(model)
298 |
299 | if args.schedule == 'linear':
300 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr)
301 | args.current_iters += 1
302 |
303 | if args.opt == 'SGD':
304 | P_SGD(model, optimizer, gk, center)
305 |
306 | # Measure accuracy and record loss
307 | prec1 = accuracy(output.data, target)[0]
308 | losses.update(loss.item(), input.size(0))
309 | top1.update(prec1.item(), input.size(0))
310 |
311 | # Measure elapsed time
312 | batch_time.update(time.time() - end)
313 | end = time.time()
314 |
315 | if i % args.print_freq == 0 or i == len(train_loader)-1:
316 | print('Epoch: [{0}][{1}/{2}]\t'
317 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
318 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
319 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
320 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
321 | epoch, i, len(train_loader), batch_time=batch_time,
322 | data_time=data_time, loss=losses, top1=top1))
323 |
324 | train_loss.append(losses.avg)
325 | train_acc.append(top1.avg)
326 |
327 | def P_SGD(model, optimizer, grad, center):
328 |
329 | # p = get_model_param_vec_torch(model)
330 | gk = torch.mm(P, grad.reshape(-1,1))
331 | grad_proj = torch.mm(P.transpose(0, 1), gk)
332 |
333 | update_grad(model, grad_proj.reshape(-1))
334 |
335 | optimizer.step()
336 |
337 | def validate(val_loader, model, criterion):
338 | # Run evaluation
339 |
340 | global test_acc, test_loss
341 |
342 | batch_time = AverageMeter()
343 | losses = AverageMeter()
344 | top1 = AverageMeter()
345 |
346 | # Switch to evaluate mode
347 | model.eval()
348 |
349 | end = time.time()
350 | with torch.no_grad():
351 | for i, (input, target) in enumerate(val_loader):
352 | target = target.cuda()
353 | input_var = input.cuda()
354 | target_var = target.cuda()
355 |
356 | if args.half:
357 | input_var = input_var.half()
358 |
359 | # Compute output
360 | output = model(input_var)
361 | loss = criterion(output, target_var)
362 |
363 | output = output.float()
364 | loss = loss.float()
365 |
366 | # Measure accuracy and record loss
367 | prec1 = accuracy(output.data, target)[0]
368 | losses.update(loss.item(), input.size(0))
369 | top1.update(prec1.item(), input.size(0))
370 |
371 | # Measure elapsed time
372 | batch_time.update(time.time() - end)
373 | end = time.time()
374 |
375 | if i % args.print_freq == 0:
376 | print('Test: [{0}/{1}]\t'
377 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
378 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
379 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
380 | i, len(val_loader), batch_time=batch_time, loss=losses,
381 | top1=top1))
382 |
383 | print(' * Prec@1 {top1.avg:.3f}'
384 | .format(top1=top1))
385 |
386 | # Store the test loss and test accuracy
387 | test_loss.append(losses.avg)
388 | test_acc.append(top1.avg)
389 |
390 | return top1.avg
391 |
392 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
393 | # Save the training model
394 |
395 | torch.save(state, filename)
396 |
397 | class AverageMeter(object):
398 | # Computes and stores the average and current value
399 |
400 | def __init__(self):
401 | self.reset()
402 |
403 | def reset(self):
404 | self.val = 0
405 | self.avg = 0
406 | self.sum = 0
407 | self.count = 0
408 |
409 | def update(self, val, n=1):
410 | self.val = val
411 | self.sum += val * n
412 | self.count += n
413 | self.avg = self.sum / self.count
414 |
415 |
416 | def accuracy(output, target, topk=(1,)):
417 | # Computes the precision@k for the specified values of k
418 |
419 | maxk = max(topk)
420 | batch_size = target.size(0)
421 |
422 | _, pred = output.topk(maxk, 1, True, True)
423 | pred = pred.t()
424 | correct = pred.eq(target.view(1, -1).expand_as(pred))
425 |
426 | res = []
427 | for k in topk:
428 | correct_k = correct[:k].view(-1).float().sum(0)
429 | res.append(correct_k.mul_(100.0 / batch_size))
430 | return res
431 |
432 | if __name__ == '__main__':
433 | main()
--------------------------------------------------------------------------------
/train_twa_ddp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import os
4 | import sys
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim as optim
11 | import torch.utils.data
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 |
15 | import numpy as np
16 | from utils import get_imagenet_dataset, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger
17 |
18 | from PIL import Image, ImageFile
19 | ImageFile.LOAD_TRUNCATED_IMAGES = True
20 |
21 | ########################## parse arguments ##########################
22 | parser = argparse.ArgumentParser(description='TWA ddp')
23 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD')
24 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN',
25 | help='model architecture (default: VGG16BN)')
26 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str,
27 | help='The training datasets')
28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
29 | help='number of data loading workers (default: 4)')
30 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
31 | help='number of total epochs to run')
32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
33 | help='manual epoch number (useful on restarts)')
34 | parser.add_argument('-b', '--batch-size', default=128, type=int,
35 | metavar='N', help='mini-batch size (default: 128)')
36 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float,
37 | metavar='W', help='weight decay (default: 1e-4)')
38 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
39 | help='momentum')
40 | parser.add_argument('--print-freq', '-p', default=200, type=int,
41 | metavar='N', help='print frequency (default: 50)')
42 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
43 | help='evaluate model on validation set')
44 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
45 | help='use pre-trained model')
46 | # env
47 | parser.add_argument('--randomseed',
48 | help='Randomseed for training and initialization',
49 | type=int, default=1)
50 | parser.add_argument('--save-dir', dest='save_dir',
51 | help='The directory used to save the trained models',
52 | default='save_temp', type=str)
53 | parser.add_argument('--log-dir', dest='log_dir',
54 | help='The directory used to save the log',
55 | default='save_temp', type=str)
56 | parser.add_argument('--log-name', dest='log_name',
57 | help='The log file name',
58 | default='log', type=str)
59 | # project subspace setting
60 | parser.add_argument('--params_start', default=0, type=int, metavar='N',
61 | help='which idx start for project subspace')
62 | parser.add_argument('--params_end', default=51, type=int, metavar='N',
63 | help='which idx end for project subspace')
64 | parser.add_argument('--train_start', default=0, type=int, metavar='N',
65 | help='which idx start for training')
66 | # optimizer and scheduler
67 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA',
68 | default='SGD', choices=['SGD'])
69 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD',
70 | default='step', choices=['step', 'constant', 'linear'])
71 | parser.add_argument('--lr', default=1, type=float, metavar='N',
72 | help='lr for PSGD')
73 | # ddp
74 | parser.add_argument("--local_rank", default=-1, type=int)
75 |
76 | args = parser.parse_args()
77 | set_seed(args.randomseed)
78 |
79 | def reduce_value(value, op=dist.ReduceOp.SUM):
80 | world_size = dist.get_world_size()
81 | if world_size < 2: # single GPU
82 | return value
83 |
84 | with torch.no_grad():
85 | dist.all_reduce(value, op)
86 | return value
87 |
88 | def get_model_param_vec_torch(model):
89 | """
90 | Return model parameters as a vector
91 | """
92 | vec = []
93 | for _, param in model.named_parameters():
94 | vec.append(param.data.detach().reshape(-1))
95 | return torch.cat(vec, 0)
96 |
97 | def get_model_grad_vec_torch(model):
98 | """
99 | Return model grad as a vector
100 | """
101 | vec = []
102 | for _, param in model.named_parameters():
103 | vec.append(param.grad.detach().reshape(-1))
104 | return torch.cat(vec, 0)
105 |
106 | def update_grad(model, grad_vec):
107 | """
108 | Update model grad
109 | """
110 | idx = 0
111 | for _, param in model.named_parameters():
112 | arr_shape = param.grad.shape
113 | size = arr_shape.numel()
114 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone()
115 | idx += size
116 |
117 | def update_param(model, param_vec):
118 | idx = 0
119 | for _, param in model.named_parameters():
120 | arr_shape = param.data.shape
121 | size = arr_shape.numel()
122 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone()
123 | idx += size
124 |
125 | def main(args):
126 | # DDP initialize backend
127 | torch.cuda.set_device(args.local_rank)
128 | dist.init_process_group(backend='nccl')
129 | world_size = torch.distributed.get_world_size()
130 | device = torch.device("cuda", args.local_rank)
131 | dist.barrier() # Synchronizes all processes
132 |
133 | if dist.get_rank() == 0:
134 | # Check the save_dir exists or not
135 | if not os.path.exists(args.save_dir):
136 | os.makedirs(args.save_dir)
137 |
138 | # Check the log_dir exists or not
139 | if not os.path.exists(args.log_dir):
140 | os.makedirs(args.log_dir)
141 |
142 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name))
143 | print('twa-ddp')
144 | print('save dir:', args.save_dir)
145 | print('log dir:', args.log_dir)
146 |
147 | # Define model
148 | model = get_model(args).to(device)
149 | model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
150 | cudnn.benchmark = True
151 |
152 | # Define loss function (criterion) and optimizer
153 | criterion = nn.CrossEntropyLoss().to(device)
154 |
155 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \
156 | momentum=args.momentum, \
157 | weight_decay=args.weight_decay)
158 | optimizer.zero_grad()
159 |
160 | if args.schedule == 'step':
161 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
162 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1)
163 | elif args.schedule == 'constant' or args.schedule == 'linear':
164 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \
165 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1)
166 |
167 | # Prepare Dataloader
168 | train_dataset, val_dataset = get_imagenet_dataset()
169 | assert args.batch_size % world_size == 0, f"Batch size {args.batch_size} cannot be divided evenly by world size {world_size}"
170 | batch_size_per_GPU = args.batch_size // world_size
171 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
172 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
173 |
174 | train_loader = torch.utils.data.DataLoader(
175 | train_dataset, batch_size=batch_size_per_GPU, sampler=train_sampler,
176 | num_workers=args.workers, pin_memory=True)
177 |
178 | val_loader = torch.utils.data.DataLoader(
179 | val_dataset, batch_size=batch_size_per_GPU, sampler=val_sampler,
180 | num_workers=args.workers)
181 |
182 | args.total_iters = len(train_loader) * args.epochs
183 | args.current_iters = 0
184 |
185 | ########################## extract subspaces ##########################
186 | # Load sampled model parameters
187 | if dist.get_rank() == 0:
188 | print('weight decay:', args.weight_decay)
189 | print('params: from', args.params_start, 'to', args.params_end)
190 | W = []
191 | for i in range(args.params_start, args.params_end):
192 | if i%2==1: continue
193 | model.load_state_dict(torch.load(os.path.join(args.save_dir, f'{i}.pt')))
194 | W.append(get_model_param_vec_torch(model))
195 | W = torch.stack(W, dim=0)
196 |
197 | # Schmidt
198 | P = W
199 | n_dim = P.shape[0]
200 | coeff = torch.eye(n_dim).to(device)
201 | for i in range(n_dim):
202 | if i > 0:
203 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1))
204 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1)
205 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1)
206 | tmp = torch.norm(P[i])
207 | P[i] /= tmp
208 | coeff[i] /= tmp
209 | coeff_inv = coeff.T.inverse()
210 |
211 | # Slice P
212 | slice_start = (n_dim//world_size)*dist.get_rank()
213 | if dist.get_rank() == world_size-1:
214 | slice_P = P[slice_start:,:].clone()
215 | else:
216 | slice_end = (n_dim//world_size)*(dist.get_rank()+1)
217 | slice_P = P[slice_start:slice_end,:].clone()
218 | if dist.get_rank() == 0:
219 | print(f'W: {W.shape} {W.device}')
220 | print(f'P: {P.shape} {P.device}')
221 | print(f'Sliced P: {slice_P.shape} {slice_P.device}')
222 | del P
223 | torch.cuda.empty_cache()
224 | dist.barrier() # Synchronizes all processes
225 |
226 | # set the start point
227 | if args.train_start >= 0:
228 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt')))
229 | if dist.get_rank() == 0:
230 | print('train start:', args.train_start)
231 |
232 | if args.evaluate:
233 | validate(val_loader, model, criterion)
234 | return
235 |
236 | if dist.get_rank() == 0:
237 | print('Train:', (args.start_epoch, args.epochs))
238 | end = time.time()
239 | his_train_acc, his_test_acc, his_train_loss, his_test_loss = [], [], [], []
240 | best_prec1 = 0
241 | for epoch in range(args.start_epoch, args.epochs):
242 | # Train for one epoch
243 | if dist.get_rank() == 0:
244 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
245 | train_loss, train_prec1 = train(train_loader, model, criterion, optimizer,
246 | args, epoch, slice_P, device, world_size)
247 | his_train_loss.append(train_loss)
248 | his_train_acc.append(train_prec1)
249 |
250 | if args.schedule != 'linear':
251 | lr_scheduler.step()
252 |
253 | # Evaluate on validation set
254 | test_loss, test_prec1 = validate(val_loader, model, criterion, device, world_size)
255 | his_test_loss.append(test_loss)
256 | his_test_acc.append(test_prec1)
257 |
258 | # Remember best prec@1 and save checkpoint
259 | best_prec1 = max(test_prec1, best_prec1)
260 | if dist.get_rank() == 0:
261 | print(f'Epoch: [{epoch}] * Best Prec@1 {best_prec1:.3f}')
262 | torch.save(model.state_dict(), os.path.join(args.save_dir, f'ddp{epoch}.pt'))
263 |
264 | if dist.get_rank() == 0:
265 | print('total time:', time.time() - end)
266 | print('train loss: ', his_train_loss)
267 | print('train acc: ', his_train_acc)
268 | print('test loss: ', his_test_loss)
269 | print('test acc: ', his_test_acc)
270 | print('best_prec1:', best_prec1)
271 |
272 |
273 | def train(train_loader, model, criterion, optimizer, args, epoch, P, device, world_size=1):
274 | # Run one train epoch
275 |
276 | batch_time = AverageMeter()
277 | data_time = AverageMeter()
278 | losses = AverageMeter()
279 | correctes = 0
280 | count = 0
281 |
282 | # Switch to train mode
283 | model.train()
284 |
285 | end = time.time()
286 | for i, (input, target) in enumerate(train_loader):
287 | # Measure data loading time
288 | data_time.update(time.time() - end)
289 |
290 | # Load batch data to cuda
291 | target = target.to(device)
292 | input = input.to(device)
293 |
294 | batch_size = torch.tensor(target.size(0)).to(device)
295 | reduce_value(batch_size)
296 | count += batch_size
297 |
298 | # Compute output
299 | output = model(input)
300 | loss = criterion(output, target)
301 |
302 | # Compute gradient and do SGD step
303 | optimizer.zero_grad()
304 | loss.backward()
305 |
306 | if args.schedule == 'linear':
307 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr)
308 | args.current_iters += 1
309 |
310 | project_gradient(model, P)
311 | optimizer.step()
312 |
313 | # Measure accuracy and record loss
314 | _, pred = output.topk(1, 1, True, True)
315 | pred = pred.t()
316 | correct = pred.eq(target.view(1, -1).expand_as(pred))
317 | correct_1 = correct[:1].view(-1).float().sum(0)
318 | reduce_value(correct_1)
319 | correctes += correct_1
320 |
321 | reduce_value(loss)
322 | loss /= world_size
323 | losses.update(loss.item(), input.size(0))
324 |
325 | # Measure elapsed time
326 | batch_time.update(time.time() - end)
327 | end = time.time()
328 |
329 | if (i % args.print_freq == 0 or i == len(train_loader)-1) and dist.get_rank() == 0:
330 | print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t'
331 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
332 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
333 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
334 | f'Prec@1 {correct_1/batch_size*100:.3f} ({correctes/count*100:.3f})')
335 |
336 | return losses.avg, correctes/count*100
337 |
338 | def project_gradient(model, P):
339 | grad = get_model_grad_vec_torch(model)
340 | gk = torch.mm(P, grad.reshape(-1, 1))
341 | grad_proj = torch.mm(P.transpose(0, 1), gk)
342 | reduce_value(grad_proj) # Sum-reduce projected gradients on different GPUs
343 |
344 | update_grad(model, grad_proj.reshape(-1))
345 |
346 | def validate(val_loader, model, criterion, device, world_size=1):
347 | # Run evaluation
348 |
349 | batch_time = AverageMeter()
350 | losses = AverageMeter()
351 | correctes = 0
352 | count = 0
353 |
354 | # Switch to evaluate mode
355 | model.eval()
356 |
357 | end = time.time()
358 | with torch.no_grad():
359 | for i, (input, target) in enumerate(val_loader):
360 | target = target.to(device)
361 | input = input.to(device)
362 |
363 | batch_size = torch.tensor(target.size(0)).to(device)
364 | reduce_value(batch_size)
365 | count += batch_size
366 |
367 | # Compute output
368 | output = model(input)
369 | loss = criterion(output, target)
370 |
371 | # Measure accuracy and record loss
372 | _, pred = output.topk(1, 1, True, True)
373 | pred = pred.t()
374 | correct = pred.eq(target.view(1, -1).expand_as(pred))
375 | correct_1 = correct[:1].view(-1).float().sum(0)
376 | reduce_value(correct_1)
377 | correctes += correct_1
378 |
379 | reduce_value(loss)
380 | loss /= world_size
381 | losses.update(loss.item(), input.size(0))
382 |
383 |
384 | # Measure elapsed time
385 | batch_time.update(time.time() - end)
386 | end = time.time()
387 |
388 | if i % args.print_freq == 0 and dist.get_rank() == 0:
389 | print(f'Test: [{i}/{len(val_loader)}]\t'
390 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
391 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
392 | f'Prec@1 {correct_1/batch_size*100:.3f} ({correctes/count*100:.3f})')
393 |
394 | print(f' * Prec@1 {correctes/count*100:.3f}')
395 |
396 | return losses.avg, correctes/count*100
397 |
398 |
399 | class AverageMeter(object):
400 | # Computes and stores the average and current value
401 |
402 | def __init__(self):
403 | self.reset()
404 |
405 | def reset(self):
406 | self.val = 0
407 | self.avg = 0
408 | self.sum = 0
409 | self.count = 0
410 |
411 | def update(self, val, n=1):
412 | self.val = val
413 | self.sum += val * n
414 | self.count += n
415 | self.avg = self.sum / self.count
416 |
417 | def accuracy(output, target, topk=(1,)):
418 | # Computes the precision@k for the specified values of k
419 |
420 | maxk = max(topk)
421 | batch_size = target.size(0)
422 |
423 | _, pred = output.topk(maxk, 1, True, True)
424 | pred = pred.t()
425 | correct = pred.eq(target.view(1, -1).expand_as(pred))
426 |
427 | res = []
428 | for k in topk:
429 | correct_k = correct[:k].view(-1).float().sum(0)
430 | res.append(correct_k.mul_(100.0 / batch_size))
431 | return res
432 |
433 | if __name__ == '__main__':
434 | main(args)
--------------------------------------------------------------------------------
/train_sgd_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import os
8 |
9 | import numpy as np
10 | import pickle
11 |
12 | from PIL import Image, ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.parallel
18 | import torch.backends.cudnn as cudnn
19 | import torch.distributed as dist
20 | import torch.optim
21 | import torch.multiprocessing as mp
22 | import torch.utils.data
23 | import torch.utils.data.distributed
24 | import torchvision.transforms as transforms
25 | import torchvision.datasets as datasets
26 | import torchvision.models as models
27 |
28 | model_names = sorted(name for name in models.__dict__
29 | if name.islower() and not name.startswith("__")
30 | and callable(models.__dict__[name]))
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
33 | parser.add_argument('data', metavar='DIR',
34 | help='path to dataset')
35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
36 | choices=model_names,
37 | help='model architecture: ' +
38 | ' | '.join(model_names) +
39 | ' (default: resnet18)')
40 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
41 | help='number of data loading workers (default: 4)')
42 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
43 | help='number of total epochs to run')
44 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
45 | help='manual epoch number (useful on restarts)')
46 | parser.add_argument('-b', '--batch-size', default=256, type=int,
47 | metavar='N',
48 | help='mini-batch size (default: 256), this is the total '
49 | 'batch size of all GPUs on the current node when '
50 | 'using Data Parallel or Distributed Data Parallel')
51 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
52 | metavar='LR', help='initial learning rate', dest='lr')
53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
54 | help='momentum')
55 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
56 | metavar='W', help='weight decay (default: 1e-4)',
57 | dest='weight_decay')
58 | parser.add_argument('-p', '--print-freq', default=1000, type=int,
59 | metavar='N', help='print frequency (default: 10)')
60 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
61 | help='path to latest checkpoint (default: none)')
62 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
63 | help='evaluate model on validation set')
64 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
65 | help='use pre-trained model')
66 | parser.add_argument('--world-size', default=-1, type=int,
67 | help='number of nodes for distributed training')
68 | parser.add_argument('--rank', default=-1, type=int,
69 | help='node rank for distributed training')
70 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
71 | help='url used to set up distributed training')
72 | parser.add_argument('--dist-backend', default='nccl', type=str,
73 | help='distributed backend')
74 | parser.add_argument('--seed', default=None, type=int,
75 | help='seed for initializing training. ')
76 | parser.add_argument('--gpu', default=None, type=int,
77 | help='GPU id to use.')
78 | parser.add_argument('--multiprocessing-distributed', action='store_true',
79 | help='Use multi-processing distributed training to launch '
80 | 'N processes per node, which has N GPUs. This is the '
81 | 'fastest way to use PyTorch for either single node or '
82 | 'multi node data parallel training')
83 |
84 | best_acc1 = 0
85 |
86 |
87 | param_vec = []
88 | # Record training statistics
89 | train_loss = []
90 | train_acc = []
91 | test_loss = []
92 | test_acc = []
93 | arr_time = []
94 |
95 | iters = 0
96 | def get_model_param_vec(model):
97 | # Return the model parameters as a vector
98 |
99 | vec = []
100 | for name,param in model.named_parameters():
101 | vec.append(param.detach().cpu().reshape(-1).numpy())
102 | return np.concatenate(vec, 0)
103 |
104 | def main():
105 | global train_loss, train_acc, test_loss, test_acc, arr_time
106 |
107 | args = parser.parse_args()
108 |
109 |
110 | save_dir = 'save_' + args.arch
111 | if not os.path.exists(save_dir):
112 | os.makedirs(save_dir)
113 |
114 | if args.seed is not None:
115 | random.seed(args.seed)
116 | torch.manual_seed(args.seed)
117 | cudnn.deterministic = True
118 | warnings.warn('You have chosen to seed training. '
119 | 'This will turn on the CUDNN deterministic setting, '
120 | 'which can slow down your training considerably! '
121 | 'You may see unexpected behavior when restarting '
122 | 'from checkpoints.')
123 |
124 | if args.gpu is not None:
125 | warnings.warn('You have chosen a specific GPU. This will completely '
126 | 'disable data parallelism.')
127 |
128 | if args.dist_url == "env://" and args.world_size == -1:
129 | args.world_size = int(os.environ["WORLD_SIZE"])
130 |
131 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
132 |
133 | ngpus_per_node = torch.cuda.device_count()
134 | if args.multiprocessing_distributed:
135 | # Since we have ngpus_per_node processes per node, the total world_size
136 | # needs to be adjusted accordingly
137 | args.world_size = ngpus_per_node * args.world_size
138 | # Use torch.multiprocessing.spawn to launch distributed processes: the
139 | # main_worker process function
140 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
141 | else:
142 | # Simply call main_worker function
143 | main_worker(args.gpu, ngpus_per_node, args)
144 |
145 | sample_idx = 0
146 |
147 | def main_worker(gpu, ngpus_per_node, args):
148 | global train_loss, train_acc, test_loss, test_acc, arr_time
149 | global best_acc1, param_vec, sample_idx
150 | args.gpu = gpu
151 |
152 | if args.gpu is not None:
153 | print("Use GPU: {} for training".format(args.gpu))
154 |
155 | if args.distributed:
156 | if args.dist_url == "env://" and args.rank == -1:
157 | args.rank = int(os.environ["RANK"])
158 | if args.multiprocessing_distributed:
159 | # For multiprocessing distributed training, rank needs to be the
160 | # global rank among all the processes
161 | args.rank = args.rank * ngpus_per_node + gpu
162 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
163 | world_size=args.world_size, rank=args.rank)
164 | # create model
165 | if args.pretrained:
166 | print("=> using pre-trained model '{}'".format(args.arch))
167 | model = models.__dict__[args.arch](pretrained=True)
168 | else:
169 | print("=> creating model '{}'".format(args.arch))
170 | model = models.__dict__[args.arch]()
171 |
172 | if not torch.cuda.is_available():
173 | print('using CPU, this will be slow')
174 | elif args.distributed:
175 | # For multiprocessing distributed, DistributedDataParallel constructor
176 | # should always set the single device scope, otherwise,
177 | # DistributedDataParallel will use all available devices.
178 | if args.gpu is not None:
179 | torch.cuda.set_device(args.gpu)
180 | model.cuda(args.gpu)
181 | # When using a single GPU per process and per
182 | # DistributedDataParallel, we need to divide the batch size
183 | # ourselves based on the total number of GPUs we have
184 | args.batch_size = int(args.batch_size / ngpus_per_node)
185 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
186 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
187 | else:
188 | model.cuda()
189 | # DistributedDataParallel will divide and allocate batch_size to all
190 | # available GPUs if device_ids are not set
191 | model = torch.nn.parallel.DistributedDataParallel(model)
192 | elif args.gpu is not None:
193 | torch.cuda.set_device(args.gpu)
194 | model = model.cuda(args.gpu)
195 | else:
196 | # DataParallel will divide and allocate batch_size to all available GPUs
197 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
198 | model.features = torch.nn.DataParallel(model.features)
199 | model.cuda()
200 | else:
201 | model = torch.nn.DataParallel(model).cuda()
202 |
203 | # define loss function (criterion) and optimizer
204 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
205 |
206 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
207 | momentum=args.momentum,
208 | weight_decay=args.weight_decay)
209 |
210 | # optionally resume from a checkpoint
211 | if args.resume:
212 | if os.path.isfile(args.resume):
213 | print("=> loading checkpoint '{}'".format(args.resume))
214 | if args.gpu is None:
215 | checkpoint = torch.load(args.resume)
216 | else:
217 | # Map model to be loaded to specified single gpu.
218 | loc = 'cuda:{}'.format(args.gpu)
219 | checkpoint = torch.load(args.resume, map_location=loc)
220 | args.start_epoch = checkpoint['epoch']
221 | best_acc1 = checkpoint['best_acc1']
222 | if args.gpu is not None:
223 | # best_acc1 may be from a checkpoint from a different GPU
224 | best_acc1 = best_acc1.to(args.gpu)
225 | model.load_state_dict(checkpoint['state_dict'])
226 | optimizer.load_state_dict(checkpoint['optimizer'])
227 | print("=> loaded checkpoint '{}' (epoch {})"
228 | .format(args.resume, checkpoint['epoch']))
229 | else:
230 | print("=> no checkpoint found at '{}'".format(args.resume))
231 |
232 | cudnn.benchmark = True
233 |
234 | # Data loading code
235 | traindir = os.path.join(args.data, 'train')
236 | valdir = os.path.join(args.data, 'val')
237 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
238 | std=[0.229, 0.224, 0.225])
239 |
240 | train_dataset = datasets.ImageFolder(
241 | traindir,
242 | transforms.Compose([
243 | transforms.RandomResizedCrop(224),
244 | transforms.RandomHorizontalFlip(),
245 | transforms.ToTensor(),
246 | normalize,
247 | ]))
248 |
249 | if args.distributed:
250 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
251 | else:
252 | train_sampler = None
253 |
254 | train_loader = torch.utils.data.DataLoader(
255 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
256 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
257 |
258 | val_loader = torch.utils.data.DataLoader(
259 | datasets.ImageFolder(valdir, transforms.Compose([
260 | transforms.Resize(256),
261 | transforms.CenterCrop(224),
262 | transforms.ToTensor(),
263 | normalize,
264 | ])),
265 | batch_size=args.batch_size, shuffle=False,
266 | num_workers=args.workers, pin_memory=True)
267 |
268 | if args.evaluate:
269 | validate(val_loader, model, criterion, args)
270 | return
271 |
272 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
273 | torch.save(model.state_dict(), 'save_' + args.arch + '/' + str(sample_idx)+'.pt')
274 |
275 |
276 | for epoch in range(args.start_epoch, args.epochs):
277 | if args.distributed:
278 | train_sampler.set_epoch(epoch)
279 | adjust_learning_rate(optimizer, epoch, args)
280 |
281 | # train for one epoch
282 | train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node)
283 |
284 | # evaluate on validation set
285 | acc1 = validate(val_loader, model, criterion, args)
286 |
287 | # remember best acc@1 and save checkpoint
288 | is_best = acc1 > best_acc1
289 | best_acc1 = max(acc1, best_acc1)
290 |
291 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
292 | and args.rank % ngpus_per_node == 0):
293 | save_checkpoint({
294 | 'epoch': epoch + 1,
295 | 'arch': args.arch,
296 | 'state_dict': model.state_dict(),
297 | 'best_acc1': best_acc1,
298 | 'optimizer' : optimizer.state_dict(),
299 | }, is_best)
300 |
301 | print ('train loss: ', train_loss)
302 | print ('train acc: ', train_acc)
303 | print ('test loss: ', test_loss)
304 | print ('test acc: ', test_acc)
305 |
306 | print ('time: ', arr_time)
307 |
308 |
309 | def train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node):
310 | global iters, param_vec, sample_idx
311 | global train_loss, train_acc, test_loss, test_acc, arr_time
312 |
313 | batch_time = AverageMeter('Time', ':6.3f')
314 | data_time = AverageMeter('Data', ':6.3f')
315 | losses = AverageMeter('Loss', ':.4e')
316 | top1 = AverageMeter('Acc@1', ':6.2f')
317 | top5 = AverageMeter('Acc@5', ':6.2f')
318 | progress = ProgressMeter(
319 | len(train_loader),
320 | [batch_time, data_time, losses, top1, top5],
321 | prefix="Epoch: [{}]".format(epoch))
322 |
323 | # switch to train mode
324 | model.train()
325 |
326 | end = time.time()
327 | epoch_start = end
328 | for i, (images, target) in enumerate(train_loader):
329 | # measure data loading time
330 | data_time.update(time.time() - end)
331 |
332 | if args.gpu is not None:
333 | images = images.cuda(args.gpu, non_blocking=True)
334 | if torch.cuda.is_available():
335 | target = target.cuda(args.gpu, non_blocking=True)
336 |
337 | # compute output
338 | output = model(images)
339 | loss = criterion(output, target)
340 |
341 | # measure accuracy and record loss
342 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
343 | losses.update(loss.item(), images.size(0))
344 | top1.update(acc1[0], images.size(0))
345 | top5.update(acc5[0], images.size(0))
346 |
347 | # compute gradient and do SGD step
348 | optimizer.zero_grad()
349 | loss.backward()
350 | optimizer.step()
351 |
352 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
353 | and args.rank % ngpus_per_node == 0):
354 |
355 | if i % args.print_freq == 0:
356 | progress.display(i)
357 |
358 | if i > 0 and i % 1000 == 0 and i < 5000:
359 | sample_idx += 1
360 | torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt')
361 |
362 | # measure elapsed time
363 | batch_time.update(time.time() - end)
364 | end = time.time()
365 |
366 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
367 | and args.rank % ngpus_per_node == 0):
368 | sample_idx += 1
369 | torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt')
370 |
371 | arr_time.append(time.time() - epoch_start)
372 | train_loss.append(losses.avg)
373 | train_acc.append(top1.avg)
374 |
375 |
376 | def validate(val_loader, model, criterion, args):
377 | global train_loss, train_acc, test_loss, test_acc, arr_time
378 | batch_time = AverageMeter('Time', ':6.3f')
379 | losses = AverageMeter('Loss', ':.4e')
380 | top1 = AverageMeter('Acc@1', ':6.2f')
381 | top5 = AverageMeter('Acc@5', ':6.2f')
382 | progress = ProgressMeter(
383 | len(val_loader),
384 | [batch_time, losses, top1, top5],
385 | prefix='Test: ')
386 |
387 | # switch to evaluate mode
388 | model.eval()
389 |
390 | with torch.no_grad():
391 | end = time.time()
392 | for i, (images, target) in enumerate(val_loader):
393 | if args.gpu is not None:
394 | images = images.cuda(args.gpu, non_blocking=True)
395 | if torch.cuda.is_available():
396 | target = target.cuda(args.gpu, non_blocking=True)
397 |
398 | # compute output
399 | output = model(images)
400 | loss = criterion(output, target)
401 |
402 | # measure accuracy and record loss
403 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
404 | losses.update(loss.item(), images.size(0))
405 | top1.update(acc1[0], images.size(0))
406 | top5.update(acc5[0], images.size(0))
407 |
408 | # measure elapsed time
409 | batch_time.update(time.time() - end)
410 | end = time.time()
411 |
412 | if i % args.print_freq == 0:
413 | progress.display(i)
414 |
415 | # TODO: this should also be done with the ProgressMeter
416 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
417 | .format(top1=top1, top5=top5))
418 | test_acc.append(top1.avg)
419 | test_loss.append(losses.avg)
420 | return top1.avg
421 |
422 |
423 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
424 | torch.save(state, filename)
425 | if is_best:
426 | shutil.copyfile(filename, 'model_best.pth.tar')
427 |
428 |
429 | class AverageMeter(object):
430 | """Computes and stores the average and current value"""
431 | def __init__(self, name, fmt=':f'):
432 | self.name = name
433 | self.fmt = fmt
434 | self.reset()
435 |
436 | def reset(self):
437 | self.val = 0
438 | self.avg = 0
439 | self.sum = 0
440 | self.count = 0
441 |
442 | def update(self, val, n=1):
443 | self.val = val
444 | self.sum += val * n
445 | self.count += n
446 | self.avg = self.sum / self.count
447 |
448 | def __str__(self):
449 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
450 | return fmtstr.format(**self.__dict__)
451 |
452 |
453 | class ProgressMeter(object):
454 | def __init__(self, num_batches, meters, prefix=""):
455 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
456 | self.meters = meters
457 | self.prefix = prefix
458 |
459 | def display(self, batch):
460 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
461 | entries += [str(meter) for meter in self.meters]
462 | print('\t'.join(entries))
463 |
464 | def _get_batch_fmtstr(self, num_batches):
465 | num_digits = len(str(num_batches // 1))
466 | fmt = '{:' + str(num_digits) + 'd}'
467 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
468 |
469 |
470 | def adjust_learning_rate(optimizer, epoch, args):
471 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
472 | lr = args.lr * (0.1 ** (epoch // 30))
473 | # lr = 0.1
474 | for param_group in optimizer.param_groups:
475 | param_group['lr'] = lr
476 |
477 |
478 | def accuracy(output, target, topk=(1,)):
479 | """Computes the accuracy over the k top predictions for the specified values of k"""
480 | with torch.no_grad():
481 | maxk = max(topk)
482 | batch_size = target.size(0)
483 |
484 | _, pred = output.topk(maxk, 1, True, True)
485 | pred = pred.t()
486 | correct = pred.eq(target.view(1, -1).expand_as(pred))
487 |
488 | res = []
489 | for k in topk:
490 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
491 | res.append(correct_k.mul_(100.0 / batch_size))
492 | return res
493 |
494 |
495 | if __name__ == '__main__':
496 | main()
--------------------------------------------------------------------------------