├── requirements.txt ├── models ├── __init__.py ├── mnist_f1.py ├── cifar_shallow.py ├── alexnet.py ├── resnet.py └── wide_resnet.py ├── LICENSE ├── run_experiments.sh ├── data.py ├── README.md ├── preprocess.py ├── utils.py ├── main_normal.py └── main_gbn.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .resnet import * 3 | from .mnist_f1 import * 4 | from .wide_resnet import * 5 | from .cifar_shallow import * 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Elad Hoffer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/mnist_f1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | __all__ = ['mnist_f1'] 4 | 5 | 6 | class mnist_model(nn.Module): 7 | 8 | def __init__(self): 9 | super(mnist_model, self).__init__() 10 | self.layers = nn.Sequential( 11 | nn.Linear(28 * 28, 512), 12 | nn.BatchNorm1d(512), 13 | nn.ReLU(True), 14 | nn.Linear(512, 512), 15 | nn.BatchNorm1d(512), 16 | nn.ReLU(True), 17 | nn.Linear(512, 512), 18 | nn.BatchNorm1d(512), 19 | nn.ReLU(True), 20 | nn.Linear(512, 512), 21 | nn.BatchNorm1d(512), 22 | nn.ReLU(True), 23 | nn.Linear(512, 512), 24 | nn.BatchNorm1d(512), 25 | nn.ReLU(True), 26 | nn.Linear(512, 10), 27 | ) 28 | self.regime = { 29 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 30 | 'weight_decay': 1e-4, 'momentum': 0.9}, 31 | 10: {'lr': 1e-2}, 32 | 20: {'lr': 1e-3}, 33 | 30: {'lr': 1e-4} 34 | } 35 | 36 | def forward(self, inputs): 37 | return self.layers(inputs.view(inputs.size(0), -1)) 38 | 39 | 40 | def mnist_f1(**kwargs): 41 | return mnist_model() 42 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_baseline --epochs 100 --b 2048 --no-lr_bb_fix; 2 | python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix; 3 | python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix; 4 | python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256; 5 | python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix; 6 | python main_normal.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs2048_no_lr_fix --epochs 50 --b 2048 --no-lr_bb_fix; 7 | python main_normal.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs2048 --epochs 50 --b 2048 --lr_bb_fix; 8 | python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128; 9 | python main_gbn.py --model cifar100_shallow --dataset cifar100 --save shallow_cifar100_baseline_bs4096_gbn --epochs 200 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128; 10 | python main_gbn.py --model cifar10_shallow --dataset cifar10 --save shallow_cifar10_baseline_bs4096_gbn --epochs 200 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128; 11 | -------------------------------------------------------------------------------- /models/cifar_shallow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | __all__ = ['cifar10_shallow', 'cifar100_shallow'] 4 | 5 | 6 | class AlexNet(nn.Module): 7 | 8 | def __init__(self, num_classes=10): 9 | super(AlexNet, self).__init__() 10 | self.features = nn.Sequential( 11 | nn.Conv2d(3, 64, kernel_size=5, padding=2, 12 | bias=False), 13 | nn.MaxPool2d(kernel_size=3, stride=2), 14 | nn.BatchNorm2d(64), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(64, 64, kernel_size=5, padding=2, bias=False), 17 | nn.MaxPool2d(kernel_size=3, stride=2), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(inplace=True), 20 | ) 21 | self.classifier = nn.Sequential( 22 | nn.Linear(64 * 7 * 7, 384, bias=False), 23 | nn.BatchNorm1d(384), 24 | nn.ReLU(inplace=True), 25 | nn.Dropout(0.5), 26 | nn.Linear(384, 192, bias=False), 27 | nn.BatchNorm1d(192), 28 | nn.ReLU(inplace=True), 29 | nn.Dropout(0.5), 30 | nn.Linear(192, num_classes) 31 | ) 32 | self.regime = { 33 | 0: {'optimizer': 'SGD', 'lr': 1e-3, 34 | 'weight_decay': 5e-4}, 35 | 60: {'lr': 1e-2}, 36 | 120: {'lr': 1e-3}, 37 | 180: {'lr': 1e-4} 38 | } 39 | 40 | def forward(self, x): 41 | x = self.features(x) 42 | x = x.view(-1, 64 * 7 * 7) 43 | x = self.classifier(x) 44 | return x 45 | 46 | 47 | def cifar10_shallow(**kwargs): 48 | num_classes = getattr(kwargs, 'num_classes', 10) 49 | return AlexNet(num_classes) 50 | 51 | 52 | def cifar100_shallow(**kwargs): 53 | num_classes = getattr(kwargs, 'num_classes', 100) 54 | return AlexNet(num_classes) 55 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets as datasets 3 | 4 | __DATASETS_DEFAULT_PATH = '/media/ssd/Datasets/' 5 | 6 | 7 | def get_dataset(name, split='train', transform=None, 8 | target_transform=None, download=True, datasets_path=__DATASETS_DEFAULT_PATH): 9 | train = (split == 'train') 10 | root = os.path.join(datasets_path, name) 11 | if name == 'cifar10': 12 | return datasets.CIFAR10(root=root, 13 | train=train, 14 | transform=transform, 15 | target_transform=target_transform, 16 | download=download) 17 | elif name == 'cifar100': 18 | return datasets.CIFAR100(root=root, 19 | train=train, 20 | transform=transform, 21 | target_transform=target_transform, 22 | download=download) 23 | elif name == 'mnist': 24 | return datasets.MNIST(root=root, 25 | train=train, 26 | transform=transform, 27 | target_transform=target_transform, 28 | download=download) 29 | elif name == 'stl10': 30 | return datasets.STL10(root=root, 31 | split=split, 32 | transform=transform, 33 | target_transform=target_transform, 34 | download=download) 35 | elif name == 'imagenet': 36 | if train: 37 | root = os.path.join(root, 'train') 38 | else: 39 | root = os.path.join(root, 'val') 40 | return datasets.ImageFolder(root=root, 41 | transform=transform, 42 | target_transform=target_transform) 43 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | 4 | __all__ = ['alexnet'] 5 | 6 | class AlexNetOWT_BN(nn.Module): 7 | 8 | def __init__(self, num_classes=1000): 9 | super(AlexNetOWT_BN, self).__init__() 10 | self.features = nn.Sequential( 11 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2, 12 | bias=False), 13 | nn.MaxPool2d(kernel_size=3, stride=2), 14 | nn.BatchNorm2d(64), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(64, 192, kernel_size=5, padding=2, bias=False), 17 | nn.MaxPool2d(kernel_size=3, stride=2), 18 | nn.ReLU(inplace=True), 19 | nn.BatchNorm2d(192), 20 | nn.Conv2d(192, 384, kernel_size=3, padding=1, bias=False), 21 | nn.ReLU(inplace=True), 22 | nn.BatchNorm2d(384), 23 | nn.Conv2d(384, 256, kernel_size=3, padding=1, bias=False), 24 | nn.ReLU(inplace=True), 25 | nn.BatchNorm2d(256), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 27 | nn.MaxPool2d(kernel_size=3, stride=2), 28 | nn.ReLU(inplace=True), 29 | nn.BatchNorm2d(256) 30 | ) 31 | self.classifier = nn.Sequential( 32 | nn.Linear(256 * 6 * 6, 4096, bias=False), 33 | nn.BatchNorm1d(4096), 34 | nn.ReLU(inplace=True), 35 | nn.Dropout(0.5), 36 | nn.Linear(4096, 4096, bias=False), 37 | nn.BatchNorm1d(4096), 38 | nn.ReLU(inplace=True), 39 | nn.Dropout(0.5), 40 | nn.Linear(4096, num_classes) 41 | ) 42 | 43 | self.regime = { 44 | 0: {'optimizer': 'SGD', 'lr': 1e-2, 45 | 'weight_decay': 5e-4, 'momentum': 0.9}, 46 | 10: {'lr': 5e-3}, 47 | 15: {'lr': 1e-3, 'weight_decay': 0}, 48 | 20: {'lr': 5e-4}, 49 | 25: {'lr': 1e-4} 50 | } 51 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 52 | std=[0.229, 0.224, 0.225]) 53 | self.input_transform = { 54 | 'train': transforms.Compose([ 55 | transforms.Scale(256), 56 | transforms.RandomCrop(224), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | normalize 60 | ]), 61 | 'eval': transforms.Compose([ 62 | transforms.Scale(256), 63 | transforms.CenterCrop(224), 64 | transforms.ToTensor(), 65 | normalize 66 | ]) 67 | } 68 | 69 | def forward(self, x): 70 | x = self.features(x) 71 | x = x.view(-1, 256 * 6 * 6) 72 | x = self.classifier(x) 73 | return x 74 | 75 | 76 | def alexnet(**kwargs): 77 | num_classes = getattr(kwargs, 'num_classes', 1000) 78 | return AlexNetOWT_BN(num_classes) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train longer, generalize better - Big batch training 2 | 3 | This is a code repository used to generate the results appearing in ["Train longer, generalize better: closing the generalization gap in large batch training of neural networks"](https://arxiv.org/abs/1705.08741) By Elad Hoffer, Itay Hubara and Daniel Soudry. 4 | 5 | It is based off [convNet.pytorch](https://github.com/eladhoffer/convNet.pytorch) with some helpful options such as: 6 | - Training on several datasets 7 | - Complete logging of trained experiment 8 | - Graph visualization of the training/validation loss and accuracy 9 | - Definition of preprocessing and optimization regime for each model 10 | 11 | ## Dependencies 12 | 13 | - [pytorch]() 14 | - [torchvision]() to load the datasets, perform image transforms 15 | - [pandas]() for logging to csv 16 | - [bokeh]() for training visualization 17 | 18 | 19 | ## Data 20 | - Configure your dataset path at **data.py**. 21 | - To get the ILSVRC data, you should register on their site for access: 22 | 23 | ## Experiment examples 24 | ```bash 25 | python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix; 26 | python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix; 27 | python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256; 28 | python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix; 29 | python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128; 30 | ``` 31 | - See *run_experiments.sh* for more examples 32 | ## Model configuration 33 | 34 | Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/\_\_init\_\_.py 35 | The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications. 36 | 37 | e.g for a model definition: 38 | 39 | ```python 40 | class Model(nn.Module): 41 | 42 | def __init__(self, num_classes=1000): 43 | super(Model, self).__init__() 44 | self.model = nn.Sequential(...) 45 | 46 | self.regime = { 47 | 0: {'optimizer': 'SGD', 'lr': 1e-2, 48 | 'weight_decay': 5e-4, 'momentum': 0.9}, 49 | 15: {'lr': 1e-3, 'weight_decay': 0} 50 | } 51 | 52 | self.input_transform = { 53 | 'train': transforms.Compose([...]), 54 | 'eval': transforms.Compose([...]) 55 | } 56 | def forward(self, inputs): 57 | return self.model(inputs) 58 | 59 | def model(**kwargs): 60 | return Model() 61 | ``` 62 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import random 4 | 5 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 6 | 'std': [0.229, 0.224, 0.225]} 7 | 8 | __imagenet_pca = { 9 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 10 | 'eigvec': torch.Tensor([ 11 | [-0.5675, 0.7192, 0.4009], 12 | [-0.5808, -0.0045, -0.8140], 13 | [-0.5836, -0.6948, 0.4203], 14 | ]) 15 | } 16 | 17 | 18 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 19 | t_list = [ 20 | transforms.CenterCrop(input_size), 21 | transforms.ToTensor(), 22 | transforms.Normalize(**normalize), 23 | ] 24 | if scale_size != input_size: 25 | t_list = [transforms.Scale(scale_size)] + t_list 26 | 27 | return transforms.Compose(t_list) 28 | 29 | 30 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 31 | t_list = [ 32 | transforms.RandomCrop(input_size), 33 | transforms.RandomHorizontalFlip(), 34 | transforms.ToTensor(), 35 | transforms.Normalize(**normalize), 36 | ] 37 | if scale_size != input_size: 38 | t_list = [transforms.Scale(scale_size)] + t_list 39 | 40 | return transforms.Compose(t_list) 41 | 42 | 43 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats, fill=0): 44 | padding = int((scale_size - input_size) / 2) 45 | return transforms.Compose([ 46 | transforms.Pad(padding, fill=fill), 47 | transforms.RandomCrop(input_size), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(**normalize), 51 | ]) 52 | 53 | 54 | def inception_preproccess(input_size, normalize=__imagenet_stats): 55 | return transforms.Compose([ 56 | transforms.RandomSizedCrop(input_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize(**normalize) 60 | ]) 61 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 62 | return transforms.Compose([ 63 | transforms.RandomSizedCrop(input_size), 64 | transforms.RandomHorizontalFlip(), 65 | transforms.ToTensor(), 66 | ColorJitter( 67 | brightness=0.4, 68 | contrast=0.4, 69 | saturation=0.4, 70 | ), 71 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 72 | transforms.Normalize(**normalize) 73 | ]) 74 | 75 | 76 | def get_transform(name='imagenet', input_size=None, 77 | scale_size=None, normalize=None, augment=True): 78 | normalize = normalize or __imagenet_stats 79 | if name == 'imagenet': 80 | scale_size = scale_size or 256 81 | input_size = input_size or 224 82 | if augment: 83 | return inception_preproccess(input_size, normalize=normalize) 84 | else: 85 | return scale_crop(input_size=input_size, 86 | scale_size=scale_size, normalize=normalize) 87 | elif 'cifar' in name: 88 | input_size = input_size or 32 89 | if augment: 90 | scale_size = scale_size or 40 91 | return pad_random_crop(input_size, scale_size=scale_size, 92 | normalize=normalize, fill=127) 93 | else: 94 | scale_size = scale_size or 32 95 | return scale_crop(input_size=input_size, 96 | scale_size=scale_size, normalize=normalize) 97 | elif name == 'mnist': 98 | normalize = {'mean': [0.5], 'std': [0.5]} 99 | input_size = input_size or 28 100 | if augment: 101 | scale_size = scale_size or 32 102 | return pad_random_crop(input_size, scale_size=scale_size, 103 | normalize=normalize) 104 | else: 105 | scale_size = scale_size or 28 106 | return scale_crop(input_size=input_size, 107 | scale_size=scale_size, normalize=normalize) 108 | 109 | 110 | class Lighting(object): 111 | """Lighting noise(AlexNet - style PCA - based noise)""" 112 | 113 | def __init__(self, alphastd, eigval, eigvec): 114 | self.alphastd = alphastd 115 | self.eigval = eigval 116 | self.eigvec = eigvec 117 | 118 | def __call__(self, img): 119 | if self.alphastd == 0: 120 | return img 121 | 122 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 123 | rgb = self.eigvec.type_as(img).clone()\ 124 | .mul(alpha.view(1, 3).expand(3, 3))\ 125 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 126 | .sum(1).squeeze() 127 | 128 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 129 | 130 | 131 | class Grayscale(object): 132 | 133 | def __call__(self, img): 134 | gs = img.clone() 135 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 136 | gs[1].copy_(gs[0]) 137 | gs[2].copy_(gs[0]) 138 | return gs 139 | 140 | 141 | class Saturation(object): 142 | 143 | def __init__(self, var): 144 | self.var = var 145 | 146 | def __call__(self, img): 147 | gs = Grayscale()(img) 148 | alpha = random.uniform(0, self.var) 149 | return img.lerp(gs, alpha) 150 | 151 | 152 | class Brightness(object): 153 | 154 | def __init__(self, var): 155 | self.var = var 156 | 157 | def __call__(self, img): 158 | gs = img.new().resize_as_(img).zero_() 159 | alpha = random.uniform(0, self.var) 160 | return img.lerp(gs, alpha) 161 | 162 | 163 | class Contrast(object): 164 | 165 | def __init__(self, var): 166 | self.var = var 167 | 168 | def __call__(self, img): 169 | gs = Grayscale()(img) 170 | gs.fill_(gs.mean()) 171 | alpha = random.uniform(0, self.var) 172 | return img.lerp(gs, alpha) 173 | 174 | 175 | class RandomOrder(object): 176 | """ Composes several transforms together in random order. 177 | """ 178 | 179 | def __init__(self, transforms): 180 | self.transforms = transforms 181 | 182 | def __call__(self, img): 183 | if self.transforms is None: 184 | return img 185 | order = torch.randperm(len(self.transforms)) 186 | for i in order: 187 | img = self.transforms[i](img) 188 | return img 189 | 190 | 191 | class ColorJitter(RandomOrder): 192 | 193 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 194 | self.transforms = [] 195 | if brightness != 0: 196 | self.transforms.append(Brightness(brightness)) 197 | if contrast != 0: 198 | self.transforms.append(Contrast(contrast)) 199 | if saturation != 0: 200 | self.transforms.append(Saturation(saturation)) 201 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import logging.config 5 | import shutil 6 | import pandas as pd 7 | from bokeh.io import output_file, save, show 8 | from bokeh.plotting import figure 9 | from bokeh.layouts import column 10 | from bokeh.charts import Line, defaults 11 | from numpy.random import choice 12 | 13 | defaults.width = 800 14 | defaults.height = 400 15 | defaults.tools = 'pan,box_zoom,wheel_zoom,box_select,hover,resize,reset,save' 16 | 17 | 18 | def setup_logging(log_file='log.txt'): 19 | """Setup logging configuration 20 | """ 21 | logging.basicConfig(level=logging.DEBUG, 22 | format="%(asctime)s - %(levelname)s - %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S", 24 | filename=log_file, 25 | filemode='w') 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.INFO) 28 | formatter = logging.Formatter('%(message)s') 29 | console.setFormatter(formatter) 30 | logging.getLogger('').addHandler(console) 31 | 32 | 33 | class ResultsLog(object): 34 | 35 | def __init__(self, path='results.csv', plot_path=None): 36 | self.path = path 37 | self.plot_path = plot_path or (self.path + '.html') 38 | self.figures = [] 39 | self.results = None 40 | 41 | def add(self, **kwargs): 42 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 43 | if self.results is None: 44 | self.results = df 45 | else: 46 | self.results = self.results.append(df, ignore_index=True) 47 | 48 | def save(self, title='Training Results'): 49 | if len(self.figures) > 0: 50 | if os.path.isfile(self.plot_path): 51 | os.remove(self.plot_path) 52 | output_file(self.plot_path, title=title) 53 | plot = column(*self.figures) 54 | save(plot) 55 | self.figures = [] 56 | self.results.to_csv(self.path, index=False, index_label=False) 57 | 58 | def load(self, path=None): 59 | path = path or self.path 60 | if os.path.isfile(path): 61 | self.results.read_csv(path) 62 | 63 | def show(self): 64 | if len(self.figures) > 0: 65 | plot = column(*self.figures) 66 | show(plot) 67 | 68 | def plot(self, *kargs, **kwargs): 69 | line = Line(data=self.results, *kargs, **kwargs) 70 | self.figures.append(line) 71 | 72 | def image(self, *kargs, **kwargs): 73 | fig = figure() 74 | fig.image(*kargs, **kwargs) 75 | self.figures.append(fig) 76 | 77 | 78 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 79 | filename = os.path.join(path, filename) 80 | torch.save(state, filename) 81 | if is_best: 82 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 83 | if save_all: 84 | shutil.copyfile(filename, os.path.join( 85 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 86 | 87 | 88 | class AverageMeter(object): 89 | """Computes and stores the average and current value""" 90 | 91 | def __init__(self): 92 | self.reset() 93 | 94 | def reset(self): 95 | self.val = 0 96 | self.avg = 0 97 | self.sum = 0 98 | self.count = 0 99 | 100 | def update(self, val, n=1): 101 | self.val = val 102 | self.sum += val * n 103 | self.count += n 104 | self.avg = self.sum / self.count 105 | 106 | 107 | class OnlineMeasure(object): 108 | 109 | def __init__(self): 110 | self.mean = torch.FloatTensor(1).fill_(-1) 111 | self.M2 = torch.FloatTensor(1).zero_() 112 | self.count = 0. 113 | self.needs_init = True 114 | 115 | def reset(self, x): 116 | self.mean = x.new(x.size()).zero_() 117 | self.M2 = x.new(x.size()).zero_() 118 | self.count = 0. 119 | self.needs_init = False 120 | 121 | def update(self, x): 122 | self.val = x 123 | if self.needs_init: 124 | self.reset(x) 125 | self.count += 1 126 | delta = x - self.mean 127 | self.mean.add_(delta / self.count) 128 | delta2 = x - self.mean 129 | self.M2.add_(delta * delta2) 130 | 131 | def var(self): 132 | if self.count < 2: 133 | return self.M2.clone().zero_() 134 | return self.M2 / (self.count - 1) 135 | 136 | def std(self): 137 | return self.var().sqrt() 138 | 139 | __optimizers = { 140 | 'SGD': torch.optim.SGD, 141 | 'ASGD': torch.optim.ASGD, 142 | 'Adam': torch.optim.Adam, 143 | 'Adamax': torch.optim.Adamax, 144 | 'Adagrad': torch.optim.Adagrad, 145 | 'Adadelta': torch.optim.Adadelta, 146 | 'Rprop': torch.optim.Rprop, 147 | 'RMSprop': torch.optim.RMSprop 148 | } 149 | 150 | 151 | def adjust_optimizer(optimizer, epoch, config): 152 | """Reconfigures the optimizer according to epoch and config dict""" 153 | def modify_optimizer(optimizer, setting): 154 | if 'optimizer' in setting: 155 | optimizer = __optimizers[setting['optimizer']]( 156 | optimizer.param_groups) 157 | logging.debug('OPTIMIZER - setting method = %s' % 158 | setting['optimizer']) 159 | for param_group in optimizer.param_groups: 160 | for key in param_group.keys(): 161 | if key in setting: 162 | new_val = setting[key] 163 | logging.debug('OPTIMIZER - setting %s = %s' % 164 | (key, new_val)) 165 | param_group[key] = new_val 166 | return optimizer 167 | 168 | if callable(config): 169 | optimizer = modify_optimizer(optimizer, config(epoch)) 170 | else: 171 | for e in range(epoch + 1): # run over all epochs - sticky setting 172 | if e in config: 173 | optimizer = modify_optimizer(optimizer, config[e]) 174 | 175 | return optimizer 176 | 177 | 178 | def accuracy(output, target, topk=(1,)): 179 | """Computes the precision@k for the specified values of k""" 180 | maxk = max(topk) 181 | batch_size = target.size(0) 182 | 183 | _, pred = output.float().topk(maxk, 1, True, True) 184 | pred = pred.t() 185 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 186 | 187 | res = [] 188 | for k in topk: 189 | correct_k = correct[:k].view(-1).float().sum(0) 190 | res.append(correct_k.mul_(100.0 / batch_size)) 191 | return res 192 | 193 | 194 | 195 | class RandomSamplerReplacment(torch.utils.data.sampler.Sampler): 196 | """Samples elements randomly, with replacement. 197 | Arguments: 198 | data_source (Dataset): dataset to sample from 199 | """ 200 | 201 | def __init__(self, data_source): 202 | self.num_samples = len(data_source) 203 | 204 | def __iter__(self): 205 | return iter(torch.from_numpy(choice(self.num_samples, self.num_samples, replace=True))) 206 | 207 | def __len__(self): 208 | return self.num_samples 209 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['resnet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self): 98 | super(ResNet, self).__init__() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes * block.expansion), 107 | ) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.feats(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | 122 | return x 123 | 124 | 125 | class ResNet_imagenet(ResNet): 126 | 127 | def __init__(self, num_classes=1000, 128 | block=Bottleneck, layers=[3, 4, 23, 3]): 129 | super(ResNet_imagenet, self).__init__() 130 | self.inplanes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = nn.BatchNorm2d(64) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 140 | self.avgpool = nn.AvgPool2d(7) 141 | self.feats = nn.Sequential(self.conv1, 142 | self.bn1, 143 | self.relu, 144 | self.maxpool, 145 | 146 | self.layer1, 147 | self.layer2, 148 | self.layer3, 149 | self.layer4, 150 | 151 | self.avgpool) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | init_model(self) 155 | self.regime = { 156 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 157 | 'weight_decay': 1e-4, 'momentum': 0.9}, 158 | 30: {'lr': 1e-2}, 159 | 60: {'lr': 1e-3}, 160 | 90: {'lr': 1e-4} 161 | } 162 | 163 | 164 | class ResNet_cifar10(ResNet): 165 | 166 | def __init__(self, num_classes=10, 167 | block=BasicBlock, depth=18): 168 | super(ResNet_cifar10, self).__init__() 169 | self.inplanes = 16 170 | n = int((depth - 2) / 6) 171 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, 172 | bias=False) 173 | self.bn1 = nn.BatchNorm2d(16) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = lambda x: x 176 | self.layer1 = self._make_layer(block, 16, n) 177 | self.layer2 = self._make_layer(block, 32, n, stride=2) 178 | self.layer3 = self._make_layer(block, 64, n, stride=2) 179 | self.layer4 = lambda x: x 180 | self.avgpool = nn.AvgPool2d(8) 181 | self.fc = nn.Linear(64, num_classes) 182 | self.feats = nn.Sequential(self.conv1, 183 | self.bn1, 184 | self.relu, 185 | self.layer1, 186 | self.layer2, 187 | self.layer3, 188 | self.avgpool) 189 | init_model(self) 190 | 191 | self.regime = { 192 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 193 | 'weight_decay': 1e-4, 'momentum': 0.9}, 194 | 81: {'lr': 1e-2}, 195 | 122: {'lr': 1e-3, 'optimizer': 'SGD'}, 196 | 164: {'lr': 1e-4} 197 | } 198 | 199 | 200 | def resnet(**kwargs): 201 | num_classes, depth, dataset = map( 202 | kwargs.get, ['num_classes', 'depth', 'dataset']) 203 | if dataset == 'imagenet': 204 | num_classes = num_classes or 1000 205 | depth = depth or 50 206 | if depth == 18: 207 | return ResNet_imagenet(num_classes=num_classes, 208 | block=BasicBlock, layers=[2, 2, 2, 2]) 209 | if depth == 34: 210 | return ResNet_imagenet(num_classes=num_classes, 211 | block=BasicBlock, layers=[3, 4, 6, 3]) 212 | if depth == 50: 213 | return ResNet_imagenet(num_classes=num_classes, 214 | block=Bottleneck, layers=[3, 4, 6, 3]) 215 | if depth == 101: 216 | return ResNet_imagenet(num_classes=num_classes, 217 | block=Bottleneck, layers=[3, 4, 23, 3]) 218 | if depth == 152: 219 | return ResNet_imagenet(num_classes=num_classes, 220 | block=Bottleneck, layers=[3, 8, 36, 3]) 221 | 222 | elif dataset == 'cifar10': 223 | num_classes = num_classes or 10 224 | depth = depth or 44 225 | return ResNet_cifar10(num_classes=num_classes, 226 | block=BasicBlock, depth=depth) 227 | elif dataset == 'cifar100': 228 | num_classes = num_classes or 100 229 | depth = depth or 44 230 | return ResNet_cifar10(num_classes=num_classes, 231 | block=BasicBlock, depth=depth) 232 | -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.transforms as transforms 3 | import math 4 | 5 | __all__ = ['wide_WResNet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | def init_model(model): 15 | for m in model.modules(): 16 | if isinstance(m, nn.Conv2d): 17 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 18 | m.weight.data.normal_(0, math.sqrt(2. / n)) 19 | elif isinstance(m, nn.BatchNorm2d): 20 | m.weight.data.fill_(1) 21 | m.bias.data.zero_() 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class WResNet(nn.Module): 96 | 97 | def __init__(self): 98 | super(WResNet, self).__init__() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = nn.Sequential( 104 | nn.Conv2d(self.inplanes, planes * block.expansion, 105 | kernel_size=1, stride=stride, bias=False), 106 | nn.BatchNorm2d(planes * block.expansion), 107 | ) 108 | 109 | layers = [] 110 | layers.append(block(self.inplanes, planes, stride, downsample)) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append(block(self.inplanes, planes)) 114 | 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x): 118 | x = self.feats(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | 122 | return x 123 | 124 | 125 | class WResNet_imagenet(WResNet): 126 | 127 | def __init__(self, num_classes=1000, 128 | block=Bottleneck, layers=[3, 4, 23, 3]): 129 | super(WResNet_imagenet, self).__init__() 130 | self.inplanes = 64 131 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 132 | bias=False) 133 | self.bn1 = nn.BatchNorm2d(64) 134 | self.relu = nn.ReLU(inplace=True) 135 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 136 | self.layer1 = self._make_layer(block, 64, layers[0]) 137 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 138 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 139 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 140 | self.avgpool = nn.AvgPool2d(7) 141 | self.feats = nn.Sequential(self.conv1, 142 | self.bn1, 143 | self.relu, 144 | self.maxpool, 145 | 146 | self.layer1, 147 | self.layer2, 148 | self.layer3, 149 | self.layer4, 150 | 151 | self.avgpool) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | init_model(self) 155 | self.regime = { 156 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 'weight_decay': 1e-4, 'momentum': 0.9}, 157 | 30: {'lr': 1e-2}, 158 | 60: {'lr': 1e-3}, 159 | 90: {'lr': 1e-4} 160 | } 161 | 162 | 163 | class WResNet_cifar10(WResNet): 164 | 165 | def __init__(self, num_classes=10, multiplier=1, 166 | block=BasicBlock, depth=18): 167 | super(WResNet_cifar10, self).__init__() 168 | self.inplanes = 16 * multiplier 169 | n = int((depth - 2) / 6) 170 | self.conv1 = nn.Conv2d(3, 16 * multiplier, kernel_size=3, stride=1, padding=1, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(16 * multiplier) 173 | self.relu = nn.ReLU(inplace=True) 174 | self.maxpool = lambda x: x 175 | self.layer1 = self._make_layer(block, 16 * multiplier, n) 176 | self.layer2 = self._make_layer(block, 32 * multiplier, n, stride=2) 177 | self.layer3 = self._make_layer(block, 64 * multiplier, n, stride=2) 178 | self.layer4 = lambda x: x 179 | self.avgpool = nn.AvgPool2d(8) 180 | self.fc = nn.Linear(64 * multiplier, num_classes) 181 | self.feats = nn.Sequential(self.conv1, 182 | self.bn1, 183 | self.relu, 184 | self.layer1, 185 | self.layer2, 186 | self.layer3, 187 | self.avgpool) 188 | init_model(self) 189 | 190 | self.regime = { 191 | 0: {'optimizer': 'SGD', 'lr': 1e-1, 192 | 'weight_decay': 1e-4, 'momentum': 0.9}, 193 | 60: {'lr': 2e-2}, 194 | 120: {'lr': 4e-3}, 195 | 140: {'lr': 1e-4} 196 | } 197 | 198 | 199 | def wide_WResNet(**kwargs): 200 | num_classes, depth, dataset = map( 201 | kwargs.get, ['num_classes', 'depth', 'dataset']) 202 | if dataset == 'imagenet': 203 | num_classes = num_classes or 1000 204 | depth = depth or 50 205 | if depth == 18: 206 | return WResNet_imagenet(num_classes=num_classes, 207 | block=BasicBlock, layers=[2, 2, 2, 2]) 208 | if depth == 34: 209 | return WResNet_imagenet(num_classes=num_classes, 210 | block=BasicBlock, layers=[3, 4, 6, 3]) 211 | if depth == 50: 212 | return WResNet_imagenet(num_classes=num_classes, 213 | block=Bottleneck, layers=[3, 4, 6, 3]) 214 | if depth == 101: 215 | return WResNet_imagenet(num_classes=num_classes, 216 | block=Bottleneck, layers=[3, 4, 23, 3]) 217 | if depth == 152: 218 | return WResNet_imagenet(num_classes=num_classes, 219 | block=Bottleneck, layers=[3, 8, 36, 3]) 220 | 221 | elif dataset == 'cifar10': 222 | num_classes = num_classes or 10 223 | depth = depth or 16 224 | return WResNet_cifar10(num_classes=num_classes, 225 | block=BasicBlock, depth=depth, multiplier=4) 226 | elif dataset == 'cifar100': 227 | num_classes = num_classes or 100 228 | depth = depth or 16 229 | return WResNet_cifar10(num_classes=num_classes, 230 | block=BasicBlock, depth=depth, multiplier=4) 231 | -------------------------------------------------------------------------------- /main_normal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import logging 5 | from datetime import datetime 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim 11 | import torch.utils.data 12 | import models 13 | from torch.autograd import Variable 14 | from data import get_dataset 15 | from preprocess import get_transform 16 | from utils import * 17 | from ast import literal_eval 18 | from torch.nn.utils import clip_grad_norm 19 | from math import ceil 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(models.__dict__[name])) 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 26 | 27 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', 28 | default='./TrainingResults', help='results dir') 29 | parser.add_argument('--save', metavar='SAVE', default='', 30 | help='saved folder') 31 | parser.add_argument('--dataset', metavar='DATASET', default='cifar10', 32 | help='dataset name or folder') 33 | parser.add_argument('--model', '-a', metavar='MODEL', default='resnet', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: alexnet)') 38 | parser.add_argument('--input_size', type=int, default=None, 39 | help='image input size') 40 | parser.add_argument('--model_config', default='', 41 | help='additional architecture configuration') 42 | parser.add_argument('--type', default='torch.cuda.FloatTensor', 43 | help='type of tensor - e.g torch.cuda.HalfTensor') 44 | parser.add_argument('--gpus', default='0', 45 | help='gpus used for training - e.g 0,1,3') 46 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 47 | help='number of data loading workers (default: 8)') 48 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 49 | help='number of total epochs to run') 50 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 51 | help='manual epoch number (useful on restarts)') 52 | parser.add_argument('-b', '--batch-size', default=256, type=int, 53 | metavar='N', help='mini-batch size (default: 256)') 54 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 55 | help='optimizer function used') 56 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 57 | metavar='LR', help='initial learning rate') 58 | parser.add_argument('--lr_bb_fix', dest='lr_bb_fix', action='store_true', 59 | help='learning rate fix for big batch lr = lr0*(batch_size/128)**0.5') 60 | parser.add_argument('--no-lr_bb_fix', dest='lr_bb_fix', action='store_false', 61 | help='learning rate fix for big batch lr = lr0*(batch_size/128)**0.5') 62 | parser.set_defaults(lr_bb_fix=True) 63 | parser.add_argument('--regime_bb_fix', dest='regime_bb_fix', action='store_true', 64 | help='regime fix for big batch e = e0*(batch_size/128)') 65 | parser.add_argument('--no-regime_bb_fix', dest='regime_bb_fix', action='store_false', 66 | help='regime fix for big batch e = e0*(batch_size/128)') 67 | parser.set_defaults(regime_bb_fix=False) 68 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 69 | help='momentum') 70 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 71 | metavar='W', help='weight decay (default: 1e-4)') 72 | parser.add_argument('--print-freq', '-p', default=10, type=int, 73 | metavar='N', help='print frequency (default: 10)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', 77 | help='evaluate model FILE on validation set') 78 | 79 | 80 | def main(): 81 | torch.manual_seed(123) 82 | global args, best_prec1 83 | best_prec1 = 0 84 | args = parser.parse_args() 85 | if args.regime_bb_fix: 86 | args.epochs *= ceil(args.batch_size / 256.) 87 | 88 | if args.evaluate: 89 | args.results_dir = '/tmp' 90 | if args.save is '': 91 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 92 | save_path = os.path.join(args.results_dir, args.save) 93 | if not os.path.exists(save_path): 94 | os.makedirs(save_path) 95 | 96 | setup_logging(os.path.join(save_path, 'log.txt')) 97 | results_file = os.path.join(save_path, 'results.%s') 98 | results = ResultsLog(results_file % 'csv', results_file % 'html') 99 | 100 | logging.info("saving to %s", save_path) 101 | logging.debug("run arguments: %s", args) 102 | 103 | if 'cuda' in args.type: 104 | torch.cuda.manual_seed(123) 105 | args.gpus = [int(i) for i in args.gpus.split(',')] 106 | torch.cuda.set_device(args.gpus[0]) 107 | cudnn.benchmark = True 108 | else: 109 | args.gpus = None 110 | 111 | # create model 112 | logging.info("creating model %s", args.model) 113 | model = models.__dict__[args.model] 114 | model_config = {'input_size': args.input_size, 'dataset': args.dataset} 115 | 116 | if args.model_config is not '': 117 | model_config = dict(model_config, **literal_eval(args.model_config)) 118 | 119 | model = model(**model_config) 120 | logging.info("created model with configuration: %s", model_config) 121 | 122 | # optionally resume from a checkpoint 123 | if args.evaluate: 124 | if not os.path.isfile(args.evaluate): 125 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 126 | checkpoint = torch.load(args.evaluate) 127 | model.load_state_dict(checkpoint['state_dict']) 128 | logging.info("loaded checkpoint '%s' (epoch %s)", 129 | args.evaluate, checkpoint['epoch']) 130 | elif args.resume: 131 | checkpoint_file = args.resume 132 | if os.path.isdir(checkpoint_file): 133 | results.load(os.path.join(checkpoint_file, 'results.csv')) 134 | checkpoint_file = os.path.join( 135 | checkpoint_file, 'model_best.pth.tar') 136 | if os.path.isfile(checkpoint_file): 137 | logging.info("loading checkpoint '%s'", args.resume) 138 | checkpoint = torch.load(checkpoint_file) 139 | args.start_epoch = checkpoint['epoch'] - 1 140 | best_prec1 = checkpoint['best_prec1'] 141 | model.load_state_dict(checkpoint['state_dict']) 142 | logging.info("loaded checkpoint '%s' (epoch %s)", 143 | checkpoint_file, checkpoint['epoch']) 144 | else: 145 | logging.error("no checkpoint found at '%s'", args.resume) 146 | 147 | num_parameters = sum([l.nelement() for l in model.parameters()]) 148 | logging.info("number of parameters: %d", num_parameters) 149 | 150 | # Data loading code 151 | default_transform = { 152 | 'train': get_transform(args.dataset, 153 | input_size=args.input_size, augment=True), 154 | 'eval': get_transform(args.dataset, 155 | input_size=args.input_size, augment=False) 156 | } 157 | transform = getattr(model, 'input_transform', default_transform) 158 | regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer, 159 | 'lr': args.lr, 160 | 'momentum': args.momentum, 161 | 'weight_decay': args.weight_decay}}) 162 | adapted_regime = {} 163 | for e, v in regime.items(): 164 | if args.lr_bb_fix and 'lr' in v: 165 | v['lr'] *= (args.batch_size / 256.) ** 0.5 166 | if args.regime_bb_fix: 167 | e *= ceil(args.batch_size / 256.) 168 | adapted_regime[e] = v 169 | regime = adapted_regime 170 | 171 | # define loss function (criterion) and optimizer 172 | criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)() 173 | criterion.type(args.type) 174 | model.type(args.type) 175 | 176 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 177 | val_loader = torch.utils.data.DataLoader( 178 | val_data, 179 | batch_size=args.batch_size, shuffle=False, 180 | num_workers=args.workers, pin_memory=True) 181 | 182 | if args.evaluate: 183 | validate(val_loader, model, criterion, 0) 184 | return 185 | 186 | train_data = get_dataset(args.dataset, 'train', transform['train']) 187 | train_loader = torch.utils.data.DataLoader( 188 | train_data, 189 | batch_size=args.batch_size, shuffle=True, 190 | num_workers=args.workers, pin_memory=True) 191 | 192 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 193 | logging.info('training regime: %s', regime) 194 | init_weights = [w.data.cpu().clone() for w in list(model.parameters())] 195 | 196 | for epoch in range(args.start_epoch, args.epochs): 197 | optimizer = adjust_optimizer(optimizer, epoch, regime) 198 | 199 | # train for one epoch 200 | train_loss, train_prec1, train_prec5 = train( 201 | train_loader, model, criterion, epoch, optimizer) 202 | 203 | # evaluate on validation set 204 | val_loss, val_prec1, val_prec5 = validate( 205 | val_loader, model, criterion, epoch) 206 | 207 | # remember best prec@1 and save checkpoint 208 | is_best = val_prec1 > best_prec1 209 | best_prec1 = max(val_prec1, best_prec1) 210 | save_checkpoint({ 211 | 'epoch': epoch + 1, 212 | 'model': args.model, 213 | 'config': args.model_config, 214 | 'state_dict': model.state_dict(), 215 | 'best_prec1': best_prec1, 216 | 'regime': regime 217 | }, is_best, path=save_path) 218 | logging.info('\n Epoch: {0}\t' 219 | 'Training Loss {train_loss:.4f} \t' 220 | 'Training Prec@1 {train_prec1:.3f} \t' 221 | 'Training Prec@5 {train_prec5:.3f} \t' 222 | 'Validation Loss {val_loss:.4f} \t' 223 | 'Validation Prec@1 {val_prec1:.3f} \t' 224 | 'Validation Prec@5 {val_prec5:.3f} \n' 225 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 226 | train_prec1=train_prec1, val_prec1=val_prec1, 227 | train_prec5=train_prec5, val_prec5=val_prec5)) 228 | 229 | #Enable to measure more layers 230 | idxs = [0]#,2,4,6,7,8,9,10]#[0, 12, 45, 63] 231 | step_dist_epoch = {'step_dist_n%s' % k: (w.data.cpu() - init_weights[k]).norm() 232 | for (k, w) in enumerate(list(model.parameters())) if k in idxs} 233 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 234 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 235 | train_error5=100 - train_prec5, val_error5=100 - val_prec5, 236 | **step_dist_epoch) 237 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 238 | title='Loss', ylabel='loss') 239 | results.plot(x='epoch', y=['train_error1', 'val_error1'], 240 | title='Error@1', ylabel='error %') 241 | results.plot(x='epoch', y=['train_error5', 'val_error5'], 242 | title='Error@5', ylabel='error %') 243 | for k in idxs: 244 | results.plot(x='epoch', y=['step_dist_n%s' % k], 245 | title='step distance per epoch %s' % k, 246 | ylabel='val') 247 | 248 | results.save() 249 | 250 | 251 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): 252 | if args.gpus and len(args.gpus) > 1: 253 | model = torch.nn.DataParallel(model, args.gpus) 254 | batch_time = AverageMeter() 255 | data_time = AverageMeter() 256 | losses = AverageMeter() 257 | top1 = AverageMeter() 258 | top5 = AverageMeter() 259 | 260 | end = time.time() 261 | for i, (inputs, target) in enumerate(data_loader): 262 | # measure data loading time 263 | data_time.update(time.time() - end) 264 | if args.gpus is not None: 265 | target = target.cuda(async=True) 266 | input_var = Variable(inputs.type(args.type), volatile=not training) 267 | target_var = Variable(target) 268 | 269 | # compute output 270 | output = model(input_var) 271 | loss = criterion(output, target_var) 272 | if type(output) is list: 273 | output = output[0] 274 | 275 | # measure accuracy and record loss 276 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 277 | losses.update(loss.data[0], inputs.size(0)) 278 | top1.update(prec1[0], inputs.size(0)) 279 | top5.update(prec5[0], inputs.size(0)) 280 | 281 | if training: 282 | # compute gradient and do SGD step 283 | optimizer.zero_grad() 284 | loss.backward() 285 | clip_grad_norm(model.parameters(), 5.) 286 | optimizer.step() 287 | 288 | # measure elapsed time 289 | batch_time.update(time.time() - end) 290 | end = time.time() 291 | 292 | if i % args.print_freq == 0: 293 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 294 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 295 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 296 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 297 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 298 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 299 | epoch, i, len(data_loader), 300 | phase='TRAINING' if training else 'EVALUATING', 301 | batch_time=batch_time, 302 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 303 | 304 | return losses.avg, top1.avg, top5.avg 305 | 306 | 307 | def train(data_loader, model, criterion, epoch, optimizer): 308 | # switch to train mode 309 | model.train() 310 | return forward(data_loader, model, criterion, epoch, 311 | training=True, optimizer=optimizer) 312 | 313 | 314 | def validate(data_loader, model, criterion, epoch): 315 | # switch to evaluate mode 316 | model.eval() 317 | return forward(data_loader, model, criterion, epoch, 318 | training=False, optimizer=None) 319 | 320 | 321 | if __name__ == '__main__': 322 | main() 323 | -------------------------------------------------------------------------------- /main_gbn.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import argparse 3 | import os 4 | import time 5 | import logging 6 | from random import uniform 7 | from datetime import datetime 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import models 15 | from torch.autograd import Variable 16 | from data import get_dataset 17 | from preprocess import get_transform 18 | from utils import * 19 | from ast import literal_eval 20 | from torch.nn.utils import clip_grad_norm 21 | from math import ceil 22 | 23 | model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 28 | 29 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', 30 | default='./TrainingResults', help='results dir') 31 | parser.add_argument('--save', metavar='SAVE', default='', 32 | help='saved folder') 33 | parser.add_argument('--dataset', metavar='DATASET', default='cifar10', 34 | help='dataset name or folder') 35 | parser.add_argument('--model', '-a', metavar='MODEL', default='resnet', 36 | choices=model_names, 37 | help='model architecture: ' + 38 | ' | '.join(model_names) + 39 | ' (default: alexnet)') 40 | parser.add_argument('--input_size', type=int, default=None, 41 | help='image input size') 42 | parser.add_argument('--model_config', default='', 43 | help='additional architecture configuration') 44 | parser.add_argument('--type', default='torch.cuda.FloatTensor', 45 | help='type of tensor - e.g torch.cuda.HalfTensor') 46 | parser.add_argument('--gpus', default='0', 47 | help='gpus used for training - e.g 0,1,3') 48 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 49 | help='number of data loading workers (default: 8)') 50 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 51 | help='number of total epochs to run') 52 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 53 | help='manual epoch number (useful on restarts)') 54 | parser.add_argument('-b', '--batch-size', default=2048, type=int, 55 | metavar='N', help='mini-batch size (default: 2048)') 56 | parser.add_argument('-mb', '--mini-batch-size', default=128, type=int, 57 | help='mini-mini-batch size (default: 64)') 58 | parser.add_argument('--lr_bb_fix', dest='lr_bb_fix', action='store_true', 59 | help='learning rate fix for big batch lr = lr0*(batch_size/128)**0.5') 60 | parser.add_argument('--no-lr_bb_fix', dest='lr_bb_fix', action='store_false', 61 | help='learning rate fix for big batch lr = lr0*(batch_size/128)**0.5') 62 | parser.set_defaults(lr_bb_fix=True) 63 | parser.add_argument('--regime_bb_fix', dest='regime_bb_fix', action='store_true', 64 | help='regime fix for big batch e = e0*(batch_size/128)') 65 | parser.add_argument('--no-regime_bb_fix', dest='regime_bb_fix', action='store_false', 66 | help='regime fix for big batch e = e0*(batch_size/128)') 67 | parser.set_defaults(regime_bb_fix=False) 68 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 69 | help='optimizer function used') 70 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 71 | metavar='LR', help='initial learning rate') 72 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 73 | help='momentum') 74 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 75 | metavar='W', help='weight decay (default: 1e-4)') 76 | parser.add_argument('--print-freq', '-p', default=10, type=int, 77 | metavar='N', help='print frequency (default: 10)') 78 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 79 | help='path to latest checkpoint (default: none)') 80 | parser.add_argument('-e', '--evaluate', type=str, metavar='FILE', 81 | help='evaluate model FILE on validation set') 82 | 83 | 84 | def main(): 85 | torch.manual_seed(123) 86 | global args, best_prec1 87 | best_prec1 = 0 88 | args = parser.parse_args() 89 | if args.regime_bb_fix: 90 | args.epochs *= ceil(args.batch_size / args.mini_batch_size) 91 | 92 | if args.evaluate: 93 | args.results_dir = '/tmp' 94 | if args.save is '': 95 | args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 96 | save_path = os.path.join(args.results_dir, args.save) 97 | if not os.path.exists(save_path): 98 | os.makedirs(save_path) 99 | 100 | setup_logging(os.path.join(save_path, 'log.txt')) 101 | results_file = os.path.join(save_path, 'results.%s') 102 | results = ResultsLog(results_file % 'csv', results_file % 'html') 103 | 104 | logging.info("saving to %s", save_path) 105 | logging.debug("run arguments: %s", args) 106 | 107 | if 'cuda' in args.type: 108 | torch.cuda.manual_seed(123) 109 | args.gpus = [int(i) for i in args.gpus.split(',')] 110 | torch.cuda.set_device(args.gpus[0]) 111 | cudnn.benchmark = True 112 | else: 113 | args.gpus = None 114 | 115 | # create model 116 | logging.info("creating model %s", args.model) 117 | model = models.__dict__[args.model] 118 | model_config = {'input_size': args.input_size, 'dataset': args.dataset} 119 | 120 | if args.model_config is not '': 121 | model_config = dict(model_config, **literal_eval(args.model_config)) 122 | 123 | model = model(**model_config) 124 | logging.info("created model with configuration: %s", model_config) 125 | 126 | # optionally resume from a checkpoint 127 | if args.evaluate: 128 | if not os.path.isfile(args.evaluate): 129 | parser.error('invalid checkpoint: {}'.format(args.evaluate)) 130 | checkpoint = torch.load(args.evaluate) 131 | model.load_state_dict(checkpoint['state_dict']) 132 | logging.info("loaded checkpoint '%s' (epoch %s)", 133 | args.evaluate, checkpoint['epoch']) 134 | elif args.resume: 135 | checkpoint_file = args.resume 136 | if os.path.isdir(checkpoint_file): 137 | results.load(os.path.join(checkpoint_file, 'results.csv')) 138 | checkpoint_file = os.path.join( 139 | checkpoint_file, 'model_best.pth.tar') 140 | if os.path.isfile(checkpoint_file): 141 | logging.info("loading checkpoint '%s'", args.resume) 142 | checkpoint = torch.load(checkpoint_file) 143 | args.start_epoch = checkpoint['epoch'] + 1 144 | best_prec1 = checkpoint['best_prec1'] 145 | model.load_state_dict(checkpoint['state_dict']) 146 | logging.info("loaded checkpoint '%s' (epoch %s)", 147 | checkpoint_file, checkpoint['epoch']) 148 | else: 149 | logging.error("no checkpoint found at '%s'", args.resume) 150 | 151 | num_parameters = sum([l.nelement() for l in model.parameters()]) 152 | logging.info("number of parameters: %d", num_parameters) 153 | 154 | # Data loading code 155 | default_transform = { 156 | 'train': get_transform(args.dataset, 157 | input_size=args.input_size, augment=True), 158 | 'eval': get_transform(args.dataset, 159 | input_size=args.input_size, augment=False) 160 | } 161 | transform = getattr(model, 'input_transform', default_transform) 162 | regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer, 163 | 'lr': args.lr, 164 | 'momentum': args.momentum, 165 | 'weight_decay': args.weight_decay}}) 166 | adapted_regime = {} 167 | for e, v in regime.items(): 168 | if args.lr_bb_fix and 'lr' in v: 169 | v['lr'] *= (args.batch_size / args.mini_batch_size) ** 0.5 170 | if args.regime_bb_fix: 171 | e *= ceil(args.batch_size / args.mini_batch_size) 172 | adapted_regime[e] = v 173 | regime = adapted_regime 174 | # define loss function (criterion) and optimizer 175 | criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)() 176 | criterion.type(args.type) 177 | model.type(args.type) 178 | 179 | val_data = get_dataset(args.dataset, 'val', transform['eval']) 180 | val_loader = torch.utils.data.DataLoader( 181 | val_data, 182 | batch_size=args.batch_size, shuffle=False, 183 | num_workers=args.workers, pin_memory=True) 184 | 185 | if args.evaluate: 186 | validate(val_loader, model, criterion, 0) 187 | return 188 | 189 | train_data = get_dataset(args.dataset, 'train', transform['train']) 190 | train_loader = torch.utils.data.DataLoader( 191 | train_data, 192 | batch_size=args.batch_size, shuffle=True, 193 | num_workers=args.workers, pin_memory=True) 194 | 195 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) 196 | logging.info('training regime: %s', regime) 197 | print({i: list(w.size()) 198 | for (i, w) in enumerate(list(model.parameters()))}) 199 | init_weights = [w.data.cpu().clone() for w in list(model.parameters())] 200 | 201 | for epoch in range(args.start_epoch, args.epochs): 202 | optimizer = adjust_optimizer(optimizer, epoch, regime) 203 | 204 | # train for one epoch 205 | train_result = train(train_loader, model, criterion, epoch, optimizer) 206 | 207 | train_loss, train_prec1, train_prec5 = [ 208 | train_result[r] for r in ['loss', 'prec1', 'prec5']] 209 | 210 | # evaluate on validation set 211 | val_result = validate(val_loader, model, criterion, epoch) 212 | val_loss, val_prec1, val_prec5 = [val_result[r] 213 | for r in ['loss', 'prec1', 'prec5']] 214 | 215 | # remember best prec@1 and save checkpoint 216 | is_best = val_prec1 > best_prec1 217 | best_prec1 = max(val_prec1, best_prec1) 218 | save_checkpoint({ 219 | 'epoch': epoch + 1, 220 | 'model': args.model, 221 | 'config': args.model_config, 222 | 'state_dict': model.state_dict(), 223 | 'best_prec1': best_prec1, 224 | 'regime': regime 225 | }, is_best, path=save_path) 226 | logging.info('\n Epoch: {0}\t' 227 | 'Training Loss {train_loss:.4f} \t' 228 | 'Training Prec@1 {train_prec1:.3f} \t' 229 | 'Training Prec@5 {train_prec5:.3f} \t' 230 | 'Validation Loss {val_loss:.4f} \t' 231 | 'Validation Prec@1 {val_prec1:.3f} \t' 232 | 'Validation Prec@5 {val_prec5:.3f} \n' 233 | .format(epoch + 1, train_loss=train_loss, val_loss=val_loss, 234 | train_prec1=train_prec1, val_prec1=val_prec1, 235 | train_prec5=train_prec5, val_prec5=val_prec5)) 236 | 237 | #Enable to measure more layers 238 | idxs = [0]#,2,4,6,7,8,9,10]#[0, 12, 45, 63] 239 | 240 | step_dist_epoch = {'step_dist_n%s' % k: (w.data.cpu() - init_weights[k]).norm() 241 | for (k, w) in enumerate(list(model.parameters())) if k in idxs} 242 | 243 | 244 | results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss, 245 | train_error1=100 - train_prec1, val_error1=100 - val_prec1, 246 | train_error5=100 - train_prec5, val_error5=100 - val_prec5, 247 | **step_dist_epoch) 248 | 249 | results.plot(x='epoch', y=['train_loss', 'val_loss'], 250 | title='Loss', ylabel='loss') 251 | results.plot(x='epoch', y=['train_error1', 'val_error1'], 252 | title='Error@1', ylabel='error %') 253 | results.plot(x='epoch', y=['train_error5', 'val_error5'], 254 | title='Error@5', ylabel='error %') 255 | 256 | for k in idxs: 257 | results.plot(x='epoch', y=['step_dist_n%s' % k], 258 | title='step distance per epoch %s' % k, 259 | ylabel='val') 260 | 261 | results.save() 262 | 263 | 264 | def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None): 265 | if args.gpus and len(args.gpus) > 1: 266 | model = torch.nn.DataParallel(model, args.gpus) 267 | 268 | batch_time = AverageMeter() 269 | data_time = AverageMeter() 270 | losses = AverageMeter() 271 | top1 = AverageMeter() 272 | top5 = AverageMeter() 273 | 274 | end = time.time() 275 | 276 | 277 | for i, (inputs, target) in enumerate(data_loader): 278 | # measure data loading time 279 | data_time.update(time.time() - end) 280 | if args.gpus is not None: 281 | target = target.cuda(async=True) 282 | input_var = Variable(inputs.type(args.type), volatile=not training) 283 | target_var = Variable(target) 284 | 285 | # compute output 286 | if not training: 287 | output = model(input_var) 288 | loss = criterion(output, target_var) 289 | 290 | # measure accuracy and record loss 291 | prec1, prec5 = accuracy(output.data, target_var.data, topk=(1, 5)) 292 | losses.update(loss.data[0], input_var.size(0)) 293 | top1.update(prec1[0], input_var.size(0)) 294 | top5.update(prec5[0], input_var.size(0)) 295 | 296 | else: 297 | 298 | mini_inputs = input_var.chunk(args.batch_size // args.mini_batch_size) 299 | mini_targets = target_var.chunk(args.batch_size // args.mini_batch_size) 300 | 301 | 302 | optimizer.zero_grad() 303 | 304 | for k, mini_input_var in enumerate(mini_inputs): 305 | mini_target_var = mini_targets[k] 306 | output = model(mini_input_var) 307 | loss = criterion(output, mini_target_var) 308 | 309 | prec1, prec5 = accuracy(output.data, mini_target_var.data, topk=(1, 5)) 310 | losses.update(loss.data[0], mini_input_var.size(0)) 311 | top1.update(prec1[0], mini_input_var.size(0)) 312 | top5.update(prec5[0], mini_input_var.size(0)) 313 | 314 | # compute gradient and do SGD step 315 | loss.backward() 316 | 317 | for p in model.parameters(): 318 | p.grad.data.div_(len(mini_inputs)) 319 | clip_grad_norm(model.parameters(), 5.) 320 | optimizer.step() 321 | 322 | 323 | # measure elapsed time 324 | batch_time.update(time.time() - end) 325 | end = time.time() 326 | 327 | if i % args.print_freq == 0: 328 | logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t' 329 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 330 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 331 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 332 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 333 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 334 | epoch, i, len(data_loader), 335 | phase='TRAINING' if training else 'EVALUATING', 336 | batch_time=batch_time, 337 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 338 | 339 | return {'loss': losses.avg, 340 | 'prec1': top1.avg, 341 | 'prec5': top5.avg} 342 | 343 | 344 | def train(data_loader, model, criterion, epoch, optimizer): 345 | # switch to train mode 346 | model.train() 347 | return forward(data_loader, model, criterion, epoch, 348 | training=True, optimizer=optimizer) 349 | 350 | 351 | def validate(data_loader, model, criterion, epoch): 352 | # switch to evaluate mode 353 | model.eval() 354 | return forward(data_loader, model, criterion, epoch, 355 | training=False, optimizer=None) 356 | 357 | 358 | if __name__ == '__main__': 359 | main() 360 | --------------------------------------------------------------------------------