├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── conf ├── data │ ├── mnist_test.yaml │ └── mnist_train.yaml ├── evaluate.yaml ├── hparams │ └── lenet_baseline.yaml ├── hydra │ └── job_logging │ │ └── custom.yaml ├── model │ └── mnist_lenet.yaml ├── status │ ├── debug.yaml │ └── train.yaml ├── train.yaml └── working_dir │ ├── job_timestamp.yaml │ └── no_chdir.yaml ├── evaluate.py ├── new_project.py ├── requirements.txt ├── srcs ├── data_loader │ └── data_loaders.py ├── logger.py ├── model │ ├── loss.py │ ├── metric.py │ └── model.py ├── trainer │ ├── __init__.py │ ├── base.py │ └── trainer.py └── utils │ ├── __init__.py │ └── util.py └── train.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = F401, F403 3 | max-line-length = 120 4 | exclude = 5 | .git, 6 | __pycache__, 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | data/ 105 | outputs 106 | saved/ 107 | datasets/ 108 | 109 | # editor, os cache directory 110 | .vscode/ 111 | .idea/ 112 | __MACOSX/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Victor Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Template Project 2 | Simple project base template for PyTorch deep Learning project. 3 | 4 | 5 | 6 | - [PyTorch Template Project](#pytorch-template-project) 7 | - [Installation](#installation) 8 | - [Requirements](#requirements) 9 | - [Features](#features) 10 | - [Folder Structure](#folder-structure) 11 | - [Usage](#usage) 12 | - [Hierarchical configurations with Hydra](#hierarchical-configurations-with-hydra) 13 | - [Using config files](#using-config-files) 14 | - [Checkpoints](#checkpoints) 15 | - [Resuming from checkpoints](#resuming-from-checkpoints) 16 | - [Using Multiple GPU](#using-multiple-gpu) 17 | - [Customization](#customization) 18 | - [Project initialization](#project-initialization) 19 | - [Data Loader](#data-loader) 20 | - [Trainer](#trainer) 21 | - [Model](#model) 22 | - [Loss](#loss) 23 | - [Metrics](#metrics) 24 | - [Additional logging](#additional-logging) 25 | - [Testing](#testing) 26 | - [Validation data](#validation-data) 27 | - [Checkpoints](#checkpoints-1) 28 | - [Tensorboard Visualization](#tensorboard-visualization) 29 | - [Contribution](#contribution) 30 | - [TODOs](#todos) 31 | - [License](#license) 32 | 33 | 34 | 35 | ## Installation 36 | ### Requirements 37 | * Python >= 3.6 38 | * PyTorch >= 1.2 39 | * tensorboard >= 1.14 (see [Tensorboard Visualization](#tensorboard-visualization)) 40 | * tqdm 41 | * hydra-core >= 1.0.3 42 | 43 | ### Features 44 | * Simple and clear directory structure, suitable for most of deep learning projects. 45 | * Hierarchical management of project configurations with [Hydra](https://hydra.cc/docs/intro). 46 | * Advanced logging and monitoring for validation metrics. Automatic handling of model checkpoints. 47 | * **Note**: This repository is detached from [victorisque/pytorch-template](https://github.com/victoresque/pytorch-template), in order to introduce advanced features rapidly without concerning much for backward compatibility. 48 | 49 | ### Folder Structure 50 | ```yaml 51 | pytorch-template/ 52 | ├── train.py # main script to start training. 53 | ├── evaluate.py # script to evaluate trained model on testset. 54 | ├── conf # config files. explained in separated section below. 55 | │   └── ... 56 | ├── srcs # source code. 57 | │   ├── data_loader # data loading, preprocessing 58 | │   │   └── data_loaders.py 59 | │   ├── model 60 | │   │   ├── loss.py 61 | │   │   ├── metric.py 62 | │   │   └── model.py 63 | │   ├── trainer # customized class managing training process 64 | │   │   ├── base.py 65 | │   │   └── trainer.py 66 | │   ├── logger.py # tensorboard, train / validation metric logging 67 | │   └── utils 68 | │   └── util.py 69 | ├── new_project.py # script to initialize new project 70 | ├── requirements.txt 71 | ├── README.md 72 | └── LICENSE 73 | ``` 74 | 75 | ## Usage 76 | This template itself is an working example project which trains a simple model(LeNet) on [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset. 77 | Try `python train.py` to run training. 78 | 79 | ### Hierarchical configurations with Hydra 80 | This repository is designed to be used with [Hydra](https://hydra.cc/) framework, which has useful key features as following. 81 | 82 | - Hierarchical configuration composable from multiple sources 83 | - Configuration can be specified or overridden from the command line 84 | - Dynamic command line tab completion 85 | - Run your application locally or launch it to run remotely 86 | - Run multiple jobs with different arguments with a single command 87 | 88 | Check [Hydra documentation](https://hydra.cc/), for more information. 89 | 90 | `conf/` directory contains `.yaml`config files which are structured into multiple **config groups**. 91 | 92 | ```yaml 93 | conf/ # hierarchical, structured config files to be used with 'Hydra' framework 94 | ├── train.yaml # main config file used for train.py 95 | ├── evaluate.yaml # main config file used for evaluate.py 96 | ├── hparams # define global hyper-parameters 97 | │   └── lenet_baseline.yaml 98 | ├── data 99 | │   ├── mnist_test.yaml 100 | │   └── mnist_train.yaml 101 | ├── model # select NN architecture to train 102 | │   └── mnist_lenet.yaml 103 | ├── status # set train/debug mode. 104 | │   ├── debug.yaml # debug mode runs faster, and don't use tensorboard 105 | │   └── train.yaml # train mode is default with full logging 106 | │ 107 | └── hydra # configure hydra framework 108 |    ├── job_logging # config for python logging module 109 | │   └── custom.yaml 110 |    └── run/dir # setup working directory 111 |    ├── job_timestamp.yaml 112 |    └── no_chdir.yaml 113 | ``` 114 | 115 | ### Using config files 116 | Modify the configurations in `.yaml` files in `conf/` dir, then run: 117 | ``` 118 | python train.py 119 | ``` 120 | 121 | At runtime, one file from each config group is selected and combined to be used as one global config. 122 | 123 | ```yaml 124 | name: MnistLeNet # experiment name. 125 | 126 | save_dir: models/ 127 | log_dir: ${name}/ 128 | resume: 129 | 130 | # Global hyper-parameters defined in conf/hparams/ 131 | # you can change the values by either editing yaml file directly, 132 | # or using command line arguments, like `python3 train.py batch_size=128` 133 | batch_size: 256 134 | learning_rate: 0.001 135 | weight_decay: 0 136 | scheduler_step_size: 50 137 | scheduler_gamma: 0.1 138 | 139 | 140 | # configuration for data loading. 141 | data_loader: 142 | _target_: srcs.data_loader.data_loaders.get_data_loaders 143 | data_dir: data/ 144 | batch_size: ${batch_size} 145 | shuffle: true 146 | validation_split: 0.1 147 | num_workers: ${n_cpu} 148 | 149 | arch: 150 | _target_: srcs.model.model.MnistModel 151 | num_classes: 10 152 | loss: 153 | _target_: srcs.model.loss.nll_loss 154 | optimizer: 155 | _target_: torch.optim.Adam 156 | lr: ${learning_rate} 157 | weight_decay: ${weight_decay} 158 | amsgrad: true 159 | lr_scheduler: 160 | _target_: torch.optim.lr_scheduler.StepLR 161 | step_size: ${scheduler_step_size} 162 | gamma: ${scheduler_gamma} 163 | 164 | metrics: 165 | - _target_: srcs.model.metric.accuracy 166 | - _target_: srcs.model.metric.top_k_acc 167 | 168 | n_gpu: 1 169 | n_cpu: 8 170 | trainer: 171 | epochs: 20 172 | logging_step: 100 173 | verbosity: 2 174 | monitor: min loss/valid 175 | early_stop: 10 176 | tensorboard: true 177 | ``` 178 | 179 | Add addional configurations if you need. 180 | 181 | `conf/hparams/lenet_baseline.yaml` contains 182 | 183 | ```yaml 184 | batch_size: 256 185 | learning_rate: 0.001 186 | weight_decay: 0 187 | scheduler_step_size: 50 188 | scheduler_gamma: 0.1 189 | ``` 190 | 191 | 192 | Those config items containing `_target_` are designed to be used with `instantiate` function of Hydra. For example, 193 | When your config looks like 194 | ```yaml 195 | # @package _global_ 196 | classitem: 197 | _target_: location.to.class.definition 198 | arg1: 123 199 | arg2: 'example' 200 | ``` 201 | 202 | then usage of instantiate as 203 | 204 | ```python 205 | example_object = instantiate(config.classitem) 206 | ``` 207 | 208 | is equivalent to 209 | 210 | ```python 211 | from location.to.class import definition 212 | 213 | example_object = definition(arg1=1, arg2='example') 214 | ``` 215 | 216 | This feature is especially useful, when you switch between multiple models with same interface(input, output), 217 | like choosing ResNet or MobileNet for CNN backbone of detection model. 218 | You can change architecture by simply using different config file, even not needing to importing both in code. 219 | 220 | ### Checkpoints 221 | 222 | ```yaml 223 | # new directory with timestamp will be created automatically. 224 | # if you enable debug mode by status=debug either in command line or main config, 225 | # checkpoints will be saved under separate directory `outputs/debug`. 226 | outputs/train/2020-07-29/12-44-37/ 227 | ├── config.yaml # composed config file 228 | ├── epoch-results.csv # epoch-wise evaluation metrics 229 | ├── MnistLeNet/ # tensorboard log file 230 | ├── model 231 | │   ├── checkpoint-epoch1.pth 232 | │   ├── checkpoint-epoch2.pth 233 | │   ├── ... 234 | │   ├── model_best.pth # checkpoint with best score 235 | │   └── model_latest.pth # checkpoint which is saved last 236 | └── train.log 237 | ``` 238 | 239 | ### Resuming from checkpoints 240 | You can resume from a previously saved checkpoint by: 241 | ``` 242 | python train.py resume=output/train/path/to/checkpoint.pth 243 | ``` 244 | 245 | ### Using Multiple GPU 246 | You can enable multi-GPU training(with DataParallel) by setting `n_gpu` argument of the config file to larger number. If configured to use smaller number of gpu than available, first n devices will be used by default. When you want to run multiple instances of training on larger maching, specify indices of available GPUs by cuda environmental variable. 247 | ```bash 248 | # assume running on a machine with 4 GPUs. 249 | python train.py n_gpu=2 # This will use first two GPU, which are on index 0 and 1 250 | CUDA_VISIBLE_DEVICES=2,3 python train.py n_gpu=2 # This will use remaining 2 GPUs on index 2 and 3 251 | ``` 252 | 253 | ## Customization 254 | 255 | ### Project initialization 256 | Use the `new_project.py` script to make your new project directory with template files. 257 | `python new_project.py ../NewProject` then a new project folder named 'NewProject' will be made. 258 | This script will filter out unneccessary files like cache, git files or readme file. 259 | 260 | 261 | ### Data Loader 262 | * **Writing your own data loader** 263 | 264 | Please refer to `data_loader/data_loaders.py` for an MNIST data loading example. 265 | 266 | ### Trainer 267 | * **Writing your own trainer** 268 | 269 | 1. **Inherit ```BaseTrainer```** 270 | 271 | `BaseTrainer` handles: 272 | * Training process logging 273 | * Checkpoint saving 274 | * Checkpoint resuming 275 | * Reconfigurable performance monitoring for saving current best model, and early stop training. 276 | * If config `monitor` is set to `max val_accuracy`, which means then the trainer will save a checkpoint `model_best.pth` when `validation accuracy` of epoch replaces current `maximum`. 277 | * If config `early_stop` is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the `early_stop` option, or just deleting the line of config. 278 | 279 | 2. **Implementing abstract methods** 280 | 281 | You need to implement `_train_epoch()` for your training process, if you need validation then you can implement `_valid_epoch()` as in `trainer/trainer.py` 282 | 283 | * **Example** 284 | 285 | Please refer to `trainer/trainer.py` for MNIST training. 286 | 287 | * **Iteration-based training** 288 | 289 | `Trainer.__init__` takes an optional argument, `len_epoch` which controls number of batches(steps) in each epoch. 290 | 291 | ### Model 292 | * **Writing your own model** 293 | 294 | 1. **Inherit `BaseModel`** 295 | 296 | `BaseModel` handles: 297 | * Inherited from `torch.nn.Module` 298 | * `__str__`: Modify native `print` function to prints the number of trainable parameters. 299 | 300 | 2. **Implementing abstract methods** 301 | 302 | Implement the foward pass method `forward()` 303 | 304 | * **Example** 305 | 306 | Please refer to `model/model.py` for a LeNet example. 307 | 308 | ### Loss 309 | Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name. 310 | 311 | ### Metrics 312 | Metric functions are located in 'model/metric.py'. 313 | 314 | You can monitor multiple metrics by providing a list in the configuration file, e.g.: 315 | ```yaml 316 | "metrics": ["accuracy", "top_k_acc"], 317 | ``` 318 | 319 | ### Additional logging 320 | If you have additional information to be logged, in `_train_epoch()` of your trainer class, merge them with `log` as shown below before returning: 321 | 322 | ```python 323 | additional_log = {"gradient_norm": g, "sensitivity": s} 324 | log.update(additional_log) 325 | return log 326 | ``` 327 | 328 | ### Testing 329 | You can test trained model by running `test.py` passing path to the trained checkpoint by `--resume` argument. 330 | 331 | ### Validation data 332 | To split validation data from a data loader, call `BaseDataLoader.split_validation()`, then it will return a data loader for validation of size specified in your config file. 333 | The `validation_split` can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < `n_total_samples`). 334 | 335 | **Note**: the `split_validation()` method will modify the original data loader 336 | **Note**: `split_validation()` will return `None` if `"validation_split"` is set to `0` 337 | 338 | ### Checkpoints 339 | You can specify the name of the training session in config files: 340 | ```yaml 341 | "name": "MNIST_LeNet", 342 | ``` 343 | 344 | The checkpoints will be saved in `save_dir/name/timestamp/checkpoint_epoch_n`, with timestamp in mmdd_HHMMSS format. 345 | 346 | A copy of config file will be saved in the same folder. 347 | 348 | **Note**: checkpoints contain: 349 | ```python 350 | { 351 | 'arch': arch, 352 | 'epoch': epoch, 353 | 'state_dict': self.model.state_dict(), 354 | 'optimizer': self.optimizer.state_dict(), 355 | 'epoch_metrics': self.ep_metrics, 356 | 'config': self.config 357 | } 358 | ``` 359 | 360 | ### Tensorboard Visualization 361 | This template supports Tensorboard visualization with `torch.utils.tensorboard`. 362 | 363 | 1. **Run training** 364 | 365 | Make sure that `tensorboard` option in the config file is turned on. 366 | 367 | ``` 368 | "tensorboard" : true 369 | ``` 370 | 371 | 2. **Open Tensorboard server** 372 | 373 | Type `tensorboard --logdir outputs/train/` at the project root, then server will open at `http://localhost:6006` 374 | 375 | By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged. 376 | If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method. 377 | `add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules. 378 | 379 | **Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `srcs.logger.py` will track current steps. 380 | 381 | ## Contribution 382 | Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8 383 | 384 | Code should pass the [Flake8](http://flake8.pycqa.org/en/latest/) check before committing. 385 | 386 | ## TODOs 387 | - [ ] Support DistributedDataParallel 388 | - [x] Option to keep top-k checkpoints only 389 | - [ ] Simple unittest code for `nn.Module` and others 390 | 391 | ## License 392 | This project is licensed under the MIT License. See LICENSE for more details 393 | -------------------------------------------------------------------------------- /conf/data/mnist_test.yaml: -------------------------------------------------------------------------------- 1 | # @package data_loader 2 | _target_: srcs.data_loader.data_loaders.get_data_loaders 3 | data_dir: ${hydra:runtime.cwd}/data/ 4 | batch_size: 512 5 | training: false 6 | shuffle: false 7 | num_workers: 4 -------------------------------------------------------------------------------- /conf/data/mnist_train.yaml: -------------------------------------------------------------------------------- 1 | # @package data_loader 2 | _target_: srcs.data_loader.data_loaders.get_data_loaders 3 | data_dir: ${hydra:runtime.cwd}/data/ 4 | batch_size: ${batch_size} 5 | shuffle: true 6 | validation_split: 0.1 7 | num_workers: ${n_cpu} -------------------------------------------------------------------------------- /conf/evaluate.yaml: -------------------------------------------------------------------------------- 1 | log_dir: ${name}/ 2 | checkpoint: ??? 3 | 4 | metrics: 5 | - _target_: srcs.model.metric.accuracy 6 | - _target_: srcs.model.metric.top_k_acc 7 | 8 | defaults: 9 | - _self_ 10 | - data: mnist_test 11 | 12 | - working_dir: no_chdir 13 | - override hydra/job_logging : custom 14 | -------------------------------------------------------------------------------- /conf/hparams/lenet_baseline.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | n_cpu: 8 3 | 4 | batch_size: 256 5 | learning_rate: 0.001 6 | weight_decay: 0 7 | 8 | scheduler_step_size: 50 9 | scheduler_gamma: 0.1 -------------------------------------------------------------------------------- /conf/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | # python logging configuration for tasks 3 | version: 1 4 | formatters: 5 | simple: 6 | format: '%(message)s' 7 | detailed: 8 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 9 | handlers: 10 | console: 11 | class: logging.StreamHandler 12 | formatter: simple 13 | stream: ext://sys.stdout 14 | file: 15 | class: logging.FileHandler 16 | formatter: detailed 17 | # relative to the job log directory 18 | filename: ${hydra.job.name}.log 19 | root: 20 | level: INFO 21 | handlers: [console, file] 22 | 23 | disable_existing_loggers: False -------------------------------------------------------------------------------- /conf/model/mnist_lenet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | name: MnistLeNet 3 | arch: 4 | _target_: srcs.model.model.MnistModel 5 | num_classes: 10 6 | loss: 7 | _target_: srcs.model.loss.nll_loss 8 | _partial_: true 9 | optimizer: 10 | _target_: torch.optim.Adam 11 | lr: ${learning_rate} 12 | weight_decay: ${weight_decay} 13 | amsgrad: true 14 | lr_scheduler: 15 | _target_: torch.optim.lr_scheduler.StepLR 16 | step_size: ${scheduler_step_size} 17 | gamma: ${scheduler_gamma} 18 | -------------------------------------------------------------------------------- /conf/status/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | status: debug 3 | 4 | trainer: 5 | epochs: 5 6 | logging_step: 10 7 | 8 | monitor: min loss/valid 9 | save_topk: 3 10 | early_stop: 5 11 | 12 | tensorboard: false 13 | -------------------------------------------------------------------------------- /conf/status/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | status: train 3 | 4 | trainer: 5 | epochs: 20 6 | logging_step: 100 7 | 8 | monitor: min loss/valid 9 | save_topk: 7 10 | early_stop: 10 11 | 12 | tensorboard: true 13 | -------------------------------------------------------------------------------- /conf/train.yaml: -------------------------------------------------------------------------------- 1 | resume: 2 | save_dir: models/ 3 | log_dir: ${name}/ 4 | 5 | metrics: 6 | accuracy: 7 | _target_: srcs.model.metric.accuracy 8 | _partial_: true 9 | top_k_acc: 10 | _target_: srcs.model.metric.top_k_acc 11 | _partial_: true 12 | 13 | defaults: 14 | - _self_ 15 | - data: mnist_train 16 | - model: mnist_lenet 17 | - hparams: lenet_baseline 18 | 19 | - status: train 20 | 21 | - working_dir: job_timestamp 22 | - override hydra/job_logging : custom 23 | -------------------------------------------------------------------------------- /conf/working_dir/job_timestamp.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra 2 | run: 3 | dir: ./outputs/${status}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | job: 5 | chdir: true 6 | -------------------------------------------------------------------------------- /conf/working_dir/no_chdir.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra 2 | run: 3 | dir: ./ 4 | job: 5 | chdir: false 6 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import hydra 4 | from hydra.utils import instantiate 5 | from omegaconf import OmegaConf 6 | from tqdm import tqdm 7 | 8 | 9 | logger = logging.getLogger('evaluate') 10 | 11 | @hydra.main(config_path='conf', config_name='evaluate') 12 | def main(config): 13 | logger.info('Loading checkpoint: {} ...'.format(config.checkpoint)) 14 | checkpoint = torch.load(config.checkpoint, weights_only=False) 15 | 16 | loaded_config = OmegaConf.create(checkpoint['config']) 17 | 18 | # setup data_loader instances 19 | data_loader = instantiate(config.data_loader) 20 | 21 | # restore network architecture 22 | model = instantiate(loaded_config.arch) 23 | logger.info(model) 24 | 25 | # load trained weights 26 | state_dict = checkpoint['state_dict'] 27 | model = torch.nn.DataParallel(model) 28 | model.load_state_dict(state_dict) 29 | 30 | # instantiate loss and metrics 31 | criterion = instantiate(loaded_config.loss) 32 | metrics = { 33 | met_name: instantiate(met) 34 | for met_name, met in loaded_config.metrics.items() 35 | } 36 | 37 | # prepare model for testing 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | model = model.to(device) 40 | model.eval() 41 | 42 | total_loss = 0.0 43 | total_metrics = torch.zeros(len(metrics)) 44 | 45 | with torch.no_grad(): 46 | for i, (data, target) in enumerate(tqdm(data_loader)): 47 | data, target = data.to(device), target.to(device) 48 | output = model(data) 49 | 50 | # 51 | # save sample images, or do something with output here 52 | # 53 | 54 | # computing loss, metrics on test set 55 | loss = criterion(output, target) 56 | batch_size = data.shape[0] 57 | total_loss += loss.item() * batch_size 58 | for i, metric in enumerate(metrics.values()): 59 | total_metrics[i] += metric(output, target) * batch_size 60 | 61 | n_samples = len(data_loader.sampler) 62 | log = {'loss': total_loss / n_samples} 63 | log.update({ 64 | met_name: total_metrics[i].item() / n_samples 65 | for i, met_name in enumerate(metrics.keys()) 66 | }) 67 | logger.info(log) 68 | 69 | 70 | if __name__ == '__main__': 71 | # pylint: disable=no-value-for-parameter 72 | main() 73 | -------------------------------------------------------------------------------- /new_project.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from shutil import copytree, ignore_patterns 4 | 5 | 6 | # This script initializes new pytorch project with the template files. 7 | # Run `python3 new_project.py ../MyNewProject` then new project named 8 | # MyNewProject will be made 9 | current_dir = Path() 10 | assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory' 11 | assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject' 12 | 13 | project_name = Path(sys.argv[1]) 14 | target_dir = current_dir / project_name 15 | 16 | ignore = [".git", "FashionMNIST", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"] 17 | copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore)) 18 | print('New project initialized at', target_dir.absolute().resolve()) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1 2 | torchvision 3 | numpy 4 | hydra-core>=1.3.2 5 | tqdm 6 | tensorboard>=1.14 7 | -------------------------------------------------------------------------------- /srcs/data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.utils.data import DataLoader, DistributedSampler, random_split 3 | from torchvision import datasets, transforms 4 | 5 | 6 | def get_data_loaders(data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 7 | trsfm = transforms.Compose([ 8 | transforms.ToTensor(), 9 | transforms.Normalize((0.1307,), (0.3081,)) 10 | ]) 11 | 12 | dataset = datasets.FashionMNIST(data_dir, train=training, download=True, transform=trsfm) 13 | 14 | loader_args = { 15 | 'batch_size': batch_size, 16 | 'shuffle': shuffle, 17 | 'num_workers': num_workers 18 | } 19 | if training: 20 | # split dataset into train and validation set 21 | num_total = len(dataset) 22 | if isinstance(validation_split, int): 23 | assert validation_split > 0 24 | assert validation_split < num_total, "validation set size is configured to be larger than entire dataset." 25 | num_valid = validation_split 26 | else: 27 | num_valid = int(num_total * validation_split) 28 | num_train = num_total - num_valid 29 | 30 | train_dataset, valid_dataset = random_split(dataset, [num_train, num_valid]) 31 | 32 | train_sampler, valid_sampler = None, None 33 | if dist.is_initialized(): 34 | loader_args['shuffle']=False 35 | train_sampler = DistributedSampler(train_dataset) 36 | valid_sampler = DistributedSampler(valid_dataset) 37 | return DataLoader(train_dataset, sampler=train_sampler, **loader_args), \ 38 | DataLoader(valid_dataset, sampler=valid_sampler, **loader_args) 39 | else: 40 | return DataLoader(dataset, **loader_args) 41 | 42 | -------------------------------------------------------------------------------- /srcs/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from itertools import product 4 | from torch.utils.tensorboard import SummaryWriter 5 | from datetime import datetime 6 | from srcs.utils import get_logger 7 | 8 | 9 | class TensorboardWriter(): 10 | def __init__(self, log_dir, enabled): 11 | self.logger = get_logger('tensorboard-writer') 12 | self.writer = SummaryWriter(log_dir) if enabled else None 13 | self.selected_module = "" 14 | 15 | if enabled: 16 | log_dir = str(log_dir) 17 | 18 | self.step = 0 19 | 20 | self.tb_writer_ftns = { 21 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 22 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 23 | } 24 | self.timer = datetime.now() 25 | 26 | def set_step(self, step): 27 | self.step = step 28 | if step == 0: 29 | self.timer = datetime.now() 30 | else: 31 | duration = datetime.now() - self.timer 32 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 33 | self.timer = datetime.now() 34 | 35 | def __getattr__(self, name): 36 | """ 37 | If visualization is configured to use: 38 | return add_data() methods of tensorboard with additional information (step, tag) added. 39 | Otherwise: 40 | return a blank function handle that does nothing 41 | """ 42 | if name in self.tb_writer_ftns: 43 | add_data = getattr(self.writer, name, None) 44 | 45 | def wrapper(tag, data, *args, **kwargs): 46 | if add_data is not None: 47 | add_data(tag, data, self.step, *args, **kwargs) 48 | return wrapper 49 | else: 50 | attr = getattr(self.writer, name) 51 | return attr 52 | 53 | class BatchMetrics: 54 | def __init__(self, *keys, postfix='', writer=None): 55 | self.writer = writer 56 | self.postfix = postfix 57 | if postfix: 58 | keys = [k+postfix for k in keys] 59 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 60 | self.reset() 61 | 62 | def reset(self): 63 | for col in self._data.columns: 64 | self._data[col].values[:] = 0 65 | 66 | def update(self, key, value, n=1): 67 | if self.postfix: 68 | key = key + self.postfix 69 | if self.writer is not None: 70 | self.writer.add_scalar(key, value) 71 | self._data.total[key] += value * n 72 | self._data.counts[key] += n 73 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 74 | 75 | def avg(self, key): 76 | if self.postfix: 77 | key = key + self.postfix 78 | return self._data.average[key] 79 | 80 | def result(self): 81 | return dict(self._data.average) 82 | 83 | class EpochMetrics: 84 | def __init__(self, metric_names, phases=('train', 'valid'), monitoring='off'): 85 | self.logger = get_logger('epoch-metrics') 86 | # setup pandas DataFrame with hierarchical columns 87 | columns = tuple(product(metric_names, phases)) 88 | self._data = pd.DataFrame(columns=columns) # TODO: add epoch duration 89 | self.monitor_mode, self.monitor_metric = self._parse_monitoring_mode(monitoring) 90 | self.topk_idx = [] 91 | 92 | def minimizing_metric(self, idx): 93 | if self.monitor_mode == 'off': 94 | return 0 95 | try: 96 | metric = self._data[self.monitor_metric].loc[idx] 97 | except KeyError: 98 | self.logger.warning("Warning: Metric '{}' is not found. " 99 | "Model performance monitoring is disabled.".format(self.monitor_metric)) 100 | self.monitor_mode = 'off' 101 | return 0 102 | if self.monitor_mode == 'min': 103 | return metric 104 | else: 105 | return - metric 106 | 107 | def _parse_monitoring_mode(self, monitor_mode): 108 | if monitor_mode == 'off': 109 | return 'off', None 110 | else: 111 | monitor_mode, monitor_metric = monitor_mode.split() 112 | monitor_metric = tuple(monitor_metric.split('/')) 113 | assert monitor_mode in ['min', 'max'] 114 | return monitor_mode, monitor_metric 115 | 116 | def is_improved(self): 117 | if self.monitor_mode == 'off': 118 | return True 119 | 120 | last_epoch = self._data.index[-1] 121 | best_epoch = self.topk_idx[0] 122 | return last_epoch == best_epoch 123 | 124 | def keep_topk_checkpt(self, checkpt_dir, k=3): 125 | """ 126 | Keep top-k checkpoints by deleting k+1'th best epoch index from dataframe for every epoch. 127 | """ 128 | if len(self.topk_idx) > k and self.monitor_mode != 'off': 129 | last_epoch = self._data.index[-1] 130 | self.topk_idx = self.topk_idx[:(k+1)] 131 | if last_epoch not in self.topk_idx: 132 | to_delete = last_epoch 133 | else: 134 | to_delete = self.topk_idx[-1] 135 | 136 | # delete checkpoint having out-of topk metric 137 | filename = str(checkpt_dir / 'checkpoint-epoch{}.pth'.format(to_delete.split('-')[1])) 138 | try: 139 | os.remove(filename) 140 | except FileNotFoundError: 141 | # this happens when current model is loaded from checkpoint 142 | # or target file is already removed somehow 143 | pass 144 | 145 | def update(self, epoch, result): 146 | epoch_idx = f'epoch-{epoch}' 147 | self._data.loc[epoch_idx] = {tuple(k.split('/')):v for k, v in result.items()} 148 | 149 | self.topk_idx.append(epoch_idx) 150 | self.topk_idx = sorted(self.topk_idx, key=self.minimizing_metric) 151 | 152 | def latest(self): 153 | return self._data[-1:] 154 | 155 | def to_csv(self, save_path=None): 156 | self._data.to_csv(save_path) 157 | 158 | def __str__(self): 159 | return str(self._data) 160 | -------------------------------------------------------------------------------- /srcs/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def nll_loss(output, target): 5 | return F.nll_loss(output, target) 6 | -------------------------------------------------------------------------------- /srcs/model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target): 5 | with torch.no_grad(): 6 | pred = torch.argmax(output, dim=1) 7 | assert pred.shape[0] == len(target) 8 | correct = 0 9 | correct += torch.sum(pred == target).item() 10 | return correct / len(target) 11 | 12 | 13 | def top_k_acc(output, target, k=3): 14 | with torch.no_grad(): 15 | pred = torch.topk(output, k, dim=1)[1] 16 | assert pred.shape[0] == len(target) 17 | correct = 0 18 | for i in range(k): 19 | correct += torch.sum(pred[:, i] == target).item() 20 | return correct / len(target) 21 | -------------------------------------------------------------------------------- /srcs/model/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class MnistModel(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 9 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 10 | self.conv2_drop = nn.Dropout2d() 11 | self.fc1 = nn.Linear(320, 50) 12 | self.fc2 = nn.Linear(50, num_classes) 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return F.log_softmax(x, dim=1) 22 | -------------------------------------------------------------------------------- /srcs/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /srcs/trainer/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.parallel import DistributedDataParallel 4 | from abc import abstractmethod, ABCMeta 5 | from pathlib import Path 6 | from shutil import copyfile 7 | from numpy import inf 8 | 9 | from srcs.utils import write_conf, is_master, get_logger 10 | from srcs.logger import TensorboardWriter, EpochMetrics 11 | 12 | 13 | class BaseTrainer(metaclass=ABCMeta): 14 | """ 15 | Base class for all trainers 16 | """ 17 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 18 | self.config = config 19 | self.logger = get_logger('trainer') 20 | 21 | self.device = config.local_rank 22 | self.model = model.to(self.device) 23 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 24 | self.model = DistributedDataParallel(model, device_ids=[self.device], output_device=self.device) 25 | 26 | self.criterion = criterion 27 | self.metric_ftns = metric_ftns 28 | self.optimizer = optimizer 29 | 30 | cfg_trainer = config['trainer'] 31 | self.epochs = cfg_trainer['epochs'] 32 | self.log_step = cfg_trainer['logging_step'] 33 | 34 | # setup metric monitoring for monitoring model performance and saving best-checkpoint 35 | self.monitor = cfg_trainer.get('monitor', 'off') 36 | 37 | metric_names = ['loss'] + list(self.metric_ftns.keys()) 38 | self.ep_metrics = EpochMetrics(metric_names, phases=('train', 'valid'), monitoring=self.monitor) 39 | 40 | self.checkpt_top_k = cfg_trainer.get('save_topk', -1) 41 | self.early_stop = cfg_trainer.get('early_stop', inf) 42 | 43 | write_conf(self.config, 'config.yaml') 44 | 45 | self.start_epoch = 1 46 | self.checkpt_dir = Path(self.config.save_dir) 47 | log_dir = Path(self.config.log_dir) 48 | if is_master(): 49 | self.checkpt_dir.mkdir() 50 | # setup visualization writer instance 51 | log_dir.mkdir() 52 | self.writer = TensorboardWriter(log_dir, cfg_trainer['tensorboard']) 53 | else: 54 | self.writer = TensorboardWriter(log_dir, False) 55 | 56 | if config.resume is not None: 57 | self._resume_checkpoint(config.resume) 58 | 59 | @abstractmethod 60 | def _train_epoch(self, epoch): 61 | """ 62 | Training logic for an epoch 63 | 64 | :param epoch: Current epoch number 65 | """ 66 | raise NotImplementedError 67 | 68 | def train(self): 69 | """ 70 | Full training logic 71 | """ 72 | not_improved_count = 0 73 | for epoch in range(self.start_epoch, self.epochs + 1): 74 | result = self._train_epoch(epoch) 75 | self.ep_metrics.update(epoch, result) 76 | 77 | # print result metrics of this epoch 78 | max_line_width = max(len(line) for line in str(self.ep_metrics).splitlines()) 79 | # divider --- 80 | self.logger.info('-' * max_line_width) 81 | self.logger.info(str(self.ep_metrics.latest()) + '\n') 82 | 83 | if is_master(): 84 | # check if model performance improved or not, for early stopping and topk saving 85 | is_best = False 86 | improved = self.ep_metrics.is_improved() 87 | if improved: 88 | not_improved_count = 0 89 | is_best = True 90 | else: 91 | not_improved_count += 1 92 | 93 | if not_improved_count > self.early_stop: 94 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 95 | "Training stops.".format(self.early_stop)) 96 | exit(1) 97 | 98 | using_topk_save = self.checkpt_top_k > 0 99 | self._save_checkpoint(epoch, save_best=is_best, save_latest=using_topk_save) 100 | # keep top-k checkpoints only, using monitoring metrics 101 | if using_topk_save: 102 | self.ep_metrics.keep_topk_checkpt(self.checkpt_dir, self.checkpt_top_k) 103 | 104 | self.ep_metrics.to_csv('epoch-results.csv') 105 | 106 | # divider === 107 | self.logger.info('=' * max_line_width) 108 | dist.barrier() 109 | 110 | 111 | def _save_checkpoint(self, epoch, save_best=False, save_latest=True): 112 | """ 113 | Saving checkpoints 114 | 115 | :param epoch: current epoch number 116 | :param log: logging information of the epoch 117 | :param save_best: if True, save a copy of current checkpoint file as 'model_best.pth' 118 | :param save_latest: if True, save a copy of current checkpoint file as 'model_latest.pth' 119 | """ 120 | arch = type(self.model).__name__ 121 | state = { 122 | 'arch': arch, 123 | 'epoch': epoch, 124 | 'state_dict': self.model.state_dict(), 125 | 'optimizer': self.optimizer.state_dict(), 126 | 'epoch_metrics': self.ep_metrics, 127 | 'config': self.config 128 | } 129 | 130 | filename = str(self.checkpt_dir / f'checkpoint-epoch{epoch}.pth') 131 | torch.save(state, filename) 132 | self.logger.info(f"Model checkpoint saved at: \n {self.config.cwd}/{filename}") 133 | if save_latest: 134 | latest_path = str(self.checkpt_dir / 'model_latest.pth') 135 | copyfile(filename, latest_path) 136 | if save_best: 137 | best_path = str(self.checkpt_dir / 'model_best.pth') 138 | copyfile(filename, best_path) 139 | self.logger.info(f"Renewing best checkpoint: \n .../{best_path}") 140 | 141 | def _resume_checkpoint(self, resume_path): 142 | """ 143 | Resume from saved checkpoints 144 | 145 | :param resume_path: Checkpoint path to be resumed 146 | """ 147 | resume_path = self.config.resume 148 | self.logger.info(f"Loading checkpoint: {resume_path} ...") 149 | checkpoint = torch.load(resume_path) 150 | self.start_epoch = checkpoint['epoch'] + 1 151 | 152 | # TODO: support overriding monitor-metric config 153 | self.ep_metrics = checkpoint['epoch_metrics'] 154 | 155 | # load architecture params from checkpoint. 156 | if checkpoint['config']['arch'] != self.config['arch']: 157 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 158 | "checkpoint. This may yield an exception while state_dict is being loaded.") 159 | self.model.load_state_dict(checkpoint['state_dict']) 160 | 161 | # load optimizer state from checkpoint only when optimizer type is not changed. 162 | if checkpoint['config']['optimizer']['_target_'] != self.config['optimizer']['_target_']: 163 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 164 | "Optimizer parameters not being resumed.") 165 | else: 166 | self.optimizer.load_state_dict(checkpoint['optimizer']) 167 | 168 | self.logger.info(f"Checkpoint loaded. Resume training from epoch {self.start_epoch}") 169 | -------------------------------------------------------------------------------- /srcs/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torchvision.utils import make_grid 4 | from .base import BaseTrainer 5 | from srcs.utils import inf_loop, collect 6 | from srcs.logger import BatchMetrics 7 | 8 | 9 | class Trainer(BaseTrainer): 10 | """ 11 | Trainer class 12 | """ 13 | def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, 14 | valid_data_loader=None, lr_scheduler=None, len_epoch=None): 15 | super().__init__(model, criterion, metric_ftns, optimizer, config) 16 | self.config = config 17 | self.data_loader = data_loader 18 | if len_epoch is None: 19 | # epoch-based training 20 | self.len_epoch = len(self.data_loader) 21 | else: 22 | # iteration-based training 23 | self.data_loader = inf_loop(data_loader) 24 | self.len_epoch = len_epoch 25 | self.valid_data_loader = valid_data_loader 26 | self.lr_scheduler = lr_scheduler 27 | 28 | args = ['loss', *[met_name for met_name in self.metric_ftns.keys()]] 29 | self.train_metrics = BatchMetrics(*args, postfix='/train', writer=self.writer) 30 | self.valid_metrics = BatchMetrics(*args, postfix='/valid', writer=self.writer) 31 | 32 | def _train_epoch(self, epoch): 33 | """ 34 | Training logic for an epoch 35 | 36 | :param epoch: Integer, current training epoch. 37 | :return: A log that contains average loss and metric in this epoch. 38 | """ 39 | self.model.train() 40 | self.train_metrics.reset() 41 | for batch_idx, (data, target) in enumerate(self.data_loader): 42 | data, target = data.to(self.device), target.to(self.device) 43 | 44 | self.optimizer.zero_grad() 45 | output = self.model(data) 46 | loss = self.criterion(output, target) 47 | loss.backward() 48 | self.optimizer.step() 49 | 50 | loss = collect(loss) 51 | self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) 52 | self.train_metrics.update('loss', loss) 53 | 54 | if batch_idx % self.log_step == 0: 55 | self.writer.add_image('train/input', make_grid(data.cpu(), nrow=8, normalize=True)) 56 | for met_name, met in self.metric_ftns.items(): 57 | metric = collect(met(output, target)) # average metric between processes 58 | self.train_metrics.update(met_name, metric) 59 | self.logger.info(f'Train Epoch: {epoch} {self._progress(batch_idx)} Loss: {loss:.6f}') 60 | 61 | if batch_idx == self.len_epoch: 62 | break 63 | log = self.train_metrics.result() 64 | 65 | if self.valid_data_loader is not None: 66 | val_log = self._valid_epoch(epoch) 67 | log.update(**val_log) 68 | 69 | if self.lr_scheduler is not None: 70 | self.lr_scheduler.step() 71 | 72 | # add result metrics on entire epoch to tensorboard 73 | self.writer.set_step(epoch) 74 | for k, v in log.items(): 75 | self.writer.add_scalar(k + '/epoch', v) 76 | return log 77 | 78 | def _valid_epoch(self, epoch): 79 | """ 80 | Validate after training an epoch 81 | 82 | :param epoch: Integer, current training epoch. 83 | :return: A log that contains information about validation 84 | """ 85 | self.model.eval() 86 | self.valid_metrics.reset() 87 | with torch.no_grad(): 88 | for batch_idx, (data, target) in enumerate(self.valid_data_loader): 89 | data, target = data.to(self.device), target.to(self.device) 90 | 91 | output = self.model(data) 92 | loss = self.criterion(output, target) 93 | 94 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx) 95 | self.writer.add_image('valid/input', make_grid(data.cpu(), nrow=8, normalize=True)) 96 | self.valid_metrics.update('loss', collect(loss)) 97 | for met_name, met in self.metric_ftns.items(): 98 | self.valid_metrics.update(met_name, met(output, target)) 99 | 100 | # add histogram of model parameters to the tensorboard 101 | for name, p in self.model.named_parameters(): 102 | self.writer.add_histogram(name, p, bins='auto') 103 | return self.valid_metrics.result() 104 | 105 | def _progress(self, batch_idx): 106 | base = '[{}/{} ({:.0f}%)]' 107 | try: 108 | # epoch-based training 109 | total = len(self.data_loader.dataset) 110 | current = batch_idx * self.data_loader.batch_size 111 | if dist.is_initialized(): 112 | current *= dist.get_world_size() 113 | except AttributeError: 114 | # iteration-based training 115 | total = self.len_epoch 116 | current = batch_idx 117 | return base.format(current, total, 100.0 * current / total) 118 | -------------------------------------------------------------------------------- /srcs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /srcs/utils/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | import torch 4 | import torch.distributed as dist 5 | from omegaconf import OmegaConf 6 | from pathlib import Path 7 | from itertools import repeat 8 | 9 | 10 | def is_master(): 11 | return not dist.is_initialized() or dist.get_rank() == 0 12 | 13 | def get_logger(name=None): 14 | if is_master(): 15 | # TODO: also configure logging for sub-processes(not master) 16 | hydra_conf = OmegaConf.load('.hydra/hydra.yaml') 17 | logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True)) 18 | return logging.getLogger(name) 19 | 20 | 21 | def collect(scalar): 22 | """ 23 | util function for DDP. 24 | syncronize a python scalar or pytorch scalar tensor between GPU processes. 25 | """ 26 | # move data to current device 27 | if not isinstance(scalar, torch.Tensor): 28 | scalar = torch.tensor(scalar) 29 | scalar = scalar.to(dist.get_rank()) 30 | 31 | # average value between devices 32 | dist.reduce(scalar, 0, dist.ReduceOp.SUM) 33 | return scalar.item() / dist.get_world_size() 34 | 35 | def inf_loop(data_loader): 36 | ''' wrapper function for endless data loader. ''' 37 | for loader in repeat(data_loader): 38 | yield from loader 39 | 40 | def write_yaml(content, fname): 41 | with fname.open('wt') as handle: 42 | yaml.dump(content, handle, indent=2, sort_keys=False) 43 | 44 | def write_conf(config, save_path): 45 | save_path = Path(save_path) 46 | save_path.parent.mkdir(parents=True, exist_ok=True) 47 | config_dict = OmegaConf.to_container(config, resolve=True) 48 | write_yaml(config_dict, save_path) 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.distributed as dist 4 | import hydra 5 | from hydra.utils import instantiate 6 | from omegaconf import OmegaConf 7 | from pathlib import Path 8 | from srcs.trainer import Trainer 9 | from srcs.utils import get_logger 10 | 11 | 12 | # fix random seeds for reproducibility 13 | SEED = 123 14 | torch.manual_seed(SEED) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | np.random.seed(SEED) 18 | 19 | 20 | def train_worker(config): 21 | logger = get_logger('train') 22 | # setup data_loader instances 23 | data_loader, valid_data_loader = instantiate(config.data_loader) 24 | 25 | # build model. print it's structure and # trainable params. 26 | model = instantiate(config.arch) 27 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 28 | logger.info(model) 29 | logger.info(f'Trainable parameters: {sum([p.numel() for p in trainable_params])}') 30 | 31 | # get function handles of loss and metrics 32 | criterion = instantiate(config.loss) 33 | metrics = { 34 | met_name: instantiate(met) 35 | for met_name, met in config.metrics.items() 36 | } 37 | 38 | # build optimizer, learning rate scheduler. 39 | optimizer = instantiate(config.optimizer, model.parameters()) 40 | lr_scheduler = instantiate(config.lr_scheduler, optimizer) 41 | 42 | trainer = Trainer(model, criterion, metrics, optimizer, 43 | config=config, 44 | data_loader=data_loader, 45 | valid_data_loader=valid_data_loader, 46 | lr_scheduler=lr_scheduler) 47 | trainer.train() 48 | 49 | def init_worker(rank, ngpus, working_dir, config): 50 | # initialize training config 51 | config = OmegaConf.create(config) 52 | config.local_rank = rank 53 | config.cwd = working_dir 54 | # prevent access to non-existing keys 55 | OmegaConf.set_struct(config, True) 56 | 57 | dist.init_process_group( 58 | backend='nccl', 59 | init_method='tcp://127.0.0.1:34567', 60 | world_size=ngpus, 61 | rank=rank) 62 | torch.cuda.set_device(rank) 63 | 64 | # start training processes 65 | train_worker(config) 66 | 67 | @hydra.main(config_path='conf/', config_name='train', version_base='1.1') 68 | def main(config): 69 | n_gpu = torch.cuda.device_count() 70 | assert n_gpu, 'Can\'t find any GPU device on this machine.' 71 | 72 | working_dir = str(Path.cwd().relative_to(hydra.utils.get_original_cwd())) 73 | 74 | if config.resume is not None: 75 | config.resume = hydra.utils.to_absolute_path(config.resume) 76 | config = OmegaConf.to_yaml(config, resolve=True) 77 | torch.multiprocessing.spawn(init_worker, nprocs=n_gpu, args=(n_gpu, working_dir, config)) 78 | 79 | if __name__ == '__main__': 80 | # pylint: disable=no-value-for-parameter 81 | main() 82 | --------------------------------------------------------------------------------