├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── config.json ├── data_loader └── data_loaders.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── model ├── loss.py ├── metric.py └── model.py ├── new_project.py ├── parse_config.py ├── test.py ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py └── util.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = F401, F403 3 | max-line-length = 120 4 | exclude = 5 | .git, 6 | __pycache__, 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | data/ 105 | input/ 106 | saved/ 107 | datasets/ 108 | 109 | # editor, os cache directory 110 | .vscode/ 111 | .idea/ 112 | __MACOSX/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Victor Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Template Project 2 | PyTorch deep learning project made easy. 3 | 4 | 5 | 6 | 7 | 8 | * [PyTorch Template Project](#pytorch-template-project) 9 | * [Requirements](#requirements) 10 | * [Features](#features) 11 | * [Folder Structure](#folder-structure) 12 | * [Usage](#usage) 13 | * [Config file format](#config-file-format) 14 | * [Using config files](#using-config-files) 15 | * [Resuming from checkpoints](#resuming-from-checkpoints) 16 | * [Using Multiple GPU](#using-multiple-gpu) 17 | * [Customization](#customization) 18 | * [Custom CLI options](#custom-cli-options) 19 | * [Data Loader](#data-loader) 20 | * [Trainer](#trainer) 21 | * [Model](#model) 22 | * [Loss and metrics](#loss-and-metrics) 23 | * [Multiple metrics](#multiple-metrics) 24 | * [Additional logging](#additional-logging) 25 | * [Validation data](#validation-data) 26 | * [Checkpoints](#checkpoints) 27 | * [Tensorboard Visualization](#tensorboard-visualization) 28 | * [Contribution](#contribution) 29 | * [TODOs](#todos) 30 | * [License](#license) 31 | * [Acknowledgements](#acknowledgments) 32 | 33 | 34 | 35 | ## Requirements 36 | * Python >= 3.5 (3.6 recommended) 37 | * PyTorch >= 0.4 (1.2 recommended) 38 | * tqdm (Optional for `test.py`) 39 | * tensorboard >= 1.14 (see [Tensorboard Visualization][#tensorboardx-visualization]) 40 | 41 | ## Features 42 | * Clear folder structure which is suitable for many deep learning projects. 43 | * `.json` config file support for convenient parameter tuning. 44 | * Customizable command line options for more convenient parameter tuning. 45 | * Checkpoint saving and resuming. 46 | * Abstract base classes for faster development: 47 | * `BaseTrainer` handles checkpoint saving/resuming, training process logging, and more. 48 | * `BaseDataLoader` handles batch generation, data shuffling, and validation data splitting. 49 | * `BaseModel` provides basic model summary. 50 | 51 | ## Folder Structure 52 | ``` 53 | pytorch-template/ 54 | │ 55 | ├── train.py - main script to start training 56 | ├── test.py - evaluation of trained model 57 | │ 58 | ├── config.json - holds configuration for training 59 | ├── parse_config.py - class to handle config file and cli options 60 | │ 61 | ├── new_project.py - initialize new project with template files 62 | │ 63 | ├── base/ - abstract base classes 64 | │ ├── base_data_loader.py 65 | │ ├── base_model.py 66 | │ └── base_trainer.py 67 | │ 68 | ├── data_loader/ - anything about data loading goes here 69 | │ └── data_loaders.py 70 | │ 71 | ├── data/ - default directory for storing input data 72 | │ 73 | ├── model/ - models, losses, and metrics 74 | │ ├── model.py 75 | │ ├── metric.py 76 | │ └── loss.py 77 | │ 78 | ├── saved/ 79 | │ ├── models/ - trained models are saved here 80 | │ └── log/ - default logdir for tensorboard and logging output 81 | │ 82 | ├── trainer/ - trainers 83 | │ └── trainer.py 84 | │ 85 | ├── logger/ - module for tensorboard visualization and logging 86 | │ ├── visualization.py 87 | │ ├── logger.py 88 | │ └── logger_config.json 89 | │ 90 | └── utils/ - small utility functions 91 | ├── util.py 92 | └── ... 93 | ``` 94 | 95 | ## Usage 96 | The code in this repo is an MNIST example of the template. 97 | Try `python train.py -c config.json` to run code. 98 | 99 | ### Config file format 100 | Config files are in `.json` format: 101 | ```javascript 102 | { 103 | "name": "Mnist_LeNet", // training session name 104 | "n_gpu": 1, // number of GPUs to use for training. 105 | 106 | "arch": { 107 | "type": "MnistModel", // name of model architecture to train 108 | "args": { 109 | 110 | } 111 | }, 112 | "data_loader": { 113 | "type": "MnistDataLoader", // selecting data loader 114 | "args":{ 115 | "data_dir": "data/", // dataset path 116 | "batch_size": 64, // batch size 117 | "shuffle": true, // shuffle training data before splitting 118 | "validation_split": 0.1 // size of validation dataset. float(portion) or int(number of samples) 119 | "num_workers": 2, // number of cpu processes to be used for data loading 120 | } 121 | }, 122 | "optimizer": { 123 | "type": "Adam", 124 | "args":{ 125 | "lr": 0.001, // learning rate 126 | "weight_decay": 0, // (optional) weight decay 127 | "amsgrad": true 128 | } 129 | }, 130 | "loss": "nll_loss", // loss 131 | "metrics": [ 132 | "accuracy", "top_k_acc" // list of metrics to evaluate 133 | ], 134 | "lr_scheduler": { 135 | "type": "StepLR", // learning rate scheduler 136 | "args":{ 137 | "step_size": 50, 138 | "gamma": 0.1 139 | } 140 | }, 141 | "trainer": { 142 | "epochs": 100, // number of training epochs 143 | "save_dir": "saved/", // checkpoints are saved in save_dir/models/name 144 | "save_freq": 1, // save checkpoints every save_freq epochs 145 | "verbosity": 2, // 0: quiet, 1: per epoch, 2: full 146 | 147 | "monitor": "min val_loss" // mode and metric for model performance monitoring. set 'off' to disable. 148 | "early_stop": 10 // number of epochs to wait before early stop. set 0 to disable. 149 | 150 | "tensorboard": true, // enable tensorboard visualization 151 | } 152 | } 153 | ``` 154 | 155 | Add addional configurations if you need. 156 | 157 | ### Using config files 158 | Modify the configurations in `.json` config files, then run: 159 | 160 | ``` 161 | python train.py --config config.json 162 | ``` 163 | 164 | ### Resuming from checkpoints 165 | You can resume from a previously saved checkpoint by: 166 | 167 | ``` 168 | python train.py --resume path/to/checkpoint 169 | ``` 170 | 171 | ### Using Multiple GPU 172 | You can enable multi-GPU training by setting `n_gpu` argument of the config file to larger number. 173 | If configured to use smaller number of gpu than available, first n devices will be used by default. 174 | Specify indices of available GPUs by cuda environmental variable. 175 | ``` 176 | python train.py --device 2,3 -c config.json 177 | ``` 178 | This is equivalent to 179 | ``` 180 | CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py 181 | ``` 182 | 183 | ## Customization 184 | 185 | ### Project initialization 186 | Use the `new_project.py` script to make your new project directory with template files. 187 | `python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made. 188 | This script will filter out unneccessary files like cache, git files or readme file. 189 | 190 | ### Custom CLI options 191 | 192 | Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes 193 | it is better to have command line options if some values need to be changed too often or quickly. 194 | 195 | This template uses the configurations stored in the json file by default, but by registering custom options as follows 196 | you can change some of them using CLI flags. 197 | 198 | ```python 199 | # simple class-like object having 3 attributes, `flags`, `type`, `target`. 200 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 201 | options = [ 202 | CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')), 203 | CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')) 204 | # options added here can be modified by command line flags. 205 | ] 206 | ``` 207 | `target` argument should be sequence of keys, which are used to access that option in the config dict. In this example, `target` 208 | for the learning rate option is `('optimizer', 'args', 'lr')` because `config['optimizer']['args']['lr']` points to the learning rate. 209 | `python train.py -c config.json --bs 256` runs training with options given in `config.json` except for the `batch size` 210 | which is increased to 256 by command line options. 211 | 212 | 213 | ### Data Loader 214 | * **Writing your own data loader** 215 | 216 | 1. **Inherit ```BaseDataLoader```** 217 | 218 | `BaseDataLoader` is a subclass of `torch.utils.data.DataLoader`, you can use either of them. 219 | 220 | `BaseDataLoader` handles: 221 | * Generating next batch 222 | * Data shuffling 223 | * Generating validation data loader by calling 224 | `BaseDataLoader.split_validation()` 225 | 226 | * **DataLoader Usage** 227 | 228 | `BaseDataLoader` is an iterator, to iterate through batches: 229 | ```python 230 | for batch_idx, (x_batch, y_batch) in data_loader: 231 | pass 232 | ``` 233 | * **Example** 234 | 235 | Please refer to `data_loader/data_loaders.py` for an MNIST data loading example. 236 | 237 | ### Trainer 238 | * **Writing your own trainer** 239 | 240 | 1. **Inherit ```BaseTrainer```** 241 | 242 | `BaseTrainer` handles: 243 | * Training process logging 244 | * Checkpoint saving 245 | * Checkpoint resuming 246 | * Reconfigurable performance monitoring for saving current best model, and early stop training. 247 | * If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`. 248 | * If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config. 249 | 250 | 2. **Implementing abstract methods** 251 | 252 | You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py` 253 | 254 | * **Example** 255 | 256 | Please refer to `trainer/trainer.py` for MNIST training. 257 | 258 | * **Iteration-based training** 259 | 260 | `Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch. 261 | 262 | ### Model 263 | * **Writing your own model** 264 | 265 | 1. **Inherit `BaseModel`** 266 | 267 | `BaseModel` handles: 268 | * Inherited from `torch.nn.Module` 269 | * `__str__`: Modify native `print` function to prints the number of trainable parameters. 270 | 271 | 2. **Implementing abstract methods** 272 | 273 | Implement the foward pass method `forward()` 274 | 275 | * **Example** 276 | 277 | Please refer to `model/model.py` for a LeNet example. 278 | 279 | ### Loss 280 | Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name. 281 | 282 | #### Metrics 283 | Metric functions are located in 'model/metric.py'. 284 | 285 | You can monitor multiple metrics by providing a list in the configuration file, e.g.: 286 | ```json 287 | "metrics": ["accuracy", "top_k_acc"], 288 | ``` 289 | 290 | ### Additional logging 291 | If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning: 292 | 293 | ```python 294 | additional_log = {"gradient_norm": g, "sensitivity": s} 295 | log.update(additional_log) 296 | return log 297 | ``` 298 | 299 | ### Testing 300 | You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument. 301 | 302 | ### Validation data 303 | To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file. 304 | The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`). 305 | 306 | **Note**: the `split_validation()` method will modify the original data loader 307 | **Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0` 308 | 309 | ### Checkpoints 310 | You can specify the name of the training session in config files: 311 | ```json 312 | "name": "MNIST_LeNet", 313 | ``` 314 | 315 | The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format. 316 | 317 | A copy of config file will be saved in the same folder. 318 | 319 | **Note**: checkpoints contain: 320 | ```python 321 | { 322 | 'arch': arch, 323 | 'epoch': epoch, 324 | 'state_dict': self.model.state_dict(), 325 | 'optimizer': self.optimizer.state_dict(), 326 | 'monitor_best': self.mnt_best, 327 | 'config': self.config 328 | } 329 | ``` 330 | 331 | ### Tensorboard Visualization 332 | This template supports Tensorboard visualization by using either `torch.utils.tensorboard` or [TensorboardX](https://github.com/lanpa/tensorboardX). 333 | 334 | 1. **Install** 335 | 336 | If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'. 337 | 338 | Otherwise, you should install tensorboardx. Follow installation guide in [TensorboardX](https://github.com/lanpa/tensorboardX). 339 | 340 | 2. **Run training** 341 | 342 | Make sure that `tensorboard` option in the config file is turned on. 343 | 344 | ``` 345 | "tensorboard" : true 346 | ``` 347 | 348 | 3. **Open Tensorboard server** 349 | 350 | Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006` 351 | 352 | By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged. 353 | If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method. 354 | `add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules. 355 | 356 | **Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps. 357 | 358 | ## Contribution 359 | Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8 360 | 361 | Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing. 362 | 363 | ## TODOs 364 | 365 | - [ ] Multiple optimizers 366 | - [ ] Support more tensorboard functions 367 | - [x] Using fixed random seed 368 | - [x] Support pytorch native tensorboard 369 | - [x] `tensorboardX` logger support 370 | - [x] Configurable logging layout, checkpoint naming 371 | - [x] Iteration-based training (instead of epoch-based) 372 | - [x] Adding command line option for fine-tuning 373 | 374 | ## License 375 | This project is licensed under the MIT License. See LICENSE for more details 376 | 377 | ## Acknowledgements 378 | This project is inspired by the project [Tensorflow-Project-Template](https://github.com/MrGemy95/Tensorflow-Project-Template) by [Mahmoud Gemy](https://github.com/MrGemy95) 379 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 12 | self.config = config 13 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 14 | 15 | # setup GPU device if available, move model into configured device 16 | self.device, device_ids = self._prepare_device(config['n_gpu']) 17 | self.model = model.to(self.device) 18 | if len(device_ids) > 1: 19 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 20 | 21 | self.criterion = criterion 22 | self.metric_ftns = metric_ftns 23 | self.optimizer = optimizer 24 | 25 | cfg_trainer = config['trainer'] 26 | self.epochs = cfg_trainer['epochs'] 27 | self.save_period = cfg_trainer['save_period'] 28 | self.monitor = cfg_trainer.get('monitor', 'off') 29 | 30 | # configuration to monitor model performance and save best 31 | if self.monitor == 'off': 32 | self.mnt_mode = 'off' 33 | self.mnt_best = 0 34 | else: 35 | self.mnt_mode, self.mnt_metric = self.monitor.split() 36 | assert self.mnt_mode in ['min', 'max'] 37 | 38 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 39 | self.early_stop = cfg_trainer.get('early_stop', inf) 40 | 41 | self.start_epoch = 1 42 | 43 | self.checkpoint_dir = config.save_dir 44 | 45 | # setup visualization writer instance 46 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 47 | 48 | if config.resume is not None: 49 | self._resume_checkpoint(config.resume) 50 | 51 | @abstractmethod 52 | def _train_epoch(self, epoch): 53 | """ 54 | Training logic for an epoch 55 | 56 | :param epoch: Current epoch number 57 | """ 58 | raise NotImplementedError 59 | 60 | def train(self): 61 | """ 62 | Full training logic 63 | """ 64 | not_improved_count = 0 65 | for epoch in range(self.start_epoch, self.epochs + 1): 66 | result = self._train_epoch(epoch) 67 | 68 | # save logged informations into log dict 69 | log = {'epoch': epoch} 70 | log.update(result) 71 | 72 | # print logged informations to the screen 73 | for key, value in log.items(): 74 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 75 | 76 | # evaluate model performance according to configured metric, save best checkpoint as model_best 77 | best = False 78 | if self.mnt_mode != 'off': 79 | try: 80 | # check whether model performance improved or not, according to specified metric(mnt_metric) 81 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 82 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 83 | except KeyError: 84 | self.logger.warning("Warning: Metric '{}' is not found. " 85 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 86 | self.mnt_mode = 'off' 87 | improved = False 88 | 89 | if improved: 90 | self.mnt_best = log[self.mnt_metric] 91 | not_improved_count = 0 92 | best = True 93 | else: 94 | not_improved_count += 1 95 | 96 | if not_improved_count > self.early_stop: 97 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 98 | "Training stops.".format(self.early_stop)) 99 | break 100 | 101 | if epoch % self.save_period == 0: 102 | self._save_checkpoint(epoch, save_best=best) 103 | 104 | def _prepare_device(self, n_gpu_use): 105 | """ 106 | setup GPU device if available, move model into configured device 107 | """ 108 | n_gpu = torch.cuda.device_count() 109 | if n_gpu_use > 0 and n_gpu == 0: 110 | self.logger.warning("Warning: There\'s no GPU available on this machine," 111 | "training will be performed on CPU.") 112 | n_gpu_use = 0 113 | if n_gpu_use > n_gpu: 114 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 115 | "on this machine.".format(n_gpu_use, n_gpu)) 116 | n_gpu_use = n_gpu 117 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 118 | list_ids = list(range(n_gpu_use)) 119 | return device, list_ids 120 | 121 | def _save_checkpoint(self, epoch, save_best=False): 122 | """ 123 | Saving checkpoints 124 | 125 | :param epoch: current epoch number 126 | :param log: logging information of the epoch 127 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 128 | """ 129 | arch = type(self.model).__name__ 130 | state = { 131 | 'arch': arch, 132 | 'epoch': epoch, 133 | 'state_dict': self.model.state_dict(), 134 | 'optimizer': self.optimizer.state_dict(), 135 | 'monitor_best': self.mnt_best, 136 | 'config': self.config 137 | } 138 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 139 | torch.save(state, filename) 140 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 141 | if save_best: 142 | best_path = str(self.checkpoint_dir / 'model_best.pth') 143 | torch.save(state, best_path) 144 | self.logger.info("Saving current best: model_best.pth ...") 145 | 146 | def _resume_checkpoint(self, resume_path): 147 | """ 148 | Resume from saved checkpoints 149 | 150 | :param resume_path: Checkpoint path to be resumed 151 | """ 152 | resume_path = str(resume_path) 153 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 154 | checkpoint = torch.load(resume_path) 155 | self.start_epoch = checkpoint['epoch'] + 1 156 | self.mnt_best = checkpoint['monitor_best'] 157 | 158 | # load architecture params from checkpoint. 159 | if checkpoint['config']['arch'] != self.config['arch']: 160 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 161 | "checkpoint. This may yield an exception while state_dict is being loaded.") 162 | self.model.load_state_dict(checkpoint['state_dict']) 163 | 164 | # load optimizer state from checkpoint only when optimizer type is not changed. 165 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 166 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 167 | "Optimizer parameters not being resumed.") 168 | else: 169 | self.optimizer.load_state_dict(checkpoint['optimizer']) 170 | 171 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 172 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Mnist_LeNet", 3 | "n_gpu": 1, 4 | 5 | "arch": { 6 | "type": "MnistModel", 7 | "args": {} 8 | }, 9 | "data_loader": { 10 | "type": "MnistDataLoader", 11 | "args":{ 12 | "data_dir": "data/", 13 | "batch_size": 128, 14 | "shuffle": true, 15 | "validation_split": 0.1, 16 | "num_workers": 2 17 | } 18 | }, 19 | "optimizer": { 20 | "type": "Adam", 21 | "args":{ 22 | "lr": 0.001, 23 | "weight_decay": 0, 24 | "amsgrad": true 25 | } 26 | }, 27 | "loss": "nll_loss", 28 | "metrics": [ 29 | "accuracy", "top_k_acc" 30 | ], 31 | "lr_scheduler": { 32 | "type": "StepLR", 33 | "args": { 34 | "step_size": 50, 35 | "gamma": 0.1 36 | } 37 | }, 38 | "trainer": { 39 | "epochs": 100, 40 | 41 | "save_dir": "saved/", 42 | "save_period": 1, 43 | "verbosity": 2, 44 | 45 | "monitor": "min val_loss", 46 | "early_stop": 10, 47 | 48 | "tensorboard": true 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from base import BaseDataLoader 3 | 4 | 5 | class MnistDataLoader(BaseDataLoader): 6 | """ 7 | MNIST data loading demo using BaseDataLoader 8 | """ 9 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 10 | trsfm = transforms.Compose([ 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.1307,), (0.3081,)) 13 | ]) 14 | self.data_dir = data_dir 15 | self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm) 16 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 17 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 28 | logger.warning(message) 29 | 30 | self.step = 0 31 | self.mode = '' 32 | 33 | self.tb_writer_ftns = { 34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 36 | } 37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 38 | self.timer = datetime.now() 39 | 40 | def set_step(self, step, mode='train'): 41 | self.mode = mode 42 | self.step = step 43 | if step == 0: 44 | self.timer = datetime.now() 45 | else: 46 | duration = datetime.now() - self.timer 47 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 48 | self.timer = datetime.now() 49 | 50 | def __getattr__(self, name): 51 | """ 52 | If visualization is configured to use: 53 | return add_data() methods of tensorboard with additional information (step, tag) added. 54 | Otherwise: 55 | return a blank function handle that does nothing 56 | """ 57 | if name in self.tb_writer_ftns: 58 | add_data = getattr(self.writer, name, None) 59 | 60 | def wrapper(tag, data, *args, **kwargs): 61 | if add_data is not None: 62 | # add mode(train/valid) tag 63 | if name not in self.tag_mode_exceptions: 64 | tag = '{}/{}'.format(tag, self.mode) 65 | add_data(tag, data, self.step, *args, **kwargs) 66 | return wrapper 67 | else: 68 | # default action for returning methods defined in this class, set_step() for instance. 69 | try: 70 | attr = object.__getattr__(name) 71 | except AttributeError: 72 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 73 | return attr 74 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def nll_loss(output, target): 5 | return F.nll_loss(output, target) 6 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target): 5 | with torch.no_grad(): 6 | pred = torch.argmax(output, dim=1) 7 | assert pred.shape[0] == len(target) 8 | correct = 0 9 | correct += torch.sum(pred == target).item() 10 | return correct / len(target) 11 | 12 | 13 | def top_k_acc(output, target, k=3): 14 | with torch.no_grad(): 15 | pred = torch.topk(output, k, dim=1)[1] 16 | assert pred.shape[0] == len(target) 17 | correct = 0 18 | for i in range(k): 19 | correct += torch.sum(pred[:, i] == target).item() 20 | return correct / len(target) 21 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from base import BaseModel 4 | 5 | 6 | class MnistModel(BaseModel): 7 | def __init__(self, num_classes=10): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 10 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 11 | self.conv2_drop = nn.Dropout2d() 12 | self.fc1 = nn.Linear(320, 50) 13 | self.fc2 = nn.Linear(50, num_classes) 14 | 15 | def forward(self, x): 16 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 17 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 18 | x = x.view(-1, 320) 19 | x = F.relu(self.fc1(x)) 20 | x = F.dropout(x, training=self.training) 21 | x = self.fc2(x) 22 | return F.log_softmax(x, dim=1) 23 | -------------------------------------------------------------------------------- /new_project.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from shutil import copytree, ignore_patterns 4 | 5 | 6 | # This script initializes new pytorch project with the template files. 7 | # Run `python3 new_project.py ../MyNewProject` then new project named 8 | # MyNewProject will be made 9 | current_dir = Path() 10 | assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory' 11 | assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject' 12 | 13 | project_name = Path(sys.argv[1]) 14 | target_dir = current_dir / project_name 15 | 16 | ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"] 17 | copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore)) 18 | print('New project initialized at', target_dir.absolute().resolve()) -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | 25 | # set save_dir where trained model and log will be saved. 26 | save_dir = Path(self.config['trainer']['save_dir']) 27 | 28 | exper_name = self.config['name'] 29 | if run_id is None: # use timestamp as default run-id 30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 31 | self._save_dir = save_dir / 'models' / exper_name / run_id 32 | self._log_dir = save_dir / 'log' / exper_name / run_id 33 | 34 | # make directory for saving checkpoints and log. 35 | exist_ok = run_id == '' 36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 38 | 39 | # save updated config file to the checkpoint dir 40 | write_json(self.config, self.save_dir / 'config.json') 41 | 42 | # configure logging module 43 | setup_logging(self.log_dir) 44 | self.log_levels = { 45 | 0: logging.WARNING, 46 | 1: logging.INFO, 47 | 2: logging.DEBUG 48 | } 49 | 50 | @classmethod 51 | def from_args(cls, args, options=''): 52 | """ 53 | Initialize this class from some cli arguments. Used in train, test. 54 | """ 55 | for opt in options: 56 | args.add_argument(*opt.flags, default=None, type=opt.type) 57 | if not isinstance(args, tuple): 58 | args = args.parse_args() 59 | 60 | if args.device is not None: 61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 62 | if args.resume is not None: 63 | resume = Path(args.resume) 64 | cfg_fname = resume.parent / 'config.json' 65 | else: 66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 67 | assert args.config is not None, msg_no_cfg 68 | resume = None 69 | cfg_fname = Path(args.config) 70 | 71 | config = read_json(cfg_fname) 72 | if args.config and resume: 73 | # update new config for fine-tuning 74 | config.update(read_json(args.config)) 75 | 76 | # parse custom cli options into dictionary 77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} 78 | return cls(config, resume, modification) 79 | 80 | def init_obj(self, name, module, *args, **kwargs): 81 | """ 82 | Finds a function handle with the name given as 'type' in config, and returns the 83 | instance initialized with corresponding arguments given. 84 | 85 | `object = config.init_obj('name', module, a, b=1)` 86 | is equivalent to 87 | `object = module.name(a, b=1)` 88 | """ 89 | module_name = self[name]['type'] 90 | module_args = dict(self[name]['args']) 91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 92 | module_args.update(kwargs) 93 | return getattr(module, module_name)(*args, **module_args) 94 | 95 | def init_ftn(self, name, module, *args, **kwargs): 96 | """ 97 | Finds a function handle with the name given as 'type' in config, and returns the 98 | function with given arguments fixed with functools.partial. 99 | 100 | `function = config.init_ftn('name', module, a, b=1)` 101 | is equivalent to 102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 103 | """ 104 | module_name = self[name]['type'] 105 | module_args = dict(self[name]['args']) 106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 107 | module_args.update(kwargs) 108 | return partial(getattr(module, module_name), *args, **module_args) 109 | 110 | def __getitem__(self, name): 111 | """Access items like ordinary dict.""" 112 | return self.config[name] 113 | 114 | def get_logger(self, name, verbosity=2): 115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 116 | assert verbosity in self.log_levels, msg_verbosity 117 | logger = logging.getLogger(name) 118 | logger.setLevel(self.log_levels[verbosity]) 119 | return logger 120 | 121 | # setting read-only attributes 122 | @property 123 | def config(self): 124 | return self._config 125 | 126 | @property 127 | def save_dir(self): 128 | return self._save_dir 129 | 130 | @property 131 | def log_dir(self): 132 | return self._log_dir 133 | 134 | # helper functions to update config dict with custom cli options 135 | def _update_config(config, modification): 136 | if modification is None: 137 | return config 138 | 139 | for k, v in modification.items(): 140 | if v is not None: 141 | _set_by_path(config, k, v) 142 | return config 143 | 144 | def _get_opt_name(flags): 145 | for flg in flags: 146 | if flg.startswith('--'): 147 | return flg.replace('--', '') 148 | return flags[0].replace('--', '') 149 | 150 | def _set_by_path(tree, keys, value): 151 | """Set a value in a nested object in tree by sequence of keys.""" 152 | keys = keys.split(';') 153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 154 | 155 | def _get_by_path(tree, keys): 156 | """Access a nested object in tree by sequence of keys.""" 157 | return reduce(getitem, keys, tree) 158 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import data_loader.data_loaders as module_data 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | import model.model as module_arch 8 | from parse_config import ConfigParser 9 | 10 | 11 | def main(config): 12 | logger = config.get_logger('test') 13 | 14 | # setup data_loader instances 15 | data_loader = getattr(module_data, config['data_loader']['type'])( 16 | config['data_loader']['args']['data_dir'], 17 | batch_size=512, 18 | shuffle=False, 19 | validation_split=0.0, 20 | training=False, 21 | num_workers=2 22 | ) 23 | 24 | # build model architecture 25 | model = config.init_obj('arch', module_arch) 26 | logger.info(model) 27 | 28 | # get function handles of loss and metrics 29 | loss_fn = getattr(module_loss, config['loss']) 30 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 31 | 32 | logger.info('Loading checkpoint: {} ...'.format(config.resume)) 33 | checkpoint = torch.load(config.resume) 34 | state_dict = checkpoint['state_dict'] 35 | if config['n_gpu'] > 1: 36 | model = torch.nn.DataParallel(model) 37 | model.load_state_dict(state_dict) 38 | 39 | # prepare model for testing 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | model = model.to(device) 42 | model.eval() 43 | 44 | total_loss = 0.0 45 | total_metrics = torch.zeros(len(metric_fns)) 46 | 47 | with torch.no_grad(): 48 | for i, (data, target) in enumerate(tqdm(data_loader)): 49 | data, target = data.to(device), target.to(device) 50 | output = model(data) 51 | 52 | # 53 | # save sample images, or do something with output here 54 | # 55 | 56 | # computing loss, metrics on test set 57 | loss = loss_fn(output, target) 58 | batch_size = data.shape[0] 59 | total_loss += loss.item() * batch_size 60 | for i, metric in enumerate(metric_fns): 61 | total_metrics[i] += metric(output, target) * batch_size 62 | 63 | n_samples = len(data_loader.sampler) 64 | log = {'loss': total_loss / n_samples} 65 | log.update({ 66 | met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns) 67 | }) 68 | logger.info(log) 69 | 70 | 71 | if __name__ == '__main__': 72 | args = argparse.ArgumentParser(description='PyTorch Template') 73 | args.add_argument('-c', '--config', default=None, type=str, 74 | help='config file path (default: None)') 75 | args.add_argument('-r', '--resume', default=None, type=str, 76 | help='path to latest checkpoint (default: None)') 77 | args.add_argument('-d', '--device', default=None, type=str, 78 | help='indices of GPUs to enable (default: all)') 79 | 80 | config = ConfigParser.from_args(args) 81 | main(config) 82 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import numpy as np 5 | import data_loader.data_loaders as module_data 6 | import model.loss as module_loss 7 | import model.metric as module_metric 8 | import model.model as module_arch 9 | from parse_config import ConfigParser 10 | from trainer import Trainer 11 | 12 | 13 | # fix random seeds for reproducibility 14 | SEED = 123 15 | torch.manual_seed(SEED) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | np.random.seed(SEED) 19 | 20 | def main(config): 21 | logger = config.get_logger('train') 22 | 23 | # setup data_loader instances 24 | data_loader = config.init_obj('data_loader', module_data) 25 | valid_data_loader = data_loader.split_validation() 26 | 27 | # build model architecture, then print to console 28 | model = config.init_obj('arch', module_arch) 29 | logger.info(model) 30 | 31 | # get function handles of loss and metrics 32 | criterion = getattr(module_loss, config['loss']) 33 | metrics = [getattr(module_metric, met) for met in config['metrics']] 34 | 35 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 36 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 37 | optimizer = config.init_obj('optimizer', torch.optim, trainable_params) 38 | 39 | lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) 40 | 41 | trainer = Trainer(model, criterion, metrics, optimizer, 42 | config=config, 43 | data_loader=data_loader, 44 | valid_data_loader=valid_data_loader, 45 | lr_scheduler=lr_scheduler) 46 | 47 | trainer.train() 48 | 49 | 50 | if __name__ == '__main__': 51 | args = argparse.ArgumentParser(description='PyTorch Template') 52 | args.add_argument('-c', '--config', default=None, type=str, 53 | help='config file path (default: None)') 54 | args.add_argument('-r', '--resume', default=None, type=str, 55 | help='path to latest checkpoint (default: None)') 56 | args.add_argument('-d', '--device', default=None, type=str, 57 | help='indices of GPUs to enable (default: all)') 58 | 59 | # custom cli options to modify configuration from default values given in json file. 60 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 61 | options = [ 62 | CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'), 63 | CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size') 64 | ] 65 | config = ConfigParser.from_args(args, options) 66 | main(config) 67 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from utils import inf_loop, MetricTracker 6 | 7 | 8 | class Trainer(BaseTrainer): 9 | """ 10 | Trainer class 11 | """ 12 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, 13 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 14 | super().__init__(model, criterion, metric_ftns, optimizer, config) 15 | self.config = config 16 | self.data_loader = data_loader 17 | if len_epoch is None: 18 | # epoch-based training 19 | self.len_epoch = len(self.data_loader) 20 | else: 21 | # iteration-based training 22 | self.data_loader = inf_loop(data_loader) 23 | self.len_epoch = len_epoch 24 | self.valid_data_loader = valid_data_loader 25 | self.do_validation = self.valid_data_loader is not None 26 | self.lr_scheduler = lr_scheduler 27 | self.log_step = int(np.sqrt(data_loader.batch_size)) 28 | 29 | self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 30 | self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) 31 | 32 | def _train_epoch(self, epoch): 33 | """ 34 | Training logic for an epoch 35 | 36 | :param epoch: Integer, current training epoch. 37 | :return: A log that contains average loss and metric in this epoch. 38 | """ 39 | self.model.train() 40 | self.train_metrics.reset() 41 | for batch_idx, (data, target) in enumerate(self.data_loader): 42 | data, target = data.to(self.device), target.to(self.device) 43 | 44 | self.optimizer.zero_grad() 45 | output = self.model(data) 46 | loss = self.criterion(output, target) 47 | loss.backward() 48 | self.optimizer.step() 49 | 50 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 51 | self.train_metrics.update('loss', loss.item()) 52 | for met in self.metric_ftns: 53 | self.train_metrics.update(met.__name__, met(output, target)) 54 | 55 | if batch_idx % self.log_step == 0: 56 | self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( 57 | epoch, 58 | self._progress(batch_idx), 59 | loss.item())) 60 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 61 | 62 | if batch_idx == self.len_epoch: 63 | break 64 | log = self.train_metrics.result() 65 | 66 | if self.do_validation: 67 | val_log = self._valid_epoch(epoch) 68 | log.update(**{'val_'+k : v for k, v in val_log.items()}) 69 | 70 | if self.lr_scheduler is not None: 71 | self.lr_scheduler.step() 72 | return log 73 | 74 | def _valid_epoch(self, epoch): 75 | """ 76 | Validate after training an epoch 77 | 78 | :param epoch: Integer, current training epoch. 79 | :return: A log that contains information about validation 80 | """ 81 | self.model.eval() 82 | self.valid_metrics.reset() 83 | with torch.no_grad(): 84 | for batch_idx, (data, target) in enumerate(self.valid_data_loader): 85 | data, target = data.to(self.device), target.to(self.device) 86 | 87 | output = self.model(data) 88 | loss = self.criterion(output, target) 89 | 90 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 91 | self.valid_metrics.update('loss', loss.item()) 92 | for met in self.metric_ftns: 93 | self.valid_metrics.update(met.__name__, met(output, target)) 94 | self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) 95 | 96 | # add histogram of model parameters to the tensorboard 97 | for name, p in self.model.named_parameters(): 98 | self.writer.add_histogram(name, p, bins='auto') 99 | return self.valid_metrics.result() 100 | 101 | def _progress(self, batch_idx): 102 | base = '[{}/{} ({:.0f}%)]' 103 | if hasattr(self.data_loader, 'n_samples'): 104 | current = batch_idx * self.data_loader.batch_size 105 | total = self.data_loader.n_samples 106 | else: 107 | current = batch_idx 108 | total = self.len_epoch 109 | return base.format(current, total, 100.0 * current / total) 110 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from pathlib import Path 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | 7 | 8 | def ensure_dir(dirname): 9 | dirname = Path(dirname) 10 | if not dirname.is_dir(): 11 | dirname.mkdir(parents=True, exist_ok=False) 12 | 13 | def read_json(fname): 14 | fname = Path(fname) 15 | with fname.open('rt') as handle: 16 | return json.load(handle, object_hook=OrderedDict) 17 | 18 | def write_json(content, fname): 19 | fname = Path(fname) 20 | with fname.open('wt') as handle: 21 | json.dump(content, handle, indent=4, sort_keys=False) 22 | 23 | def inf_loop(data_loader): 24 | ''' wrapper function for endless data loader. ''' 25 | for loader in repeat(data_loader): 26 | yield from loader 27 | 28 | class MetricTracker: 29 | def __init__(self, *keys, writer=None): 30 | self.writer = writer 31 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 32 | self.reset() 33 | 34 | def reset(self): 35 | for col in self._data.columns: 36 | self._data[col].values[:] = 0 37 | 38 | def update(self, key, value, n=1): 39 | if self.writer is not None: 40 | self.writer.add_scalar(key, value) 41 | self._data.total[key] += value * n 42 | self._data.counts[key] += n 43 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 44 | 45 | def avg(self, key): 46 | return self._data.average[key] 47 | 48 | def result(self): 49 | return dict(self._data.average) 50 | --------------------------------------------------------------------------------