├── twa.png ├── swa ├── models │ ├── __init__.py │ ├── vgg.py │ └── preresnet.py ├── readme.md ├── run.sh ├── run_twa.sh ├── utils_swa.py ├── utils.py ├── train.py └── train_twa.py ├── requirements.txt ├── .gitattributes ├── models ├── __pycache__ │ ├── vgg.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ ├── preresnet.cpython-37.pyc │ └── wide_resnet.cpython-37.pyc ├── __init__.py ├── vgg.py ├── preresnet.py ├── resnet.py └── wide_resnet.py ├── LICENSE ├── run.sh ├── README.md ├── utils.py ├── train_sgd_cifar.py ├── train_twa.py ├── train_twa_ddp.py └── train_sgd_imagenet.py /twa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/twa.png -------------------------------------------------------------------------------- /swa/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .preresnet import * 2 | from .vgg import * 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | torchvision>=0.6 3 | numpy>=1.21 4 | wandb==0.12.7 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/vgg.cpython-37.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .preresnet import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .wide_resnet import * -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/preresnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/preresnet.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/wide_resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nblt/TWA/HEAD/models/__pycache__/wide_resnet.cpython-37.pyc -------------------------------------------------------------------------------- /swa/readme.md: -------------------------------------------------------------------------------- 1 | ### TWA in tail stage training 2 | We show that TWA could improve the performance of SWA in the original SWA setting, where the improvements are more significant when the tail learning rate `swa_lr` is larger. 3 | 4 | First, run SWA using original [code](https://github.com/timgaripov/swa): 5 | ``` 6 | bash run.sh 7 | ``` 8 | Then, we could perform TWA using: 9 | ``` 10 | bash run_twa.sh 11 | ``` 12 | The training configuration is easy to set as you need in the scripts. -------------------------------------------------------------------------------- /swa/run.sh: -------------------------------------------------------------------------------- 1 | device=0 2 | data_dir=../datasets/ 3 | 4 | ############################### VGG16 ################################### 5 | dataset=CIFAR100 6 | model=VGG16BN 7 | seed=0 8 | swa_lr=0.05 9 | dir=swa_$model\_$dataset\_$seed\_$swa_lr 10 | UDA_VISIBLE_DEVICES=$device python3 train.py --dir=$dir --dataset=$dataset --data_path=$data_dir \ 11 | --model=$model --epochs=300 --lr_init=0.1 --wd=5e-4 --seed $seed \ 12 | --swa --swa_start=161 --swa_lr=$swa_lr |& tee -a $dir/log # SWA 1.5 Budgets 13 | 14 | 15 | ############################### PreResNet ################################### 16 | dataset=CIFAR100 # CIFAR10 CIFAR100 17 | model=PreResNet164 18 | seed=0 19 | swa_lr=0.05 20 | 21 | dir=swa_$model\_$dataset\_$seed\_$swa_lr 22 | CUDA_VISIBLE_DEVICES=$device python3 train.py --dir=$dir --seed $seed\ 23 | --dataset=$dataset --data_path=$data_dir --model=PreResNet164 --epochs=225 \ 24 | --lr_init=0.1 --wd=3e-4 --swa --swa_start=126 --swa_lr=$swa_lr |& tee -a $dir/log # SWA 1.5 Budgets 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 nblt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /swa/run_twa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CIFAR experiments 4 | 5 | datasets=CIFAR100 6 | device=0 7 | 8 | ####################################### swa ######################################## 9 | 10 | seed=0 11 | swa_lr=0.05 # 0.05 / 0.10 12 | model=PreResNet164 13 | wd_psgd=0.00005 14 | DST=swa_$model\_$datasets\_$seed\_$swa_lr 15 | 16 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \ 17 | --opt SGD --extract Schmidt --schedule step --accumulate 1 \ 18 | --lr 2 --params_start 126 --params_end 226 --train_start 225 --wd $wd_psgd \ 19 | --batch-size 128 --arch=$model \ 20 | --save-dir=$DST/checkpoints --log-dir=$DST --log-name=from_last 21 | 22 | seed=0 23 | swa_lr=0.05 # 0.05 / 0.10 24 | model=VGG16BN 25 | wd_psgd=0.00005 26 | DST=swa_$model\_$datasets\_$seed\_$swa_lr 27 | 28 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \ 29 | --opt SGD --extract Schmidt --schedule step --accumulate 1 \ 30 | --lr 2 --params_start 161 --params_end 301 --train_start 300 --wd $wd_psgd \ 31 | --batch-size 128 --arch=$model \ 32 | --save-dir=$DST/checkpoints --log-dir=$DST --log-name=from_last -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################ CIFAR ################################### 4 | datasets=CIFAR100 5 | device=0 6 | model=VGG16BN # PreResNet164 7 | DST=results/$model\_$datasets 8 | 9 | CUDA_VISIBLE_DEVICES=$device python -u train_sgd_cifar.py --datasets $datasets \ 10 | --arch=$model --epochs=200 --lr 0.1 \ 11 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100 12 | 13 | lr=2 14 | end=101 15 | wd_psgd=0.00001 16 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 10 --datasets $datasets \ 17 | --opt SGD --extract Schmidt --schedule step \ 18 | --lr $lr --params_start 0 --params_end $end --train_start -1 --wd $wd_psgd \ 19 | --batch-size 128 --arch=$model \ 20 | --save-dir=$DST/checkpoints --log-dir=$DST 21 | 22 | 23 | ################################ ImageNet ################################ 24 | datasets=ImageNet 25 | device=0,1,2,3 26 | 27 | model=resnet18 28 | path=/home/datasets/ILSVRC2012/ 29 | CUDA_VISIBLE_DEVICES=$device python3 train_sgd_imagenet.py -a $model \ 30 | --epochs 90 --workers 8 --dist-url 'tcp://127.0.0.1:1234' \ 31 | --dist-backend 'nccl' --multiprocessing-distributed \ 32 | --world-size 1 --rank 0 $path 33 | 34 | # TWA 60+2 35 | wd_psgd=0.00001 36 | lr=0.3 37 | DST=save_resnet18 38 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 2 --datasets $datasets \ 39 | --opt SGD --extract Schmidt --schedule step --worker 8 \ 40 | --lr $lr --params_start 0 --params_end 301 --train_start -1 --wd $wd_psgd \ 41 | --batch-size 256 --arch=$model \ 42 | --save-dir=$DST --log-dir=$DST 43 | 44 | # TWA (DDP version) 60+2 45 | datasets=ImageNet 46 | device=0,1,2,3 47 | 48 | model=resnet18 49 | wd_psgd=0.00001 50 | lr=0.3 51 | DST=save_resnet18 52 | CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node 4 train_twa_ddp.py \ 53 | --epochs 2 --datasets $datasets --opt SGD --schedule step --worker 8 \ 54 | --lr $lr --params_start 0 --params_end 301 --train_start -1 --wd $wd_psgd \ 55 | --batch-size 256 --arch $model --save-dir $DST --log-dir $DST 56 | 57 | # TWA 90+1 58 | wd_psgd=0.00001 59 | lr=0.03 60 | DST=save_resnet18 61 | CUDA_VISIBLE_DEVICES=$device python -u train_twa.py --epochs 1 --datasets $datasets \ 62 | --opt SGD --extract Schmidt --schedule linear --worker 8 \ 63 | --lr $lr --params_start 301 --params_end 451 --train_start -1 --wd $wd_psgd \ 64 | --batch-size 256 --arch=$model \ 65 | --save-dir=$DST --log-dir=$DST 66 | 67 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | 10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN'] 11 | 12 | 13 | def make_layers(cfg, batch_norm=False): 14 | layers = list() 15 | in_channels = 3 16 | for v in cfg: 17 | if v == 'M': 18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return nn.Sequential(*layers) 27 | 28 | 29 | cfg = { 30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 32 | 512, 512, 512, 512, 'M'], 33 | } 34 | 35 | 36 | class VGG(nn.Module): 37 | def __init__(self, num_classes=10, depth=16, batch_norm=False): 38 | super(VGG, self).__init__() 39 | self.features = make_layers(cfg[depth], batch_norm) 40 | self.classifier = nn.Sequential( 41 | nn.Dropout(), 42 | nn.Linear(512, 512), 43 | nn.ReLU(True), 44 | nn.Dropout(), 45 | nn.Linear(512, 512), 46 | nn.ReLU(True), 47 | nn.Linear(512, num_classes), 48 | ) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | m.bias.data.zero_() 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.classifier(x) 60 | return x 61 | 62 | 63 | class Base: 64 | base = VGG 65 | args = list() 66 | kwargs = dict() 67 | transform_train = transforms.Compose([ 68 | transforms.RandomHorizontalFlip(), 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | ]) 73 | 74 | transform_test = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 77 | ]) 78 | 79 | 80 | class VGG16(Base): 81 | pass 82 | 83 | 84 | class VGG16BN(Base): 85 | kwargs = {'batch_norm': True} 86 | 87 | 88 | class VGG19(Base): 89 | kwargs = {'depth': 19} 90 | 91 | 92 | class VGG19BN(Base): 93 | kwargs = {'depth': 19, 'batch_norm': True} -------------------------------------------------------------------------------- /swa/models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG model definition 3 | ported from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 4 | """ 5 | 6 | import math 7 | import torch.nn as nn 8 | import torchvision.transforms as transforms 9 | 10 | __all__ = ['VGG16', 'VGG16BN', 'VGG19', 'VGG19BN'] 11 | 12 | 13 | def make_layers(cfg, batch_norm=False): 14 | layers = list() 15 | in_channels = 3 16 | for v in cfg: 17 | if v == 'M': 18 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 19 | else: 20 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 21 | if batch_norm: 22 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 23 | else: 24 | layers += [conv2d, nn.ReLU(inplace=True)] 25 | in_channels = v 26 | return nn.Sequential(*layers) 27 | 28 | 29 | cfg = { 30 | 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 31 | 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 32 | 512, 512, 512, 512, 'M'], 33 | } 34 | 35 | 36 | class VGG(nn.Module): 37 | def __init__(self, num_classes=10, depth=16, batch_norm=False): 38 | super(VGG, self).__init__() 39 | self.features = make_layers(cfg[depth], batch_norm) 40 | self.classifier = nn.Sequential( 41 | nn.Dropout(), 42 | nn.Linear(512, 512), 43 | nn.ReLU(True), 44 | nn.Dropout(), 45 | nn.Linear(512, 512), 46 | nn.ReLU(True), 47 | nn.Linear(512, num_classes), 48 | ) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 53 | m.weight.data.normal_(0, math.sqrt(2. / n)) 54 | m.bias.data.zero_() 55 | 56 | def forward(self, x): 57 | x = self.features(x) 58 | x = x.view(x.size(0), -1) 59 | x = self.classifier(x) 60 | return x 61 | 62 | 63 | class Base: 64 | base = VGG 65 | args = list() 66 | kwargs = dict() 67 | transform_train = transforms.Compose([ 68 | transforms.RandomHorizontalFlip(), 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.ToTensor(), 71 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 72 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 73 | ]) 74 | 75 | transform_test = transforms.Compose([ 76 | transforms.ToTensor(), 77 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 78 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 79 | ]) 80 | 81 | 82 | class VGG16(Base): 83 | pass 84 | 85 | 86 | class VGG16BN(Base): 87 | kwargs = {'batch_norm': True} 88 | 89 | 90 | class VGG19(Base): 91 | kwargs = {'depth': 19} 92 | 93 | 94 | class VGG19BN(Base): 95 | kwargs = {'depth': 19, 'batch_norm': True} 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TWA 2 | The code is the official implementation of our ICLR paper 3 | [Trainable Weight Averaging: Efficient Training by Optimizing Historical Solutions](https://openreview.net/pdf?id=8wbnpOJY-f). For the journal version, please refer to this [branch](https://github.com/nblt/TWA/tree/journal). 4 | 5 | We propose to conduct neural network training in a tiny subspace spanned by historical solutions. Such optimization is equivalent to performing weight averaging on these solutions with trainable coefficients (TWA), in contrast with the equal averaging coefficients as in [SWA](https://github.com/timgaripov/swa). We show that a good solution can emerge early in DNN's training by properly averaging historical solutions with TWA. In this way, we are able to achieve great training efficiency (e.g. saving over **30%** training epochs on CIFAR / ImageNet) by optimizing these historical solutions. We also provide an efficient and scalable framework for multi-node training. Besides, TWA is also able to improve finetune results from multiple training configurations, which we are currently focusing on. This [colab](https://colab.research.google.com/drive/1fxUJ0K8dd7V3gsozmKsHhfdYHhYVB-WZ?usp=sharing) provides an exploratory example we adapt from [Model Soups](https://github.com/mlfoundations/model-soups). 6 | 7 | 8 |
9 | 10 | 11 |
12 | 13 | 14 | ## Dependencies 15 | 16 | Install required dependencies: 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## How to run 23 | 24 | ### TWA in tail stage training 25 | We first show that TWA could improve the performance of SWA in the original SWA setting, where the improvements are more significant when the tail learning rate `swa_lr` is larger. 26 | ``` 27 | cd swa 28 | ``` 29 | First, run SWA using original [code](https://github.com/timgaripov/swa): 30 | ``` 31 | bash run.sh 32 | ``` 33 | Then, we could perform TWA using: 34 | ``` 35 | bash run_twa.sh 36 | ``` 37 | The training configuration is easy to set as you need in the scripts. 38 | 39 | ### TWA in head stage training 40 | In this part, we conduct TWA in the head training stage, where we achieve considerably **30%-40%** epochs saving on CIFAR-10/100 and ImageNet, with a comparable or even better performance against regular training. 41 | We show sample usages in `run.sh`. 42 | 43 | For the first step, we conduct regular training for generating the historical solutions (for ImageNet training, the dataset need to be prepared at folder `path`). For example, 44 | 45 | ``` 46 | datasets=CIFAR100 47 | model=VGG16BN 48 | DST=results/$model\_$datasets 49 | 50 | CUDA_VISIBLE_DEVICES=0 python -u train_sgd_cifar.py --datasets $datasets \ 51 | --arch=$model --epochs=200 --lr 0.1 \ 52 | --save-dir=$DST/checkpoints --log-dir=$DST -p 100 53 | ``` 54 | Then, we conduct TWA training for quickly composing a good solution utilizing historical solutions (note that here we only utilize the first 100 epoch checkpoints): 55 | ``` 56 | CUDA_VISIBLE_DEVICES=0 python -u train_twa.py --epochs 10 --datasets $datasets \ 57 | --opt SGD --extract Schmidt --schedule step \ 58 | --lr 2 --params_start 0 --params_end 101 --train_start -1 --wd 0.00001 \ 59 | --batch-size 128 --arch=$model \ 60 | --save-dir=$DST/checkpoints --log-dir=$DST 61 | ``` 62 | 63 | ## Citation 64 | If you find this work helpful, please cite: 65 | ``` 66 | @inproceedings{ 67 | li2023trainable, 68 | title={Trainable Weight Averaging: Efficient Training by Optimizing Historical Solutions}, 69 | author={Tao Li and Zhehao Huang and Qinghua Tao and Yingwen Wu and Xiaolin Huang}, 70 | booktitle={The Eleventh International Conference on Learning Representations}, 71 | year={2023}, 72 | url={https://openreview.net/forum?id=8wbnpOJY-f} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /swa/utils_swa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | num_classes_dict = { 5 | "CIFAR10":10, 6 | "CIFAR100":100, 7 | } 8 | 9 | 10 | def adjust_learning_rate(optimizer, lr): 11 | for param_group in optimizer.param_groups: 12 | param_group['lr'] = lr 13 | return lr 14 | 15 | 16 | def save_checkpoint(dir, epoch, **kwargs): 17 | state = { 18 | 'epoch': epoch, 19 | } 20 | state.update(kwargs) 21 | filepath = os.path.join(dir, 'checkpoint-%d.pt' % epoch) 22 | torch.save(state, filepath) 23 | 24 | 25 | def train_epoch(loader, model, criterion, optimizer): 26 | loss_sum = 0.0 27 | correct = 0.0 28 | 29 | model.train() 30 | 31 | for i, (input, target) in enumerate(loader): 32 | # input = input.cuda(async=True) 33 | # target = target.cuda(async=True) 34 | input = input.cuda() 35 | target = target.cuda() 36 | 37 | input_var = torch.autograd.Variable(input) 38 | target_var = torch.autograd.Variable(target) 39 | 40 | output = model(input_var) 41 | loss = criterion(output, target_var) 42 | 43 | optimizer.zero_grad() 44 | loss.backward() 45 | optimizer.step() 46 | 47 | loss_sum += loss.item() * input.size(0) 48 | pred = output.data.max(1, keepdim=True)[1] 49 | correct += pred.eq(target_var.data.view_as(pred)).sum().item() 50 | 51 | return { 52 | 'loss': loss_sum / len(loader.dataset), 53 | 'accuracy': correct / len(loader.dataset) * 100.0, 54 | } 55 | 56 | 57 | def eval(loader, model, criterion): 58 | loss_sum = 0.0 59 | correct = 0.0 60 | 61 | model.eval() 62 | 63 | for i, (input, target) in enumerate(loader): 64 | input = input.cuda() 65 | target = target.cuda() 66 | input_var = torch.autograd.Variable(input) 67 | target_var = torch.autograd.Variable(target) 68 | 69 | output = model(input_var) 70 | loss = criterion(output, target_var) 71 | 72 | loss_sum += loss.item() * input.size(0) 73 | pred = output.data.max(1, keepdim=True)[1] 74 | correct += pred.eq(target_var.data.view_as(pred)).sum().item() 75 | 76 | return { 77 | 'loss': loss_sum / len(loader.dataset), 78 | 'accuracy': correct / len(loader.dataset) * 100.0, 79 | } 80 | 81 | 82 | def moving_average(net1, net2, alpha=1): 83 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 84 | param1.data *= (1.0 - alpha) 85 | param1.data += param2.data * alpha 86 | 87 | 88 | def _check_bn(module, flag): 89 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 90 | flag[0] = True 91 | 92 | 93 | def check_bn(model): 94 | flag = [False] 95 | model.apply(lambda module: _check_bn(module, flag)) 96 | return flag[0] 97 | 98 | 99 | def reset_bn(module): 100 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 101 | module.running_mean = torch.zeros_like(module.running_mean) 102 | module.running_var = torch.ones_like(module.running_var) 103 | 104 | 105 | def _get_momenta(module, momenta): 106 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 107 | momenta[module] = module.momentum 108 | 109 | 110 | def _set_momenta(module, momenta): 111 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 112 | module.momentum = momenta[module] 113 | 114 | 115 | def bn_update(loader, model): 116 | """ 117 | BatchNorm buffers update (if any). 118 | Performs 1 epochs to estimate buffers average using train dataset. 119 | 120 | :param loader: train dataset loader for buffers average estimation. 121 | :param model: model being update 122 | :return: None 123 | """ 124 | if not check_bn(model): 125 | return 126 | model.train() 127 | momenta = {} 128 | model.apply(reset_bn) 129 | model.apply(lambda module: _get_momenta(module, momenta)) 130 | n = 0 131 | for input, _ in loader: 132 | input = input.cuda() 133 | input_var = torch.autograd.Variable(input) 134 | b = input_var.data.size(0) 135 | 136 | momentum = b / (n + b) 137 | for module in momenta.keys(): 138 | module.momentum = momentum 139 | 140 | model(input_var) 141 | n += b 142 | 143 | model.apply(lambda module: _set_momenta(module, momenta)) 144 | -------------------------------------------------------------------------------- /models/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet model definition 3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py 4 | """ 5 | 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import math 9 | 10 | __all__ = ['PreResNet18', 'PreResNet110', 'PreResNet164'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.bn1 = nn.BatchNorm2d(inplanes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.bn1(x) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(inplanes) 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.bn1(x) 70 | out = self.relu(out) 71 | out = self.conv1(out) 72 | 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | out = self.conv2(out) 76 | 77 | out = self.bn3(out) 78 | out = self.relu(out) 79 | out = self.conv3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | 86 | return out 87 | 88 | 89 | class PreResNet(nn.Module): 90 | 91 | def __init__(self, num_classes=10, depth=110): 92 | super(PreResNet, self).__init__() 93 | if depth >= 44: 94 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 95 | n = (depth - 2) // 9 96 | block = Bottleneck 97 | else: 98 | print ('depth:', (depth - 2) % 6) 99 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | 103 | 104 | self.inplanes = 16 105 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 106 | bias=False) 107 | self.layer1 = self._make_layer(block, 16, n) 108 | self.layer2 = self._make_layer(block, 32, n, stride=2) 109 | self.layer3 = self._make_layer(block, 64, n, stride=2) 110 | self.bn = nn.BatchNorm2d(64 * block.expansion) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.avgpool = nn.AvgPool2d(8) 113 | self.fc = nn.Linear(64 * block.expansion, num_classes) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | nn.Conv2d(self.inplanes, planes * block.expansion, 128 | kernel_size=1, stride=stride, bias=False), 129 | ) 130 | 131 | layers = list() 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | 142 | x = self.layer1(x) # 32x32 143 | x = self.layer2(x) # 16x16 144 | x = self.layer3(x) # 8x8 145 | x = self.bn(x) 146 | x = self.relu(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | class PreResNet18: 156 | base = PreResNet 157 | args = list() 158 | kwargs = {'depth': 18} 159 | 160 | class PreResNet110: 161 | base = PreResNet 162 | args = list() 163 | kwargs = {'depth': 110} 164 | 165 | class PreResNet164: 166 | base = PreResNet 167 | args = list() 168 | kwargs = {'depth': 164} -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | 3 | 4 | 5 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 6 | 7 | Deep Residual Learning for Image Recognition 8 | https://arxiv.org/abs/1512.03385v1 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class BasicBlock(nn.Module): 15 | """Basic Block for resnet 18 and resnet 34 16 | 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | 54 | """ 55 | expansion = 4 56 | def __init__(self, in_channels, out_channels, stride=1): 57 | super().__init__() 58 | self.residual_function = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 63 | nn.BatchNorm2d(out_channels), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 66 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 67 | ) 68 | 69 | self.shortcut = nn.Sequential() 70 | 71 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 75 | ) 76 | 77 | def forward(self, x): 78 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, num_block, num_classes=100): 83 | super().__init__() 84 | 85 | self.in_channels = 64 86 | 87 | self.conv1 = nn.Sequential( 88 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 89 | nn.BatchNorm2d(64), 90 | nn.ReLU(inplace=True)) 91 | #we use a different inputsize than the original paper 92 | #so conv2_x's stride is 1 93 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 94 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 95 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 96 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 97 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.fc = nn.Linear(512 * block.expansion, num_classes) 99 | 100 | def _make_layer(self, block, out_channels, num_blocks, stride): 101 | """make resnet layers(by layer i didnt mean this 'layer' was the 102 | same as a neuron netowork layer, ex. conv layer), one layer may 103 | contain more than one residual block 104 | 105 | Args: 106 | block: block type, basic block or bottle neck block 107 | out_channels: output depth channel number of this layer 108 | num_blocks: how many blocks per layer 109 | stride: the stride of the first block of this layer 110 | 111 | Return: 112 | return a resnet layer 113 | """ 114 | 115 | # we have num_block blocks per layer, the first block 116 | # could be 1 or 2, other blocks would always be 1 117 | strides = [stride] + [1] * (num_blocks - 1) 118 | layers = [] 119 | for stride in strides: 120 | layers.append(block(self.in_channels, out_channels, stride)) 121 | self.in_channels = out_channels * block.expansion 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | output = self.conv1(x) 127 | output = self.conv2_x(output) 128 | output = self.conv3_x(output) 129 | output = self.conv4_x(output) 130 | output = self.conv5_x(output) 131 | output = self.avg_pool(output) 132 | output = output.view(output.size(0), -1) 133 | output = self.fc(output) 134 | 135 | return output 136 | 137 | class resnet18: 138 | base = ResNet 139 | args = list() 140 | kwargs = {'block': BasicBlock, 'num_block': [2, 2, 2, 2]} 141 | 142 | # def resnet18(): 143 | # """ return a ResNet 18 object 144 | # """ 145 | # kwargs = {} 146 | # return ResNet(BasicBlock, [2, 2, 2, 2]) 147 | 148 | def resnet34(): 149 | """ return a ResNet 34 object 150 | """ 151 | return ResNet(BasicBlock, [3, 4, 6, 3]) 152 | 153 | def resnet50(): 154 | """ return a ResNet 50 object 155 | """ 156 | return ResNet(BottleNeck, [3, 4, 6, 3]) 157 | 158 | def resnet101(): 159 | """ return a ResNet 101 object 160 | """ 161 | return ResNet(BottleNeck, [3, 4, 23, 3]) 162 | 163 | def resnet152(): 164 | """ return a ResNet 152 object 165 | """ 166 | return ResNet(BottleNeck, [3, 8, 36, 3]) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /swa/models/preresnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreResNet model definition 3 | ported from https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/preresnet.py 4 | """ 5 | 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import math 9 | 10 | __all__ = ['PreResNet110', 'PreResNet164'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.bn1 = nn.BatchNorm2d(inplanes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.bn1(x) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(inplanes) 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.bn1(x) 70 | out = self.relu(out) 71 | out = self.conv1(out) 72 | 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | out = self.conv2(out) 76 | 77 | out = self.bn3(out) 78 | out = self.relu(out) 79 | out = self.conv3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | 86 | return out 87 | 88 | 89 | class PreResNet(nn.Module): 90 | 91 | def __init__(self, num_classes=10, depth=110): 92 | super(PreResNet, self).__init__() 93 | if depth >= 44: 94 | assert (depth - 2) % 9 == 0, 'depth should be 9n+2' 95 | n = (depth - 2) // 9 96 | block = Bottleneck 97 | else: 98 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 99 | n = (depth - 2) // 6 100 | block = BasicBlock 101 | 102 | 103 | self.inplanes = 16 104 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 105 | bias=False) 106 | self.layer1 = self._make_layer(block, 16, n) 107 | self.layer2 = self._make_layer(block, 32, n, stride=2) 108 | self.layer3 = self._make_layer(block, 64, n, stride=2) 109 | self.bn = nn.BatchNorm2d(64 * block.expansion) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.avgpool = nn.AvgPool2d(8) 112 | self.fc = nn.Linear(64 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | ) 129 | 130 | layers = list() 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | 141 | x = self.layer1(x) # 32x32 142 | x = self.layer2(x) # 16x16 143 | x = self.layer3(x) # 8x8 144 | x = self.bn(x) 145 | x = self.relu(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | class PreResNet110: 155 | base = PreResNet 156 | args = list() 157 | kwargs = {'depth': 110} 158 | transform_train = transforms.Compose([ 159 | transforms.RandomCrop(32, padding=4), 160 | transforms.RandomHorizontalFlip(), 161 | transforms.ToTensor(), 162 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 163 | ]) 164 | transform_test = transforms.Compose([ 165 | transforms.ToTensor(), 166 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 167 | ]) 168 | 169 | class PreResNet164: 170 | base = PreResNet 171 | args = list() 172 | kwargs = {'depth': 164} 173 | transform_train = transforms.Compose([ 174 | transforms.RandomCrop(32, padding=4), 175 | transforms.RandomHorizontalFlip(), 176 | transforms.ToTensor(), 177 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 178 | ]) 179 | transform_test = transforms.Compose([ 180 | transforms.ToTensor(), 181 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 182 | ]) 183 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models_imagenet 11 | 12 | import numpy as np 13 | import random 14 | import os 15 | import time 16 | import models 17 | import sys 18 | 19 | def set_seed(seed=1): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | class Logger(object): 28 | def __init__(self,fileN ="Default.log"): 29 | self.terminal = sys.stdout 30 | self.log = open(fileN,"a") 31 | 32 | def write(self,message): 33 | self.terminal.write(message) 34 | self.log.write(message) 35 | 36 | def flush(self): 37 | pass 38 | 39 | def adjust_learning_rate(optimizer, lr): 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = lr 42 | return lr 43 | 44 | ################################ datasets ####################################### 45 | 46 | import torchvision.transforms as transforms 47 | import torchvision.datasets as datasets 48 | from torch.utils.data import DataLoader 49 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder 50 | 51 | def get_datasets(args): 52 | if args.datasets == 'CIFAR10': 53 | print ('cifar10 dataset!') 54 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 55 | 56 | train_loader = torch.utils.data.DataLoader( 57 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 58 | transforms.RandomHorizontalFlip(), 59 | transforms.RandomCrop(32, 4), 60 | transforms.ToTensor(), 61 | normalize, 62 | ]), download=True), 63 | batch_size=args.batch_size, shuffle=True, 64 | num_workers=args.workers, pin_memory=True) 65 | 66 | val_loader = torch.utils.data.DataLoader( 67 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | normalize, 70 | ])), 71 | batch_size=128, shuffle=False, 72 | num_workers=args.workers, pin_memory=True) 73 | 74 | elif args.datasets == 'CIFAR100': 75 | print ('cifar100 dataset!') 76 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 77 | 78 | train_loader = torch.utils.data.DataLoader( 79 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 80 | transforms.RandomHorizontalFlip(), 81 | transforms.RandomCrop(32, 4), 82 | transforms.ToTensor(), 83 | normalize, 84 | ]), download=True), 85 | batch_size=args.batch_size, shuffle=True, 86 | num_workers=args.workers, pin_memory=True) 87 | 88 | val_loader = torch.utils.data.DataLoader( 89 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 90 | transforms.ToTensor(), 91 | normalize, 92 | ])), 93 | batch_size=128, shuffle=False, 94 | num_workers=args.workers, pin_memory=True) 95 | 96 | elif args.datasets == 'ImageNet': 97 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train') 98 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | 102 | train_dataset = datasets.ImageFolder( 103 | traindir, 104 | transforms.Compose([ 105 | transforms.RandomResizedCrop(224), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | normalize, 109 | ])) 110 | 111 | train_loader = torch.utils.data.DataLoader( 112 | train_dataset, batch_size=args.batch_size, shuffle=True, 113 | num_workers=args.workers, pin_memory=True) 114 | 115 | val_loader = torch.utils.data.DataLoader( 116 | datasets.ImageFolder(valdir, transforms.Compose([ 117 | transforms.Resize(256), 118 | transforms.CenterCrop(224), 119 | transforms.ToTensor(), 120 | normalize, 121 | ])), 122 | batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.workers) 124 | 125 | return train_loader, val_loader 126 | 127 | 128 | def get_imagenet_dataset(): 129 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train') 130 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val') 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 132 | std=[0.229, 0.224, 0.225]) 133 | 134 | train_dataset = datasets.ImageFolder( 135 | traindir, 136 | transforms.Compose([ 137 | transforms.RandomResizedCrop(224), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | normalize, 141 | ])) 142 | 143 | val_dataset = datasets.ImageFolder( 144 | valdir, 145 | transforms.Compose([ 146 | transforms.Resize(256), 147 | transforms.CenterCrop(224), 148 | transforms.ToTensor(), 149 | normalize, 150 | ])) 151 | return train_dataset, val_dataset 152 | 153 | ################################ training & evaluation ####################################### 154 | 155 | def eval_model(loader, model, criterion): 156 | loss_sum = 0.0 157 | correct = 0.0 158 | 159 | model.eval() 160 | 161 | for i, (input, target) in enumerate(loader): 162 | input = input.cuda() 163 | target = target.cuda() 164 | 165 | output = model(input) 166 | loss = criterion(output, target) 167 | 168 | loss_sum += loss.item() * input.size(0) 169 | pred = output.data.max(1, keepdim=True)[1] 170 | correct += pred.eq(target.data.view_as(pred)).sum().item() 171 | 172 | return { 173 | 'loss': loss_sum / len(loader.dataset), 174 | 'accuracy': correct / len(loader.dataset) * 100.0, 175 | } 176 | 177 | def bn_update(loader, model): 178 | model.train() 179 | for i, (input, target) in enumerate(loader): 180 | target = target.cuda() 181 | input_var = input.cuda() 182 | target_var = target 183 | 184 | # compute output 185 | output = model(input_var) 186 | 187 | def get_model(args): 188 | print('Model: {}'.format(args.arch)) 189 | 190 | if args.datasets == 'ImageNet': 191 | return models_imagenet.__dict__[args.arch]() 192 | 193 | if args.datasets == 'CIFAR10': 194 | num_classes = 10 195 | elif args.datasets == 'CIFAR100': 196 | num_classes = 100 197 | 198 | model_cfg = getattr(models, args.arch) 199 | 200 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 201 | 202 | 203 | -------------------------------------------------------------------------------- /swa/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | import torchvision.datasets as datasets 10 | import torchvision.models as models_imagenet 11 | 12 | import numpy as np 13 | import random 14 | import os 15 | import time 16 | import models 17 | import sys 18 | 19 | def set_seed(seed=1): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | class Logger(object): 28 | def __init__(self,fileN ="Default.log"): 29 | self.terminal = sys.stdout 30 | self.log = open(fileN,"a") 31 | 32 | def write(self,message): 33 | self.terminal.write(message) 34 | self.log.write(message) 35 | 36 | def flush(self): 37 | pass 38 | 39 | def adjust_learning_rate(optimizer, lr): 40 | for param_group in optimizer.param_groups: 41 | param_group['lr'] = lr 42 | return lr 43 | 44 | ################################ datasets ####################################### 45 | 46 | import torchvision.transforms as transforms 47 | import torchvision.datasets as datasets 48 | from torch.utils.data import DataLoader 49 | from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder 50 | 51 | def get_datasets(args): 52 | if args.datasets == 'CIFAR10': 53 | print ('cifar10 dataset!') 54 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 55 | 56 | train_loader = torch.utils.data.DataLoader( 57 | datasets.CIFAR10(root='./datasets/', train=True, transform=transforms.Compose([ 58 | transforms.RandomHorizontalFlip(), 59 | transforms.RandomCrop(32, 4), 60 | transforms.ToTensor(), 61 | normalize, 62 | ]), download=True), 63 | batch_size=args.batch_size, shuffle=True, 64 | num_workers=args.workers, pin_memory=True) 65 | 66 | val_loader = torch.utils.data.DataLoader( 67 | datasets.CIFAR10(root='./datasets/', train=False, transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | normalize, 70 | ])), 71 | batch_size=128, shuffle=False, 72 | num_workers=args.workers, pin_memory=True) 73 | 74 | elif args.datasets == 'CIFAR100': 75 | print ('cifar100 dataset!') 76 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 77 | 78 | train_loader = torch.utils.data.DataLoader( 79 | datasets.CIFAR100(root='./datasets/', train=True, transform=transforms.Compose([ 80 | transforms.RandomHorizontalFlip(), 81 | transforms.RandomCrop(32, 4), 82 | transforms.ToTensor(), 83 | normalize, 84 | ]), download=True), 85 | batch_size=args.batch_size, shuffle=True, 86 | num_workers=args.workers, pin_memory=True) 87 | 88 | val_loader = torch.utils.data.DataLoader( 89 | datasets.CIFAR100(root='./datasets/', train=False, transform=transforms.Compose([ 90 | transforms.ToTensor(), 91 | normalize, 92 | ])), 93 | batch_size=128, shuffle=False, 94 | num_workers=args.workers, pin_memory=True) 95 | 96 | elif args.datasets == 'ImageNet': 97 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train') 98 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | 102 | train_dataset = datasets.ImageFolder( 103 | traindir, 104 | transforms.Compose([ 105 | transforms.RandomResizedCrop(224), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | normalize, 109 | ])) 110 | 111 | train_loader = torch.utils.data.DataLoader( 112 | train_dataset, batch_size=args.batch_size, shuffle=True, 113 | num_workers=args.workers, pin_memory=True) 114 | 115 | val_loader = torch.utils.data.DataLoader( 116 | datasets.ImageFolder(valdir, transforms.Compose([ 117 | transforms.Resize(256), 118 | transforms.CenterCrop(224), 119 | transforms.ToTensor(), 120 | normalize, 121 | ])), 122 | batch_size=args.batch_size, shuffle=False, 123 | num_workers=args.workers) 124 | 125 | return train_loader, val_loader 126 | 127 | 128 | def get_imagenet_dataset(): 129 | traindir = os.path.join('/home/datasets/ILSVRC2012/', 'train') 130 | valdir = os.path.join('/home/datasets/ILSVRC2012/', 'val') 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 132 | std=[0.229, 0.224, 0.225]) 133 | 134 | train_dataset = datasets.ImageFolder( 135 | traindir, 136 | transforms.Compose([ 137 | transforms.RandomResizedCrop(224), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | normalize, 141 | ])) 142 | 143 | val_dataset = datasets.ImageFolder( 144 | valdir, 145 | transforms.Compose([ 146 | transforms.Resize(256), 147 | transforms.CenterCrop(224), 148 | transforms.ToTensor(), 149 | normalize, 150 | ])) 151 | return train_dataset, val_dataset 152 | 153 | ################################ training & evaluation ####################################### 154 | 155 | def eval_model(loader, model, criterion): 156 | loss_sum = 0.0 157 | correct = 0.0 158 | 159 | model.eval() 160 | 161 | for i, (input, target) in enumerate(loader): 162 | input = input.cuda() 163 | target = target.cuda() 164 | 165 | output = model(input) 166 | loss = criterion(output, target) 167 | 168 | loss_sum += loss.item() * input.size(0) 169 | pred = output.data.max(1, keepdim=True)[1] 170 | correct += pred.eq(target_var.data.view_as(pred)).sum().item() 171 | 172 | return { 173 | 'loss': loss_sum / len(loader.dataset), 174 | 'accuracy': correct / len(loader.dataset) * 100.0, 175 | } 176 | 177 | def bn_update(loader, model): 178 | model.train() 179 | for i, (input, target) in enumerate(loader): 180 | target = target.cuda() 181 | input_var = input.cuda() 182 | target_var = target 183 | 184 | # compute output 185 | output = model(input_var) 186 | 187 | def get_model(args): 188 | print('Model: {}'.format(args.arch)) 189 | 190 | if args.datasets == 'ImageNet': 191 | return models_imagenet.__dict__[args.arch]() 192 | 193 | if args.datasets == 'CIFAR10': 194 | num_classes = 10 195 | elif args.datasets == 'CIFAR100': 196 | num_classes = 100 197 | 198 | model_cfg = getattr(models, args.arch) 199 | 200 | return model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 201 | 202 | 203 | -------------------------------------------------------------------------------- /swa/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | import models 9 | import utils_swa as utils 10 | import tabulate 11 | 12 | 13 | parser = argparse.ArgumentParser(description='SGD/SWA training') 14 | parser.add_argument('--dir', type=str, default=None, required=True, help='training directory (default: None)') 15 | 16 | parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset name (default: CIFAR10)') 17 | parser.add_argument('--data_path', type=str, default=None, required=True, metavar='PATH', 18 | help='path to datasets location (default: None)') 19 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size (default: 128)') 20 | parser.add_argument('--num_workers', type=int, default=4, metavar='N', help='number of workers (default: 4)') 21 | parser.add_argument('--model', type=str, default=None, required=True, metavar='MODEL', 22 | help='model name (default: None)') 23 | 24 | parser.add_argument('--resume', type=str, default=None, metavar='CKPT', 25 | help='checkpoint to resume training from (default: None)') 26 | 27 | parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 200)') 28 | parser.add_argument('--save_freq', type=int, default=25, metavar='N', help='save frequency (default: 25)') 29 | parser.add_argument('--eval_freq', type=int, default=5, metavar='N', help='evaluation frequency (default: 5)') 30 | parser.add_argument('--lr_init', type=float, default=0.1, metavar='LR', help='initial learning rate (default: 0.01)') 31 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 32 | parser.add_argument('--wd', type=float, default=1e-4, help='weight decay (default: 1e-4)') 33 | 34 | parser.add_argument('--swa', action='store_true', help='swa usage flag (default: off)') 35 | parser.add_argument('--swa_start', type=float, default=161, metavar='N', help='SWA start epoch number (default: 161)') 36 | parser.add_argument('--swa_lr', type=float, default=0.05, metavar='LR', help='SWA LR (default: 0.05)') 37 | parser.add_argument('--swa_c_epochs', type=int, default=1, metavar='N', 38 | help='SWA model collection frequency/cycle length in epochs (default: 1)') 39 | 40 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 41 | 42 | args = parser.parse_args() 43 | 44 | 45 | if not os.path.exists(args.dir): 46 | os.makedirs(args.dir) 47 | 48 | save_dir = os.path.join(args.dir, 'checkpoints') 49 | if not os.path.exists(save_dir): 50 | os.makedirs(save_dir) 51 | 52 | print('Preparing directory %s' % args.dir) 53 | os.makedirs(args.dir, exist_ok=True) 54 | with open(os.path.join(args.dir, 'command.sh'), 'w') as f: 55 | f.write(' '.join(sys.argv)) 56 | f.write('\n') 57 | 58 | torch.backends.cudnn.benchmark = True 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed(args.seed) 61 | 62 | print('Using model %s' % args.model) 63 | model_cfg = getattr(models, args.model) 64 | 65 | print('Loading dataset %s from %s' % (args.dataset, args.data_path)) 66 | ds = getattr(torchvision.datasets, args.dataset) 67 | path = os.path.join(args.data_path, args.dataset.lower()) 68 | train_set = ds(path, train=True, download=True, transform=model_cfg.transform_train) 69 | test_set = ds(path, train=False, download=True, transform=model_cfg.transform_test) 70 | loaders = { 71 | 'train': torch.utils.data.DataLoader( 72 | train_set, 73 | batch_size=args.batch_size, 74 | shuffle=True, 75 | num_workers=args.num_workers, 76 | pin_memory=True 77 | ), 78 | 'test': torch.utils.data.DataLoader( 79 | test_set, 80 | batch_size=args.batch_size, 81 | shuffle=False, 82 | num_workers=args.num_workers, 83 | pin_memory=True 84 | ) 85 | } 86 | print (train_set) 87 | # num_classes = max(train_set.train_labels) + 1 88 | num_classes = utils.num_classes_dict[args.dataset] 89 | 90 | print('Preparing model') 91 | model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 92 | model.cuda() 93 | 94 | 95 | if args.swa: 96 | print('SWA training') 97 | swa_model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs) 98 | swa_model.cuda() 99 | swa_n = 0 100 | else: 101 | print('SGD training') 102 | 103 | 104 | def schedule(epoch): 105 | t = (epoch) / (args.swa_start if args.swa else args.epochs) 106 | lr_ratio = args.swa_lr / args.lr_init if args.swa else 0.01 107 | if t <= 0.5: 108 | factor = 1.0 109 | elif t <= 0.9: 110 | factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 111 | else: 112 | factor = lr_ratio 113 | return args.lr_init * factor 114 | 115 | 116 | criterion = F.cross_entropy 117 | optimizer = torch.optim.SGD( 118 | model.parameters(), 119 | lr=args.lr_init, 120 | momentum=args.momentum, 121 | weight_decay=args.wd 122 | ) 123 | 124 | start_epoch = 0 125 | if args.resume is not None: 126 | print('Resume training from %s' % args.resume) 127 | checkpoint = torch.load(args.resume) 128 | start_epoch = checkpoint['epoch'] 129 | model.load_state_dict(checkpoint['state_dict']) 130 | optimizer.load_state_dict(checkpoint['optimizer']) 131 | if args.swa: 132 | swa_state_dict = checkpoint['swa_state_dict'] 133 | if swa_state_dict is not None: 134 | swa_model.load_state_dict(swa_state_dict) 135 | swa_n_ckpt = checkpoint['swa_n'] 136 | if swa_n_ckpt is not None: 137 | swa_n = swa_n_ckpt 138 | 139 | # print (utils.eval(loaders['train'], swa_model, criterion)) 140 | # print (utils.eval(loaders['test'], swa_model, criterion)) 141 | # sys.kill() 142 | 143 | columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time'] 144 | if args.swa: 145 | columns = columns[:-1] + ['swa_te_loss', 'swa_te_acc'] + columns[-1:] 146 | swa_res = {'loss': None, 'accuracy': None} 147 | 148 | utils.save_checkpoint( 149 | args.dir, 150 | start_epoch, 151 | state_dict=model.state_dict(), 152 | swa_state_dict=swa_model.state_dict() if args.swa else None, 153 | swa_n=swa_n if args.swa else None, 154 | optimizer=optimizer.state_dict() 155 | ) 156 | 157 | # DLDR sampling 158 | sample_idx = 0 159 | torch.save(model.state_dict(), os.path.join(save_dir, str(sample_idx) + '.pt')) 160 | 161 | for epoch in range(start_epoch, args.epochs): 162 | time_ep = time.time() 163 | 164 | lr = schedule(epoch) 165 | utils.adjust_learning_rate(optimizer, lr) 166 | train_res = utils.train_epoch(loaders['train'], model, criterion, optimizer) 167 | if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: 168 | test_res = utils.eval(loaders['test'], model, criterion) 169 | else: 170 | test_res = {'loss': None, 'accuracy': None} 171 | 172 | if args.swa and (epoch + 1) >= args.swa_start and (epoch + 1 - args.swa_start) % args.swa_c_epochs == 0: 173 | utils.moving_average(swa_model, model, 1.0 / (swa_n + 1)) 174 | swa_n += 1 175 | if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: 176 | utils.bn_update(loaders['train'], swa_model) 177 | swa_res = utils.eval(loaders['test'], swa_model, criterion) 178 | else: 179 | swa_res = {'loss': None, 'accuracy': None} 180 | 181 | if (epoch + 1) % args.save_freq == 0: 182 | utils.save_checkpoint( 183 | args.dir, 184 | epoch + 1, 185 | state_dict=model.state_dict(), 186 | swa_state_dict=swa_model.state_dict() if args.swa else None, 187 | swa_n=swa_n if args.swa else None, 188 | optimizer=optimizer.state_dict() 189 | ) 190 | 191 | # DLDR sampling 192 | sample_idx += 1 193 | torch.save(model.state_dict(), os.path.join(save_dir, str(sample_idx) + '.pt')) 194 | 195 | time_ep = time.time() - time_ep 196 | values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], test_res['loss'], test_res['accuracy'], time_ep] 197 | if args.swa: 198 | values = values[:-1] + [swa_res['loss'], swa_res['accuracy']] + values[-1:] 199 | table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') 200 | if epoch % 40 == 0: 201 | table = table.split('\n') 202 | table = '\n'.join([table[1]] + table) 203 | else: 204 | table = table.split('\n')[2] 205 | print(table) 206 | 207 | if args.epochs % args.save_freq != 0: 208 | utils.save_checkpoint( 209 | args.dir, 210 | args.epochs, 211 | state_dict=model.state_dict(), 212 | swa_state_dict=swa_model.state_dict() if args.swa else None, 213 | swa_n=swa_n if args.swa else None, 214 | optimizer=optimizer.state_dict() 215 | ) 216 | 217 | -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | WideResNet model definition 3 | ported from https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 4 | """ 5 | 6 | import torchvision.transforms as transforms 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | import math 11 | 12 | __all__ = ['WideResNet28x10', 'WideResNet16x8'] 13 | 14 | 15 | # def conv3x3(in_planes, out_planes, stride=1): 16 | # return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 17 | 18 | 19 | # def conv_init(m): 20 | # classname = m.__class__.__name__ 21 | # if classname.find('Conv') != -1: 22 | # init.xavier_uniform(m.weight, gain=math.sqrt(2)) 23 | # init.constant(m.bias, 0) 24 | # elif classname.find('BatchNorm') != -1: 25 | # init.constant(m.weight, 1) 26 | # init.constant(m.bias, 0) 27 | 28 | 29 | # class WideBasic(nn.Module): 30 | # def __init__(self, in_planes, planes, dropout_rate, stride=1): 31 | # super(WideBasic, self).__init__() 32 | # self.bn1 = nn.BatchNorm2d(in_planes) 33 | # self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 34 | # self.dropout = nn.Dropout(p=dropout_rate) 35 | # self.bn2 = nn.BatchNorm2d(planes) 36 | # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 37 | 38 | # self.shortcut = nn.Sequential() 39 | # if stride != 1 or in_planes != planes: 40 | # self.shortcut = nn.Sequential( 41 | # nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 42 | # ) 43 | 44 | # def forward(self, x): 45 | # out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 46 | # out = self.conv2(F.relu(self.bn2(out))) 47 | # out += self.shortcut(x) 48 | 49 | # return out 50 | 51 | 52 | # class WideResNet(nn.Module): 53 | # def __init__(self, num_classes=10, depth=28, widen_factor=10, dropout_rate=0.): 54 | # super(WideResNet, self).__init__() 55 | # self.in_planes = 16 56 | 57 | # assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 58 | # n = (depth - 4) / 6 59 | # k = widen_factor 60 | 61 | # nstages = [16, 16 * k, 32 * k, 64 * k] 62 | 63 | # self.conv1 = conv3x3(3, nstages[0]) 64 | # self.layer1 = self._wide_layer(WideBasic, nstages[1], n, dropout_rate, stride=1) 65 | # self.layer2 = self._wide_layer(WideBasic, nstages[2], n, dropout_rate, stride=2) 66 | # self.layer3 = self._wide_layer(WideBasic, nstages[3], n, dropout_rate, stride=2) 67 | # self.bn1 = nn.BatchNorm2d(nstages[3], momentum=0.9) 68 | # self.linear = nn.Linear(nstages[3], num_classes) 69 | 70 | # def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 71 | # strides = [stride] + [1] * int(num_blocks - 1) 72 | # layers = [] 73 | 74 | # for stride in strides: 75 | # layers.append(block(self.in_planes, planes, dropout_rate, stride)) 76 | # self.in_planes = planes 77 | 78 | # return nn.Sequential(*layers) 79 | 80 | # def forward(self, x): 81 | # out = self.conv1(x) 82 | # out = self.layer1(out) 83 | # out = self.layer2(out) 84 | # out = self.layer3(out) 85 | # out = F.relu(self.bn1(out)) 86 | # out = F.avg_pool2d(out, 8) 87 | # out = out.view(out.size(0), -1) 88 | # out = self.linear(out) 89 | 90 | # return out 91 | 92 | from collections import OrderedDict 93 | 94 | import torch 95 | import torch.nn as nn 96 | import torch.nn.functional as F 97 | 98 | 99 | class BasicUnit(nn.Module): 100 | def __init__(self, channels: int, dropout: float): 101 | super(BasicUnit, self).__init__() 102 | self.block = nn.Sequential(OrderedDict([ 103 | ("0_normalization", nn.BatchNorm2d(channels)), 104 | ("1_activation", nn.ReLU(inplace=True)), 105 | ("2_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 106 | ("3_normalization", nn.BatchNorm2d(channels)), 107 | ("4_activation", nn.ReLU(inplace=True)), 108 | ("5_dropout", nn.Dropout(dropout, inplace=True)), 109 | ("6_convolution", nn.Conv2d(channels, channels, (3, 3), stride=1, padding=1, bias=False)), 110 | ])) 111 | 112 | def forward(self, x): 113 | return x + self.block(x) 114 | 115 | 116 | class DownsampleUnit(nn.Module): 117 | def __init__(self, in_channels: int, out_channels: int, stride: int, dropout: float): 118 | super(DownsampleUnit, self).__init__() 119 | self.norm_act = nn.Sequential(OrderedDict([ 120 | ("0_normalization", nn.BatchNorm2d(in_channels)), 121 | ("1_activation", nn.ReLU(inplace=True)), 122 | ])) 123 | self.block = nn.Sequential(OrderedDict([ 124 | ("0_convolution", nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1, bias=False)), 125 | ("1_normalization", nn.BatchNorm2d(out_channels)), 126 | ("2_activation", nn.ReLU(inplace=True)), 127 | ("3_dropout", nn.Dropout(dropout, inplace=True)), 128 | ("4_convolution", nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1, bias=False)), 129 | ])) 130 | self.downsample = nn.Conv2d(in_channels, out_channels, (1, 1), stride=stride, padding=0, bias=False) 131 | 132 | def forward(self, x): 133 | x = self.norm_act(x) 134 | return self.block(x) + self.downsample(x) 135 | 136 | 137 | class Block(nn.Module): 138 | def __init__(self, in_channels: int, out_channels: int, stride: int, depth: int, dropout: float): 139 | super(Block, self).__init__() 140 | self.block = nn.Sequential( 141 | DownsampleUnit(in_channels, out_channels, stride, dropout), 142 | *(BasicUnit(out_channels, dropout) for _ in range(depth)) 143 | ) 144 | 145 | def forward(self, x): 146 | return self.block(x) 147 | 148 | 149 | class WideResNet(nn.Module): 150 | def __init__(self, depth: int, width_factor: int, dropout: float, in_channels: int, num_classes: int): 151 | super(WideResNet, self).__init__() 152 | 153 | self.filters = [16, 1 * 16 * width_factor, 2 * 16 * width_factor, 4 * 16 * width_factor] 154 | self.block_depth = (depth - 4) // (3 * 2) 155 | 156 | self.f = nn.Sequential(OrderedDict([ 157 | ("0_convolution", nn.Conv2d(in_channels, self.filters[0], (3, 3), stride=1, padding=1, bias=False)), 158 | ("1_block", Block(self.filters[0], self.filters[1], 1, self.block_depth, dropout)), 159 | ("2_block", Block(self.filters[1], self.filters[2], 2, self.block_depth, dropout)), 160 | ("3_block", Block(self.filters[2], self.filters[3], 2, self.block_depth, dropout)), 161 | ("4_normalization", nn.BatchNorm2d(self.filters[3])), 162 | ("5_activation", nn.ReLU(inplace=True)), 163 | ("6_pooling", nn.AvgPool2d(kernel_size=8)), 164 | ("7_flattening", nn.Flatten()), 165 | ("8_classification", nn.Linear(in_features=self.filters[3], out_features=num_classes)), 166 | ])) 167 | 168 | self._initialize() 169 | 170 | def _initialize(self): 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in", nonlinearity="relu") 174 | if m.bias is not None: 175 | m.bias.data.zero_() 176 | elif isinstance(m, nn.BatchNorm2d): 177 | m.weight.data.fill_(1) 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.Linear): 180 | m.weight.data.zero_() 181 | m.bias.data.zero_() 182 | 183 | def forward(self, x): 184 | return self.f(x) 185 | 186 | class WideResNet28x10: 187 | base = WideResNet 188 | args = list() 189 | kwargs = {'depth': 28, 'width_factor': 10} 190 | transform_train = transforms.Compose([ 191 | transforms.RandomCrop(32, padding=4), 192 | transforms.RandomHorizontalFlip(), 193 | transforms.ToTensor(), 194 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 195 | ]) 196 | transform_test = transforms.Compose([ 197 | transforms.ToTensor(), 198 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 199 | ]) 200 | 201 | class WideResNet16x8: 202 | base = WideResNet 203 | args = list() 204 | kwargs = {'depth': 16, 'width_factor': 8, 'dropout': 0, 'in_channels': 3} 205 | transform_train = transforms.Compose([ 206 | transforms.RandomCrop(32, padding=4), 207 | transforms.RandomHorizontalFlip(), 208 | transforms.ToTensor(), 209 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 210 | ]) 211 | transform_test = transforms.Compose([ 212 | transforms.ToTensor(), 213 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 214 | ]) -------------------------------------------------------------------------------- /train_sgd_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import random 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | 17 | from utils import get_datasets, get_model, adjust_learning_rate, set_seed, Logger 18 | 19 | # Parse arguments 20 | parser = argparse.ArgumentParser(description='Regular SGD training') 21 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='SGD') 22 | parser.add_argument('--arch', '-a', metavar='ARCH', 23 | help='The architecture of the model') 24 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 25 | help='The training datasets') 26 | parser.add_argument('--optimizer', metavar='OPTIMIZER', default='sgd', type=str, 27 | help='The optimizer for training') 28 | parser.add_argument('--schedule', metavar='SCHEDULE', default='step', type=str, 29 | help='The schedule for training') 30 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 33 | help='number of total epochs to run') 34 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 35 | help='manual epoch number (useful on restarts)') 36 | parser.add_argument('-b', '--batch-size', default=128, type=int, 37 | metavar='N', help='mini-batch size (default: 128)') 38 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 39 | metavar='LR', help='initial learning rate') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 43 | metavar='W', help='weight decay (default: 1e-4)') 44 | parser.add_argument('--print-freq', '-p', default=50, type=int, 45 | metavar='N', help='print frequency (default: 50 iterations)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--wandb', dest='wandb', action='store_true', 51 | help='use wandb to monitor statisitcs') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--half', dest='half', action='store_true', 55 | help='use half-precision(16-bit) ') 56 | parser.add_argument('--save-dir', dest='save_dir', 57 | help='The directory used to save the trained models', 58 | default='save_temp', type=str) 59 | parser.add_argument('--log-dir', dest='log_dir', 60 | help='The directory used to save the log', 61 | default='save_temp', type=str) 62 | parser.add_argument('--log-name', dest='log_name', 63 | help='The log file name', 64 | default='log', type=str) 65 | parser.add_argument('--randomseed', 66 | help='Randomseed for training and initialization', 67 | type=int, default=1) 68 | 69 | best_prec1 = 0 70 | 71 | 72 | # Record training statistics 73 | train_loss = [] 74 | train_err = [] 75 | test_loss = [] 76 | test_err = [] 77 | arr_time = [] 78 | 79 | p0 = None 80 | 81 | args = parser.parse_args() 82 | 83 | if args.wandb: 84 | import wandb 85 | wandb.init(project="TWA", entity="XXX") 86 | date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 87 | wandb.run.name = args.EXP + date 88 | 89 | 90 | def get_model_param_vec(model): 91 | # Return the model parameters as a vector 92 | 93 | vec = [] 94 | for name,param in model.named_parameters(): 95 | vec.append(param.data.detach().reshape(-1)) 96 | return torch.cat(vec, 0) 97 | 98 | 99 | def get_model_grad_vec(model): 100 | # Return the model gradient as a vector 101 | 102 | vec = [] 103 | for name,param in model.named_parameters(): 104 | vec.append(param.grad.detach().reshape(-1)) 105 | return torch.cat(vec, 0) 106 | 107 | def update_grad(model, grad_vec): 108 | idx = 0 109 | for name,param in model.named_parameters(): 110 | arr_shape = param.grad.shape 111 | size = arr_shape.numel() 112 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape) 113 | idx += size 114 | 115 | def update_param(model, param_vec): 116 | idx = 0 117 | for name,param in model.named_parameters(): 118 | arr_shape = param.data.shape 119 | size = arr_shape.numel() 120 | param.data = param_vec[idx:idx+size].reshape(arr_shape) 121 | idx += size 122 | 123 | sample_idx = 0 124 | 125 | def main(): 126 | 127 | global args, best_prec1, p0, sample_idx 128 | global param_avg, train_loss, train_err, test_loss, test_err, arr_time, running_weight 129 | 130 | set_seed(args.randomseed) 131 | 132 | # Check the save_dir exists or not 133 | print ('save dir:', args.save_dir) 134 | if not os.path.exists(args.save_dir): 135 | os.makedirs(args.save_dir) 136 | 137 | # Check the log_dir exists or not 138 | print ('log dir:', args.log_dir) 139 | if not os.path.exists(args.log_dir): 140 | os.makedirs(args.log_dir) 141 | 142 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 143 | 144 | # Define model 145 | # model = torch.nn.DataParallel(get_model(args)) 146 | model = get_model(args) 147 | model.cuda() 148 | 149 | # Optionally resume from a checkpoint 150 | if args.resume: 151 | # if os.path.isfile(args.resume): 152 | if os.path.isfile(os.path.join(args.save_dir, args.resume)): 153 | 154 | # model.load_state_dict(torch.load(os.path.join(args.save_dir, args.resume))) 155 | 156 | print ("=> loading checkpoint '{}'".format(args.resume)) 157 | checkpoint = torch.load(args.resume) 158 | args.start_epoch = checkpoint['epoch'] 159 | print ('from ', args.start_epoch) 160 | best_prec1 = checkpoint['best_prec1'] 161 | model.load_state_dict(checkpoint['state_dict']) 162 | print ("=> loaded checkpoint '{}' (epoch {})" 163 | .format(args.evaluate, checkpoint['epoch'])) 164 | else: 165 | print ("=> no checkpoint found at '{}'".format(args.resume)) 166 | 167 | cudnn.benchmark = True 168 | 169 | 170 | # Prepare Dataloader 171 | train_loader, val_loader = get_datasets(args) 172 | 173 | # define loss function (criterion) and optimizer 174 | criterion = nn.CrossEntropyLoss().cuda() 175 | 176 | if args.half: 177 | model.half() 178 | criterion.half() 179 | 180 | print ('optimizer:', args.optimizer) 181 | 182 | if args.optimizer == 'sgd': 183 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 184 | momentum=args.momentum, 185 | weight_decay=args.weight_decay) 186 | elif args.optimizer == 'adam': 187 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 188 | weight_decay=args.weight_decay) 189 | 190 | if args.schedule == 'step': 191 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1) 192 | 193 | if args.evaluate: 194 | validate(val_loader, model, criterion) 195 | return 196 | 197 | is_best = 0 198 | print ('Start training: ', args.start_epoch, '->', args.epochs) 199 | 200 | # DLDR sampling 201 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(0) + '.pt')) 202 | 203 | p0 = get_model_param_vec(model) 204 | running_weight = p0 205 | 206 | for epoch in range(args.start_epoch, args.epochs): 207 | 208 | # train for one epoch 209 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 210 | train(train_loader, model, criterion, optimizer, epoch) 211 | 212 | lr_scheduler.step() 213 | 214 | # evaluate on validation set 215 | prec1 = validate(val_loader, model, criterion) 216 | 217 | # remember best prec@1 and save checkpoint 218 | is_best = prec1 > best_prec1 219 | best_prec1 = max(prec1, best_prec1) 220 | 221 | save_checkpoint({ 222 | 'state_dict': model.state_dict(), 223 | 'best_prec1': best_prec1, 224 | }, is_best, filename=os.path.join(args.save_dir, 'model.th')) 225 | 226 | # DLDR sampling 227 | sample_idx += 1 228 | torch.save(model.state_dict(), os.path.join(args.save_dir, str(sample_idx) + '.pt')) 229 | 230 | print ('train loss: ', train_loss) 231 | print ('train err: ', train_err) 232 | print ('test loss: ', test_loss) 233 | print ('test err: ', test_err) 234 | 235 | print ('time: ', arr_time) 236 | 237 | running_weight = None 238 | 239 | def train(train_loader, model, criterion, optimizer, epoch): 240 | """ 241 | Run one train epoch 242 | """ 243 | global train_loss, train_err, arr_time, p0, sample_idx, running_weight 244 | 245 | batch_time = AverageMeter() 246 | data_time = AverageMeter() 247 | losses = AverageMeter() 248 | top1 = AverageMeter() 249 | 250 | # switch to train mode 251 | model.train() 252 | 253 | param_epoch_sum = None 254 | cnt = 0 255 | 256 | total_loss, total_err = 0, 0 257 | end = time.time() 258 | for i, (input, target) in enumerate(train_loader): 259 | 260 | # measure data loading time 261 | data_time.update(time.time() - end) 262 | 263 | target = target.cuda() 264 | input_var = input.cuda() 265 | target_var = target 266 | if args.half: 267 | input_var = input_var.half() 268 | 269 | # compute output 270 | output = model(input_var) 271 | loss = criterion(output, target_var) 272 | 273 | # compute gradient and do SGD step 274 | optimizer.zero_grad() 275 | loss.backward() 276 | total_loss += loss.item() * input_var.shape[0] 277 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 278 | 279 | optimizer.step() 280 | output = output.float() 281 | loss = loss.float() 282 | 283 | # measure accuracy and record loss 284 | prec1 = accuracy(output.data, target)[0] 285 | losses.update(loss.item(), input.size(0)) 286 | top1.update(prec1.item(), input.size(0)) 287 | 288 | # measure elapsed time 289 | batch_time.update(time.time() - end) 290 | end = time.time() 291 | 292 | if i % args.print_freq == 0: 293 | print('Epoch: [{0}][{1}/{2}]\t' 294 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 295 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 296 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 297 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 298 | epoch, i, len(train_loader), batch_time=batch_time, 299 | data_time=data_time, loss=losses, top1=top1)) 300 | 301 | print ('Total time for epoch [{0}] : {1:.3f}'.format(epoch, batch_time.sum)) 302 | 303 | train_loss.append(total_loss / len(train_loader.dataset)) 304 | train_err.append(total_err / len(train_loader.dataset)) 305 | if args.wandb: 306 | wandb.log({"train loss": total_loss / len(train_loader.dataset)}) 307 | wandb.log({"train acc": 1 - total_err / len(train_loader.dataset)}) 308 | 309 | arr_time.append(batch_time.sum) 310 | 311 | def validate(val_loader, model, criterion): 312 | """ 313 | Run evaluation 314 | """ 315 | global test_err, test_loss 316 | 317 | total_loss = 0 318 | total_err = 0 319 | 320 | batch_time = AverageMeter() 321 | losses = AverageMeter() 322 | top1 = AverageMeter() 323 | 324 | # switch to evaluate mode 325 | model.eval() 326 | 327 | end = time.time() 328 | with torch.no_grad(): 329 | for i, (input, target) in enumerate(val_loader): 330 | target = target.cuda() 331 | input_var = input.cuda() 332 | target_var = target.cuda() 333 | 334 | if args.half: 335 | input_var = input_var.half() 336 | 337 | # compute output 338 | output = model(input_var) 339 | loss = criterion(output, target_var) 340 | 341 | output = output.float() 342 | loss = loss.float() 343 | 344 | total_loss += loss.item() * input_var.shape[0] 345 | total_err += (output.max(dim=1)[1] != target_var).sum().item() 346 | 347 | # measure accuracy and record loss 348 | prec1 = accuracy(output.data, target)[0] 349 | losses.update(loss.item(), input.size(0)) 350 | top1.update(prec1.item(), input.size(0)) 351 | 352 | # measure elapsed time 353 | batch_time.update(time.time() - end) 354 | end = time.time() 355 | 356 | if i % args.print_freq == 0: 357 | print('Test: [{0}/{1}]\t' 358 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 359 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 360 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 361 | i, len(val_loader), batch_time=batch_time, loss=losses, 362 | top1=top1)) 363 | 364 | print(' * Prec@1 {top1.avg:.3f}' 365 | .format(top1=top1)) 366 | 367 | test_loss.append(total_loss / len(val_loader.dataset)) 368 | test_err.append(total_err / len(val_loader.dataset)) 369 | 370 | if args.wandb: 371 | wandb.log({"test loss": total_loss / len(val_loader.dataset)}) 372 | wandb.log({"test acc": 1 - total_err / len(val_loader.dataset)}) 373 | 374 | return top1.avg 375 | 376 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 377 | """ 378 | Save the training model 379 | """ 380 | torch.save(state, filename) 381 | 382 | class AverageMeter(object): 383 | """Computes and stores the average and current value""" 384 | def __init__(self): 385 | self.reset() 386 | 387 | def reset(self): 388 | self.val = 0 389 | self.avg = 0 390 | self.sum = 0 391 | self.count = 0 392 | 393 | def update(self, val, n=1): 394 | self.val = val 395 | self.sum += val * n 396 | self.count += n 397 | self.avg = self.sum / self.count 398 | 399 | 400 | def accuracy(output, target, topk=(1,)): 401 | """Computes the precision@k for the specified values of k""" 402 | maxk = max(topk) 403 | batch_size = target.size(0) 404 | 405 | _, pred = output.topk(maxk, 1, True, True) 406 | pred = pred.t() 407 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 408 | 409 | res = [] 410 | for k in topk: 411 | correct_k = correct[:k].view(-1).float().sum(0) 412 | res.append(correct_k.mul_(100.0 / batch_size)) 413 | return res 414 | 415 | 416 | if __name__ == '__main__': 417 | main() 418 | -------------------------------------------------------------------------------- /train_twa.py: -------------------------------------------------------------------------------- 1 | from random import choices 2 | import argparse 3 | import _osx_support 4 | import time 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import torch.utils.data 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import pickle 18 | import random 19 | import utils 20 | from utils import get_datasets, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger 21 | 22 | ########################## parse arguments ########################## 23 | parser = argparse.ArgumentParser(description='TWA') 24 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD') 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN', 26 | help='model architecture (default: VGG16BN)') 27 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 28 | help='The training datasets') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=128, type=int, 36 | metavar='N', help='mini-batch size (default: 128)') 37 | parser.add_argument('-acc', '--accumulate', default=1, type=int, 38 | metavar='A', help='accumulate times for batch gradient (default: 1)') 39 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 40 | metavar='W', help='weight decay (default: 1e-4)') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--print-freq', '-p', default=200, type=int, 44 | metavar='N', help='print frequency (default: 50)') 45 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 48 | help='evaluate model on validation set') 49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 50 | help='use pre-trained model') 51 | parser.add_argument('--half', dest='half', action='store_true', 52 | help='use half-precision(16-bit) ') 53 | parser.add_argument('--randomseed', 54 | help='Randomseed for training and initialization', 55 | type=int, default=1) 56 | parser.add_argument('--save-dir', dest='save_dir', 57 | help='The directory used to save the trained models', 58 | default='save_temp', type=str) 59 | parser.add_argument('--log-dir', dest='log_dir', 60 | help='The directory used to save the log', 61 | default='save_temp', type=str) 62 | parser.add_argument('--log-name', dest='log_name', 63 | help='The log file name', 64 | default='log', type=str) 65 | 66 | ########################## P-SGD setting ########################## 67 | parser.add_argument('--extract', metavar='EXTRACT', help='method for extracting subspace', 68 | default='Schmidt', choices=['Schmidt']) 69 | parser.add_argument('--params_start', default=0, type=int, metavar='N', 70 | help='which idx start for TWA') 71 | parser.add_argument('--params_end', default=101, type=int, metavar='N', 72 | help='which idx end for TWA') 73 | parser.add_argument('--train_start', default=0, type=int, metavar='N', 74 | help='which idx start for training') 75 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA', 76 | default='SGD', choices=['SGD']) 77 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD', 78 | default='step', choices=['step', 'constant', 'linear']) 79 | parser.add_argument('--lr', default=1, type=float, metavar='N', 80 | help='lr for PSGD') 81 | 82 | args = parser.parse_args() 83 | set_seed(args.randomseed) 84 | best_prec1 = 0 85 | P = None 86 | train_acc, test_acc, train_loss, test_loss = [], [], [], [] 87 | 88 | def get_model_param_vec(model): 89 | """ 90 | Return model parameters as a vector 91 | """ 92 | vec = [] 93 | for name,param in model.named_parameters(): 94 | vec.append(param.detach().cpu().numpy().reshape(-1)) 95 | return np.concatenate(vec, 0) 96 | 97 | def get_model_param_vec_torch(model): 98 | """ 99 | Return model parameters as a vector 100 | """ 101 | vec = [] 102 | for name,param in model.named_parameters(): 103 | vec.append(param.data.detach().reshape(-1)) 104 | return torch.cat(vec, 0) 105 | 106 | def get_model_grad_vec(model): 107 | """ 108 | Return model grad as a vector 109 | """ 110 | vec = [] 111 | for name,param in model.named_parameters(): 112 | vec.append(param.grad.detach().reshape(-1)) 113 | return torch.cat(vec, 0) 114 | 115 | def update_grad(model, grad_vec): 116 | """ 117 | Update model grad 118 | """ 119 | idx = 0 120 | for name,param in model.named_parameters(): 121 | arr_shape = param.grad.shape 122 | size = arr_shape.numel() 123 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 124 | idx += size 125 | 126 | def update_param(model, param_vec): 127 | idx = 0 128 | for name,param in model.named_parameters(): 129 | arr_shape = param.data.shape 130 | size = arr_shape.numel() 131 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone() 132 | idx += size 133 | 134 | def main(): 135 | 136 | global args, best_prec1, Bk, P, coeff, coeff_inv 137 | 138 | # Check the save_dir exists or not 139 | if not os.path.exists(args.save_dir): 140 | os.makedirs(args.save_dir) 141 | 142 | # Check the log_dir exists or not 143 | if not os.path.exists(args.log_dir): 144 | os.makedirs(args.log_dir) 145 | 146 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 147 | print ('twa-psgd') 148 | print ('save dir:', args.save_dir) 149 | print ('log dir:', args.log_dir) 150 | 151 | # Define model 152 | if args.datasets == 'ImageNet': 153 | model = torch.nn.DataParallel(get_model(args)) 154 | else: 155 | model = get_model(args) 156 | model.cuda() 157 | cudnn.benchmark = True 158 | 159 | # Define loss function (criterion) and optimizer 160 | criterion = nn.CrossEntropyLoss().cuda() 161 | 162 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \ 163 | momentum=args.momentum, \ 164 | weight_decay=args.weight_decay) 165 | 166 | if args.schedule == 'step': 167 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 168 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1) 169 | 170 | elif args.schedule == 'constant' or args.schedule == 'linear': 171 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 172 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1) 173 | 174 | optimizer.zero_grad() 175 | 176 | # Prepare Dataloader 177 | train_loader, val_loader = get_datasets(args) 178 | 179 | args.total_iters = len(train_loader) * args.epochs 180 | args.current_iters = 0 181 | 182 | ########################## extract subspaces ########################## 183 | # Load sampled model parameters 184 | print ('weight decay:', args.weight_decay) 185 | print ('params: from', args.params_start, 'to', args.params_end) 186 | W = [] 187 | for i in range(args.params_start, args.params_end): 188 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(i) + '.pt'))) 189 | W.append(get_model_param_vec(model)) 190 | W = np.array(W) 191 | print ('W:', W.shape) 192 | 193 | # Evaluate swa performance 194 | center = torch.from_numpy(np.mean(W, axis=0)).cuda() 195 | 196 | update_param(model, center) 197 | bn_update(train_loader, model) 198 | print (utils.eval_model(val_loader, model, criterion)) 199 | 200 | if args.extract == 'Schmidt': 201 | P = torch.from_numpy(np.array(W)).cuda() 202 | n_dim = P.shape[0] 203 | args.n_components = n_dim 204 | coeff = torch.eye(n_dim).cuda() 205 | for i in range(n_dim): 206 | if i > 0: 207 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1)) 208 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1) 209 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1) 210 | tmp = torch.norm(P[i]) 211 | P[i] /= tmp 212 | coeff[i] /= tmp 213 | coeff_inv = coeff.T.inverse() 214 | 215 | print (P.shape) 216 | 217 | # set the start point 218 | if args.train_start >= 0: 219 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt'))) 220 | print ('train start:', args.train_start) 221 | 222 | if args.half: 223 | model.half() 224 | criterion.half() 225 | 226 | if args.evaluate: 227 | validate(val_loader, model, criterion) 228 | return 229 | 230 | print ('Train:', (args.start_epoch, args.epochs)) 231 | end = time.time() 232 | p0 = get_model_param_vec(model) 233 | 234 | for epoch in range(args.start_epoch, args.epochs): 235 | # Train for one epoch 236 | 237 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 238 | train(train_loader, model, criterion, optimizer, args, epoch, center) 239 | 240 | if args.schedule != 'linear': 241 | lr_scheduler.step() 242 | 243 | # Evaluate on validation set 244 | prec1 = validate(val_loader, model, criterion) 245 | 246 | # Remember best prec@1 and save checkpoint 247 | is_best = prec1 > best_prec1 248 | best_prec1 = max(prec1, best_prec1) 249 | 250 | print ('Save final model') 251 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'PSGD.pt')) 252 | 253 | bn_update(train_loader, model) 254 | print (utils.eval_model(val_loader, model, criterion)) 255 | 256 | print ('total time:', time.time() - end) 257 | print ('train loss: ', train_loss) 258 | print ('train acc: ', train_acc) 259 | print ('test loss: ', test_loss) 260 | print ('test acc: ', test_acc) 261 | print ('best_prec1:', best_prec1) 262 | 263 | 264 | def train(train_loader, model, criterion, optimizer, args, epoch, center): 265 | # Run one train epoch 266 | 267 | global P, W, iters, T, train_loss, train_acc, search_times, coeff 268 | 269 | batch_time = AverageMeter() 270 | data_time = AverageMeter() 271 | losses = AverageMeter() 272 | top1 = AverageMeter() 273 | 274 | # Switch to train mode 275 | model.train() 276 | 277 | end = time.time() 278 | for i, (input, target) in enumerate(train_loader): 279 | 280 | # Measure data loading time 281 | data_time.update(time.time() - end) 282 | 283 | # Load batch data to cuda 284 | target = target.cuda() 285 | input_var = input.cuda() 286 | target_var = target 287 | if args.half: 288 | input_var = input_var.half() 289 | 290 | # Compute output 291 | output = model(input_var) 292 | loss = criterion(output, target_var) 293 | 294 | # Compute gradient and do SGD step 295 | optimizer.zero_grad() 296 | loss.backward() 297 | gk = get_model_grad_vec(model) 298 | 299 | if args.schedule == 'linear': 300 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr) 301 | args.current_iters += 1 302 | 303 | if args.opt == 'SGD': 304 | P_SGD(model, optimizer, gk, center) 305 | 306 | # Measure accuracy and record loss 307 | prec1 = accuracy(output.data, target)[0] 308 | losses.update(loss.item(), input.size(0)) 309 | top1.update(prec1.item(), input.size(0)) 310 | 311 | # Measure elapsed time 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if i % args.print_freq == 0 or i == len(train_loader)-1: 316 | print('Epoch: [{0}][{1}/{2}]\t' 317 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 318 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 319 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 320 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 321 | epoch, i, len(train_loader), batch_time=batch_time, 322 | data_time=data_time, loss=losses, top1=top1)) 323 | 324 | train_loss.append(losses.avg) 325 | train_acc.append(top1.avg) 326 | 327 | def P_SGD(model, optimizer, grad, center): 328 | 329 | # p = get_model_param_vec_torch(model) 330 | gk = torch.mm(P, grad.reshape(-1,1)) 331 | grad_proj = torch.mm(P.transpose(0, 1), gk) 332 | 333 | update_grad(model, grad_proj.reshape(-1)) 334 | 335 | optimizer.step() 336 | 337 | def validate(val_loader, model, criterion): 338 | # Run evaluation 339 | 340 | global test_acc, test_loss 341 | 342 | batch_time = AverageMeter() 343 | losses = AverageMeter() 344 | top1 = AverageMeter() 345 | 346 | # Switch to evaluate mode 347 | model.eval() 348 | 349 | end = time.time() 350 | with torch.no_grad(): 351 | for i, (input, target) in enumerate(val_loader): 352 | target = target.cuda() 353 | input_var = input.cuda() 354 | target_var = target.cuda() 355 | 356 | if args.half: 357 | input_var = input_var.half() 358 | 359 | # Compute output 360 | output = model(input_var) 361 | loss = criterion(output, target_var) 362 | 363 | output = output.float() 364 | loss = loss.float() 365 | 366 | # Measure accuracy and record loss 367 | prec1 = accuracy(output.data, target)[0] 368 | losses.update(loss.item(), input.size(0)) 369 | top1.update(prec1.item(), input.size(0)) 370 | 371 | # Measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | if i % args.print_freq == 0: 376 | print('Test: [{0}/{1}]\t' 377 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 378 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 379 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 380 | i, len(val_loader), batch_time=batch_time, loss=losses, 381 | top1=top1)) 382 | 383 | print(' * Prec@1 {top1.avg:.3f}' 384 | .format(top1=top1)) 385 | 386 | # Store the test loss and test accuracy 387 | test_loss.append(losses.avg) 388 | test_acc.append(top1.avg) 389 | 390 | return top1.avg 391 | 392 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 393 | # Save the training model 394 | 395 | torch.save(state, filename) 396 | 397 | class AverageMeter(object): 398 | # Computes and stores the average and current value 399 | 400 | def __init__(self): 401 | self.reset() 402 | 403 | def reset(self): 404 | self.val = 0 405 | self.avg = 0 406 | self.sum = 0 407 | self.count = 0 408 | 409 | def update(self, val, n=1): 410 | self.val = val 411 | self.sum += val * n 412 | self.count += n 413 | self.avg = self.sum / self.count 414 | 415 | 416 | def accuracy(output, target, topk=(1,)): 417 | # Computes the precision@k for the specified values of k 418 | 419 | maxk = max(topk) 420 | batch_size = target.size(0) 421 | 422 | _, pred = output.topk(maxk, 1, True, True) 423 | pred = pred.t() 424 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 425 | 426 | res = [] 427 | for k in topk: 428 | correct_k = correct[:k].view(-1).float().sum(0) 429 | res.append(correct_k.mul_(100.0 / batch_size)) 430 | return res 431 | 432 | if __name__ == '__main__': 433 | main() -------------------------------------------------------------------------------- /swa/train_twa.py: -------------------------------------------------------------------------------- 1 | from random import choices 2 | import argparse 3 | import _osx_support 4 | import time 5 | import os 6 | import sys 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import torch.utils.data 14 | 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import pickle 18 | import random 19 | import utils 20 | from utils import get_datasets, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger 21 | 22 | ########################## parse arguments ########################## 23 | parser = argparse.ArgumentParser(description='SGD in Projected Subspace') 24 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD') 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN', 26 | help='model architecture (default: VGG16BN)') 27 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 28 | help='The training datasets') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=128, type=int, 36 | metavar='N', help='mini-batch size (default: 128)') 37 | parser.add_argument('-acc', '--accumulate', default=1, type=int, 38 | metavar='A', help='accumulate times for batch gradient (default: 1)') 39 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 40 | metavar='W', help='weight decay (default: 1e-4)') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--print-freq', '-p', default=200, type=int, 44 | metavar='N', help='print frequency (default: 50)') 45 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 48 | help='evaluate model on validation set') 49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 50 | help='use pre-trained model') 51 | parser.add_argument('--half', dest='half', action='store_true', 52 | help='use half-precision(16-bit) ') 53 | parser.add_argument('--randomseed', 54 | help='Randomseed for training and initialization', 55 | type=int, default=1) 56 | parser.add_argument('--save-dir', dest='save_dir', 57 | help='The directory used to save the trained models', 58 | default='save_temp', type=str) 59 | parser.add_argument('--log-dir', dest='log_dir', 60 | help='The directory used to save the log', 61 | default='save_temp', type=str) 62 | parser.add_argument('--log-name', dest='log_name', 63 | help='The log file name', 64 | default='log', type=str) 65 | 66 | ########################## P-SGD setting ########################## 67 | parser.add_argument('--extract', metavar='EXTRACT', help='method for extracting subspace', 68 | default='Schmidt', choices=['Schmidt']) 69 | parser.add_argument('--params_start', default=0, type=int, metavar='N', 70 | help='which idx start for TWA') 71 | parser.add_argument('--params_end', default=51, type=int, metavar='N', 72 | help='which idx end for TWA') 73 | parser.add_argument('--train_start', default=0, type=int, metavar='N', 74 | help='which idx start for training') 75 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA', 76 | default='SGD', choices=['SGD']) 77 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD', 78 | default='step', choices=['step', 'constant', 'linear']) 79 | parser.add_argument('--lr', default=1, type=float, metavar='N', 80 | help='lr for PSGD') 81 | 82 | args = parser.parse_args() 83 | set_seed(args.randomseed) 84 | best_prec1 = 0 85 | P = None 86 | train_acc, test_acc, train_loss, test_loss = [], [], [], [] 87 | 88 | def get_model_param_vec(model): 89 | """ 90 | Return model parameters as a vector 91 | """ 92 | vec = [] 93 | for name,param in model.named_parameters(): 94 | vec.append(param.detach().cpu().numpy().reshape(-1)) 95 | return np.concatenate(vec, 0) 96 | 97 | def get_model_param_vec_torch(model): 98 | """ 99 | Return model parameters as a vector 100 | """ 101 | vec = [] 102 | for name,param in model.named_parameters(): 103 | vec.append(param.data.detach().reshape(-1)) 104 | return torch.cat(vec, 0) 105 | 106 | def get_model_grad_vec(model): 107 | """ 108 | Return model grad as a vector 109 | """ 110 | vec = [] 111 | for name,param in model.named_parameters(): 112 | vec.append(param.grad.detach().reshape(-1)) 113 | return torch.cat(vec, 0) 114 | 115 | def update_grad(model, grad_vec): 116 | """ 117 | Update model grad 118 | """ 119 | idx = 0 120 | for name,param in model.named_parameters(): 121 | arr_shape = param.grad.shape 122 | size = arr_shape.numel() 123 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 124 | idx += size 125 | 126 | def update_param(model, param_vec): 127 | idx = 0 128 | for name,param in model.named_parameters(): 129 | arr_shape = param.data.shape 130 | size = arr_shape.numel() 131 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone() 132 | idx += size 133 | 134 | def main(): 135 | 136 | global args, best_prec1, Bk, P, coeff, coeff_inv 137 | 138 | # Check the save_dir exists or not 139 | if not os.path.exists(args.save_dir): 140 | os.makedirs(args.save_dir) 141 | 142 | # Check the log_dir exists or not 143 | if not os.path.exists(args.log_dir): 144 | os.makedirs(args.log_dir) 145 | 146 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 147 | print ('twa-psgd') 148 | print ('save dir:', args.save_dir) 149 | print ('log dir:', args.log_dir) 150 | 151 | # Define model 152 | if args.datasets == 'ImageNet': 153 | model = torch.nn.DataParallel(get_model(args)) 154 | else: 155 | model = get_model(args) 156 | model.cuda() 157 | cudnn.benchmark = True 158 | 159 | # Define loss function (criterion) and optimizer 160 | criterion = nn.CrossEntropyLoss().cuda() 161 | 162 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \ 163 | momentum=args.momentum, \ 164 | weight_decay=args.weight_decay) 165 | 166 | if args.schedule == 'step': 167 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 168 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1) 169 | 170 | elif args.schedule == 'constant' or args.schedule == 'linear': 171 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 172 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1) 173 | 174 | optimizer.zero_grad() 175 | 176 | # Prepare Dataloader 177 | train_loader, val_loader = get_datasets(args) 178 | 179 | args.total_iters = len(train_loader) * args.epochs 180 | args.current_iters = 0 181 | 182 | ########################## extract subspaces ########################## 183 | # Load sampled model parameters 184 | print ('weight decay:', args.weight_decay) 185 | print ('params: from', args.params_start, 'to', args.params_end) 186 | W = [] 187 | for i in range(args.params_start, args.params_end): 188 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(i) + '.pt'))) 189 | W.append(get_model_param_vec(model)) 190 | W = np.array(W) 191 | print ('W:', W.shape) 192 | 193 | # Evaluate swa performance 194 | center = torch.from_numpy(np.mean(W, axis=0)).cuda() 195 | 196 | update_param(model, center) 197 | bn_update(train_loader, model) 198 | print ('SWA:', utils.eval_model(val_loader, model, criterion)) 199 | 200 | if args.extract == 'Schmidt': 201 | P = torch.from_numpy(np.array(W)).cuda() 202 | n_dim = P.shape[0] 203 | args.n_components = n_dim 204 | coeff = torch.eye(n_dim).cuda() 205 | for i in range(n_dim): 206 | if i > 0: 207 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1)) 208 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1) 209 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1) 210 | tmp = torch.norm(P[i]) 211 | P[i] /= tmp 212 | coeff[i] /= tmp 213 | coeff_inv = coeff.T.inverse() 214 | 215 | print (P.shape) 216 | 217 | # set the start point 218 | if args.train_start >= 0: 219 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt'))) 220 | print ('train start:', args.train_start) 221 | 222 | if args.half: 223 | model.half() 224 | criterion.half() 225 | 226 | if args.evaluate: 227 | validate(val_loader, model, criterion) 228 | return 229 | 230 | print ('Train:', (args.start_epoch, args.epochs)) 231 | end = time.time() 232 | p0 = get_model_param_vec(model) 233 | 234 | for epoch in range(args.start_epoch, args.epochs): 235 | # Train for one epoch 236 | 237 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 238 | train(train_loader, model, criterion, optimizer, args, epoch, center) 239 | 240 | if args.schedule != 'linear': 241 | lr_scheduler.step() 242 | 243 | # Evaluate on validation set 244 | prec1 = validate(val_loader, model, criterion) 245 | 246 | # Remember best prec@1 and save checkpoint 247 | is_best = prec1 > best_prec1 248 | best_prec1 = max(prec1, best_prec1) 249 | 250 | print ('Save final model') 251 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'PSGD.pt')) 252 | 253 | bn_update(train_loader, model) 254 | print (utils.eval_model(val_loader, model, criterion)) 255 | 256 | print ('total time:', time.time() - end) 257 | print ('train loss: ', train_loss) 258 | print ('train acc: ', train_acc) 259 | print ('test loss: ', test_loss) 260 | print ('test acc: ', test_acc) 261 | print ('best_prec1:', best_prec1) 262 | 263 | 264 | def train(train_loader, model, criterion, optimizer, args, epoch, center): 265 | # Run one train epoch 266 | 267 | global P, W, iters, T, train_loss, train_acc, search_times, coeff 268 | 269 | batch_time = AverageMeter() 270 | data_time = AverageMeter() 271 | losses = AverageMeter() 272 | top1 = AverageMeter() 273 | 274 | # Switch to train mode 275 | model.train() 276 | 277 | end = time.time() 278 | for i, (input, target) in enumerate(train_loader): 279 | 280 | # Measure data loading time 281 | data_time.update(time.time() - end) 282 | 283 | # Load batch data to cuda 284 | target = target.cuda() 285 | input_var = input.cuda() 286 | target_var = target 287 | if args.half: 288 | input_var = input_var.half() 289 | 290 | # Compute output 291 | output = model(input_var) 292 | loss = criterion(output, target_var) 293 | 294 | # Compute gradient and do SGD step 295 | optimizer.zero_grad() 296 | loss.backward() 297 | gk = get_model_grad_vec(model) 298 | 299 | if args.schedule == 'linear': 300 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr) 301 | args.current_iters += 1 302 | 303 | if args.opt == 'SGD': 304 | P_SGD(model, optimizer, gk, center) 305 | 306 | # Measure accuracy and record loss 307 | prec1 = accuracy(output.data, target)[0] 308 | losses.update(loss.item(), input.size(0)) 309 | top1.update(prec1.item(), input.size(0)) 310 | 311 | # Measure elapsed time 312 | batch_time.update(time.time() - end) 313 | end = time.time() 314 | 315 | if i % args.print_freq == 0 or i == len(train_loader)-1: 316 | print('Epoch: [{0}][{1}/{2}]\t' 317 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 318 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 319 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 320 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 321 | epoch, i, len(train_loader), batch_time=batch_time, 322 | data_time=data_time, loss=losses, top1=top1)) 323 | 324 | train_loss.append(losses.avg) 325 | train_acc.append(top1.avg) 326 | 327 | def P_SGD(model, optimizer, grad, center): 328 | 329 | # p = get_model_param_vec_torch(model) 330 | gk = torch.mm(P, grad.reshape(-1,1)) 331 | grad_proj = torch.mm(P.transpose(0, 1), gk) 332 | 333 | update_grad(model, grad_proj.reshape(-1)) 334 | 335 | optimizer.step() 336 | 337 | def validate(val_loader, model, criterion): 338 | # Run evaluation 339 | 340 | global test_acc, test_loss 341 | 342 | batch_time = AverageMeter() 343 | losses = AverageMeter() 344 | top1 = AverageMeter() 345 | 346 | # Switch to evaluate mode 347 | model.eval() 348 | 349 | end = time.time() 350 | with torch.no_grad(): 351 | for i, (input, target) in enumerate(val_loader): 352 | target = target.cuda() 353 | input_var = input.cuda() 354 | target_var = target.cuda() 355 | 356 | if args.half: 357 | input_var = input_var.half() 358 | 359 | # Compute output 360 | output = model(input_var) 361 | loss = criterion(output, target_var) 362 | 363 | output = output.float() 364 | loss = loss.float() 365 | 366 | # Measure accuracy and record loss 367 | prec1 = accuracy(output.data, target)[0] 368 | losses.update(loss.item(), input.size(0)) 369 | top1.update(prec1.item(), input.size(0)) 370 | 371 | # Measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | if i % args.print_freq == 0: 376 | print('Test: [{0}/{1}]\t' 377 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 378 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 379 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 380 | i, len(val_loader), batch_time=batch_time, loss=losses, 381 | top1=top1)) 382 | 383 | print(' * Prec@1 {top1.avg:.3f}' 384 | .format(top1=top1)) 385 | 386 | # Store the test loss and test accuracy 387 | test_loss.append(losses.avg) 388 | test_acc.append(top1.avg) 389 | 390 | return top1.avg 391 | 392 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 393 | # Save the training model 394 | 395 | torch.save(state, filename) 396 | 397 | class AverageMeter(object): 398 | # Computes and stores the average and current value 399 | 400 | def __init__(self): 401 | self.reset() 402 | 403 | def reset(self): 404 | self.val = 0 405 | self.avg = 0 406 | self.sum = 0 407 | self.count = 0 408 | 409 | def update(self, val, n=1): 410 | self.val = val 411 | self.sum += val * n 412 | self.count += n 413 | self.avg = self.sum / self.count 414 | 415 | 416 | def accuracy(output, target, topk=(1,)): 417 | # Computes the precision@k for the specified values of k 418 | 419 | maxk = max(topk) 420 | batch_size = target.size(0) 421 | 422 | _, pred = output.topk(maxk, 1, True, True) 423 | pred = pred.t() 424 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 425 | 426 | res = [] 427 | for k in topk: 428 | correct_k = correct[:k].view(-1).float().sum(0) 429 | res.append(correct_k.mul_(100.0 / batch_size)) 430 | return res 431 | 432 | if __name__ == '__main__': 433 | main() -------------------------------------------------------------------------------- /train_twa_ddp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | import numpy as np 16 | from utils import get_imagenet_dataset, get_model, set_seed, adjust_learning_rate, bn_update, eval_model, Logger 17 | 18 | from PIL import Image, ImageFile 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | ########################## parse arguments ########################## 22 | parser = argparse.ArgumentParser(description='TWA ddp') 23 | parser.add_argument('--EXP', metavar='EXP', help='experiment name', default='P-SGD') 24 | parser.add_argument('--arch', '-a', metavar='ARCH', default='VGG16BN', 25 | help='model architecture (default: VGG16BN)') 26 | parser.add_argument('--datasets', metavar='DATASETS', default='CIFAR10', type=str, 27 | help='The training datasets') 28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 31 | help='number of total epochs to run') 32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('-b', '--batch-size', default=128, type=int, 35 | metavar='N', help='mini-batch size (default: 128)') 36 | parser.add_argument('--weight-decay', '--wd', default=1e-5, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 39 | help='momentum') 40 | parser.add_argument('--print-freq', '-p', default=200, type=int, 41 | metavar='N', help='print frequency (default: 50)') 42 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 43 | help='evaluate model on validation set') 44 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 45 | help='use pre-trained model') 46 | # env 47 | parser.add_argument('--randomseed', 48 | help='Randomseed for training and initialization', 49 | type=int, default=1) 50 | parser.add_argument('--save-dir', dest='save_dir', 51 | help='The directory used to save the trained models', 52 | default='save_temp', type=str) 53 | parser.add_argument('--log-dir', dest='log_dir', 54 | help='The directory used to save the log', 55 | default='save_temp', type=str) 56 | parser.add_argument('--log-name', dest='log_name', 57 | help='The log file name', 58 | default='log', type=str) 59 | # project subspace setting 60 | parser.add_argument('--params_start', default=0, type=int, metavar='N', 61 | help='which idx start for project subspace') 62 | parser.add_argument('--params_end', default=51, type=int, metavar='N', 63 | help='which idx end for project subspace') 64 | parser.add_argument('--train_start', default=0, type=int, metavar='N', 65 | help='which idx start for training') 66 | # optimizer and scheduler 67 | parser.add_argument('--opt', metavar='OPT', help='optimization method for TWA', 68 | default='SGD', choices=['SGD']) 69 | parser.add_argument('--schedule', metavar='SCHE', help='learning rate schedule for P-SGD', 70 | default='step', choices=['step', 'constant', 'linear']) 71 | parser.add_argument('--lr', default=1, type=float, metavar='N', 72 | help='lr for PSGD') 73 | # ddp 74 | parser.add_argument("--local_rank", default=-1, type=int) 75 | 76 | args = parser.parse_args() 77 | set_seed(args.randomseed) 78 | 79 | def reduce_value(value, op=dist.ReduceOp.SUM): 80 | world_size = dist.get_world_size() 81 | if world_size < 2: # single GPU 82 | return value 83 | 84 | with torch.no_grad(): 85 | dist.all_reduce(value, op) 86 | return value 87 | 88 | def get_model_param_vec_torch(model): 89 | """ 90 | Return model parameters as a vector 91 | """ 92 | vec = [] 93 | for _, param in model.named_parameters(): 94 | vec.append(param.data.detach().reshape(-1)) 95 | return torch.cat(vec, 0) 96 | 97 | def get_model_grad_vec_torch(model): 98 | """ 99 | Return model grad as a vector 100 | """ 101 | vec = [] 102 | for _, param in model.named_parameters(): 103 | vec.append(param.grad.detach().reshape(-1)) 104 | return torch.cat(vec, 0) 105 | 106 | def update_grad(model, grad_vec): 107 | """ 108 | Update model grad 109 | """ 110 | idx = 0 111 | for _, param in model.named_parameters(): 112 | arr_shape = param.grad.shape 113 | size = arr_shape.numel() 114 | param.grad.data = grad_vec[idx:idx+size].reshape(arr_shape).clone() 115 | idx += size 116 | 117 | def update_param(model, param_vec): 118 | idx = 0 119 | for _, param in model.named_parameters(): 120 | arr_shape = param.data.shape 121 | size = arr_shape.numel() 122 | param.data = param_vec[idx:idx+size].reshape(arr_shape).clone() 123 | idx += size 124 | 125 | def main(args): 126 | # DDP initialize backend 127 | torch.cuda.set_device(args.local_rank) 128 | dist.init_process_group(backend='nccl') 129 | world_size = torch.distributed.get_world_size() 130 | device = torch.device("cuda", args.local_rank) 131 | dist.barrier() # Synchronizes all processes 132 | 133 | if dist.get_rank() == 0: 134 | # Check the save_dir exists or not 135 | if not os.path.exists(args.save_dir): 136 | os.makedirs(args.save_dir) 137 | 138 | # Check the log_dir exists or not 139 | if not os.path.exists(args.log_dir): 140 | os.makedirs(args.log_dir) 141 | 142 | sys.stdout = Logger(os.path.join(args.log_dir, args.log_name)) 143 | print('twa-ddp') 144 | print('save dir:', args.save_dir) 145 | print('log dir:', args.log_dir) 146 | 147 | # Define model 148 | model = get_model(args).to(device) 149 | model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) 150 | cudnn.benchmark = True 151 | 152 | # Define loss function (criterion) and optimizer 153 | criterion = nn.CrossEntropyLoss().to(device) 154 | 155 | optimizer = optim.SGD(model.parameters(), lr=args.lr, \ 156 | momentum=args.momentum, \ 157 | weight_decay=args.weight_decay) 158 | optimizer.zero_grad() 159 | 160 | if args.schedule == 'step': 161 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 162 | milestones=[int(args.epochs*0.5), int(args.epochs*0.75+0.9)], last_epoch=args.start_epoch - 1) 163 | elif args.schedule == 'constant' or args.schedule == 'linear': 164 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \ 165 | milestones=[args.epochs + 1], last_epoch=args.start_epoch - 1) 166 | 167 | # Prepare Dataloader 168 | train_dataset, val_dataset = get_imagenet_dataset() 169 | assert args.batch_size % world_size == 0, f"Batch size {args.batch_size} cannot be divided evenly by world size {world_size}" 170 | batch_size_per_GPU = args.batch_size // world_size 171 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 172 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 173 | 174 | train_loader = torch.utils.data.DataLoader( 175 | train_dataset, batch_size=batch_size_per_GPU, sampler=train_sampler, 176 | num_workers=args.workers, pin_memory=True) 177 | 178 | val_loader = torch.utils.data.DataLoader( 179 | val_dataset, batch_size=batch_size_per_GPU, sampler=val_sampler, 180 | num_workers=args.workers) 181 | 182 | args.total_iters = len(train_loader) * args.epochs 183 | args.current_iters = 0 184 | 185 | ########################## extract subspaces ########################## 186 | # Load sampled model parameters 187 | if dist.get_rank() == 0: 188 | print('weight decay:', args.weight_decay) 189 | print('params: from', args.params_start, 'to', args.params_end) 190 | W = [] 191 | for i in range(args.params_start, args.params_end): 192 | if i%2==1: continue 193 | model.load_state_dict(torch.load(os.path.join(args.save_dir, f'{i}.pt'))) 194 | W.append(get_model_param_vec_torch(model)) 195 | W = torch.stack(W, dim=0) 196 | 197 | # Schmidt 198 | P = W 199 | n_dim = P.shape[0] 200 | coeff = torch.eye(n_dim).to(device) 201 | for i in range(n_dim): 202 | if i > 0: 203 | tmp = torch.mm(P[:i, :], P[i].reshape(-1, 1)) 204 | P[i] -= torch.mm(P[:i, :].T, tmp).reshape(-1) 205 | coeff[i] -= torch.mm(coeff[:i, :].T, tmp).reshape(-1) 206 | tmp = torch.norm(P[i]) 207 | P[i] /= tmp 208 | coeff[i] /= tmp 209 | coeff_inv = coeff.T.inverse() 210 | 211 | # Slice P 212 | slice_start = (n_dim//world_size)*dist.get_rank() 213 | if dist.get_rank() == world_size-1: 214 | slice_P = P[slice_start:,:].clone() 215 | else: 216 | slice_end = (n_dim//world_size)*(dist.get_rank()+1) 217 | slice_P = P[slice_start:slice_end,:].clone() 218 | if dist.get_rank() == 0: 219 | print(f'W: {W.shape} {W.device}') 220 | print(f'P: {P.shape} {P.device}') 221 | print(f'Sliced P: {slice_P.shape} {slice_P.device}') 222 | del P 223 | torch.cuda.empty_cache() 224 | dist.barrier() # Synchronizes all processes 225 | 226 | # set the start point 227 | if args.train_start >= 0: 228 | model.load_state_dict(torch.load(os.path.join(args.save_dir, str(args.train_start) + '.pt'))) 229 | if dist.get_rank() == 0: 230 | print('train start:', args.train_start) 231 | 232 | if args.evaluate: 233 | validate(val_loader, model, criterion) 234 | return 235 | 236 | if dist.get_rank() == 0: 237 | print('Train:', (args.start_epoch, args.epochs)) 238 | end = time.time() 239 | his_train_acc, his_test_acc, his_train_loss, his_test_loss = [], [], [], [] 240 | best_prec1 = 0 241 | for epoch in range(args.start_epoch, args.epochs): 242 | # Train for one epoch 243 | if dist.get_rank() == 0: 244 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 245 | train_loss, train_prec1 = train(train_loader, model, criterion, optimizer, 246 | args, epoch, slice_P, device, world_size) 247 | his_train_loss.append(train_loss) 248 | his_train_acc.append(train_prec1) 249 | 250 | if args.schedule != 'linear': 251 | lr_scheduler.step() 252 | 253 | # Evaluate on validation set 254 | test_loss, test_prec1 = validate(val_loader, model, criterion, device, world_size) 255 | his_test_loss.append(test_loss) 256 | his_test_acc.append(test_prec1) 257 | 258 | # Remember best prec@1 and save checkpoint 259 | best_prec1 = max(test_prec1, best_prec1) 260 | if dist.get_rank() == 0: 261 | print(f'Epoch: [{epoch}] * Best Prec@1 {best_prec1:.3f}') 262 | torch.save(model.state_dict(), os.path.join(args.save_dir, f'ddp{epoch}.pt')) 263 | 264 | if dist.get_rank() == 0: 265 | print('total time:', time.time() - end) 266 | print('train loss: ', his_train_loss) 267 | print('train acc: ', his_train_acc) 268 | print('test loss: ', his_test_loss) 269 | print('test acc: ', his_test_acc) 270 | print('best_prec1:', best_prec1) 271 | 272 | 273 | def train(train_loader, model, criterion, optimizer, args, epoch, P, device, world_size=1): 274 | # Run one train epoch 275 | 276 | batch_time = AverageMeter() 277 | data_time = AverageMeter() 278 | losses = AverageMeter() 279 | correctes = 0 280 | count = 0 281 | 282 | # Switch to train mode 283 | model.train() 284 | 285 | end = time.time() 286 | for i, (input, target) in enumerate(train_loader): 287 | # Measure data loading time 288 | data_time.update(time.time() - end) 289 | 290 | # Load batch data to cuda 291 | target = target.to(device) 292 | input = input.to(device) 293 | 294 | batch_size = torch.tensor(target.size(0)).to(device) 295 | reduce_value(batch_size) 296 | count += batch_size 297 | 298 | # Compute output 299 | output = model(input) 300 | loss = criterion(output, target) 301 | 302 | # Compute gradient and do SGD step 303 | optimizer.zero_grad() 304 | loss.backward() 305 | 306 | if args.schedule == 'linear': 307 | adjust_learning_rate(optimizer, (1 - args.current_iters / args.total_iters) * args.lr) 308 | args.current_iters += 1 309 | 310 | project_gradient(model, P) 311 | optimizer.step() 312 | 313 | # Measure accuracy and record loss 314 | _, pred = output.topk(1, 1, True, True) 315 | pred = pred.t() 316 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 317 | correct_1 = correct[:1].view(-1).float().sum(0) 318 | reduce_value(correct_1) 319 | correctes += correct_1 320 | 321 | reduce_value(loss) 322 | loss /= world_size 323 | losses.update(loss.item(), input.size(0)) 324 | 325 | # Measure elapsed time 326 | batch_time.update(time.time() - end) 327 | end = time.time() 328 | 329 | if (i % args.print_freq == 0 or i == len(train_loader)-1) and dist.get_rank() == 0: 330 | print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t' 331 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 332 | f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 333 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 334 | f'Prec@1 {correct_1/batch_size*100:.3f} ({correctes/count*100:.3f})') 335 | 336 | return losses.avg, correctes/count*100 337 | 338 | def project_gradient(model, P): 339 | grad = get_model_grad_vec_torch(model) 340 | gk = torch.mm(P, grad.reshape(-1, 1)) 341 | grad_proj = torch.mm(P.transpose(0, 1), gk) 342 | reduce_value(grad_proj) # Sum-reduce projected gradients on different GPUs 343 | 344 | update_grad(model, grad_proj.reshape(-1)) 345 | 346 | def validate(val_loader, model, criterion, device, world_size=1): 347 | # Run evaluation 348 | 349 | batch_time = AverageMeter() 350 | losses = AverageMeter() 351 | correctes = 0 352 | count = 0 353 | 354 | # Switch to evaluate mode 355 | model.eval() 356 | 357 | end = time.time() 358 | with torch.no_grad(): 359 | for i, (input, target) in enumerate(val_loader): 360 | target = target.to(device) 361 | input = input.to(device) 362 | 363 | batch_size = torch.tensor(target.size(0)).to(device) 364 | reduce_value(batch_size) 365 | count += batch_size 366 | 367 | # Compute output 368 | output = model(input) 369 | loss = criterion(output, target) 370 | 371 | # Measure accuracy and record loss 372 | _, pred = output.topk(1, 1, True, True) 373 | pred = pred.t() 374 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 375 | correct_1 = correct[:1].view(-1).float().sum(0) 376 | reduce_value(correct_1) 377 | correctes += correct_1 378 | 379 | reduce_value(loss) 380 | loss /= world_size 381 | losses.update(loss.item(), input.size(0)) 382 | 383 | 384 | # Measure elapsed time 385 | batch_time.update(time.time() - end) 386 | end = time.time() 387 | 388 | if i % args.print_freq == 0 and dist.get_rank() == 0: 389 | print(f'Test: [{i}/{len(val_loader)}]\t' 390 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 391 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 392 | f'Prec@1 {correct_1/batch_size*100:.3f} ({correctes/count*100:.3f})') 393 | 394 | print(f' * Prec@1 {correctes/count*100:.3f}') 395 | 396 | return losses.avg, correctes/count*100 397 | 398 | 399 | class AverageMeter(object): 400 | # Computes and stores the average and current value 401 | 402 | def __init__(self): 403 | self.reset() 404 | 405 | def reset(self): 406 | self.val = 0 407 | self.avg = 0 408 | self.sum = 0 409 | self.count = 0 410 | 411 | def update(self, val, n=1): 412 | self.val = val 413 | self.sum += val * n 414 | self.count += n 415 | self.avg = self.sum / self.count 416 | 417 | def accuracy(output, target, topk=(1,)): 418 | # Computes the precision@k for the specified values of k 419 | 420 | maxk = max(topk) 421 | batch_size = target.size(0) 422 | 423 | _, pred = output.topk(maxk, 1, True, True) 424 | pred = pred.t() 425 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 426 | 427 | res = [] 428 | for k in topk: 429 | correct_k = correct[:k].view(-1).float().sum(0) 430 | res.append(correct_k.mul_(100.0 / batch_size)) 431 | return res 432 | 433 | if __name__ == '__main__': 434 | main(args) -------------------------------------------------------------------------------- /train_sgd_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import os 8 | 9 | import numpy as np 10 | import pickle 11 | 12 | from PIL import Image, ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.parallel 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | import torch.optim 21 | import torch.multiprocessing as mp 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | import torchvision.models as models 27 | 28 | model_names = sorted(name for name in models.__dict__ 29 | if name.islower() and not name.startswith("__") 30 | and callable(models.__dict__[name])) 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 33 | parser.add_argument('data', metavar='DIR', 34 | help='path to dataset') 35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 36 | choices=model_names, 37 | help='model architecture: ' + 38 | ' | '.join(model_names) + 39 | ' (default: resnet18)') 40 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 41 | help='number of data loading workers (default: 4)') 42 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 43 | help='number of total epochs to run') 44 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 45 | help='manual epoch number (useful on restarts)') 46 | parser.add_argument('-b', '--batch-size', default=256, type=int, 47 | metavar='N', 48 | help='mini-batch size (default: 256), this is the total ' 49 | 'batch size of all GPUs on the current node when ' 50 | 'using Data Parallel or Distributed Data Parallel') 51 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 52 | metavar='LR', help='initial learning rate', dest='lr') 53 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 54 | help='momentum') 55 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 56 | metavar='W', help='weight decay (default: 1e-4)', 57 | dest='weight_decay') 58 | parser.add_argument('-p', '--print-freq', default=1000, type=int, 59 | metavar='N', help='print frequency (default: 10)') 60 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 61 | help='path to latest checkpoint (default: none)') 62 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 63 | help='evaluate model on validation set') 64 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 65 | help='use pre-trained model') 66 | parser.add_argument('--world-size', default=-1, type=int, 67 | help='number of nodes for distributed training') 68 | parser.add_argument('--rank', default=-1, type=int, 69 | help='node rank for distributed training') 70 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 71 | help='url used to set up distributed training') 72 | parser.add_argument('--dist-backend', default='nccl', type=str, 73 | help='distributed backend') 74 | parser.add_argument('--seed', default=None, type=int, 75 | help='seed for initializing training. ') 76 | parser.add_argument('--gpu', default=None, type=int, 77 | help='GPU id to use.') 78 | parser.add_argument('--multiprocessing-distributed', action='store_true', 79 | help='Use multi-processing distributed training to launch ' 80 | 'N processes per node, which has N GPUs. This is the ' 81 | 'fastest way to use PyTorch for either single node or ' 82 | 'multi node data parallel training') 83 | 84 | best_acc1 = 0 85 | 86 | 87 | param_vec = [] 88 | # Record training statistics 89 | train_loss = [] 90 | train_acc = [] 91 | test_loss = [] 92 | test_acc = [] 93 | arr_time = [] 94 | 95 | iters = 0 96 | def get_model_param_vec(model): 97 | # Return the model parameters as a vector 98 | 99 | vec = [] 100 | for name,param in model.named_parameters(): 101 | vec.append(param.detach().cpu().reshape(-1).numpy()) 102 | return np.concatenate(vec, 0) 103 | 104 | def main(): 105 | global train_loss, train_acc, test_loss, test_acc, arr_time 106 | 107 | args = parser.parse_args() 108 | 109 | 110 | save_dir = 'save_' + args.arch 111 | if not os.path.exists(save_dir): 112 | os.makedirs(save_dir) 113 | 114 | if args.seed is not None: 115 | random.seed(args.seed) 116 | torch.manual_seed(args.seed) 117 | cudnn.deterministic = True 118 | warnings.warn('You have chosen to seed training. ' 119 | 'This will turn on the CUDNN deterministic setting, ' 120 | 'which can slow down your training considerably! ' 121 | 'You may see unexpected behavior when restarting ' 122 | 'from checkpoints.') 123 | 124 | if args.gpu is not None: 125 | warnings.warn('You have chosen a specific GPU. This will completely ' 126 | 'disable data parallelism.') 127 | 128 | if args.dist_url == "env://" and args.world_size == -1: 129 | args.world_size = int(os.environ["WORLD_SIZE"]) 130 | 131 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 132 | 133 | ngpus_per_node = torch.cuda.device_count() 134 | if args.multiprocessing_distributed: 135 | # Since we have ngpus_per_node processes per node, the total world_size 136 | # needs to be adjusted accordingly 137 | args.world_size = ngpus_per_node * args.world_size 138 | # Use torch.multiprocessing.spawn to launch distributed processes: the 139 | # main_worker process function 140 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 141 | else: 142 | # Simply call main_worker function 143 | main_worker(args.gpu, ngpus_per_node, args) 144 | 145 | sample_idx = 0 146 | 147 | def main_worker(gpu, ngpus_per_node, args): 148 | global train_loss, train_acc, test_loss, test_acc, arr_time 149 | global best_acc1, param_vec, sample_idx 150 | args.gpu = gpu 151 | 152 | if args.gpu is not None: 153 | print("Use GPU: {} for training".format(args.gpu)) 154 | 155 | if args.distributed: 156 | if args.dist_url == "env://" and args.rank == -1: 157 | args.rank = int(os.environ["RANK"]) 158 | if args.multiprocessing_distributed: 159 | # For multiprocessing distributed training, rank needs to be the 160 | # global rank among all the processes 161 | args.rank = args.rank * ngpus_per_node + gpu 162 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 163 | world_size=args.world_size, rank=args.rank) 164 | # create model 165 | if args.pretrained: 166 | print("=> using pre-trained model '{}'".format(args.arch)) 167 | model = models.__dict__[args.arch](pretrained=True) 168 | else: 169 | print("=> creating model '{}'".format(args.arch)) 170 | model = models.__dict__[args.arch]() 171 | 172 | if not torch.cuda.is_available(): 173 | print('using CPU, this will be slow') 174 | elif args.distributed: 175 | # For multiprocessing distributed, DistributedDataParallel constructor 176 | # should always set the single device scope, otherwise, 177 | # DistributedDataParallel will use all available devices. 178 | if args.gpu is not None: 179 | torch.cuda.set_device(args.gpu) 180 | model.cuda(args.gpu) 181 | # When using a single GPU per process and per 182 | # DistributedDataParallel, we need to divide the batch size 183 | # ourselves based on the total number of GPUs we have 184 | args.batch_size = int(args.batch_size / ngpus_per_node) 185 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 186 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 187 | else: 188 | model.cuda() 189 | # DistributedDataParallel will divide and allocate batch_size to all 190 | # available GPUs if device_ids are not set 191 | model = torch.nn.parallel.DistributedDataParallel(model) 192 | elif args.gpu is not None: 193 | torch.cuda.set_device(args.gpu) 194 | model = model.cuda(args.gpu) 195 | else: 196 | # DataParallel will divide and allocate batch_size to all available GPUs 197 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 198 | model.features = torch.nn.DataParallel(model.features) 199 | model.cuda() 200 | else: 201 | model = torch.nn.DataParallel(model).cuda() 202 | 203 | # define loss function (criterion) and optimizer 204 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 205 | 206 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 207 | momentum=args.momentum, 208 | weight_decay=args.weight_decay) 209 | 210 | # optionally resume from a checkpoint 211 | if args.resume: 212 | if os.path.isfile(args.resume): 213 | print("=> loading checkpoint '{}'".format(args.resume)) 214 | if args.gpu is None: 215 | checkpoint = torch.load(args.resume) 216 | else: 217 | # Map model to be loaded to specified single gpu. 218 | loc = 'cuda:{}'.format(args.gpu) 219 | checkpoint = torch.load(args.resume, map_location=loc) 220 | args.start_epoch = checkpoint['epoch'] 221 | best_acc1 = checkpoint['best_acc1'] 222 | if args.gpu is not None: 223 | # best_acc1 may be from a checkpoint from a different GPU 224 | best_acc1 = best_acc1.to(args.gpu) 225 | model.load_state_dict(checkpoint['state_dict']) 226 | optimizer.load_state_dict(checkpoint['optimizer']) 227 | print("=> loaded checkpoint '{}' (epoch {})" 228 | .format(args.resume, checkpoint['epoch'])) 229 | else: 230 | print("=> no checkpoint found at '{}'".format(args.resume)) 231 | 232 | cudnn.benchmark = True 233 | 234 | # Data loading code 235 | traindir = os.path.join(args.data, 'train') 236 | valdir = os.path.join(args.data, 'val') 237 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 238 | std=[0.229, 0.224, 0.225]) 239 | 240 | train_dataset = datasets.ImageFolder( 241 | traindir, 242 | transforms.Compose([ 243 | transforms.RandomResizedCrop(224), 244 | transforms.RandomHorizontalFlip(), 245 | transforms.ToTensor(), 246 | normalize, 247 | ])) 248 | 249 | if args.distributed: 250 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 251 | else: 252 | train_sampler = None 253 | 254 | train_loader = torch.utils.data.DataLoader( 255 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 256 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 257 | 258 | val_loader = torch.utils.data.DataLoader( 259 | datasets.ImageFolder(valdir, transforms.Compose([ 260 | transforms.Resize(256), 261 | transforms.CenterCrop(224), 262 | transforms.ToTensor(), 263 | normalize, 264 | ])), 265 | batch_size=args.batch_size, shuffle=False, 266 | num_workers=args.workers, pin_memory=True) 267 | 268 | if args.evaluate: 269 | validate(val_loader, model, criterion, args) 270 | return 271 | 272 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): 273 | torch.save(model.state_dict(), 'save_' + args.arch + '/' + str(sample_idx)+'.pt') 274 | 275 | 276 | for epoch in range(args.start_epoch, args.epochs): 277 | if args.distributed: 278 | train_sampler.set_epoch(epoch) 279 | adjust_learning_rate(optimizer, epoch, args) 280 | 281 | # train for one epoch 282 | train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node) 283 | 284 | # evaluate on validation set 285 | acc1 = validate(val_loader, model, criterion, args) 286 | 287 | # remember best acc@1 and save checkpoint 288 | is_best = acc1 > best_acc1 289 | best_acc1 = max(acc1, best_acc1) 290 | 291 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 292 | and args.rank % ngpus_per_node == 0): 293 | save_checkpoint({ 294 | 'epoch': epoch + 1, 295 | 'arch': args.arch, 296 | 'state_dict': model.state_dict(), 297 | 'best_acc1': best_acc1, 298 | 'optimizer' : optimizer.state_dict(), 299 | }, is_best) 300 | 301 | print ('train loss: ', train_loss) 302 | print ('train acc: ', train_acc) 303 | print ('test loss: ', test_loss) 304 | print ('test acc: ', test_acc) 305 | 306 | print ('time: ', arr_time) 307 | 308 | 309 | def train(train_loader, model, criterion, optimizer, epoch, args, ngpus_per_node): 310 | global iters, param_vec, sample_idx 311 | global train_loss, train_acc, test_loss, test_acc, arr_time 312 | 313 | batch_time = AverageMeter('Time', ':6.3f') 314 | data_time = AverageMeter('Data', ':6.3f') 315 | losses = AverageMeter('Loss', ':.4e') 316 | top1 = AverageMeter('Acc@1', ':6.2f') 317 | top5 = AverageMeter('Acc@5', ':6.2f') 318 | progress = ProgressMeter( 319 | len(train_loader), 320 | [batch_time, data_time, losses, top1, top5], 321 | prefix="Epoch: [{}]".format(epoch)) 322 | 323 | # switch to train mode 324 | model.train() 325 | 326 | end = time.time() 327 | epoch_start = end 328 | for i, (images, target) in enumerate(train_loader): 329 | # measure data loading time 330 | data_time.update(time.time() - end) 331 | 332 | if args.gpu is not None: 333 | images = images.cuda(args.gpu, non_blocking=True) 334 | if torch.cuda.is_available(): 335 | target = target.cuda(args.gpu, non_blocking=True) 336 | 337 | # compute output 338 | output = model(images) 339 | loss = criterion(output, target) 340 | 341 | # measure accuracy and record loss 342 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 343 | losses.update(loss.item(), images.size(0)) 344 | top1.update(acc1[0], images.size(0)) 345 | top5.update(acc5[0], images.size(0)) 346 | 347 | # compute gradient and do SGD step 348 | optimizer.zero_grad() 349 | loss.backward() 350 | optimizer.step() 351 | 352 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 353 | and args.rank % ngpus_per_node == 0): 354 | 355 | if i % args.print_freq == 0: 356 | progress.display(i) 357 | 358 | if i > 0 and i % 1000 == 0 and i < 5000: 359 | sample_idx += 1 360 | torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt') 361 | 362 | # measure elapsed time 363 | batch_time.update(time.time() - end) 364 | end = time.time() 365 | 366 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 367 | and args.rank % ngpus_per_node == 0): 368 | sample_idx += 1 369 | torch.save(model.state_dict(), 'save_' + args.arch + '/'+str(sample_idx)+'.pt') 370 | 371 | arr_time.append(time.time() - epoch_start) 372 | train_loss.append(losses.avg) 373 | train_acc.append(top1.avg) 374 | 375 | 376 | def validate(val_loader, model, criterion, args): 377 | global train_loss, train_acc, test_loss, test_acc, arr_time 378 | batch_time = AverageMeter('Time', ':6.3f') 379 | losses = AverageMeter('Loss', ':.4e') 380 | top1 = AverageMeter('Acc@1', ':6.2f') 381 | top5 = AverageMeter('Acc@5', ':6.2f') 382 | progress = ProgressMeter( 383 | len(val_loader), 384 | [batch_time, losses, top1, top5], 385 | prefix='Test: ') 386 | 387 | # switch to evaluate mode 388 | model.eval() 389 | 390 | with torch.no_grad(): 391 | end = time.time() 392 | for i, (images, target) in enumerate(val_loader): 393 | if args.gpu is not None: 394 | images = images.cuda(args.gpu, non_blocking=True) 395 | if torch.cuda.is_available(): 396 | target = target.cuda(args.gpu, non_blocking=True) 397 | 398 | # compute output 399 | output = model(images) 400 | loss = criterion(output, target) 401 | 402 | # measure accuracy and record loss 403 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 404 | losses.update(loss.item(), images.size(0)) 405 | top1.update(acc1[0], images.size(0)) 406 | top5.update(acc5[0], images.size(0)) 407 | 408 | # measure elapsed time 409 | batch_time.update(time.time() - end) 410 | end = time.time() 411 | 412 | if i % args.print_freq == 0: 413 | progress.display(i) 414 | 415 | # TODO: this should also be done with the ProgressMeter 416 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 417 | .format(top1=top1, top5=top5)) 418 | test_acc.append(top1.avg) 419 | test_loss.append(losses.avg) 420 | return top1.avg 421 | 422 | 423 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 424 | torch.save(state, filename) 425 | if is_best: 426 | shutil.copyfile(filename, 'model_best.pth.tar') 427 | 428 | 429 | class AverageMeter(object): 430 | """Computes and stores the average and current value""" 431 | def __init__(self, name, fmt=':f'): 432 | self.name = name 433 | self.fmt = fmt 434 | self.reset() 435 | 436 | def reset(self): 437 | self.val = 0 438 | self.avg = 0 439 | self.sum = 0 440 | self.count = 0 441 | 442 | def update(self, val, n=1): 443 | self.val = val 444 | self.sum += val * n 445 | self.count += n 446 | self.avg = self.sum / self.count 447 | 448 | def __str__(self): 449 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 450 | return fmtstr.format(**self.__dict__) 451 | 452 | 453 | class ProgressMeter(object): 454 | def __init__(self, num_batches, meters, prefix=""): 455 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 456 | self.meters = meters 457 | self.prefix = prefix 458 | 459 | def display(self, batch): 460 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 461 | entries += [str(meter) for meter in self.meters] 462 | print('\t'.join(entries)) 463 | 464 | def _get_batch_fmtstr(self, num_batches): 465 | num_digits = len(str(num_batches // 1)) 466 | fmt = '{:' + str(num_digits) + 'd}' 467 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 468 | 469 | 470 | def adjust_learning_rate(optimizer, epoch, args): 471 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 472 | lr = args.lr * (0.1 ** (epoch // 30)) 473 | # lr = 0.1 474 | for param_group in optimizer.param_groups: 475 | param_group['lr'] = lr 476 | 477 | 478 | def accuracy(output, target, topk=(1,)): 479 | """Computes the accuracy over the k top predictions for the specified values of k""" 480 | with torch.no_grad(): 481 | maxk = max(topk) 482 | batch_size = target.size(0) 483 | 484 | _, pred = output.topk(maxk, 1, True, True) 485 | pred = pred.t() 486 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 487 | 488 | res = [] 489 | for k in topk: 490 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 491 | res.append(correct_k.mul_(100.0 / batch_size)) 492 | return res 493 | 494 | 495 | if __name__ == '__main__': 496 | main() --------------------------------------------------------------------------------