├── .gitignore ├── README.md └── src ├── datasets.py ├── evaluate.py ├── main.py ├── resnet.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | data/ 3 | src/__pycache__/ 4 | *.pyc 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Pytorch Distributed Training 3 | 4 | This is general pytorch code for running and logging distributed training experiments. 5 | 6 | Using **DistributedDataParallel** is faster than **DataParallel**, even for single machine multi-gpu training. 7 | 8 | Runs are automatically organised into folders, with logs of the architecture and hyperparameters used, as well as the training progress print outs from the terminal (see example below). 9 | 10 | Simply drop your own model into `src/main.py`. 11 | 12 | * **Author**: Fabio De Sousa Ribeiro 13 | * **Email**: fdesosuaribeiro@lincoln.ac.uk 14 | 15 | ## Run 16 | You can launch **Distributed** training from `src/` using: 17 | 18 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py 19 | 20 | This will train on a single machine (`nnodes=1`), assigning 1 process per GPU where `nproc_per_node=2` refers to training on 2 GPUs. To train on `N` GPUs simply launch `N` processes by setting `nproc_per_node=N`. 21 | 22 | The number of CPU threads to use per process is hard coded to `torch.set_num_threads(1)` for safety, and can be changed to `your # cpu threads / nproc_per_node` for better performance. 23 | 24 | For more info on **multi-node** and **multi-gpu** distributed training refer to https://github.com/hgrover/pytorchdistr/blob/master/README.md 25 | 26 | To train normally using **nn.DataParallel** or using the CPU: 27 | 28 | python main.py --no_distributed 29 | 30 | 31 | ## Example output when launching an experiment: 32 | 33 | ``` 34 | (torch) 35 | Documents/Distributed-Pytorch-Boilerplate/src master ✔ 43m ⍉ 36 | ▶ python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py 37 | 38 | World size: 2 ; Rank: 0 ; LocalRank: 0 ; Master: localhost:port 39 | World size: 2 ; Rank: 1 ; LocalRank: 1 ; Master: localhost:port 40 | 41 | ---------------------------------------------------------------------- 42 | Layer.Parameter Shape Param# 43 | ---------------------------------------------------------------------- 44 | conv1.weight [16, 3, 3, 3] 432 45 | conv1.weight [16, 16, 3, 3] 2,304 46 | conv2.weight [16, 16, 3, 3] 2,304 47 | conv1.weight [16, 16, 3, 3] 2,304 48 | conv2.weight [16, 16, 3, 3] 2,304 49 | conv1.weight [16, 16, 3, 3] 2,304 50 | conv2.weight [16, 16, 3, 3] 2,304 51 | conv1.weight [32, 16, 3, 3] 4,608 52 | conv2.weight [32, 32, 3, 3] 9,216 53 | shortcut.weight [32, 16] 512 54 | conv1.weight [32, 32, 3, 3] 9,216 55 | conv2.weight [32, 32, 3, 3] 9,216 56 | conv1.weight [32, 32, 3, 3] 9,216 57 | conv2.weight [32, 32, 3, 3] 9,216 58 | conv1.weight [64, 32, 3, 3] 18,432 59 | conv2.weight [64, 64, 3, 3] 36,864 60 | shortcut.weight [64, 32] 2,048 61 | conv1.weight [64, 64, 3, 3] 36,864 62 | conv2.weight [64, 64, 3, 3] 36,864 63 | conv1.weight [64, 64, 3, 3] 36,864 64 | conv2.weight [64, 64, 3, 3] 36,864 65 | linear.weight [10, 64] 640 66 | linear.bias [10] 10 67 | ---------------------------------------------------------------------- 68 | 69 | Total params: 272,474 70 | 71 | Summaries dir: .../Distributed-Pytorch-Boilerplate/experiments/Model_17/summaries 72 | 73 | --dataset: cifar10 74 | --n_epochs: 1000 75 | --batch_size: 128 76 | --learning_rate: 0.1 77 | --weight_decay: 0.0005 78 | --decay_rate: 0.1 79 | --decay_steps: 0 80 | --optimiser: sgd 81 | --decay_milestones: [0] 82 | --padding: 4 83 | --brightness: 0 84 | --contrast: 0 85 | --patience: 60 86 | --crop_dim: 32 87 | --load_checkpoint_dir: None 88 | --distributed: True 89 | --inference: False 90 | --half_precision: False 91 | --class_names: ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 92 | --n_channels: 3 93 | --n_classes: 10 94 | --summaries_dir: .../Distributed-Pytorch-Boilerplate/experiments/Model_17/summaries 95 | --checkpoint_dir: .../Distributed-Pytorch-Boilerplate/experiments/Model_17/checkpoint.pt 96 | 97 | train: 45000 - valid: 5000 - test: 10000 98 | 99 | Epoch 1/1000: 100 | 101 | 100%|████████████████████████████████████████████████████████████████| 175/175 [00:05<00:00, 30.29it/s] 102 | 103 | [Train] loss: 1.6505 - acc: 0.3918 | [Valid] loss: 1.4670 - acc: 0.4575 - acc_topk: 0.6883 104 | 105 | Epoch 2/1000: 106 | 107 | 100%|████████████████████████████████████████████████████████████████| 175/175 [00:05<00:00, 31.47it/s] 108 | 109 | [Train] loss: 1.1477 - acc: 0.5896 | [Valid] loss: 1.0672 - acc: 0.6206 - acc_topk: 0.8073 110 | 111 | Epoch 3/1000: 112 | 113 | 100%|████████████████████████████████████████████████████████████████| 175/175 [00:05<00:00, 31.60it/s] 114 | 115 | [Train] loss: 0.9342 - acc: 0.6696 | [Valid] loss: 1.1729 - acc: 0.6176 - acc_topk: 0.7985 116 | 117 | Epoch 4/1000: 118 | 119 | 100%|████████████████████████████████████████████████████████████████| 175/175 [00:05<00:00, 32.01it/s] 120 | 121 | [Train] loss: 0.8061 - acc: 0.7194 | [Valid] loss: 0.9524 - acc: 0.6723 - acc_topk: 0.8504 122 | ``` 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Loading of various datasets. 4 | """ 5 | import os 6 | import numpy as np 7 | 8 | import torch 9 | from torch.utils.data import DataLoader, WeightedRandomSampler 10 | from torch.utils.data.distributed import DistributedSampler 11 | 12 | import torchvision 13 | from torchvision import transforms 14 | from torchvision.datasets import CIFAR10, MNIST, SVHN, FashionMNIST 15 | 16 | from utils import * 17 | 18 | 19 | def get_dataloaders(args): 20 | """ Gets the dataloaders for the chosen dataset. 21 | """ 22 | 23 | if args.dataset == 'cifar10': 24 | dataset = 'CIFAR10' 25 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 26 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 27 | 'test': os.path.join(working_dir, 'test')} 28 | 29 | dataloaders = cifar10(args, dataset_paths) 30 | 31 | args.class_names = ( 32 | 'plane', 'car', 'bird', 'cat', 33 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' 34 | ) # 0,1,2,3,4,5,6,7,8,9 labels 35 | args.n_channels, args.n_classes = 3, 10 36 | 37 | elif args.dataset == 'svhn': 38 | dataset = 'SVHN' 39 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 40 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 41 | # 'extra': os.path.join(working_dir,'extra'), 42 | 'test': os.path.join(working_dir, 'test')} 43 | 44 | dataloaders = svhn(args, dataset_paths) 45 | 46 | args.class_names = ( 47 | 'zero', 'one', 'two', 'three', 48 | 'four', 'five', 'six', 'seven', 'eight', 'nine' 49 | ) # 0,1,2,3,4,5,6,7,8,9 labels 50 | args.n_channels, args.n_classes = 3, 10 51 | 52 | elif args.dataset == 'fashionmnist': 53 | dataset = 'FashionMNIST' 54 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 55 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 56 | 'test': os.path.join(working_dir, 'test')} 57 | 58 | dataloaders = fashionmnist(args, dataset_paths) 59 | 60 | args.class_names = ( 61 | 'tshirt', 'trouser', 'pullover', 'dress', 62 | 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankleboot' 63 | ) # 0,1,2,3,4,5,6,7,8,9 labels 64 | args.n_channels, args.n_classes = 1, 10 65 | 66 | elif args.dataset == 'mnist': 67 | dataset = 'MNIST' 68 | working_dir = os.path.join(os.path.split(os.getcwd())[0], 'data', dataset) 69 | dataset_paths = {'train': os.path.join(working_dir, 'train'), 70 | 'test': os.path.join(working_dir, 'test')} 71 | 72 | dataloaders = mnist(args, dataset_paths) 73 | 74 | args.class_names = ( 75 | 'zero', 'one', 'two', 'three', 'four', 76 | 'five', 'six', 'seven', 'eight', 'nine' 77 | ) # 0,1,2,3,4,5,6,7,8,9 labels 78 | args.n_channels, args.n_classes = 1, 10 79 | 80 | else: 81 | NotImplementedError('{} dataset not available.'.format(args.dataset)) 82 | 83 | return dataloaders, args 84 | 85 | 86 | def cifar10(args, dataset_paths): 87 | """ Loads the CIFAR-10 dataset. 88 | Returns: train/valid/test set split dataloaders. 89 | """ 90 | transf = { 91 | 'train': transforms.Compose([ 92 | transforms.RandomHorizontalFlip(0.5), 93 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 94 | transforms.ToTensor(), 95 | # Standardize()]), 96 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 97 | (0.24703223, 0.24348513, 0.26158784))]), 98 | 'test': transforms.Compose([ 99 | transforms.ToTensor(), 100 | # Standardize()])} 101 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 102 | (0.24703223, 0.24348513, 0.26158784))]) 103 | } 104 | 105 | config = {'train': True, 'test': False} 106 | datasets = {i: CIFAR10(root=dataset_paths[i], transform=transf[i], 107 | train=config[i], download=True) for i in config.keys()} 108 | 109 | # weighted sampler weights for full(f) training set 110 | f_s_weights = sample_weights(datasets['train'].targets) 111 | 112 | # return data, labels dicts for new train set and class-balanced valid set 113 | data, labels = random_split(data=datasets['train'].data, 114 | labels=datasets['train'].targets, 115 | n_classes=10, 116 | n_samples_per_class=np.repeat(500, 10).reshape(-1)) 117 | 118 | # define transforms for train set (without valid data) 119 | transf['train_'] = transforms.Compose([ 120 | transforms.ToPILImage(), 121 | transforms.RandomHorizontalFlip(0.5), 122 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 123 | transforms.ToTensor(), 124 | # Standardize()]) 125 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 126 | (0.24703223, 0.24348513, 0.26158784))]) 127 | 128 | # define transforms for class-balanced valid set 129 | transf['valid'] = transforms.Compose([ 130 | transforms.ToPILImage(), 131 | transforms.ToTensor(), 132 | # Standardize()]) 133 | transforms.Normalize((0.49139968, 0.48215841, 0.44653091), 134 | (0.24703223, 0.24348513, 0.26158784))]) 135 | 136 | # save original full training set 137 | datasets['train_valid'] = datasets['train'] 138 | 139 | # make new training set without validation samples 140 | datasets['train'] = CustomDataset(data=data['train'], 141 | labels=labels['train'], transform=transf['train_']) 142 | 143 | # make class balanced validation set 144 | datasets['valid'] = CustomDataset(data=data['valid'], 145 | labels=labels['valid'], transform=transf['valid']) 146 | 147 | # weighted sampler weights for new training set 148 | s_weights = sample_weights(datasets['train'].labels) 149 | 150 | config = { 151 | 'train': WeightedRandomSampler(s_weights, 152 | num_samples=len(s_weights), replacement=True), 153 | 'train_valid': WeightedRandomSampler(f_s_weights, 154 | num_samples=len(f_s_weights), replacement=True), 155 | 'valid': None, 'test': None 156 | } 157 | 158 | if args.distributed: 159 | config = {'train': DistributedSampler(datasets['train']), 160 | 'train_valid': DistributedSampler(datasets['train_valid']), 161 | 'valid': None, 'test': None} 162 | 163 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 164 | num_workers=8, pin_memory=True, drop_last=True, 165 | batch_size=args.batch_size) for i in config.keys()} 166 | 167 | return dataloaders 168 | 169 | 170 | def svhn(args, dataset_paths): 171 | ''' Loads the SVHN dataset. 172 | Returns: train/valid/test set split dataloaders. 173 | ''' 174 | transf = { 175 | 'train': transforms.Compose([ 176 | # transforms.RandomApply([ 177 | # transforms.RandomAffine(30, shear=True)], p=0.5), 178 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 179 | transforms.ColorJitter(brightness=args.brightness, contrast=args.contrast), 180 | transforms.ToTensor(), 181 | # Standardize()]), 182 | transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 183 | (0.19803012, 0.20101562, 0.19703614))]), 184 | # 'extra': transforms.Compose([ 185 | # # transforms.RandomApply([ 186 | # # transforms.RandomAffine(30, shear=True)], p=0.5), 187 | # transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 188 | # transforms.ColorJitter(brightness=args.brightness, contrast=args.contrast), 189 | # transforms.ToTensor(), 190 | # # Standardize()]), 191 | # transforms.Normalize((0.4379, 0.4441, 0.4734), (0.1202, 0.1232, 0.1054))]), 192 | 'test': transforms.Compose([ 193 | transforms.ToTensor(), 194 | # Standardize()])} 195 | transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 196 | (0.19803012, 0.20101562, 0.19703614))]) 197 | } 198 | 199 | # config = {'train': True, 'extra': True, 'test': False} 200 | config = {'train': True, 'test': False} 201 | datasets = {i: SVHN(root=dataset_paths[i], transform=transf[i], 202 | split=i, download=True) for i in config.keys()} 203 | 204 | # weighted sampler weights for full(f) training set 205 | f_s_weights = sample_weights(datasets['train'].labels) 206 | 207 | # return data, labels dicts for new train set and class-balanced valid set 208 | data, labels = random_split(data=datasets['train'].data, 209 | labels=datasets['train'].labels, 210 | n_classes=10, 211 | n_samples_per_class=np.unique( 212 | datasets['test'].labels, return_counts=True)[1] // 3) # fraction of test set per class 213 | 214 | # define transforms for train set (without valid data) 215 | transf['train_'] = transforms.Compose([ 216 | transforms.ToPILImage(), 217 | # transforms.RandomApply([ 218 | # transforms.RandomAffine(30, shear=True)], p=0.5), 219 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 220 | transforms.ColorJitter(brightness=args.brightness, contrast=args.contrast), 221 | transforms.ToTensor(), 222 | # Standardize()]) 223 | transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 224 | (0.19803012, 0.20101562, 0.19703614))]) 225 | 226 | # define transforms for class-balanced valid set 227 | transf['valid'] = transforms.Compose([ 228 | transforms.ToPILImage(), 229 | transforms.ToTensor(), 230 | # Standardize()]) 231 | transforms.Normalize((0.4376821, 0.4437697, 0.47280442), 232 | (0.19803012, 0.20101562, 0.19703614))]) 233 | 234 | # save original full training set 235 | datasets['train_valid'] = datasets['train'] 236 | 237 | # make channels last and convert to np arrays 238 | data['train'] = np.moveaxis(np.array(data['train']), 1, -1) 239 | data['valid'] = np.moveaxis(np.array(data['valid']), 1, -1) 240 | 241 | # make new training set without validation samples 242 | datasets['train'] = CustomDataset(data=data['train'], 243 | labels=labels['train'], transform=transf['train_']) 244 | 245 | # make class balanced validation set 246 | datasets['valid'] = CustomDataset(data=data['valid'], 247 | labels=labels['valid'], transform=transf['valid']) 248 | 249 | # weighted sampler weights for new training set 250 | s_weights = sample_weights(datasets['train'].labels) 251 | 252 | config = { 253 | 'train': WeightedRandomSampler(s_weights, 254 | num_samples=len(s_weights), replacement=True), 255 | 'train_valid': WeightedRandomSampler(f_s_weights, 256 | num_samples=len(f_s_weights), replacement=True), 257 | 'valid': None, 'test': None} 258 | 259 | if args.distributed: 260 | config = {'train': DistributedSampler(datasets['train']), 261 | 'train_valid': DistributedSampler(datasets['train_valid']), 262 | 'valid': None, 'test': None} 263 | 264 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 265 | num_workers=8, pin_memory=True, drop_last=True, 266 | batch_size=args.batch_size) for i in config.keys()} 267 | 268 | return dataloaders 269 | 270 | 271 | def fashionmnist(args, dataset_paths): 272 | ''' Loads the Fashion-MNIST dataset. 273 | Returns: train/valid/test set split dataloaders. 274 | ''' 275 | transf = { 276 | 'train': transforms.Compose([ 277 | transforms.RandomHorizontalFlip(p=0.5), 278 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 279 | transforms.ToTensor(), 280 | # Standardize()]), 281 | transforms.Normalize((0.28604059,), (0.35302424,))]), 282 | 'test': transforms.Compose([ 283 | # transforms.Grayscale(num_output_channels=3), 284 | transforms.Pad(np.maximum(0, (args.crop_dim-28) // 2)), 285 | # transforms.CenterCrop((args.crop_dim, args.crop_dim)), 286 | transforms.ToTensor(), 287 | # Standardize()])} 288 | transforms.Normalize((0.28604059,), (0.35302424,))]) 289 | } 290 | 291 | config = {'train': True, 'test': False} 292 | datasets = {i: FashionMNIST(root=dataset_paths[i], transform=transf[i], 293 | train=config[i], download=True) for i in config.keys()} 294 | 295 | # weighted sampler weights for full(f) training set 296 | f_s_weights = sample_weights(datasets['train'].targets) 297 | 298 | # return data, labels dicts for new train set and class-balanced valid set 299 | data, labels = random_split(data=datasets['train'].data, 300 | labels=datasets['train'].targets, 301 | n_classes=10, 302 | n_samples_per_class=np.unique( 303 | datasets['test'].targets, return_counts=True)[1] // 2) # half of test set per class 304 | 305 | # define transforms for train set (without valid data) 306 | transf['train_'] = transforms.Compose([ 307 | transforms.ToPILImage(), 308 | transforms.RandomHorizontalFlip(p=0.5), 309 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 310 | transforms.ToTensor(), 311 | # Standardize()]) 312 | transforms.Normalize((0.28604059,), (0.35302424,))]) 313 | 314 | # define transforms for class-balanced valid set 315 | transf['valid'] = transforms.Compose([ 316 | transforms.ToPILImage(), 317 | transforms.Pad(np.maximum(0, (args.crop_dim-28) // 2)), 318 | transforms.ToTensor(), 319 | # Standardize()]) 320 | transforms.Normalize((0.28604059,), (0.35302424,))]) 321 | 322 | # save original full training set 323 | datasets['train_valid'] = datasets['train'] 324 | 325 | # make new training set without validation samples 326 | datasets['train'] = CustomDataset(data=data['train'], 327 | labels=labels['train'], transform=transf['train_']) 328 | 329 | # make class balanced validation set 330 | datasets['valid'] = CustomDataset(data=data['valid'], 331 | labels=labels['valid'], transform=transf['valid']) 332 | 333 | # weighted sampler weights for new training set 334 | s_weights = sample_weights(datasets['train'].labels) 335 | 336 | config = { 337 | 'train': WeightedRandomSampler(s_weights, 338 | num_samples=len(s_weights), replacement=True), 339 | 'train_valid': WeightedRandomSampler(f_s_weights, 340 | num_samples=len(f_s_weights), replacement=True), 341 | 'valid': None, 'test': None} 342 | 343 | if args.distributed: 344 | config = {'train': DistributedSampler(datasets['train']), 345 | 'train_valid': DistributedSampler(datasets['train_valid']), 346 | 'valid': None, 'test': None} 347 | 348 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 349 | num_workers=8, pin_memory=True, drop_last=True, 350 | batch_size=args.batch_size) for i in config.keys()} 351 | 352 | return dataloaders 353 | 354 | 355 | def mnist(args, dataset_paths): 356 | ''' Loads the MNIST dataset. 357 | Returns: train/valid/test set split dataloaders. 358 | ''' 359 | transf = { 360 | 'train': transforms.Compose([ 361 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 362 | transforms.ToTensor(), 363 | transforms.Normalize((0.13066047,), (0.30810780,)) 364 | ]), 365 | 'test': transforms.Compose([ 366 | transforms.Pad(np.maximum(0, (args.crop_dim-28) // 2)), 367 | transforms.ToTensor(), 368 | transforms.Normalize((0.13066047,), (0.30810780,))]) 369 | } 370 | 371 | config = {'train': True, 'test': False} 372 | datasets = {i: MNIST(root=dataset_paths[i], transform=transf[i], 373 | train=config[i], download=True) for i in config.keys()} 374 | 375 | # split train into train and class-balanced valid set 376 | data, labels = random_split(data=datasets['train'].data, 377 | labels=datasets['train'].targets, 378 | n_classes=10, 379 | n_samples_per_class=np.repeat(500, 10)) # 500 per class 380 | 381 | # define transforms for train set (without valid data) 382 | transf['train_'] = transforms.Compose([ 383 | transforms.ToPILImage(), 384 | transforms.RandomCrop((args.crop_dim, args.crop_dim), padding=args.padding), 385 | transforms.ToTensor(), 386 | transforms.Normalize((0.13066047,), (0.30810780,))]) 387 | 388 | # define transforms for class-balanced valid set 389 | transf['valid'] = transforms.Compose([ 390 | transforms.ToPILImage(), 391 | transforms.Pad(np.maximum(0, (args.crop_dim-28) // 2)), 392 | transforms.ToTensor(), 393 | transforms.Normalize((0.13066047,), (0.30810780,))]) 394 | 395 | # save original full training set 396 | datasets['train_valid'] = datasets['train'] 397 | 398 | # make new training set without validation samples 399 | datasets['train'] = CustomDataset(data=data['train'], 400 | labels=labels['train'], transform=transf['train_']) 401 | 402 | # make class balanced validation set 403 | datasets['valid'] = CustomDataset(data=data['valid'], 404 | labels=labels['valid'], transform=transf['valid']) 405 | 406 | if args.distributed: 407 | config = {'train': DistributedSampler(datasets['train']), 408 | 'train_valid': DistributedSampler(datasets['train_valid']), 409 | 'valid': None, 'test': None} 410 | 411 | dataloaders = {i: DataLoader(datasets[i], sampler=config[i], 412 | num_workers=8, pin_memory=True, drop_last=True, 413 | batch_size=args.batch_size) for i in config.keys()} 414 | else: 415 | config = {'train': True, 'train_valid': True, 416 | 'valid': False, 'test': False} 417 | 418 | dataloaders = {i: DataLoader(datasets[i], shuffle=config[i], 419 | num_workers=8, pin_memory=True, drop_last=True, 420 | batch_size=args.batch_size) for i in config.keys()} 421 | return dataloaders 422 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Evaluation loop. 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def evaluate(model, dataloader, args): 11 | """ Evaluates a given model and dataset. 12 | """ 13 | model.eval() 14 | sample_count = 0 15 | running_loss = 0 16 | running_acc = 0 17 | running_acc_topk = 0 18 | k = 2 19 | 20 | with torch.no_grad(): 21 | 22 | for inputs, labels in dataloader: 23 | 24 | if args.half_precision: 25 | inputs = inputs.type(torch.HalfTensor).cuda(non_blocking=True) 26 | else: 27 | inputs = inputs.type(torch.FloatTensor).cuda(non_blocking=True) 28 | labels = labels.type(torch.LongTensor).cuda(non_blocking=True) 29 | 30 | yhat = model(inputs) 31 | loss = F.nll_loss(F.log_softmax(yhat), labels) 32 | 33 | sample_count += inputs.size(0) 34 | running_loss += loss.item() * inputs.size(0) # smaller batches count less 35 | running_acc += (yhat.argmax(-1) == labels).sum().item() # num corrects 36 | _, yhat = yhat.topk(k, 1, True, True) 37 | running_acc_topk += (yhat == labels.view(-1, 1).expand_as(yhat) 38 | ).sum().item() # num corrects 39 | 40 | loss = running_loss / sample_count 41 | acc = running_acc / sample_count 42 | top_k_acc = running_acc_topk / sample_count 43 | 44 | return loss, (acc, top_k_acc) 45 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Distributed training using Pytorch boilerplate. 5 | """ 6 | import os 7 | import logging 8 | import random 9 | import argparse 10 | import warnings 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | import resnet 18 | from train import train 19 | from evaluate import evaluate 20 | from datasets import get_dataloaders 21 | from utils import experiment_config, print_network 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | PARSER = argparse.ArgumentParser() 26 | PARSER.add_argument('--dataset', default='cifar10', 27 | help='e.g. cifar10, svhn, fashionmnist, mnist') 28 | PARSER.add_argument('--n_epochs', type=int, default=1000, 29 | help='number of epochs to train for.') 30 | PARSER.add_argument('--batch_size', type=int, default=128, 31 | help='number of images used to approx. gradient.') 32 | PARSER.add_argument('--learning_rate', type=float, default=.1, 33 | help='step size.') 34 | PARSER.add_argument('--weight_decay', type=float, default=5e-4, 35 | help='weight decay regularisation factor.') 36 | PARSER.add_argument('--decay_rate', type=float, default=0.1, 37 | help='factor to multiply with learning rate.') 38 | PARSER.add_argument('--decay_steps', type=int, default=0, 39 | help='decay learning rate every n steps.') 40 | PARSER.add_argument('--optimiser', default='sgd', 41 | help='e.g. sgd, adam') 42 | PARSER.add_argument('--decay_milestones', nargs='+', type=int, default=[0], 43 | help='epochs at which to multiply learning rate with decay rate.') 44 | PARSER.add_argument('--padding', type=int, default=4, 45 | help='padding augmentation factor.') 46 | PARSER.add_argument('--brightness', type=float, default=0, 47 | help='brightness augmentation factor.') 48 | PARSER.add_argument('--contrast', type=float, default=0, 49 | help='contrast augmentation factor.') 50 | PARSER.add_argument('--patience', default=60, 51 | help='number of epochs to wait for improvement.') 52 | PARSER.add_argument('--crop_dim', type=int, default=32, 53 | help='height and width of input cropping.') 54 | PARSER.add_argument('--load_checkpoint_dir', default=None, 55 | help='directory to load a checkpoint from.') 56 | PARSER.add_argument('--no_distributed', dest='distributed', action='store_false', 57 | help='choose whether or not to use distributed training.') 58 | PARSER.set_defaults(distributed=True) 59 | PARSER.add_argument('--inference', dest='inference', action='store_true', 60 | help='infer from checkpoint rather than training.') 61 | PARSER.set_defaults(inference=False) 62 | PARSER.add_argument('--half_precision', dest='half_precision', action='store_true', 63 | help='train using fp16.') 64 | PARSER.set_defaults(half_precision=False) 65 | 66 | 67 | def setup(distributed): 68 | """ Sets up for optional distributed training. 69 | 70 | For distributed training run as: 71 | python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=2 --use_env main.py 72 | To kill zombie processes use: 73 | kill $(ps aux | grep "main.py" | grep -v grep | awk '{print $2}') 74 | 75 | For data parallel training on GPUs or CPU training run as: 76 | python main.py --no_distributed 77 | """ 78 | if distributed: 79 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 80 | local_rank = int(os.environ.get('LOCAL_RANK')) 81 | device = torch.device(f'cuda:{local_rank}') # unique on individual node 82 | 83 | print('World size: {} ; Rank: {} ; LocalRank: {} ; Master: {}:{}'.format( 84 | os.environ.get('WORLD_SIZE'), 85 | os.environ.get('RANK'), 86 | os.environ.get('LOCAL_RANK'), 87 | os.environ.get('MASTER_ADDR'), os.environ.get('MASTER_PORT'))) 88 | else: 89 | local_rank = None 90 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 91 | 92 | seed = 8 # 666 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | 99 | torch.backends.cudnn.enabled = True 100 | torch.backends.cudnn.deterministic = True 101 | torch.backends.cudnn.benchmark = False # True 102 | 103 | return device, local_rank 104 | 105 | 106 | def main(): 107 | """ Main method. """ 108 | args = PARSER.parse_known_args()[0] 109 | 110 | # sets up the backend for distributed training (optional) 111 | device, local_rank = setup(distributed=args.distributed) 112 | 113 | # retrieve the dataloaders for the chosen dataset 114 | dataloaders, args = get_dataloaders(args) 115 | 116 | # make dirs for current experiment logs, summaries etc 117 | args = experiment_config(args) 118 | 119 | # initialise the model 120 | model = resnet.resnet20(args) 121 | 122 | # place model onto GPU(s) 123 | if args.distributed: 124 | torch.cuda.set_device(device) 125 | torch.set_num_threads(1) # n cpu threads / n processes per node 126 | model = DistributedDataParallel(model.cuda(), 127 | device_ids=[local_rank], output_device=local_rank) 128 | # only print stuff from process (rank) 0 129 | args.print_progress = True if int(os.environ.get('RANK')) == 0 else False 130 | else: 131 | if args.half_precision: 132 | model.half() # convert to half precision 133 | for layer in model.modules(): 134 | # keep batchnorm in 32 for convergence reasons 135 | if isinstance(layer, nn.BatchNorm2d): 136 | layer.float() 137 | 138 | if torch.cuda.device_count() > 1: 139 | model = nn.DataParallel(model) 140 | print('\nUsing', torch.cuda.device_count(), 'GPU(s).\n') 141 | model.to(device) 142 | args.print_progress = True 143 | 144 | if args.print_progress: 145 | print_network(model, args) # prints out the network architecture etc 146 | logging.info('\ntrain: {} - valid: {} - test: {}'.format( 147 | len(dataloaders['train'].dataset), len(dataloaders['valid'].dataset), 148 | len(dataloaders['test'].dataset))) 149 | 150 | # launch model training or inference 151 | if not args.inference: 152 | train(model, dataloaders, args) 153 | 154 | if args.distributed: # cleanup 155 | torch.distributed.destroy_process_group() 156 | else: 157 | model.load_state_dict(torch.load(args.load_checkpoint_dir)) 158 | test_loss, test_acc = evaluate(model, args, dataloaders['test']) 159 | print('[Test] loss {:.4f} - acc {:.4f} - acc_topk {:.4f}'.format( 160 | test_loss, test_acc[0], test_acc[1])) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /src/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Taken from: https://github.com/akamaster/pytorch_resnet_cifar10 as an example. 3 | ----------------------------------------------------------------------------- 4 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 5 | 6 | The implementation and structure of this file is hugely influenced by [2] 7 | which is implemented for ImageNet and doesn't have option A for identity. 8 | Moreover, most of the implementations on the web is copy-paste from 9 | torchvision's resnet and has wrong number of params. 10 | 11 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 12 | number of layers and parameters: 13 | 14 | name | layers | params 15 | ResNet20 | 20 | 0.27M 16 | ResNet32 | 32 | 0.46M 17 | ResNet44 | 44 | 0.66M 18 | ResNet56 | 56 | 0.85M 19 | ResNet110 | 110 | 1.7M 20 | ResNet1202| 1202 | 19.4m 21 | 22 | which this implementation indeed has. 23 | 24 | Reference: 25 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 26 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 27 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 28 | 29 | If you use this implementation in you work, please don't forget to mention the 30 | author, Yerlan Idelbayev. 31 | ''' 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | import torch.nn.init as init 36 | 37 | import collections as co 38 | from torch.autograd import Variable 39 | 40 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 41 | 42 | 43 | def _weights_init(m): 44 | classname = m.__class__.__name__ 45 | # print(classname) 46 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 47 | init.kaiming_normal_(m.weight) 48 | 49 | 50 | class LambdaLayer(nn.Module): 51 | def __init__(self, lambd): 52 | super(LambdaLayer, self).__init__() 53 | self.lambd = lambd 54 | 55 | def forward(self, x): 56 | return self.lambd(x) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | expansion = 1 61 | 62 | def __init__(self, in_planes, planes, stride=1, option='B', ReZero=True): 63 | super(BasicBlock, self).__init__() 64 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 65 | stride=stride, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | 70 | self.shortcut = nn.Sequential() 71 | if stride != 1 or in_planes != planes: 72 | if option == 'A': 73 | """ 74 | For CIFAR10 ResNet paper uses option A. 75 | """ 76 | self.shortcut = LambdaLayer(lambda x: 77 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 78 | elif option == 'B': 79 | self.shortcut = nn.Sequential(co.OrderedDict([ 80 | ('shortcut', nn.Conv2d(in_planes, self.expansion * planes, 81 | kernel_size=1, stride=stride, bias=False)), 82 | ('bn', nn.BatchNorm2d(self.expansion * planes)) 83 | ])) 84 | 85 | self.ReZero = ReZero 86 | if self.ReZero: 87 | self.alpha_i = nn.Parameter(torch.zeros(1)) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.bn2(self.conv2(out)) 92 | 93 | if self.ReZero: 94 | out = self.alpha_i * out + self.shortcut(x) 95 | else: 96 | out += self.shortcut(x) 97 | out = F.relu(out) 98 | return out 99 | 100 | class ResNet(nn.Module): 101 | def __init__(self, block, num_blocks, in_channels, num_classes): 102 | super(ResNet, self).__init__() 103 | self.in_planes = 16 104 | 105 | self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(16) 107 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 108 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 109 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 110 | self.linear = nn.Linear(64, num_classes) 111 | 112 | self.apply(_weights_init) 113 | 114 | def _make_layer(self, block, planes, num_blocks, stride): 115 | strides = [stride] + [1]*(num_blocks-1) 116 | layers = [] 117 | for stride in strides: 118 | layers.append(block(self.in_planes, planes, stride)) 119 | self.in_planes = planes * block.expansion 120 | 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | out = F.relu(self.bn1(self.conv1(x))) 125 | out = self.layer1(out) 126 | out = self.layer2(out) 127 | out = self.layer3(out) 128 | out = F.avg_pool2d(out, out.size()[3]) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | def resnet20(args): 135 | return ResNet(BasicBlock, [3, 3, 3], args.n_channels, args.n_classes) 136 | 137 | 138 | def resnet32(args): 139 | return ResNet(BasicBlock, [5, 5, 5], args.n_channels, args.n_classes) 140 | 141 | 142 | def resnet44(args): 143 | return ResNet(BasicBlock, [7, 7, 7], args.n_channels, args.n_classes) 144 | 145 | 146 | def resnet56(args): 147 | return ResNet(BasicBlock, [9, 9, 9], args.n_channels, args.n_classes) 148 | 149 | 150 | def resnet110(args): 151 | return ResNet(BasicBlock, [18, 18, 18], args.n_channels, args.n_classes) 152 | 153 | 154 | def resnet1202(args): 155 | return ResNet(BasicBlock, [200, 200, 200], args.n_channels, args.n_classes) 156 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Training loop. 4 | """ 5 | import gc 6 | import time 7 | import logging 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.optim.lr_scheduler as lr_scheduler 15 | import torch.nn.functional as F 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from evaluate import evaluate 19 | 20 | 21 | def train(model, dataloaders, args): 22 | """ Trains a given model and dataset. 23 | """ 24 | # optimisers 25 | if args.optimiser == 'adam': 26 | optimiser = optim.Adam(model.parameters(), lr=args.learning_rate, 27 | weight_decay=args.weight_decay) 28 | elif args.optimiser == 'sgd': 29 | optimiser = optim.SGD(model.parameters(), lr=args.learning_rate, 30 | weight_decay=args.weight_decay, momentum=0.9) 31 | else: 32 | raise NotImplementedError('{} not setup.'.format(args.optimiser)) 33 | 34 | # lr schedulers 35 | if args.decay_steps > 0: 36 | lr_decay = lr_scheduler.ExponentialLR(optimiser, gamma=args.decay_rate) 37 | elif args.decay_milestones[0] > 0: 38 | lr_decay = lr_scheduler.MultiStepLR(optimiser, milestones=args.decay_milestones, 39 | gamma=args.decay_rate) 40 | else: 41 | lr_decay = lr_scheduler.ReduceLROnPlateau(optimiser, factor=args.decay_rate, mode='max', 42 | patience=30, cooldown=20, min_lr=1e-6, verbose=True) 43 | 44 | args.writer = SummaryWriter(args.summaries_dir) 45 | best_valid_loss = np.inf 46 | best_valid_acc = 0 47 | patience_counter = 0 48 | 49 | since = time.time() 50 | for epoch in range(args.n_epochs): 51 | 52 | model.train() 53 | sample_count = 0 54 | running_loss = 0 55 | running_acc = 0 56 | 57 | if args.print_progress: 58 | logging.info('\nEpoch {}/{}:\n'.format(epoch+1, args.n_epochs)) 59 | # tqdm for process (rank) 0 only when using distributed training 60 | train_dataloader = tqdm(dataloaders['train']) 61 | else: 62 | train_dataloader = dataloaders['train'] 63 | 64 | for i, (inputs, labels) in enumerate(train_dataloader): 65 | args.step = (epoch * len(dataloaders['train'])) + i + 1 # calc current step 66 | 67 | if args.half_precision: 68 | inputs = inputs.type(torch.HalfTensor).cuda(non_blocking=True) 69 | else: 70 | inputs = inputs.type(torch.FloatTensor).cuda(non_blocking=True) 71 | labels = labels.cuda(non_blocking=True) 72 | 73 | optimiser.zero_grad() 74 | yhat = model(inputs) 75 | loss = F.nll_loss(F.log_softmax(yhat), labels) 76 | loss.backward() 77 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 78 | optimiser.step() 79 | 80 | sample_count += inputs.size(0) 81 | running_loss += loss.item() * inputs.size(0) # smaller batches count less 82 | running_acc += (yhat.argmax(-1) == labels).sum().item() # num corrects 83 | 84 | if args.print_progress: 85 | # inspect gradient L2 norm 86 | total_norm = torch.zeros(1).cuda() 87 | for name, param in model.named_parameters(): 88 | try: 89 | total_norm += param.grad.data.norm(2)**2 90 | except: 91 | pass 92 | total_norm = total_norm**(1/2) 93 | args.writer.add_scalar('grad_L2_norm', total_norm, args.step) 94 | 95 | epoch_train_loss = running_loss / sample_count 96 | epoch_train_acc = running_acc / sample_count 97 | 98 | # reduce lr 99 | if args.decay_steps > 0 or args.decay_milestones[0] > 0: 100 | lr_decay.step() 101 | else: # reduce on plateau, evaluate to keep track of acc in each process 102 | epoch_valid_loss, epoch_valid_acc = evaluate(model, dataloaders['valid'], args) 103 | lr_decay.step(epoch_valid_acc[0]) 104 | 105 | if args.print_progress: # only validate using process 0 106 | if epoch_valid_loss is None: # check if process 0 already validated 107 | epoch_valid_loss, epoch_valid_acc = evaluate(model, dataloaders['valid'], args) 108 | 109 | logging.info('\n[Train] loss: {:.4f} - acc: {:.4f} | [Valid] loss: {:.4f} - acc: {:.4f} - acc_topk: {:.4f}'.format( 110 | epoch_train_loss, epoch_train_acc, 111 | epoch_valid_loss, epoch_valid_acc[0], epoch_valid_acc[1])) 112 | 113 | epoch_valid_acc = epoch_valid_acc[0] # discard top k acc 114 | args.writer.add_scalars('epoch_loss', {'train': epoch_train_loss, 115 | 'valid': epoch_valid_loss}, epoch+1) 116 | args.writer.add_scalars('epoch_acc', {'train': epoch_train_acc, 117 | 'valid': epoch_valid_acc}, epoch+1) 118 | args.writer.add_scalars('epoch_error', {'train': 1-epoch_train_acc, 119 | 'valid': 1-epoch_valid_acc}, epoch+1) 120 | 121 | # # inspect weights and gradients 122 | # for name, p in model.named_parameters(): 123 | # try: 124 | # name = '.'.join(name.split('.')[1:]) \ 125 | # if name.split('.')[0] == 'module' else name 126 | # args.writer.add_histogram(name, p.data.clone().cpu().numpy(), args.step) 127 | # args.writer.add_histogram(name, p.data.grad.clone().cpu().numpy(), args.step) 128 | # except: 129 | # pass 130 | 131 | # save model and early stopping 132 | if epoch_valid_acc >= best_valid_acc: 133 | patience_counter = 0 134 | best_epoch = epoch + 1 135 | best_valid_acc = epoch_valid_acc 136 | best_valid_loss = epoch_valid_loss 137 | # saving using process (rank) 0 only as all processes are in sync 138 | torch.save(model.state_dict(), args.checkpoint_dir) 139 | else: 140 | patience_counter += 1 141 | if patience_counter == (args.patience-10): 142 | logging.info('\nPatience counter {}/{}.'.format( 143 | patience_counter, args.patience)) 144 | elif patience_counter == args.patience: 145 | logging.info('\nEarly stopping... no improvement after {} Epochs.'.format( 146 | args.patience)) 147 | break 148 | epoch_valid_loss = None # reset loss 149 | 150 | gc.collect() # release unreferenced memory 151 | 152 | if args.print_progress: 153 | time_elapsed = time.time() - since 154 | logging.info('\nTraining time: {:.0f}m {:.0f}s'.format( 155 | time_elapsed // 60, time_elapsed % 60)) 156 | 157 | model.load_state_dict(torch.load(args.checkpoint_dir)) # load best model 158 | 159 | test_loss, test_acc = evaluate(model, dataloaders['test'], args) 160 | 161 | logging.info('\nBest [Valid] | epoch: {} - loss: {:.4f} - acc: {:.4f}'.format( 162 | best_epoch, best_valid_loss, best_valid_acc)) 163 | logging.info('[Test] loss {:.4f} - acc: {:.4f} - acc_topk: {:.4f}'.format( 164 | test_loss, test_acc[0], test_acc[1])) 165 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | A few useful utilities. 4 | """ 5 | import os 6 | import logging 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class CustomDataset(Dataset): 14 | """ Creates a custom pytorch dataset, mainly 15 | used for creating validation set splits. 16 | """ 17 | 18 | def __init__(self, data, labels, transform): 19 | # shuffle the dataset 20 | idx = np.random.permutation(data.shape[0]) 21 | if isinstance(data, torch.Tensor): 22 | data = data.numpy() # to work with `ToPILImage' 23 | self.data = data[idx] 24 | self.labels = labels[idx] 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | return self.data.shape[0] 29 | 30 | def __getitem__(self, idx): 31 | return self.transform(self.data[idx]), self.labels[idx] 32 | 33 | 34 | def random_split(data, labels, n_classes, n_samples_per_class): 35 | """ Creates a class-balanced validation set from a training set. 36 | """ 37 | train_x, train_y, valid_x, valid_y = [], [], [], [] 38 | 39 | if isinstance(labels, list): 40 | labels = np.array(labels) 41 | 42 | for i in range(n_classes): 43 | # get indices of all class 'c' samples 44 | c_idx = (np.array(labels) == i).nonzero()[0] 45 | # get n unique class 'c' samples 46 | valid_samples = np.random.choice(c_idx, n_samples_per_class[i], replace=False) 47 | # get remaining samples of class 'c' 48 | train_samples = np.setdiff1d(c_idx, valid_samples) 49 | # assign class c samples to validation, and remaining to training 50 | train_x.extend(data[train_samples]) 51 | train_y.extend(labels[train_samples]) 52 | valid_x.extend(data[valid_samples]) 53 | valid_y.extend(labels[valid_samples]) 54 | 55 | if isinstance(data, torch.Tensor): 56 | # torch.stack transforms list of tensors to tensor 57 | return {'train': torch.stack(train_x), 'valid': torch.stack(valid_x)}, \ 58 | {'train': torch.stack(train_y), 'valid': torch.stack(valid_y)} 59 | # transforms list of np arrays to tensor 60 | return {'train': torch.from_numpy(np.stack(train_x)), 61 | 'valid': torch.from_numpy(np.stack(valid_x))}, \ 62 | {'train': torch.from_numpy(np.stack(train_y)), 63 | 'valid': torch.from_numpy(np.stack(valid_y))} 64 | 65 | 66 | def sample_weights(labels): 67 | """ Calculates per sample weights. """ 68 | class_sample_count = np.unique(labels, return_counts=True)[1] 69 | class_weights = 1. / torch.Tensor(class_sample_count) 70 | return class_weights[list(map(int, labels))] 71 | 72 | 73 | class Standardize(object): 74 | """ Standardizes a 'PIL Image' such that each channel 75 | gets zero mean and unit variance. """ 76 | 77 | def __call__(self, img): 78 | return (img - img.mean(dim=(1, 2), keepdim=True)) \ 79 | / torch.clamp(img.std(dim=(1, 2), keepdim=True), min=1e-8) 80 | 81 | def __repr__(self): 82 | return self.__class__.__name__ + '()' 83 | 84 | 85 | def experiment_config(args): 86 | """ Handles experiment configuration and creates new dirs for model. 87 | """ 88 | # check number of models already saved in 'experiments' dir, add 1 to get new model number 89 | experiments_dir = os.path.join(os.path.split(os.getcwd())[0], 'experiments') 90 | os.makedirs(experiments_dir, exist_ok=True) 91 | model_num = len(os.listdir(experiments_dir)) + 1 92 | 93 | # create all save dirs 94 | model_dir = os.path.join(os.path.split(os.getcwd())[0], 95 | 'experiments', 'Model_'+str(model_num)) 96 | args.summaries_dir = os.path.join(model_dir, 'summaries') 97 | args.checkpoint_dir = os.path.join(model_dir, 'checkpoint.pt') 98 | os.makedirs(model_dir, exist_ok=True) 99 | os.makedirs(args.summaries_dir, exist_ok=True) 100 | 101 | # save hyperparameters in .txt file 102 | with open(os.path.join(model_dir, 'hyperparams.txt'), 'w') as logs: 103 | for key, value in vars(args).items(): 104 | logs.write('--{0}={1} '.format(str(key), str(value))) 105 | 106 | # reset root logger 107 | [logging.root.removeHandler(handler) for handler in logging.root.handlers[:]] 108 | # info logger for saving command line outputs during training 109 | logging.basicConfig(level=logging.INFO, format='%(message)s', 110 | handlers=[logging.FileHandler(os.path.join(model_dir, 'trainlogs.txt')), 111 | logging.StreamHandler()]) 112 | return args 113 | 114 | 115 | def print_network(model, args): 116 | """ Utility for printing out a model's architecture. 117 | """ 118 | logging.info('-'*70) # print some info on architecture 119 | logging.info('{:>25} {:>27} {:>15}'.format('Layer.Parameter', 'Shape', 'Param#')) 120 | logging.info('-'*70) 121 | 122 | for param in model.state_dict(): 123 | p_name = param.split('.')[-2]+'.'+param.split('.')[-1] 124 | # don't print batch norm layers for prettyness 125 | if p_name[:2] != 'BN' and p_name[:2] != 'bn': 126 | logging.info( 127 | '{:>25} {:>27} {:>15}'.format( 128 | p_name, 129 | str(list(model.state_dict()[param].squeeze().size())), 130 | '{0:,}'.format(np.product(list(model.state_dict()[param].size()))) 131 | ) 132 | ) 133 | logging.info('-'*70) 134 | 135 | logging.info('\nTotal params: {:,}\n\nSummaries dir: {}\n'.format( 136 | sum(p.numel() for p in model.parameters()), 137 | args.summaries_dir)) 138 | 139 | for key, value in vars(args).items(): 140 | if str(key) != 'print_progress': 141 | logging.info('--{0}: {1}'.format(str(key), str(value))) 142 | --------------------------------------------------------------------------------