├── .gitignore ├── imagenet ├── models │ ├── __init__.py │ ├── pswitch.py │ ├── switch.py │ └── resnet.py ├── requirements.txt ├── README.md ├── utils.py ├── main.py ├── GPUtil.py └── main_gradual.py ├── README.md └── cifar10 ├── models ├── __init__.py ├── pswitch.py ├── lenet.py ├── switch.py ├── vgg.py ├── mobilenet.py ├── mobilenetv2.py ├── googlenet.py ├── resnext.py ├── dpn.py ├── densenet.py ├── shufflenet.py ├── senet.py ├── preact_resnet.py ├── pnasnet.py └── resnet.py ├── README.md ├── get_mean_std.py ├── profile.py ├── test_accu_trend.py ├── utils.py ├── flops_counter.py ├── main_add.py └── main_gradual.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | checkpoint 3 | results 4 | -------------------------------------------------------------------------------- /imagenet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from resnet import * -------------------------------------------------------------------------------- /imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # autogrow 2 | AutoGrow: Automatic Layer Growing in Deep Convolutional Networks 3 | 4 | This is code of paper **AutoGrow** at https://arxiv.org/abs/1906.02909, which proposes a method to automatically grow layers to find an optimial depth in deep neural networks. 5 | 6 | Tutorial to be updated soon... 7 | -------------------------------------------------------------------------------- /cifar10/models/__init__.py: -------------------------------------------------------------------------------- 1 | from vgg import * 2 | from dpn import * 3 | from lenet import * 4 | from senet import * 5 | from pnasnet import * 6 | from densenet import * 7 | from googlenet import * 8 | from shufflenet import * 9 | from resnet import * 10 | from resnext import * 11 | from preact_resnet import * 12 | from mobilenet import * 13 | from mobilenetv2 import * 14 | -------------------------------------------------------------------------------- /cifar10/models/pswitch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PSwitch(nn.Module): 5 | """ 6 | This is a learnable switch 7 | """ 8 | def __init__(self, value=1.0): 9 | super(PSwitch, self).__init__() 10 | self.switch = nn.Parameter(torch.Tensor(1)) 11 | self.switch.data.fill_(value) 12 | 13 | def forward(self, input): 14 | return input * self.switch 15 | 16 | def get_switch(self): 17 | return self.switch.data -------------------------------------------------------------------------------- /imagenet/models/pswitch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class PSwitch(nn.Module): 5 | """ 6 | This is a learnable switch 7 | """ 8 | def __init__(self, value=1.0): 9 | super(PSwitch, self).__init__() 10 | self.switch = nn.Parameter(torch.Tensor(1)) 11 | self.switch.data.fill_(value) 12 | 13 | def forward(self, input): 14 | return input * self.switch 15 | 16 | def get_switch(self): 17 | return self.switch.data -------------------------------------------------------------------------------- /cifar10/models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /cifar10/models/switch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Switch(nn.Module): 5 | """ 6 | This is a param-free switch 7 | """ 8 | def __init__(self, value=1.0, steps=1, start=0.0, stop=1.0, mode='linear'): 9 | super(Switch, self).__init__() 10 | self.value = torch.ones(()) * value 11 | self.steps = steps 12 | self.start = start 13 | self.stop = stop 14 | self.mode = mode 15 | assert (self.steps >= 1) 16 | assert (self.stop >= self.start) 17 | self.register_buffer('switch', self.value) 18 | 19 | def set_params(self, steps, start=0.0, stop=1.0, mode='linear'): 20 | self.steps = steps 21 | self.start = start 22 | self.stop = stop 23 | self.mode = mode 24 | assert (self.steps >= 1) 25 | assert (self.stop >= self.start) 26 | 27 | def forward(self, input): 28 | return input * self.switch 29 | 30 | def increase(self): 31 | if 'linear' == self.mode: 32 | self.switch += (self.stop - self.start) / self.steps 33 | 34 | self.switch.fill_(self.start if self.switch < self.start else self.switch) 35 | self.switch.fill_(self.stop if self.switch > self.stop else self.switch) 36 | 37 | def get_switch(self): 38 | return self.switch 39 | 40 | def extra_repr(self): 41 | switch_str = '%.3f' % self.switch 42 | return switch_str -------------------------------------------------------------------------------- /imagenet/models/switch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Switch(nn.Module): 5 | """ 6 | This is a param-free switch 7 | """ 8 | def __init__(self, value=1.0, steps=1, start=0.0, stop=1.0, mode='linear'): 9 | super(Switch, self).__init__() 10 | self.value = torch.ones(()) * value 11 | self.steps = steps 12 | self.start = start 13 | self.stop = stop 14 | self.mode = mode 15 | assert (self.steps >= 1) 16 | assert (self.stop >= self.start) 17 | self.register_buffer('switch', self.value) 18 | 19 | def set_params(self, steps, start=0.0, stop=1.0, mode='linear'): 20 | self.steps = steps 21 | self.start = start 22 | self.stop = stop 23 | self.mode = mode 24 | assert (self.steps >= 1) 25 | assert (self.stop >= self.start) 26 | 27 | def forward(self, input): 28 | return input * self.switch 29 | 30 | def increase(self): 31 | if 'linear' == self.mode: 32 | self.switch += (self.stop - self.start) / self.steps 33 | 34 | self.switch.fill_(self.start if self.switch < self.start else self.switch) 35 | self.switch.fill_(self.stop if self.switch > self.stop else self.switch) 36 | 37 | def get_switch(self): 38 | return self.switch 39 | 40 | def extra_repr(self): 41 | switch_str = '%.3f' % self.switch 42 | return switch_str -------------------------------------------------------------------------------- /cifar10/README.md: -------------------------------------------------------------------------------- 1 | # Train CIFAR10 with PyTorch 2 | Merged from https://github.com/kuangliu/pytorch-cifar 3 | 4 | 5 | I'm playing with [PyTorch](http://pytorch.org/) on the CIFAR10 dataset. 6 | 7 | ## Pros & cons 8 | Pros: 9 | - Built-in data loading and augmentation, very nice! 10 | - Training is fast, maybe even a little bit faster. 11 | - Very memory efficient! 12 | 13 | Cons: 14 | - No progress bar, sad :( 15 | - No built-in log. 16 | 17 | ## Accuracy 18 | | Model | Acc. | 19 | | ----------------- | ----------- | 20 | | [VGG16](https://arxiv.org/abs/1409.1556) | 92.64% | 21 | | [ResNet18](https://arxiv.org/abs/1512.03385) | 93.02% | 22 | | [ResNet50](https://arxiv.org/abs/1512.03385) | 93.62% | 23 | | [ResNet101](https://arxiv.org/abs/1512.03385) | 93.75% | 24 | | [MobileNetV2](https://arxiv.org/abs/1801.04381) | 94.43% | 25 | | [ResNeXt29(32x4d)](https://arxiv.org/abs/1611.05431) | 94.73% | 26 | | [ResNeXt29(2x64d)](https://arxiv.org/abs/1611.05431) | 94.82% | 27 | | [DenseNet121](https://arxiv.org/abs/1608.06993) | 95.04% | 28 | | [PreActResNet18](https://arxiv.org/abs/1603.05027) | 95.11% | 29 | | [DPN92](https://arxiv.org/abs/1707.01629) | 95.16% | 30 | 31 | ## Learning rate adjustment 32 | I manually change the `lr` during training: 33 | - `0.1` for epoch `[0,150)` 34 | - `0.01` for epoch `[150,250)` 35 | - `0.001` for epoch `[250,350)` 36 | 37 | Resume the training with `python main.py --resume --lr=0.01` 38 | -------------------------------------------------------------------------------- /cifar10/models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /cifar10/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # ImageNet training in PyTorch 2 | 3 | This implements training of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset. 4 | 5 | ## Requirements 6 | 7 | - Install PyTorch ([pytorch.org](http://pytorch.org)) 8 | - `pip install -r requirements.txt` 9 | - Download the ImageNet dataset and move validation images to labeled subfolders 10 | - To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh 11 | 12 | ## Training 13 | 14 | To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset: 15 | 16 | ```bash 17 | python main.py -a resnet18 [imagenet-folder with train and val folders] 18 | ``` 19 | 20 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: 21 | 22 | ```bash 23 | python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders] 24 | ``` 25 | 26 | ## Usage 27 | 28 | ``` 29 | usage: main.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] 30 | [--lr LR] [--momentum M] [--weight-decay W] [--print-freq N] 31 | [--resume PATH] [-e] [--pretrained] 32 | DIR 33 | 34 | PyTorch ImageNet Training 35 | 36 | positional arguments: 37 | DIR path to dataset 38 | 39 | optional arguments: 40 | -h, --help show this help message and exit 41 | --arch ARCH, -a ARCH model architecture: alexnet | resnet | resnet101 | 42 | resnet152 | resnet18 | resnet34 | resnet50 | vgg | 43 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn 44 | | vgg19 | vgg19_bn (default: resnet18) 45 | -j N, --workers N number of data loading workers (default: 4) 46 | --epochs N number of total epochs to run 47 | --start-epoch N manual epoch number (useful on restarts) 48 | -b N, --batch-size N mini-batch size (default: 256) 49 | --lr LR, --learning-rate LR 50 | initial learning rate 51 | --momentum M momentum 52 | --weight-decay W, --wd W 53 | weight decay (default: 1e-4) 54 | --print-freq N, -p N print frequency (default: 10) 55 | --resume PATH path to latest checkpoint (default: none) 56 | -e, --evaluate evaluate model on validation set 57 | --pretrained use pre-trained model 58 | ``` 59 | -------------------------------------------------------------------------------- /cifar10/get_mean_std.py: -------------------------------------------------------------------------------- 1 | '''Get dataset mean and std with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import matplotlib 5 | matplotlib.use("pdf") 6 | import matplotlib.pyplot as plt 7 | import logging 8 | from datetime import datetime 9 | from copy import deepcopy 10 | import re 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | import models.switch as ms 21 | import models.pswitch as ps 22 | 23 | import os 24 | import argparse 25 | import numpy as np 26 | import models 27 | import utils 28 | import time 29 | 30 | # from models import * 31 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 32 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset') 33 | parser.add_argument('--batch_size', default='200', type=int, help='dataset') 34 | 35 | args = parser.parse_args() 36 | 37 | # Data 38 | print('==> Preparing data..') 39 | transform_train = transforms.Compose([ 40 | transforms.ToTensor(), 41 | ]) 42 | if 'SVHN' == args.dataset: 43 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, split='train', download=True, 44 | transform=transform_train) 45 | else: 46 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-'+args.dataset, train=True, download=True, transform=transform_train) 47 | print('%d training samples.' % len(trainset)) 48 | 49 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=2) 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | h, w = 0, 0 52 | for batch_idx, (inputs, targets) in enumerate(trainloader): 53 | inputs = inputs.to(device) 54 | if batch_idx == 0: 55 | h, w = inputs.size(2), inputs.size(3) 56 | print(inputs.min(), inputs.max()) 57 | chsum = inputs.sum(dim=(0, 2, 3), keepdim=True) 58 | else: 59 | chsum += inputs.sum(dim=(0, 2, 3), keepdim=True) 60 | mean = chsum/len(trainset)/h/w 61 | print('mean: %s' % mean.view(-1)) 62 | 63 | chsum = None 64 | for batch_idx, (inputs, targets) in enumerate(trainloader): 65 | inputs = inputs.to(device) 66 | if batch_idx == 0: 67 | chsum = (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True) 68 | else: 69 | chsum += (inputs - mean).pow(2).sum(dim=(0, 2, 3), keepdim=True) 70 | std = torch.sqrt(chsum/(len(trainset) * h * w - 1)) 71 | print('std: %s' % std.view(-1)) 72 | 73 | print('Done!') -------------------------------------------------------------------------------- /cifar10/profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flops_counter import get_model_complexity_info 3 | import models as mymodels 4 | 5 | nets = [ 6 | # pruning from 96-96-96 7 | ('CifarResNetBasic', [1, 3, 4], 91.09), 8 | ('CifarResNetBasic', [1, 6, 7], 92.52), 9 | ('CifarResNetBasic', [2, 11, 17], 92.96), 10 | ('CifarResNetBasic', [4, 17, 41], 93.59), 11 | # ('CifarResNetBasic', [2, 22, 63], 93.72), 12 | ('CifarResNetBasic', [3, 24, 36], 93.74), 13 | ('CifarResNetBasic', [7, 36, 56], 93.88), 14 | ('CifarResNetBasic', [7, 45, 64], 93.96), 15 | 16 | # pruning from 48-48-48 17 | ('CifarResNetBasic', [1, 4, 8], 91.98), 18 | ('CifarResNetBasic', [4, 13, 10], 93.43), 19 | ('CifarResNetBasic', [5, 14, 26], 93.61), 20 | ('CifarResNetBasic', [13, 26, 46], 94.37), 21 | # pruning from 24-24-24 22 | ('CifarResNetBasic', [1, 2, 4], 91.50), 23 | ('CifarResNetBasic', [4, 3, 5], 92.70), 24 | ('CifarResNetBasic', [4, 5, 7], 93.20), 25 | ('CifarResNetBasic', [10, 5, 16], 93.80), 26 | ('CifarResNetBasic', [13, 11, 23], 93.89), 27 | ('CifarResNetBasic', [17, 10, 24], 93.96), 28 | # manual search 29 | ('CifarResNetBasic', [3, 3, 3], 92.96), 30 | ('CifarResNetBasic', [5, 5, 5], 93.44), 31 | ('CifarResNetBasic', [7, 7, 7], 93.68), 32 | ('CifarResNetBasic', [9, 9, 9], 93.88), 33 | ('CifarResNetBasic', [11, 11, 11], 93.72), 34 | ('CifarResNetBasic', [13, 13, 13], 93.91), 35 | ('CifarResNetBasic', [15, 15, 15], 93.90), 36 | ('CifarResNetBasic', [18, 18, 18], 93.79), 37 | ('CifarResNetBasic', [24, 24, 24], 94.26), 38 | ('CifarResNetBasic', [48, 48, 48], 94.54), 39 | # growing with gaussian 40 | # ('CifarResNetBasic', [5, 5, 4], 92.68), 41 | ('CifarResNetBasic', [4, 6, 3], 93.11), 42 | ('CifarResNetBasic', [10, 10, 10], 93.34), 43 | ('CifarResNetBasic', [11, 8, 11], 93.41), 44 | # ('CifarResNetBasic', [9, 16, 16], 93.34), 45 | ('CifarResNetBasic', [24, 23, 23], 93.48), 46 | ('CifarResNetBasic', [33, 32, 32], 94.15), 47 | ('CifarResNetBasic', [92, 91, 91], 94.41), 48 | 49 | # growing with zero 50 | # ('CifarResNetBasic', [6, 3, 5], 92.28), 51 | # ('CifarResNetBasic', [3, 3, 6], 92.56), 52 | ('CifarResNetBasic', [5, 5, 4], 92.99), 53 | ('CifarResNetBasic', [22, 6, 6], 93.18), 54 | ('CifarResNetBasic', [49, 18, 18], 93.33), 55 | ('CifarResNetBasic', [32, 31, 31], 93.86), 56 | ('CifarResNetBasic', [42, 42, 41], 94.31), 57 | 58 | ] 59 | 60 | with torch.cuda.device(0): 61 | print('Net, flops, params, accuracy:') 62 | for net_type, num_blocks, accu in nets: 63 | net = getattr(mymodels, net_type)(num_blocks) 64 | flops, params = get_model_complexity_info(net, (32, 32), as_strings=False, print_per_layer_stat=False) 65 | print('{}-{}\t{}\t{}\t{}'.format(net_type, num_blocks, flops, params, accu)) -------------------------------------------------------------------------------- /cifar10/test_accu_trend.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ExponentialMovingAverage(object): 4 | def __init__(self, decay=0.9): 5 | self.data = [] 6 | self.decay = decay 7 | self.avg_val = 0.0 8 | 9 | def push(self, current_data): 10 | self.avg_val = self.decay * self.avg_val + (1 - self.decay) * current_data 11 | self.data.append(self.avg_val) 12 | 13 | def get(self): 14 | return self.data 15 | 16 | 17 | accu=np.array([54.950, 71.720, 74.470, 78.800, 76.560, 80.330, 82.150, 81.020, 83.260, 85.580, 84.090, 82.760, 84.610, 86.980, 86.650, 86.510, 86.790, 87.630, 87.690, 87.750, 89.120, 88.060, 87.980, 89.090, 88.600, 88.560, 89.310, 90.090, 88.670, 89.570, 90.150, 89.260, 89.470, 89.710, 90.720, 90.480, 89.520, 88.440, 90.650, 90.600, 90.410, 90.670, 89.070, 90.780, 90.630, 90.100, 90.620, 90.130, 90.330, 91.220, 90.530, 90.830, 90.550, 91.110, 90.380, 89.970, 90.390, 90.050, 90.100, 90.520, 91.680, 91.120, 90.760, 90.400, 91.080, 91.320, 90.330, 90.990, 91.210, 90.780, 90.560, 90.980, 90.980, 91.080, 91.570, 90.900, 90.820, 90.850, 90.650, 91.280, 91.300, 91.690, 91.640, 90.150, 90.790, 91.180, 91.210, 91.420, 91.120, 90.930, 91.180, 91.210, 90.700, 91.510, 90.880, 91.470, 91.100, 91.170, 90.530, 91.190, 91.270, 91.840, 91.530, 91.400, 91.640, 91.640, 91.450, 91.770, 91.220, 90.920, 91.690, 91.040, 91.600, 91.410, 91.610, 91.730, 91.080, 91.160, 91.220, 91.790, 90.200, 91.410, 91.360, 91.460, 91.660, 91.290, 91.060, 92.040, 90.980, 90.930, 91.600, 91.820, 91.040, 91.090, 91.460, 91.660, 91.830, 91.240, 91.270, 91.470, 91.730, 91.060, 91.050, 91.510, 91.820, 91.180, 91.320, 91.310, 91.870, 91.420, 91.640, 91.320, 91.710, 91.660, 91.040, 91.330, 91.590, 91.440, 91.620, 91.220, 91.670, 91.400, 91.160, 92.080, 91.530, 91.630, 91.550, 91.770, 91.000, 92.220, 91.770, 91.200, 91.820, 91.570, 91.770, 91.890, 91.590, 90.920, 92.060, 91.430, 91.800, 91.450, 91.310, 92.090, 91.350, 91.870, 91.660, 91.440, 92.100, 92.140, 91.670, 91.670, 91.310, 91.930, 91.800, 91.710, 91.520, 92.000, 90.990, 91.520, 91.320, 91.610, 91.190, 91.600, 91.640, 91.680, 91.770, 91.620, 91.420, 91.240, 90.980, 91.760, 91.460, 91.530, 91.460, 92.110, 91.110, 91.400, 91.000, 91.100, 92.090, 91.510, 91.980, 91.720, 91.410, 91.660, 91.470, 91.840, 91.580, 91.840, 91.920, 91.020, 91.640, 91.460, 91.820, 92.400, 91.820, 91.820, 91.750, 91.920, 91.800, 91.110, 91.490, 91.410, 92.180, 91.600, 92.140, 91.570, 91.440, 91.930, 92.020, 90.990, 91.940, 91.630, 92.040, 92.150, 92.060, 91.840, 91.880, 91.430, 92.100, 91.430, 91.380, 91.310, 91.970, 91.590, 92.000, 92.000, 92.500, 92.140, 91.510, 91.470, 92.030, 91.500, 92.010, 91.900, 91.900, 91.250, 91.810, 91.480, 91.640, 91.390, 91.760, 92.040, 91.930, 91.800, 91.830, 91.820, 91.750, 92.150, 91.610, 91.520, 92.010, 90.290, 91.900, 91.840, 91.550, 91.500, 92.240, 91.200]) 18 | 19 | emv = ExponentialMovingAverage(0.95) 20 | for val in accu: 21 | emv.push(val) 22 | smoothed_accu = np.array(emv.get()) 23 | smoothed_delta = smoothed_accu[5:smoothed_accu.size]-smoothed_accu[:(smoothed_accu.size-5)] 24 | 25 | import matplotlib 26 | matplotlib.use("pdf") 27 | import matplotlib.pyplot as plt 28 | plt.plot(smoothed_delta) 29 | plt.ylim([-.5,.5]) 30 | plt.savefig('smoothed_delta.pdf') 31 | 32 | -------------------------------------------------------------------------------- /cifar10/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /cifar10/models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /cifar10/models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /cifar10/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /cifar10/models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /cifar10/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /cifar10/models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /cifar10/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(): 98 | return PreActResNet(PreActBlock, [2,2,2,2]) 99 | 100 | def PreActResNet34(): 101 | return PreActResNet(PreActBlock, [3,4,6,3]) 102 | 103 | def PreActResNet50(): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 105 | 106 | def PreActResNet101(): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 108 | 109 | def PreActResNet152(): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /cifar10/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /imagenet/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | import torch 11 | import copy 12 | import numpy as np 13 | 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | class MovingMaximum(object): 18 | def __init__(self): 19 | self.data = [] # data[i] is the maximum val in data[0:i+1] 20 | self.max = 0.0 21 | 22 | def push(self, current_data): 23 | if len(self.data) == 0: 24 | self.max = current_data 25 | elif current_data > self.max: 26 | self.max = current_data 27 | self.data.append(self.max) 28 | 29 | def get(self): 30 | return self.data 31 | 32 | def delta(self, start, end): 33 | try: 34 | res = self.data[end] - self.data[start] 35 | except IndexError: 36 | res = self.data[end] 37 | return res 38 | 39 | class ExponentialMovingAverage(object): 40 | def __init__(self, decay=0.95): 41 | self.data = [] 42 | self.decay = decay 43 | self.avg_val = 0.0 44 | 45 | def push(self, current_data): 46 | self.avg_val = self.decay * self.avg_val + (1 - self.decay) * current_data 47 | self.data.append(self.avg_val) 48 | 49 | def get(self): 50 | return self.data 51 | 52 | def delta(self, start, end): 53 | try: 54 | res = self.data[end] - self.data[start] 55 | except IndexError: 56 | res = self.data[end] 57 | return res 58 | 59 | class TorchExponentialMovingAverage(object): 60 | def __init__(self, decay=0.999): 61 | self.decay = decay 62 | self.ema = {} 63 | self.number = {} 64 | 65 | def push(self, current_data): 66 | assert isinstance(current_data, dict), "current_data should be a dict" 67 | for key in current_data: 68 | if key in self.ema: 69 | # in-place 70 | self.ema[key] -= (1.0 - self.decay) * (self.ema[key] - current_data[key]) 71 | self.number[key] += 1 72 | else: 73 | # self.ema[key] = copy.deepcopy(current_data[key]) 74 | self.ema[key] = current_data[key] * (1.0 - self.decay) 75 | self.number[key] = 1 76 | 77 | def average(self): 78 | scaled_ema = {} 79 | for key in self.ema: 80 | scaled_ema[key] = self.ema[key] / (1.0 - self.decay ** self.number[key]) 81 | return scaled_ema 82 | 83 | 84 | # net: pytorch module 85 | # strict: strict matching for set 86 | def set_named_parameters(net, named_params, strict=True): 87 | assert isinstance(named_params, dict), "named_params should be a dict" 88 | orig_params_data = {} 89 | for n, p in net.named_parameters(): 90 | orig_params_data[n] = copy.deepcopy(p.data) 91 | if strict: 92 | assert len(named_params) == len(list(net.named_parameters())), "Unmatched number of params!" 93 | for n, p in net.named_parameters(): 94 | if strict: 95 | assert n in named_params, "Unknown param name!" 96 | if n in named_params: 97 | p.data.copy_(named_params[n]) 98 | 99 | return orig_params_data 100 | 101 | def next_group(g, maxlim, arch, logger): 102 | if g < 0 or g >= len(maxlim): 103 | logger.info('group index %d is out of range.' % g) 104 | return -1 105 | for i in range(len(maxlim)): 106 | idx = (g+i+1)%len(maxlim) 107 | if maxlim[idx] > arch[idx]: 108 | return idx 109 | return -1 110 | 111 | def next_arch(mode, maxlim, arch, logger, sub=None, rate=0.333, group=0): 112 | tmp_arch = [v for v in arch] 113 | if 'all' == mode: 114 | tmp_arch = [v+1 for v in tmp_arch] 115 | elif 'group' == mode and group >= 0 and group < len(arch): 116 | tmp_arch[group] += 1 117 | elif 'rate' == mode: 118 | num = int(round(sum(arch)*rate)) 119 | while num >= len(arch): 120 | tmp_arch = [v + 1 for v in tmp_arch] 121 | num -= len(arch) 122 | if num: 123 | rperm = np.random.permutation(len(arch)) 124 | for idx in rperm[:num]: 125 | tmp_arch[idx] += 1 126 | elif 'sub' == mode and (sub is not None): 127 | for idx, val in enumerate(tmp_arch): 128 | tmp_arch[idx] += sub[idx] 129 | else: 130 | logger.fatal('Unknown mode') 131 | exit() 132 | 133 | res = [] 134 | for r, m in zip(tmp_arch, maxlim): 135 | if r > m: 136 | res.append(m) 137 | else: 138 | res.append(r) 139 | return res -------------------------------------------------------------------------------- /cifar10/utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | import torch 11 | import copy 12 | import numpy as np 13 | 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | datasets = { 'CIFAR10': {'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010), 'num_classes': 10, 'image_channels': 3, 'size': 32}, # std is wrong. should be [0.2471, 0.2435, 0.2616] 18 | 'CIFAR100': {'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762), 'num_classes': 100, 'image_channels': 3, 'size': 32}, 19 | 'SVHN': {'mean': (0.4377, 0.4438, 0.4728), 'std': (0.1980, 0.2010, 0.1970), 'num_classes': 10, 'image_channels': 3, 'size': 32}, 20 | 21 | 'FashionMNIST': {'mean': (0.0,), 'std': (1.0,), 'num_classes': 10, 'image_channels': 1, 'size': 28}, 22 | 'MNIST': {'mean': (0.0,), 'std': (1.0,), 'num_classes': 10, 'image_channels': 1, 'size': 28}, 23 | } 24 | 25 | class MovingMaximum(object): 26 | def __init__(self): 27 | self.data = [] # data[i] is the maximum val in data[0:i+1] 28 | self.max = 0.0 29 | 30 | def push(self, current_data): 31 | if len(self.data) == 0: 32 | self.max = current_data 33 | elif current_data > self.max: 34 | self.max = current_data 35 | self.data.append(self.max) 36 | 37 | def get(self): 38 | return self.data 39 | 40 | def delta(self, start, end): 41 | try: 42 | res = self.data[end] - self.data[start] 43 | except IndexError: 44 | res = self.data[end] 45 | return res 46 | 47 | class ExponentialMovingAverage(object): 48 | def __init__(self, decay=0.95): 49 | self.data = [] 50 | self.decay = decay 51 | self.avg_val = 0.0 52 | 53 | def push(self, current_data): 54 | self.avg_val = self.decay * self.avg_val + (1 - self.decay) * current_data 55 | self.data.append(self.avg_val) 56 | 57 | def get(self): 58 | return self.data 59 | 60 | def delta(self, start, end): 61 | try: 62 | res = self.data[end] - self.data[start] 63 | except IndexError: 64 | res = self.data[end] 65 | return res 66 | 67 | class TorchExponentialMovingAverage(object): 68 | def __init__(self, decay=0.999): 69 | self.decay = decay 70 | self.ema = {} 71 | self.number = {} 72 | 73 | def push(self, current_data): 74 | assert isinstance(current_data, dict), "current_data should be a dict" 75 | for key in current_data: 76 | if key in self.ema: 77 | # in-place 78 | self.ema[key] -= (1.0 - self.decay) * (self.ema[key] - current_data[key]) 79 | self.number[key] += 1 80 | else: 81 | # self.ema[key] = copy.deepcopy(current_data[key]) 82 | self.ema[key] = current_data[key] * (1.0 - self.decay) 83 | self.number[key] = 1 84 | 85 | def average(self): 86 | scaled_ema = {} 87 | for key in self.ema: 88 | scaled_ema[key] = self.ema[key] / (1.0 - self.decay ** self.number[key]) 89 | return scaled_ema 90 | 91 | 92 | # net: pytorch module 93 | # strict: strict matching for set 94 | def set_named_parameters(net, named_params, strict=True): 95 | assert isinstance(named_params, dict), "named_params should be a dict" 96 | orig_params_data = {} 97 | for n, p in net.named_parameters(): 98 | orig_params_data[n] = copy.deepcopy(p.data) 99 | if strict: 100 | assert len(named_params) == len(list(net.named_parameters())), "Unmatched number of params!" 101 | for n, p in net.named_parameters(): 102 | if strict: 103 | assert n in named_params, "Unknown param name!" 104 | if n in named_params: 105 | p.data.copy_(named_params[n]) 106 | 107 | return orig_params_data 108 | 109 | def next_group(g, maxlim, arch, logger): 110 | if g < 0 or g >= len(maxlim): 111 | logger.info('group index %d is out of range.' % g) 112 | return -1 113 | for i in range(len(maxlim)): 114 | idx = (g+i+1)%len(maxlim) 115 | if maxlim[idx] > arch[idx]: 116 | return idx 117 | return -1 118 | 119 | def next_arch(mode, maxlim, arch, logger, sub=None, rate=0.333, group=0): 120 | tmp_arch = [v for v in arch] 121 | if 'all' == mode: 122 | tmp_arch = [v+1 for v in tmp_arch] 123 | elif 'group' == mode and group >= 0 and group < len(arch): 124 | tmp_arch[group] += 1 125 | elif 'rate' == mode: 126 | num = int(round(sum(arch)*rate)) 127 | while num >= len(arch): 128 | tmp_arch = [v + 1 for v in tmp_arch] 129 | num -= len(arch) 130 | if num: 131 | rperm = np.random.permutation(len(arch)) 132 | for idx in rperm[:num]: 133 | tmp_arch[idx] += 1 134 | elif 'sub' == mode and (sub is not None): 135 | for idx, val in enumerate(tmp_arch): 136 | tmp_arch[idx] += sub[idx] 137 | else: 138 | logger.fatal('Unknown mode') 139 | exit() 140 | 141 | res = [] 142 | for r, m in zip(tmp_arch, maxlim): 143 | if r > m: 144 | res.append(m) 145 | else: 146 | res.append(r) 147 | return res -------------------------------------------------------------------------------- /imagenet/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['PlainNet', 'ResNet', 'ResNetBasic', 'ResNetBottleneck', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | class PlainBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(PlainBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | return out 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None): 48 | super(BasicBlock, self).__init__() 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = nn.BatchNorm2d(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None): 80 | super(Bottleneck, self).__init__() 81 | self.conv1 = conv1x1(inplanes, planes) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | self.conv2 = conv3x3(planes, planes, stride) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | self.conv3 = conv1x1(planes, planes * self.expansion) 86 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | 114 | class ResNet(nn.Module): 115 | 116 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 117 | super(ResNet, self).__init__() 118 | self.inplanes = 64 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | self.bn1 = nn.BatchNorm2d(64) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, 64, layers[0]) 125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 128 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 129 | self.fc = nn.Linear(512 * block.expansion, num_classes) 130 | 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 134 | elif isinstance(m, nn.BatchNorm2d): 135 | nn.init.constant_(m.weight, 1) 136 | nn.init.constant_(m.bias, 0) 137 | 138 | # Zero-initialize the last BN in each residual branch, 139 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 140 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 141 | if zero_init_residual: 142 | for m in self.modules(): 143 | if isinstance(m, Bottleneck): 144 | nn.init.constant_(m.bn3.weight, 0) 145 | elif isinstance(m, BasicBlock): 146 | nn.init.constant_(m.bn2.weight, 0) 147 | 148 | def _make_layer(self, block, planes, blocks, stride=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | conv1x1(self.inplanes, planes * block.expansion, stride), 153 | nn.BatchNorm2d(planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, stride, downsample)) 158 | self.inplanes = planes * block.expansion 159 | for _ in range(1, blocks): 160 | layers.append(block(self.inplanes, planes)) 161 | 162 | return nn.Sequential(*layers) 163 | 164 | def forward(self, x): 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | x = self.maxpool(x) 169 | 170 | x = self.layer1(x) 171 | x = self.layer2(x) 172 | x = self.layer3(x) 173 | x = self.layer4(x) 174 | 175 | x = self.avgpool(x) 176 | x = x.view(x.size(0), -1) 177 | x = self.fc(x) 178 | 179 | return x 180 | 181 | def ResNetBasic(num_blocks): 182 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 183 | return ResNet(BasicBlock, num_blocks) 184 | 185 | def PlainNet(num_blocks): 186 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 187 | return ResNet(PlainBlock, num_blocks) 188 | 189 | def ResNetBottleneck(num_blocks): 190 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 191 | return ResNet(Bottleneck, num_blocks) 192 | 193 | def resnet18(pretrained=False, **kwargs): 194 | """Constructs a ResNet-18 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 202 | return model 203 | 204 | 205 | def resnet34(pretrained=False, **kwargs): 206 | """Constructs a ResNet-34 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 214 | return model 215 | 216 | 217 | def resnet50(pretrained=False, **kwargs): 218 | """Constructs a ResNet-50 model. 219 | 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 224 | if pretrained: 225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 226 | return model 227 | 228 | 229 | def resnet101(pretrained=False, **kwargs): 230 | """Constructs a ResNet-101 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 238 | return model 239 | 240 | 241 | def resnet152(pretrained=False, **kwargs): 242 | """Constructs a ResNet-152 model. 243 | 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | """ 247 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 248 | if pretrained: 249 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 250 | return model 251 | -------------------------------------------------------------------------------- /cifar10/models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import pswitch 13 | 14 | class BasicSwitchBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicSwitchBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.switch = pswitch.PSwitch() 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out = self.switch(out) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | class PlainBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PlainBlock, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | return out 51 | 52 | class PlainNoBNBlock(nn.Module): 53 | expansion = 1 54 | 55 | def __init__(self, in_planes, planes, stride=1): 56 | super(PlainNoBNBlock, self).__init__() 57 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.conv1(x)) 61 | return out 62 | 63 | class BasicBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, in_planes, planes, stride=1): 67 | super(BasicBlock, self).__init__() 68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | 73 | self.shortcut = nn.Sequential() 74 | if stride != 1 or in_planes != self.expansion*planes: 75 | self.shortcut = nn.Sequential( 76 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 77 | nn.BatchNorm2d(self.expansion*planes) 78 | ) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | out = self.bn2(self.conv2(out)) 83 | out += self.shortcut(x) 84 | out = F.relu(out) 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | expansion = 4 90 | 91 | def __init__(self, in_planes, planes, stride=1): 92 | super(Bottleneck, self).__init__() 93 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(planes) 95 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 96 | self.bn2 = nn.BatchNorm2d(planes) 97 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 98 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 99 | 100 | self.shortcut = nn.Sequential() 101 | if stride != 1 or in_planes != self.expansion*planes: 102 | self.shortcut = nn.Sequential( 103 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(self.expansion*planes) 105 | ) 106 | 107 | def forward(self, x): 108 | out = F.relu(self.bn1(self.conv1(x))) 109 | out = F.relu(self.bn2(self.conv2(out))) 110 | out = self.bn3(self.conv3(out)) 111 | out += self.shortcut(x) 112 | out = F.relu(out) 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | def __init__(self, block, num_blocks, num_classes=10, image_channels=3, batchnorm=True): 118 | super(ResNet, self).__init__() 119 | self.in_planes = 64 120 | if batchnorm: 121 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=3, stride=1, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(64) 123 | else: 124 | self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=3, stride=1, padding=1, bias=True) 125 | self.bn1 = nn.Sequential() 126 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 127 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 128 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 129 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 130 | self.linear = nn.Linear(512*block.expansion, num_classes) 131 | 132 | def _make_layer(self, block, planes, num_blocks, stride): 133 | strides = [stride] + [1]*(num_blocks-1) 134 | layers = [] 135 | for stride in strides: 136 | layers.append(block(self.in_planes, planes, stride)) 137 | self.in_planes = planes * block.expansion 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | out = F.relu(self.bn1(self.conv1(x))) 142 | out = self.layer1(out) 143 | out = self.layer2(out) 144 | out = self.layer3(out) 145 | out = self.layer4(out) 146 | out = F.avg_pool2d(out, 4) 147 | out = out.view(out.size(0), -1) 148 | out = self.linear(out) 149 | return out 150 | 151 | class CifarResNet(nn.Module): 152 | def __init__(self, block, num_blocks, num_classes=10, image_channels=3, batchnorm=True): 153 | super(CifarResNet, self).__init__() 154 | self.in_planes = 16 155 | if batchnorm: 156 | self.conv1 = nn.Conv2d(image_channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 157 | self.bn1 = nn.BatchNorm2d(16) 158 | else: 159 | self.conv1 = nn.Conv2d(image_channels, 16, kernel_size=3, stride=1, padding=1, bias=True) 160 | self.bn1 = nn.Sequential() 161 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 162 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 163 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 164 | self.linear = nn.Linear(64*block.expansion, num_classes) 165 | 166 | def _make_layer(self, block, planes, num_blocks, stride): 167 | strides = [stride] + [1]*(num_blocks-1) 168 | layers = [] 169 | for stride in strides: 170 | layers.append(block(self.in_planes, planes, stride)) 171 | self.in_planes = planes * block.expansion 172 | return nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | out = F.relu(self.bn1(self.conv1(x))) 176 | out = self.layer1(out) 177 | out = self.layer2(out) 178 | out = self.layer3(out) 179 | out = F.avg_pool2d(out, 8) 180 | out = out.view(out.size(0), -1) 181 | out = self.linear(out) 182 | return out 183 | 184 | def CifarResNetBasic(num_blocks, num_classes=10, image_channels=3): 185 | assert len(num_blocks) == 3, "3 blocks are needed, but %d is found." % len(num_blocks) 186 | print ('num_classes=%d, image_channels=%d' % (num_classes, image_channels)) 187 | return CifarResNet(BasicBlock, num_blocks, num_classes=num_classes, image_channels=image_channels) 188 | 189 | def CifarPlainNet(num_blocks, num_classes=10, image_channels=3): 190 | assert len(num_blocks) == 3, "3 blocks are needed, but %d is found." % len(num_blocks) 191 | print ('num_classes=%d, image_channels=%d' % (num_classes, image_channels)) 192 | # CifarResNet is NOT a ResNet, it's just a building func 193 | return CifarResNet(PlainBlock, num_blocks, num_classes=num_classes, image_channels=image_channels) 194 | 195 | def CifarPlainNoBNNet(num_blocks, num_classes=10, image_channels=3): 196 | assert len(num_blocks) == 3, "3 blocks are needed, but %d is found." % len(num_blocks) 197 | print ('num_classes=%d, image_channels=%d' % (num_classes, image_channels)) 198 | # CifarResNet is NOT a ResNet, it's just a building func 199 | return CifarResNet(PlainNoBNBlock, num_blocks, num_classes=num_classes, image_channels=image_channels, batchnorm=False) 200 | 201 | def CifarSwitchResNetBasic(num_blocks, num_classes=10, image_channels=3): 202 | assert len(num_blocks) == 3, "3 blocks are needed, but %d is found." % len(num_blocks) 203 | return CifarResNet(BasicSwitchBlock, num_blocks, num_classes=num_classes, image_channels=image_channels) 204 | 205 | def PlainNet(num_blocks, num_classes=10, image_channels=3): 206 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 207 | # ResNet is NOT a ResNet, it's just a building func 208 | return ResNet(PlainBlock, num_blocks, num_classes=num_classes, image_channels=image_channels) 209 | 210 | def PlainNoBNNet(num_blocks, num_classes=10, image_channels=3): 211 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 212 | # ResNet is NOT a ResNet, it's just a building func 213 | return ResNet(PlainNoBNBlock, num_blocks, num_classes=num_classes, image_channels=image_channels, batchnorm=False) 214 | 215 | def ResNetBasic(num_blocks, num_classes=10, image_channels=3): 216 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 217 | return ResNet(BasicBlock, num_blocks, num_classes=num_classes, image_channels=image_channels) 218 | 219 | def ResNetBottleneck(num_blocks, num_classes=10, image_channels=3): 220 | assert len(num_blocks) == 4, "4 blocks are needed, but %d is found." % len(num_blocks) 221 | return ResNet(Bottleneck, num_blocks, num_classes=num_classes, image_channels=image_channels) 222 | 223 | def ResNet18(): 224 | return ResNet(BasicBlock, [2,2,2,2]) 225 | 226 | def ResNet34(): 227 | return ResNet(BasicBlock, [3,4,6,3]) 228 | 229 | def ResNet50(): 230 | return ResNet(Bottleneck, [3,4,6,3]) 231 | 232 | def ResNet101(): 233 | return ResNet(Bottleneck, [3,4,23,3]) 234 | 235 | def ResNet152(): 236 | return ResNet(Bottleneck, [3,8,36,3]) 237 | 238 | 239 | def test(): 240 | net = ResNet18() 241 | y = net(torch.randn(1,3,32,32)) 242 | print(y.size()) 243 | 244 | # test() 245 | -------------------------------------------------------------------------------- /cifar10/flops_counter.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/sovrasov/flops-counter.pytorch 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True): 8 | assert type(input_res) is tuple 9 | assert len(input_res) == 2 10 | batch = torch.FloatTensor(1, 3, *input_res) 11 | flops_model = add_flops_counting_methods(model) 12 | flops_model.eval().start_flops_count() 13 | out = flops_model(batch) 14 | 15 | if print_per_layer_stat: 16 | print_model_with_flops(flops_model) 17 | flops_count = flops_model.compute_average_flops_cost() 18 | params_count = get_model_parameters_number(flops_model) 19 | flops_model.stop_flops_count() 20 | 21 | if as_strings: 22 | return flops_to_string(flops_count), params_to_string(params_count) 23 | 24 | return flops_count, params_count 25 | 26 | def flops_to_string(flops, units='GMac', precision=3): 27 | if units is None: 28 | if flops // 10**9 > 0: 29 | return str(round(flops / 10.**9, precision)) + ' GMac' 30 | elif flops // 10**6 > 0: 31 | return str(round(flops / 10.**6, precision)) + ' MMac' 32 | elif flops // 10**3 > 0: 33 | return str(round(flops / 10.**3, precision)) + ' KMac' 34 | else: 35 | return str(flops) + ' Mac' 36 | else: 37 | if units == 'GMac': 38 | return str(round(flops / 10.**9, precision)) + ' ' + units 39 | elif units == 'MMac': 40 | return str(round(flops / 10.**6, precision)) + ' ' + units 41 | elif units == 'KMac': 42 | return str(round(flops / 10.**3, precision)) + ' ' + units 43 | else: 44 | return str(flops) + ' Mac' 45 | 46 | def params_to_string(params_num): 47 | if params_num // 10 ** 6 > 0: 48 | return str(round(params_num / 10. ** 6, 3)) + ' M' 49 | elif params_num // 10 ** 3: 50 | return str(round(params_num / 10. ** 3, 3)) + ' k' 51 | 52 | def print_model_with_flops(model, units='GMac', precision=3): 53 | total_flops = model.compute_average_flops_cost() 54 | 55 | def accumulate_flops(self): 56 | if is_supported_instance(self): 57 | return self.__flops__ / model.__batch_counter__ 58 | else: 59 | sum = 0 60 | for m in self.children(): 61 | sum += m.accumulate_flops() 62 | return sum 63 | 64 | def flops_repr(self): 65 | accumulated_flops_cost = self.accumulate_flops() 66 | return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), 67 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 68 | self.original_extra_repr()]) 69 | 70 | def add_extra_repr(m): 71 | m.accumulate_flops = accumulate_flops.__get__(m) 72 | flops_extra_repr = flops_repr.__get__(m) 73 | if m.extra_repr != flops_extra_repr: 74 | m.original_extra_repr = m.extra_repr 75 | m.extra_repr = flops_extra_repr 76 | assert m.extra_repr != m.original_extra_repr 77 | 78 | def del_extra_repr(m): 79 | if hasattr(m, 'original_extra_repr'): 80 | m.extra_repr = m.original_extra_repr 81 | del m.original_extra_repr 82 | if hasattr(m, 'accumulate_flops'): 83 | del m.accumulate_flops 84 | 85 | model.apply(add_extra_repr) 86 | print(model) 87 | model.apply(del_extra_repr) 88 | 89 | def get_model_parameters_number(model): 90 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | return params_num 92 | 93 | def add_flops_counting_methods(net_main_module): 94 | # adding additional methods to the existing module object, 95 | # this is done this way so that each function has access to self object 96 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 97 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 98 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 99 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 100 | 101 | net_main_module.reset_flops_count() 102 | 103 | # Adding variables necessary for masked flops computation 104 | net_main_module.apply(add_flops_mask_variable_or_reset) 105 | 106 | return net_main_module 107 | 108 | 109 | def compute_average_flops_cost(self): 110 | """ 111 | A method that will be available after add_flops_counting_methods() is called 112 | on a desired net object. 113 | 114 | Returns current mean flops consumption per image. 115 | 116 | """ 117 | 118 | batches_count = self.__batch_counter__ 119 | flops_sum = 0 120 | for module in self.modules(): 121 | if is_supported_instance(module): 122 | flops_sum += module.__flops__ 123 | 124 | return flops_sum / batches_count 125 | 126 | 127 | def start_flops_count(self): 128 | """ 129 | A method that will be available after add_flops_counting_methods() is called 130 | on a desired net object. 131 | 132 | Activates the computation of mean flops consumption per image. 133 | Call it before you run the network. 134 | 135 | """ 136 | add_batch_counter_hook_function(self) 137 | self.apply(add_flops_counter_hook_function) 138 | 139 | 140 | def stop_flops_count(self): 141 | """ 142 | A method that will be available after add_flops_counting_methods() is called 143 | on a desired net object. 144 | 145 | Stops computing the mean flops consumption per image. 146 | Call whenever you want to pause the computation. 147 | 148 | """ 149 | remove_batch_counter_hook_function(self) 150 | self.apply(remove_flops_counter_hook_function) 151 | 152 | 153 | def reset_flops_count(self): 154 | """ 155 | A method that will be available after add_flops_counting_methods() is called 156 | on a desired net object. 157 | 158 | Resets statistics computed so far. 159 | 160 | """ 161 | add_batch_counter_variables_or_reset(self) 162 | self.apply(add_flops_counter_variable_or_reset) 163 | 164 | 165 | def add_flops_mask(module, mask): 166 | def add_flops_mask_func(module): 167 | if isinstance(module, torch.nn.Conv2d): 168 | module.__mask__ = mask 169 | module.apply(add_flops_mask_func) 170 | 171 | 172 | def remove_flops_mask(module): 173 | module.apply(add_flops_mask_variable_or_reset) 174 | 175 | 176 | # ---- Internal functions 177 | def is_supported_instance(module): 178 | if isinstance(module, (torch.nn.Conv2d, torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 179 | torch.nn.LeakyReLU, torch.nn.ReLU6, torch.nn.Linear, \ 180 | torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.BatchNorm2d, \ 181 | torch.nn.Upsample, nn.AdaptiveMaxPool2d, nn.AdaptiveAvgPool2d)): 182 | return True 183 | 184 | return False 185 | 186 | 187 | def empty_flops_counter_hook(module, input, output): 188 | module.__flops__ += 0 189 | 190 | 191 | def upsample_flops_counter_hook(module, input, output): 192 | output_size = output[0] 193 | batch_size = output_size.shape[0] 194 | output_elements_count = batch_size 195 | for val in output_size.shape[1:]: 196 | output_elements_count *= val 197 | module.__flops__ += output_elements_count 198 | 199 | 200 | def relu_flops_counter_hook(module, input, output): 201 | active_elements_count = output.numel() 202 | module.__flops__ += active_elements_count 203 | 204 | 205 | def linear_flops_counter_hook(module, input, output): 206 | input = input[0] 207 | batch_size = input.shape[0] 208 | module.__flops__ += batch_size * input.shape[1] * output.shape[1] 209 | 210 | 211 | def pool_flops_counter_hook(module, input, output): 212 | input = input[0] 213 | module.__flops__ += np.prod(input.shape) 214 | 215 | def bn_flops_counter_hook(module, input, output): 216 | module.affine 217 | input = input[0] 218 | 219 | batch_flops = np.prod(input.shape) 220 | if module.affine: 221 | batch_flops *= 2 222 | module.__flops__ += batch_flops 223 | 224 | def conv_flops_counter_hook(conv_module, input, output): 225 | # Can have multiple inputs, getting the first one 226 | input = input[0] 227 | 228 | batch_size = input.shape[0] 229 | output_height, output_width = output.shape[2:] 230 | 231 | kernel_height, kernel_width = conv_module.kernel_size 232 | in_channels = conv_module.in_channels 233 | out_channels = conv_module.out_channels 234 | groups = conv_module.groups 235 | 236 | filters_per_channel = out_channels // groups 237 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 238 | 239 | active_elements_count = batch_size * output_height * output_width 240 | 241 | if conv_module.__mask__ is not None: 242 | # (b, 1, h, w) 243 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 244 | active_elements_count = flops_mask.sum() 245 | 246 | overall_conv_flops = conv_per_position_flops * active_elements_count 247 | 248 | bias_flops = 0 249 | 250 | if conv_module.bias is not None: 251 | 252 | bias_flops = out_channels * active_elements_count 253 | 254 | overall_flops = overall_conv_flops + bias_flops 255 | 256 | conv_module.__flops__ += overall_flops 257 | 258 | 259 | def batch_counter_hook(module, input, output): 260 | # Can have multiple inputs, getting the first one 261 | input = input[0] 262 | batch_size = input.shape[0] 263 | module.__batch_counter__ += batch_size 264 | 265 | 266 | def add_batch_counter_variables_or_reset(module): 267 | 268 | module.__batch_counter__ = 0 269 | 270 | 271 | def add_batch_counter_hook_function(module): 272 | if hasattr(module, '__batch_counter_handle__'): 273 | return 274 | 275 | handle = module.register_forward_hook(batch_counter_hook) 276 | module.__batch_counter_handle__ = handle 277 | 278 | 279 | def remove_batch_counter_hook_function(module): 280 | if hasattr(module, '__batch_counter_handle__'): 281 | module.__batch_counter_handle__.remove() 282 | del module.__batch_counter_handle__ 283 | 284 | 285 | def add_flops_counter_variable_or_reset(module): 286 | if is_supported_instance(module): 287 | module.__flops__ = 0 288 | 289 | 290 | def add_flops_counter_hook_function(module): 291 | if is_supported_instance(module): 292 | if hasattr(module, '__flops_handle__'): 293 | return 294 | 295 | if isinstance(module, torch.nn.Conv2d): 296 | handle = module.register_forward_hook(conv_flops_counter_hook) 297 | elif isinstance(module, (torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 298 | torch.nn.LeakyReLU, torch.nn.ReLU6)): 299 | handle = module.register_forward_hook(relu_flops_counter_hook) 300 | elif isinstance(module, torch.nn.Linear): 301 | handle = module.register_forward_hook(linear_flops_counter_hook) 302 | elif isinstance(module, (torch.nn.AvgPool2d, torch.nn.MaxPool2d, nn.AdaptiveMaxPool2d, \ 303 | nn.AdaptiveAvgPool2d)): 304 | handle = module.register_forward_hook(pool_flops_counter_hook) 305 | elif isinstance(module, torch.nn.BatchNorm2d): 306 | handle = module.register_forward_hook(bn_flops_counter_hook) 307 | elif isinstance(module, torch.nn.Upsample): 308 | handle = module.register_forward_hook(upsample_flops_counter_hook) 309 | else: 310 | handle = module.register_forward_hook(empty_flops_counter_hook) 311 | module.__flops_handle__ = handle 312 | 313 | 314 | def remove_flops_counter_hook_function(module): 315 | if is_supported_instance(module): 316 | if hasattr(module, '__flops_handle__'): 317 | module.__flops_handle__.remove() 318 | del module.__flops_handle__ 319 | # --- Masked flops counting 320 | 321 | 322 | # Also being run in the initialization 323 | def add_flops_mask_variable_or_reset(module): 324 | if is_supported_instance(module): 325 | module.__mask__ = None 326 | -------------------------------------------------------------------------------- /imagenet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.utils.data 15 | import torch.utils.data.distributed 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torchvision.models as models 19 | 20 | model_names = sorted(name for name in models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and callable(models.__dict__[name])) 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 25 | parser.add_argument('data', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 28 | choices=model_names, 29 | help='model architecture: ' + 30 | ' | '.join(model_names) + 31 | ' (default: resnet18)') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 37 | help='manual epoch number (useful on restarts)') 38 | parser.add_argument('-b', '--batch-size', default=256, type=int, 39 | metavar='N', help='mini-batch size (default: 256)') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)') 46 | parser.add_argument('--print-freq', '-p', default=10, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 51 | help='evaluate model on validation set') 52 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 53 | help='use pre-trained model') 54 | parser.add_argument('--world-size', default=1, type=int, 55 | help='number of distributed processes') 56 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 57 | help='url used to set up distributed training') 58 | parser.add_argument('--dist-backend', default='gloo', type=str, 59 | help='distributed backend') 60 | parser.add_argument('--seed', default=None, type=int, 61 | help='seed for initializing training. ') 62 | parser.add_argument('--gpu', default=None, type=int, 63 | help='GPU id to use.') 64 | 65 | best_prec1 = 0 66 | 67 | 68 | def main(): 69 | global args, best_prec1 70 | args = parser.parse_args() 71 | 72 | if args.seed is not None: 73 | random.seed(args.seed) 74 | torch.manual_seed(args.seed) 75 | cudnn.deterministic = True 76 | warnings.warn('You have chosen to seed training. ' 77 | 'This will turn on the CUDNN deterministic setting, ' 78 | 'which can slow down your training considerably! ' 79 | 'You may see unexpected behavior when restarting ' 80 | 'from checkpoints.') 81 | 82 | if args.gpu is not None: 83 | warnings.warn('You have chosen a specific GPU. This will completely ' 84 | 'disable data parallelism.') 85 | 86 | args.distributed = args.world_size > 1 87 | 88 | if args.distributed: 89 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 90 | world_size=args.world_size) 91 | 92 | # create model 93 | if args.pretrained: 94 | print("=> using pre-trained model '{}'".format(args.arch)) 95 | model = models.__dict__[args.arch](pretrained=True) 96 | else: 97 | print("=> creating model '{}'".format(args.arch)) 98 | model = models.__dict__[args.arch]() 99 | 100 | if args.gpu is not None: 101 | model = model.cuda(args.gpu) 102 | elif args.distributed: 103 | model.cuda() 104 | model = torch.nn.parallel.DistributedDataParallel(model) 105 | else: 106 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 107 | model.features = torch.nn.DataParallel(model.features) 108 | model.cuda() 109 | else: 110 | model = torch.nn.DataParallel(model).cuda() 111 | 112 | # define loss function (criterion) and optimizer 113 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 114 | 115 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 116 | momentum=args.momentum, 117 | weight_decay=args.weight_decay) 118 | 119 | # optionally resume from a checkpoint 120 | if args.resume: 121 | if os.path.isfile(args.resume): 122 | print("=> loading checkpoint '{}'".format(args.resume)) 123 | checkpoint = torch.load(args.resume) 124 | args.start_epoch = checkpoint['epoch'] 125 | best_prec1 = checkpoint['best_prec1'] 126 | model.load_state_dict(checkpoint['state_dict']) 127 | optimizer.load_state_dict(checkpoint['optimizer']) 128 | print("=> loaded checkpoint '{}' (epoch {})" 129 | .format(args.resume, checkpoint['epoch'])) 130 | else: 131 | print("=> no checkpoint found at '{}'".format(args.resume)) 132 | 133 | cudnn.benchmark = True 134 | 135 | # Data loading code 136 | traindir = os.path.join(args.data, 'train') 137 | valdir = os.path.join(args.data, 'val') 138 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 139 | std=[0.229, 0.224, 0.225]) 140 | 141 | train_dataset = datasets.ImageFolder( 142 | traindir, 143 | transforms.Compose([ 144 | transforms.RandomResizedCrop(224), 145 | transforms.RandomHorizontalFlip(), 146 | transforms.ToTensor(), 147 | normalize, 148 | ])) 149 | 150 | if args.distributed: 151 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 152 | else: 153 | train_sampler = None 154 | 155 | train_loader = torch.utils.data.DataLoader( 156 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 157 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 158 | 159 | val_loader = torch.utils.data.DataLoader( 160 | datasets.ImageFolder(valdir, transforms.Compose([ 161 | transforms.Resize(256), 162 | transforms.CenterCrop(224), 163 | transforms.ToTensor(), 164 | normalize, 165 | ])), 166 | batch_size=args.batch_size, shuffle=False, 167 | num_workers=args.workers, pin_memory=True) 168 | 169 | if args.evaluate: 170 | validate(val_loader, model, criterion) 171 | return 172 | 173 | for epoch in range(args.start_epoch, args.epochs): 174 | if args.distributed: 175 | train_sampler.set_epoch(epoch) 176 | adjust_learning_rate(optimizer, epoch) 177 | 178 | # train for one epoch 179 | train(train_loader, model, criterion, optimizer, epoch) 180 | 181 | # evaluate on validation set 182 | prec1 = validate(val_loader, model, criterion) 183 | 184 | # remember best prec@1 and save checkpoint 185 | is_best = prec1 > best_prec1 186 | best_prec1 = max(prec1, best_prec1) 187 | save_checkpoint({ 188 | 'epoch': epoch + 1, 189 | 'arch': args.arch, 190 | 'state_dict': model.state_dict(), 191 | 'best_prec1': best_prec1, 192 | 'optimizer' : optimizer.state_dict(), 193 | }, is_best) 194 | 195 | 196 | def train(train_loader, model, criterion, optimizer, epoch): 197 | batch_time = AverageMeter() 198 | data_time = AverageMeter() 199 | losses = AverageMeter() 200 | top1 = AverageMeter() 201 | top5 = AverageMeter() 202 | 203 | # switch to train mode 204 | model.train() 205 | 206 | end = time.time() 207 | for i, (input, target) in enumerate(train_loader): 208 | # measure data loading time 209 | data_time.update(time.time() - end) 210 | 211 | if args.gpu is not None: 212 | input = input.cuda(args.gpu, non_blocking=True) 213 | target = target.cuda(args.gpu, non_blocking=True) 214 | 215 | # compute output 216 | output = model(input) 217 | loss = criterion(output, target) 218 | 219 | # measure accuracy and record loss 220 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 221 | losses.update(loss.item(), input.size(0)) 222 | top1.update(prec1[0], input.size(0)) 223 | top5.update(prec5[0], input.size(0)) 224 | 225 | # compute gradient and do SGD step 226 | optimizer.zero_grad() 227 | loss.backward() 228 | optimizer.step() 229 | 230 | # measure elapsed time 231 | batch_time.update(time.time() - end) 232 | end = time.time() 233 | 234 | if i % args.print_freq == 0: 235 | print('Epoch: [{0}][{1}/{2}]\t' 236 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 237 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 238 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 239 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 240 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 241 | epoch, i, len(train_loader), batch_time=batch_time, 242 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 243 | 244 | 245 | def validate(val_loader, model, criterion): 246 | batch_time = AverageMeter() 247 | losses = AverageMeter() 248 | top1 = AverageMeter() 249 | top5 = AverageMeter() 250 | 251 | # switch to evaluate mode 252 | model.eval() 253 | 254 | with torch.no_grad(): 255 | end = time.time() 256 | for i, (input, target) in enumerate(val_loader): 257 | if args.gpu is not None: 258 | input = input.cuda(args.gpu, non_blocking=True) 259 | target = target.cuda(args.gpu, non_blocking=True) 260 | 261 | # compute output 262 | output = model(input) 263 | loss = criterion(output, target) 264 | 265 | # measure accuracy and record loss 266 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 267 | losses.update(loss.item(), input.size(0)) 268 | top1.update(prec1[0], input.size(0)) 269 | top5.update(prec5[0], input.size(0)) 270 | 271 | # measure elapsed time 272 | batch_time.update(time.time() - end) 273 | end = time.time() 274 | 275 | if i % args.print_freq == 0: 276 | print('Test: [{0}/{1}]\t' 277 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 278 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 279 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 280 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 281 | i, len(val_loader), batch_time=batch_time, loss=losses, 282 | top1=top1, top5=top5)) 283 | 284 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 285 | .format(top1=top1, top5=top5)) 286 | 287 | return top1.avg 288 | 289 | 290 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 291 | torch.save(state, filename) 292 | if is_best: 293 | shutil.copyfile(filename, 'model_best.pth.tar') 294 | 295 | 296 | class AverageMeter(object): 297 | """Computes and stores the average and current value""" 298 | def __init__(self): 299 | self.reset() 300 | 301 | def reset(self): 302 | self.val = 0 303 | self.avg = 0 304 | self.sum = 0 305 | self.count = 0 306 | 307 | def update(self, val, n=1): 308 | self.val = val 309 | self.sum += val * n 310 | self.count += n 311 | self.avg = self.sum / self.count 312 | 313 | 314 | def adjust_learning_rate(optimizer, epoch): 315 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 316 | lr = args.lr * (0.1 ** (epoch // 30)) 317 | for param_group in optimizer.param_groups: 318 | param_group['lr'] = lr 319 | 320 | 321 | def accuracy(output, target, topk=(1,)): 322 | """Computes the precision@k for the specified values of k""" 323 | with torch.no_grad(): 324 | maxk = max(topk) 325 | batch_size = target.size(0) 326 | 327 | _, pred = output.topk(maxk, 1, True, True) 328 | pred = pred.t() 329 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 330 | 331 | res = [] 332 | for k in topk: 333 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 334 | res.append(correct_k.mul_(100.0 / batch_size)) 335 | return res 336 | 337 | 338 | if __name__ == '__main__': 339 | main() 340 | -------------------------------------------------------------------------------- /imagenet/GPUtil.py: -------------------------------------------------------------------------------- 1 | # GPUtil - GPU utilization 2 | # 3 | # A Python module for programmically getting the GPU utilization from NVIDA GPUs using nvidia-smi 4 | # 5 | # Author: Anders Krogh Mortensen (anderskm) 6 | # Date: 16 January 2017 7 | # Web: https://github.com/anderskm/gputil 8 | # 9 | # LICENSE 10 | # 11 | # MIT License 12 | # 13 | # Copyright (c) 2017 anderskm 14 | # 15 | # Permission is hereby granted, free of charge, to any person obtaining a copy 16 | # of this software and associated documentation files (the "Software"), to deal 17 | # in the Software without restriction, including without limitation the rights 18 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 19 | # copies of the Software, and to permit persons to whom the Software is 20 | # furnished to do so, subject to the following conditions: 21 | # 22 | # The above copyright notice and this permission notice shall be included in all 23 | # copies or substantial portions of the Software. 24 | # 25 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 | # SOFTWARE. 32 | 33 | from subprocess import Popen, PIPE 34 | from distutils import spawn 35 | import os 36 | import math 37 | import random 38 | import time 39 | import sys 40 | import platform 41 | 42 | 43 | __version__ = '1.4.0' 44 | 45 | class GPU: 46 | def __init__(self, ID, uuid, load, memoryTotal, memoryUsed, memoryFree, driver, gpu_name, serial, display_mode, display_active, temp_gpu): 47 | self.id = ID 48 | self.uuid = uuid 49 | self.load = load 50 | self.memoryUtil = float(memoryUsed)/float(memoryTotal) 51 | self.memoryTotal = memoryTotal 52 | self.memoryUsed = memoryUsed 53 | self.memoryFree = memoryFree 54 | self.driver = driver 55 | self.name = gpu_name 56 | self.serial = serial 57 | self.display_mode = display_mode 58 | self.display_active = display_active 59 | self.temperature = temp_gpu 60 | 61 | def safeFloatCast(strNumber): 62 | try: 63 | number = float(strNumber) 64 | except ValueError: 65 | number = float('nan') 66 | return number 67 | 68 | def getGPUs(): 69 | if platform.system() == "Windows": 70 | # If the platform is Windows and nvidia-smi 71 | # could not be found from the environment path, 72 | # try to find it from system drive with default installation path 73 | nvidia_smi = spawn.find_executable('nvidia-smi') 74 | if nvidia_smi is None: 75 | nvidia_smi = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ['systemdrive'] 76 | else: 77 | nvidia_smi = "nvidia-smi" 78 | 79 | # Get ID, processing and memory utilization for all GPUs 80 | try: 81 | p = Popen([nvidia_smi,"--query-gpu=index,uuid,utilization.gpu,memory.total,memory.used,memory.free,driver_version,name,gpu_serial,display_active,display_mode,temperature.gpu", "--format=csv,noheader,nounits"], stdout=PIPE) 82 | stdout, stderror = p.communicate() 83 | except: 84 | return [] 85 | output = stdout.decode('UTF-8') 86 | # output = output[2:-1] # Remove b' and ' from string added by python 87 | #print(output) 88 | ## Parse output 89 | # Split on line break 90 | lines = output.split(os.linesep) 91 | #print(lines) 92 | numDevices = len(lines)-1 93 | GPUs = [] 94 | for g in range(numDevices): 95 | line = lines[g] 96 | #print(line) 97 | vals = line.split(', ') 98 | #print(vals) 99 | for i in range(12): 100 | # print(vals[i]) 101 | if (i == 0): 102 | deviceIds = int(vals[i]) 103 | elif (i == 1): 104 | uuid = vals[i] 105 | elif (i == 2): 106 | gpuUtil = safeFloatCast(vals[i])/100 107 | elif (i == 3): 108 | memTotal = safeFloatCast(vals[i]) 109 | elif (i == 4): 110 | memUsed = safeFloatCast(vals[i]) 111 | elif (i == 5): 112 | memFree = safeFloatCast(vals[i]) 113 | elif (i == 6): 114 | driver = vals[i] 115 | elif (i == 7): 116 | gpu_name = vals[i] 117 | elif (i == 8): 118 | serial = vals[i] 119 | elif (i == 9): 120 | display_active = vals[i] 121 | elif (i == 10): 122 | display_mode = vals[i] 123 | elif (i == 11): 124 | temp_gpu = safeFloatCast(vals[i]); 125 | GPUs.append(GPU(deviceIds, uuid, gpuUtil, memTotal, memUsed, memFree, driver, gpu_name, serial, display_mode, display_active, temp_gpu)) 126 | return GPUs # (deviceIds, gpuUtil, memUtil) 127 | 128 | 129 | def getAvailable(order = 'first', limit=1, maxLoad=0.5, maxMemory=0.5, memoryFree=0, includeNan=False, excludeID=[], excludeUUID=[]): 130 | # order = first | last | random | load | memory 131 | # first --> select the GPU with the lowest ID (DEFAULT) 132 | # last --> select the GPU with the highest ID 133 | # random --> select a random available GPU 134 | # load --> select the GPU with the lowest load 135 | # memory --> select the GPU with the most memory available 136 | # limit = 1 (DEFAULT), 2, ..., Inf 137 | # Limit sets the upper limit for the number of GPUs to return. E.g. if limit = 2, but only one is available, only one is returned. 138 | 139 | # Get device IDs, load and memory usage 140 | GPUs = getGPUs() 141 | 142 | # Determine, which GPUs are available 143 | GPUavailability = getAvailability(GPUs, maxLoad=maxLoad, maxMemory=maxMemory, memoryFree=memoryFree, includeNan=includeNan, excludeID=excludeID, excludeUUID=excludeUUID) 144 | availAbleGPUindex = [idx for idx in range(0,len(GPUavailability)) if (GPUavailability[idx] == 1)] 145 | # Discard unavailable GPUs 146 | GPUs = [GPUs[g] for g in availAbleGPUindex] 147 | 148 | # Sort available GPUs according to the order argument 149 | if (order == 'first'): 150 | GPUs.sort(key=lambda x: float('inf') if math.isnan(x.id) else x.id, reverse=False) 151 | elif (order == 'last'): 152 | GPUs.sort(key=lambda x: float('-inf') if math.isnan(x.id) else x.id, reverse=True) 153 | elif (order == 'random'): 154 | GPUs = [GPUs[g] for g in random.sample(range(0,len(GPUs)),len(GPUs))] 155 | elif (order == 'load'): 156 | GPUs.sort(key=lambda x: float('inf') if math.isnan(x.load) else x.load, reverse=False) 157 | elif (order == 'memory'): 158 | GPUs.sort(key=lambda x: float('inf') if math.isnan(x.memoryUtil) else x.memoryUtil, reverse=False) 159 | 160 | # Extract the number of desired GPUs, but limited to the total number of available GPUs 161 | GPUs = GPUs[0:min(limit, len(GPUs))] 162 | 163 | # Extract the device IDs from the GPUs and return them 164 | deviceIds = [gpu.id for gpu in GPUs] 165 | 166 | return deviceIds 167 | 168 | #def getAvailability(GPUs, maxLoad = 0.5, maxMemory = 0.5, includeNan = False): 169 | # # Determine, which GPUs are available 170 | # GPUavailability = np.zeros(len(GPUs)) 171 | # for i in range(len(GPUs)): 172 | # if (GPUs[i].load < maxLoad or (includeNan and np.isnan(GPUs[i].load))) and (GPUs[i].memoryUtil < maxMemory or (includeNan and np.isnan(GPUs[i].memoryUtil))): 173 | # GPUavailability[i] = 1 174 | 175 | def getAvailability(GPUs, maxLoad=0.5, maxMemory=0.5, memoryFree=0, includeNan=False, excludeID=[], excludeUUID=[]): 176 | # Determine, which GPUs are available 177 | GPUavailability = [1 if (gpu.memoryFree>=memoryFree) and (gpu.load < maxLoad or (includeNan and math.isnan(gpu.load))) and (gpu.memoryUtil < maxMemory or (includeNan and math.isnan(gpu.memoryUtil))) and ((gpu.id not in excludeID) and (gpu.uuid not in excludeUUID)) else 0 for gpu in GPUs] 178 | return GPUavailability 179 | 180 | def getFirstAvailable(order = 'first', maxLoad=0.5, maxMemory=0.5, attempts=1, interval=900, verbose=False, includeNan=False, excludeID=[], excludeUUID=[]): 181 | #GPUs = getGPUs() 182 | #firstAvailableGPU = np.NaN 183 | #for i in range(len(GPUs)): 184 | # if (GPUs[i].load < maxLoad) & (GPUs[i].memory < maxMemory): 185 | # firstAvailableGPU = GPUs[i].id 186 | # break 187 | #return firstAvailableGPU 188 | for i in range(attempts): 189 | if (verbose): 190 | print('Attempting (' + str(i+1) + '/' + str(attempts) + ') to locate available GPU.') 191 | # Get first available GPU 192 | available = getAvailable(order=order, limit=1, maxLoad=maxLoad, maxMemory=maxMemory, includeNan=includeNan, excludeID=excludeID, excludeUUID=excludeUUID) 193 | # If an available GPU was found, break for loop. 194 | if (available): 195 | if (verbose): 196 | print('GPU ' + str(available) + ' located!') 197 | break 198 | # If this is not the last attempt, sleep for 'interval' seconds 199 | if (i != attempts-1): 200 | time.sleep(interval) 201 | # Check if an GPU was found, or if the attempts simply ran out. Throw error, if no GPU was found 202 | if (not(available)): 203 | raise RuntimeError('Could not find an available GPU after ' + str(attempts) + ' attempts with ' + str(interval) + ' seconds interval.') 204 | 205 | # Return found GPU 206 | return available 207 | 208 | 209 | def showUtilization(all=False, attrList=None, useOldCode=False): 210 | GPUs = getGPUs() 211 | if (all): 212 | if (useOldCode): 213 | print(' ID | Name | Serial | UUID || GPU util. | Memory util. || Memory total | Memory used | Memory free || Display mode | Display active |') 214 | print('------------------------------------------------------------------------------------------------------------------------------') 215 | for gpu in GPUs: 216 | print(' {0:2d} | {1:s} | {2:s} | {3:s} || {4:3.0f}% | {5:3.0f}% || {6:.0f}MB | {7:.0f}MB | {8:.0f}MB || {9:s} | {10:s}'.format(gpu.id,gpu.name,gpu.serial,gpu.uuid,gpu.load*100,gpu.memoryUtil*100,gpu.memoryTotal,gpu.memoryUsed,gpu.memoryFree,gpu.display_mode,gpu.display_active)) 217 | else: 218 | attrList = [[{'attr':'id','name':'ID'}, 219 | {'attr':'name','name':'Name'}, 220 | {'attr':'serial','name':'Serial'}, 221 | {'attr':'uuid','name':'UUID'}], 222 | [{'attr':'temperature','name':'GPU temp.','suffix':'C','transform': lambda x: x,'precision':0}, 223 | {'attr':'load','name':'GPU util.','suffix':'%','transform': lambda x: x*100,'precision':0}, 224 | {'attr':'memoryUtil','name':'Memory util.','suffix':'%','transform': lambda x: x*100,'precision':0}], 225 | [{'attr':'memoryTotal','name':'Memory total','suffix':'MB','precision':0}, 226 | {'attr':'memoryUsed','name':'Memory used','suffix':'MB','precision':0}, 227 | {'attr':'memoryFree','name':'Memory free','suffix':'MB','precision':0}], 228 | [{'attr':'display_mode','name':'Display mode'}, 229 | {'attr':'display_active','name':'Display active'}]] 230 | 231 | else: 232 | if (useOldCode): 233 | print(' ID GPU MEM') 234 | print('--------------') 235 | for gpu in GPUs: 236 | print(' {0:2d} {1:3.0f}% {2:3.0f}%'.format(gpu.id, gpu.load*100, gpu.memoryUtil*100)) 237 | else: 238 | attrList = [[{'attr':'id','name':'ID'}, 239 | {'attr':'load','name':'GPU','suffix':'%','transform': lambda x: x*100,'precision':0}, 240 | {'attr':'memoryUtil','name':'MEM','suffix':'%','transform': lambda x: x*100,'precision':0}], 241 | ] 242 | 243 | if (not useOldCode): 244 | if (attrList is not None): 245 | headerString = '' 246 | GPUstrings = ['']*len(GPUs) 247 | for attrGroup in attrList: 248 | #print(attrGroup) 249 | for attrDict in attrGroup: 250 | headerString = headerString + '| ' + attrDict['name'] + ' ' 251 | headerWidth = len(attrDict['name']) 252 | minWidth = len(attrDict['name']) 253 | 254 | attrPrecision = '.' + str(attrDict['precision']) if ('precision' in attrDict.keys()) else '' 255 | attrSuffix = str(attrDict['suffix']) if ('suffix' in attrDict.keys()) else '' 256 | attrTransform = attrDict['transform'] if ('transform' in attrDict.keys()) else lambda x : x 257 | for gpu in GPUs: 258 | attr = getattr(gpu,attrDict['attr']) 259 | 260 | attr = attrTransform(attr) 261 | 262 | if (isinstance(attr,float)): 263 | attrStr = ('{0:' + attrPrecision + 'f}').format(attr) 264 | elif (isinstance(attr,int)): 265 | attrStr = ('{0:d}').format(attr) 266 | elif (isinstance(attr,str)): 267 | attrStr = attr; 268 | elif (sys.version_info[0] == 2): 269 | if (isinstance(attr,unicode)): 270 | attrStr = attr.encode('ascii','ignore') 271 | else: 272 | raise TypeError('Unhandled object type (' + str(type(attr)) + ') for attribute \'' + attrDict['name'] + '\'') 273 | 274 | attrStr += attrSuffix 275 | 276 | minWidth = max(minWidth,len(attrStr)) 277 | 278 | headerString += ' '*max(0,minWidth-headerWidth) 279 | 280 | minWidthStr = str(minWidth - len(attrSuffix)) 281 | 282 | for gpuIdx,gpu in enumerate(GPUs): 283 | attr = getattr(gpu,attrDict['attr']) 284 | 285 | attr = attrTransform(attr) 286 | 287 | if (isinstance(attr,float)): 288 | attrStr = ('{0:'+ minWidthStr + attrPrecision + 'f}').format(attr) 289 | elif (isinstance(attr,int)): 290 | attrStr = ('{0:' + minWidthStr + 'd}').format(attr) 291 | elif (isinstance(attr,str)): 292 | attrStr = ('{0:' + minWidthStr + 's}').format(attr); 293 | elif (sys.version_info[0] == 2): 294 | if (isinstance(attr,unicode)): 295 | attrStr = ('{0:' + minWidthStr + 's}').format(attr.encode('ascii','ignore')) 296 | else: 297 | raise TypeError('Unhandled object type (' + str(type(attr)) + ') for attribute \'' + attrDict['name'] + '\'') 298 | 299 | attrStr += attrSuffix 300 | 301 | GPUstrings[gpuIdx] += '| ' + attrStr + ' ' 302 | 303 | headerString = headerString + '|' 304 | for gpuIdx,gpu in enumerate(GPUs): 305 | GPUstrings[gpuIdx] += '|' 306 | 307 | headerSpacingString = '-' * len(headerString) 308 | print(headerString) 309 | print(headerSpacingString) 310 | for GPUstring in GPUstrings: 311 | print(GPUstring) 312 | -------------------------------------------------------------------------------- /cifar10/main_add.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import matplotlib 5 | matplotlib.use("pdf") 6 | import matplotlib.pyplot as plt 7 | import logging 8 | from datetime import datetime 9 | from copy import deepcopy 10 | import re 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | 21 | import os 22 | import argparse 23 | import numpy as np 24 | import models 25 | import utils 26 | import time 27 | 28 | # from models import * 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 31 | parser.add_argument('--optimizer', '--opt', default='sgd', type=str, help='sgd variants (sgd, adam, amsgrad, adagrad, adadelta, rmsprop)') 32 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 33 | parser.add_argument('--epochs', default=4000, type=int, help='the number of epochs') 34 | parser.add_argument('--grow-threshold', '--gt', default=0.05, type=float, help='the accuracy threshold to grow or stop') 35 | parser.add_argument('--ema-params', '--ep', action='store_true', help='validating accuracy by a exponentially moving average of parameters') 36 | parser.add_argument('--growing-mode', default='group', type=str, help='how new structures are added (rate, all, group)') 37 | 38 | parser.add_argument('--rate', default=0.4, type=float, help='the rate to grow when --growing-mode=rate') 39 | parser.add_argument('--grow-interval', '--gi', default=100, type=int, help='an interval (in epochs) to grow new structures') 40 | parser.add_argument('--net', default='1-1-1', type=str, help='starting net') 41 | parser.add_argument('--max-net', default='200-200-200', type=str, help='The maximum net') 42 | parser.add_argument('--residual', default='CifarResNetBasic', type=str, help='the type of residual block (ResNetBasic or ResNetBottleneck or CifarResNetBasic)') 43 | parser.add_argument('--initializer', '--init', default='gaussian', type=str, help='initializers of new structures (zero, uniform, gaussian, adam)') 44 | 45 | parser.add_argument('--growing-metric', default='max', type=str, help='the metric for growing (max or avg)') 46 | parser.add_argument('--reset-states', '--rs', action='store_true', help='reset optimizer states or not (such as momentum)') 47 | parser.add_argument('--init-meta', default=1.0, type=float, help='a meta parameter for initializer') 48 | parser.add_argument('--tail-epochs', '--te', default=100, type=int, help='the number of epochs after growing epochs (--epochs) for sgd optimizer') 49 | parser.add_argument('--batch-size', '--bz', default=128, type=int, help='batch size') 50 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset (CIFAR10, CIFAR100, SVHN, FashionMNIST, MNIST)') 51 | parser.add_argument('--pad-net', default='', type=str, metavar='NET', 52 | help='a smallest net to be padded to (default: none)') 53 | parser.add_argument('--pad-epochs', default=4, type=int, help='the padding step for visualization') 54 | 55 | args = parser.parse_args() 56 | 57 | def save_model_with_padding(epoch, train_accu, model, new_arch, path): 58 | new_model = get_module(args.residual, num_blocks=new_arch, 59 | num_classes=utils.datasets[args.dataset]['num_classes'], 60 | image_channels=utils.datasets[args.dataset]['image_channels']) 61 | orig_params_data = {} 62 | for n, p in model.named_parameters(): 63 | orig_params_data[n] = p.data 64 | for n, p in new_model.named_parameters(): 65 | if n not in orig_params_data: 66 | logger.info('%s are set to zeros' % n) 67 | p.data.zero_() 68 | else: 69 | logger.info('%s are kept' % n) 70 | p.data = orig_params_data[n] 71 | 72 | torch.save({ 73 | 'epoch': epoch, 74 | 'train_accu': train_accu, 75 | 'net': new_model.state_dict(), 76 | }, path) 77 | 78 | def list_to_str(l): 79 | list(map(str, l)) 80 | s = '' 81 | for v in l: 82 | s += str(v) + '-' 83 | return s[:-1] 84 | 85 | def get_module(name, *args, **keywords): 86 | net = getattr(models, name)(*args, **keywords) 87 | net = net.to('cuda') 88 | net = torch.nn.DataParallel(net) 89 | cudnn.benchmark = True 90 | return net 91 | 92 | def get_optimizer(net): 93 | if 'sgd' == args.optimizer or 'sgdc' == args.optimizer: 94 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 95 | elif 'adam' == args.optimizer: 96 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4) 97 | elif 'amsgrad' == args.optimizer: 98 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4, amsgrad=True) 99 | elif 'adagrad' == args.optimizer: 100 | optimizer = optim.Adagrad(net.parameters(), lr=args.lr, weight_decay=5e-4) 101 | elif 'adadelta' == args.optimizer: 102 | optimizer = optim.Adadelta(net.parameters(), weight_decay=5e-4) 103 | elif 'rmsprop' == args.optimizer: 104 | optimizer = optim.RMSprop(net.parameters(), lr=args.lr, alpha=0.99, weight_decay=5e-4) 105 | else: 106 | logger.fatal('Unknown --optimizer') 107 | raise ValueError('Unknown --optimizer') 108 | return optimizer 109 | 110 | def params_id_to_name(net): 111 | themap = {} 112 | for k, v in net.named_parameters(): 113 | themap[id(v)] = k 114 | return themap 115 | 116 | def params_name_to_id(net): 117 | themap = {} 118 | for k, v in net.named_parameters(): 119 | themap[k] = id(v) 120 | return themap 121 | 122 | def save_all(epoch, train_accu, model, optimizer, path): 123 | torch.save({ 124 | 'epoch': epoch, 125 | 'train_accu': train_accu, 126 | 'model_state_dict': model.state_dict(), 127 | 'optimizer_state_dict': optimizer.state_dict(), 128 | 'name_id_map': params_name_to_id(model), 129 | }, path) 130 | 131 | def load_all(model, optimizer, path): 132 | checkpoint = torch.load(path) 133 | old_name_id_map = checkpoint['name_id_map'] 134 | new_id_name_map = params_id_to_name(model) 135 | # load existing params, and initializing missing ones 136 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 137 | new_params = [] 138 | if args.residual == 'ResNetBasic' or args.residual == 'CifarResNetBasic': 139 | reinit_pattern = '.*layer.*bn2\.((weight)|(bias))$' 140 | elif args.residual == 'ResNetBottleneck': 141 | reinit_pattern = '.*layer.*bn3\.((weight)|(bias))$' 142 | else: 143 | logger.fatal('Unknown --residual') 144 | exit() 145 | for n, p in model.named_parameters(): 146 | if n not in old_name_id_map and re.match(reinit_pattern, n): 147 | logger.info('reinitializing param {} ...'.format(n)) 148 | new_params.append(p) 149 | if args.initializer == 'zero': 150 | logger.info('******> Initializing as zeros...') 151 | p.data.zero_() 152 | elif args.initializer == 'uniform': 153 | logger.info('******> Initializing by uniform noises...') 154 | p.data.uniform_(0.0, to=args.init_meta) 155 | elif args.initializer == 'gaussian': 156 | logger.info('******> Initializing by gaussian noises') 157 | p.data.normal_(0.0, std=args.init_meta) 158 | elif args.initializer == 'adam': 159 | logger.info('******> Initializing by adam optimizer') 160 | else: 161 | logger.fatal('Unknown --initializer.') 162 | exit() 163 | if len(new_params) and args.initializer == 'adam': 164 | logger.info('******> Using adam to find a good initialization') 165 | old_train_accu = checkpoint['train_accu'] 166 | local_optimizer = optim.Adam(new_params, lr=0.001, weight_decay=5e-4) 167 | max_epoch = 10 168 | founded = False 169 | for e in range(max_epoch): 170 | _, accu = train(e, model, local_optimizer) 171 | if accu > old_train_accu - 0.5: 172 | logger.info('******> Found a good initial position with training accuracy %.2f (vs. old %.2f) at epoch %d' % ( 173 | accu, old_train_accu, e)) 174 | founded = True 175 | break 176 | if not founded: 177 | logger.info('******> failed to find a good initial position in %d epochs. Continue...' % max_epoch) 178 | 179 | # load existing states, and insert missing states as empty dict 180 | if not args.reset_states: 181 | new_checkpoint = deepcopy(optimizer.state_dict()) 182 | old_checkpoint = checkpoint['optimizer_state_dict'] 183 | if len(old_checkpoint['param_groups']) != 1 or len(new_checkpoint['param_groups']) != 1: 184 | logger.fatal('The number of param_groups is not 1.') 185 | exit() 186 | for new_id in new_checkpoint['param_groups'][0]['params']: 187 | name = new_id_name_map[new_id] 188 | if name in old_name_id_map: 189 | logger.info('loading param {} state...'.format(name)) 190 | old_id = old_name_id_map[name] 191 | new_checkpoint['state'][new_id] = old_checkpoint['state'][old_id] 192 | else: 193 | if new_id not in new_checkpoint['state']: 194 | logger.info('initializing param {} state as an empty dict...'.format(name)) 195 | new_checkpoint['state'][new_id] = {} 196 | else: 197 | logger.info('skipping param {} state (initial state exists)...'.format(name)) 198 | optimizer.load_state_dict(new_checkpoint) 199 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 200 | epoch = checkpoint['epoch'] 201 | 202 | return epoch 203 | 204 | def set_learning_rate(optimizer, lr): 205 | """Sets the learning rate """ 206 | logger.info('\nSetting learning rate to %.6f' % lr) 207 | for param_group in optimizer.param_groups: 208 | param_group['lr'] = lr 209 | 210 | def decay_learning_rate(optimizer): 211 | """Sets the learning rate to the initial LR decayed by 10""" 212 | for param_group in optimizer.param_groups: 213 | param_group['lr'] = param_group['lr'] * 0.1 214 | 215 | device = 'cuda' # if torch.cuda.is_available() else 'cpu' 216 | best_acc = 0 # best test accuracy 217 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 218 | 219 | save_path = os.path.join('./results', datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 220 | if not os.path.exists(save_path): 221 | os.makedirs(save_path) 222 | else: 223 | raise OSError('Directory {%s} exists. Use a new one.' % save_path) 224 | logging.basicConfig(filename=os.path.join(save_path, 'log.txt'), level=logging.INFO) 225 | logger = logging.getLogger('main') 226 | logger.addHandler(logging.StreamHandler()) 227 | logger.info("Saving to %s", save_path) 228 | logger.info("Running arguments: %s", args) 229 | 230 | # Data 231 | logger.info('==> Preparing data %s..' % args.dataset) 232 | train_padding = 4 if utils.datasets[args.dataset]['size'] == 32 else (32 - utils.datasets[args.dataset]['size'] ) / 2 233 | test_padding = 0 if utils.datasets[args.dataset]['size'] == 32 else (32 - utils.datasets[args.dataset]['size'] ) / 2 234 | assert (test_padding * 2 + utils.datasets[args.dataset]['size']) == 32 235 | logger.info('train pad = %d, test pad = %d' % (train_padding, test_padding)) 236 | if 'CIFAR10' == args.dataset or 'CIFAR100' == args.dataset or 'FashionMNIST' == args.dataset: 237 | logger.info('==> RandomHorizontalFlip enabled..') 238 | transform_train = transforms.Compose([ 239 | transforms.RandomCrop(32, padding=train_padding), 240 | transforms.RandomHorizontalFlip(), 241 | transforms.ToTensor(), 242 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 243 | ]) 244 | else: 245 | logger.info('==> RandomHorizontalFlip disabled..') 246 | transform_train = transforms.Compose([ 247 | transforms.RandomCrop(32, padding=train_padding), 248 | transforms.ToTensor(), 249 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 250 | ]) 251 | 252 | transform_test = transforms.Compose([ 253 | transforms.RandomCrop(32, padding=test_padding), # deterministic 254 | transforms.ToTensor(), 255 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 256 | ]) 257 | 258 | # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 259 | if 'SVHN' == args.dataset: 260 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, split='train', download=True, 261 | transform=transform_train) 262 | else: 263 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, train=True, download=True, 264 | transform=transform_train) 265 | logger.info('%d training samples.' % len(trainset)) 266 | 267 | logger.info('%d training samples are used for training' % len(trainset)) 268 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 269 | 270 | # testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 271 | if 'SVHN' == args.dataset: 272 | testset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, split='test', download=True, transform=transform_test) 273 | else: 274 | testset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, train=False, download=True, transform=transform_test) 275 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 276 | logger.info('%d test samples.' % len(testset)) 277 | 278 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 279 | 280 | # Model 281 | logger.info('==> Building model..') 282 | current_arch = list(map(int, args.net.split('-'))) 283 | max_arch = list(map(int, args.max_net.split('-'))) 284 | pad_arch = list(map(int, args.pad_net.split('-'))) if args.pad_net else None 285 | if len(current_arch) != len(max_arch): 286 | logger.fatal('max_arch has different size.') 287 | exit() 288 | growing_group = -1 289 | for cnt, v in enumerate(current_arch): 290 | if v < max_arch[cnt]: 291 | growing_group = cnt 292 | break 293 | 294 | net = get_module(args.residual, num_blocks=current_arch, 295 | num_classes=utils.datasets[args.dataset]['num_classes'], 296 | image_channels=utils.datasets[args.dataset]['image_channels']) 297 | # net = VGG('VGG19') 298 | # net = ResNet18() 299 | # net = PreActResNet18() 300 | # net = GoogLeNet() 301 | # net = DenseNet121() 302 | # net = ResNeXt29_2x64d() 303 | # net = MobileNet() 304 | # net = MobileNetV2() 305 | # net = DPN92() 306 | # net = ShuffleNetG2() 307 | # net = SENet18() 308 | 309 | criterion = nn.CrossEntropyLoss() 310 | optimizer = get_optimizer(net) 311 | param_ema = utils.TorchExponentialMovingAverage() 312 | # Training 313 | def train(epoch, net, own_optimizer=None): 314 | logger.info('\nTraining epoch %d @ %.1f sec' % (epoch, time.time())) 315 | net.train() 316 | train_loss = 0 317 | correct = 0 318 | total = 0 319 | for batch_idx, (inputs, targets) in enumerate(trainloader): 320 | inputs, targets = inputs.to(device), targets.to(device) 321 | if own_optimizer is not None: 322 | own_optimizer.zero_grad() 323 | else: 324 | optimizer.zero_grad() 325 | outputs = net(inputs) 326 | loss = criterion(outputs, targets) 327 | loss.backward() 328 | if own_optimizer is not None: 329 | own_optimizer.step() 330 | else: 331 | optimizer.step() 332 | # maintain a moving average 333 | params_data_dict = {} 334 | for n, p in net.named_parameters(): 335 | params_data_dict[n] = p.data 336 | param_ema.push(params_data_dict) 337 | 338 | train_loss += loss.item() 339 | _, predicted = outputs.max(1) 340 | total += targets.size(0) 341 | correct += predicted.eq(targets).sum().item() 342 | if 0 == batch_idx % 100 or batch_idx == len(trainloader) - 1: 343 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 344 | % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 345 | return train_loss / len(trainloader), 100.*correct/total 346 | 347 | def test(epoch, net, save=False): 348 | logger.info('Testing epoch %d @ %.1f sec' % (epoch, time.time())) 349 | global best_acc 350 | net.eval() 351 | test_loss = 0 352 | correct = 0 353 | total = 0 354 | with torch.no_grad(): 355 | if args.ema_params: 356 | logger.info('Using average params for test') 357 | orig_params = utils.set_named_parameters(net, param_ema.average(), strict=False) 358 | for batch_idx, (inputs, targets) in enumerate(testloader): 359 | inputs, targets = inputs.to(device), targets.to(device) 360 | outputs = net(inputs) 361 | loss = criterion(outputs, targets) 362 | 363 | test_loss += loss.item() 364 | _, predicted = outputs.max(1) 365 | total += targets.size(0) 366 | correct += predicted.eq(targets).sum().item() 367 | if 0 == batch_idx % 100 or batch_idx == len(testloader) - 1: 368 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 369 | % (batch_idx+1, len(testloader), test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 370 | 371 | # Save checkpoint. 372 | acc = 100.*correct/total 373 | global current_arch 374 | if acc > best_acc and save: 375 | logger.info('Saving best %.3f @ %d ( resnet-%s )...' %(acc, epoch, list_to_str(current_arch))) 376 | state = { 377 | 'net': net.state_dict(), 378 | 'acc': acc, 379 | 'epoch': epoch, 380 | } 381 | torch.save(state, os.path.join(save_path, 'best_ckpt.t7')) 382 | best_acc = acc if acc > best_acc else best_acc 383 | 384 | with torch.no_grad(): 385 | if args.ema_params: 386 | utils.set_named_parameters(net, orig_params, strict=True) 387 | 388 | return test_loss / len(testloader), acc 389 | 390 | # main func 391 | if args.growing_metric == 'max': 392 | ema = utils.MovingMaximum() 393 | elif args.growing_metric == 'avg': 394 | ema = utils.ExponentialMovingAverage(decay=0.95) 395 | else: 396 | logger.fatal('Unknown --growing-metric') 397 | exit() 398 | 399 | growed = False # if growed in the recent interval 400 | 401 | 402 | def can_grow(maxlim, arch): 403 | for maxv, a in zip(maxlim, arch): 404 | if maxv > a: 405 | return True 406 | return False 407 | 408 | num_tail_epochs = args.tail_epochs if ('sgd' == args.optimizer or 'sgdc' == args.optimizer) else 0 409 | last_epoch = -1 410 | growing_epochs = [] 411 | intervals = (args.epochs - 1) // args.grow_interval + 1 412 | # epoch, train loss, train accu, test loss, test accu, timestamps 413 | curves = np.zeros((intervals*args.grow_interval + num_tail_epochs, 6)) 414 | for interval in range(0, intervals): 415 | # grow or stop 416 | grow_check = interval > 0 417 | if grow_check: # check after every interval 418 | delta_accu = ema.delta(-1 - args.grow_interval, -1) 419 | logger.info( 420 | '******> improved %.3f (ExponentialMovingAverage) in the last %d epochs' % (delta_accu, args.grow_interval)) 421 | if can_grow(max_arch, current_arch) and delta_accu < args.grow_threshold: 422 | # save current model 423 | save_ckpt = os.path.join(save_path, 'resnet-growing_ckpt.t7') 424 | save_all(interval*args.grow_interval - 1, curves[interval*args.grow_interval - 1, 2], net, optimizer, save_ckpt) 425 | # create a new net and optimizer 426 | # current_arch[growing_group] += 1 427 | current_arch = utils.next_arch(args.growing_mode, max_arch, current_arch, logger, rate=args.rate, group=growing_group) 428 | logger.info('******> growing to resnet-%s before epoch %d' % (list_to_str(current_arch), interval*args.grow_interval)) 429 | net = get_module(args.residual, num_blocks=current_arch, 430 | num_classes=utils.datasets[args.dataset]['num_classes'], 431 | image_channels=utils.datasets[args.dataset]['image_channels']) 432 | optimizer = get_optimizer(net) 433 | loaded_epoch = load_all(net, optimizer, save_ckpt) 434 | logger.info('testing new model ...') 435 | test(loaded_epoch, net) 436 | growed = True 437 | growing_epochs.append(interval*args.grow_interval) 438 | else: 439 | growed = False 440 | 441 | # training and testing 442 | for epoch in range(interval*args.grow_interval, (interval+1)*args.grow_interval): 443 | if 'sgdc' == args.optimizer: 444 | e = epoch % args.grow_interval 445 | if e < args.grow_interval // 2: 446 | set_learning_rate(optimizer, args.lr) 447 | elif e < args.grow_interval * 3 // 4: 448 | set_learning_rate(optimizer, args.lr * 0.1) 449 | else: 450 | set_learning_rate(optimizer, args.lr * 0.01) 451 | curves[epoch, 0] = epoch 452 | curves[epoch, 1], curves[epoch, 2] = train(epoch, net) 453 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 454 | curves[epoch, 5] = time.time() / 60.0 455 | ema.push(curves[epoch, 4]) 456 | if args.pad_net and (epoch % args.pad_epochs == 0): 457 | save_model_with_padding(epoch, curves[epoch, 2], net, pad_arch, 458 | os.path.join(save_path, 'model_pad_%d.t7' % epoch)) 459 | 460 | if grow_check: # check after every interval 461 | delta_accu = ema.delta(-1 - args.grow_interval, -1) 462 | if growed and delta_accu < args.grow_threshold: # just growed but no improvement 463 | if args.growing_mode == 'group': 464 | max_arch[growing_group] = current_arch[growing_group] 465 | logger.info('******> stop growing group %d permanently. Limited as %s .' % (growing_group, list_to_str(max_arch))) 466 | else: 467 | max_arch[:] = current_arch[:] 468 | logger.info('******> stop growing all permanently. Limited as %s .' % (list_to_str(max_arch))) 469 | if growed: 470 | if can_grow(max_arch, current_arch): 471 | if args.growing_mode == 'group': 472 | growing_group = utils.next_group(growing_group, max_arch, current_arch, logger) 473 | else: 474 | logger.info('******> stop growing all groups') 475 | last_epoch = (interval + 1) * args.grow_interval - 1 476 | logger.info('******> reach limitation. Finished in advance @ epoch %d' % last_epoch) 477 | curves = curves[:last_epoch+1+num_tail_epochs, :] 478 | break 479 | last_epoch = (interval + 1) * args.grow_interval - 1 480 | 481 | set_learning_rate(optimizer, args.lr) 482 | for epoch in range(last_epoch + 1, last_epoch + 1 + num_tail_epochs): 483 | if (epoch == last_epoch + 1) or (epoch == last_epoch + 1 + num_tail_epochs // 2): 484 | logger.info('======> decaying learning rate') 485 | decay_learning_rate(optimizer) 486 | curves[epoch, 0] = epoch 487 | curves[epoch, 1], curves[epoch, 2] = train(epoch, net) 488 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 489 | curves[epoch, 5] = time.time() / 60.0 490 | ema.push(curves[epoch, 4]) 491 | if args.pad_net and (epoch % args.pad_epochs == 0): 492 | save_model_with_padding(epoch, curves[epoch, 2], net, pad_arch, 493 | os.path.join(save_path, 'model_pad_%d.t7' % epoch)) 494 | 495 | # align time 496 | for e in range(curves.shape[0]): 497 | curves[curves.shape[0]-1-e, 5] -= curves[0, 5] 498 | 499 | # plotting 500 | plot_segs = [0] + growing_epochs 501 | if len(growing_epochs) == 0 or growing_epochs[-1] != curves.shape[0]-1: 502 | plot_segs = plot_segs + [curves.shape[0]-1] 503 | logger.info('growing epochs {}'.format(list_to_str(growing_epochs))) 504 | logger.info('curves: \n {}'.format(np.array_str(curves))) 505 | np.savetxt(os.path.join(save_path, 'curves.dat'), curves) 506 | clr1 = (0.5, 0., 0.) 507 | clr2 = (0.0, 0.5, 0.) 508 | fig, ax1 = plt.subplots() 509 | fig2, ax3 = plt.subplots() 510 | ax2 = ax1.twinx() 511 | ax4 = ax3.twinx() 512 | ax1.set_xlabel('epoch') 513 | ax1.set_ylabel('Loss', color=clr1) 514 | ax1.tick_params(axis='y', colors=clr1) 515 | ax2.set_ylabel('Accuracy (%)', color=clr2) 516 | ax2.tick_params(axis='y', colors=clr2) 517 | 518 | ax3.set_xlabel('time (mins)') 519 | ax3.set_ylabel('Loss', color=clr1) 520 | ax3.tick_params(axis='y', colors=clr1) 521 | ax4.set_ylabel('Accuracy (%)', color=clr2) 522 | ax4.tick_params(axis='y', colors=clr2) 523 | 524 | # ax2.set_ylim(80, 100) # no plot if enabled 525 | for idx in range(len(plot_segs)-1): 526 | start = plot_segs[idx] 527 | end = plot_segs[idx+1] + 1 if (plot_segs[idx+1] == curves.shape[0] - 1) else plot_segs[idx+1] 528 | markersize = 12 529 | coef = 2. if idx % 2 else 1. 530 | if idx == len(plot_segs)-2: 531 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize) 532 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize) 533 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize) 534 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize) 535 | 536 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize) 537 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize) 538 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize) 539 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize) 540 | else: 541 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 542 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 543 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 544 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 545 | 546 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 547 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 548 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 549 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 550 | 551 | ax2.plot(curves[:, 0], ema.get(), '-', color=[1.0, 0, 1.0]) 552 | logger.info('Val accuracy moving average: \n {}'.format(np.array_str(np.array(ema.get())))) 553 | np.savetxt(os.path.join(save_path, 'ema.dat'), np.array(ema.get())) 554 | ax2.set_ylim(bottom=40, top=100) 555 | ax1.legend(('Train loss', 'Val loss'), loc='lower right') 556 | ax2.legend(('Train accuracy', 'Val accuracy', 'Val moving avg'), loc='lower left') 557 | fig.savefig(os.path.join(save_path, 'curves-vs-epochs.pdf')) 558 | 559 | ax4.plot(curves[:, 5], ema.get(), '-', color=[1.0, 0, 1.0]) 560 | ax4.set_ylim(bottom=40, top=100) 561 | ax3.legend(('Train loss', 'Val loss'), loc='lower right') 562 | ax4.legend(('Train accuracy', 'Val accuracy', 'Val moving avg'), loc='lower left') 563 | fig2.savefig(os.path.join(save_path, 'curves-vs-time.pdf')) 564 | 565 | 566 | logger.info('Done!') -------------------------------------------------------------------------------- /imagenet/main_gradual.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import matplotlib 5 | matplotlib.use("pdf") 6 | import matplotlib.pyplot as plt 7 | import logging 8 | from datetime import datetime 9 | from copy import deepcopy 10 | import re 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | import models.switch as ms 22 | import models.pswitch as ps 23 | 24 | import os 25 | import argparse 26 | import numpy as np 27 | import models 28 | import utils 29 | import time 30 | import GPUtil 31 | import gc 32 | import pickle 33 | 34 | # from models import * 35 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 36 | parser.add_argument('--optimizer', '--opt', default='sgd', type=str, help='sgd variants (sgdc, sgd, adam, amsgrad, adagrad, adadelta, rmsprop)') 37 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 38 | parser.add_argument('--switch-reg', '--sr', default=None, type=float, help='lambda of L1 regularization on pswitch') 39 | parser.add_argument('--epochs', default=1000, type=int, help='the number of epochs') 40 | parser.add_argument('--grow-threshold', '--gt', default=0.05, type=float, help='the accuracy threshold to grow or stop') 41 | parser.add_argument('--ema-params', '--ep', action='store_true', help='validating accuracy by a exponentially moving average of parameters') 42 | parser.add_argument('--growing-mode', default='group', type=str, help='how new structures are added (rate, all, sub, group)') 43 | parser.add_argument('--tail-epochs', '--te', default=90, type=int, help='the number of epochs after growing epochs (--epochs)') 44 | parser.add_argument('--pswitch-thre', '--pt', default=0.005, type=float, help='threshold to zero pswitchs') 45 | 46 | parser.add_argument('--batch-size', '--bz', default=256, type=int, help='batch size') 47 | parser.add_argument('--switch-off', '--so', action='store_true', help='switch off at initialization') 48 | parser.add_argument('--grow-interval', '--gi', default=1, type=int, help='an interval (in epochs) to grow new structures') 49 | parser.add_argument('--stop-interval', '--si', default=30, type=int, help='an interval (in epochs) to grow new structures') 50 | parser.add_argument('--net', default='1-1-1-1', type=str, help='starting net') 51 | parser.add_argument('--sub-net', default='1-1-1-1', type=str, help='a sub net to grow') 52 | parser.add_argument('--max-net', default='6-16-72-32', type=str, help='The maximum net') 53 | parser.add_argument('--residual', default='ResNetBasic', type=str, 54 | help='the type of residual block (ResNetBasic or ResNetBottleneck)') 55 | parser.add_argument('--initializer', '--init', default='gaussian', type=str, help='initializers of new structures (zero, uniform, gaussian, adam)') 56 | 57 | parser.add_argument('--rate', default=0.4, type=float, help='the rate to grow when --growing-mode=rate') 58 | parser.add_argument('--growing-metric', default='max', type=str, help='the metric for growing (max or avg)') 59 | parser.add_argument('--reset-states', '--rs', action='store_true', help='reset optimizer states or not (such as momentum)') 60 | parser.add_argument('--init-meta', default=1.0, type=float, help='a meta parameter for initializer') 61 | parser.add_argument('--evaluate', default='', type=str, metavar='PATH', 62 | help='path to checkpoint (default: none)') 63 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 64 | help='path to checkpoint to resume (default: none)') 65 | parser.add_argument('--start-epoch', default=0, type=int, help='start epoch') 66 | parser.add_argument('--start-group', default=0, type=int, help='start group to grow') 67 | parser.add_argument('--grown-group', default=None, type=int, help='grown group') 68 | parser.add_argument('--start-chunk', default=1, type=int, help='start chunk number to avoid OOM') 69 | parser.add_argument('--data', default='./imagenet', type=str, metavar='PATH', 70 | help='path to imagenet dataset (default: ./imagenet)') 71 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 72 | help='number of data loading workers (default: 16)') 73 | 74 | args = parser.parse_args() 75 | cur_chunk_num = args.start_chunk 76 | 77 | def list_to_str(l): 78 | list(map(str, l)) 79 | s = '' 80 | for v in l: 81 | s += str(v) + '-' 82 | return s[:-1] 83 | 84 | def get_module(name, switch_steps, *_args, **keywords): 85 | net = getattr(models, name)(*_args, **keywords) 86 | net = net.to('cuda') 87 | net = torch.nn.DataParallel(net) 88 | cudnn.benchmark = True 89 | configure_switch_policy(net, switch_steps) 90 | return net 91 | 92 | def get_optimizer(net): 93 | if 'sgd' == args.optimizer or 'sgdc' == args.optimizer: 94 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) 95 | elif 'adam' == args.optimizer: 96 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4) 97 | elif 'amsgrad' == args.optimizer: 98 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4, amsgrad=True) 99 | elif 'adagrad' == args.optimizer: 100 | optimizer = optim.Adagrad(net.parameters(), lr=args.lr, weight_decay=1e-4) 101 | elif 'adadelta' == args.optimizer: 102 | optimizer = optim.Adadelta(net.parameters(), weight_decay=1e-4) 103 | elif 'rmsprop' == args.optimizer: 104 | optimizer = optim.RMSprop(net.parameters(), lr=args.lr, alpha=0.99, weight_decay=1e-4) 105 | else: 106 | logger.fatal('Unknown --optimizer') 107 | raise ValueError('Unknown --optimizer') 108 | return optimizer 109 | 110 | def params_id_to_name(net): 111 | themap = {} 112 | for k, v in net.named_parameters(): 113 | themap[id(v)] = k 114 | return themap 115 | 116 | def params_name_to_id(net): 117 | themap = {} 118 | for k, v in net.named_parameters(): 119 | themap[k] = id(v) 120 | return themap 121 | 122 | 123 | def save_all(epoch, train_accu, model, optimizer, path): 124 | torch.save({ 125 | 'epoch': epoch, 126 | 'train_accu': train_accu, 127 | 'model_state_dict': model.state_dict(), 128 | 'optimizer_state_dict': optimizer.state_dict(), 129 | 'name_id_map': params_name_to_id(model), 130 | }, path) 131 | 132 | def zerout_pswitchs(model, threshold=0.0001, log=False): 133 | total = 0 134 | zeroed = 0 135 | for idx, m in enumerate(model.named_modules()): 136 | if isinstance(m[1], ps.PSwitch): 137 | total += 1 138 | # logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 139 | if m[1].switch.data.abs() < threshold: 140 | if log: 141 | logger.info('******> switch {} is zeroed out...'.format(m[0])) 142 | m[1].switch.data.fill_(0.0) 143 | zeroed += 1 144 | if log: 145 | logger.info('%d/%d zeroed out' % (zeroed, total)) 146 | 147 | def reg_pswitchs(model): 148 | reg = 0.0 149 | for idx, m in enumerate(model.named_modules()): 150 | if isinstance(m[1], ps.PSwitch): 151 | reg += m[1].switch.norm(p=1) 152 | return reg * args.switch_reg 153 | 154 | def print_switchs(model): 155 | for idx, m in enumerate(model.named_modules()): 156 | if isinstance(m[1], ms.Switch): 157 | logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 158 | 159 | def print_pswitchs(model): 160 | for idx, m in enumerate(model.named_modules()): 161 | if isinstance(m[1], ps.PSwitch): 162 | logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 163 | 164 | def increase_switchs(model): 165 | for idx, m in enumerate(model.named_modules()): 166 | if isinstance(m[1], ms.Switch): 167 | m[1].increase() 168 | 169 | def configure_switch_policy(model, steps, start=0.0, stop=1.0, mode='linear'): 170 | for idx, m in enumerate(model.named_modules()): 171 | if isinstance(m[1], ms.Switch): 172 | logger.info('******> configuring switch {}...'.format(m[0])) 173 | m[1].set_params(steps, start, stop, mode) 174 | 175 | 176 | def load_all(model, optimizer, path): 177 | checkpoint = torch.load(path) 178 | old_name_id_map = checkpoint['name_id_map'] 179 | new_id_name_map = params_id_to_name(model) 180 | # load existing params, and initializing missing ones 181 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 182 | new_params = [] 183 | if args.residual == 'ResNetBasic': 184 | reinit_pattern = '.*layer.*bn2\.((weight)|(bias))$' 185 | elif args.residual == 'ResNetBottleneck': 186 | reinit_pattern = '.*layer.*bn3\.((weight)|(bias))$' 187 | elif args.residual == 'PlainNet': 188 | reinit_pattern = 'UseDefaultInitialization' 189 | logger.info('No reinitialization for PlainNet') 190 | else: 191 | logger.fatal('Unknown --residual') 192 | exit() 193 | for n, p in model.named_parameters(): 194 | if n not in old_name_id_map and re.match(reinit_pattern, n): 195 | logger.info('******> reinitializing param {} ...'.format(n)) 196 | new_params.append(p) 197 | if args.initializer == 'zero': 198 | logger.info('******> Initializing as zeros...') 199 | p.data.zero_() 200 | elif args.initializer == 'uniform': 201 | logger.info('******> Initializing by uniform noises...') 202 | p.data.uniform_(0.0, to=args.init_meta) 203 | elif args.initializer == 'gaussian': 204 | logger.info('******> Initializing by gaussian noises') 205 | p.data.normal_(0.0, std=args.init_meta) 206 | elif args.initializer == 'adam': 207 | logger.info('******> Initializing by adam optimizer') 208 | else: 209 | logger.fatal('Unknown --initializer.') 210 | exit() 211 | if args.switch_off: 212 | switch_name = '.'.join(n.split('.')[:-2]+['switch.switch']) 213 | if switch_name in model.state_dict(): 214 | logger.info('******> resetting %s to 0.0 from %.3f' % (switch_name, model.state_dict()[switch_name])) 215 | model.state_dict()[switch_name].zero_() 216 | 217 | if len(new_params) and args.initializer == 'adam': 218 | logger.info('******> Using adam to find a good initialization') 219 | old_train_accu = checkpoint['train_accu'] 220 | local_optimizer = optim.Adam(new_params, lr=0.001, weight_decay=1e-4) 221 | max_epoch = 10 222 | founded = False 223 | for e in range(max_epoch): 224 | _, accu, __ = train(e, model, local_optimizer, chunk_num=cur_chunk_num) 225 | if accu > old_train_accu - 0.5: 226 | logger.info('******> Found a good initial position with training accuracy %.2f (vs. old %.2f) at epoch %d' % ( 227 | accu, old_train_accu, e)) 228 | founded = True 229 | break 230 | if not founded: 231 | logger.info('******> failed to find a good initial position in %d epochs. Continue...' % max_epoch) 232 | 233 | 234 | # load existing states, and insert missing states as empty dict 235 | if not args.reset_states: 236 | new_checkpoint = deepcopy(optimizer.state_dict()) 237 | old_checkpoint = checkpoint['optimizer_state_dict'] 238 | if len(old_checkpoint['param_groups']) != 1 or len(new_checkpoint['param_groups']) != 1: 239 | logger.fatal('The number of param_groups is not 1.') 240 | exit() 241 | for new_id in new_checkpoint['param_groups'][0]['params']: 242 | name = new_id_name_map[new_id] 243 | if name in old_name_id_map: 244 | old_id = old_name_id_map[name] 245 | if old_id in old_checkpoint['state']: 246 | logger.info('loading param {} state...'.format(name)) 247 | new_checkpoint['state'][new_id] = old_checkpoint['state'][old_id] 248 | else: 249 | logger.info('initializing param {} state as an empty dict...'.format(name)) 250 | new_checkpoint['state'][new_id] = {} 251 | else: 252 | if new_id not in new_checkpoint['state']: 253 | logger.info('initializing param {} state as an empty dict...'.format(name)) 254 | new_checkpoint['state'][new_id] = {} 255 | else: 256 | logger.info('skipping param {} state (initial state exists)...'.format(name)) 257 | optimizer.load_state_dict(new_checkpoint) 258 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 259 | epoch = checkpoint['epoch'] 260 | 261 | return epoch 262 | 263 | def set_learning_rate(optimizer, lr): 264 | """Sets the learning rate """ 265 | logger.info('\nSetting learning rate to %.6f' % lr) 266 | for param_group in optimizer.param_groups: 267 | param_group['lr'] = lr 268 | 269 | def decay_learning_rate(optimizer): 270 | """Sets the learning rate to the initial LR decayed by 10""" 271 | for param_group in optimizer.param_groups: 272 | param_group['lr'] = param_group['lr'] * 0.1 273 | 274 | device = 'cuda' # if torch.cuda.is_available() else 'cpu' 275 | best_acc = 0 # best test accuracy 276 | 277 | save_path = os.path.join('./results', datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 278 | if not os.path.exists(save_path): 279 | os.makedirs(save_path) 280 | else: 281 | raise OSError('Directory {%s} exists. Use a new one.' % save_path) 282 | logging.basicConfig(filename=os.path.join(save_path, 'log.txt'), level=logging.INFO) 283 | logger = logging.getLogger('main') 284 | logger.addHandler(logging.StreamHandler()) 285 | logger.info("Saving to %s", save_path) 286 | logger.info("Running arguments: %s", args) 287 | 288 | # Data 289 | logger.info('==> Preparing data..') 290 | train_sampler = None 291 | traindir = os.path.join(args.data, 'train') 292 | valdir = os.path.join(args.data, 'val') 293 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 294 | std=[0.229, 0.224, 0.225]) 295 | train_dataset = datasets.ImageFolder( 296 | traindir, 297 | transforms.Compose([ 298 | transforms.RandomResizedCrop(224), 299 | transforms.RandomHorizontalFlip(), 300 | transforms.ToTensor(), 301 | normalize, 302 | ])) 303 | trainloader = torch.utils.data.DataLoader( 304 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 305 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 306 | 307 | testloader = torch.utils.data.DataLoader( 308 | datasets.ImageFolder(valdir, transforms.Compose([ 309 | transforms.Resize(256), 310 | transforms.CenterCrop(224), 311 | transforms.ToTensor(), 312 | normalize, 313 | ])), 314 | batch_size=args.batch_size, shuffle=False, 315 | num_workers=args.workers, pin_memory=True) 316 | # Model 317 | logger.info('==> Building model..') 318 | current_arch = list(map(int, args.net.split('-'))) 319 | subnet_arch = list(map(int, args.sub_net.split('-'))) 320 | max_arch = list(map(int, args.max_net.split('-'))) 321 | if len(current_arch) != len(max_arch): 322 | logger.fatal('max_arch has different size.') 323 | exit() 324 | for v1, v2 in zip(current_arch, max_arch): 325 | if v1 > v2: 326 | logger.error('current arch is larger than max arch! Exit!') 327 | exit() 328 | growing_group = -1 329 | grown_group = args.grown_group 330 | 331 | for idx in range(len(current_arch)): 332 | cnt = (idx + args.start_group) % len(current_arch) 333 | # for cnt, v in enumerate(current_arch): 334 | if current_arch[cnt] < max_arch[cnt]: 335 | growing_group = cnt 336 | break 337 | net = get_module(args.residual, args.grow_interval, current_arch) 338 | # net = VGG('VGG19') 339 | # net = ResNet18() 340 | # net = PreActResNet18() 341 | # net = GoogLeNet() 342 | # net = DenseNet121() 343 | # net = ResNeXt29_2x64d() 344 | # net = MobileNet() 345 | # net = MobileNetV2() 346 | # net = DPN92() 347 | # net = ShuffleNetG2() 348 | # net = SENet18() 349 | 350 | criterion = nn.CrossEntropyLoss(reduction='sum') 351 | optimizer = get_optimizer(net) 352 | param_ema = utils.TorchExponentialMovingAverage() 353 | # Training 354 | def train(epoch, net, own_optimizer=None, increase_switch=False, chunk_num=1): 355 | logger.info('\nTraining epoch %d @ %.1f sec' % (epoch, time.time())) 356 | net.train() 357 | new_chunk_num = chunk_num 358 | train_loss = 0 359 | correct = 0 360 | total = 0 361 | print_pswitchs(net) 362 | if increase_switch: 363 | increase_switchs(net) 364 | for batch_idx, (inputs, targets) in enumerate(trainloader): 365 | inputs, targets = inputs.to(device), targets.to(device) 366 | if own_optimizer is not None: 367 | own_optimizer.zero_grad() 368 | else: 369 | optimizer.zero_grad() 370 | 371 | # avoid out of memory by split a batch to chunks 372 | if targets.size(0) < chunk_num: 373 | logger.warning('%d samples cannot be chunked to %d. Set the chunk number to %d' % (targets.size(0), chunk_num, targets.size(0))) 374 | chunk_num = targets.size(0) 375 | sub_inputs = inputs.chunk(chunk_num) 376 | sub_targets = targets.chunk(chunk_num) 377 | for chunk_ins, chunk_tgts in zip(sub_inputs, sub_targets): 378 | outputs = net(chunk_ins) 379 | loss = criterion(outputs, chunk_tgts) 380 | loss.backward() 381 | train_loss += loss.item() 382 | _, predicted = outputs.max(1) 383 | total += chunk_tgts.size(0) 384 | correct += predicted.eq(chunk_tgts).sum().item() 385 | for p in net.parameters(): 386 | p.grad.data.div_(len(inputs)) 387 | sreg = 0.0 388 | if args.switch_reg is not None: 389 | sreg = reg_pswitchs(net) 390 | sreg.backward() 391 | 392 | if own_optimizer is not None: 393 | own_optimizer.step() 394 | else: 395 | optimizer.step() 396 | 397 | if args.switch_reg is not None: 398 | zerout_pswitchs(net, args.pswitch_thre) 399 | 400 | # maintain a moving average 401 | if args.ema_params: 402 | params_data_dict = {} 403 | for n, p in net.named_parameters(): 404 | params_data_dict[n] = p.data 405 | param_ema.push(params_data_dict) 406 | 407 | if 0 == batch_idx % 100 or batch_idx == len(trainloader) - 1: 408 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 409 | % (batch_idx+1, len(trainloader), train_loss/total, 100.*correct/total, correct, total)) 410 | if args.switch_reg is not None: 411 | logger.info(' ==> PSwitch L1 Reg.: %.6f ' % sreg) 412 | 413 | # a probe of potential OOM 414 | if batch_idx < 10 and new_chunk_num == chunk_num: 415 | for gpu_stat in GPUtil.getGPUs(): 416 | if gpu_stat.memoryFree < 1000: 417 | logger.info('******> hitting gpu memory limit. Only %d MB / %d MB is free in GPU %d.' % ( 418 | gpu_stat.memoryFree, gpu_stat.memoryTotal, gpu_stat.id)) 419 | new_chunk_num += 1 420 | logger.info('******> may increase the chunk number to %d' % new_chunk_num) 421 | gc.collect() 422 | break 423 | 424 | return train_loss / total, 100.*correct/total, new_chunk_num 425 | 426 | def test(epoch, net, save=False): 427 | logger.info('Testing epoch %d @ %.1f sec' % (epoch, time.time())) 428 | global best_acc 429 | net.eval() 430 | test_loss = 0 431 | correct = 0 432 | total = 0 433 | with torch.no_grad(): 434 | if args.ema_params: 435 | logger.info('Using average params for test') 436 | orig_params = utils.set_named_parameters(net, param_ema.average(), strict=False) 437 | for batch_idx, (inputs, targets) in enumerate(testloader): 438 | inputs, targets = inputs.to(device), targets.to(device) 439 | outputs = net(inputs) 440 | loss = criterion(outputs, targets) 441 | 442 | test_loss += loss.item() 443 | _, predicted = outputs.max(1) 444 | total += targets.size(0) 445 | correct += predicted.eq(targets).sum().item() 446 | if 0 == batch_idx % 100 or batch_idx == len(testloader) - 1: 447 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 448 | % (batch_idx+1, len(testloader), test_loss/total, 100.*correct/total, correct, total)) 449 | 450 | # Save checkpoint. 451 | acc = 100.*correct/total 452 | global current_arch 453 | if acc > best_acc and save: 454 | logger.info('Saving best %.3f @ %d ( resnet-%s )...' %(acc, epoch, list_to_str(current_arch))) 455 | state = { 456 | 'net': net.state_dict(), 457 | 'acc': acc, 458 | 'epoch': epoch, 459 | } 460 | torch.save(state, os.path.join(save_path, 'best_ckpt.t7')) 461 | save_all(epoch, None, net, optimizer, 462 | os.path.join(save_path, 'best_model_optimizer_ckpt.t7')) 463 | 464 | best_acc = acc if acc > best_acc else best_acc 465 | 466 | with torch.no_grad(): 467 | if args.ema_params: 468 | utils.set_named_parameters(net, orig_params, strict=True) 469 | 470 | return test_loss / total, acc 471 | 472 | # main func 473 | 474 | # resume and evaluate from a checkpoint 475 | if args.evaluate: 476 | if os.path.isfile(args.evaluate): 477 | # load existing params, and initializing missing ones 478 | print("=> loading checkpoint '{}'".format(args.evaluate)) 479 | checkpoint = torch.load(args.evaluate) 480 | net.load_state_dict(checkpoint['net']) 481 | print("=> loaded checkpoint '{}' (epoch {})" 482 | .format(args.evaluate, checkpoint['epoch'])) 483 | print_pswitchs(net) 484 | logger.info('zeroing out small pswitchs...') 485 | zerout_pswitchs(net, args.pswitch_thre, log=True) 486 | test(checkpoint['epoch'], net) 487 | logger.info('evaluation done!') 488 | else: 489 | print("=> no checkpoint found at '{}'".format(args.evaluate)) 490 | exit() 491 | 492 | if args.growing_metric == 'max': 493 | ema = utils.MovingMaximum() 494 | elif args.growing_metric == 'avg': 495 | ema = utils.ExponentialMovingAverage(decay=0.95) 496 | else: 497 | logger.fatal('Unknown --growing-metric') 498 | exit() 499 | 500 | 501 | def can_grow(maxlim, arch): 502 | for maxv, a in zip(maxlim, arch): 503 | if maxv > a: 504 | return True 505 | return False 506 | 507 | num_tail_epochs = args.tail_epochs # if (args.optimizer == 'sgd' or args.optimizer == 'sgdc') else 0 508 | last_epoch = -1 509 | growing_epochs = [] 510 | intervals = (args.epochs - 1) // args.grow_interval + 1 511 | if args.resume: 512 | ckpt_file = None 513 | if intervals > 0 and args.start_epoch == 0: 514 | logger.info('resuming a growing model') 515 | if os.path.isfile(os.path.join(args.resume, 'ema.obj')): 516 | ema = pickle.load(open(os.path.join(args.resume, 'ema.obj'), 'r')) 517 | logger.info('Previous max val accuracy: \n {}'.format(np.array_str(np.array(ema.get())))) 518 | else: 519 | logger.warning('ema.obj does not exist. Growing may be extended.') 520 | ckpt_file = os.path.join(args.resume, 'resnet-growing_ckpt.t7') 521 | saved_epoch = load_all(net, optimizer, ckpt_file) 522 | elif intervals == 0 and args.start_epoch < num_tail_epochs: 523 | logger.info('resuming a post-growing model') 524 | ckpt_file = os.path.join(args.resume, 'best_model_optimizer_ckpt.t7') 525 | if os.path.isfile(ckpt_file): 526 | saved_epoch = load_all(net, optimizer, ckpt_file) 527 | else: 528 | logger.warning('Only model states are resumed (no optimizer states).') 529 | ckpt_file = os.path.join(args.resume, 'best_ckpt.t7') 530 | checkpoint = torch.load(ckpt_file) 531 | saved_epoch = checkpoint['epoch'] 532 | net.load_state_dict(checkpoint['net']) 533 | else: 534 | logger.error('resuming with unexpected args! Exit!') 535 | exit() 536 | logger.info('resumed from %s (saved at epoch %d)' % (ckpt_file, saved_epoch)) 537 | test(saved_epoch, net) 538 | 539 | # epoch, train loss, train accu, test loss, test accu, timestamps 540 | curves = np.zeros((intervals*args.grow_interval + num_tail_epochs, 6)) 541 | for interval in range(0, intervals): 542 | # training and testing 543 | for epoch in range(interval*args.grow_interval, (interval+1)*args.grow_interval): 544 | if 'sgdc' == args.optimizer: 545 | e = epoch % args.grow_interval 546 | if e < args.grow_interval // 3: 547 | set_learning_rate(optimizer, args.lr) 548 | elif e < args.grow_interval * 2 // 3: 549 | set_learning_rate(optimizer, args.lr * 0.1) 550 | else: 551 | set_learning_rate(optimizer, args.lr * 0.01) 552 | curves[epoch, 0] = epoch 553 | curves[epoch, 1], curves[epoch, 2], cur_chunk_num = train(epoch, net, chunk_num=cur_chunk_num) 554 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 555 | curves[epoch, 5] = time.time() / 60.0 556 | ema.push(curves[epoch, 4]) 557 | 558 | # limit max arch 559 | logger.info('******> improved %.3f (ExponentialMovingAverage) in the last %d epochs' % ( 560 | ema.delta(-1 - args.grow_interval, -1), args.grow_interval)) 561 | delta_accu = ema.delta(-1 - args.stop_interval, -1) 562 | logger.info( 563 | '******> improved %.3f (ExponentialMovingAverage) in the last %d epochs' % (delta_accu, args.stop_interval)) 564 | if delta_accu < args.grow_threshold: # no improvement 565 | if args.growing_mode == 'group': 566 | if grown_group is not None: 567 | max_arch[grown_group] = current_arch[grown_group] 568 | logger.info('******> stop growing group %d permanently. Limited as %s .' % (grown_group, list_to_str(max_arch))) 569 | else: 570 | max_arch[:] = current_arch[:] 571 | logger.info('******> stop growing all permanently. Limited as %s .' % (list_to_str(max_arch))) 572 | 573 | if can_grow(max_arch, current_arch): 574 | # save current model 575 | save_ckpt = os.path.join(save_path, 'resnet-growing_ckpt.t7') 576 | save_all((interval + 1) * args.grow_interval - 1, 577 | curves[(interval + 1) * args.grow_interval - 1, 2], 578 | net, 579 | optimizer, 580 | save_ckpt) 581 | pickle.dump(ema, open(os.path.join(save_path, 'ema.obj'), 'w')) 582 | pickle.dump(curves, open(os.path.join(save_path, 'curves.obj'), 'w')) 583 | 584 | # create a new net and optimizer 585 | current_arch = utils.next_arch(args.growing_mode, max_arch, current_arch, logger, sub=subnet_arch, 586 | rate=args.rate, group=growing_group) 587 | logger.info( 588 | '******> growing to resnet-%s before epoch %d' % (list_to_str(current_arch), (interval + 1) * args.grow_interval)) 589 | net = get_module(args.residual, args.grow_interval, num_blocks=current_arch) 590 | optimizer = get_optimizer(net) 591 | loaded_epoch = load_all(net, optimizer, save_ckpt) 592 | # logger.info('testing new model ...') 593 | # test(loaded_epoch, net) 594 | growing_epochs.append((interval + 1) * args.grow_interval) 595 | if args.growing_mode == 'group': 596 | grown_group = growing_group 597 | growing_group = utils.next_group(growing_group, max_arch, current_arch, logger) 598 | else: 599 | logger.info('******> stop growing all groups') 600 | last_epoch = (interval + 1) * args.grow_interval - 1 601 | logger.info('******> reach limitation. Finished in advance @ epoch %d' % last_epoch) 602 | curves = curves[:last_epoch+1+num_tail_epochs, :] 603 | break 604 | last_epoch = (interval + 1) * args.grow_interval - 1 605 | 606 | set_learning_rate(optimizer, args.lr) 607 | for epoch in range(last_epoch + 1 + args.start_epoch, last_epoch + 1 + num_tail_epochs): 608 | if args.optimizer == 'sgd' or args.optimizer == 'sgdc': 609 | if epoch < last_epoch + 1 + num_tail_epochs // 3: 610 | set_learning_rate(optimizer, args.lr) 611 | elif epoch < last_epoch + 1 + num_tail_epochs * 2 // 3: 612 | set_learning_rate(optimizer, args.lr * 0.1) 613 | else: 614 | set_learning_rate(optimizer, args.lr * 0.01) 615 | curves[epoch, 0] = epoch 616 | curves[epoch, 1], curves[epoch, 2], _ = train(epoch, net, chunk_num=cur_chunk_num) 617 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 618 | curves[epoch, 5] = time.time() / 60.0 619 | ema.push(curves[epoch, 4]) 620 | 621 | # align time 622 | for e in range(curves.shape[0]): 623 | curves[curves.shape[0]-1-e, 5] -= curves[0, 5] 624 | 625 | # plotting 626 | plot_segs = [0] + growing_epochs 627 | if len(growing_epochs) == 0 or growing_epochs[-1] != curves.shape[0]-1: 628 | plot_segs = plot_segs + [curves.shape[0]-1] 629 | logger.info('growing epochs {}'.format(list_to_str(growing_epochs))) 630 | logger.info('curves: \n {}'.format(np.array_str(curves))) 631 | np.savetxt(os.path.join(save_path, 'curves.dat'), curves) 632 | clr1 = (0.5, 0., 0.) 633 | clr2 = (0.0, 0.5, 0.) 634 | fig, ax1 = plt.subplots() 635 | fig2, ax3 = plt.subplots() 636 | ax2 = ax1.twinx() 637 | ax4 = ax3.twinx() 638 | ax1.set_xlabel('epoch') 639 | ax1.set_ylabel('Loss', color=clr1) 640 | ax1.tick_params(axis='y', colors=clr1) 641 | ax2.set_ylabel('Accuracy (%)', color=clr2) 642 | ax2.tick_params(axis='y', colors=clr2) 643 | 644 | ax3.set_xlabel('time (mins)') 645 | ax3.set_ylabel('Loss', color=clr1) 646 | ax3.tick_params(axis='y', colors=clr1) 647 | ax4.set_ylabel('Accuracy (%)', color=clr2) 648 | ax4.tick_params(axis='y', colors=clr2) 649 | 650 | # ax2.set_ylim(80, 100) # no plot if enabled 651 | for idx in range(len(plot_segs)-1): 652 | start = plot_segs[idx] 653 | end = plot_segs[idx+1] + 1 if (plot_segs[idx+1] == curves.shape[0] - 1) else plot_segs[idx+1] 654 | markersize = 12 655 | coef = 2. if idx % 2 else 1. 656 | if idx == len(plot_segs)-2: 657 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize) 658 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize) 659 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize) 660 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize) 661 | 662 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize) 663 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize) 664 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize) 665 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize) 666 | else: 667 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 668 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 669 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 670 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 671 | 672 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 673 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 674 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 675 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 676 | 677 | if len(ema.get()) == curves.shape[0]: 678 | ax2.plot(curves[:, 0], ema.get(), '-', color=[1.0, 0, 1.0]) 679 | logger.info('Val accuracy moving average: \n {}'.format(np.array_str(np.array(ema.get())))) 680 | np.savetxt(os.path.join(save_path, 'ema.dat'), np.array(ema.get())) 681 | ax2.set_ylim(bottom=20, top=100) 682 | ax1.legend(('Train loss', 'Val loss'), loc='lower right') 683 | ax2.legend(('Train accuracy', 'Val accuracy', 'Val max'), loc='lower left') 684 | fig.savefig(os.path.join(save_path, 'curves-vs-epochs.pdf')) 685 | 686 | if len(ema.get()) == curves.shape[0]: 687 | ax4.plot(curves[:, 5], ema.get(), '-', color=[1.0, 0, 1.0]) 688 | ax4.set_ylim(bottom=20, top=100) 689 | ax3.legend(('Train loss', 'Val loss'), loc='lower right') 690 | ax4.legend(('Train accuracy', 'Val accuracy', 'Val moving avg'), loc='lower left') 691 | fig2.savefig(os.path.join(save_path, 'curves-vs-time.pdf')) 692 | 693 | 694 | logger.info('Done!') -------------------------------------------------------------------------------- /cifar10/main_gradual.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | 4 | import matplotlib 5 | matplotlib.use("pdf") 6 | import matplotlib.pyplot as plt 7 | import logging 8 | from datetime import datetime 9 | from copy import deepcopy 10 | import re 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | import torch.backends.cudnn as cudnn 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | import models.switch as ms 21 | import models.pswitch as ps 22 | 23 | import os 24 | import argparse 25 | import numpy as np 26 | import models 27 | import utils 28 | import time 29 | 30 | # from models import * 31 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 32 | parser.add_argument('--optimizer', '--opt', default='sgd', type=str, help='sgd variants (sgdc, sgd, adam, amsgrad, adagrad, adadelta, rmsprop)') 33 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 34 | parser.add_argument('--switch-reg', '--sr', default=None, type=float, help='lambda of L1 regularization on pswitch') 35 | parser.add_argument('--epochs', default=4000, type=int, help='the number of epochs') 36 | parser.add_argument('--grow-threshold', '--gt', default=0.05, type=float, help='the accuracy threshold to grow or stop') 37 | parser.add_argument('--growing-mode', default='group', type=str, help='how new structures are added (rate, all, sub, group)') 38 | parser.add_argument('--tail-epochs', '--te', default=200, type=int, help='the number of epochs after growing epochs (--epochs)') 39 | parser.add_argument('--pswitch-thre', '--pt', default=0.005, type=float, help='threshold to zero pswitchs') 40 | 41 | parser.add_argument('--grow-interval', '--gi', default=1, type=int, help='an interval (in epochs) to grow new structures') 42 | parser.add_argument('--stop-interval', '--si', default=100, type=int, help='an interval (in epochs) to grow new structures') 43 | parser.add_argument('--net', default='1-1-1', type=str, help='starting net') 44 | parser.add_argument('--sub-net', default='1-1-1', type=str, help='a sub net to grow') 45 | parser.add_argument('--max-net', default='200-200-200', type=str, help='The maximum net') 46 | parser.add_argument('--residual', default='CifarResNetBasic', type=str, 47 | help='the type of residual block (CifarSwitchResNetBasic, CifarPlainNoBNNet, CifarPlainNet, PlainNet, PlainNoBNNet, ResNetBasic or ResNetBottleneck or CifarResNetBasic)') 48 | parser.add_argument('--initializer', '--init', default='gaussian', type=str, help='initializers of new structures (zero, uniform, gaussian, adam)') 49 | 50 | parser.add_argument('--switch-off', '--so', action='store_true', help='switch off at initialization') 51 | parser.add_argument('--ema-params', '--ep', action='store_true', help='validating accuracy by a exponentially moving average of parameters') 52 | parser.add_argument('--rate', default=0.4, type=float, help='the rate to grow when --growing-mode=rate') 53 | parser.add_argument('--growing-metric', default='max', type=str, help='the metric for growing (max or avg)') 54 | parser.add_argument('--reset-states', '--rs', action='store_true', help='reset optimizer states or not (such as momentum)') 55 | parser.add_argument('--init-meta', default=1.0, type=float, help='a meta parameter for initializer') 56 | parser.add_argument('--batch-size', '--bz', default=128, type=int, help='batch size') 57 | parser.add_argument('--evaluate', default='', type=str, metavar='PATH', 58 | help='path to checkpoint (default: none)') 59 | parser.add_argument('--pad-net', default='', type=str, metavar='NET', 60 | help='a smallest net to be padded to (default: none)') 61 | parser.add_argument('--pad-epochs', default=4, type=int, help='the padding step for visualization') 62 | parser.add_argument('--dataset-ratio', default=1.0, type=float, help='the ratio of training dataset for learning') 63 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset (CIFAR10, CIFAR100, SVHN, FashionMNIST, MNIST)') 64 | 65 | args = parser.parse_args() 66 | 67 | def list_to_str(l): 68 | list(map(str, l)) 69 | s = '' 70 | for v in l: 71 | s += str(v) + '-' 72 | return s[:-1] 73 | 74 | def get_module(name, switch_steps, *_args, **keywords): 75 | net = getattr(models, name)(*_args, **keywords) 76 | net = net.to('cuda') 77 | net = torch.nn.DataParallel(net) 78 | cudnn.benchmark = True 79 | configure_switch_policy(net, switch_steps) 80 | return net 81 | 82 | def get_optimizer(net): 83 | if 'sgd' == args.optimizer or 'sgdc' == args.optimizer: 84 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 85 | elif 'adam' == args.optimizer: 86 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4) 87 | elif 'amsgrad' == args.optimizer: 88 | optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4, amsgrad=True) 89 | elif 'adagrad' == args.optimizer: 90 | optimizer = optim.Adagrad(net.parameters(), lr=args.lr, weight_decay=5e-4) 91 | elif 'adadelta' == args.optimizer: 92 | optimizer = optim.Adadelta(net.parameters(), weight_decay=5e-4) 93 | elif 'rmsprop' == args.optimizer: 94 | optimizer = optim.RMSprop(net.parameters(), lr=args.lr, alpha=0.99, weight_decay=5e-4) 95 | else: 96 | logger.fatal('Unknown --optimizer') 97 | raise ValueError('Unknown --optimizer') 98 | return optimizer 99 | 100 | def params_id_to_name(net): 101 | themap = {} 102 | for k, v in net.named_parameters(): 103 | themap[id(v)] = k 104 | return themap 105 | 106 | def params_name_to_id(net): 107 | themap = {} 108 | for k, v in net.named_parameters(): 109 | themap[k] = id(v) 110 | return themap 111 | 112 | def save_model_with_padding(epoch, train_accu, model, new_arch, path): 113 | new_model = get_module(args.residual, args.grow_interval, new_arch, 114 | num_classes=utils.datasets[args.dataset]['num_classes'], 115 | image_channels=utils.datasets[args.dataset]['image_channels']) 116 | new_model.load_state_dict(model.state_dict(), strict=False) 117 | orig_params_data = {} 118 | for n, p in model.named_parameters(): 119 | orig_params_data[n] = p.data 120 | for n, p in new_model.named_parameters(): 121 | if n not in orig_params_data: 122 | logger.info('%s are set to zeros' % n) 123 | p.data.zero_() 124 | else: 125 | logger.info('%s are kept' % n) 126 | 127 | torch.save({ 128 | 'epoch': epoch, 129 | 'train_accu': train_accu, 130 | 'net': new_model.state_dict(), 131 | }, path) 132 | 133 | def save_all(epoch, train_accu, model, optimizer, path): 134 | torch.save({ 135 | 'epoch': epoch, 136 | 'train_accu': train_accu, 137 | 'model_state_dict': model.state_dict(), 138 | 'optimizer_state_dict': optimizer.state_dict(), 139 | 'name_id_map': params_name_to_id(model), 140 | }, path) 141 | 142 | def zerout_pswitchs(model, threshold=0.0001, log=False): 143 | total = 0 144 | zeroed = 0 145 | for idx, m in enumerate(model.named_modules()): 146 | if isinstance(m[1], ps.PSwitch): 147 | total += 1 148 | # logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 149 | if m[1].switch.data.abs() < threshold: 150 | if log: 151 | logger.info('******> switch {} is zeroed out...'.format(m[0])) 152 | m[1].switch.data.fill_(0.0) 153 | zeroed += 1 154 | if log: 155 | logger.info('%d/%d zeroed out' % (zeroed, total)) 156 | 157 | def reg_pswitchs(model): 158 | reg = 0.0 159 | for idx, m in enumerate(model.named_modules()): 160 | if isinstance(m[1], ps.PSwitch): 161 | reg += m[1].switch.norm(p=1) 162 | return reg * args.switch_reg 163 | 164 | def print_switchs(model): 165 | for idx, m in enumerate(model.named_modules()): 166 | if isinstance(m[1], ms.Switch): 167 | logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 168 | 169 | def print_pswitchs(model): 170 | for idx, m in enumerate(model.named_modules()): 171 | if isinstance(m[1], ps.PSwitch): 172 | logger.info('******> switch {} is {}...'.format(m[0], m[1].get_switch())) 173 | 174 | def increase_switchs(model): 175 | for idx, m in enumerate(model.named_modules()): 176 | if isinstance(m[1], ms.Switch): 177 | m[1].increase() 178 | 179 | def configure_switch_policy(model, steps, start=0.0, stop=1.0, mode='linear'): 180 | for idx, m in enumerate(model.named_modules()): 181 | if isinstance(m[1], ms.Switch): 182 | logger.info('******> configuring switch {}...'.format(m[0])) 183 | m[1].set_params(steps, start, stop, mode) 184 | 185 | 186 | def load_all(model, optimizer, path): 187 | checkpoint = torch.load(path) 188 | old_name_id_map = checkpoint['name_id_map'] 189 | new_id_name_map = params_id_to_name(model) 190 | # load existing params, and initializing missing ones 191 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 192 | new_params = [] 193 | if args.residual == 'ResNetBasic' or args.residual == 'CifarResNetBasic' or args.residual == 'CifarSwitchResNetBasic' : 194 | reinit_pattern = '.*layer.*bn2\.((weight)|(bias))$' 195 | elif args.residual == 'ResNetBottleneck': 196 | reinit_pattern = '.*layer.*bn3\.((weight)|(bias))$' 197 | elif args.residual == 'PlainNet' or args.residual == 'PlainNoBNNet' or args.residual == 'CifarPlainNet' or args.residual == 'CifarPlainNoBNNet': 198 | reinit_pattern = 'UseDefaultInitialization' 199 | logger.info('No reinitialization for PlainNet or PlainNoBNNet or CifarPlainNet or CifarPlainNoBNNet') 200 | else: 201 | logger.fatal('Unknown --residual') 202 | exit() 203 | for n, p in model.named_parameters(): 204 | if n not in old_name_id_map and re.match(reinit_pattern, n): 205 | logger.info('******> reinitializing param {} ...'.format(n)) 206 | new_params.append(p) 207 | if args.initializer == 'zero': 208 | logger.info('******> Initializing as zeros...') 209 | p.data.zero_() 210 | elif args.initializer == 'uniform': 211 | logger.info('******> Initializing by uniform noises...') 212 | p.data.uniform_(0.0, to=args.init_meta) 213 | elif args.initializer == 'gaussian': 214 | logger.info('******> Initializing by gaussian noises') 215 | p.data.normal_(0.0, std=args.init_meta) 216 | elif args.initializer == 'adam': 217 | logger.info('******> Initializing by adam optimizer') 218 | else: 219 | logger.fatal('Unknown --initializer.') 220 | exit() 221 | if args.switch_off: 222 | switch_name = '.'.join(n.split('.')[:-2]+['switch.switch']) 223 | if switch_name in model.state_dict(): 224 | logger.info('******> resetting %s to 0.0 from %.3f' % (switch_name, model.state_dict()[switch_name])) 225 | model.state_dict()[switch_name].zero_() 226 | 227 | if len(new_params) and args.initializer == 'adam': 228 | logger.info('******> Using adam to find a good initialization') 229 | old_train_accu = checkpoint['train_accu'] 230 | local_optimizer = optim.Adam(new_params, lr=0.001, weight_decay=5e-4) 231 | max_epoch = 10 232 | founded = False 233 | for e in range(max_epoch): 234 | _, accu = train(e, model, local_optimizer) 235 | if accu > old_train_accu - 0.5: 236 | logger.info('******> Found a good initial position with training accuracy %.2f (vs. old %.2f) at epoch %d' % ( 237 | accu, old_train_accu, e)) 238 | founded = True 239 | break 240 | if not founded: 241 | logger.info('******> failed to find a good initial position in %d epochs. Continue...' % max_epoch) 242 | 243 | 244 | # load existing states, and insert missing states as empty dict 245 | if not args.reset_states: 246 | new_checkpoint = deepcopy(optimizer.state_dict()) 247 | old_checkpoint = checkpoint['optimizer_state_dict'] 248 | if len(old_checkpoint['param_groups']) != 1 or len(new_checkpoint['param_groups']) != 1: 249 | logger.fatal('The number of param_groups is not 1.') 250 | exit() 251 | for new_id in new_checkpoint['param_groups'][0]['params']: 252 | name = new_id_name_map[new_id] 253 | if name in old_name_id_map: 254 | old_id = old_name_id_map[name] 255 | if old_id in old_checkpoint['state']: 256 | logger.info('loading param {} state...'.format(name)) 257 | new_checkpoint['state'][new_id] = old_checkpoint['state'][old_id] 258 | else: 259 | logger.info('initializing param {} state as an empty dict...'.format(name)) 260 | new_checkpoint['state'][new_id] = {} 261 | else: 262 | if new_id not in new_checkpoint['state']: 263 | logger.info('initializing param {} state as an empty dict...'.format(name)) 264 | new_checkpoint['state'][new_id] = {} 265 | else: 266 | logger.info('skipping param {} state (initial state exists)...'.format(name)) 267 | optimizer.load_state_dict(new_checkpoint) 268 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 269 | epoch = checkpoint['epoch'] 270 | 271 | return epoch 272 | 273 | def set_learning_rate(optimizer, lr): 274 | """Sets the learning rate """ 275 | logger.info('\nSetting learning rate to %.6f' % lr) 276 | for param_group in optimizer.param_groups: 277 | param_group['lr'] = lr 278 | 279 | def decay_learning_rate(optimizer): 280 | """Sets the learning rate to the initial LR decayed by 10""" 281 | for param_group in optimizer.param_groups: 282 | param_group['lr'] = param_group['lr'] * 0.1 283 | 284 | device = 'cuda' # if torch.cuda.is_available() else 'cpu' 285 | best_acc = 0 # best test accuracy 286 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 287 | 288 | save_path = os.path.join('./results', datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 289 | if not os.path.exists(save_path): 290 | os.makedirs(save_path) 291 | else: 292 | raise OSError('Directory {%s} exists. Use a new one.' % save_path) 293 | logging.basicConfig(filename=os.path.join(save_path, 'log.txt'), level=logging.INFO) 294 | logger = logging.getLogger('main') 295 | logger.addHandler(logging.StreamHandler()) 296 | logger.info("Saving to %s", save_path) 297 | logger.info("Running arguments: %s", args) 298 | 299 | # Data 300 | logger.info('==> Preparing data %s..' % args.dataset) 301 | train_padding = 4 if utils.datasets[args.dataset]['size'] == 32 else (32 - utils.datasets[args.dataset]['size'] ) / 2 302 | test_padding = 0 if utils.datasets[args.dataset]['size'] == 32 else (32 - utils.datasets[args.dataset]['size'] ) / 2 303 | assert (test_padding * 2 + utils.datasets[args.dataset]['size']) == 32 304 | logger.info('train pad = %d, test pad = %d' % (train_padding, test_padding)) 305 | if 'CIFAR10' == args.dataset or 'CIFAR100' == args.dataset or 'FashionMNIST' == args.dataset: 306 | logger.info('==> RandomHorizontalFlip enabled..') 307 | transform_train = transforms.Compose([ 308 | transforms.RandomCrop(32, padding=train_padding), 309 | transforms.RandomHorizontalFlip(), 310 | transforms.ToTensor(), 311 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 312 | ]) 313 | else: 314 | logger.info('==> RandomHorizontalFlip disabled..') 315 | transform_train = transforms.Compose([ 316 | transforms.RandomCrop(32, padding=train_padding), 317 | transforms.ToTensor(), 318 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 319 | ]) 320 | 321 | transform_test = transforms.Compose([ 322 | transforms.RandomCrop(32, padding=test_padding), # deterministic 323 | transforms.ToTensor(), 324 | transforms.Normalize(utils.datasets[args.dataset]['mean'], utils.datasets[args.dataset]['std']), 325 | ]) 326 | 327 | # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 328 | if 'SVHN' == args.dataset: 329 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, split='train', download=True, 330 | transform=transform_train) 331 | else: 332 | trainset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, train=True, download=True, 333 | transform=transform_train) 334 | logger.info('%d training samples.' % len(trainset)) 335 | 336 | train_sample_num = int(len(trainset) * args.dataset_ratio) 337 | trainset, _ = torch.utils.data.random_split(trainset, [train_sample_num, len(trainset) - train_sample_num]) 338 | logger.info('%d training samples are used for training' % len(trainset)) 339 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 340 | 341 | # testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 342 | if 'SVHN' == args.dataset: 343 | testset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, split='test', download=True, transform=transform_test) 344 | else: 345 | testset = getattr(torchvision.datasets, args.dataset)(root='./data-' + args.dataset, train=False, download=True, transform=transform_test) 346 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 347 | logger.info('%d test samples.' % len(testset)) 348 | 349 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 350 | 351 | # Model 352 | logger.info('==> Building model..') 353 | current_arch = list(map(int, args.net.split('-'))) 354 | subnet_arch = list(map(int, args.sub_net.split('-'))) 355 | max_arch = list(map(int, args.max_net.split('-'))) 356 | pad_arch = list(map(int, args.pad_net.split('-'))) if args.pad_net else None 357 | if len(current_arch) != len(max_arch): 358 | logger.fatal('max_arch has different size.') 359 | exit() 360 | growing_group = -1 361 | grown_group = None 362 | for cnt, v in enumerate(current_arch): 363 | if v < max_arch[cnt]: 364 | growing_group = cnt 365 | break 366 | 367 | net = get_module(args.residual, args.grow_interval, current_arch, 368 | num_classes=utils.datasets[args.dataset]['num_classes'], 369 | image_channels=utils.datasets[args.dataset]['image_channels']) 370 | # net = VGG('VGG19') 371 | # net = ResNet18() 372 | # net = PreActResNet18() 373 | # net = GoogLeNet() 374 | # net = DenseNet121() 375 | # net = ResNeXt29_2x64d() 376 | # net = MobileNet() 377 | # net = MobileNetV2() 378 | # net = DPN92() 379 | # net = ShuffleNetG2() 380 | # net = SENet18() 381 | 382 | criterion = nn.CrossEntropyLoss() 383 | optimizer = get_optimizer(net) 384 | param_ema = utils.TorchExponentialMovingAverage() 385 | # Training 386 | def train(epoch, net, own_optimizer=None, increase_switch=False): 387 | logger.info('\nTraining epoch %d @ %.1f sec' % (epoch, time.time())) 388 | net.train() 389 | train_loss = 0 390 | correct = 0 391 | total = 0 392 | print_pswitchs(net) 393 | if increase_switch: 394 | increase_switchs(net) 395 | for batch_idx, (inputs, targets) in enumerate(trainloader): 396 | inputs, targets = inputs.to(device), targets.to(device) 397 | if own_optimizer is not None: 398 | own_optimizer.zero_grad() 399 | else: 400 | optimizer.zero_grad() 401 | outputs = net(inputs) 402 | loss = criterion(outputs, targets) 403 | total_cost = loss 404 | sreg = 0.0 405 | if args.switch_reg is not None: 406 | sreg = reg_pswitchs(net) 407 | total_cost += sreg 408 | total_cost.backward() 409 | if own_optimizer is not None: 410 | own_optimizer.step() 411 | else: 412 | optimizer.step() 413 | 414 | if args.switch_reg is not None: 415 | zerout_pswitchs(net, args.pswitch_thre) 416 | 417 | # maintain a moving average 418 | if args.ema_params: 419 | params_data_dict = {} 420 | for n, p in net.named_parameters(): 421 | params_data_dict[n] = p.data 422 | param_ema.push(params_data_dict) 423 | 424 | train_loss += loss.item() 425 | _, predicted = outputs.max(1) 426 | total += targets.size(0) 427 | correct += predicted.eq(targets).sum().item() 428 | if 0 == batch_idx % 100 or batch_idx == len(trainloader) - 1: 429 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 430 | % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 431 | if args.switch_reg is not None: 432 | logger.info(' ==> PSwitch L1 Reg.: %.6f ' % sreg) 433 | return train_loss / len(trainloader), 100.*correct/total 434 | 435 | def test(epoch, net, save=False): 436 | logger.info('Testing epoch %d @ %.1f sec' % (epoch, time.time())) 437 | global best_acc 438 | net.eval() 439 | test_loss = 0 440 | correct = 0 441 | total = 0 442 | with torch.no_grad(): 443 | if args.ema_params: 444 | logger.info('Using average params for test') 445 | orig_params = utils.set_named_parameters(net, param_ema.average(), strict=False) 446 | for batch_idx, (inputs, targets) in enumerate(testloader): 447 | inputs, targets = inputs.to(device), targets.to(device) 448 | outputs = net(inputs) 449 | loss = criterion(outputs, targets) 450 | 451 | test_loss += loss.item() 452 | _, predicted = outputs.max(1) 453 | total += targets.size(0) 454 | correct += predicted.eq(targets).sum().item() 455 | if 0 == batch_idx % 100 or batch_idx == len(testloader) - 1: 456 | logger.info('(%d/%d) ==> Loss: %.3f | Acc: %.3f%% (%d/%d)' 457 | % (batch_idx+1, len(testloader), test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 458 | 459 | # Save checkpoint. 460 | acc = 100.*correct/total 461 | global current_arch 462 | if acc > best_acc and save: 463 | logger.info('Saving best %.3f @ %d ( resnet-%s )...' %(acc, epoch, list_to_str(current_arch))) 464 | state = { 465 | 'net': net.state_dict(), 466 | 'acc': acc, 467 | 'epoch': epoch, 468 | } 469 | torch.save(state, os.path.join(save_path, 'best_ckpt.t7')) 470 | best_acc = acc if acc > best_acc else best_acc 471 | 472 | with torch.no_grad(): 473 | if args.ema_params: 474 | utils.set_named_parameters(net, orig_params, strict=True) 475 | 476 | return test_loss / len(testloader), acc 477 | 478 | # main func 479 | 480 | # resume and evaluate from a checkpoint 481 | if args.evaluate: 482 | if os.path.isfile(args.evaluate): 483 | # load existing params, and initializing missing ones 484 | print("=> loading checkpoint '{}'".format(args.evaluate)) 485 | checkpoint = torch.load(args.evaluate) 486 | net.load_state_dict(checkpoint['net']) 487 | print("=> loaded checkpoint '{}' (epoch {})" 488 | .format(args.evaluate, checkpoint['epoch'])) 489 | print_pswitchs(net) 490 | logger.info('zeroing out small pswitchs...') 491 | zerout_pswitchs(net, args.pswitch_thre, log=True) 492 | test(checkpoint['epoch'], net) 493 | logger.info('evaluation done!') 494 | else: 495 | print("=> no checkpoint found at '{}'".format(args.evaluate)) 496 | exit() 497 | 498 | if args.growing_metric == 'max': 499 | ema = utils.MovingMaximum() 500 | elif args.growing_metric == 'avg': 501 | ema = utils.ExponentialMovingAverage(decay=0.95) 502 | else: 503 | logger.fatal('Unknown --growing-metric') 504 | exit() 505 | 506 | 507 | def can_grow(maxlim, arch): 508 | for maxv, a in zip(maxlim, arch): 509 | if maxv > a: 510 | return True 511 | return False 512 | 513 | num_tail_epochs = args.tail_epochs # if (args.optimizer == 'sgd' or args.optimizer == 'sgdc') else 0 514 | last_epoch = -1 515 | growing_epochs = [] 516 | intervals = (args.epochs - 1) // args.grow_interval + 1 517 | # epoch, train loss, train accu, test loss, test accu, timestamps 518 | curves = np.zeros((intervals*args.grow_interval + num_tail_epochs, 6)) 519 | for interval in range(0, intervals): 520 | # training and testing 521 | for epoch in range(interval*args.grow_interval, (interval+1)*args.grow_interval): 522 | if 'sgdc' == args.optimizer: 523 | e = epoch % args.grow_interval 524 | if e < args.grow_interval // 2: 525 | set_learning_rate(optimizer, args.lr) 526 | elif e < args.grow_interval * 3 // 4: 527 | set_learning_rate(optimizer, args.lr * 0.1) 528 | else: 529 | set_learning_rate(optimizer, args.lr * 0.01) 530 | curves[epoch, 0] = epoch 531 | curves[epoch, 1], curves[epoch, 2] = train(epoch, net) 532 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 533 | curves[epoch, 5] = time.time() / 60.0 534 | ema.push(curves[epoch, 4]) 535 | if args.pad_net and (epoch % args.pad_epochs == 0): 536 | save_model_with_padding(epoch, curves[epoch, 2], net, pad_arch, 537 | os.path.join(save_path, 'model_pad_%d.t7' % epoch)) 538 | 539 | # limit max arch 540 | logger.info('******> improved %.3f (ExponentialMovingAverage) in the last %d epochs' % ( 541 | ema.delta(-1 - args.grow_interval, -1), args.grow_interval)) 542 | delta_accu = ema.delta(-1 - args.stop_interval, -1) 543 | logger.info( 544 | '******> improved %.3f (ExponentialMovingAverage) in the last %d epochs' % (delta_accu, args.stop_interval)) 545 | if delta_accu < args.grow_threshold: # no improvement 546 | if args.growing_mode == 'group': 547 | if grown_group is not None: 548 | max_arch[grown_group] = current_arch[grown_group] 549 | logger.info('******> stop growing group %d permanently. Limited as %s .' % (grown_group, list_to_str(max_arch))) 550 | else: 551 | max_arch[:] = current_arch[:] 552 | logger.info('******> stop growing all permanently. Limited as %s .' % (list_to_str(max_arch))) 553 | 554 | if can_grow(max_arch, current_arch): 555 | # save current model 556 | save_ckpt = os.path.join(save_path, 'resnet-growing_ckpt.t7') 557 | save_all((interval + 1) * args.grow_interval - 1, 558 | curves[(interval + 1) * args.grow_interval - 1, 2], 559 | net, 560 | optimizer, 561 | save_ckpt) 562 | # create a new net and optimizer 563 | current_arch = utils.next_arch(args.growing_mode, max_arch, current_arch, logger, sub=subnet_arch, 564 | rate=args.rate, group=growing_group) 565 | logger.info( 566 | '******> growing to resnet-%s before epoch %d' % (list_to_str(current_arch), (interval + 1) * args.grow_interval)) 567 | net = get_module(args.residual, args.grow_interval, num_blocks=current_arch, 568 | num_classes=utils.datasets[args.dataset]['num_classes'], 569 | image_channels=utils.datasets[args.dataset]['image_channels']) 570 | optimizer = get_optimizer(net) 571 | loaded_epoch = load_all(net, optimizer, save_ckpt) 572 | logger.info('testing new model ...') 573 | test(loaded_epoch, net) 574 | growing_epochs.append((interval + 1) * args.grow_interval) 575 | if args.growing_mode == 'group': 576 | grown_group = growing_group 577 | growing_group = utils.next_group(growing_group, max_arch, current_arch, logger) 578 | else: 579 | logger.info('******> stop growing all groups') 580 | last_epoch = (interval + 1) * args.grow_interval - 1 581 | logger.info('******> reach limitation. Finished in advance @ epoch %d' % last_epoch) 582 | curves = curves[:last_epoch+1+num_tail_epochs, :] 583 | break 584 | last_epoch = (interval + 1) * args.grow_interval - 1 585 | 586 | set_learning_rate(optimizer, args.lr) 587 | for epoch in range(last_epoch + 1, last_epoch + 1 + num_tail_epochs): 588 | if ((epoch == last_epoch + 1 + num_tail_epochs // 2) or (epoch == last_epoch + 1 + num_tail_epochs * 3 // 4)) and ( 589 | args.optimizer == 'sgd' or args.optimizer == 'sgdc'): 590 | logger.info('======> decaying learning rate') 591 | decay_learning_rate(optimizer) 592 | curves[epoch, 0] = epoch 593 | curves[epoch, 1], curves[epoch, 2] = train(epoch, net) 594 | curves[epoch, 3], curves[epoch, 4] = test(epoch, net, save=True) 595 | curves[epoch, 5] = time.time() / 60.0 596 | ema.push(curves[epoch, 4]) 597 | if args.pad_net and (epoch % args.pad_epochs == 0): 598 | save_model_with_padding(epoch, curves[epoch, 2], net, pad_arch, 599 | os.path.join(save_path, 'model_pad_%d.t7' % epoch)) 600 | 601 | # align time 602 | for e in range(curves.shape[0]): 603 | curves[curves.shape[0]-1-e, 5] -= curves[0, 5] 604 | 605 | # plotting 606 | plot_segs = [0] + growing_epochs 607 | if len(growing_epochs) == 0 or growing_epochs[-1] != curves.shape[0]-1: 608 | plot_segs = plot_segs + [curves.shape[0]-1] 609 | logger.info('growing epochs {}'.format(list_to_str(growing_epochs))) 610 | logger.info('curves: \n {}'.format(np.array_str(curves))) 611 | np.savetxt(os.path.join(save_path, 'curves.dat'), curves) 612 | clr1 = (0.5, 0., 0.) 613 | clr2 = (0.0, 0.5, 0.) 614 | fig, ax1 = plt.subplots() 615 | fig2, ax3 = plt.subplots() 616 | ax2 = ax1.twinx() 617 | ax4 = ax3.twinx() 618 | ax1.set_xlabel('epoch') 619 | ax1.set_ylabel('Loss', color=clr1) 620 | ax1.tick_params(axis='y', colors=clr1) 621 | ax2.set_ylabel('Accuracy (%)', color=clr2) 622 | ax2.tick_params(axis='y', colors=clr2) 623 | 624 | ax3.set_xlabel('time (mins)') 625 | ax3.set_ylabel('Loss', color=clr1) 626 | ax3.tick_params(axis='y', colors=clr1) 627 | ax4.set_ylabel('Accuracy (%)', color=clr2) 628 | ax4.tick_params(axis='y', colors=clr2) 629 | 630 | # ax2.set_ylim(80, 100) # no plot if enabled 631 | for idx in range(len(plot_segs)-1): 632 | start = plot_segs[idx] 633 | end = plot_segs[idx+1] + 1 if (plot_segs[idx+1] == curves.shape[0] - 1) else plot_segs[idx+1] 634 | markersize = 12 635 | coef = 2. if idx % 2 else 1. 636 | if idx == len(plot_segs)-2: 637 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize) 638 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize) 639 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize) 640 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize) 641 | 642 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize) 643 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize) 644 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize) 645 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize) 646 | else: 647 | ax1.semilogy(curves[start:end, 0], curves[start:end, 1], '--', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 648 | ax1.semilogy(curves[start:end, 0], curves[start:end, 3], '-', color=[c*coef for c in clr1], markersize=markersize, label='_nolegend_') 649 | ax2.plot(curves[start:end, 0], curves[start:end, 2], '--', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 650 | ax2.plot(curves[start:end, 0], curves[start:end, 4], '-', color=[c*coef for c in clr2], markersize=markersize, label='_nolegend_') 651 | 652 | ax3.semilogy(curves[start:end, 5], curves[start:end, 1], '--', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 653 | ax3.semilogy(curves[start:end, 5], curves[start:end, 3], '-', color=[c * coef for c in clr1], markersize=markersize, label='_nolegend_') 654 | ax4.plot(curves[start:end, 5], curves[start:end, 2], '--', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 655 | ax4.plot(curves[start:end, 5], curves[start:end, 4], '-', color=[c * coef for c in clr2], markersize=markersize, label='_nolegend_') 656 | 657 | ax2.plot(curves[:, 0], ema.get(), '-', color=[1.0, 0, 1.0]) 658 | logger.info('Val accuracy moving average: \n {}'.format(np.array_str(np.array(ema.get())))) 659 | np.savetxt(os.path.join(save_path, 'ema.dat'), np.array(ema.get())) 660 | ax2.set_ylim(bottom=40, top=100) 661 | ax1.legend(('Train loss', 'Val loss'), loc='lower right') 662 | ax2.legend(('Train accuracy', 'Val accuracy', 'Val max'), loc='lower left') 663 | fig.savefig(os.path.join(save_path, 'curves-vs-epochs.pdf')) 664 | 665 | ax4.plot(curves[:, 5], ema.get(), '-', color=[1.0, 0, 1.0]) 666 | ax4.set_ylim(bottom=40, top=100) 667 | ax3.legend(('Train loss', 'Val loss'), loc='lower right') 668 | ax4.legend(('Train accuracy', 'Val accuracy', 'Val moving avg'), loc='lower left') 669 | fig2.savefig(os.path.join(save_path, 'curves-vs-time.pdf')) 670 | 671 | 672 | logger.info('Done!') --------------------------------------------------------------------------------