├── .gitignore ├── README.md ├── data └── README.md ├── data_utils.py ├── eval.py ├── metrics.py ├── models ├── __init__.py ├── resnet_20.py └── simple_cnn.py ├── plot_log.py ├── train.py └── training_curve.png /.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | **/__pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | A simple demo for training cnn with pytorch. It is a pytorch clone of [tf-simple-example](https://github.com/blaueck/tf-simple-example). It support training cnn with mnist, cifar10, cifar100 and svhn. 3 | 4 | # Requirement 5 | * python==3.6 6 | * pytorch>=1.0.0 7 | * torchvision 8 | * matplotlib (optional, for plot training logs) 9 | 10 | 11 | # Run 12 | ```bash 13 | # train 14 | python train.py 15 | 16 | # eval 17 | python eval.py log/resnet_20/checkpoint_69.pk 18 | 19 | # plot training logs 20 | python plot_log.py log/ -m true 21 | ``` 22 | 23 | # Training Curve 24 | ![training curve](./training_curve.png) -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Download dataset 2 | * Download mnist from [here](http://yann.lecun.com/exdb/mnist/). 3 | * Download cifar10(binary version) or cifar100(binary version) from [here](http://www.cs.toronto.edu/~kriz/cifar.html). 4 | * Download svhn(Cropped Digits) from [here](http://ufldl.stanford.edu/housenumbers/). 5 | 6 | # mnist 7 | mnist directory should contain following files: 8 | * train-images-idx3-ubyte.gz 9 | * train-labels-idx1-ubyte.gz 10 | * t10k-images-idx3-ubyte.gz 11 | * t10k-labels-idx1-ubyte.gz 12 | 13 | # cifar10 14 | cifar10 directory should as least contain following files: 15 | * data_batch_[1-5].bin 16 | * test_batch.bin 17 | 18 | # cifar100 19 | cifar100 directory should as least contain following files: 20 | * train.bin 21 | * test.bin 22 | 23 | # svhn 24 | svhn directory should as least contain following files: 25 | * train_32x32.mat 26 | * test_32x32.mat -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import gzip 4 | 5 | import numpy as np 6 | from scipy.io import loadmat 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | def _glob(pattern): 12 | 13 | if isinstance(pattern, str): 14 | files = glob.glob(pattern) 15 | elif isinstance(pattern, list): 16 | files = [] 17 | for p in pattern: 18 | files.extend(glob.glob(pattern)) 19 | else: 20 | raise TypeError('wrong argument type.') 21 | 22 | return files 23 | 24 | 25 | def get_cifar10(files): 26 | 27 | images_splits = [] 28 | labels_splits = [] 29 | n_pixel = 32 * 32 * 3 30 | 31 | for f in files: 32 | buffer = np.fromfile(f, dtype='uint8') 33 | buffer = buffer.reshape(-1, n_pixel+1) 34 | 35 | images = buffer[:, 1:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) 36 | labels = buffer[:, 0] 37 | 38 | images_splits.append(images) 39 | labels_splits.append(labels) 40 | 41 | images = np.concatenate(images_splits) 42 | labels = np.concatenate(labels_splits) 43 | 44 | return images, labels 45 | 46 | 47 | def get_cifar100(files): 48 | 49 | images_splits = [] 50 | labels_splits = [] 51 | n_pixel = 32 * 32 * 3 52 | 53 | for f in files: 54 | buffer = np.fromfile(f, dtype='uint8') 55 | buffer = buffer.reshape(-1, n_pixel+2) 56 | 57 | images = buffer[:, 2:].reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1) 58 | labels = buffer[:, 1] 59 | 60 | images_splits.append(images) 61 | labels_splits.append(labels) 62 | 63 | images = np.concatenate(images_splits) 64 | labels = np.concatenate(labels_splits) 65 | 66 | return images, labels 67 | 68 | 69 | def get_mnist(image_files, label_files): 70 | 71 | images_splits = [] 72 | labels_splits = [] 73 | 74 | for i_f, l_f in zip(image_files, label_files): 75 | 76 | with gzip.open(i_f, 'rb') as f: 77 | images = np.frombuffer(f.read(), dtype='uint8', offset=16) 78 | images = images.reshape(-1, 28, 28, 1) 79 | images_splits.append(images) 80 | 81 | with gzip.open(l_f, 'rb') as f: 82 | labels = np.frombuffer(f.read(), dtype='uint8', offset=8) 83 | labels_splits.append(labels) 84 | 85 | images = np.concatenate(images_splits) 86 | labels = np.concatenate(labels_splits) 87 | 88 | return images, labels 89 | 90 | 91 | def get_svhn(data_file): 92 | 93 | data = loadmat(data_file) 94 | images = data['X'].transpose(3, 0, 1, 2) 95 | labels = data['y'].reshape(-1) 96 | 97 | # map label 10 to 0 98 | labels[labels == 10] = 0 99 | 100 | return images, labels 101 | 102 | 103 | def load_data(root, dataset, is_training): 104 | 105 | if dataset == 'cifar10': 106 | 107 | if is_training: 108 | pattern = os.path.join(root, 'data_batch*.bin') 109 | else: 110 | pattern = os.path.join(root, 'test_batch.bin') 111 | 112 | files = _glob(pattern) 113 | assert files, 'no file is matched.' 114 | 115 | data = get_cifar10(files) 116 | meta = {'n_class': 10} 117 | 118 | elif dataset == 'cifar100': 119 | 120 | if is_training: 121 | pattern = os.path.join(root, 'train.bin') 122 | else: 123 | pattern = os.path.join(root, 'test.bin') 124 | 125 | files = _glob(pattern) 126 | assert files, 'no file is matched.' 127 | 128 | data = get_cifar100(pattern) 129 | meta = {'n_class': 100} 130 | 131 | elif dataset == 'mnist': 132 | 133 | if is_training: 134 | img_pattern, label_pattern = [ 135 | os.path.join(root, fn) 136 | for fn in ['train-images-idx3-ubyte.gz', 137 | 'train-labels-idx1-ubyte.gz']] 138 | else: 139 | img_pattern, label_pattern = [ 140 | os.path.join(root, fn) 141 | for fn in ['t10k-images-idx3-ubyte.gz', 142 | 't10k-labels-idx1-ubyte.gz']] 143 | 144 | img_files, label_files = _glob(img_pattern), _glob(label_pattern) 145 | assert img_files, 'no image file is matched.' 146 | assert label_files, 'no label file is matched.' 147 | 148 | data = get_mnist(img_files, label_files) 149 | meta = {'n_class': 10} 150 | 151 | elif dataset == 'svhn': 152 | 153 | if is_training: 154 | pattern = os.path.join(root, 'train_32x32.mat') 155 | else: 156 | pattern = os.path.join(root, 'test_32x32.mat') 157 | 158 | files = _glob(pattern) 159 | assert files, 'no file is matched.' 160 | data_file = files[0] 161 | 162 | data = get_svhn(data_file) 163 | meta = {'n_class': 10} 164 | 165 | else: 166 | raise ValueError('%s is not supported.' % dataset) 167 | 168 | meta['name'] = dataset 169 | return data, meta 170 | 171 | 172 | def split_data(data, rate, shuffle=True): 173 | 174 | images, labels = data 175 | N = images.shape[0] 176 | split_point = int(N * rate) 177 | 178 | if shuffle: 179 | idx = np.random.permutation(N) 180 | images, labels = images[idx], labels[idx] 181 | 182 | train_data = images[split_point:], labels[split_point:] 183 | val_data = images[:split_point], labels[:split_point] 184 | 185 | return train_data, val_data 186 | 187 | 188 | 189 | class ImageDataset(Dataset): 190 | 191 | def __init__(self, images, labels, is_training=True, is_flip=False): 192 | self.images = images 193 | self.labels = labels 194 | self.is_training = is_training 195 | self.is_flip = is_flip 196 | h, w, c = images.shape[1:] 197 | self.shape = (c, h, w) 198 | 199 | def _preprocess(self, img, label): 200 | 201 | if self.is_training: 202 | h, w = img.shape[:2] 203 | 204 | # padding 205 | img = np.pad(img, ((4, 4), (4, 4), (0, 0)), 'constant', constant_values=0.) 206 | 207 | # random crop 208 | dx, dy = np.random.randint(9), np.random.randint(9) 209 | img = img[dy:dy+h, dx:dx+w] 210 | 211 | # random flip 212 | if self.is_flip and np.random.rand() < 0.5: 213 | img = img[:, ::-1] 214 | 215 | img = img.astype('float32') 216 | img = img / 255. 217 | img = img.transpose(2, 0, 1).copy() 218 | return img, int(label) 219 | 220 | def __len__(self): 221 | return len(self.images) 222 | 223 | def __getitem__(self, index): 224 | img, label = self._preprocess(self.images[index], self.labels[index]) 225 | return img, label 226 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.utils.data 7 | 8 | import data_utils 9 | import models 10 | from metrics import MeanValue, Accuracy 11 | 12 | 13 | def main(args): 14 | 15 | # load data 16 | data, meta = data_utils.load_data( 17 | args.dataset_root, args.dataset, is_training=False) 18 | 19 | # build val dataloader 20 | dataset = data_utils.ImageDataset(*data, is_training=False) 21 | dataloader = torch.utils.data.DataLoader( 22 | dataset, args.batch_size, shuffle=False, num_workers=2, pin_memory=True) 23 | 24 | # remove temp dataset variables to reduce memory usage 25 | del data 26 | 27 | device = torch.device(args.device) 28 | 29 | # build model 30 | if args.model == 'resnet_20': 31 | model = models.Resnet20 32 | else: 33 | model = models.SimpleCNN 34 | net = model(dataset.shape, meta['n_class']).to(device=device) 35 | 36 | criterion = torch.nn.CrossEntropyLoss() 37 | 38 | state = torch.load(args.cpt) 39 | net.load_state_dict(state['net']) 40 | 41 | net.eval() 42 | mean_loss, acc = MeanValue(), Accuracy() 43 | for x, y in dataloader: 44 | 45 | if device.type == 'cuda': 46 | x = x.cuda(device, non_blocking=True) 47 | y = y.cuda(device, non_blocking=True) 48 | 49 | logits = net(x) 50 | loss = criterion(logits, y) 51 | 52 | loss = loss.detach().cpu().numpy() 53 | predicts = torch.argmax(logits, dim=1).detach().cpu().numpy() 54 | y = y.detach().cpu().numpy() 55 | 56 | mean_loss.add(loss) 57 | acc.add(predicts, y) 58 | 59 | print('loss: {:.4f}, acc: {:.2%}'.format( 60 | mean_loss.get(), acc.get())) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = ArgumentParser(description='pytorch for small image dataset') 65 | parser.add_argument('cpt', default='', help='checkpoint path') 66 | parser.add_argument('--dataset', default='cifar10', 67 | help='the eval dataset') 68 | parser.add_argument( 69 | '--dataset_root', default='data/cifar-10-batches-bin/', help='dataset root') 70 | 71 | parser.add_argument('--model', default='resnet_20', 72 | choices=['resnet_20', 'simple_cnn'], help='model name') 73 | 74 | parser.add_argument('--device', default='cuda:0', help='device') 75 | 76 | parser.add_argument('--batch_size', default=64, 77 | type=int, help='batch size') 78 | args = parser.parse_args() 79 | main(args) 80 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | 6 | class MeanValue: 7 | 8 | def __init__(self): 9 | self.value = 0 10 | self.counter = 0 11 | 12 | def add(self, value): 13 | self.value += value 14 | self.counter += 1 15 | 16 | def get(self): 17 | return self.value / self.counter 18 | 19 | def reset(self): 20 | self.value = 0 21 | self.counter = 0 22 | 23 | 24 | class Accuracy: 25 | 26 | def __init__(self): 27 | self.n_correct = 0. 28 | self.n_sample = 0. 29 | 30 | def add(self, predicts, labels): 31 | self.n_sample += len(labels) 32 | self.n_correct += np.sum(predicts == labels) 33 | 34 | def get(self): 35 | return self.n_correct / self.n_sample 36 | 37 | def reset(self): 38 | self.n_correct = 0. 39 | self.n_sample = 0. 40 | 41 | 42 | class TimeMeter: 43 | 44 | def __init__(self): 45 | self.start_time, self.duration, self.counter = 0., 0., 0. 46 | 47 | def add_counter(self): 48 | self.counter += 1 49 | 50 | def start(self): 51 | self.start_time = time.perf_counter() 52 | 53 | def stop(self): 54 | self.duration += time.perf_counter() - self.start_time 55 | 56 | def get(self): 57 | return self.duration / self.counter 58 | 59 | def reset(self): 60 | self.start_time, self.duration, self.counter = 0., 0., 0. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_cnn import SimpleCNN 2 | from .resnet_20 import Resnet20 -------------------------------------------------------------------------------- /models/resnet_20.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torchvision.models.resnet import BasicBlock 8 | 9 | class Resnet20(nn.Module): 10 | 11 | def __init__(self, in_shape, n_class, *args, **kwargs): 12 | super(Resnet20, self).__init__(*args, **kwargs) 13 | self.in_shape = in_shape 14 | 15 | self.conv1 = nn.Conv2d(in_shape[0], 16, 3, stride=1, padding=1) 16 | self.bn1 = nn.BatchNorm2d(16) 17 | 18 | self.in_planes = 16 19 | 20 | block = BasicBlock 21 | self.layer1 = self._make_layers(block, 16, 2, 1) 22 | self.layer2 = self._make_layers(block, 32, 2, 2) 23 | self.layer3 = self._make_layers(block, 64, 2, 2) 24 | 25 | # compute the height of layer3's feature map 26 | fh = in_shape[1] 27 | for _ in range(2): 28 | fh = math.ceil(fh / 2) 29 | self.fh = int(fh) 30 | 31 | self.fc = nn.Linear(64, n_class) 32 | 33 | def _make_layers(self, block, planes, n_block, stride=1): 34 | layers = [] 35 | 36 | downsample = None 37 | if stride != 1: 38 | downsample = nn.Sequential( 39 | nn.Conv2d(self.in_planes, planes, 40 | kernel_size=1, stride=stride, bias=False), 41 | nn.BatchNorm2d(planes * block.expansion)) 42 | 43 | layers.append(block(self.in_planes, planes, stride, downsample)) 44 | for _ in range(1, n_block): 45 | layers.append(block(planes, planes)) 46 | 47 | self.in_planes = planes 48 | return nn.Sequential(*layers) 49 | 50 | 51 | def forward(self, images): 52 | out = self.conv1(images) 53 | out = self.bn1(out) 54 | out = F.relu(out) 55 | 56 | out = self.layer1(out) 57 | out = self.layer2(out) 58 | out = self.layer3(out) 59 | out = F.avg_pool2d(out, self.fh) 60 | 61 | shape = out.shape 62 | out = torch.reshape(out, (shape[0], -1)) 63 | out = self.fc(out) 64 | return out 65 | 66 | -------------------------------------------------------------------------------- /models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class SimpleCNN(nn.Module): 8 | 9 | def __init__(self, in_shape, n_class, *args, **kwargs): 10 | super(SimpleCNN, self).__init__(*args, **kwargs) 11 | self.in_shape = in_shape 12 | 13 | self.conv1 = nn.Conv2d(in_shape[0], 32, 3, stride=2, padding=1) 14 | self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1) 15 | self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 16 | 17 | fh = in_shape[1] 18 | for _ in range(3): 19 | fh = math.ceil(fh / 2) 20 | fh = int(fh) 21 | self.fc = nn.Linear(fh * fh * 128, n_class) 22 | 23 | def forward(self, images): 24 | out = self.conv1(images) 25 | out = F.relu(out) 26 | out = self.conv2(out) 27 | out = F.relu(out) 28 | out = self.conv3(out) 29 | out = F.relu(out) 30 | shape = out.shape 31 | out = torch.reshape(out, (shape[0], -1)) 32 | out = self.fc(out) 33 | return out 34 | 35 | -------------------------------------------------------------------------------- /plot_log.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from argparse import ArgumentParser 4 | 5 | import matplotlib.pyplot as plt 6 | import torch 7 | 8 | 9 | def parse_bool(value): 10 | if isinstance(value, bool): 11 | return value 12 | elif isinstance(value, str): 13 | value = value.lower() 14 | if value in ['true', 'ok', '1']: 15 | return True 16 | return False 17 | 18 | 19 | def plot_dir(logdir, prefix): 20 | logfiles = list(logdir.glob('log_*')) 21 | 22 | if len(logfiles) == 0: 23 | return 24 | 25 | # collect log 26 | train_log, val_log = {}, {} 27 | for log_f in logfiles: 28 | log = torch.load(str(log_f)) 29 | train_log.update(log['train']) 30 | val_log.update(log['val']) 31 | 32 | # prepair logs for plot 33 | train_step, train_loss, train_acc = [], [], [] 34 | for k in sorted(train_log.keys()): 35 | train_step.append(k) 36 | data = train_log[k] 37 | train_loss.append(data['loss']) 38 | train_acc.append(data['acc']) 39 | 40 | val_step, val_loss, val_acc = [], [], [] 41 | for k in sorted(val_log.keys()): 42 | val_step.append(k) 43 | data = val_log[k] 44 | val_loss.append(data['loss']) 45 | val_acc.append(data['acc']) 46 | 47 | # plot logs 48 | plt.subplot(121) 49 | plt.title('Loss') 50 | plt.plot(train_step, train_loss, '-', label=prefix+'train') 51 | plt.plot(val_step, val_loss, '-', label=prefix+'val') 52 | plt.xlabel('step') 53 | plt.legend() 54 | 55 | plt.subplot(122) 56 | plt.title('Accuracy') 57 | plt.plot(train_step, train_acc, '-', label=prefix+'train') 58 | plt.plot(val_step, val_acc, '-', label=prefix+'val') 59 | plt.xlabel('step') 60 | plt.legend() 61 | 62 | 63 | def main(args): 64 | logdir = Path(args.logdir) 65 | 66 | dirs = [] 67 | if args.multi_dir: 68 | dirs = [d for d in logdir.iterdir() if d.is_dir()] 69 | else: 70 | dirs = [logdir] 71 | 72 | for d in dirs: 73 | subfix = d.name + '-' if args.multi_dir else '' 74 | plot_dir(d, subfix) 75 | 76 | plt.show() 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = ArgumentParser(description='plot training log') 81 | parser.add_argument('logdir', help='log directory') 82 | parser.add_argument('-m', '--multi_dir', type=parse_bool, default=False, 83 | help='whether logdir contain multi sub-directory that contain logs') 84 | args = parser.parse_args() 85 | sys.exit(main(args)) 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | 9 | import data_utils 10 | import models 11 | from metrics import TimeMeter, MeanValue, Accuracy 12 | 13 | 14 | class LRManager: 15 | 16 | def __init__(self, boundaries, values): 17 | self.boundaries = boundaries 18 | self.values = values 19 | 20 | def get(self, epoch): 21 | for b, v in zip(self.boundaries, self.values): 22 | if epoch < b: 23 | return v 24 | return self.values[-1] 25 | 26 | def set_lr_for_optim(self, epoch, optim): 27 | lr = self.get(epoch) 28 | for group in optim.param_groups: 29 | group['lr'] = lr 30 | 31 | 32 | def init_param(m): 33 | if type(m) == torch.nn.Conv2d: 34 | torch.nn.init.kaiming_normal_(m.weight) 35 | if m.bias is not None: 36 | torch.nn.init.constant_(m.bias, 0) 37 | elif type(m) == torch.nn.Linear: 38 | torch.nn.init.kaiming_normal_(m.weight) 39 | if m.bias is not None: 40 | torch.nn.init.constant_(m.bias, 0) 41 | elif type(m) == torch.nn.BatchNorm2d: 42 | torch.nn.init.constant_(m.weight, 1) 43 | torch.nn.init.constant_(m.bias, 0) 44 | 45 | 46 | def make_param_groups(net, weight_decay): 47 | weight_params = [] 48 | bias_params = [] 49 | for name, value in net.named_parameters(): 50 | if 'bias' in name: 51 | bias_params.append(value) 52 | else: 53 | weight_params.append(value) 54 | param_groups = [ 55 | {'params': weight_params, 'weight_decay': weight_decay}, 56 | {'params': bias_params, 'weight_decay': 0.} 57 | ] 58 | return param_groups 59 | 60 | 61 | def main(args): 62 | 63 | # set random seed 64 | np.random.seed(args.seed) 65 | torch.manual_seed(args.seed) 66 | 67 | # load data 68 | data, meta = data_utils.load_data( 69 | args.dataset_root, args.dataset, is_training=True) 70 | train_data, val_data = data_utils.split_data( 71 | data, args.validate_rate, shuffle=True) 72 | 73 | # build train dataloader 74 | train_dataset = data_utils.ImageDataset( 75 | *train_data, is_training=True, is_flip=args.dataset not in ['mnist', 'svhn']) 76 | train_dataloader = torch.utils.data.DataLoader( 77 | train_dataset, args.batch_size, shuffle=True, num_workers=2, pin_memory=True) 78 | 79 | # build val dataloader 80 | val_dataset = data_utils.ImageDataset(*val_data, is_training=False) 81 | val_dataloader = torch.utils.data.DataLoader( 82 | val_dataset, args.batch_size, shuffle=False, num_workers=2, pin_memory=True) 83 | 84 | # remove temp dataset variables to reduce memory usage 85 | del data, train_data, val_data 86 | 87 | device = torch.device(args.device) 88 | 89 | # build model 90 | if args.model == 'resnet_20': 91 | model = models.Resnet20 92 | else: 93 | model = models.SimpleCNN 94 | net = model(train_dataset.shape, meta['n_class']).to(device=device) 95 | net.apply(init_param) 96 | 97 | criterion = torch.nn.CrossEntropyLoss() 98 | 99 | # build optim 100 | optim = torch.optim.SGD(make_param_groups( 101 | net, args.weight_decay), 0.1, momentum=0.9) 102 | 103 | # make log directory 104 | logdir = Path(args.logdir) 105 | if not logdir.exists(): 106 | os.makedirs(str(logdir)) 107 | 108 | global_step = 0 109 | start_epoch = 0 110 | if args.restore: 111 | # restore checkpoint 112 | state = torch.load(args.restore) 113 | start_epoch = state['epoch'] + 1 114 | global_step = state['global_step'] 115 | net.load_state_dict(state['net']) 116 | optim.load_state_dict(state['optim']) 117 | 118 | # lr strategy 119 | lr_boundaries = list(map(int, args.boundaries.split(','))) 120 | lr_values = list(map(float, args.values.split(','))) 121 | lr_manager = LRManager(lr_boundaries, lr_values) 122 | 123 | for e in range(start_epoch, args.n_epoch): 124 | print('-------epoch: {:d}-------'.format(e)) 125 | 126 | # training phrase 127 | net.train() 128 | mean_loss, acc = MeanValue(), Accuracy() 129 | lr_manager.set_lr_for_optim(e, optim) 130 | tm = TimeMeter() 131 | tm.start() 132 | train_log = {} 133 | for i, (x, y) in enumerate(train_dataloader): 134 | tm.add_counter() 135 | 136 | if device.type == 'cuda': 137 | x = x.cuda(device, non_blocking=True) 138 | y = y.cuda(device, non_blocking=True) 139 | 140 | optim.zero_grad() 141 | logits = net(x) 142 | loss = criterion(logits, y) 143 | 144 | loss.backward() 145 | optim.step() 146 | global_step += 1 147 | 148 | loss = loss.detach().cpu().numpy() 149 | predicts = torch.argmax(logits, dim=1).detach().cpu().numpy() 150 | y = y.detach().cpu().numpy() 151 | 152 | mean_loss.add(loss) 153 | acc.add(predicts, y) 154 | 155 | if i % args.log_every == 0: 156 | torch.cuda.synchronize() 157 | tm.stop() 158 | 159 | print('step: {:d}, lr: {:g}, loss: {:.4f}, acc: {:.2%}, speed: {:.2f} i/s.' 160 | .format(i, lr_manager.get(e), mean_loss.get(), acc.get(), args.batch_size / tm.get())) 161 | train_log[global_step] = { 162 | 'loss': mean_loss.get(), 'acc': acc.get()} 163 | tm.reset() 164 | tm.start() 165 | mean_loss.reset() 166 | acc.reset() 167 | 168 | # val phrase 169 | net.eval() 170 | mean_loss, acc = MeanValue(), Accuracy() 171 | for x, y in val_dataloader: 172 | 173 | if device.type == 'cuda': 174 | x = x.cuda(device, non_blocking=True) 175 | y = y.cuda(device, non_blocking=True) 176 | 177 | logits = net(x) 178 | loss = criterion(logits, y) 179 | 180 | loss = loss.detach().cpu().numpy() 181 | predicts = torch.argmax(logits, dim=1).detach().cpu().numpy() 182 | y = y.detach().cpu().numpy() 183 | 184 | mean_loss.add(loss) 185 | acc.add(predicts, y) 186 | 187 | print('val_loss: {:.4f}, val_acc: {:.2%}'.format( 188 | mean_loss.get(), acc.get())) 189 | val_log = {global_step: {'loss': mean_loss.get(), 'acc': acc.get()}} 190 | 191 | # save checkpoint 192 | vars_to_saver = { 193 | 'net': net.state_dict(), 'optim': optim.state_dict(), 194 | 'epoch': e, 'global_step': global_step} 195 | cpt_file = logdir / 'checkpoint_{:d}.pk'.format(e) 196 | torch.save(vars_to_saver, str(cpt_file)) 197 | 198 | log_file = logdir / 'log_{:d}.pk'.format(e) 199 | torch.save({'train': train_log, 'val': val_log}, str(log_file)) 200 | 201 | 202 | if __name__ == '__main__': 203 | parser = ArgumentParser(description='pytorch for small image dataset') 204 | parser.add_argument('--dataset', default='cifar10', 205 | help='the training dataset') 206 | parser.add_argument( 207 | '--dataset_root', default='data/cifar-10-batches-bin/', help='dataset root') 208 | parser.add_argument( 209 | '--logdir', default='log/resnet_20', help='log directory') 210 | parser.add_argument('--restore', default='', help='snapshot path') 211 | parser.add_argument('--validate_rate', default=0.1, 212 | type=float, help='validate split rate') 213 | 214 | parser.add_argument('--device', default='cuda:0', help='device') 215 | 216 | parser.add_argument('--model', default='resnet_20', 217 | choices=['resnet_20', 'simple_cnn'], help='model name') 218 | parser.add_argument('--n_epoch', default=70, 219 | type=int, help='number of epoch') 220 | parser.add_argument('--weight_decay', default=0.0001, 221 | type=float, help='weight decay rate') 222 | parser.add_argument('--boundaries', default='30,60', 223 | help='learning rate boundaries') 224 | parser.add_argument( 225 | '--values', default='1e-1,1e-2,1e-3', help='learning rate values') 226 | 227 | parser.add_argument('--log_every', default=100, type=int, 228 | help='display and log frequency') 229 | parser.add_argument('--seed', default=0, type=float, help='random seed') 230 | 231 | parser.add_argument('--batch_size', default=64, 232 | type=int, help='batch size') 233 | args = parser.parse_args() 234 | main(args) 235 | -------------------------------------------------------------------------------- /training_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blaueck/pytorch-demo/ec42dffa65406b2e63df69828851299e43a38cda/training_curve.png --------------------------------------------------------------------------------