├── .gitignore ├── README.md ├── datasets ├── CIFAR.py ├── LSUN.py ├── SVHN.py └── __init__.py ├── models ├── Densenet.py ├── Densenet_BC.py ├── WideResnet.py ├── __init__.py └── classifiers.py ├── setup └── requirements.txt ├── train.py └── utils ├── __init__.py └── args.py /.gitignore: -------------------------------------------------------------------------------- 1 | workspace/ 2 | lightning_logs/ 3 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Train CIFAR10,CIFAR100 with Pytorch-lightning 2 | Measure CNN,VGG,Resnet,WideResnet models' accuracy on dataset CIFAR10,CIFAR100 using [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning). 3 | 4 | ## Requirements 5 | - setup/requirements.txt 6 | ```bash 7 | torch 1.5.1 8 | torchvision 0.6.1 9 | pytorch-lightning 0.9.0rc5 10 | tqdm 11 | argparse 12 | pytablewriter 13 | seaborn 14 | enum34 15 | scipy 16 | cffi 17 | sklearn 18 | ``` 19 | 20 | - install requirements using pip 21 | ```bash 22 | pip3 install -r setup/requirements.txt 23 | ``` 24 | 25 | ## How to run 26 | After you have cloned the repository, you can train each models with datasets cifar10, cifar100. Trainable models are [VGG](https://arxiv.org/abs/1409.1556), [Resnet](https://arxiv.org/abs/1512.03385), [WideResnet](https://arxiv.org/pdf/1605.07146.pdf), [Densenet-BC](https://arxiv.org/pdf/1608.06993.pdf), [Densenet](https://arxiv.org/abs/1608.06993). 27 | 28 | ```bash 29 | python train.py 30 | ``` 31 | 32 | ## Implementation Details 33 | - CIFAR10 34 | 35 | | epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov | 36 | |:---------:|:-------------:|:-------------:|:---------:|:--------:|:---------:| 37 | | 0 ~ 20 | 0.1 | 0.0005 | SGD | 0.9 | False | 38 | | 21 ~ 40 | 0.01 | 0.0005 | SGD | 0.9 | False | 39 | | 41 ~ 60 | 0.001 | 0.0005 | SGD | 0.9 | False | 40 | 41 | - CIFAR100 42 | 43 | | epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov | 44 | |:---------:|:-------------:|:-------------:|:---------:|:--------:|:---------:| 45 | | 0 ~ 60 | 0.1 | 0.0005 | SGD | 0.9 | False | 46 | | 61 ~ 120 | 0.01 | 0.0005 | SGD | 0.9 | False | 47 | | 121 ~ 180 | 0.001 | 0.0005 | SGD | 0.9 | False | 48 | 49 | ## Accuracy 50 | Below is the result of the test set accuracy for CIFAR-10, CIFAR-100 dataset training 51 | 52 | **Accuracy of models trained on CIFAR10** 53 | | network | dropout | preprocess | parameters | accuracy(%) | 54 | |:-----------------:|:-------:|:----------:|:----------:|:-----------:| 55 | | VGG16 | 0 | meanstd | 14M | 91.09 | 56 | | Resnet-50 | 0 | meanstd | 23M | 92.11 | 57 | | WideResnet 28x10 | 0.3 | meanstd | 36M | 93.61 | 58 | | Densenet-BC | 0 | meanstd | 769K | 92.85 | 59 | | Densenet | 0 | meanstd | 769K | 93.06 | 60 | 61 | 62 | **Accuracy of models trained on CIFAR100** 63 | | network | dropout | preprocess | parameters | accuracy(%) | 64 | |:-----------------:|:-------:|:----------:|:----------:|:-----------:| 65 | | VGG16 | 0 | meanstd | 14M | 72.79 | 66 | | Resnet-50 | 0 | meanstd | 23M | 75.80 | 67 | | WideResnet 28x20 | 0.3 | meanstd | 145M | 75.46 | 68 | | Densenet-BC | 0 | meanstd | 800K | 72.23 | 69 | | Densenet | 0 | meanstd | 800K | 75.58 | 70 | -------------------------------------------------------------------------------- /datasets/CIFAR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | from torchvision import datasets 5 | 6 | import pytorch_lightning as pl 7 | 8 | class CIFAR10DataModule(pl.LightningDataModule): 9 | def __init__(self): 10 | super().__init__() 11 | self.mean = [0.4914, 0.4822, 0.4465] 12 | self.std = [0.2023, 0.1994, 0.2010] 13 | self.transform = transforms.Compose([transforms.RandomCrop(32,padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | transforms.Normalize(self.mean, self.std)]) 17 | 18 | def prepare_data(self): 19 | datasets.CIFAR10(root='./workspace/datasets/cifar10',train=True,download=True, transform=self.transform) 20 | datasets.CIFAR10(root='./workspace/datasets/cifar10',train=False,download=True, transform=self.transform) 21 | 22 | def setup(self, stage): 23 | cifar_train = datasets.CIFAR10(root='./workspace/datasets/cifar10',train=True,download=True, transform=self.transform) 24 | self.cifar_test = datasets.CIFAR10(root='./workspace/datasets/cifar10',train=False,download=True, transform=self.transform) 25 | self.cifar_train = cifar_train 26 | 27 | def train_dataloader(self): 28 | cifar_train = DataLoader(self.cifar_train, batch_size=64, shuffle=True, num_workers=8) 29 | return cifar_train 30 | 31 | def val_dataloader(self): 32 | cifar_val = DataLoader(self.cifar_test, batch_size=64, shuffle=False, num_workers=8) 33 | return cifar_val 34 | 35 | def test_dataloader(self): 36 | return DataLoader(self.cifar_test, batch_size=64, shuffle=False, num_workers=8) 37 | 38 | 39 | 40 | class CIFAR100DataModule(pl.LightningDataModule): 41 | def __init__(self): 42 | super().__init__() 43 | self.mean = [0.4914, 0.4822, 0.4465] 44 | self.std = [0.2023, 0.1994, 0.2010] 45 | self.transform = transforms.Compose([transforms.RandomCrop(32,padding=4), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(self.mean, self.std)]) 49 | 50 | def prepare_data(self): 51 | datasets.CIFAR100(root='./workspace/datasets/cifar100',train=True,download=True, transform=self.transform) 52 | datasets.CIFAR100(root='./workspace/datasets/cifar100',train=False,download=True, transform=self.transform) 53 | 54 | def setup(self, stage): 55 | cifar_train = datasets.CIFAR100(root='./workspace/datasets/cifar100',train=True,download=True, transform=self.transform) 56 | self.cifar_test = datasets.CIFAR100(root='./workspace/datasets/cifar100',train=False,download=True, transform=self.transform) 57 | self.cifar_train = cifar_train 58 | 59 | def train_dataloader(self): 60 | cifar_train = DataLoader(self.cifar_train, batch_size=128, shuffle=True, num_workers=8) 61 | return cifar_train 62 | 63 | def val_dataloader(self): 64 | cifar_val = DataLoader(self.cifar_test, batch_size=128, shuffle=False, num_workers=8) 65 | return cifar_val 66 | 67 | def test_dataloader(self): 68 | return DataLoader(self.cifar_test, batch_size=128, shuffle=False, num_workers=8) 69 | -------------------------------------------------------------------------------- /datasets/LSUN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | from torchvision import datasets 5 | 6 | import pytorch_lightning as pl 7 | 8 | class LSUNDataModule(pl.LightningDataModule): 9 | # This DataModule is usded only for out-of-distribution dataset. 10 | # Therefore there are no seperated train/val/test dataset/ 11 | def __init__(self,batch_size=64): 12 | super().__init__() 13 | # these mean and std are not LSUN mean/std 14 | self.mean = [125.3/255, 123.0/255, 113.9/255] 15 | self.std = [63.0/255, 62.1/255.0, 66.7/255.0] 16 | 17 | self.transform = transforms.Compose([transforms.ToTensor(), 18 | transforms.Normalize(self.mean, self.std)]) 19 | self.batch_size=batch_size 20 | 21 | def prepare_data(self): 22 | pass 23 | 24 | def setup(self, stage=None): 25 | self.lsun_dataset = datasets.ImageFolder(root='./workspace/datasets/LSUN', transform=self.transform) 26 | 27 | def train_dataloader(self): 28 | pass 29 | 30 | def val_dataloader(self): 31 | pass 32 | 33 | def test_dataloader(self): 34 | return DataLoader(self.lsun_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8) 35 | 36 | 37 | class LSUN_resizeDataModule(pl.LightningDataModule): 38 | # This DataModule is usded only for out-of-distribution dataset. 39 | # Therefore there are no seperated train/val/test dataset/ 40 | def __init__(self,batch_size=64): 41 | super().__init__() 42 | # these mean and std are not LSUN mean/std 43 | self.mean = [125.3/255, 123.0/255, 113.9/255] 44 | self.std = [63.0/255, 62.1/255.0, 66.7/255.0] 45 | 46 | self.transform = transforms.Compose([transforms.ToTensor(), 47 | transforms.Normalize(self.mean, self.std)]) 48 | self.batch_size=batch_size 49 | 50 | def prepare_data(self): 51 | pass 52 | 53 | def setup(self, stage=None): 54 | self.lsun_resize_dataset = datasets.ImageFolder(root='./workspace/datasets/LSUN_resize', transform=self.transform) 55 | 56 | def train_dataloader(self): 57 | pass 58 | 59 | def val_dataloader(self): 60 | pass 61 | 62 | def test_dataloader(self): 63 | return DataLoader(self.lsun_resize_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8) -------------------------------------------------------------------------------- /datasets/SVHN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import torchvision.transforms as transforms 4 | from torchvision import datasets 5 | 6 | import pytorch_lightning as pl 7 | 8 | class SVHNDataModule(pl.LightningDataModule): 9 | def __init__(self,batch_size=64): 10 | super().__init__() 11 | self.mean = [129.3/255, 124.1/255, 112.4/255] 12 | self.std = [68.2/255, 65.4/255.0, 70.4/255.0] 13 | 14 | self.transform = transforms.Compose([transforms.RandomCrop(32,padding=4), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | transforms.Normalize(self.mean, self.std)]) 18 | self.batch_size=batch_size 19 | 20 | def prepare_data(self): 21 | datasets.SVHN(root='./workspace/datasets/SVHN', split='train', transform=self.transform, download=True) 22 | datasets.SVHN(root='./workspace/datasets/SVHN', split ='extra',transform=self.transform, download=True) 23 | datasets.SVHN(root='./workspace/datasets/SVHN', split ='test',transform=self.transform, download=True) 24 | 25 | def setup(self, stage=None): 26 | self.SVHN_test = datasets.SVHN(root='./workspace/datasets/SVHN', split='test',transform=self.transform, download=True) 27 | self.SVHN_train = datasets.SVHN(root='./workspace/datasets/SVHN', split='train',transform=self.transform, download=True) 28 | self.SVHN_val = datasets.SVHN(root='./workspace/datasets/SVHN', split='extra',transform=self.transform, download=True) 29 | 30 | def train_dataloader(self): 31 | SVHN_train = DataLoader(self.SVHN_train, batch_size=self.batch_size, shuffle=True, num_workers=8) 32 | return SVHN_train 33 | 34 | def val_dataloader(self): 35 | SVHN_val = DataLoader(self.SVHN_val, batch_size=self.batch_size, shuffle=False, num_workers=8) 36 | return SVHN_val 37 | 38 | def test_dataloader(self): 39 | return DataLoader(self.SVHN_test, batch_size=self.batch_size, shuffle=False, num_workers=8) 40 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LJY-HY/cifar_pytorch-lightning/868ed76339f4d4d0c607b9c3d09907b70b29ceb5/datasets/__init__.py -------------------------------------------------------------------------------- /models/Densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | class _DenseLayer(nn.Sequential): 9 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 10 | super(_DenseLayer, self).__init__() 11 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 12 | self.add_module('relu1', nn.ReLU(inplace=True)), 13 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 14 | growth_rate, kernel_size=1, stride=1, bias=False)), 15 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 16 | self.add_module('relu2', nn.ReLU(inplace=True)), 17 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 18 | kernel_size=3, stride=1, padding=1, bias=False)), 19 | self.drop_rate = drop_rate 20 | 21 | def forward(self, x): 22 | new_features = super(_DenseLayer, self).forward(x) 23 | if self.drop_rate > 0: 24 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 25 | return torch.cat([x, new_features], 1) 26 | 27 | 28 | class _DenseBlock(nn.Sequential): 29 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 30 | super(_DenseBlock, self).__init__() 31 | for i in range(num_layers): 32 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 33 | self.add_module('denselayer%d' % (i + 1), layer) 34 | 35 | 36 | class _Transition(nn.Sequential): 37 | def __init__(self, num_input_features, num_output_features): 38 | super(_Transition, self).__init__() 39 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 40 | self.add_module('relu', nn.ReLU(inplace=True)) 41 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 42 | kernel_size=1, stride=1, bias=False)) 43 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 44 | 45 | 46 | class _DenseNet(nn.Module): 47 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 48 | num_init_features=64, bn_size=4, drop_rate=0.0, num_classes=10): 49 | 50 | super(_DenseNet, self).__init__() 51 | 52 | # First convolution 53 | self.features = nn.Sequential(OrderedDict([ 54 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)) 55 | ])) 56 | 57 | # Each denseblock 58 | num_features = num_init_features 59 | for i, num_layers in enumerate(block_config): 60 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 61 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 62 | self.features.add_module('denseblock%d' % (i + 1), block) 63 | num_features = num_features + num_layers * growth_rate 64 | if i != len(block_config) - 1: 65 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 66 | self.features.add_module('transition%d' % (i + 1), trans) 67 | num_features = num_features // 2 68 | 69 | # Final batch norm 70 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 71 | 72 | # Linear layer 73 | self.classifier = nn.Linear(num_features, num_classes) 74 | 75 | # Official init from torch repo. 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | nn.init.kaiming_normal_(m.weight) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | nn.init.constant_(m.weight, 1) 81 | nn.init.constant_(m.bias, 0) 82 | elif isinstance(m, nn.Linear): 83 | nn.init.constant_(m.bias, 0) 84 | 85 | def forward(self, x): 86 | features = self.features(x) 87 | out = F.relu(features, inplace=True) 88 | out = F.avg_pool2d(out, kernel_size=8, stride=1).view(features.size(0), -1) 89 | out = self.classifier(out) 90 | return out 91 | 92 | def DenseNet(**kwargs): 93 | #default values are for cifar10 94 | model = _DenseNet(num_init_features=24, growth_rate=12, block_config=(16, 16, 16), 95 | **kwargs) 96 | return model -------------------------------------------------------------------------------- /models/Densenet_BC.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | class BasicBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes, dropRate=0.0): 15 | super(BasicBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRateP 21 | def forward(self, x): 22 | out = self.conv1(self.relu(self.bn1(x))) 23 | if self.droprate > 0: 24 | out = F.dropout(out, p=self.droprate, training=self.training) 25 | return torch.cat([x, out], 1) 26 | 27 | class BottleneckBlock(nn.Module): 28 | def __init__(self, in_planes, out_planes, dropRate=0.0): 29 | super(BottleneckBlock, self).__init__() 30 | inter_planes = out_planes * 4 31 | self.bn1 = nn.BatchNorm2d(in_planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 34 | padding=0, bias=False) 35 | self.bn2 = nn.BatchNorm2d(inter_planes) 36 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 37 | padding=1, bias=False) 38 | self.droprate = dropRate 39 | def forward(self, x): 40 | out = self.conv1(self.relu(self.bn1(x))) 41 | if self.droprate > 0: 42 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 43 | out = self.conv2(self.relu(self.bn2(out))) 44 | if self.droprate > 0: 45 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 46 | return torch.cat([x, out], 1) 47 | 48 | class TransitionBlock(nn.Module): 49 | def __init__(self, in_planes, out_planes, dropRate=0.0): 50 | super(TransitionBlock, self).__init__() 51 | self.bn1 = nn.BatchNorm2d(in_planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 54 | padding=0, bias=False) 55 | self.droprate = dropRate 56 | def forward(self, x): 57 | out = self.conv1(self.relu(self.bn1(x))) 58 | if self.droprate > 0: 59 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 60 | return F.avg_pool2d(out, 2) 61 | 62 | class DenseBlock(nn.Module): 63 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 64 | super(DenseBlock, self).__init__() 65 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 66 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 67 | layers = [] 68 | for i in range(int(nb_layers)): 69 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 70 | return nn.Sequential(*layers) 71 | def forward(self, x): 72 | return self.layer(x) 73 | 74 | class DenseNet3(nn.Module): 75 | def __init__(self, depth, num_classes, growth_rate=12, 76 | reduction=0.5, bottleneck=True, dropRate=0.0): 77 | super(DenseNet3, self).__init__() 78 | in_planes = 2 * growth_rate 79 | n = (depth - 4) / 3 80 | if bottleneck == True: 81 | n = n/2 82 | block = BottleneckBlock 83 | else: 84 | block = BasicBlock 85 | # 1st conv before any dense block 86 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 87 | padding=1, bias=False) 88 | # 1st block 89 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 90 | in_planes = int(in_planes+n*growth_rate) 91 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 92 | in_planes = int(math.floor(in_planes*reduction)) 93 | # 2nd block 94 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 95 | in_planes = int(in_planes+n*growth_rate) 96 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 97 | in_planes = int(math.floor(in_planes*reduction)) 98 | # 3rd block 99 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 100 | in_planes = int(in_planes+n*growth_rate) 101 | # global average pooling and classifier 102 | self.bn1 = nn.BatchNorm2d(in_planes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.fc = nn.Linear(in_planes, num_classes) 105 | self.in_planes = in_planes 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | elif isinstance(m, nn.Linear): 115 | m.bias.data.zero_() 116 | def forward(self, x): 117 | out = self.conv1(x) 118 | out = self.trans1(self.block1(out)) 119 | out = self.trans2(self.block2(out)) 120 | out = self.block3(out) 121 | out = self.relu(self.bn1(out)) 122 | out = F.avg_pool2d(out, 8) 123 | out = out.view(-1, self.in_planes) 124 | # TODO : check the shape of out is (1,*) 125 | return self.fc(out) 126 | 127 | def DenseNet_BC(**kwargs): 128 | model = DenseNet3(depth=100, growth_rate=12, reduction=0.5, bottleneck=True, dropRate=0.0, **kwargs) 129 | 130 | return model -------------------------------------------------------------------------------- /models/WideResnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LJY-HY/cifar_pytorch-lightning/868ed76339f4d4d0c607b9c3d09907b70b29ceb5/models/__init__.py -------------------------------------------------------------------------------- /models/classifiers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.optim.lr_scheduler 6 | import torch.nn.functional as F 7 | 8 | import torchvision.transforms as transforms 9 | import torch.utils.data 10 | import torchvision.models.vgg as VGG 11 | import torchvision.models.resnet as Resnet 12 | 13 | from models.WideResnet import Wide_ResNet 14 | from models.Densenet import DenseNet 15 | from models.Densenet_BC import DenseNet_BC 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.metrics.functional import accuracy 18 | import pdb 19 | ''' 20 | CIFAR10_model skeleton 21 | ''' 22 | class CIFAR10_LIGHTNING(pl.LightningModule): 23 | # Base model is VGG-16 24 | def __init__(self): 25 | super(CIFAR10_LIGHTNING, self).__init__() 26 | self.model = VGG.vgg16_bn() 27 | 28 | def forward(self, x): 29 | output = self.model(x) 30 | return output 31 | 32 | def training_step(self,batch,batch_idx): 33 | data, target = batch 34 | loss = F.cross_entropy(self.forward(data), target) 35 | tensorboard_logs = {'train_loss':loss} 36 | return {'loss':loss, 'log':tensorboard_logs} 37 | 38 | def configure_optimizers(self): 39 | optimizer = optim.SGD(self.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-4) 40 | lr_scheduler = {'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,40], gamma=0.1), 'interval': 'epoch'} 41 | return [optimizer], [lr_scheduler] 42 | 43 | def validation_step(self,batch,batch_idx): 44 | data, target = batch 45 | output = self.forward(data) 46 | loss = F.cross_entropy(output, target) 47 | pred = output.argmax(dim=1,keepdim=True) 48 | correct = pred.eq(target.view_as(pred)).sum().item() 49 | return {'val_loss':loss,'correct':correct} 50 | 51 | def validation_epoch_end(self,outputs): 52 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 53 | sum_correct = sum([x['correct'] for x in outputs]) 54 | tensorboard_logs = {'val_loss':avg_loss} 55 | print('Validation accuracy : ',sum_correct/10000,'\n\n') # self.arg.validation_size 56 | return {'avg_val_loss':avg_loss, 'log':tensorboard_logs} 57 | 58 | def test_step(self,batch,batch_idx): 59 | data, target = batch 60 | output = self.forward(data) 61 | pred = output.argmax(dim=1,keepdim=True) 62 | correct = pred.eq(target.view_as(pred)).sum().item() 63 | return {'test_loss':F.cross_entropy(output,target), 'correct':correct} 64 | 65 | def test_epoch_end(self, outputs): 66 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 67 | sum_correct = sum([x['correct'] for x in outputs]) 68 | tensorboard_logs = {'test_loss': avg_loss} 69 | print('Test accuracy :',sum_correct/10000,'\n') 70 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 71 | 72 | class CIFAR10_VGG(CIFAR10_LIGHTNING): 73 | # This Module is based on VGG-16 for dataset CIFAR10 74 | def __init__(self): 75 | super(CIFAR10_LIGHTNING, self).__init__() 76 | self.model = VGG.vgg16_bn() 77 | self.model.avgpool = nn.AdaptiveAvgPool2d((1,1)) 78 | self.model.classifier = nn.Sequential( 79 | nn.Linear(512,10) 80 | ) 81 | 82 | class CIFAR10_Resnet(CIFAR10_LIGHTNING): 83 | # This Module is based on Resnet-50 for dataset CIFAR10 84 | def __init__(self): 85 | super(CIFAR10_LIGHTNING, self).__init__() 86 | self.model = Resnet.ResNet(Resnet.Bottleneck,[3,4,6,3],num_classes=10) 87 | self.model.inplanes=64 88 | self.model.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False) 89 | self.model.bn1 = nn.BatchNorm2d(64) 90 | self.model.linear = nn.Linear(512*Resnet.Bottleneck.expansion, 10) 91 | del self.model.maxpool 92 | self.model.maxpool = lambda x : x 93 | 94 | class CIFAR10_WideResnet(CIFAR10_LIGHTNING): 95 | # This Module is based on WideResNet28-10 for dataset CIFAR10 96 | def __init__(self): 97 | super(CIFAR10_LIGHTNING, self).__init__() 98 | self.model = Wide_ResNet(28,10,0.3,10) 99 | 100 | class CIFAR10_Densenet(CIFAR10_LIGHTNING): 101 | # This Module is based on Densenet for dataset CIFAR10 102 | def __init__(self): 103 | super(CIFAR10_LIGHTNING, self).__init__() 104 | self.model = DenseNet() 105 | 106 | class CIFAR10_Densenet_BC(CIFAR10_LIGHTNING): 107 | def __init__(self): 108 | super(CIFAR10_Densenet_BC,self).__init__() 109 | self.model = DenseNet_BC(num_classes=10) 110 | 111 | 112 | ''' 113 | CIFAR100_model skeleton 114 | ''' 115 | class CIFAR100_LIGHTNING(pl.LightningModule): 116 | # This Module is based on VGG-16 117 | def __init__(self): 118 | super(CIFAR100_LIGHTNING, self).__init__() 119 | 120 | def forward(self, x): 121 | output = self.model(x) 122 | return output 123 | 124 | def training_step(self,batch,batch_idx): 125 | data, target = batch 126 | output = self.forward(data) 127 | loss = F.cross_entropy(output,target) 128 | tensorboard_logs = {'train_loss':loss} 129 | return {'loss':loss, 'log':tensorboard_logs} 130 | 131 | def configure_optimizers(self): 132 | optimizer = optim.SGD(self.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-4) 133 | lr_scheduler = {'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60,120], gamma=0.1), 'interval': 'epoch'} 134 | return [optimizer], [lr_scheduler] 135 | 136 | def validation_step(self,batch,batch_idx): 137 | data, target = batch 138 | output = self.forward(data) 139 | loss = F.cross_entropy(output,target) 140 | pred = output.argmax(dim=1,keepdim=True) 141 | correct = pred.eq(target.view_as(pred)).sum().item() 142 | return {'val_loss':loss,'correct':correct} 143 | 144 | def validation_epoch_end(self,outputs): 145 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 146 | sum_correct = sum([x['correct'] for x in outputs]) 147 | tensorboard_logs = {'val_loss':avg_loss} 148 | print('Validation accuracy : ',sum_correct/10000,'\n\n') 149 | return {'avg_val_loss':avg_loss, 'log':tensorboard_logs} 150 | 151 | def test_step(self,batch,batch_idx): 152 | data, target = batch 153 | output = self.forward(data) 154 | pred = output.argmax(dim=1,keepdim=True) 155 | correct = pred.eq(target.view_as(pred)).sum().item() 156 | return {'test_loss':F.cross_entropy(output,target), 'correct':correct} 157 | 158 | def test_epoch_end(self, outputs): 159 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 160 | sum_correct = sum([x['correct'] for x in outputs]) 161 | tensorboard_logs = {'test_loss': avg_loss} 162 | print('Test accuracy :',sum_correct/10000,'\n') 163 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 164 | 165 | class CIFAR100_VGG(CIFAR100_LIGHTNING): 166 | # This Module is based on VGG-16 for dataset CIFAR100 167 | def __init__(self): 168 | super(CIFAR100_VGG, self).__init__() 169 | self.model = VGG.vgg16_bn() 170 | self.model.avgpool = nn.AdaptiveAvgPool2d((1,1)) 171 | self.model.classifier = nn.Sequential( 172 | nn.Linear(512,100) 173 | ) 174 | 175 | class CIFAR100_Resnet(CIFAR100_LIGHTNING): 176 | # This Module is based on Resnet-50 for dataset CIFAR100 177 | def __init__(self): 178 | super(CIFAR100_Resnet, self).__init__() 179 | self.model = Resnet.ResNet(Resnet.Bottleneck,[3,4,6,3],num_classes=100) 180 | self.model.inplanes=64 181 | self.model.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False) 182 | self.model.bn1 = nn.BatchNorm2d(64) 183 | self.model.linear = nn.Linear(512*Resnet.Bottleneck.expansion, 100) 184 | del self.model.maxpool 185 | self.model.maxpool = lambda x : x 186 | 187 | class CIFAR100_WideResnet(CIFAR100_LIGHTNING): 188 | # This Module is based on WideResNet 28-20 for dataset CIFAR-100 189 | def __init__(self): 190 | super(CIFAR100_WideResnet, self).__init__() 191 | self.model = Wide_ResNet(28,20,0.3,100) 192 | 193 | class CIFAR100_Densenet(CIFAR100_LIGHTNING): 194 | # This Module is based on VGG-16 for dataset CIFAR100 195 | def __init__(self): 196 | super(CIFAR100_Densenet, self).__init__() 197 | self.model = DenseNet(num_classes=100) 198 | 199 | class CIFAR100_Densenet_BC(CIFAR10_LIGHTNING): 200 | def __init__(self): 201 | super(CIFAR100_Densenet_BC,self).__init__() 202 | self.model = DenseNet_BC(num_classes=100) -------------------------------------------------------------------------------- /setup/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.1 2 | torchvision==0.6.1 3 | pytorch-lightning==0.9.0rc5 4 | tqdm 5 | argparse 6 | pytablewriter 7 | seaborn 8 | enum34 9 | scipy 10 | cffi 11 | sklearn -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from pytorch_lightning.loggers import TensorBoardLogger 7 | 8 | from models.classifiers import * 9 | from datasets.CIFAR import * 10 | from datasets.LSUN import * 11 | from datasets.SVHN import * 12 | from utils.args import * 13 | 14 | if __name__ == '__main__': 15 | datasets = ['CIFAR10','CIFAR100'] 16 | NNModels = ['VGG','Resnet','WideResnet','Densenet_BC','Densenet'] 17 | for dataset in datasets: 18 | if dataset == 'CIFAR10': 19 | dm = CIFAR10DataModule() 20 | max_epochs = 60 21 | elif dataset == 'CIFAR100': 22 | dm = CIFAR100DataModule() 23 | max_epochs = 180 24 | for NNModel in NNModels: 25 | model_name = dataset + '_' + NNModel 26 | model = globals()[model_name]() 27 | modelpath = './workspace/model_ckpts/' + model_name + '/' 28 | os.makedirs(modelpath, exist_ok=True) 29 | checkpoint_callback=ModelCheckpoint(filepath=modelpath) 30 | trainer=Trainer(checkpoint_callback=checkpoint_callback, gpus=1, num_nodes=1, max_epochs = max_epochs) 31 | if os.path.isfile(modelpath + 'final.ckpt'): 32 | model = model.load_from_checkpoint(checkpoint_path=modelpath + 'final.ckpt') 33 | else: 34 | trainer.fit(model, dm) 35 | trainer.save_checkpoint(modelpath + 'final.ckpt') 36 | trainer.test(model, datamodule = dm) 37 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LJY-HY/cifar_pytorch-lightning/868ed76339f4d4d0c607b9c3d09907b70b29ceb5/utils/__init__.py -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | def get_args(): 4 | parser = ArgumentParser(description='Implementation of many methods for detecting OOD samples.') 5 | 6 | parser.add_argument('--validation-size', default=5000, type=int, help='Number of validation set. (default 5000)') 7 | 8 | args = parser.parse_args() 9 | 10 | return args 11 | --------------------------------------------------------------------------------