├── .gitignore ├── README.md ├── config └── cifar100.yaml ├── expts ├── bs1024-opt │ └── tb_logs │ │ └── events.out.tfevents.1604953226.cgpu05.79229.0 ├── bs1024-warmup-opt │ └── tb_logs │ │ └── events.out.tfevents.1604953230.cgpu07.50299.0 ├── bs128-opt │ └── tb_logs │ │ └── events.out.tfevents.1604948513.cgpu15.13890.0 ├── bs2048-opt │ └── tb_logs │ │ └── events.out.tfevents.1604954305.cgpu05.43094.0 ├── bs2048-warmup-opt │ └── tb_logs │ │ └── events.out.tfevents.1604953824.cgpu05.66469.0 └── bs512-opt │ └── tb_logs │ └── events.out.tfevents.1604950792.cgpu09.61987.0 ├── models └── resnet.py ├── print_arch.py ├── sout └── __init__.py ├── submit.slr ├── submit_multinode.slr ├── train.py ├── train_simple.py ├── tutorial_images ├── acc1.png ├── acc2.png ├── acc3.png ├── bs128_learning.png ├── nsys_amp_nhwc_extra_zoomed.png ├── nsys_amp_nhwc_zoomed.png ├── nsys_amp_zoomed.png ├── nsys_amp_zoomed_kernels.png ├── nsys_baseline_full.png ├── nsys_baseline_zoomed.png └── throughputScaling.png └── utils ├── YParams.py └── cifar100_data_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | expts 3 | sout 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC20 Deep Learning at Scale Tutorial 2 | 3 | This repository contains the example code material for the SC20 tutorial: 4 | *Deep Learning at Scale*. 5 | 6 | The example demonstrates *synchronous data-parallel distributed training* of a 7 | convolutional deep neural network implemented in [PyTorch](https://pytorch.org/) 8 | on a standard computer vision problem. In particular, we are training ResNet50 9 | on the [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) dataset to 10 | classify images into 100 classes. 11 | 12 | **Contents** 13 | * [Links](#links) 14 | * [Installation](#installation) 15 | * [Model, data, and training code overview](#model-data-and-training-code-overview) 16 | * [Single GPU training](#single-gpu-training) 17 | * [Performance profiling and optimization](#performance-profiling-and-optimization) 18 | * [Profiling with Nsight Systems](#profiling-with-nsight-systems) 19 | * [Enabling Mixed Precision Training](#enabling-mixed-precision-training) 20 | * [Applying additional PyTorch optimizations](#applying-additional-pytorch-optimizations) 21 | * [Distributed GPU training](#distributed-gpu-training) 22 | * [Code basics](#code-basics) 23 | * [Large batch convergence](#large-batch-convergence) 24 | 25 | ## Links 26 | 27 | Presentation slides for the tutorial can be found at: 28 | https://drive.google.com/drive/folders/1-gi1WvfQ6alDOnMwN3JqgNlrQh7MlIQr?usp=sharing 29 | 30 | We have a nersc-dl-tutorial slack you can join. Use this link and join the 31 | `#sc20-dl-tutorial` channel: 32 | https://join.slack.com/t/nersc-dl-tutorial/shared_invite/zt-iwyrkhza-h_Oun~8JO9xDKZynD5hERA 33 | 34 | ## Installation 35 | 36 | If you're running these examples on the Cori GPU system at NERSC, no 37 | installation is needed; you can simply use our provided modules or 38 | shifter containers. 39 | 40 | See the [submit.slr](submit.slr) slurm script for a simple example using 41 | our PyTorch 1.7.0 installation for GPU. 42 | 43 | Otherwise, our package dependencies: 44 | - pytorch 1.7.0 45 | - torchvision 46 | - apex 47 | - ruamel.yaml 48 | - nccl 49 | 50 | ## Model, data, and training code overview 51 | 52 | The network architecture for our ResNet50 model can be found in 53 | [models/resnet.py](models/resnet.py). Here we have copied the ResNet50 54 | implementation from torchvision and made a few minor adjustments for the 55 | CIFAR dataset (e.g. reducing stride and pooling). 56 | 57 | The data pipeline code can be found in 58 | [utils/cifar100\_data\_loader.py](utils/cifar100_data_loader.py). 59 | Note that the dataset code for this example is fairly simple because the 60 | torchvision package provides the dataset class which handles the image 61 | loading for us. Key ingredients: 62 | * We compose a sequence of data transforms for normalization and random 63 | augmentation of the images. 64 | * We construct a `datasets.CIFAR100` dataset which will automatically download 65 | the dataset to a specified directory. We pass it our list of transforms. 66 | * We construct a DataLoader which orchestrates the random sampling and batching 67 | of our images. 68 | 69 | The basic training logic can be found in [train\_simple.py](train_simple.py). 70 | In this training script we have defined a simple Trainer class which 71 | implements methods for training and validation epochs. Key ingredients: 72 | * In the Trainer's `__init__` method, we get the data loaders, construct our 73 | ResNet50 model, the SGD optimizer, and our `CrossEntropyLoss` objective 74 | function. 75 | * In the Trainer's `train_one_epoch` method, we implement the actual logic for 76 | training the model on batches of data. 77 | * Identify where we loop over data batches from our data loader. 78 | * Identify where we apply the forward pass of the model ("Model forward pass") 79 | and compute the loss function. 80 | * Identify where we call `backward()` on the loss value. Note the use of the 81 | `grad_scaler` will be explained below when enabling mixed precision. 82 | * Similarly, in the Trainer's `validate_one_epoch`, we implement the simpler 83 | logic of applying the model to a validation dataset and compute metrics like 84 | accuracy. 85 | * Checkpoint saving and loading are implemented in the Trainer's `save_checkpoint` 86 | and `restore_checkpoint` methods, respectively. 87 | * We construct and use a TensorBoard SummaryWriter for logging metrics to 88 | visualize in TensorBoard. See if you can find where our specific metrics 89 | are logged via the `add_scalar` call. 90 | 91 | Besides the `train_simple.py` script, we have a more complex [train.py](train.py) 92 | script which implements the same functionality but also includes a lot of 93 | additional optimizations which will be covered in the Performance profiling and 94 | optimization section below. 95 | 96 | ## Single GPU training 97 | 98 | To run single GPU training of the baseline training script, use the following command: 99 | ``` 100 | $ python train.py --config=bs128 101 | ``` 102 | This will run the training on a single GPU using batch size of 128 103 | (see `config/cifar100.yaml` for specific configuration details). 104 | Note we will use batch size 256 for the optimization work in the next section 105 | and will push beyond to larger batch sizes in the distributed training section. 106 | 107 | In the baseline configuration, the model converges to about 75% accuracy on 108 | the validation dataset in about 80 epochs: 109 | 110 | ![bs128 learning curves](tutorial_images/bs128_learning.png) 111 | 112 | ## Performance profiling and optimization 113 | 114 | This is the performance of the baseline script using the NGC PyTorch 20.10 container for the first two epochs on a 16GB V100 card with batch size 256: 115 | ``` 116 | INFO - Starting Training Loop... 117 | INFO - Epoch: 1, Iteration: 0, Avg img/sec: 110.19908073510402 118 | INFO - Epoch: 1, Iteration: 20, Avg img/sec: 680.8613838734273 119 | INFO - Epoch: 1, Iteration: 40, Avg img/sec: 682.4229819820212 120 | INFO - Epoch: 1, Iteration: 60, Avg img/sec: 683.0516710020236 121 | INFO - Epoch: 1, Iteration: 80, Avg img/sec: 681.2955112832597 122 | INFO - Epoch: 1, Iteration: 100, Avg img/sec: 681.7366420029032 123 | INFO - Epoch: 1, Iteration: 120, Avg img/sec: 680.9312458089512 124 | INFO - Epoch: 1, Iteration: 140, Avg img/sec: 680.2227561980723 125 | INFO - Epoch: 1, Iteration: 160, Avg img/sec: 680.6287580660272 126 | INFO - Epoch: 1, Iteration: 180, Avg img/sec: 680.7244649829499 127 | INFO - Time taken for epoch 1 is 79.90803146362305 sec 128 | INFO - Epoch: 2, Iteration: 0, Avg img/sec: 297.1326786725325 129 | INFO - Epoch: 2, Iteration: 20, Avg img/sec: 680.1821654149742 130 | INFO - Epoch: 2, Iteration: 40, Avg img/sec: 679.7391921357676 131 | INFO - Epoch: 2, Iteration: 60, Avg img/sec: 680.29168975637 132 | INFO - Epoch: 2, Iteration: 80, Avg img/sec: 680.2163354650426 133 | INFO - Epoch: 2, Iteration: 100, Avg img/sec: 680.1871635938127 134 | INFO - Epoch: 2, Iteration: 120, Avg img/sec: 679.7543395008651 135 | INFO - Epoch: 2, Iteration: 140, Avg img/sec: 679.708426128615 136 | INFO - Epoch: 2, Iteration: 160, Avg img/sec: 679.2982136487756 137 | INFO - Epoch: 2, Iteration: 180, Avg img/sec: 679.0788730107779 138 | INFO - Time taken for epoch 2 is 78.5151789188385 sec 139 | ``` 140 | 141 | ### Profiling with Nsight Systems 142 | Before generating a profile with Nsight, we can add NVTX ranges to the script to add context to the produced timeline. First, we can enable PyTorch's built-in NVTX annotations by using the `torch.autograd.profiler.emit_nvtx` context manager. 143 | We can also manually add some manually defined NVTX ranges to the code using `torch.cuda.nvtx.range_push` and `torch.cuda.nvtx.range_pop`. Search `train.py` for comments labeled `# PROF` to see where we've added code. 144 | As a quick note, we defined some simple functions to wrap the NVTX range calls in order to add synchronization: 145 | ``` 146 | def nvtx_range_push(name, enabled): 147 | if enabled: 148 | torch.cuda.synchronize() 149 | torch.cuda.nvtx.range_push(name) 150 | 151 | def nvtx_range_pop(enabled): 152 | if enabled: 153 | torch.cuda.synchronize() 154 | torch.cuda.nvtx.range_pop() 155 | ``` 156 | As GPU operations can be asynchronous with respect to the Python thread, these syncs are necessary to create accurate ranges. Without them, the ranges will only contain the time to _launch_ the GPU work. 157 | 158 | 159 | To generate a timeline, run the following: 160 | ``` 161 | $ NSYS_NVTX_PROFILER_REGISTER_ONLY=0 nsys profile -o baseline --trace=cuda,nvtx --capture-range=nvtx --nvtx-capture=PROFILE python -m torch.distributed.launch --nproc_per_node=1 train.py --config=bs256-prof 162 | ``` 163 | This command will run two shortened epochs of 80 iterations of the training script and produce a file `baseline.qdrep` that can be opened in the Nsight System's program. The arg `--trace=cuda,nvtx` is optional and is used here to disable OS Runtime tracing for speed. 164 | The args `--capture-range=nvtx --nvtx-capture=PROFILE` and variable `NSYS_NVTX_PROFILER_REGISTER_ONLY=0` will limit the profiling to the NVTX range named "PROFILE", which we've used to limit profiling to the second epoch only. 165 | 166 | Loading this profile in Nsight Systems will look like this: 167 | ![Baseline](tutorial_images/nsys_baseline_full.png) 168 | 169 | With our NVTX ranges, we can easily zoom into a single iteration and get an idea of where compute time is being spent: 170 | ![Baseline Zoomed](tutorial_images/nsys_baseline_zoomed.png) 171 | 172 | 173 | ### Enabling Mixed Precision Training 174 | As a first step to improve the compute performance of this training script, we can enable automatic mixed precision (AMP) in PyTorch. AMP provides a simple way for users to convert existing FP32 training scripts to mixed FP32/FP16 precision, unlocking 175 | faster computation with Tensor Cores on NVIDIA GPUs. The AMP module in torch is composed of two main parts: `torch.cuda.amp.GradScaler` and `torch.cuda.amp.autocast`. `torch.cuda.amp.GradScaler` handles automatic loss scaling to control the range of FP16 gradients. 176 | The `torch.cuda.amp.autocast` context manager handles converting model operations to FP16 where appropriate. Search `train.py` for comments labeled `# AMP:` to see where we've added code to enable AMP in this script. 177 | 178 | To run the script on a single GPU with AMP enabled, use the following command: 179 | ``` 180 | $ python -m torch.distributed.launch --nproc_per_node=1 train.py --config=bs256-amp 181 | ``` 182 | With AMP enabled, this is the performance of the baseline using the NGC PyTorch 20.10 container for the first two epochs on a 16GB V100 card: 183 | ``` 184 | INFO - Starting Training Loop... 185 | INFO - Epoch: 1, Iteration: 0, Avg img/sec: 131.4890829860097 186 | INFO - Epoch: 1, Iteration: 20, Avg img/sec: 1925.8088037080554 187 | INFO - Epoch: 1, Iteration: 40, Avg img/sec: 1884.341731901802 188 | INFO - Epoch: 1, Iteration: 60, Avg img/sec: 1796.3608488557659 189 | INFO - Epoch: 1, Iteration: 80, Avg img/sec: 1797.1991164491794 190 | INFO - Epoch: 1, Iteration: 100, Avg img/sec: 1794.721454602102 191 | INFO - Epoch: 1, Iteration: 120, Avg img/sec: 1800.0616660977953 192 | INFO - Epoch: 1, Iteration: 140, Avg img/sec: 1794.3491050370249 193 | INFO - Epoch: 1, Iteration: 160, Avg img/sec: 1797.8587343614402 194 | INFO - Epoch: 1, Iteration: 180, Avg img/sec: 1794.0956118635277 195 | INFO - Time taken for epoch 1 is 33.888301610946655 sec 196 | INFO - Epoch: 2, Iteration: 0, Avg img/sec: 397.0763949367613 197 | INFO - Epoch: 2, Iteration: 20, Avg img/sec: 1831.3360728112361 198 | INFO - Epoch: 2, Iteration: 40, Avg img/sec: 1804.6830246566537 199 | INFO - Epoch: 2, Iteration: 60, Avg img/sec: 1799.7809136620713 200 | INFO - Epoch: 2, Iteration: 80, Avg img/sec: 1793.427968035233 201 | INFO - Epoch: 2, Iteration: 100, Avg img/sec: 1794.953670200433 202 | INFO - Epoch: 2, Iteration: 120, Avg img/sec: 1795.3373776036665 203 | INFO - Epoch: 2, Iteration: 140, Avg img/sec: 1791.194021111478 204 | INFO - Epoch: 2, Iteration: 160, Avg img/sec: 1825.7166134675574 205 | INFO - Epoch: 2, Iteration: 180, Avg img/sec: 1794.5686271249087 206 | INFO - Time taken for epoch 2 is 33.07876420021057 sec 207 | ``` 208 | 209 | You can run another profile (using `--config=bs256-amp-prof`) with Nsight Systems. Loading this profile and zooming into a single iteration, this is what we see: 210 | ![AMP Zoomed](tutorial_images/nsys_amp_zoomed.png) 211 | 212 | With AMP enabled, we see that the `forward/loss/backward` time is significatly reduced. As this is a CNN, the forward and backward convolution ops are well-suited to benefit from acceleration with tensor cores. 213 | 214 | If we zoom into the forward section of the profile to the GPU kernels, we can see very many calls to `nchwToNhwc` and `nhwcToNCHW` kernels: 215 | ![AMP Zoomed Kernels](tutorial_images/nsys_amp_zoomed_kernels.png) 216 | 217 | These kernels are transposing the data from PyTorch's native data layout (NCHW or channels first) to the NHWC (or channels last) format which cuDNN requires to use tensor cores. Luckily, there is a way to avoid these transposes by using the `torch.channels_last` memory 218 | format. To use this, we need to convert both the model and the input image tensors to this format by using the following lines: 219 | ``` 220 | model = model.to(memory_format=torch.channels_last) 221 | images = images.to(memory_format=torch.channels_last) 222 | ``` 223 | Search `train.py` for comments labeled `# NHWC` to see where we've added these lines to run the model using NHWC format. 224 | 225 | To run the script on a single GPU with AMP enabled using the NHWC memory format, use the following command: 226 | ``` 227 | $ python -m torch.distributed.launch --nproc_per_node=1 train.py --config=bs256-amp-nhwc 228 | ``` 229 | With AMP enabled using the NHWC memory format, this is the performance of the script using the NGC PyTorch 20.10 container for the first two epochs on a 16GB V100 card: 230 | ``` 231 | INFO - Starting Training Loop... 232 | INFO - Epoch: 1, Iteration: 0, Avg img/sec: 125.35020387731124 233 | INFO - Epoch: 1, Iteration: 20, Avg img/sec: 2089.3251919566933 234 | INFO - Epoch: 1, Iteration: 40, Avg img/sec: 2075.2397782670346 235 | INFO - Epoch: 1, Iteration: 60, Avg img/sec: 2078.1579609491064 236 | INFO - Epoch: 1, Iteration: 80, Avg img/sec: 2114.314909986603 237 | INFO - Epoch: 1, Iteration: 100, Avg img/sec: 2076.3754707171784 238 | INFO - Epoch: 1, Iteration: 120, Avg img/sec: 2066.673609844659 239 | INFO - Epoch: 1, Iteration: 140, Avg img/sec: 2070.3321011509784 240 | INFO - Epoch: 1, Iteration: 160, Avg img/sec: 2107.977617868012 241 | INFO - Epoch: 1, Iteration: 180, Avg img/sec: 2117.288989717637 242 | INFO - Time taken for epoch 1 is 30.756738424301147 sec 243 | INFO - Epoch: 2, Iteration: 0, Avg img/sec: 464.2617647745541 244 | INFO - Epoch: 2, Iteration: 20, Avg img/sec: 2151.947432559358 245 | INFO - Epoch: 2, Iteration: 40, Avg img/sec: 2208.417190923362 246 | INFO - Epoch: 2, Iteration: 60, Avg img/sec: 2177.7232959147427 247 | INFO - Epoch: 2, Iteration: 80, Avg img/sec: 2226.609558578422 248 | INFO - Epoch: 2, Iteration: 100, Avg img/sec: 2253.0767957237485 249 | INFO - Epoch: 2, Iteration: 120, Avg img/sec: 2137.2692109868517 250 | INFO - Epoch: 2, Iteration: 140, Avg img/sec: 2214.0994804791235 251 | INFO - Epoch: 2, Iteration: 160, Avg img/sec: 2195.9345278285564 252 | INFO - Epoch: 2, Iteration: 180, Avg img/sec: 2162.628100059094 253 | INFO - Time taken for epoch 2 is 28.39500093460083 sec 254 | ``` 255 | With the NCHW/NHWC tranposes removed, we see another modest gain in throughput. You can run another profile (using `--config=bs256-amp-nhwc-prof`) with Nsight Systems. Loading this profile and zooming into a single iteration, this is what we see now: 256 | ![AMP NHWC Zoomed](tutorial_images/nsys_amp_nhwc_zoomed.png) 257 | 258 | Using the NHWC memory format with AMP, we see that the `forward/loss/backward` times are reduced further due to no longer calling the transpose kernels. Now we can move onto some other small PyTorch-specific optimizations to deal with the remaining sections that stand out in the profile. 259 | 260 | ### Applying additional PyTorch optimizations 261 | With the forward and backward pass accelerated with AMP and NHWC memory layout, the remaining NVTX ranges we added to the profile stand out, namely the `zero_grad` marker and `optimizer.step`. 262 | 263 | To speed up the `zero_grad`, we can add the following argument to the `zero_grad` call: 264 | ``` 265 | self.model.zero_grad(set_to_none=True) 266 | ``` 267 | This optional argument allows PyTorch to skip memset operations to zero out gradients and also allows PyTorch to set gradients with a single write (`=` operator) instead of a read/write (`+=` operator). 268 | 269 | 270 | If we look closely at the `optimizer.step` range in the profile, we see that there are many indivdual pointwise operation kernels launched. To make this more efficient, we can replace the native PyTorch SGD optimizer with the `FusedSGD` optimizer from the `Apex` package, which fuses many of these pointwise 271 | operations. 272 | 273 | Finally, as a general optimization, we add the line `torch.backends.cudnn.benchmark = True` to the start of training to enable cuDNN autotuning. This will allow cuDNN to test and select algorithms that run fastest on your system/model. 274 | 275 | Search `train.py` for comments labeled `# EXTRA` to see where we've added changes for these additional optimizations. 276 | 277 | 278 | To run the script on a single GPU with AMP enabled, NHWC memory format and these additional optimizations, use the following command: 279 | ``` 280 | $ python -m torch.distributed.launch --nproc_per_node=1 train.py --config=bs256-amp-nhwc-extra-opts 281 | ``` 282 | With all these features enabled, this is the performance of the script using the NGC PyTorch 20.10 container for the first two epochs on a 16GB V100 card: 283 | ``` 284 | INFO - Starting Training Loop... 285 | INFO - Epoch: 1, Iteration: 0, Avg img/sec: 51.52879474970972 286 | INFO - Epoch: 1, Iteration: 20, Avg img/sec: 2428.815812361664 287 | INFO - Epoch: 1, Iteration: 40, Avg img/sec: 2471.928460752096 288 | INFO - Epoch: 1, Iteration: 60, Avg img/sec: 2461.6635515925623 289 | INFO - Epoch: 1, Iteration: 80, Avg img/sec: 2461.5230335547976 290 | INFO - Epoch: 1, Iteration: 100, Avg img/sec: 2470.371590429863 291 | INFO - Epoch: 1, Iteration: 120, Avg img/sec: 2462.8998420750218 292 | INFO - Epoch: 1, Iteration: 140, Avg img/sec: 2567.007655538539 293 | INFO - Epoch: 1, Iteration: 160, Avg img/sec: 2531.0173058079126 294 | INFO - Epoch: 1, Iteration: 180, Avg img/sec: 2577.144387068793 295 | INFO - Time taken for epoch 1 is 30.52899408340454 sec 296 | INFO - Epoch: 2, Iteration: 0, Avg img/sec: 410.57308753695185 297 | INFO - Epoch: 2, Iteration: 20, Avg img/sec: 2547.8182536936824 298 | INFO - Epoch: 2, Iteration: 40, Avg img/sec: 2519.104752035505 299 | INFO - Epoch: 2, Iteration: 60, Avg img/sec: 2529.822264348943 300 | INFO - Epoch: 2, Iteration: 80, Avg img/sec: 2539.450348785371 301 | INFO - Epoch: 2, Iteration: 100, Avg img/sec: 2533.167522740291 302 | INFO - Epoch: 2, Iteration: 120, Avg img/sec: 2542.63597641221 303 | INFO - Epoch: 2, Iteration: 140, Avg img/sec: 2502.990963521907 304 | INFO - Epoch: 2, Iteration: 160, Avg img/sec: 2525.3185224087124 305 | INFO - Epoch: 2, Iteration: 180, Avg img/sec: 2501.353650885946 306 | INFO - Time taken for epoch 2 is 25.808385372161865 sec 307 | ``` 308 | 309 | We can run a final profile with all the optimizations enabled (using `--config=bs256-amp-nhwc-extra-opts-prof`) with Nsight Systems. Loading this profile and zooming into a single iteration, this is what we see now: 310 | ![AMP NHWC Extra Zoomed](tutorial_images/nsys_amp_nhwc_extra_zoomed.png) 311 | With these additional optimizations enabled in PyTorch, we see the length of the `zero_grad` and `optimizer.step` ranges are greatly reduced, as well as a small improvement in the `forward/loss/backward` time. 312 | 313 | ## Distributed GPU training 314 | 315 | Now that we have model training code that is optimized for training on a single GPU, 316 | we are ready to utilize multiple GPUs and multiple nodes to accelerate the workflow 317 | with *distributed training*. We will use the recommended `DistributedDataParallel` 318 | wrapper in PyTorch with the NCCL backend for optimized communication operations on 319 | systems with NVIDIA GPUs. Refer to the PyTorch documentation for additional details 320 | on the distributed package: https://pytorch.org/docs/stable/distributed.html 321 | 322 | ### Code basics 323 | 324 | We use the `torch.distributed.launch` utility for launching training processes 325 | on one node, one per GPU. The [submit\_multinode.slr](submit_multinode.slr) 326 | script shows how we use the utility with SLURM to launch the tasks on each node 327 | in our system allocation. 328 | 329 | In the [train.py](train.py) script, near the bottom in the main script execution, 330 | we set up the distributed backend. We use the environment variable initialization 331 | method, automatically configured for us when we use the `torch.distributed.launch` utility. 332 | 333 | In the `get_data_loader` function in 334 | [utils/cifar100\_data\_loader.py](utils/cifar100_data_loader.py), we use the 335 | DistributedSampler from PyTorch which takes care of partitioning the dataset 336 | so that each training process sees a unique subset. 337 | 338 | In our Trainer's `__init__` method, after our ResNet50 model is constructed, 339 | we convert it to a distributed data parallel model by wrapping it as: 340 | 341 | self.model = DistributedDataParallel(self.model, ...) 342 | 343 | The DistributedDataParallel (DDP) model wrapper takes care of broadcasting 344 | initial model weights to all workers and performing all-reduce on the gradients 345 | in the training backward pass to properly synchronize and update the model 346 | weights in the distributed setting. 347 | 348 | ### Large batch convergence 349 | 350 | To speed up training, we try to use larger batch sizes, spread across more GPUs, 351 | with larger learning rates. In particular, we try increasing from 1 to 8 then 16 gpus, 352 | and scale the batch size similarly to 1024 and 2048. 353 | The first thing we demonstrate here is increasing 354 | the learning rate according to the square-root scaling rule. The settings for 355 | batch size 512, 1024, and 2048 are in [config/cifar100.yaml](config/cifar100.yaml) 356 | under `bs512-opt`, `bs1024-opt`, and `bs2048-opt`, respectively. 357 | We view the accuracy plots in TensorBoard and notice that the convergence 358 | performs worse with larger batch size, i.e. we see a generalization gap: 359 | 360 | ![Accuracy for bs128, bs1024, bs2048](tutorial_images/acc1.png) 361 | 362 | Next, as suggested in the presentation previously, we apply a linear learning rate 363 | warmup for these batch sizes. You can see where we compute the learning rate 364 | in the warmup phase in our Trainer's `train` method in the `train.py` script. 365 | Look for the comment, "Apply learning rate warmup". 366 | As shown in configs `bs1024-warmup-opt` and `bs2048-warmup-opt` in our 367 | `config/cifar100.yaml` file, we use 8 and 16 epochs for the warmup, 368 | respectively. 369 | 370 | Now we can see the generalization gap closes and 371 | the higher batch size results are as good as the original batch size 128: 372 | 373 | ![Accuracy for bs128, bs1024 warmup, bs2048 warmup](tutorial_images/acc2.png) 374 | 375 | Next, we can now look at the wallclock time to see that, indeed, using 376 | these tricks together result in a much faster convergence: 377 | 378 | ![Accuracy vs time for bs128, bs1024 warmup, bs2048 warmup](tutorial_images/acc3.png) 379 | 380 | In particular, our batch size 128 run on 1 gpu takes about 32 min to converge, 381 | while our batch size 2048 run on 16 gpus takes around 4 min. 382 | 383 | Finally, we look at the throughput (images/second) of our training runs as 384 | we do this weak scaling of the batch size and GPUs: 385 | 386 | ![Weak scaling training throughput](tutorial_images/throughputScaling.png) 387 | 388 | These plots show 81% scaling efficiency with respect to ideal scaling at 16 GPUs. 389 | -------------------------------------------------------------------------------- /config/cifar100.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | 3 | lr: !!float 0.1 4 | momentum: !!float 0.9 5 | weight_decay: !!float 5e-4 6 | lr_warmup_epochs: 0 7 | 8 | #mean and std of cifar100 dataset 9 | cifar100_mean: [0.5070751592371323, 0.48654887331495095, 0.4409178433670343] 10 | cifar100_std: [0.2673342858792401, 0.2564384629170883, 0.27615047132568404] 11 | data_path: './data' # will be automatically downloaded if not available 12 | rnd_rotation_angle: 15 13 | num_data_workers: 6 14 | num_classes: 100 15 | 16 | max_epochs: 100 17 | batch_size: 128 18 | valid_batch_size_per_gpu: 128 19 | 20 | log_to_screen: !!bool True 21 | log_to_tensorboard: !!bool True 22 | save_checkpoint: !!bool True 23 | log_freq: 20 24 | 25 | enable_amp: !!bool False 26 | enable_nhwc: !!bool False 27 | enable_extra_opts: !!bool False 28 | 29 | enable_profiling: !!bool False 30 | profiling_epoch_start: 2 31 | profiling_iters_per_epoch: 100 32 | 33 | bs128: 34 | <<: *DEFAULT 35 | 36 | # --- Optimization experiments --- # 37 | bs256: 38 | <<: *DEFAULT 39 | lr: !!float 0.14 40 | batch_size: 256 41 | 42 | bs256-prof: 43 | <<: *DEFAULT 44 | lr: !!float 0.14 45 | batch_size: 256 46 | enable_profiling: True 47 | max_epochs: 2 48 | save_checkpoint: False 49 | 50 | bs256-amp: 51 | <<: *DEFAULT 52 | lr: !!float 0.14 53 | batch_size: 256 54 | enable_amp: True 55 | 56 | bs256-amp-prof: 57 | <<: *DEFAULT 58 | lr: !!float 0.14 59 | batch_size: 256 60 | enable_amp: True 61 | enable_profiling: True 62 | max_epochs: 2 63 | save_checkpoint: False 64 | 65 | bs256-amp-nhwc: 66 | <<: *DEFAULT 67 | lr: !!float 0.14 68 | batch_size: 256 69 | enable_amp: True 70 | enable_nhwc: True 71 | 72 | bs256-amp-nhwc-prof: 73 | <<: *DEFAULT 74 | lr: !!float 0.14 75 | batch_size: 256 76 | enable_amp: True 77 | enable_nhwc: True 78 | enable_profiling: True 79 | max_epochs: 2 80 | save_checkpoint: False 81 | 82 | bs256-amp-nhwc-extra-opts: 83 | <<: *DEFAULT 84 | lr: !!float 0.14 85 | batch_size: 256 86 | enable_amp: True 87 | enable_nhwc: True 88 | enable_extra_opts: True 89 | 90 | bs256-amp-nhwc-extra-opts-prof: 91 | <<: *DEFAULT 92 | lr: !!float 0.14 93 | batch_size: 256 94 | enable_amp: True 95 | enable_nhwc: True 96 | enable_extra_opts: True 97 | enable_profiling: True 98 | max_epochs: 2 99 | save_checkpoint: False 100 | 101 | # --- Scaling experiments --- # 102 | bs128-opt: 103 | <<: *DEFAULT 104 | enable_amp: True 105 | enable_nhwc: True 106 | enable_extra_opts: True 107 | 108 | bs512-opt: 109 | <<: *DEFAULT 110 | lr: !!float 0.2 111 | batch_size: 512 112 | max_epochs: 120 113 | enable_amp: True 114 | enable_nhwc: True 115 | enable_extra_opts: True 116 | 117 | bs1024-opt: 118 | <<: *DEFAULT 119 | lr: !!float 0.28 120 | batch_size: 1024 121 | max_epochs: 120 122 | enable_amp: True 123 | enable_nhwc: True 124 | enable_extra_opts: True 125 | 126 | bs2048-opt: 127 | <<: *DEFAULT 128 | lr: !!float 0.4 129 | batch_size: 2048 130 | max_epochs: 120 131 | enable_amp: True 132 | enable_nhwc: True 133 | enable_extra_opts: True 134 | 135 | bs1024-warmup-opt: 136 | <<: *DEFAULT 137 | lr: !!float 0.28 138 | batch_size: 1024 139 | max_epochs: 120 140 | lr_warmup_epochs: 8 141 | enable_amp: True 142 | enable_nhwc: True 143 | enable_extra_opts: True 144 | 145 | bs2048-warmup-opt: 146 | <<: *DEFAULT 147 | lr: !!float 0.4 148 | batch_size: 2048 149 | lr_warmup_epochs: 16 150 | max_epochs: 120 151 | enable_amp: True 152 | enable_nhwc: True 153 | enable_extra_opts: True 154 | -------------------------------------------------------------------------------- /expts/bs1024-opt/tb_logs/events.out.tfevents.1604953226.cgpu05.79229.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs1024-opt/tb_logs/events.out.tfevents.1604953226.cgpu05.79229.0 -------------------------------------------------------------------------------- /expts/bs1024-warmup-opt/tb_logs/events.out.tfevents.1604953230.cgpu07.50299.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs1024-warmup-opt/tb_logs/events.out.tfevents.1604953230.cgpu07.50299.0 -------------------------------------------------------------------------------- /expts/bs128-opt/tb_logs/events.out.tfevents.1604948513.cgpu15.13890.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs128-opt/tb_logs/events.out.tfevents.1604948513.cgpu15.13890.0 -------------------------------------------------------------------------------- /expts/bs2048-opt/tb_logs/events.out.tfevents.1604954305.cgpu05.43094.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs2048-opt/tb_logs/events.out.tfevents.1604954305.cgpu05.43094.0 -------------------------------------------------------------------------------- /expts/bs2048-warmup-opt/tb_logs/events.out.tfevents.1604953824.cgpu05.66469.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs2048-warmup-opt/tb_logs/events.out.tfevents.1604953824.cgpu05.66469.0 -------------------------------------------------------------------------------- /expts/bs512-opt/tb_logs/events.out.tfevents.1604950792.cgpu09.61987.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/expts/bs512-opt/tb_logs/events.out.tfevents.1604950792.cgpu09.61987.0 -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Small modification of torchvision.models.resnet.py. Copied on 10/21/2020 2 | import torch 3 | import torch.nn as nn 4 | # from .utils import load_state_dict_from_url 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 9 | 'wide_resnet50_2', 'wide_resnet101_2'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 21 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=dilation, groups=groups, bias=False, dilation=dilation) 29 | 30 | 31 | def conv1x1(in_planes, out_planes, stride=1): 32 | """1x1 convolution""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 78 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 79 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 80 | # This variant is also known as ResNet V1.5 and improves accuracy according to 81 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 82 | 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 86 | base_width=64, dilation=1, norm_layer=None): 87 | super(Bottleneck, self).__init__() 88 | if norm_layer is None: 89 | norm_layer = nn.BatchNorm2d 90 | width = int(planes * (base_width / 64.)) * groups 91 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 92 | self.conv1 = conv1x1(inplanes, width) 93 | self.bn1 = norm_layer(width) 94 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 95 | self.bn2 = norm_layer(width) 96 | self.conv3 = conv1x1(width, planes * self.expansion) 97 | self.bn3 = norm_layer(planes * self.expansion) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | identity = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv3(out) 114 | out = self.bn3(out) 115 | 116 | if self.downsample is not None: 117 | identity = self.downsample(x) 118 | 119 | out += identity 120 | out = self.relu(out) 121 | 122 | return out 123 | 124 | 125 | class ResNet(nn.Module): 126 | 127 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 128 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 129 | norm_layer=None): 130 | super(ResNet, self).__init__() 131 | if norm_layer is None: 132 | norm_layer = nn.BatchNorm2d 133 | self._norm_layer = norm_layer 134 | 135 | self.inplanes = 64 136 | self.dilation = 1 137 | if replace_stride_with_dilation is None: 138 | # each element in the tuple indicates if we should replace 139 | # the 2x2 stride with a dilated convolution instead 140 | replace_stride_with_dilation = [False, False, False] 141 | if len(replace_stride_with_dilation) != 3: 142 | raise ValueError("replace_stride_with_dilation should be None " 143 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 144 | self.groups = groups 145 | self.base_width = width_per_group 146 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, 147 | bias=False) 148 | self.bn1 = norm_layer(self.inplanes) 149 | self.relu = nn.ReLU(inplace=True) 150 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 151 | self.layer1 = self._make_layer(block, 64, layers[0]) 152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 153 | dilate=replace_stride_with_dilation[0]) 154 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 155 | dilate=replace_stride_with_dilation[1]) 156 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 157 | dilate=replace_stride_with_dilation[2]) 158 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 159 | self.fc = nn.Linear(512 * block.expansion, num_classes) 160 | 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 164 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 165 | nn.init.constant_(m.weight, 1) 166 | nn.init.constant_(m.bias, 0) 167 | 168 | # Zero-initialize the last BN in each residual branch, 169 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 170 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 171 | if zero_init_residual: 172 | for m in self.modules(): 173 | if isinstance(m, Bottleneck): 174 | nn.init.constant_(m.bn3.weight, 0) 175 | elif isinstance(m, BasicBlock): 176 | nn.init.constant_(m.bn2.weight, 0) 177 | 178 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 179 | norm_layer = self._norm_layer 180 | downsample = None 181 | previous_dilation = self.dilation 182 | if dilate: 183 | self.dilation *= stride 184 | stride = 1 185 | if stride != 1 or self.inplanes != planes * block.expansion: 186 | downsample = nn.Sequential( 187 | conv1x1(self.inplanes, planes * block.expansion, stride), 188 | norm_layer(planes * block.expansion), 189 | ) 190 | 191 | layers = [] 192 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 193 | self.base_width, previous_dilation, norm_layer)) 194 | self.inplanes = planes * block.expansion 195 | for _ in range(1, blocks): 196 | layers.append(block(self.inplanes, planes, groups=self.groups, 197 | base_width=self.base_width, dilation=self.dilation, 198 | norm_layer=norm_layer)) 199 | 200 | return nn.Sequential(*layers) 201 | 202 | def _forward_impl(self, x): 203 | # See note [TorchScript super()] 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | # x = self.maxpool(x) 208 | 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | 214 | x = self.avgpool(x) 215 | x = torch.flatten(x, 1) 216 | x = self.fc(x) 217 | 218 | return x 219 | 220 | def forward(self, x): 221 | return self._forward_impl(x) 222 | 223 | 224 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 225 | model = ResNet(block, layers, **kwargs) 226 | if pretrained: 227 | state_dict = load_state_dict_from_url(model_urls[arch], 228 | progress=progress) 229 | model.load_state_dict(state_dict) 230 | return model 231 | 232 | 233 | def resnet18(pretrained=False, progress=True, **kwargs): 234 | r"""ResNet-18 model from 235 | `"Deep Residual Learning for Image Recognition" `_ 236 | 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | progress (bool): If True, displays a progress bar of the download to stderr 240 | """ 241 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 242 | **kwargs) 243 | 244 | 245 | def resnet34(pretrained=False, progress=True, **kwargs): 246 | r"""ResNet-34 model from 247 | `"Deep Residual Learning for Image Recognition" `_ 248 | 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | progress (bool): If True, displays a progress bar of the download to stderr 252 | """ 253 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 254 | **kwargs) 255 | 256 | 257 | def resnet50(pretrained=False, progress=True, **kwargs): 258 | r"""ResNet-50 model from 259 | `"Deep Residual Learning for Image Recognition" `_ 260 | 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | progress (bool): If True, displays a progress bar of the download to stderr 264 | """ 265 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 266 | **kwargs) 267 | 268 | 269 | def resnet101(pretrained=False, progress=True, **kwargs): 270 | r"""ResNet-101 model from 271 | `"Deep Residual Learning for Image Recognition" `_ 272 | 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnet152(pretrained=False, progress=True, **kwargs): 282 | r"""ResNet-152 model from 283 | `"Deep Residual Learning for Image Recognition" `_ 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 290 | **kwargs) 291 | 292 | 293 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 294 | r"""ResNeXt-50 32x4d model from 295 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | kwargs['groups'] = 32 302 | kwargs['width_per_group'] = 4 303 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 304 | pretrained, progress, **kwargs) 305 | 306 | 307 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 308 | r"""ResNeXt-101 32x8d model from 309 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 310 | 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | kwargs['groups'] = 32 316 | kwargs['width_per_group'] = 8 317 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 318 | pretrained, progress, **kwargs) 319 | 320 | 321 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 322 | r"""Wide ResNet-50-2 model from 323 | `"Wide Residual Networks" `_ 324 | 325 | The model is the same as ResNet except for the bottleneck number of channels 326 | which is twice larger in every block. The number of channels in outer 1x1 327 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 328 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 329 | 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | kwargs['width_per_group'] = 64 * 2 335 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 336 | pretrained, progress, **kwargs) 337 | 338 | 339 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 340 | r"""Wide ResNet-101-2 model from 341 | `"Wide Residual Networks" `_ 342 | 343 | The model is the same as ResNet except for the bottleneck number of channels 344 | which is twice larger in every block. The number of channels in outer 1x1 345 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 346 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 347 | 348 | Args: 349 | pretrained (bool): If True, returns a model pre-trained on ImageNet 350 | progress (bool): If True, displays a progress bar of the download to stderr 351 | """ 352 | kwargs['width_per_group'] = 64 * 2 353 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 354 | pretrained, progress, **kwargs) 355 | -------------------------------------------------------------------------------- /print_arch.py: -------------------------------------------------------------------------------- 1 | from torchsummary import summary 2 | import models.resnet 3 | import torchvision.models 4 | 5 | model = models.resnet.resnet50(num_classes=100) 6 | # model = torchvision.models.resnet50(num_classes=100) 7 | summary(model, (3, 32, 32)) 8 | -------------------------------------------------------------------------------- /sout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/sout/__init__.py -------------------------------------------------------------------------------- /submit.slr: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH -C gpu 3 | #SBATCH -A m1759 4 | #SBATCH --time=4:00:00 5 | #SBATCH --nodes=1 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --gpus-per-task=1 8 | #SBATCH --cpus-per-task=10 9 | #SBATCH -o sout/%j.out 10 | 11 | # Configuration 12 | nproc_per_node=1 13 | config=bs128-opt 14 | 15 | # Load software 16 | module load cgpu 17 | module load pytorch/1.7.0-gpu 18 | 19 | # Launch one SLURM task, and use torch distributed launch utility 20 | # to spawn training worker processes; one per GPU 21 | srun -N 1 -n 1 python -m torch.distributed.launch --nproc_per_node=$nproc_per_node \ 22 | train.py --config=$config 23 | -------------------------------------------------------------------------------- /submit_multinode.slr: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH -C gpu 3 | #SBATCH --time=00:30:00 4 | #SBATCH --nodes=2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | #SBATCH --cpus-per-task=80 8 | #SBATCH -o sout/%j.out 9 | 10 | # Configuration 11 | nproc_per_node=8 12 | config=bs2048-warmup-opt 13 | 14 | # Load software 15 | module load cgpu 16 | module load pytorch/1.7.0-gpu 17 | 18 | # Setup node list 19 | nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) # Getting the node names 20 | nodes_array=( $nodes ) 21 | master_node=${nodes_array[0]} 22 | master_addr=$(srun --nodes=1 --ntasks=1 -w $master_node hostname --ip-address) 23 | worker_num=$(($SLURM_JOB_NUM_NODES)) 24 | 25 | # Loop over nodes and submit training tasks 26 | for (( node_rank=0; node_rank<$worker_num; node_rank++ )) 27 | do 28 | node=${nodes_array[$node_rank]} 29 | echo "Submitting node # $node_rank, $node" 30 | 31 | # Launch one SLURM task per node, and use torch distributed launch utility 32 | # to spawn training worker processes; one per GPU 33 | srun -N 1 -n 1 -w $node python -m torch.distributed.launch \ 34 | --nproc_per_node=$nproc_per_node --nnodes=$SLURM_JOB_NUM_NODES \ 35 | --node_rank=$node_rank --master_addr=$master_addr \ 36 | train.py --config=$config & 37 | 38 | pids[${node_rank}]=$! 39 | done 40 | 41 | # Wait for completion 42 | for pid in ${pids[*]}; do 43 | wait $pid 44 | done 45 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.nn.parallel import DistributedDataParallel 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import logging 11 | logging.basicConfig(format='%(levelname)s - %(message)s', level=logging.INFO) 12 | 13 | import models.resnet 14 | from utils.YParams import YParams 15 | from utils.cifar100_data_loader import get_data_loader 16 | 17 | import apex 18 | 19 | # PROF: define wrapped NVTX range routines with device syncs 20 | def nvtx_range_push(name, enabled): 21 | if enabled: 22 | torch.cuda.synchronize() 23 | torch.cuda.nvtx.range_push(name) 24 | 25 | def nvtx_range_pop(enabled): 26 | if enabled: 27 | torch.cuda.synchronize() 28 | torch.cuda.nvtx.range_pop() 29 | 30 | class Trainer(): 31 | 32 | def __init__(self, params): 33 | self.params = params 34 | self.device = torch.cuda.current_device() 35 | # AMP: Construct GradScaler for loss scaling 36 | self.grad_scaler = torch.cuda.amp.GradScaler(enabled=self.params.enable_amp) 37 | self.profiler_running = False 38 | 39 | # first constrcut the dataloader on rank0 in case the data is not downloaded 40 | if params.world_rank == 0: 41 | logging.info('rank %d, begin data loader init'%params.world_rank) 42 | self.train_data_loader, self.train_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=True) 43 | self.valid_data_loader, self.valid_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=False) 44 | logging.info('rank %d, data loader initialized'%params.world_rank) 45 | 46 | # wait for rank0 to finish downloading the data 47 | if dist.is_initialized(): 48 | dist.barrier() 49 | 50 | # now construct the dataloaders on other ranks 51 | if params.world_rank != 0: 52 | logging.info('rank %d, begin data loader init'%params.world_rank) 53 | self.train_data_loader, self.train_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=True) 54 | self.valid_data_loader, self.valid_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=False) 55 | logging.info('rank %d, data loader initialized'%params.world_rank) 56 | 57 | self.model = models.resnet.resnet50(num_classes=params.num_classes).to(self.device) 58 | 59 | if self.params.enable_nhwc: 60 | # NHWC: Convert model to channels_last memory format 61 | self.model = self.model.to(memory_format=torch.channels_last) 62 | 63 | if self.params.enable_extra_opts: 64 | # EXTRA: use Apex FusedSGD optimizer 65 | self.optimizer = apex.optimizers.FusedSGD(self.model.parameters(), lr=params.lr, 66 | momentum=params.momentum, weight_decay=params.weight_decay) 67 | else: 68 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=params.lr, 69 | momentum=params.momentum, weight_decay=params.weight_decay) 70 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=10, mode='min') 71 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device) 72 | 73 | if dist.is_initialized(): 74 | self.model = DistributedDataParallel(self.model, 75 | device_ids=[params.local_rank], 76 | output_device=[params.local_rank]) 77 | self.iters = 0 78 | self.startEpoch = 0 79 | if params.resuming: 80 | logging.info("Loading checkpoint %s"%params.checkpoint_path) 81 | self.restore_checkpoint(params.checkpoint_path) 82 | self.epoch = self.startEpoch 83 | 84 | if params.log_to_screen: 85 | logging.info(self.model) 86 | 87 | if params.log_to_tensorboard: 88 | self.writer = SummaryWriter(os.path.join(params.experiment_dir, 'tb_logs')) 89 | 90 | def train(self): 91 | if self.params.log_to_screen: 92 | logging.info("Starting Training Loop...") 93 | 94 | for epoch in range(self.startEpoch, self.params.max_epochs): 95 | if self.params.enable_profiling and epoch + 1 == self.params.profiling_epoch_start: 96 | # PROF: create range to control profiler start and stop 97 | self.profiler_running = True 98 | nvtx_range_push('PROFILE', self.profiler_running) 99 | 100 | if dist.is_initialized(): 101 | self.train_sampler.set_epoch(epoch) 102 | self.valid_sampler.set_epoch(epoch) 103 | 104 | # Apply learning rate warmup 105 | if epoch < params.lr_warmup_epochs: 106 | self.optimizer.param_groups[0]['lr'] = params.lr*float(epoch+1.)/float(params.lr_warmup_epochs) 107 | 108 | start = time.time() 109 | # PROF: Add custom NVTX ranges 110 | nvtx_range_push('epoch {}'.format(self.epoch), self.profiler_running) 111 | # PROF: Enable torch built-in NVTX ranges. Disabled for this example to reduce profiling overhead. 112 | with torch.autograd.profiler.emit_nvtx(enabled=False):#enabled=self.profiler_running): 113 | train_logs = self.train_one_epoch() 114 | nvtx_range_pop(self.profiler_running) 115 | valid_time, valid_logs = self.validate_one_epoch() 116 | if epoch >= params.lr_warmup_epochs: 117 | self.scheduler.step(valid_logs['loss']) 118 | 119 | if self.params.world_rank == 0: 120 | if self.params.save_checkpoint: 121 | #checkpoint at the end of every epoch 122 | self.save_checkpoint(self.params.checkpoint_path) 123 | 124 | if self.params.log_to_tensorboard: 125 | self.writer.add_scalar('loss/train', train_logs['loss'], self.epoch) 126 | self.writer.add_scalar('loss/valid', valid_logs['loss'], self.epoch) 127 | self.writer.add_scalar('acc1/train', train_logs['acc1'], self.epoch) 128 | self.writer.add_scalar('acc1/valid', valid_logs['acc1'], self.epoch) 129 | self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]['lr'], self.epoch) 130 | 131 | if self.params.log_to_screen: 132 | logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 133 | logging.info('train acc1={}, valid acc1={}'.format(train_logs['acc1'], valid_logs['acc1'])) 134 | 135 | if self.params.enable_profiling: 136 | nvtx_range_pop(self.profiler_running) 137 | self.profiler_running = False 138 | 139 | def train_one_epoch(self): 140 | self.epoch += 1 141 | torch.cuda.synchronize() 142 | report_time = time.time() 143 | report_bs = 0 144 | 145 | # Loop over training data batches 146 | for i, data in enumerate(self.train_data_loader, 0): 147 | # PROF: Add custom NVTX ranges 148 | nvtx_range_push('iteration {}'.format(i), self.profiler_running) 149 | self.iters += 1 150 | 151 | # PROF: Add custom NVTX ranges 152 | nvtx_range_push('data', self.profiler_running) 153 | # Move our images and labels to GPU 154 | images, labels = map(lambda x: x.to(self.device), data) 155 | # NHWC: Convert input images to channels_last memory format 156 | if self.params.enable_nhwc: 157 | images = images.to(memory_format=torch.channels_last) 158 | nvtx_range_pop(self.profiler_running) 159 | 160 | # PROF: Add custom NVTX ranges 161 | nvtx_range_push('zero_grad', self.profiler_running) 162 | if self.params.enable_extra_opts: 163 | # EXTRA: Use set_to_none option to avoid slow memsets to zero 164 | self.model.zero_grad(set_to_none=True) 165 | else: 166 | self.model.zero_grad() 167 | nvtx_range_pop(self.profiler_running) 168 | self.model.train() 169 | 170 | # PROF: Add custom NVTX ranges 171 | nvtx_range_push('forward/loss/backward', self.profiler_running) 172 | # AMP: Add autocast context manager 173 | with torch.cuda.amp.autocast(enabled=self.params.enable_amp): 174 | 175 | # Model forward pass and loss computation 176 | outputs = self.model(images) 177 | loss = self.criterion(outputs, labels) 178 | 179 | # AMP: Use GradScaler to scale loss and run backward to produce scaled gradients 180 | self.grad_scaler.scale(loss).backward() 181 | nvtx_range_pop(self.profiler_running) 182 | 183 | # PROF: Add custom NVTX ranges 184 | nvtx_range_push('optimizer.step', self.profiler_running) 185 | # AMP: Run optimizer step through GradScaler (unscales gradients and skips steps if required) 186 | self.grad_scaler.step(self.optimizer) 187 | nvtx_range_pop(self.profiler_running) 188 | 189 | # AMP: Update GradScaler loss scale value 190 | self.grad_scaler.update() 191 | 192 | torch.cuda.synchronize() 193 | nvtx_range_pop(self.profiler_running) 194 | 195 | report_bs += len(images) 196 | 197 | if i % self.params.log_freq == 0: 198 | torch.cuda.synchronize() 199 | logging.info('Epoch: {}, Iteration: {}, Avg img/sec: {}'.format(self.epoch, i, report_bs / (time.time() - report_time))) 200 | report_time = time.time() 201 | report_bs = 0 202 | 203 | if self.params.enable_profiling and i >= self.params.profiling_iters_per_epoch: 204 | break 205 | 206 | # save metrics of last batch 207 | _, preds = outputs.max(1) 208 | acc1 = preds.eq(labels).sum().float()/labels.shape[0] 209 | logs = {'loss': loss, 210 | 'acc1': acc1} 211 | 212 | if dist.is_initialized(): 213 | for key in sorted(logs.keys()): 214 | dist.all_reduce(logs[key].detach()) 215 | logs[key] = float(logs[key]/dist.get_world_size()) 216 | 217 | return logs 218 | 219 | def validate_one_epoch(self): 220 | self.model.eval() 221 | 222 | valid_start = time.time() 223 | loss = 0.0 224 | correct = 0.0 225 | with torch.no_grad(): 226 | for data in self.valid_data_loader: 227 | images, labels = map(lambda x: x.to(self.device), data) 228 | outputs = self.model(images) 229 | loss += self.criterion(outputs, labels) 230 | _, preds = outputs.max(1) 231 | correct += preds.eq(labels).sum().float()/labels.shape[0] 232 | 233 | logs = {'loss': loss/len(self.valid_data_loader), 234 | 'acc1': correct/len(self.valid_data_loader)} 235 | valid_time = time.time() - valid_start 236 | 237 | if dist.is_initialized(): 238 | for key in sorted(logs.keys()): 239 | logs[key] = torch.as_tensor(logs[key]).to(self.device) 240 | dist.all_reduce(logs[key].detach()) 241 | logs[key] = float(logs[key]/dist.get_world_size()) 242 | 243 | return valid_time, logs 244 | 245 | def save_checkpoint(self, checkpoint_path, model=None): 246 | """ We intentionally require a checkpoint_dir to be passed 247 | in order to allow Ray Tune to use this function """ 248 | 249 | if not model: 250 | model = self.model 251 | 252 | torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 253 | 'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path) 254 | 255 | def restore_checkpoint(self, checkpoint_path): 256 | """ We intentionally require a checkpoint_dir to be passed 257 | in order to allow Ray Tune to use this function """ 258 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank)) 259 | self.model.load_state_dict(checkpoint['model_state']) 260 | self.iters = checkpoint['iters'] 261 | self.startEpoch = checkpoint['epoch'] + 1 262 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 263 | 264 | if __name__ == '__main__': 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument("--local_rank", default=0, type=int) 267 | parser.add_argument("--yaml_config", default='./config/cifar100.yaml', type=str) 268 | parser.add_argument("--config", default='default', type=str) 269 | args = parser.parse_args() 270 | 271 | params = YParams(os.path.abspath(args.yaml_config), args.config) 272 | 273 | # setup distributed training variables and intialize cluster if using 274 | params['world_size'] = 1 275 | if 'WORLD_SIZE' in os.environ: 276 | params['world_size'] = int(os.environ['WORLD_SIZE']) 277 | 278 | params['local_rank'] = args.local_rank 279 | params['world_rank'] = 0 280 | if params['world_size'] > 1: 281 | torch.cuda.set_device(args.local_rank) 282 | dist.init_process_group(backend='nccl', 283 | init_method='env://') 284 | params['world_rank'] = dist.get_rank() 285 | params['global_batch_size'] = params.batch_size 286 | params['batch_size'] = int(params.batch_size//params['world_size']) 287 | 288 | # EXTRA: enable cuDNN autotuning. 289 | if params.enable_extra_opts: 290 | torch.backends.cudnn.benchmark = True 291 | 292 | # setup output directory 293 | expDir = os.path.join('./expts', args.config) 294 | if params.world_rank==0: 295 | if not os.path.isdir(expDir): 296 | os.makedirs(expDir) 297 | os.makedirs(os.path.join(expDir, 'checkpoints/')) 298 | 299 | params['experiment_dir'] = os.path.abspath(expDir) 300 | params['checkpoint_path'] = os.path.join(expDir, 'checkpoints/ckpt.tar') 301 | params['resuming'] = True if os.path.isfile(params.checkpoint_path) else False 302 | 303 | if params.world_rank==0: 304 | params.log() 305 | params['log_to_screen'] = params.log_to_screen and params.world_rank==0 306 | params['log_to_tensorboard'] = params.log_to_tensorboard and params.world_rank==0 307 | 308 | trainer = Trainer(params) 309 | trainer.train() 310 | -------------------------------------------------------------------------------- /train_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.nn.parallel import DistributedDataParallel 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import logging 11 | logging.basicConfig(format='%(levelname)s - %(message)s', level=logging.INFO) 12 | 13 | import models.resnet 14 | from utils.YParams import YParams 15 | from utils.cifar100_data_loader import get_data_loader 16 | 17 | class Trainer(): 18 | 19 | def __init__(self, params): 20 | self.params = params 21 | self.device = torch.cuda.current_device() 22 | 23 | # first constrcut the dataloader on rank0 in case the data is not downloaded 24 | if params.world_rank == 0: 25 | logging.info('rank %d, begin data loader init'%params.world_rank) 26 | self.train_data_loader, self.train_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=True) 27 | self.valid_data_loader, self.valid_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=False) 28 | logging.info('rank %d, data loader initialized'%params.world_rank) 29 | 30 | # wait for rank0 to finish downloading the data 31 | if dist.is_initialized(): 32 | dist.barrier() 33 | 34 | # now construct the dataloaders on other ranks 35 | if params.world_rank != 0: 36 | logging.info('rank %d, begin data loader init'%params.world_rank) 37 | self.train_data_loader, self.train_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=True) 38 | self.valid_data_loader, self.valid_sampler = get_data_loader(params, params.data_path, dist.is_initialized(), is_train=False) 39 | logging.info('rank %d, data loader initialized'%params.world_rank) 40 | 41 | self.model = models.resnet.resnet50(num_classes=params.num_classes).to(self.device) 42 | 43 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=params.lr, 44 | momentum=params.momentum, weight_decay=params.weight_decay) 45 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.2, patience=10, mode='min') 46 | self.criterion = torch.nn.CrossEntropyLoss().to(self.device) 47 | 48 | if dist.is_initialized(): 49 | self.model = DistributedDataParallel(self.model, 50 | device_ids=[params.local_rank], 51 | output_device=[params.local_rank]) 52 | self.iters = 0 53 | self.startEpoch = 0 54 | if params.resuming: 55 | logging.info("Loading checkpoint %s"%params.checkpoint_path) 56 | self.restore_checkpoint(params.checkpoint_path) 57 | self.epoch = self.startEpoch 58 | 59 | if params.log_to_screen: 60 | logging.info(self.model) 61 | 62 | if params.log_to_tensorboard: 63 | self.writer = SummaryWriter(os.path.join(params.experiment_dir, 'tb_logs')) 64 | 65 | def train(self): 66 | if self.params.log_to_screen: 67 | logging.info("Starting Training Loop...") 68 | 69 | for epoch in range(self.startEpoch, self.params.max_epochs): 70 | if dist.is_initialized(): 71 | self.train_sampler.set_epoch(epoch) 72 | self.valid_sampler.set_epoch(epoch) 73 | 74 | if epoch < params.lr_warmup_epochs: 75 | self.optimizer.param_groups[0]['lr'] = params.lr*float(epoch+1.)/float(params.lr_warmup_epochs) 76 | 77 | start = time.time() 78 | tr_time, data_time, train_logs = self.train_one_epoch() 79 | valid_time, valid_logs = self.validate_one_epoch() 80 | if epoch >= params.lr_warmup_epochs: 81 | self.scheduler.step(valid_logs['loss']) 82 | 83 | if self.params.world_rank == 0: 84 | if self.params.save_checkpoint: 85 | #checkpoint at the end of every epoch 86 | self.save_checkpoint(self.params.checkpoint_path) 87 | 88 | if self.params.log_to_tensorboard: 89 | self.writer.add_scalar('loss/train', train_logs['loss'], self.epoch) 90 | self.writer.add_scalar('loss/valid', valid_logs['loss'], self.epoch) 91 | self.writer.add_scalar('acc1/train', train_logs['acc1'], self.epoch) 92 | self.writer.add_scalar('acc1/valid', valid_logs['acc1'], self.epoch) 93 | self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]['lr'], self.epoch) 94 | 95 | if self.params.log_to_screen: 96 | logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start)) 97 | logging.info('train data time={}, train time={}, valid step time={}, train acc1={}, valid acc1={}'.format(data_time, tr_time, 98 | valid_time, 99 | train_logs['acc1'], 100 | valid_logs['acc1'])) 101 | 102 | def train_one_epoch(self): 103 | self.epoch += 1 104 | tr_time = 0 105 | data_time = 0 106 | report_time = report_bs = 0 107 | for i, data in enumerate(self.train_data_loader, 0): 108 | self.iters += 1 109 | iter_start = time.time() 110 | data_start = time.time() 111 | images, labels = map(lambda x: x.to(self.device), data) 112 | data_time += time.time() - data_start 113 | 114 | tr_start = time.time() 115 | self.model.zero_grad() 116 | self.model.train() 117 | outputs = self.model(images) 118 | loss = self.criterion(outputs, labels) 119 | loss.backward() 120 | self.optimizer.step() 121 | tr_time += time.time() - tr_start 122 | iter_time = time.time() - iter_start 123 | report_time += iter_time 124 | report_bs += len(images) 125 | 126 | if i % self.params.log_freq == 0: 127 | logging.info('Epoch: {}, Iteration: {}, Avg img/sec: {}'.format(self.epoch, i, report_bs / report_time)) 128 | report_time = report_bs = 0 129 | 130 | # save metrics of last batch 131 | _, preds = outputs.max(1) 132 | acc1 = preds.eq(labels).sum().float()/labels.shape[0] 133 | logs = {'loss': loss, 134 | 'acc1': acc1} 135 | 136 | if dist.is_initialized(): 137 | for key in sorted(logs.keys()): 138 | dist.all_reduce(logs[key].detach()) 139 | logs[key] = float(logs[key]/dist.get_world_size()) 140 | 141 | return tr_time, data_time, logs 142 | 143 | def validate_one_epoch(self): 144 | self.model.eval() 145 | 146 | valid_start = time.time() 147 | loss = 0.0 148 | correct = 0.0 149 | with torch.no_grad(): 150 | for data in self.valid_data_loader: 151 | images, labels = map(lambda x: x.to(self.device), data) 152 | outputs = self.model(images) 153 | loss += self.criterion(outputs, labels) 154 | _, preds = outputs.max(1) 155 | correct += preds.eq(labels).sum().float()/labels.shape[0] 156 | 157 | logs = {'loss': loss/len(self.valid_data_loader), 158 | 'acc1': correct/len(self.valid_data_loader)} 159 | valid_time = time.time() - valid_start 160 | 161 | if dist.is_initialized(): 162 | for key in sorted(logs.keys()): 163 | logs[key] = torch.as_tensor(logs[key]).to(self.device) 164 | dist.all_reduce(logs[key].detach()) 165 | logs[key] = float(logs[key]/dist.get_world_size()) 166 | 167 | return valid_time, logs 168 | 169 | def save_checkpoint(self, checkpoint_path, model=None): 170 | """ We intentionally require a checkpoint_dir to be passed 171 | in order to allow Ray Tune to use this function """ 172 | 173 | if not model: 174 | model = self.model 175 | 176 | torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 177 | 'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path) 178 | 179 | def restore_checkpoint(self, checkpoint_path): 180 | """ We intentionally require a checkpoint_dir to be passed 181 | in order to allow Ray Tune to use this function """ 182 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank)) 183 | self.model.load_state_dict(checkpoint['model_state']) 184 | self.iters = checkpoint['iters'] 185 | self.startEpoch = checkpoint['epoch'] + 1 186 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 187 | 188 | if __name__ == '__main__': 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument("--local_rank", default=0, type=int) 191 | parser.add_argument("--yaml_config", default='./config/cifar100.yaml', type=str) 192 | parser.add_argument("--config", default='default', type=str) 193 | args = parser.parse_args() 194 | 195 | params = YParams(os.path.abspath(args.yaml_config), args.config) 196 | 197 | # setup distributed training variables and intialize cluster if using 198 | params['world_size'] = 1 199 | if 'WORLD_SIZE' in os.environ: 200 | params['world_size'] = int(os.environ['WORLD_SIZE']) 201 | 202 | params['local_rank'] = args.local_rank 203 | params['world_rank'] = 0 204 | if params['world_size'] > 1: 205 | torch.cuda.set_device(args.local_rank) 206 | dist.init_process_group(backend='nccl', 207 | init_method='env://') 208 | params['world_rank'] = dist.get_rank() 209 | params['global_batch_size'] = params.batch_size 210 | params['batch_size'] = int(params.batch_size//params['world_size']) 211 | 212 | torch.backends.cudnn.benchmark = True 213 | 214 | # setup output directory 215 | expDir = os.path.join('./expts', args.config) 216 | if params.world_rank==0: 217 | if not os.path.isdir(expDir): 218 | os.makedirs(expDir) 219 | os.makedirs(os.path.join(expDir, 'checkpoints/')) 220 | 221 | params['experiment_dir'] = os.path.abspath(expDir) 222 | params['checkpoint_path'] = os.path.join(expDir, 'checkpoints/ckpt.tar') 223 | params['resuming'] = True if os.path.isfile(params.checkpoint_path) else False 224 | 225 | if params.world_rank==0: 226 | params.log() 227 | params['log_to_screen'] = params.log_to_screen and params.world_rank==0 228 | params['log_to_tensorboard'] = params.log_to_tensorboard and params.world_rank==0 229 | 230 | trainer = Trainer(params) 231 | trainer.train() 232 | -------------------------------------------------------------------------------- /tutorial_images/acc1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/acc1.png -------------------------------------------------------------------------------- /tutorial_images/acc2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/acc2.png -------------------------------------------------------------------------------- /tutorial_images/acc3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/acc3.png -------------------------------------------------------------------------------- /tutorial_images/bs128_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/bs128_learning.png -------------------------------------------------------------------------------- /tutorial_images/nsys_amp_nhwc_extra_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_amp_nhwc_extra_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_amp_nhwc_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_amp_nhwc_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_amp_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_amp_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_amp_zoomed_kernels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_amp_zoomed_kernels.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline_full.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_baseline_full.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/nsys_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/throughputScaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc20-dl-tutorial/5c9f83c99a15543a5fe780865dcf81321f7342ad/tutorial_images/throughputScaling.png -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | class YParams(): 5 | """ Yaml file parser """ 6 | def __init__(self, yaml_filename, config_name, print_params=False): 7 | self._yaml_filename = yaml_filename 8 | self._config_name = config_name 9 | self.params = {} 10 | 11 | if print_params: 12 | print("------------------ Configuration ------------------") 13 | 14 | with open(yaml_filename) as _file: 15 | 16 | for key, val in YAML().load(_file)[config_name].items(): 17 | if print_params: print(key, val) 18 | if val =='None': val = None 19 | 20 | self.params[key] = val 21 | self.__setattr__(key, val) 22 | 23 | if print_params: 24 | print("---------------------------------------------------") 25 | 26 | def __getitem__(self, key): 27 | return self.params[key] 28 | 29 | def __setitem__(self, key, val): 30 | self.params[key] = val 31 | self.__setattr__(key, val) 32 | 33 | def __contains__(self, key): 34 | return (key in self.params) 35 | 36 | def update_params(self, config): 37 | for key, val in config.items(): 38 | self.params[key] = val 39 | self.__setattr__(key, val) 40 | 41 | def log(self): 42 | logging.info("------------------ Configuration ------------------") 43 | logging.info("Configuration file: "+str(self._yaml_filename)) 44 | logging.info("Configuration name: "+str(self._config_name)) 45 | for key, val in self.params.items(): 46 | logging.info(str(key) + ' ' + str(val)) 47 | logging.info("---------------------------------------------------") 48 | -------------------------------------------------------------------------------- /utils/cifar100_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | from torch.utils.data.distributed import DistributedSampler 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | 7 | 8 | def get_data_loader(params, files_pattern, distributed, is_train): 9 | 10 | if is_train: 11 | transform = transforms.Compose([ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.RandomRotation(params.rnd_rotation_angle), 15 | transforms.ToTensor(), 16 | transforms.Normalize(tuple(params.cifar100_mean), 17 | tuple(params.cifar100_std))]) 18 | else: 19 | transform = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize(tuple(params.cifar100_mean), 22 | tuple(params.cifar100_std))]) 23 | 24 | dataset = datasets.CIFAR100(root=params.data_path, 25 | train=is_train, 26 | download=True if (is_train and params.world_rank==0) else False, 27 | transform=transform) 28 | 29 | sampler = DistributedSampler(dataset, shuffle=True) if distributed else None 30 | 31 | dataloader = DataLoader(dataset, 32 | batch_size=int(params.batch_size) if is_train else int(params.valid_batch_size_per_gpu), 33 | num_workers=params.num_data_workers, 34 | shuffle=(sampler is None), 35 | sampler=sampler, 36 | drop_last=True, 37 | pin_memory=torch.cuda.is_available()) 38 | 39 | return dataloader, sampler 40 | --------------------------------------------------------------------------------