├── .gitignore ├── README.md ├── README_summit.md ├── benchy-conf.yaml ├── config └── UNet.yaml ├── docker └── Dockerfile ├── example_logs └── base │ └── 1GPU │ └── 00 │ ├── logs │ └── events.out.tfevents.1636677146.nid001104.65966.0 │ └── out.log ├── export_DDP_vars.sh ├── export_DDP_vars_summit.sh ├── launch_summit.sh ├── networks └── UNet.py ├── run_summit.sh ├── sample_nsys_profiles ├── 16workers.nsys-rep ├── 4gpu_baseline.nsys-rep ├── 4gpu_bucketcap100mb.nsys-rep ├── 4gpu_nobroadcast.nsys-rep ├── baseline.nsys-rep ├── dali.nsys-rep ├── dali_amp.nsys-rep ├── dali_amp_apex_jit.nsys-rep ├── summit_6gpu_baseline.qdrep ├── summit_6gpu_bucketcap100mb.qdrep └── summit_6gpu_nobroadcast.qdrep ├── start_tensorboard.ipynb ├── start_tensorboard_summit.ipynb ├── submit_cgpu.sh ├── submit_pm.sh ├── submit_summit.sh ├── summit_scaling_logs └── plot_weak_scale.py ├── train.py ├── train_graph.py ├── tutorial_images ├── baseline_tb.png ├── baseline_tb_summit.png ├── bs512.png ├── bs512_short.png ├── bs576_short_summit.png ├── bs576_summit.png ├── bs_compare.png ├── bs_compare_summit.png ├── nbody2hydro.png ├── nsys_4gpu_baseline.png ├── nsys_4gpu_baseline_zoomed.png ├── nsys_4gpu_bucketcap100mb_zoomed.png ├── nsys_4gpu_nobroadcast.png ├── nsys_baseline.png ├── nsys_baseline_zoomed.png ├── nsys_dali.png ├── nsys_dali_amp.png ├── nsys_dali_amp_apex_jit.png ├── nsys_dali_amp_apex_jit_zoomed.png ├── nsys_dali_amp_zoomed.png ├── nsys_dali_zoomed.png ├── nsys_nativedata_16workers.png ├── nsys_nativedata_16workers_zoomed.png ├── nsys_nativedata_8workers.png ├── nsys_nativedata_8workers_zoomed.png ├── nsys_summit_6gpu_baseline.png ├── nsys_summit_6gpu_baseline_zoomed.png ├── nsys_summit_6gpu_bucketcap100mb_zoomed.png ├── nsys_summit_6gpu_nobroadcast.png ├── relative.png ├── scale_perfComm.png ├── scale_perfDiffBS.png ├── scale_perfEff.png └── scale_perfEff_bs128.png └── utils ├── YParams.py ├── __init__.py ├── convert_to_npy.py ├── data_loader.py ├── data_loader_dali.py ├── logging_utils.py └── symmetry.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.out 3 | *.slr 4 | ./joblogs 5 | .ipynb_checkpoints/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC22 Deep Learning at Scale Tutorial 2 | 3 | This repository contains the example code material for the SC22 tutorial: 4 | *Deep Learning at Scale*. 5 | 6 | **Contents** 7 | * [Links](#links) 8 | * [Installation](#installation-and-setup) 9 | * [Model, data, and code overview](#model-data-and-training-code-overview) 10 | * [Single GPU training](#single-gpu-training) 11 | * [Single GPU performance](#single-gpu-performance-profiling-and-optimization) 12 | * [Distributed training](#distributed-gpu-training) 13 | * [Multi GPU performance](#multi-gpu-performance-profiling-and-optimization) 14 | * [Putting it all together](#putting-it-all-together) 15 | 16 | ## Links 17 | 18 | Tutorial slides: https://drive.google.com/drive/folders/1T8u7kA-PLgs1rhFxF7zT7c81HJ2bouWO?usp=sharing 19 | 20 | Join the Slack workspace: https://join.slack.com/t/nersc-dl-tutorial/shared_invite/zt-1jnpj4ggz-ks4dCyCsI8Z8iVRWO4LBsg 21 | 22 | NERSC JupyterHub: https://jupyter.nersc.gov 23 | 24 | Data download (only needed if you want to run our examples elsewhere): https://portal.nersc.gov/project/dasrepo/pharring/ 25 | 26 | ## Installation and Setup 27 | 28 | ### Software environment 29 | 30 | The instructions in this README are intended to be used with NERSC's Perlmutter machine. However, in-person participants can also choose to run on the OLCF Summit system if they want. Instructions for running on Summit are given in [`README_summit.md`](./README_summit.md). 31 | 32 | Access to the Perlmutter machine is provided for this tutorial via [jupyter.nersc.gov](https://jupyter.nersc.gov). 33 | Training account setup instructions will be given during the session. Once you have your provided account credentials, you can log in to Jupyter via the link (leave the OTP field blank when logging into Jupyter). 34 | Once logged into the hub, start a session by clicking the button for Perlmutter Shared CPU Node (other options will not work with this tutorial material). This will open up a session on a Perlmutter login node, from which you can submit jobs to the GPU nodes and monitor their progress. 35 | 36 | To begin, start a terminal from JupyterHub and clone this repository with: 37 | ```bash 38 | git clone https://github.com/NERSC/sc22-dl-tutorial.git 39 | ``` 40 | You can use the Jupyter file browser to view and edit source files and scripts. For all of the example commands provided below, make sure you are running them from within the top-level folder of the repository. In your terminal, change to the directory with 41 | ```bash 42 | cd sc22-dl-tutorial 43 | ``` 44 | 45 | For running slurm jobs on Perlmutter, we will use training accounts which are provided under the `ntrain4` project. The slurm script `submit_pm.sh` included in the repository is configured to work automatically as is, but if you submit your own custom jobs via `salloc` or `sbatch` you must include the following flags for slurm: 46 | * `-A ntrain4_g` is required for training accounts 47 | * `--reservation=sc22_tutorial` is required to access the set of GPU nodes we have reserved for the duration of the tutorial. 48 | 49 | The code can be run using the `nersc/sc22-dl-tutorial` docker container. On Perlmutter, docker containers are run via [shifter](https://docs.nersc.gov/development/shifter/), and this container is already downloaded and automatically invoked by our job submission scripts. Our container is based on the [NVIDIA ngc 22.10 pytorch container](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-22-10.html), with a few additional packages added. See the dockerfile in [`docker/Dockerfile`](docker/Dockerfile) for details. 50 | 51 | ### Installing Nsight Systems 52 | In this tutorial, we will be generating profile files using NVIDIA Nsight Systems on the remote systems. In order to open and view these 53 | files on your local computer, you will need to install the Nsight Systems program, which you can download [here](https://developer.nvidia.com/gameworksdownload#?search=nsight%20systems). Select the download option required for your system (e.g. Mac OS host for MacOS, Window Host for Windows, or Linux Host .rpm/.deb/.run for Linux). You may need to sign up and create a login to NVIDIA's developer program if you do not 54 | already have an account to access the download. Proceed to run and install the program using your selected installation method. 55 | 56 | ## Model, data, and training code overview 57 | 58 | The model in this repository is adapted from a cosmological application of deep learning ([Harrington et al. 2021](https://arxiv.org/abs/2106.12662)), which aims to augment computationally expensive simulations by using a [U-Net](https://arxiv.org/abs/1505.04597) model to reconstruct physical fields of interest (namely, hydrodynamic quantities associated with diffuse gas in the universe): 59 | 60 | ![n-body to hydro schematic](tutorial_images/nbody2hydro.png) 61 | 62 | The U-Net model architecture used in these examples can be found in [`networks/UNet.py`](networks/UNet.py). U-Nets are a popular and capable architecture, as they can extract long-range features through sequential downsampling convolutions, while fine-grained details can be propagated to the upsampling path via skip connections. This particular U-Net is relatively lightweight, to better accommodate our 3D data samples. 63 | 64 | The basic data loading pipeline is defined in [`utils/data_loader.py`](utils/data_loader.py), whose primary components are: 65 | * The `RandomCropDataset` which accesses the simulation data stored on disk, and randomly crops sub-volumes of the physical fields to serve for training and validation. For this repository, we will be using a crop size of 64^3 66 | * The `RandomRotator` transform, which applies random rotations and reflections to the samples as data augmentations 67 | * The above components are assembled to feed a PyTorch `DataLoader` which takes the augmented samples and combines them into a batch for each training step. 68 | 69 | It is common practice to decay the learning rate according to some schedule as the model trains, so that the optimizer can settle into sharper minima during gradient descent. Here we opt for the cosine learning rate decay schedule, which starts at an intial learning rate and decays continuously throughout training according to a cosine function. This is handled by the `lr_schedule` routine defined in [`utils/__init__.py`](utils/__init__.py), which also has logic to implement learning rate scaling and warm-up for use in the [Distributed GPU training](#Distributed-GPU-training) section 70 | 71 | As we will see in the [Single GPU performance profiling and optimization](#Single-GPU-performance-profiling-and-optimization) section, the random rotations add considerable overhead to data loading during training, and we can achieve performance gains by doing these preprocessing steps on the GPU instead using NVIDIA's DALI library. Code for this is found in [`utils/data_loader_dali.py`](utils/data_loader_dali.py). 72 | 73 | The script to train the model is [`train.py`](train.py), which uses the following arguments to load the desired training setup: 74 | ``` 75 | --yaml_config YAML_CONFIG path to yaml file containing training configs 76 | --config CONFIG name of desired config in yaml file 77 | ``` 78 | 79 | Based on the selected configuration, the train script will then: 80 | 1. Set up the data loaders and construct our U-Net model, the Adam optimizer, and our L1 loss function. 81 | 2. Loop over training epochs to run the training. See if you can identify the following key components: 82 | * Looping over data batches from our data loader. 83 | * Applying the forward pass of the model and computing the loss function. 84 | * Calling `backward()` on the loss value to backpropagate gradients. Note the use of the `grad_scaler` will be explained below when enabling mixed precision. 85 | * Applying the model to the validation dataset and logging training and validation metrics to visualize in TensorBoard (see if you can find where we construct the TensorBoard `SummaryWriter` and where our specific metrics are logged via the `add_scalar` call). 86 | 87 | Besides the `train.py` script, we have a slightly more complex [`train_graph.py`](train_graph.py) 88 | script, which implements the same functionality with added capability for using the CUDA Graphs APIs introduced in PyTorch 1.10. This topic will be covered in the [Single GPU performance profiling and optimization](#Single-GPU-performance-profiling-and-optimization) section. 89 | 90 | More info on the model and data can be found in the [slides](https://drive.google.com/drive/u/1/folders/1Ei56_HDjLMPbdLq9QdQhoxN3J1YdzZw0). If you are experimenting with this repository after the tutorial date, you can download the data from here: https://portal.nersc.gov/project/dasrepo/pharring/. 91 | Note that you will have to adjust the data path in `submit_pm.sh` to point yor personal copy after downloading. 92 | 93 | 94 | ## Single GPU training 95 | 96 | First, let us look at the performance of the training script without optimizations on a single GPU. 97 | 98 | On Perlmutter for the tutorial, we will be submitting jobs to the batch queue. To submit this job, use the following command: 99 | ``` 100 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 101 | ``` 102 | `submit_pm.sh` is a batch submission script that defines resources to be requested by SLURM as well as the command to run. 103 | Note that any arguments for `train.py`, such as the desired config (`--config`), can be added after `submit_pm.sh` when submitting, and they will be passed to `train.py` properly. 104 | When using batch submission, you can see the job output by viewing the file `pm-crop64-.out` in the submission 105 | directory. You can find the job id of your job using the command `squeue --me` and looking at the first column of the output. 106 | 107 | For interactive jobs, you can run the Python script directly using the following command (**NOTE: please don't run training on the Perlmutter login nodes**): 108 | ``` 109 | python train.py --config=short --num_epochs 4 110 | ``` 111 | For V100 systems, you will likely need to update the config to reduce the local batch size to 32 due to the reduced memory capacity. Otherwise, instructions are the same. 112 | 113 | This will run 3 epochs of training on a single GPU using a default batch size of 64. 114 | See [`config/UNet.yaml`](config/UNet.yaml) for specific configuration details. 115 | Note we will use the default batch size for the optimization work in the next section 116 | and will push beyond to larger batch sizes in the distributed training section. 117 | 118 | In the baseline configuration, the model converges to a loss of about `4.75e-3` on 119 | the validation dataset in 10 epochs. This takes around 2 hours to run, so to save time we have already included an example TensorBoard log for the `base` config in the `example_logs` directory for you. 120 | We want to compare our training results against the `base` config baseline, and TensorBoard makes this easy as long as all training runs are stored in the same place. 121 | To copy the example TensorBoard log to the scratch directory where our training jobs will output their logs, do 122 | ``` 123 | mkdir -p $SCRATCH/sc22-dl-tutorial/logs 124 | cp -r ./example_logs/base $SCRATCH/sc22-dl-tutorial/logs 125 | ``` 126 | 127 | To view results in TensorBoard, open the [`start_tensorboard.ipynb`](start_tensorboard.ipynb) notebook and follow the instructions in it to launch a TensorBoard session in your browser. Once you have TensorBoard open, you should see a dashboard with data for the loss values, learning rate, and average iterations per second. Looking at the validation loss for the `base` config, you should see the following training curve: 128 | ![baseline training](tutorial_images/baseline_tb.png) 129 | 130 | As our training with the `short` config runs, it should also dump the training metrics to the TensorBoard directory, and TensorBoard will parse the data and display it for you. You can hit the refresh button in the upper-right corner of TensorBoard to update the plots with the latest data. 131 | 132 | ## Single GPU performance profiling and optimization 133 | 134 | This is the performance of the baseline script for the first four epochs on a 40GB A100 card with batch size 64: 135 | ``` 136 | 2022-11-09 15:33:33,897 - root - INFO - Time taken for epoch 1 is 73.79664635658264 sec, avg 55.503877238652294 samples/sec 137 | 2022-11-09 15:33:33,901 - root - INFO - Avg train loss=0.066406 138 | 2022-11-09 15:33:39,679 - root - INFO - Avg val loss=0.042361 139 | 2022-11-09 15:33:39,681 - root - INFO - Total validation time: 5.777978897094727 sec 140 | 2022-11-09 15:34:28,412 - root - INFO - Time taken for epoch 2 is 48.72997832298279 sec, avg 84.05503431279347 samples/sec 141 | 2022-11-09 15:34:28,414 - root - INFO - Avg train loss=0.028927 142 | 2022-11-09 15:34:33,504 - root - INFO - Avg val loss=0.026633 143 | 2022-11-09 15:34:33,504 - root - INFO - Total validation time: 5.089476585388184 sec 144 | 2022-11-09 15:35:22,528 - root - INFO - Time taken for epoch 3 is 49.02241778373718 sec, avg 83.55361047408023 samples/sec 145 | 2022-11-09 15:35:22,531 - root - INFO - Avg train loss=0.019387 146 | 2022-11-09 15:35:27,788 - root - INFO - Avg val loss=0.021904 147 | 2022-11-09 15:35:27,788 - root - INFO - Total validation time: 5.256815195083618 sec 148 | 2022-11-09 15:36:15,871 - root - INFO - Time taken for epoch 4 is 48.08129024505615 sec, avg 85.18906167292717 samples/sec 149 | 2022-11-09 15:36:15,872 - root - INFO - Avg train loss=0.017213 150 | 2022-11-09 15:36:20,641 - root - INFO - Avg val loss=0.020661 151 | 2022-11-09 15:36:20,642 - root - INFO - Total validation time: 4.768946886062622 sec 152 | ``` 153 | After the first epoch, we see that the throughput achieved is about 85 samples/s. 154 | 155 | ### Profiling with Nsight Systems 156 | #### Adding NVTX ranges and profiler controls 157 | Before generating a profile with Nsight, we can add NVTX ranges to the script to add context to the produced timeline. 158 | We can add some manually defined NVTX ranges to the code using `torch.cuda.nvtx.range_push` and `torch.cuda.nvtx.range_pop`. 159 | We can also add calls to `torch.cuda.profiler.start()` and `torch.cuda.profiler.stop()` to control the duration of the profiling 160 | (e.g., limit profiling to single epoch). 161 | 162 | To generate a profile using our scripts on Perlmutter, run the following command: 163 | ``` 164 | ENABLE_PROFILING=1 PROFILE_OUTPUT=baseline sbatch -n1 submit_pm.sh --config=short --num_epochs 4 --enable_manual_profiling 165 | ``` 166 | If running interactively, this is the full command from the batch submission script: 167 | ``` 168 | nsys profile -o baseline --trace=cuda,nvtx -c cudaProfilerApi --kill none -f true python train.py --config=short --num_epochs 4 --enable_manual_profiling 169 | ``` 170 | This command will run four epochs of the training script, profiling only 60 steps of the last epoch. It will produce a file `baseline.nsys-rep` that can be opened in the Nsight System's program. The arg `--trace=cuda,nvtx` is optional and is used here to disable OS Runtime tracing for speed. 171 | 172 | Loading this profile ([`baseline.nsys-rep`](sample_nsys_profiles/baseline.nsys-rep)) in Nsight Systems will look like this: 173 | ![NSYS Baseline](tutorial_images/nsys_baseline.png) 174 | 175 | From this zoomed out view, we can see a lot idle gaps between iterations. These gaps are due to the data loading, which we will address in the next section. 176 | 177 | Beyond this, we can zoom into a single iteration and get an idea of where compute time is being spent: 178 | ![NSYS Baseline zoomed](tutorial_images/nsys_baseline_zoomed.png) 179 | 180 | 181 | #### Using the benchy profiling tool 182 | As an alternative to manually specifying NVTX ranges, we've included the use of a simple profiling tool [`benchy`](https://github.com/romerojosh/benchy) that overrides the PyTorch dataloader in the script to produce throughput information to the terminal, as well as add NVTX ranges/profiler start and stop calls. This tool also runs a sequence of tests to measure and report the throughput of the dataloader in isolation (`IO`), the model running with synthetic/cached data (`SYNTHETIC`), and the throughput of the model running normally with real data (`FULL`). 183 | 184 | To run using using benchy on Perlmutter, use the following command: 185 | ``` 186 | sbatch -n1 submit_pm.sh --config=short --num_epochs 15 --enable_benchy 187 | ``` 188 | If running interactively: 189 | ``` 190 | python train.py --config=short ---num_epochs 15 -enable_benchy 191 | ``` 192 | benchy uses epoch boundaries to separate the test trials it runs, so in these cases we increase the epoch limit to 15 to ensure the full experiment runs. 193 | 194 | benchy will report throughput measurements directly to the terminal, including a simple summary of averages at the end of the job. For this case on Perlmutter, the summary output from benchy is: 195 | ``` 196 | BENCHY::SUMMARY::IO average trial throughput: 84.468 +/- 0.463 197 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 481.618 +/- 1.490 198 | BENCHY::SUMMARY::FULL average trial throughput: 83.212 +/- 0.433 199 | ``` 200 | From these throughput values, we can see that the `SYNTHETIC` (i.e. compute) throughput is greater than the `IO` (i.e. data loading) throughput. 201 | The `FULL` (i.e. real) throughput is bounded by the slower of these two values, which is `IO` in this case. What these throughput 202 | values indicate is the GPU can achieve much greater training throughput for this model, but is being limited by the data loading 203 | speed. 204 | 205 | ### Data loading optimizations 206 | #### Improving the native PyTorch dataloader performance 207 | The PyTorch dataloader has several knobs we can adjust to improve performance. If you look at the `DataLoader` initialization in 208 | `utils/data_loader.py`, you'll see we've already set several useful options, like `pin_memory` and `persistent_workers`. 209 | `pin_memory` has the data loader read input data into pinned host memory, which typically yields better host-to-device and device-to-host 210 | memcopy bandwidth. `persistent_workers` allows PyTorch to reuse workers between epochs, instead of the default behavior which is to 211 | respawn them. One knob we've left to adjust is the `num_workers` argument, which we can control via the `--num_data_workers` command 212 | line arg to our script. The default in our config is two workers, but we can experiment with this value to see if increasing the number 213 | of workers improves performance. 214 | 215 | We can run this experiment on Perlmutter by running the following command: 216 | ``` 217 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 218 | ``` 219 | If running interactively: 220 | ``` 221 | python train.py --config=short --num_epochs 4 --num_data_workers 222 | ``` 223 | 224 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64 and 4 data workers: 225 | ``` 226 | 2022-11-09 15:39:11,669 - root - INFO - Time taken for epoch 1 is 45.737974405288696 sec, avg 89.55359421265447 samples/sec 227 | 2022-11-09 15:39:11,670 - root - INFO - Avg train loss=0.073490 228 | 2022-11-09 15:39:15,088 - root - INFO - Avg val loss=0.045034 229 | 2022-11-09 15:39:15,088 - root - INFO - Total validation time: 3.416804075241089 sec 230 | 2022-11-09 15:39:44,070 - root - INFO - Time taken for epoch 2 is 28.980637788772583 sec, avg 141.33574388024115 samples/sec 231 | 2022-11-09 15:39:44,073 - root - INFO - Avg train loss=0.030368 232 | 2022-11-09 15:39:47,385 - root - INFO - Avg val loss=0.026028 233 | 2022-11-09 15:39:47,385 - root - INFO - Total validation time: 3.31168794631958 sec 234 | 2022-11-09 15:40:15,724 - root - INFO - Time taken for epoch 3 is 28.337406396865845 sec, avg 144.5439269436113 samples/sec 235 | 2022-11-09 15:40:15,727 - root - INFO - Avg train loss=0.019323 236 | 2022-11-09 15:40:19,103 - root - INFO - Avg val loss=0.020982 237 | 2022-11-09 15:40:19,103 - root - INFO - Total validation time: 3.376103639602661 sec 238 | 2022-11-09 15:40:47,585 - root - INFO - Time taken for epoch 4 is 28.479787349700928 sec, avg 143.8212985829409 samples/sec 239 | 2022-11-09 15:40:47,586 - root - INFO - Avg train loss=0.017431 240 | 2022-11-09 15:40:50,858 - root - INFO - Avg val loss=0.020178 241 | 2022-11-09 15:40:50,859 - root - INFO - Total validation time: 3.2716639041900635 sec 242 | ``` 243 | 244 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64 and 8 data workers: 245 | ``` 246 | 2022-11-09 15:43:17,251 - root - INFO - Time taken for epoch 1 is 33.487831115722656 sec, avg 122.31308697913593 samples/sec 247 | 2022-11-09 15:43:17,252 - root - INFO - Avg train loss=0.073041 248 | 2022-11-09 15:43:20,238 - root - INFO - Avg val loss=0.046040 249 | 2022-11-09 15:43:20,239 - root - INFO - Total validation time: 2.9856343269348145 sec 250 | 2022-11-09 15:43:41,047 - root - INFO - Time taken for epoch 2 is 20.80655312538147 sec, avg 196.861055039596 samples/sec 251 | 2022-11-09 15:43:41,047 - root - INFO - Avg train loss=0.029219 252 | 2022-11-09 15:43:43,839 - root - INFO - Avg val loss=0.025189 253 | 2022-11-09 15:43:43,839 - root - INFO - Total validation time: 2.791450262069702 sec 254 | 2022-11-09 15:44:04,333 - root - INFO - Time taken for epoch 3 is 20.492927312850952 sec, avg 199.8738363470128 samples/sec 255 | 2022-11-09 15:44:04,336 - root - INFO - Avg train loss=0.019040 256 | 2022-11-09 15:44:07,148 - root - INFO - Avg val loss=0.021613 257 | 2022-11-09 15:44:07,148 - root - INFO - Total validation time: 2.8117287158966064 sec 258 | 2022-11-09 15:44:27,724 - root - INFO - Time taken for epoch 4 is 20.574484825134277 sec, avg 199.0815339879728 samples/sec 259 | 2022-11-09 15:44:27,724 - root - INFO - Avg train loss=0.017410 260 | 2022-11-09 15:44:30,501 - root - INFO - Avg val loss=0.020545 261 | 2022-11-09 15:44:30,501 - root - INFO - Total validation time: 2.776258707046509 sec 262 | ``` 263 | 264 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64 and 16 data workers: 265 | ``` 266 | 2022-11-09 15:47:46,373 - root - INFO - Time taken for epoch 1 is 31.55659580230713 sec, avg 129.79853801912748 samples/sec 267 | 2022-11-09 15:47:46,373 - root - INFO - Avg train loss=0.066807 268 | 2022-11-09 15:47:49,513 - root - INFO - Avg val loss=0.044213 269 | 2022-11-09 15:47:49,513 - root - INFO - Total validation time: 3.139697790145874 sec 270 | 2022-11-09 15:48:09,174 - root - INFO - Time taken for epoch 2 is 19.658984661102295 sec, avg 208.3525711327522 samples/sec 271 | 2022-11-09 15:48:09,174 - root - INFO - Avg train loss=0.027332 272 | 2022-11-09 15:48:12,697 - root - INFO - Avg val loss=0.024337 273 | 2022-11-09 15:48:12,697 - root - INFO - Total validation time: 3.5227739810943604 sec 274 | 2022-11-09 15:48:32,134 - root - INFO - Time taken for epoch 3 is 19.434626817703247 sec, avg 210.75784157938665 samples/sec 275 | 2022-11-09 15:48:32,136 - root - INFO - Avg train loss=0.018133 276 | 2022-11-09 15:48:35,182 - root - INFO - Avg val loss=0.020886 277 | 2022-11-09 15:48:35,182 - root - INFO - Total validation time: 3.0449066162109375 sec 278 | 2022-11-09 15:48:54,834 - root - INFO - Time taken for epoch 4 is 19.65017533302307 sec, avg 208.44597722832904 samples/sec 279 | 2022-11-09 15:48:54,834 - root - INFO - Avg train loss=0.016385 280 | 2022-11-09 15:48:57,621 - root - INFO - Avg val loss=0.020395 281 | 2022-11-09 15:48:57,621 - root - INFO - Total validation time: 2.786870241165161 sec 282 | ``` 283 | 284 | Increasing the number of workers to 8 improves performance to around 200 samples per second, while increasing to 16 workers yields only a slight improvement from this. 285 | 286 | We can run the 16 worker configuration through profiler using the instructions in the previous section with the added `--num_data_workers` 287 | argument and load that profile in Nsight Systems. This is what this profile ([`16workers.nsys-rep`](sample_nsys_profiles/16workers.nsys-rep)) looks like: 288 | ![NSYS Native Data](tutorial_images/nsys_nativedata_16workers.png) 289 | 290 | and zoomed in: 291 | ![NSYS Native Data Zoomed](tutorial_images/nsys_nativedata_16workers_zoomed.png) 292 | 293 | With 16 data workers, the large gaps between steps somewhat alleviated, improving the throughput. However, from the zoomed out view, we still see large gaps between groups of 16 iterations. Looking at the zoomed in profile, we 294 | still see that the H2D copy in of the input data takes some time and could be improved. One option here is to implement a prefetching 295 | mechanism in PyTorch directly using CUDA streams to concurrently load and copy in the next batch of input during the current batch, however 296 | this is left as an exercise outside of this tutorial. A good example of this can be found in [here](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ConvNets/image_classification/dataloaders.py#L347). 297 | 298 | Using benchy, we can also check how the various throughputs compare using 16 data workers. Running this configuration on Perlmutter 299 | using the tool yields the following: 300 | ``` 301 | BENCHY::SUMMARY::IO average trial throughput: 234.450 +/- 47.531 302 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 415.428 +/- 18.697 303 | BENCHY::SUMMARY::FULL average trial throughput: 166.868 +/- 2.432 304 | ``` 305 | `IO` is faster as expected, and the `FULL` throughput increases correspondingly. However, `IO` is still lower than `SYNTHETIC`, meaning we 306 | should still address data loading before focusing on compute improvements. 307 | 308 | #### Using NVIDIA DALI 309 | While we were able to get more performance out of the PyTorch native DataLoader, there are several overheads we cannot overcome in 310 | PyTorch alone: 311 | 1. The PyTorch DataLoader will use CPU operations for all I/O operations as well as data augmentations 312 | 2. The PyTorch DataLoader uses multi-processing to spawn data workers, which has performance overheads compared to true threads 313 | 314 | The NVIDIA DALI library is a data loading library that can address both of these points: 315 | 1. DALI can perform a wide array of data augmentation operations on the GPU, benefitting from acceleration relative to the CPU. 316 | 2. DALI maintains its own worker threads in the C++ backend, enabling much more performant threading and concurrent operation. 317 | 318 | For this tutorial, we've provided an alternative data loader using DALI to accelerate the data augementations used in this training script (e.g. 3D cropping, rotations, and flips) that can be found in `utils/data_loader_dali.py`. This data loader is enabled via the command line 319 | argument `--data_loader_config=dali-lowmem` to the training script. 320 | 321 | We can run this experiment on Perlmutter using DALI with 8 worker threads by running the following command: 322 | ``` 323 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem 324 | ``` 325 | If running interactively: 326 | ``` 327 | python train.py --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem 328 | ``` 329 | 330 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64 and DALI: 331 | ``` 332 | 2022-11-09 16:06:54,501 - root - INFO - Time taken for epoch 1 is 218.65716981887817 sec, avg 18.439825244879255 samples/sec 333 | 2022-11-09 16:06:54,502 - root - INFO - Avg train loss=0.067234 334 | 2022-11-09 16:07:02,075 - root - INFO - Avg val loss=0.045643 335 | 2022-11-09 16:07:02,079 - root - INFO - Total validation time: 7.572189092636108 sec 336 | 2022-11-09 16:07:14,529 - root - INFO - Time taken for epoch 2 is 12.448298931121826 sec, avg 329.0409414702956 samples/sec 337 | 2022-11-09 16:07:14,529 - root - INFO - Avg train loss=0.029191 338 | 2022-11-09 16:07:15,911 - root - INFO - Avg val loss=0.026629 339 | 2022-11-09 16:07:15,911 - root - INFO - Total validation time: 1.3810546398162842 sec 340 | 2022-11-09 16:07:23,015 - root - INFO - Time taken for epoch 3 is 7.101759910583496 sec, avg 576.7584446069318 samples/sec 341 | 2022-11-09 16:07:23,018 - root - INFO - Avg train loss=0.019404 342 | 2022-11-09 16:07:24,046 - root - INFO - Avg val loss=0.021680 343 | 2022-11-09 16:07:24,046 - root - INFO - Total validation time: 1.0282492637634277 sec 344 | 2022-11-09 16:07:31,144 - root - INFO - Time taken for epoch 4 is 7.096238374710083 sec, avg 577.2072165159393 samples/sec 345 | 2022-11-09 16:07:31,145 - root - INFO - Avg train loss=0.017324 346 | 2022-11-09 16:07:31,782 - root - INFO - Avg val loss=0.020755 347 | 2022-11-09 16:07:31,783 - root - INFO - Total validation time: 0.6371126174926758 sec 348 | ``` 349 | 350 | We can run the DALI case through profiler using the instructions in the earlier section with the added `--data_loader_config=dali-lowmem` 351 | argument and load that profile in Nsight Systems. This is what this profile ([`dali.nsys-rep`](sample_nsys_profiles/dali.nsys-rep)) looks like: 352 | ![NSYS DALI](tutorial_images/nsys_dali.png) 353 | 354 | and zoomed in to a single iteration: 355 | ![NSYS DALI Zoomed](tutorial_images/nsys_dali_zoomed.png) 356 | 357 | With DALI, you will see that there are now multiple CUDA stream rows in the timeline view, corresponding to internal streams DALI uses 358 | to run data augmentation kernels and any memory movement concurrently with the existing PyTorch compute kernels. Stream 16 in this view, in particular, shows concurrent H2D memory copies of the batch input data, which is an improvement over the native dataloader. 359 | 360 | Running this case using benchy on Perlmutter results in the following throughput measurements: 361 | ``` 362 | BENCHY::SUMMARY::IO average trial throughput: 943.632 +/- 83.507 363 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 592.984 +/- 0.052 364 | BENCHY::SUMMARY::FULL average trial throughput: 578.118 +/- 0.046 365 | ``` 366 | One thing we can notice here is that the `SYNTHETIC` speed is increased from previous cases. This is because the synthetic data sample that 367 | is cached and reused from the DALI data loader is already resident on the GPU, in contrast to the case using the PyTorch dataloader where 368 | the cached sample is in CPU memory. As a result, the `SYNTHETIC` result here is improved due to no longer requiring a H2D memory copy. 369 | In general, we now see that the `IO` throughput is greater than the `SYNTHETIC`, meaning the data loader can keep up with the compute 370 | throughput with additional headroom for compute speed improvements. 371 | 372 | ### Enabling Mixed Precision Training 373 | Now that the data loading performance is faster than the synthetic compute throughput, we can start looking at improving compute performance. As a first step to improve the compute performance of this training script, we can enable automatic mixed precision (AMP) in PyTorch. AMP provides a simple way for users to convert existing FP32 training scripts to mixed FP32/FP16 precision, unlocking 374 | faster computation with Tensor Cores on NVIDIA GPUs. 375 | 376 | The AMP module in torch is composed of two main parts: `torch.cuda.amp.GradScaler` and `torch.cuda.amp.autocast`. `torch.cuda.amp.GradScaler` handles automatic loss scaling to control the range of FP16 gradients. 377 | The `torch.cuda.amp.autocast` context manager handles converting model operations to FP16 where appropriate. 378 | 379 | As a quick note, the A100 GPUs we've been using to report results thus far have been able to benefit from Tensor Core compute via the use of TF32 precision operations, enabled by default for CUDNN and CUBLAS in PyTorch. We can measure the benefit of TF32 precision usage on the A100 GPU by temporarily disabling it via setting the environment variable `NVIDIA_TF32_OVERRIDE=0`. 380 | We can run this experiment on Perlmutter by running the following command: 381 | ``` 382 | NVIDIA_TF32_OVERRIDE=0 sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem 383 | ``` 384 | yields the following result for 4 epochs: 385 | ``` 386 | 2022-11-09 16:14:15,095 - root - INFO - Time taken for epoch 1 is 239.92680501937866 sec, avg 16.80512521172588 samples/sec 387 | 2022-11-09 16:14:15,097 - root - INFO - Avg train loss=0.067225 388 | 2022-11-09 16:14:22,583 - root - INFO - Avg val loss=0.041215 389 | 2022-11-09 16:14:22,585 - root - INFO - Total validation time: 7.484572410583496 sec 390 | 2022-11-09 16:14:50,312 - root - INFO - Time taken for epoch 2 is 27.725804328918457 sec, avg 147.73241387005703 samples/sec 391 | 2022-11-09 16:14:50,317 - root - INFO - Avg train loss=0.027006 392 | 2022-11-09 16:14:51,934 - root - INFO - Avg val loss=0.024100 393 | 2022-11-09 16:14:51,934 - root - INFO - Total validation time: 1.6165187358856201 sec 394 | 2022-11-09 16:15:19,669 - root - INFO - Time taken for epoch 3 is 27.71122097969055 sec, avg 147.81015975448872 samples/sec 395 | 2022-11-09 16:15:19,671 - root - INFO - Avg train loss=0.018199 396 | 2022-11-09 16:15:21,012 - root - INFO - Avg val loss=0.020106 397 | 2022-11-09 16:15:21,012 - root - INFO - Total validation time: 1.3401463031768799 sec 398 | 2022-11-09 16:15:48,762 - root - INFO - Time taken for epoch 4 is 27.7261164188385 sec, avg 147.73075096867782 samples/sec 399 | 2022-11-09 16:15:48,762 - root - INFO - Avg train loss=0.016480 400 | 2022-11-09 16:15:49,956 - root - INFO - Avg val loss=0.019319 401 | 2022-11-09 16:15:49,956 - root - INFO - Total validation time: 1.193620204925537 sec 402 | ``` 403 | From here, we can see that running in FP32 without TF32 acceleration is much slower and we are already seeing great performance from 404 | TF32 Tensor Core operations without any code changes to add AMP. With that said, AMP can still be a useful improvement for A100 GPUs, 405 | as TF32 is a compute type only, leaving all data in full precision FP32. FP16 precision has the compute benefits of Tensor Cores combined with a reduction in storage and memory bandwidth requirements. 406 | 407 | We can run this experiment using AMP on Perlmutter by running the following command: 408 | ``` 409 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 410 | ``` 411 | If running interactively: 412 | ``` 413 | python train.py --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 414 | ``` 415 | 416 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64, DALI, and AMP: 417 | ``` 418 | 2022-11-09 16:30:03,601 - root - INFO - Time taken for epoch 1 is 293.50703406333923 sec, avg 13.737319832443568 samples/sec 419 | 2022-11-09 16:30:03,602 - root - INFO - Avg train loss=0.065381 420 | 2022-11-09 16:30:03,603 - root - WARNING - DALI iterator does not support resetting while epoch is not finished. Ignoring... 421 | 2022-11-09 16:30:11,238 - root - INFO - Avg val loss=0.042530 422 | 2022-11-09 16:30:11,238 - root - INFO - Total validation time: 7.635042905807495 sec 423 | 2022-11-09 16:30:23,569 - root - INFO - Time taken for epoch 2 is 12.32933497428894 sec, avg 332.21580957461373 samples/sec 424 | 2022-11-09 16:30:23,570 - root - INFO - Avg train loss=0.027131 425 | 2022-11-09 16:30:24,948 - root - INFO - Avg val loss=0.026105 426 | 2022-11-09 16:30:24,949 - root - INFO - Total validation time: 1.378551721572876 sec 427 | 2022-11-09 16:30:30,479 - root - INFO - Time taken for epoch 3 is 5.5291588306427 sec, avg 740.7998441462547 samples/sec 428 | 2022-11-09 16:30:30,479 - root - INFO - Avg train loss=0.018360 429 | 2022-11-09 16:30:31,495 - root - INFO - Avg val loss=0.021196 430 | 2022-11-09 16:30:31,495 - root - INFO - Total validation time: 1.015498161315918 sec 431 | 2022-11-09 16:30:36,787 - root - INFO - Time taken for epoch 4 is 5.289811372756958 sec, avg 774.3187254454474 samples/sec 432 | 2022-11-09 16:30:36,787 - root - INFO - Avg train loss=0.016491 433 | 2022-11-09 16:30:37,415 - root - INFO - Avg val loss=0.020216 434 | 2022-11-09 16:30:37,415 - root - INFO - Total validation time: 0.6275067329406738 sec 435 | ``` 436 | 437 | We can run the case with AMP enabled through profiler using the instructions in the earlier section with the added `--amp_mode=fp16` 438 | argument and load that profile in Nsight Systems. This is what this profile ([`dali_amp.nsys-rep`](sample_nsys_profiles/dali_amp.nsys-rep)) looks like: 439 | ![NSYS DALI AMP](tutorial_images/nsys_dali_amp.png) 440 | 441 | and zoomed in to a single iteration: 442 | ![NSYS DALI AMP Zoomed](tutorial_images/nsys_dali_amp_zoomed.png) 443 | 444 | With AMP enabled, we see that the `forward` (and, correspondingly the backward) time is significantly reduced. As this is a CNN, the forward and backward convolution ops are well-suited to benefit from acceleration with tensor cores and that is where we see the most benefit. 445 | 446 | Running this case using benchy on Perlmutter results in the following throughput measurements: 447 | ``` 448 | BENCHY::SUMMARY::IO average trial throughput: 929.612 +/- 92.659 449 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 820.332 +/- 0.371 450 | BENCHY::SUMMARY::FULL average trial throughput: 665.790 +/- 1.026 451 | ``` 452 | From these results, we can see a big improvement in the `SYNTHETIC` and `FULL` throughput from using mixed-precision training over 453 | TF32 alone. 454 | 455 | ### Just-in-time (JIT) compiliation and APEX fused optimizers 456 | While AMP provided a large increase in compute speed already, there are a few other optimizations available for PyTorch to improve 457 | compute throughput. A first (and simple change) is to replace the Adam optimizer from `torch.optim.Adam` with a fused version from 458 | [APEX](https://github.com/NVIDIA/apex), `apex.optimizers.FusedAdam`. This fused optimizer uses fewer kernels to perform the weight 459 | update than the standard PyTorch optimizer, reducing latency and making more efficient use of GPU bandwidth by increasing register 460 | reuse. We can enabled the use of the `FusedAdam` optimizer in our training script by adding the flag `--enable_apex`. 461 | 462 | We can run this experiment using APEX on Perlmutter by running the following command: 463 | ``` 464 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 --enable_apex 465 | ``` 466 | If running interactively: 467 | ``` 468 | python train.py --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 --enable_apex 469 | ``` 470 | 471 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64, DALI, and AMP, and APEX: 472 | ``` 473 | 2022-11-09 16:41:30,604 - root - INFO - Time taken for epoch 1 is 244.000159740448 sec, avg 16.52457934572251 samples/sec 474 | 2022-11-09 16:41:30,605 - root - INFO - Avg train loss=0.067127 475 | 2022-11-09 16:41:38,195 - root - INFO - Avg val loss=0.042797 476 | 2022-11-09 16:41:38,195 - root - INFO - Total validation time: 7.589394569396973 sec 477 | 2022-11-09 16:41:50,463 - root - INFO - Time taken for epoch 2 is 12.266955137252808 sec, avg 333.90519115547215 samples/sec 478 | 2022-11-09 16:41:50,463 - root - INFO - Avg train loss=0.028232 479 | 2022-11-09 16:41:51,829 - root - INFO - Avg val loss=0.025897 480 | 2022-11-09 16:41:51,829 - root - INFO - Total validation time: 1.3656654357910156 sec 481 | 2022-11-09 16:41:57,088 - root - INFO - Time taken for epoch 3 is 5.256546497344971 sec, avg 779.2188278119196 samples/sec 482 | 2022-11-09 16:41:57,088 - root - INFO - Avg train loss=0.018998 483 | 2022-11-09 16:41:58,074 - root - INFO - Avg val loss=0.021492 484 | 2022-11-09 16:41:58,075 - root - INFO - Total validation time: 0.9862794876098633 sec 485 | 2022-11-09 16:42:03,412 - root - INFO - Time taken for epoch 4 is 5.336004257202148 sec, avg 767.6155794800048 samples/sec 486 | 2022-11-09 16:42:03,412 - root - INFO - Avg train loss=0.017139 487 | 2022-11-09 16:42:04,020 - root - INFO - Avg val loss=0.020624 488 | 2022-11-09 16:42:04,020 - root - INFO - Total validation time: 0.6071317195892334 sec 489 | ``` 490 | 491 | While APEX provides some already fused kernels, for more general fusion of eligible pointwise operations in PyTorch, we can enable 492 | JIT compilation, done in our training script via the flag `--enable_jit`. 493 | 494 | We can run this experiment using JIT on Perlmutter by running the following command: 495 | ``` 496 | sbatch -n 1 ./submit_pm.sh --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 --enable_apex --enable_jit 497 | ``` 498 | If running interactively: 499 | ``` 500 | python train.py --config=short --num_epochs 4 --num_data_workers 8 --data_loader_config=dali-lowmem --amp_mode=fp16 --enable_apex --enable_jit 501 | ``` 502 | 503 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 64, DALI, and AMP, APEX and JIT: 504 | ``` 505 | 2022-11-09 16:43:43,077 - root - INFO - Time taken for epoch 1 is 238.02903866767883 sec, avg 16.93910970933771 samples/sec 506 | 2022-11-09 16:43:43,081 - root - INFO - Avg train loss=0.075173 507 | 2022-11-09 16:43:50,692 - root - INFO - Avg val loss=0.048126 508 | 2022-11-09 16:43:50,692 - root - INFO - Total validation time: 7.610842704772949 sec 509 | 2022-11-09 16:44:03,150 - root - INFO - Time taken for epoch 2 is 12.457206010818481 sec, avg 328.80567251138194 samples/sec 510 | 2022-11-09 16:44:03,151 - root - INFO - Avg train loss=0.030912 511 | 2022-11-09 16:44:04,513 - root - INFO - Avg val loss=0.027476 512 | 2022-11-09 16:44:04,513 - root - INFO - Total validation time: 1.362241506576538 sec 513 | 2022-11-09 16:44:09,757 - root - INFO - Time taken for epoch 3 is 5.242457389831543 sec, avg 781.3129788989315 samples/sec 514 | 2022-11-09 16:44:09,758 - root - INFO - Avg train loss=0.020107 515 | 2022-11-09 16:44:10,752 - root - INFO - Avg val loss=0.021986 516 | 2022-11-09 16:44:10,752 - root - INFO - Total validation time: 0.9937717914581299 sec 517 | 2022-11-09 16:44:15,990 - root - INFO - Time taken for epoch 4 is 5.2364501953125 sec, avg 782.2092920250833 samples/sec 518 | 2022-11-09 16:44:15,990 - root - INFO - Avg train loss=0.017781 519 | 2022-11-09 16:44:16,587 - root - INFO - Avg val loss=0.020978 520 | 2022-11-09 16:44:16,587 - root - INFO - Total validation time: 0.5963444709777832 sec 521 | ``` 522 | 523 | Running a profile ([`dali_amp_apex_jit.nsys-rep`](sample_nsys_profiles/dali_amp_apex_jit.nsys-rep)) using these new options and loading in Nsight Systems looks like this: 524 | ![NSYS DALI AMP APEX JIT](tutorial_images/nsys_dali_amp_apex_jit.png) 525 | 526 | and zoomed in to a single iteration: 527 | ![NSYS DALI AMP APEX JIT Zoomed](tutorial_images/nsys_dali_amp_apex_jit_zoomed.png) 528 | 529 | Running this case with APEX and JIT enabled using benchy on Perlmutter results in the following throughput measurements: 530 | ``` 531 | BENCHY::SUMMARY::IO average trial throughput: 936.818 +/- 95.516 532 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 893.160 +/- 0.250 533 | BENCHY::SUMMARY::FULL average trial throughput: 683.573 +/- 1.248 534 | ``` 535 | We see a modest gain in the `SYNTHETIC` throughput, resuling in a slight increase in the `FULL` throughput. 536 | 537 | ### Using CUDA Graphs (optional) 538 | In this repository, we've included an alternative training script [train_graph.py](train_graph.py) that illustrates applying 539 | PyTorch's new CUDA Graphs functionality to the existing model and training loop. Our tutorial model configuration does not benefit 540 | much using CUDA Graphs, but for models with more CPU latency issues (e.g. from many small kernel launches), CUDA graphs are 541 | something to consider to improve. Compare [train.py](train.py) and [train_graph.py](train_graph.py) to see 542 | how to use CUDA Graphs in PyTorch. 543 | 544 | ### Full training with optimizations 545 | Now you can run the full model training on a single GPU with our optimizations. For convenience, we provide a configuration with the optimizations already enabled. Submit the full training with: 546 | 547 | ``` 548 | sbatch -n 1 -t 40 ./submit_pm.sh --config=bs64_opt 549 | ``` 550 | 551 | ## Distributed GPU training 552 | 553 | Now that we have model training code that is optimized for training on a single GPU, 554 | we are ready to utilize multiple GPUs and multiple nodes to accelerate the workflow 555 | with *distributed training*. We will use the recommended `DistributedDataParallel` 556 | wrapper in PyTorch with the NCCL backend for optimized communication operations on 557 | systems with NVIDIA GPUs. Refer to the PyTorch documentation for additional details 558 | on the distributed package: https://pytorch.org/docs/stable/distributed.html 559 | 560 | ### Code basics 561 | 562 | To submit a multi-GPU job, use the `submit_pm.sh` with the `-n` option set to the desired number of GPUs. For example, to launch a training with multiple GPUs, you will use commands like: 563 | ``` 564 | sbatch -n NUM_GPU submit_pm.sh [OPTIONS] 565 | ``` 566 | This script automatically uses the slurm flags `--ntasks-per-node 4`, `--cpus-per-task 32`, `--gpus-per-node 4`, so slurm will allocate all the CPUs and GPUs available on each Perlmutter GPU node, and launch one process for each GPU in the job. This way, multi-node trainings can easily be launched simply by setting `-n` to multiples of 4. 567 | 568 | *Question: why do you think we run 1 task (cpu process) per GPU, instead of 1 task per node (each running 4 GPUs)?* 569 | 570 | PyTorch `DistributedDataParallel`, or DDP for short, is flexible and can initialize process groups with a variety of methods. For this code, we will use the standard approach of initializing via environment variables, which can be easily read from the slurm environment. Take a look at the `export_DDP_vars.sh` helper script, which is used by our job script to expose for PyTorch DDP the global rank and node-local rank of each process, along with the total number of ranks and the address and port to use for network communication. In the [`train.py`](train.py) script, near the bottom in the main script execution, we set up the distributed backend using these environment variables via `torch.distributed.init_proces_group`. 571 | 572 | When distributing a batch of samples in DDP training, we must make sure each rank gets a properly-sized subset of the full batch. See if you can find where we use the `DistributedSampler` from PyTorch to properly partition the data in [`utils/data_loader.py`](utils/data_loader.py). Note that in this particular example, we are already cropping samples randomly form a large simulation volume, so the partitioning does not ensure each rank gets unique data, but simply shortens the number of steps needed to complete an "epoch". For datasets with a fixed number of unique samples, `DistributedSampler` will also ensure each rank sees a unique minibatch. 573 | 574 | In `train.py`, after our U-Net model is constructed, 575 | we convert it to a distributed data parallel model by wrapping it as: 576 | ``` 577 | model = DistributedDataParallel(model, device_ids=[local_rank]) 578 | ``` 579 | 580 | The DistributedDataParallel (DDP) model wrapper takes care of broadcasting 581 | initial model weights to all workers and performing all-reduce on the gradients 582 | in the training backward pass to properly synchronize and update the model 583 | weights in the distributed setting. 584 | 585 | *Question: why does DDP broadcast the initial model weights to all workers? What would happen if it didn't?* 586 | 587 | ### Large batch convergence 588 | 589 | To speed up training, we try to use larger batch sizes, spread across more GPUs, 590 | with larger learning rates. The base config uses a batchsize of 64 for single-GPU training, so we will set `base_batch_size=64` in our configs and then increase the `global_batch_size` parameter in increments of 64 for every additional GPU we add to the distributed training. Then, we can take the ratio of `global_batch_size` and `base_batch_size` to decide how much to scale up the learning rate as the global batch size grows. In this section, we will make use of the square-root scaling rule, which multiplies the base initial learning rate by `sqrt(global_batch_size/base_batch_size)`. Take a look at [`utils/__init__.py`](utils/__init__.py) to see how this is implemented. 591 | 592 | *Question: how do you think the loss curves would change if we didn't increase the learning rate at all as we scale up?* 593 | 594 | *Question: what do you think would happen if we simply increased our learning rate without increasing batch size?* 595 | 596 | As a first attempt, let's try increasing the batchsize from 64 to 512, distributing our training across 8 GPUs (thus two GPU nodes on Perlmutter). To submit a job with this config, do 597 | ``` 598 | sbatch -t 10 -n 8 submit_pm.sh --config=bs512_test 599 | ``` 600 | 601 | Looking at the TensorBoard log, we can see that the rate of convergence is increased initially, but the validation loss plateaus quickly and our final accuracy ends up worse than the single-GPU training: 602 | ![batchsize 512 bad](tutorial_images/bs512_short.png) 603 | 604 | From the plot, we see that with a global batch size of 512 we complete each epoch in a much shorter amount of time, so training concludes rapidly. This affects our learning rate schedule, which depends on the total number of steps as set in `train.py`: 605 | ``` 606 | params.lr_schedule['tot_steps'] = params.num_epochs*(params.Nsamples//params.global_batch_size) 607 | ``` 608 | 609 | If we increase the total number of epochs, we will run longer (thus giving the model more training iterations to update weights) and the learning rate will decay more slowly, giving us more time to converge quickly with a larger learning rate. To try this out, run the `bs512_opt` config, which runs for 40 epochs rather than the default 10: 610 | ``` 611 | sbatch -t 20 -n 8 submit_pm.sh --config=bs512_opt 612 | ``` 613 | With the longer training, we can see that our higher batch size results are slightly better than the baseline configuration. Furthermore, the minimum in the loss is reached sooner, despite running for more epochs: 614 | ![batchsize 512 good](tutorial_images/bs512.png) 615 | 616 | Based on our findings, we can strategize to have trainings with larger batch sizes run for half as many total iterations as the baseline, as a rule of thumb. You can see this imlemented in the different configs for various global batchsizes: `bs256_opt`, `bs512_opt`, `bs2048_opt`. However, to really compare how our convergence is improving between these configurations, we must consider the actual time-to-solution. To do this in TensorBoard, select the "Relative" option on the left-hand side, which will change the x-axis in each plot to show walltime of the job (in hours), relative to the first data point: 617 | 618 | ![relative option for tensorboard](tutorial_images/relative.png) 619 | 620 | With this selected, we can compare results between these different configs as a function of time, and see that all of them improve over the baseline. Furthermore, the rate of convergence improves as we add more GPUs and increase the global batch size: 621 | 622 | ![comparison across batchsizes](tutorial_images/bs_compare.png) 623 | 624 | Based on our study, we see that scaling up our U-Net can definitely speed up training and reduce time-to-solution. Compared to our un-optimized single-GPU baseline from the first section, which took around 2 hours to train, we can now converge in about 10 minutes, which is a great speedup! We have also seen that there are several considerations to be aware of and several key hyperparameters to tune. We encourage you to now play with some of these settings and observe how they can affect the results. The main parameters in `config/UNet.yaml` to consider are: 625 | 626 | * `num_epochs`, to adjust how long it takes for learning rate to decay and for training to conclude. 627 | * `lr_schedule`, to choose how to scale up learning rate, or change the start and end learning rates. 628 | * `global_batch_size`. We ask that you limit yourself to a maximum of 8 GPUs initially for this section, to ensure everyone gets sufficient access to compute resources. 629 | 630 | You should also consider the following questions: 631 | * *What are the limitations to scaling up batch size and learning rates?* 632 | * *What would happen to the learning curves and runtime if we did "strong scaling" instead (hold global batch size fixed as we increase GPUs, and respectively decrease the local batch size)?* 633 | 634 | ## Multi-GPU performance profiling and optimization 635 | 636 | With distributed training enabled and large batch convergence tested, we are ready 637 | to optimize the multi-GPU training throughput. We start with understanding and ploting 638 | the performance of our application as we scale. Then we can go in more details and profile 639 | the multi-GPU training with Nsight Systems to understand the communication performance. 640 | 641 | ### Weak and Strong Throughput Scaling 642 | 643 | First we want to measure the scaling efficiency. An example command to generate the points for 8 nodes is: 644 | ``` 645 | BENCHY_OUTPUT=weak_scale sbatch -N 8 ./submit_pm.sh --num_data_workers 4 --local_batch_size 64 --config=bs64_opt --enable_benchy --num_epochs 15 646 | ``` 647 | 648 | 649 | 650 | The plot shows the throughput as we scale up to 32 nodes. The solid green line shows the real data throughput, while the dotted green line shows the ideal throughput, i.e. if we multiply the single GPU throughput by the number of GPUs used. For example for 32 nodes we get around 84% scaling efficiency. The blue lines show the data throughput by running the data-loader in isolation. The orange lines show the throughput for synthetic data. 651 | 652 | Next we can further breakdown the performance of the applications, by switching off the communication between workers. An example command to generate the points for 8 nodes and adding the noddp flag is: 653 | ``` 654 | BENCHY_OUTPUT=weak_scale_noddp sbatch -N 8 ./submit_pm.sh --num_data_workers 4 --local_batch_size 64 --config=bs64_opt --enable_benchy --noddp --num_epochs 15 655 | ``` 656 | 657 | 658 | 659 | The orange line is with synthetic data, so no I/O overhead, and the orange dotted line is with synthetic data but having the communication between compute switched off. That effectively makes the dotted orange line the compute of the application. By comparing it with the solid orange line we can get the communication overhead. For example in this case for 32 nodes the communication overhead is around 12%. 660 | 661 | 662 | ### Profiling with Nsight Systems 663 | 664 | Using the optimized options for compute and I/O, we profile the communication baseline with 665 | 4 GPUs (1 node) on Perlmutter: 666 | ``` 667 | ENABLE_PROFILING=1 PROFILE_OUTPUT=4gpu_baseline sbatch -n 4 ./submit_pm.sh --config=bs64_opt --num_epochs 4 --num_data_workers 8 --local_batch_size 16 --enable_apex --enable_jit --enable_manual_profiling 668 | ``` 669 | Considering both the case of strong scaling and large-batch training limitation, the 670 | `local_batch_size`, i.e. per GPU batch size, is set to 16 to show the effect of communication. Loading this profile ([`4gpu_baseline.nsys-rep`](sample_nsys_profiles/4gpu_baseline.nsys-rep)) in Nsight Systems will look like this: 671 | ![NSYS 4gpu_Baseline](tutorial_images/nsys_4gpu_baseline.png) 672 | where the stream 20 shows the NCCL communication calls. 673 | 674 | By default, for our model there are 8 NCCL calls per iteration, as shown in zoomed-in view: 675 | ![NSYS 4gpu_Baseline_zoomed](tutorial_images/nsys_4gpu_baseline_zoomed.png) 676 | 677 | The performance of this run: 678 | ``` 679 | 2022-11-11 00:13:06,915 - root - INFO - Time taken for epoch 1 is 232.8502073287964 sec, avg 281.1764728538557 samples/sec 680 | 2022-11-11 00:13:06,923 - root - INFO - Avg train loss=0.014707 681 | 2022-11-11 00:13:24,620 - root - INFO - Avg val loss=0.007693 682 | 2022-11-11 00:13:24,626 - root - INFO - Total validation time: 17.69732642173767 sec 683 | 2022-11-11 00:13:57,324 - root - INFO - Time taken for epoch 2 is 32.692954301834106 sec, avg 2004.5909401440472 samples/sec 684 | 2022-11-11 00:13:57,326 - root - INFO - Avg train loss=0.006410 685 | 2022-11-11 00:13:59,928 - root - INFO - Avg val loss=0.006294 686 | 2022-11-11 00:13:59,936 - root - INFO - Total validation time: 2.601088762283325 sec 687 | 2022-11-11 00:14:34,003 - root - INFO - Time taken for epoch 3 is 34.06396484375 sec, avg 1923.909923598469 samples/sec 688 | 2022-11-11 00:14:34,008 - root - INFO - Avg train loss=0.005792 689 | 2022-11-11 00:14:36,751 - root - INFO - Avg val loss=0.005899 690 | 2022-11-11 00:14:36,751 - root - INFO - Total validation time: 2.7424585819244385 sec 691 | ``` 692 | 693 | ### Adjusting DistributedDataParallel options 694 | 695 | The [tuning knobs](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 696 | for `DistributedDataParallel` includes `broadcast_buffers`, `bucket_cap_mb`, etc. `broadcast_buffers` adds 697 | additional communication (syncing buffers) and is enabled by default, which is often not necessary. `bucket_cap_mb` 698 | sets a upper limit for the messsage size per NCCL call, adjusting which can change the total number of communication 699 | calls per iteration. The proper bucket size depends on the overlap between communication and computation, and requires 700 | tunning. 701 | 702 | Since there is no batch norm layer in our model, it's safe to disable the `broadcast_buffers` with the added knob `--disable_broadcast_buffers`: 703 | ``` 704 | ENABLE_PROFILING=1 PROFILE_OUTPUT=4gpu_nobroadcast sbatch -n 4 ./submit_pm.sh --config=bs64_opt --num_epochs 4 --num_data_workers 8 --local_batch_size 16 --enable_apex --enable_jit --enable_manual_profiling --disable_broadcast_buffers 705 | ``` 706 | Loading this profile ([`4gpu_nobroadcast.nsys-rep`](sample_nsys_profiles/4gpu_nobroadcast.nsys-rep)) in Nsight Systems will look like this: 707 | ![NSYS 4gpu_nobroadcast](tutorial_images/nsys_4gpu_nobroadcast.png) 708 | The per step timing is slightly improved comparing to the baseline. 709 | 710 | The performance of this run: 711 | ``` 712 | 2022-11-11 00:13:04,938 - root - INFO - Time taken for epoch 1 is 231.4636251926422 sec, avg 282.86085965131264 samples/sec 713 | 2022-11-11 00:13:04,939 - root - INFO - Avg train loss=0.015009 714 | 2022-11-11 00:13:22,565 - root - INFO - Avg val loss=0.007752 715 | 2022-11-11 00:13:22,566 - root - INFO - Total validation time: 17.62473154067993 sec 716 | 2022-11-11 00:13:54,745 - root - INFO - Time taken for epoch 2 is 32.171440839767456 sec, avg 2037.0862569198412 samples/sec 717 | 2022-11-11 00:13:54,747 - root - INFO - Avg train loss=0.006358 718 | 2022-11-11 00:13:57,350 - root - INFO - Avg val loss=0.006510 719 | 2022-11-11 00:13:57,352 - root - INFO - Total validation time: 2.6025969982147217 sec 720 | 2022-11-11 00:14:29,527 - root - INFO - Time taken for epoch 3 is 32.17182660102844 sec, avg 2037.061830922121 samples/sec 721 | 2022-11-11 00:14:29,528 - root - INFO - Avg train loss=0.005735 722 | 2022-11-11 00:14:33,794 - root - INFO - Avg val loss=0.005844 723 | 2022-11-11 00:14:33,794 - root - INFO - Total validation time: 4.264811754226685 sec 724 | ``` 725 | Comparing to the baseline, there are few percentages (performance may slightly vary run by run) improvement in `samples/sec`. 726 | 727 | To show the effect of the message bucket size, we add another knob to the code, `--bucket_cap_mb`. The current 728 | default value in PyTorch is 25 mb. We profile a run with 100 mb bucket size with following command: 729 | ``` 730 | ENABLE_PROFILING=1 PROFILE_OUTPUT=4gpu_bucket100mb sbatch -n 4 ./submit_pm.sh --config=bs64_opt --num_epochs 4 --num_data_workers 8 --local_batch_size 16 --enable_apex --enable_jit --enable_manual_profiling --disable_broadcast_buffers --bucket_cap_mb 100 731 | ``` 732 | Loading this profile ([`4gpu_bucketcap100mb.nsys-rep`](sample_nsys_profiles/4gpu_bucketcap100mb.nsys-rep)) in Nsight Systems (zoomed in to a single iteration) will look like this: 733 | ![NSYS 4gpu_bucketcap100mb_zoomed](tutorial_images/nsys_4gpu_bucketcap100mb_zoomed.png) 734 | the total number of NCCL calls per step now reduced to 5. 735 | 736 | The performance of this run: 737 | ``` 738 | 2022-11-11 00:26:18,684 - root - INFO - Time taken for epoch 1 is 229.9525065422058 sec, avg 284.71966226636096 samples/sec 739 | 2022-11-11 00:26:18,685 - root - INFO - Avg train loss=0.014351 740 | 2022-11-11 00:26:36,334 - root - INFO - Avg val loss=0.007701 741 | 2022-11-11 00:26:36,334 - root - INFO - Total validation time: 17.648387670516968 sec 742 | 2022-11-11 00:27:08,169 - root - INFO - Time taken for epoch 2 is 31.827385425567627 sec, avg 2059.107247538892 samples/sec 743 | 2022-11-11 00:27:08,169 - root - INFO - Avg train loss=0.006380 744 | 2022-11-11 00:27:10,782 - root - INFO - Avg val loss=0.006292 745 | 2022-11-11 00:27:10,782 - root - INFO - Total validation time: 2.6118412017822266 sec 746 | 2022-11-11 00:27:42,651 - root - INFO - Time taken for epoch 3 is 31.86245894432068 sec, avg 2056.8406259706285 samples/sec 747 | 2022-11-11 00:27:42,651 - root - INFO - Avg train loss=0.005768 748 | 2022-11-11 00:27:45,328 - root - INFO - Avg val loss=0.005847 749 | 2022-11-11 00:27:45,329 - root - INFO - Total validation time: 2.67659068107605 sec 750 | ``` 751 | Similarly, to understand the cross node performance, we run the baseline and optimized options with 2 nodes on Perlmutter. 752 | 753 | Baseline: 754 | ``` 755 | ENABLE_PROFILING=1 PROFILE_OUTPUT=8gpu_baseline sbatch -N 2 ./submit_pm.sh --config=bs64_opt --num_epochs 4 --num_data_workers 8 --local_batch_size 16 --enable_apex --enable_jit --enable_manual_profiling 756 | ``` 757 | and the performance of the run: 758 | ``` 759 | 2022-11-11 00:20:02,171 - root - INFO - Time taken for epoch 1 is 213.5172028541565 sec, avg 306.33597258520246 samples/sec 760 | 2022-11-11 00:20:02,173 - root - INFO - Avg train loss=0.028723 761 | 2022-11-11 00:20:18,038 - root - INFO - Avg val loss=0.010950 762 | 2022-11-11 00:20:18,039 - root - INFO - Total validation time: 15.865173101425171 sec 763 | 2022-11-11 00:20:36,759 - root - INFO - Time taken for epoch 2 is 18.71891736984253 sec, avg 3501.0571768206546 samples/sec 764 | 2022-11-11 00:20:36,760 - root - INFO - Avg train loss=0.007529 765 | 2022-11-11 00:20:38,613 - root - INFO - Avg val loss=0.007636 766 | 2022-11-11 00:20:38,615 - root - INFO - Total validation time: 1.8524699211120605 sec 767 | 2022-11-11 00:20:58,163 - root - INFO - Time taken for epoch 3 is 19.54378581047058 sec, avg 3353.290945548999 samples/sec 768 | 2022-11-11 00:20:58,166 - root - INFO - Avg train loss=0.006395 769 | 2022-11-11 00:20:59,522 - root - INFO - Avg val loss=0.006702 770 | 2022-11-11 00:20:59,522 - root - INFO - Total validation time: 1.3556835651397705 sec 771 | ``` 772 | Optimized: 773 | ``` 774 | ENABLE_PROFILING=1 PROFILE_OUTPUT=8gpu_bucket100mb sbatch -N 2 ./submit_pm.sh --config=bs64_opt --num_epochs 4 --num_data_workers 8 --local_batch_size 16 --enable_apex --enable_jit --enable_manual_profiling --disable_broadcast_buffers --bucket_cap_mb 100 775 | ``` 776 | and the performance of the run: 777 | ``` 778 | 2022-11-11 00:20:22,173 - root - INFO - Time taken for epoch 1 is 217.77708864212036 sec, avg 300.3438075503293 samples/sec 779 | 2022-11-11 00:20:22,176 - root - INFO - Avg train loss=0.028697 780 | 2022-11-11 00:20:38,198 - root - INFO - Avg val loss=0.009447 781 | 2022-11-11 00:20:38,200 - root - INFO - Total validation time: 16.021918296813965 sec 782 | 2022-11-11 00:20:56,770 - root - INFO - Time taken for epoch 2 is 18.569196939468384 sec, avg 3529.285634356368 samples/sec 783 | 2022-11-11 00:20:56,772 - root - INFO - Avg train loss=0.007177 784 | 2022-11-11 00:20:59,962 - root - INFO - Avg val loss=0.007194 785 | 2022-11-11 00:20:59,964 - root - INFO - Total validation time: 3.190000534057617 sec 786 | 2022-11-11 00:21:18,232 - root - INFO - Time taken for epoch 3 is 18.263245820999146 sec, avg 3588.409236908287 samples/sec 787 | 2022-11-11 00:21:18,232 - root - INFO - Avg train loss=0.006346 788 | 2022-11-11 00:21:19,560 - root - INFO - Avg val loss=0.006503 789 | 2022-11-11 00:21:19,560 - root - INFO - Total validation time: 1.3270776271820068 sec 790 | ``` 791 | Note that the batch size is set to a small value to tune the knobs at smaller scale. To have a better scaliing efficiency, we 792 | want to increase the per GPU compute intensity by increasing the per GPU batch size. 793 | 794 | ## Putting it all together 795 | 796 | With all of our multi-GPU settings and optimizations in place, we now leave it to you to take what you've learned and try to achieve the best performance on this problem. Specifically, try to further tune things to either reach the lowest possible validation loss, or converge to the single-GPU validation loss (`~4.7e-3`) in the shortest amount of time. Some ideas for things to adjust are: 797 | * Further tune `num_epochs` to adjust how long it takes for learning rate to decay, and for training to conclude. 798 | * Play with the learning rate: try out a different scaling rule, such as linear scale-up of learning rate, or come up with your own learning rate schedule. 799 | * Change other components, such as the optimizer used. Here we have used the standard Adam optimizer, but many practitioners also use the SGD optimizer (with momentum) in distributed training. 800 | 801 | The [PyTorch docs](https://pytorch.org/docs/stable/index.html) will be helpful if you are attempting more advanced changes. 802 | -------------------------------------------------------------------------------- /README_summit.md: -------------------------------------------------------------------------------- 1 | # SC22 Deep Learning at Scale Tutorial (Summit Commands) 2 | 3 | Please refer to main [(README.md)](https://github.com/tsaris/sc22-dl-tutorial/blob/main/README.md) for details of the tutorial and how to run on NERSC's Perlmutter machine. This page has the commands on how to run on OLCF's Summit machine. 4 | 5 | Data location on Summit: `/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data` 6 | 7 | ## Installation and Setup 8 | 9 | ### Software environment 10 | 11 | For running jobs on Summit, we will use training accounts which are provided under the `TRN001` project. The script `submit_summit.sh` included in the repository is configured to work automatically as is. 12 | * `-P TRN001` is required for training accounts 13 | 14 | To begin, start a terminal and login to Summit: 15 | ```bash 16 | mkdir -p $WORLDWORK/trn001/$USER 17 | cd $WORLDWORK/trn001/$USER/ 18 | git clone https://github.com/tsaris/sc22-dl-tutorial.git 19 | cd sc22-dl-tutorial 20 | mkdir logs 21 | ``` 22 | 23 | ### Installing Nsight Systems 24 | In this tutorial, we will be generating profile files using NVIDIA Nsight Systems on the remote systems. In order to open and view these 25 | files on your local computer, you will need to install the Nsight Systems program, which you can download [here](https://developer.nvidia.com/gameworksdownload#?dn=nsight-systems-2021-4-1-73). Select the download option required for your system (e.g. Mac OS host for MacOS, Window Host for Windows, or Linux Host .rpm/.deb/.run for Linux). You may need to sign up and create a login to NVIDIA's developer program if you do not 26 | already have an account to access the download. Proceed to run and install the program using your selected installation method. 27 | 28 | ## Single GPU training [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#single-gpu-training) 29 | 30 | On Summit for the tutorial, we will be submitting jobs to the batch queue. To submit this job, use the following command: 31 | ``` 32 | bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=shorter_sm --num_epochs 3" 33 | ``` 34 | 35 | To view the results in TensorBoard: 36 | * login to `jupyter.olcf.ornl.gov` from your browsher with the olcf credentials 37 | * from `jupyter.olcf.ornl.gov` open the file `start_tensorboard_summit.ipynb` that you can find after you clone the repo at `$WORLDWORK/trn001/$USER/sc22-dl-tutorial` 38 | * select option `SC22 Training Series Lab for DL Tutorial Participants` 39 | 40 | ## Single GPU performance profiling and optimization [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial/blob/main/README.md#single-gpu-performance-profiling-and-optimization) 41 | 42 | This is the performance of the baseline script with `Nsamples: 512` and `Nsamples_val: 64` with batch size of 32 on 16GB V100 card. 43 | 44 | ``` 45 | 2022-11-07 15:13:32,641 - root - INFO - Time taken for epoch 1 is 191.95538854599 sec, avg 2.6672864141937414 samples/sec 46 | 2022-11-07 15:13:32,642 - root - INFO - Avg train loss=0.123056 47 | 2022-11-07 15:13:34,624 - root - INFO - Avg val loss=0.118149 48 | 2022-11-07 15:13:34,624 - root - INFO - Total validation time: 1.9810168743133545 sec 49 | 2022-11-07 15:16:36,183 - root - INFO - Time taken for epoch 2 is 181.55334734916687 sec, avg 2.820107739546723 samples/sec 50 | 2022-11-07 15:16:36,183 - root - INFO - Avg train loss=0.082647 51 | 2022-11-07 15:16:38,190 - root - INFO - Avg val loss=0.108615 52 | 2022-11-07 15:16:38,191 - root - INFO - Total validation time: 2.006861686706543 sec 53 | 2022-11-07 15:19:34,602 - root - INFO - Time taken for epoch 3 is 176.40894150733948 sec, avg 2.902347214518592 samples/sec 54 | 2022-11-07 15:19:34,602 - root - INFO - Avg train loss=0.074187 55 | 2022-11-07 15:19:36,335 - root - INFO - Avg val loss=0.100554 56 | 2022-11-07 15:19:36,335 - root - INFO - Total validation time: 1.7319042682647705 sec 57 | ``` 58 | 59 | ### Profiling with Nsight Systems 60 | #### Adding NVTX ranges and profiler controls [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#adding-nvtx-ranges-and-profiler-controls) 61 | 62 | To generate a profile using our scripts on Summit, run the following command: 63 | ``` 64 | ENABLE_PROFILING=1 PROFILE_OUTPUT=baseline bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 2 --enable_manual_profiling" 65 | ``` 66 | This command will run two epochs of the training script, profiling only 30 steps of the second epoch. It will produce a file baseline.qdrep that can be opened in the Nsight System's program. 67 | 68 | #### Using the benchy profiling tool [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#using-the-benchy-profiling-tool) 69 | 70 | To run using using benchy on Summit, use the following command: 71 | ``` 72 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --enable_benchy" 73 | ``` 74 | benchy uses epoch boundaries to separate the test trials it runs, so in these cases we increase the epoch limit to 10 to ensure the full experiment runs. 75 | 76 | benchy will report throughput measurements directly to the terminal, including a simple summary of averages at the end of the job. For this case on Perlmutter, the summary output from benchy is: 77 | ``` 78 | BENCHY::SUMMARY::IO average trial throughput: 4.750 +/- 0.122 79 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 102.148 +/- 0.079 80 | BENCHY::SUMMARY::FULL average trial throughput: 4.212 +/- 0.065 81 | ``` 82 | 83 | From these throughput values, we can see that the SYNTHETIC (i.e. compute) throughput is greater than the IO (i.e. data loading) throughput. The FULL (i.e. real) throughput is bounded by the slower of these two values, which is IO in this case. What these throughput values indicate is the GPU can achieve much greater training throughput for this model, but is being limited by the data loading speed. 84 | 85 | In fact on Summit without dataloading optimizations it is very slow, the above job took ~1h, so we recommend to start runs with dataload optimizations already in for the Summit system. 86 | 87 | ### Data loading optimizations 88 | #### Improving the native PyTorch dataloader performance [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#improving-the-native-pytorch-dataloader-performance) 89 | 90 | The PyTorch dataloader has several knobs we can adjust to improve performance. One knob we've left to adjust is the num_workers argument, which we can control via the `--num_data_workers` command line arg to our script. The default in our config is two workers, but it will be very slow if we use `Nsamples` larger than 512, so we are setting it up to seven workers. Reminder that each Summit node has 42 physical cores (168 hardware cores) and 6 GPUs. 91 | 92 | We can run this experiment on Summit by running the following command. This will take ~1h so we recommend go to start the runs with the DALI optimization already in from the next section. 93 | ``` 94 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --enable_benchy" 95 | ``` 96 | output 97 | ``` 98 | BENCHY::SUMMARY::IO average trial throughput: 4.750 +/- 0.122 99 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 102.148 +/- 0.079 100 | BENCHY::SUMMARY::FULL average trial throughput: 4.212 +/- 0.065 101 | ``` 102 | 103 | #### Using NVIDIA DALI [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#using-nvidia-dali) 104 | 105 | To use NVIDIA DALI use the `-data_loader_config=dali-lowmem` flag. 106 | 107 | ``` 108 | bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --data_loader_config=dali-lowmem --enable_benchy" 109 | ``` 110 | output 111 | ``` 112 | BENCHY::SUMMARY::IO average trial throughput: 582.159 +/- 40.639 113 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 102.855 +/- 0.141 114 | BENCHY::SUMMARY::FULL average trial throughput: 101.400 +/- 0.020 115 | ``` 116 | 117 | ### Enabling Mixed Precision Training [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#enabling-mixed-precision-training) 118 | 119 | To enable mixed precision training use the `--amp_mode fp16` flag. On Summit the bf16 won't work. 120 | 121 | ``` 122 | bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --data_loader_config=dali-lowmem --amp_mode fp16 --enable_benchy" 123 | ``` 124 | output 125 | ``` 126 | BENCHY::SUMMARY::IO average trial throughput: 907.329 +/- 69.924 127 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 270.235 +/- 0.014 128 | BENCHY::SUMMARY::FULL average trial throughput: 260.223 +/- 1.274 129 | ``` 130 | 131 | ### Just-in-time (JIT) compiliation and APEX fused optimizers [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#just-in-time-jit-compiliation-and-apex-fused-optimizers) 132 | 133 | To enable APEX use the `--enable_apex` flag. 134 | 135 | ``` 136 | bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --data_loader_config=dali-lowmem --amp_mode fp16 --enable_apex --enable_benchy" 137 | ``` 138 | output 139 | ``` 140 | BENCHY::SUMMARY::IO average trial throughput: 602.045 +/- 60.880 141 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 298.335 +/- 0.317 142 | BENCHY::SUMMARY::FULL average trial throughput: 281.767 +/- 7.210 143 | ``` 144 | 145 | To enable JIT use the `--enable_jit` flag. 146 | 147 | ``` 148 | bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=short_sm --num_epochs 15 --num_data_workers 7 --data_loader_config=dali-lowmem --amp_mode fp16 --enable_apex --enable_jit --enable_benchy" 149 | ``` 150 | output 151 | ``` 152 | BENCHY::SUMMARY::IO average trial throughput: 1303.459 +/- 0.094 153 | BENCHY::SUMMARY:: SYNTHETIC average trial throughput: 294.373 +/- 0.021 154 | BENCHY::SUMMARY::FULL average trial throughput: 285.010 +/- 0.068 155 | ``` 156 | 157 | ### Full training with optimizations 158 | 159 | Now you can run the full model training on a single GPU with our optimizations. For convenience, we provide a configuration with the optimizations already enabled for 3 epochs. Submit the full training with: 160 | 161 | ``` 162 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 1 --config=bs32_opt_sm" 163 | ``` 164 | output 165 | ``` 166 | 2022-11-08 04:23:38,659 - root - INFO - Time taken for epoch 1 is 774.4071238040924 sec, avg 84.5859987421436 samples/sec 167 | 2022-11-08 04:23:38,660 - root - INFO - Avg train loss=0.010418 168 | 2022-11-08 04:24:13,654 - root - INFO - Avg val loss=0.006322 169 | 2022-11-08 04:24:13,654 - root - INFO - Total validation time: 34.99340867996216 sec 170 | 2022-11-08 04:34:57,201 - root - INFO - Time taken for epoch 2 is 643.5462810993195 sec, avg 101.83572172626032 samples/sec 171 | 2022-11-08 04:34:57,201 - root - INFO - Avg train loss=0.005482 172 | 2022-11-08 04:35:26,425 - root - INFO - Avg val loss=0.005525 173 | 2022-11-08 04:35:26,425 - root - INFO - Total validation time: 29.22293472290039 sec 174 | 2022-11-08 04:46:09,960 - root - INFO - Time taken for epoch 3 is 643.5349762439728 sec, avg 101.83751065482791 samples/sec 175 | 2022-11-08 04:46:09,961 - root - INFO - Avg train loss=0.005130 176 | 2022-11-08 04:46:39,230 - root - INFO - Avg val loss=0.005329 177 | 2022-11-08 04:46:39,230 - root - INFO - Total validation time: 29.268264055252075 sec 178 | ``` 179 | 180 | ![baseline_tb_summit](tutorial_images/baseline_tb_summit.png) 181 | 182 | ## Distributed GPU training [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#distributed-gpu-training) 183 | 184 | ### Large batch convergence [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#large-batch-convergence) 185 | 186 | To enable multi-gpu training, we need to have the `-g 6`, since Summit has six GPUs per node, and the `-nnodes` for the desired number of nodes. 187 | 188 | As a first attempt, let's try increasing the batchsize from 32 to 576, distributing our training across 18 GPUs (thus 3 Summit nodes). To submit a job with this config, do 189 | 190 | ``` 191 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 3 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs576_test_sm" 192 | ``` 193 | output 194 | ``` 195 | 2022-11-08 05:24:02,375 - root - INFO - Time taken for epoch 1 is 232.76213240623474 sec, avg 279.6331144036922 samples/sec 196 | 2022-11-08 05:24:02,375 - root - INFO - Avg train loss=0.060278 197 | 2022-11-08 05:24:17,270 - root - INFO - Avg val loss=0.023173 198 | 2022-11-08 05:24:17,271 - root - INFO - Total validation time: 14.894511938095093 sec 199 | 2022-11-08 05:24:54,600 - root - INFO - Time taken for epoch 2 is 37.32917857170105 sec, avg 1759.0529047906603 samples/sec 200 | 2022-11-08 05:24:54,601 - root - INFO - Avg train loss=0.013236 201 | 2022-11-08 05:24:56,534 - root - INFO - Avg val loss=0.010874 202 | 2022-11-08 05:24:56,534 - root - INFO - Total validation time: 1.9319498538970947 sec 203 | 2022-11-08 05:25:33,315 - root - INFO - Time taken for epoch 3 is 36.77945351600647 sec, avg 1785.3446346456165 samples/sec 204 | 2022-11-08 05:25:33,315 - root - INFO - Avg train loss=0.008215 205 | 2022-11-08 05:25:35,192 - root - INFO - Avg val loss=0.008668 206 | 2022-11-08 05:25:35,253 - root - INFO - Total validation time: 1.876542568206787 sec 207 | ``` 208 | 209 | Looking at the TensorBoard log, we can see that the rate of convergence is increased initially, but the validation loss plateaus quickly and our final accuracy ends up worse than the single-GPU training: 210 | ![bs576_short_summit](tutorial_images/bs576_short_summit.png) 211 | 212 | If we increase the total number of epochs, we will run longer (thus giving the model more training iterations to update weights) and the learning rate will decay more slowly, giving us more time to converge quickly with a larger learning rate. To try this out, run the `bs576_opt_sm` config, which runs for 12 epochs rather than 3 on 18 GPUs as well: 213 | 214 | ``` 215 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 3 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs576_opt_sm" 216 | ``` 217 | output 218 | ``` 219 | 2022-11-08 05:38:52,219 - root - INFO - Time taken for epoch 1 is 207.89257979393005 sec, avg 313.0847674530633 samples/sec 220 | 2022-11-08 05:38:52,220 - root - INFO - Avg train loss=0.059437 221 | 2022-11-08 05:39:06,496 - root - INFO - Avg val loss=0.024136 222 | 2022-11-08 05:39:06,496 - root - INFO - Total validation time: 14.276114225387573 sec 223 | 2022-11-08 05:39:43,781 - root - INFO - Time taken for epoch 2 is 37.28419780731201 sec, avg 1761.1750784972573 samples/sec 224 | 2022-11-08 05:39:43,781 - root - INFO - Avg train loss=0.013361 225 | 2022-11-08 05:39:45,970 - root - INFO - Avg val loss=0.011936 226 | 2022-11-08 05:39:45,971 - root - INFO - Total validation time: 2.1886651515960693 sec 227 | 2022-11-08 05:40:23,166 - root - INFO - Time taken for epoch 3 is 37.195136308670044 sec, avg 1765.392105437559 samples/sec 228 | 2022-11-08 05:40:23,167 - root - INFO - Avg train loss=0.008965 229 | 2022-11-08 05:40:25,059 - root - INFO - Avg val loss=0.009923 230 | 2022-11-08 05:40:25,083 - root - INFO - Total validation time: 1.8922131061553955 sec 231 | ``` 232 | 233 | With the longer training, we can see that our higher batch size results are slightly better than the baseline configuration. Furthermore, the minimum in the loss is reached sooner, despite running for more epochs: 234 | ![bs576_summit](tutorial_images/bs576_summit.png) 235 | 236 | For 72 GPUs with 48 epochs and batch size of 2304 237 | 238 | ``` 239 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 12 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs2304_opt_sm" 240 | ``` 241 | output 242 | ``` 243 | 2022-11-08 05:57:22,846 - root - INFO - Time taken for epoch 1 is 119.31746411323547 sec, avg 540.6752521892049 samples/sec 244 | 2022-11-08 05:57:22,847 - root - INFO - Avg train loss=0.109694 245 | 2022-11-08 05:57:25,648 - root - INFO - Avg val loss=0.081074 246 | 2022-11-08 05:57:25,648 - root - INFO - Total validation time: 2.8007771968841553 sec 247 | 2022-11-08 05:57:36,728 - root - INFO - Time taken for epoch 2 is 11.079113721847534 sec, avg 6030.807308010724 samples/sec 248 | 2022-11-08 05:57:36,728 - root - INFO - Avg train loss=0.045270 249 | 2022-11-08 05:57:37,957 - root - INFO - Avg val loss=0.037623 250 | 2022-11-08 05:57:37,959 - root - INFO - Total validation time: 1.2282912731170654 sec 251 | 2022-11-08 05:57:47,571 - root - INFO - Time taken for epoch 3 is 9.61080002784729 sec, avg 6952.178778707356 samples/sec 252 | 2022-11-08 05:57:47,571 - root - INFO - Avg train loss=0.026268 253 | 2022-11-08 05:57:48,132 - root - INFO - Avg val loss=0.033672 254 | 2022-11-08 05:57:48,146 - root - INFO - Total validation time: 0.5604500770568848 sec 255 | 256 | ``` 257 | 258 | For 288 GPUs with 96 epochs and batch size of 9216. 259 | 260 | ``` 261 | bsub -P trn001 -W 2:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 48 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs9216_opt_sm" 262 | ``` 263 | output 264 | ``` 265 | 2022-11-08 06:06:07,478 - root - INFO - Time taken for epoch 1 is 56.57166934013367 sec, avg 1140.3587829824428 samples/sec 266 | 2022-11-08 06:06:07,478 - root - INFO - Avg train loss=0.138469 267 | 2022-11-08 06:06:07,489 - root - INFO - Avg val loss=nan 268 | 2022-11-08 06:06:07,489 - root - INFO - Total validation time: 0.00026726722717285156 sec 269 | 2022-11-08 06:06:44,690 - root - INFO - Time taken for epoch 2 is 37.20052146911621 sec, avg 1981.9077015145829 samples/sec 270 | 2022-11-08 06:06:44,691 - root - INFO - Avg train loss=0.108786 271 | 2022-11-08 06:06:47,000 - root - INFO - Avg val loss=0.116034 272 | 2022-11-08 06:06:47,000 - root - INFO - Total validation time: 2.3085243701934814 sec 273 | 2022-11-08 06:07:00,177 - root - INFO - Time taken for epoch 3 is 13.176549434661865 sec, avg 5595.3950892525145 samples/sec 274 | 2022-11-08 06:07:00,177 - root - INFO - Avg train loss=0.078056 275 | 2022-11-08 06:07:01,561 - root - INFO - Avg val loss=0.079437 276 | 2022-11-08 06:08:40,605 - root - INFO - Total validation time: 1.3832640647888184 sec 277 | ``` 278 | 279 | And a summary plot of the above runs: 280 | bs_compare_summit.png 281 | ![bs_compare_summit](tutorial_images/bs_compare_summit.png) 282 | 283 | ## Multi-GPU performance profiling and optimization 284 | 285 | You can find example json logs from benchy that run on 1 to 32 Summit nodes and made the plots bellow on the `summit_scaling_logs` directory. An example script to make the scaling plots bellow is here: `summit_scaling_logs/plot_weak_scale.py`. To run this on Summit setup a python env with `module load python/3.8-anaconda3`. 286 | 287 | ### Weak and Strong Throughput Scaling [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#weak-and-strong-throughput-scaling) 288 | 289 | First we want to measure the scaling efficiency. An example command to generate the points for 8 nodes is: 290 | 291 | ``` 292 | BENCHY_OUTPUT=weak_scale_8 bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 8 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs32_opt_sm --num_epochs 15 --local_batch_size 32 --enable_benchy" 293 | ``` 294 | 295 | Next we can further breakdown the performance of the applications, by switching off the communication between workers. An example command to generate the points for 8 nodes and adding the noddp flag is: 296 | 297 | ``` 298 | BENCHY_OUTPUT=weak_scale_8_noddp bsub -P trn001 -W 1:00 -J sc22.tut -o logs/sc22.tut.o%J -nnodes 8 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=bs32_opt_sm --num_epochs 15 --local_batch_size 32 --enable_benchy --noddp" 299 | ``` 300 | 301 | ### Profiling with Nsight Systems [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#profiling-with-nsight-systems-1) 302 | 303 | Using the optimized options for compute and I/O, we profile the communication baseline with 304 | 6 GPUs (1 node) on Summit: 305 | ``` 306 | ENABLE_PROFILING=1 PROFILE_OUTPUT=6gpu_baseline bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.baseline.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=short_opt_sm --num_epochs 8 --local_batch_size 8 --enable_manual_profiling" 307 | ``` 308 | the example proflie [`6gpu_baseline.qdrep`](sample_nsys_profiles/6gpu_baseline.qdrep) can be viewed via Nsight Systems, which looks like 309 | ![NSYS 6gpu_Baseline](tutorial_images/nsys_summit_6gpu_baseline.png) 310 | 311 | By default, for our model there are 8 NCCL calls per iteration, as shown in zoomed-in view: 312 | ![NSYS 6gpu_Baseline_zoomed](tutorial_images/nsys_summit_6gpu_baseline_zoomed.png) 313 | 314 | The performance of this run: 315 | ``` 316 | 2022-11-09 09:03:03,764 - root - INFO - Time taken for epoch 2 is 34.94349765777588 sec, avg 118.13356637701686 samples/sec 317 | 2022-11-09 09:03:03,770 - root - INFO - Avg train loss=0.041073 318 | 2022-11-09 09:03:04,178 - root - INFO - Avg val loss=0.031329 319 | 2022-11-09 09:03:04,178 - root - INFO - Total validation time: 0.40726447105407715 sec 320 | 2022-11-09 09:03:13,658 - root - INFO - Time taken for epoch 3 is 9.478790998458862 sec, avg 435.4985778957636 samples/sec 321 | 2022-11-09 09:03:13,658 - root - INFO - Avg train loss=0.019790 322 | 2022-11-09 09:03:14,071 - root - INFO - Avg val loss=0.021612 323 | 2022-11-09 09:03:14,072 - root - INFO - Total validation time: 0.41242146492004395 sec 324 | 2022-11-09 09:03:23,531 - root - INFO - Time taken for epoch 4 is 9.458720207214355 sec, avg 436.42267765267985 samples/sec 325 | 2022-11-09 09:03:23,531 - root - INFO - Avg train loss=0.013929 326 | 2022-11-09 09:03:23,946 - root - INFO - Avg val loss=0.016024 327 | 2022-11-09 09:03:23,946 - root - INFO - Total validation time: 0.41402602195739746 sec 328 | ``` 329 | 330 | ### Adjusting DistributedDataParallel options [(Look Perlmutter section for more details)](https://github.com/tsaris/sc22-dl-tutorial#adjusting-distributeddataparallel-options) 331 | 332 | Since there is no batch norm layer in our model, it's safe to disable the `broadcast_buffers` with the added knob `--disable_broadcast_buffers`: 333 | ``` 334 | ENABLE_PROFILING=1 PROFILE_OUTPUT=6gpu_nobroadcast bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.nobroadcast.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=short_opt_sm --num_epochs 8 --local_batch_size 8 --enable_manual_profiling --disable_broadcast_buffers" 335 | ``` 336 | The performance of this run: 337 | ``` 338 | 2022-11-09 09:03:27,409 - root - INFO - Time taken for epoch 2 is 39.85852289199829 sec, avg 103.56630653838674 samples/sec 339 | 2022-11-09 09:03:27,415 - root - INFO - Avg train loss=0.039671 340 | 2022-11-09 09:03:27,819 - root - INFO - Avg val loss=0.030174 341 | 2022-11-09 09:03:27,820 - root - INFO - Total validation time: 0.40375518798828125 sec 342 | 2022-11-09 09:03:37,169 - root - INFO - Time taken for epoch 3 is 9.348496437072754 sec, avg 441.5683343077338 samples/sec 343 | 2022-11-09 09:03:37,175 - root - INFO - Avg train loss=0.019775 344 | 2022-11-09 09:03:37,579 - root - INFO - Avg val loss=0.021372 345 | 2022-11-09 09:03:37,580 - root - INFO - Total validation time: 0.4038228988647461 sec 346 | 2022-11-09 09:03:46,932 - root - INFO - Time taken for epoch 4 is 9.35121488571167 sec, avg 441.4399680096583 samples/sec 347 | 2022-11-09 09:03:46,938 - root - INFO - Avg train loss=0.013456 348 | 2022-11-09 09:03:47,342 - root - INFO - Avg val loss=0.016038 349 | 2022-11-09 09:03:47,343 - root - INFO - Total validation time: 0.4036595821380615 sec 350 | ``` 351 | The generated profile [`6gpu_nobroadcast.qdrep`](sample_nsys_profiles/6gpu_nobroadcast.qdrep) and it looks like 352 | ![NSYS 6gpu_nobroadcast](tutorial_images/nsys_summit_6gpu_nobroadcast.png) 353 | 354 | To show the effect of the message bucket size, we add another knob to the code, `--bucket_cap_mb`. We profile a run with 100 mb bucket size with following command: 355 | ``` 356 | ENABLE_PROFILING=1 PROFILE_OUTPUT=6gpu_bucketcap100mb bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.100mb.o%J -nnodes 1 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=short_opt_sm --num_epochs 8 --local_batch_size 8 --enable_manual_profiling --disable_broadcast_buffers --bucket_cap_mb 100" 357 | ``` 358 | The performance of this run: 359 | ``` 360 | 2022-11-09 09:03:25,663 - root - INFO - Time taken for epoch 2 is 39.40925359725952 sec, avg 104.74697242900984 samples/sec 361 | 2022-11-09 09:03:25,669 - root - INFO - Avg train loss=0.040747 362 | 2022-11-09 09:03:26,075 - root - INFO - Avg val loss=0.031916 363 | 2022-11-09 09:03:26,075 - root - INFO - Total validation time: 0.40433478355407715 sec 364 | 2022-11-09 09:03:35,384 - root - INFO - Time taken for epoch 3 is 9.308758020401001 sec, avg 443.45335768242205 samples/sec 365 | 2022-11-09 09:03:35,391 - root - INFO - Avg train loss=0.021065 366 | 2022-11-09 09:03:35,799 - root - INFO - Avg val loss=0.020096 367 | 2022-11-09 09:03:35,800 - root - INFO - Total validation time: 0.40817737579345703 sec 368 | 2022-11-09 09:03:45,086 - root - INFO - Time taken for epoch 4 is 9.285616397857666 sec, avg 444.5585325872812 samples/sec 369 | 2022-11-09 09:03:45,092 - root - INFO - Avg train loss=0.013695 370 | 2022-11-09 09:03:45,500 - root - INFO - Avg val loss=0.015974 371 | 2022-11-09 09:03:45,500 - root - INFO - Total validation time: 0.407240629196167 sec 372 | ``` 373 | and corresponding profile [`6gpu_bucketcap100mb.qdrep`](sample_nsys_profiles/6gpu_bucketcap100mb.qdrep), which looks like 374 | ![NSYS 6gpu_bucketcap100mb_zoomed](tutorial_images/nsys_summit_6gpu_bucketcap100mb_zoomed.png) 375 | The total number of NCCL calls per step now reduced to 5. 376 | 377 | Similarly, to understand the cross node performance, we run the baseline and optimized options with 2 nodes on Summit: 378 | 379 | Baseline: 380 | ``` 381 | bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.n2.baseline.o%J -nnodes 2 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=short_opt_sm --num_epochs 8 --local_batch_size 8" 382 | ``` 383 | and the performance of the run: 384 | ``` 385 | 2022-11-09 08:52:51,817 - root - INFO - Time taken for epoch 6 is 4.794616460800171 sec, avg 860.9656337998473 samples/sec 386 | 2022-11-09 08:52:51,821 - root - INFO - Avg train loss=0.014148 387 | 2022-11-09 08:52:52,177 - root - INFO - Avg val loss=0.016380 388 | 2022-11-09 08:52:52,178 - root - INFO - Total validation time: 0.3557288646697998 sec 389 | 2022-11-09 08:52:56,955 - root - INFO - Time taken for epoch 7 is 4.7768919467926025 sec, avg 864.1602209092682 samples/sec 390 | 2022-11-09 08:52:56,960 - root - INFO - Avg train loss=0.012832 391 | 2022-11-09 08:52:57,206 - root - INFO - Avg val loss=0.014343 392 | 2022-11-09 08:52:57,206 - root - INFO - Total validation time: 0.2450239658355713 sec 393 | 2022-11-09 08:53:01,999 - root - INFO - Time taken for epoch 8 is 4.791579484939575 sec, avg 861.5113268964287 samples/sec 394 | 2022-11-09 08:53:01,999 - root - INFO - Avg train loss=0.011258 395 | 2022-11-09 08:53:02,249 - root - INFO - Avg val loss=0.013502 396 | 2022-11-09 08:53:02,250 - root - INFO - Total validation time: 0.24966740608215332 sec 397 | ``` 398 | 399 | Optimized: 400 | ``` 401 | bsub -P trn001 -W 0:30 -J sc22.tut -o logs/sc22.tut.n2.100mb.o%J -nnodes 2 -alloc_flags "gpumps smt4" "./submit_summit.sh -g 6 --config=short_opt_sm --num_epochs 8 --local_batch_size 8 --disable_broadcast_buffers --bucket_cap_mb 100" 402 | ``` 403 | and the performance of the run: 404 | ``` 405 | 2022-11-09 08:52:45,939 - root - INFO - Time taken for epoch 6 is 4.755844593048096 sec, avg 867.9846280162615 samples/sec 406 | 2022-11-09 08:52:45,944 - root - INFO - Avg train loss=0.014258 407 | 2022-11-09 08:52:46,298 - root - INFO - Avg val loss=0.016293 408 | 2022-11-09 08:52:46,298 - root - INFO - Total validation time: 0.3531801700592041 sec 409 | 2022-11-09 08:52:51,017 - root - INFO - Time taken for epoch 7 is 4.719024419784546 sec, avg 874.7570753593325 samples/sec 410 | 2022-11-09 08:52:51,022 - root - INFO - Avg train loss=0.012921 411 | 2022-11-09 08:52:51,267 - root - INFO - Avg val loss=0.015027 412 | 2022-11-09 08:52:51,267 - root - INFO - Total validation time: 0.24351286888122559 sec 413 | 2022-11-09 08:52:55,998 - root - INFO - Time taken for epoch 8 is 4.730204105377197 sec, avg 872.6896150860331 samples/sec 414 | 2022-11-09 08:52:56,003 - root - INFO - Avg train loss=0.011467 415 | 2022-11-09 08:52:56,248 - root - INFO - Avg val loss=0.014088 416 | 2022-11-09 08:52:56,249 - root - INFO - Total validation time: 0.24413108825683594 sec 417 | ``` 418 | -------------------------------------------------------------------------------- /benchy-conf.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | report_freq: 5 # frequency of terminal throughput output 3 | exit_after_tests: True # terminate process after benchy trials are completed 4 | profiler_mode: 'single' # controls what processes start CUDA profiling (options: 'single', 'all', 'none') 5 | output_filename: 'benchy_output.json' # output filename for JSON output 6 | output_dir: './' # directory to write output file 7 | use_distributed_barrier: False # enable/disable use of process global barrier for timings 8 | 9 | IO: 10 | run_benchmark: True # enable/disable IO benchmarking 11 | nbatches: 50 # number of batches to run per trial 12 | ntrials: 3 # number of trials to run 13 | nwarmup: 3 # number of warmup trials to run 14 | 15 | synthetic: 16 | run_benchmark: True # enable/disable synthetic/cached data benchmarking 17 | nbatches: 50 # number of batches to run per trial 18 | ntrials: 3 # number of trials to run 19 | nwarmup: 3 # number of warmup trials to run 20 | 21 | full: 22 | run_benchmark: True # enable/disable full training benchmarking 23 | nbatches: 50 # number of batches to run per trial 24 | ntrials: 3 # number of trials to run 25 | nwarmup: 3 # number of warmup trials to run 26 | -------------------------------------------------------------------------------- /config/UNet.yaml: -------------------------------------------------------------------------------- 1 | base: &base 2 | 3 | # Training config 4 | weight_init: {conv_init: 'normal', conv_scale: 0.02, conv_bias: 0.} 5 | lambda_rho: 0 # weight for additional rho loss term 6 | full_scale: True # whether or not to use all 6 of the scales in U-Net 7 | global_batch_size: 64 # number of samples per training batch 8 | base_batch_size: 64 # single GPU batch size 9 | Nsamples: 65536 10 | Nsamples_val: 8192 11 | num_epochs: 10 12 | amp_mode: none 13 | enable_apex: False 14 | enable_benchy: False 15 | enable_jit: False 16 | expdir: '/logs' 17 | 18 | # params for setting learning rate (cosine decay schedule): 19 | # start_lr: initial learning rate 20 | # end_lr: final learning rate 21 | # warmup_steps: number of steps over which to do linear warm-up of learning rate 22 | # *not used when training single-GPU or when scaling='none'* 23 | # scaling: 'none' initial lr doesn't change with respect to global batch size 24 | # 'linear' scale up according to lr_global_batchsize = (global_batchsize/base_batchsize)*lr_base_batchsize 25 | # 'sqrt' scale up according to lr_global_batchsize = sqrt(global_batchsize/base_batchsize)*lr_base_batchsize 26 | lr_schedule: {scaling: 'sqrt', start_lr: 2.E-4, end_lr: 0., warmup_steps: 128} 27 | 28 | # Data 29 | data_loader_config: 'lowmem' # choices: 'synthetic', 'inmem', 'lowmem', 'dali-lowmem' 30 | box_size: [1024, 512] # total size of simulation boxes (train, validation) 31 | data_size: 64 # size of crops for training 32 | num_data_workers: 2 # number of dataloader worker threads per proc 33 | N_out_channels: 5 34 | # HDF5 files for PyTorch native dataloader 35 | train_path: '/data/downsamp_2048crop_train.h5' 36 | val_path: '/data/downsamp_1024crop_valid.h5' 37 | # numpy files for DALI dataloader 38 | train_path_npy_data: '/data/downsamp_2048crop_train_data.npy' 39 | train_path_npy_label: '/data/downsamp_2048crop_train_label.npy' 40 | val_path_npy_data: '/data/downsamp_1024crop_valid_data.npy' 41 | val_path_npy_label: '/data/downsamp_1024crop_valid_label.npy' 42 | use_cache: None # set this to a cache dir (e.g., NVMe on CoriGPU) if you copied data there 43 | 44 | # A short config for testing/profiling on a single A100 45 | short: &short 46 | <<: *base 47 | Nsamples: 4096 48 | Nsamples_val: 512 49 | num_epochs: 4 50 | 51 | # Short config with full optimizations 52 | short_opt: 53 | <<: *short 54 | data_loader_config: 'dali-lowmem' 55 | num_data_workers: 8 56 | amp_mode: fp16 57 | enable_apex: True 58 | enable_jit: True 59 | 60 | # Full training, batch size 64, with optimizations 61 | bs64_opt: &bs64_opt 62 | <<: *base 63 | data_loader_config: 'dali-lowmem' 64 | global_batch_size: 64 65 | amp_mode: fp16 66 | enable_apex: True 67 | enable_jit: True 68 | 69 | bs256_opt: # 4 GPUs 70 | <<: *bs64_opt 71 | global_batch_size: 256 72 | num_epochs: 20 73 | 74 | bs512_test: # 8 GPUs 75 | <<: *bs64_opt 76 | global_batch_size: 512 77 | 78 | bs512_opt: # 8 GPUs 79 | <<: *bs64_opt 80 | global_batch_size: 512 81 | num_epochs: 40 82 | 83 | bs2048_opt: # 32 GPUs 84 | <<: *bs64_opt 85 | global_batch_size: 2048 86 | num_epochs: 160 87 | 88 | # Warning: training may be unstable 89 | bs8192_opt: # 128 GPUs 90 | <<: *bs64_opt 91 | global_batch_size: 8192 92 | num_epochs: 320 93 | 94 | 95 | ####################################################### 96 | ##### Summit configs ################################## 97 | ####################################################### 98 | 99 | # A short config for testing/profiling on a single V100 100 | base_sm: &base_sm 101 | <<: *base 102 | # Training config 103 | global_batch_size: 32 # number of samples per training batch 104 | base_batch_size: 32 # single GPU batch size 105 | expdir: './logs/' 106 | 107 | # Data 108 | # HDF5 files for PyTorch native dataloader 109 | train_path: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_2048crop_train.h5' 110 | val_path: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_1024crop_valid.h5' 111 | # numpy files for DALI dataloader 112 | train_path_npy_data: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_2048crop_train_data.npy' 113 | train_path_npy_label: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_2048crop_train_label.npy' 114 | val_path_npy_data: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_1024crop_valid_data.npy' 115 | val_path_npy_label: '/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data/downsamp_1024crop_valid_label.npy' 116 | 117 | # A short config for testing/profiling on a single V100 118 | shorter_sm: &shorter_sm 119 | <<: *base_sm 120 | Nsamples: 512 121 | Nsamples_val: 64 122 | num_epochs: 4 123 | 124 | # A short config for testing/profiling on a single V100 125 | short_sm: &short_sm 126 | <<: *base_sm 127 | Nsamples: 4096 128 | Nsamples_val: 512 129 | num_epochs: 4 130 | 131 | # Short config with full optimizations 132 | short_opt_sm: 133 | <<: *short_sm 134 | data_loader_config: 'dali-lowmem' 135 | num_data_workers: 7 136 | amp_mode: fp16 137 | enable_apex: True 138 | enable_jit: True 139 | 140 | # Full training, batch size 32, with optimizations 141 | bs32_opt_sm: &bs32_opt_sm 142 | <<: *base_sm 143 | data_loader_config: 'dali-lowmem' 144 | num_data_workers: 7 145 | amp_mode: fp16 146 | enable_apex: True 147 | enable_jit: True 148 | global_batch_size: 32 149 | num_epochs: 3 150 | 151 | bs576_test_sm: # 18 GPUs (3 nodes) 152 | <<: *bs32_opt_sm 153 | global_batch_size: 576 154 | 155 | bs576_opt_sm: # 18 GPUs (3 nodes) 156 | <<: *bs32_opt_sm 157 | global_batch_size: 576 158 | num_epochs: 12 159 | 160 | bs2304_opt_sm: # 72 GPUs (12 nodes) 161 | <<: *bs32_opt_sm 162 | global_batch_size: 2304 163 | num_epochs: 48 164 | 165 | bs9216_opt_sm: # 288 GPUs (48 nodes) 166 | <<: *bs32_opt_sm 167 | global_batch_size: 9216 168 | num_epochs: 96 169 | 170 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.10-py3 2 | 3 | # Upgrade pip 4 | RUN python -m pip install -U setuptools pip 5 | 6 | # Install pip dependencies 7 | RUN pip install ruamel.yaml && \ 8 | pip install h5py 9 | 10 | # Install benchy lib 11 | RUN git clone https://github.com/romerojosh/benchy.git && \ 12 | cd benchy && \ 13 | python setup.py install && \ 14 | cd ../ && rm -rf benchy 15 | -------------------------------------------------------------------------------- /example_logs/base/1GPU/00/logs/events.out.tfevents.1636677146.nid001104.65966.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/example_logs/base/1GPU/00/logs/events.out.tfevents.1636677146.nid001104.65966.0 -------------------------------------------------------------------------------- /example_logs/base/1GPU/00/out.log: -------------------------------------------------------------------------------- 1 | 2021-11-12 00:32:26,118 - root - INFO - ------------------ Configuration ------------------ 2 | 2021-11-12 00:32:26,118 - root - INFO - Configuration file: /global/u2/p/pharring/production/sc21/sc21-dl-tutorial/config/UNet.yaml 3 | 2021-11-12 00:32:26,118 - root - INFO - Configuration name: base 4 | 2021-11-12 00:32:26,118 - root - INFO - weight_init ordereddict([('conv_init', 'normal'), ('conv_scale', 0.02), ('conv_bias', 0.0)]) 5 | 2021-11-12 00:32:26,118 - root - INFO - lambda_rho 0 6 | 2021-11-12 00:32:26,119 - root - INFO - full_scale True 7 | 2021-11-12 00:32:26,119 - root - INFO - global_batch_size 64 8 | 2021-11-12 00:32:26,119 - root - INFO - base_batch_size 64 9 | 2021-11-12 00:32:26,119 - root - INFO - Nsamples 65536 10 | 2021-11-12 00:32:26,119 - root - INFO - Nsamples_val 8192 11 | 2021-11-12 00:32:26,119 - root - INFO - num_epochs 10 12 | 2021-11-12 00:32:26,119 - root - INFO - enable_amp False 13 | 2021-11-12 00:32:26,119 - root - INFO - enable_apex False 14 | 2021-11-12 00:32:26,119 - root - INFO - enable_benchy False 15 | 2021-11-12 00:32:26,119 - root - INFO - enable_jit False 16 | 2021-11-12 00:32:26,119 - root - INFO - expdir /logs 17 | 2021-11-12 00:32:26,119 - root - INFO - lr_schedule ordereddict([('scaling', 'sqrt'), ('start_lr', 0.0002), ('end_lr', 0.0), ('warmup_steps', 128)]) 18 | 2021-11-12 00:32:26,119 - root - INFO - data_loader_config lowmem 19 | 2021-11-12 00:32:26,119 - root - INFO - box_size [1024, 512] 20 | 2021-11-12 00:32:26,119 - root - INFO - data_size 64 21 | 2021-11-12 00:32:26,119 - root - INFO - num_data_workers 2 22 | 2021-11-12 00:32:26,119 - root - INFO - N_out_channels 5 23 | 2021-11-12 00:32:26,119 - root - INFO - train_path /data/downsamp_2048crop_train.h5 24 | 2021-11-12 00:32:26,119 - root - INFO - val_path /data/downsamp_1024crop_valid.h5 25 | 2021-11-12 00:32:26,119 - root - INFO - train_path_npy_data /data/downsamp_2048crop_train_data.npy 26 | 2021-11-12 00:32:26,119 - root - INFO - train_path_npy_label /data/downsamp_2048crop_train_label.npy 27 | 2021-11-12 00:32:26,119 - root - INFO - val_path_npy_data /data/downsamp_1024crop_valid_data.npy 28 | 2021-11-12 00:32:26,119 - root - INFO - val_path_npy_label /data/downsamp_1024crop_valid_label.npy 29 | 2021-11-12 00:32:26,119 - root - INFO - use_cache None 30 | 2021-11-12 00:32:26,119 - root - INFO - --------------------------------------------------- 31 | 2021-11-12 00:32:26,150 - root - INFO - rank 0, begin data loader init 32 | 2021-11-12 00:32:26,441 - root - INFO - rank 0, data loader initialized with config lowmem 33 | 2021-11-12 00:32:47,010 - root - INFO - Starting Training Loop... 34 | 2021-11-12 00:46:23,492 - root - INFO - Time taken for epoch 1 is 752.9529612064362 sec, avg 87.03863770585806 samples/sec 35 | 2021-11-12 00:46:23,492 - root - INFO - Avg train loss=0.014280 36 | 2021-11-12 00:47:36,657 - root - INFO - Avg val loss=0.007129 37 | 2021-11-12 00:47:36,657 - root - INFO - Total validation time: 73.16440415382385 sec 38 | 2021-11-12 00:59:52,873 - root - INFO - Time taken for epoch 2 is 736.1673448085785 sec, avg 89.02323698837927 samples/sec 39 | 2021-11-12 00:59:52,874 - root - INFO - Avg train loss=0.006163 40 | 2021-11-12 01:01:03,747 - root - INFO - Avg val loss=0.006131 41 | 2021-11-12 01:01:03,747 - root - INFO - Total validation time: 70.87276315689087 sec 42 | 2021-11-12 01:13:09,400 - root - INFO - Time taken for epoch 3 is 725.6503672599792 sec, avg 90.31346631499791 samples/sec 43 | 2021-11-12 01:13:09,401 - root - INFO - Avg train loss=0.005703 44 | 2021-11-12 01:14:20,543 - root - INFO - Avg val loss=0.005793 45 | 2021-11-12 01:14:20,543 - root - INFO - Total validation time: 71.14171385765076 sec 46 | 2021-11-12 01:26:34,245 - root - INFO - Time taken for epoch 4 is 733.644660949707 sec, avg 89.32934905457267 samples/sec 47 | 2021-11-12 01:26:34,245 - root - INFO - Avg train loss=0.005422 48 | 2021-11-12 01:27:46,311 - root - INFO - Avg val loss=0.005409 49 | 2021-11-12 01:27:46,311 - root - INFO - Total validation time: 72.06519365310669 sec 50 | 2021-11-12 01:40:11,457 - root - INFO - Time taken for epoch 5 is 745.0990128517151 sec, avg 87.95609559214725 samples/sec 51 | 2021-11-12 01:40:11,458 - root - INFO - Avg train loss=0.005207 52 | 2021-11-12 01:41:38,000 - root - INFO - Avg val loss=0.005207 53 | 2021-11-12 01:41:38,000 - root - INFO - Total validation time: 86.54100251197815 sec 54 | 2021-11-12 01:54:08,192 - root - INFO - Time taken for epoch 6 is 750.1890060901642 sec, avg 87.35931807580144 samples/sec 55 | 2021-11-12 01:54:08,192 - root - INFO - Avg train loss=0.005045 56 | 2021-11-12 01:55:27,159 - root - INFO - Avg val loss=0.005085 57 | 2021-11-12 01:55:27,159 - root - INFO - Total validation time: 78.966463804245 sec 58 | 2021-11-12 02:07:54,065 - root - INFO - Time taken for epoch 7 is 746.9029471874237 sec, avg 87.74366234165463 samples/sec 59 | 2021-11-12 02:07:54,066 - root - INFO - Avg train loss=0.004935 60 | 2021-11-12 02:09:12,587 - root - INFO - Avg val loss=0.004972 61 | 2021-11-12 02:09:12,588 - root - INFO - Total validation time: 78.52156686782837 sec 62 | 2021-11-12 02:21:41,208 - root - INFO - Time taken for epoch 8 is 748.6173532009125 sec, avg 87.5427208837511 samples/sec 63 | 2021-11-12 02:21:41,208 - root - INFO - Avg train loss=0.004817 64 | 2021-11-12 02:22:59,974 - root - INFO - Avg val loss=0.004856 65 | 2021-11-12 02:22:59,974 - root - INFO - Total validation time: 78.76536536216736 sec 66 | 2021-11-12 02:35:29,456 - root - INFO - Time taken for epoch 9 is 749.4797053337097 sec, avg 87.44199413754608 samples/sec 67 | 2021-11-12 02:35:29,457 - root - INFO - Avg train loss=0.004782 68 | 2021-11-12 02:36:48,701 - root - INFO - Avg val loss=0.004855 69 | 2021-11-12 02:36:48,701 - root - INFO - Total validation time: 79.24383068084717 sec 70 | 2021-11-12 02:49:31,830 - root - INFO - Time taken for epoch 10 is 763.1262722015381 sec, avg 85.87831711118477 samples/sec 71 | 2021-11-12 02:49:31,831 - root - INFO - Avg train loss=0.004771 72 | 2021-11-12 02:50:57,833 - root - INFO - Avg val loss=0.004844 73 | 2021-11-12 02:50:57,834 - root - INFO - Total validation time: 86.00221180915833 sec 74 | 2021-11-12 02:50:57,853 - root - INFO - DONE ---- rank 0 75 | -------------------------------------------------------------------------------- /export_DDP_vars.sh: -------------------------------------------------------------------------------- 1 | export RANK=$SLURM_PROCID 2 | export LOCAL_RANK=$SLURM_LOCALID 3 | export WORLD_SIZE=$SLURM_NTASKS 4 | export MASTER_PORT=29500 # default from torch launcher 5 | -------------------------------------------------------------------------------- /export_DDP_vars_summit.sh: -------------------------------------------------------------------------------- 1 | export RANK=$OMPI_COMM_WORLD_RANK 2 | export LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK 3 | export WORLD_SIZE=$OMPI_COMM_WORLD_SIZE 4 | export MASTER_ADDR=$(cat $LSB_DJOB_HOSTFILE | sort | uniq | grep -v batch | grep -v login | head -1) 5 | export MASTER_PORT=29500 # default from torch launcher 6 | -------------------------------------------------------------------------------- /launch_summit.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | # 4 | # the following mapping is tied to "smt4" 5 | # this mapping may not be optimal for a given application 6 | # 7 | 8 | APP=$1 9 | 10 | grank=$PMIX_RANK 11 | lrank=$(($PMIX_RANK%6)) 12 | export PAMI_ENABLE_STRIPING=0 13 | 14 | case ${lrank} in 15 | [0]) 16 | export PAMI_IBV_DEVICE_NAME=mlx5_0:1 17 | export OMP_PLACES={0:28} 18 | numactl --physcpubind=0-27 --membind=0 $APP 19 | #${APP} 20 | ;; 21 | [1]) 22 | export PAMI_IBV_DEVICE_NAME=mlx5_1:1 23 | export OMP_PLACES={28:28} 24 | numactl --physcpubind=28-55 --membind=0 $APP 25 | #${APP} 26 | ;; 27 | [2]) 28 | export PAMI_IBV_DEVICE_NAME=mlx5_0:1 29 | export OMP_PLACES={56:28} 30 | numactl --physcpubind=56-83 --membind=0 $APP 31 | #${APP} 32 | ;; 33 | [3]) 34 | export PAMI_IBV_DEVICE_NAME=mlx5_3:1 35 | export OMP_PLACES={88:28} 36 | numactl --physcpubind=88-115 --membind=8 $APP 37 | #${APP} 38 | ;; 39 | [4]) 40 | export PAMI_IBV_DEVICE_NAME=mlx5_2:1 41 | export OMP_PLACES={116:28} 42 | numactl --physcpubind=116-143 --membind=8 $APP 43 | #${APP} 44 | ;; 45 | [5]) 46 | export PAMI_IBV_DEVICE_NAME=mlx5_3:1 47 | export OMP_PLACES={144:28} 48 | numactl --physcpubind=144-171 --membind=8 $APP 49 | #${APP} 50 | ;; 51 | esac 52 | -------------------------------------------------------------------------------- /networks/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def down_conv(in_channels, out_channels): 5 | return nn.Sequential( 6 | nn.Conv3d(in_channels, out_channels, 4, stride=2, padding=1), 7 | nn.LeakyReLU(inplace=True), 8 | ) 9 | 10 | def up_conv(in_channels, out_channels): 11 | return nn.Sequential( 12 | nn.ConvTranspose3d(in_channels, out_channels, 4, stride=2, padding=1, output_padding=0), 13 | nn.ReLU(inplace=True), 14 | ) 15 | 16 | 17 | def inverse_transf(x): 18 | return torch.exp(14.*x) 19 | 20 | 21 | def loss_func(gen_output, target, lambda_rho): 22 | 23 | # first part of the loss 24 | l1_loss = nn.functional.l1_loss(gen_output, target) 25 | 26 | if lambda_rho == 0.: 27 | return l1_loss 28 | 29 | # Transform T and rho back to original space, compute additional L1 30 | orig_gen = inverse_transf(gen_output[:,0,:,:,:]) 31 | orig_tar = inverse_transf(target[:,0,:,:,:]) 32 | orig_l1_loss = nn.functional.l1_loss(orig_gen, orig_tar) 33 | return l1_loss + lambda_rho * orig_l1_loss 34 | 35 | 36 | @torch.jit.script 37 | def loss_func_opt(gen_output: torch.Tensor, target: torch.Tensor, lambda_rho: float): 38 | 39 | # first part of the loss 40 | l1_loss = torch.mean(torch.abs(gen_output - target)) 41 | 42 | # Transform T and rho back to original space, compute additional L1 43 | orig_gen = inverse_transf(gen_output[:,0,:,:,:]) 44 | orig_tar = inverse_transf(target[:,0,:,:,:]) 45 | orig_l1_loss = torch.mean(torch.abs(orig_gen - orig_tar)) 46 | return l1_loss + lambda_rho * orig_l1_loss 47 | 48 | 49 | @torch.jit.script 50 | def loss_func_opt_final(gen_output: torch.Tensor, target: torch.Tensor, lambda_rho: torch.Tensor): 51 | 52 | # first part of the loss 53 | l1_loss = torch.abs(gen_output - target) 54 | 55 | # Transform T and rho back to original space, compute additional L1 56 | orig_gen = inverse_transf(gen_output) 57 | orig_tar = inverse_transf(target) 58 | orig_l1_loss = torch.abs(orig_gen - orig_tar) 59 | 60 | # combine 61 | loss = l1_loss + lambda_rho * orig_l1_loss 62 | 63 | return torch.mean(loss) 64 | 65 | 66 | class UNet(nn.Module): 67 | 68 | def __init__(self, params): 69 | super().__init__() 70 | self.full_scale = params.full_scale 71 | self.conv_down1 = down_conv(4, 64) 72 | self.conv_down2 = down_conv(64, 128) 73 | self.conv_down3 = down_conv(128, 256) 74 | self.conv_down4 = down_conv(256, 512) 75 | self.conv_down5 = down_conv(512, 512) 76 | if self.full_scale: 77 | self.conv_down6 = down_conv(512, 512) 78 | 79 | self.conv_up6 = up_conv(512, 512) 80 | self.conv_up5 = up_conv(512+512, 512) 81 | else: 82 | self.conv_up5 = up_conv(512, 512) 83 | self.conv_up4 = up_conv(512+512, 256) 84 | self.conv_up3 = up_conv(256+256, 128) 85 | self.conv_up2 = up_conv(128+128, 64) 86 | self.conv_last = nn.ConvTranspose3d(64+64, params.N_out_channels, 4, stride=2, padding=1, output_padding=0) 87 | self.tanh = nn.Tanh() 88 | 89 | 90 | def forward(self, x): 91 | conv1 = self.conv_down1(x) # 64 92 | conv2 = self.conv_down2(conv1) # 128 93 | conv3 = self.conv_down3(conv2) # 256 94 | conv4 = self.conv_down4(conv3) # 512 95 | conv5 = self.conv_down5(conv4) # 512 96 | if self.full_scale: 97 | conv6 = self.conv_down6(conv5) # 512 98 | 99 | x = self.conv_up6(conv6) # 512 100 | x = torch.cat([x, conv5], dim=1) 101 | else: 102 | x = conv5 103 | x = self.conv_up5(x) # 512 104 | x = torch.cat([x, conv4], dim=1) 105 | x = self.conv_up4(x) # 256 106 | x = torch.cat([x, conv3], dim=1) 107 | x = self.conv_up3(x) # 128 108 | x = torch.cat([x, conv2], dim=1) 109 | x = self.conv_up2(x) # 64 110 | x = torch.cat([x, conv1], dim=1) 111 | x = self.conv_last(x) # 5 112 | out = self.tanh(x) 113 | return out 114 | 115 | def get_weights_function(self, params): 116 | def weights_init(m): 117 | classname = m.__class__.__name__ 118 | if classname.find('Conv') != -1: 119 | nn.init.normal_(m.weight.data, 0.0, params['conv_scale']) 120 | if params['conv_bias'] is not None: 121 | m.bias.data.fill_(params['conv_bias']) 122 | return weights_init 123 | 124 | -------------------------------------------------------------------------------- /run_summit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ENABLE_PROFILING=${1} 4 | PROFILE_OUTPUT=${2} 5 | 6 | # Profiling 7 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ] && [ "$PMIX_RANK" -eq 0 ] ; then 8 | echo "Enabling profiling..." 9 | NSYS_ARGS="--trace=cuda,nvtx,osrt --kill none -c cudaProfilerApi -f true" 10 | NSYS_OUTPUT=${PROFILE_OUTPUT:-"profile"} 11 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 12 | fi 13 | 14 | echo ${PROFILE_CMD} 15 | 16 | source export_DDP_vars_summit.sh && \ 17 | ${PROFILE_CMD} python train.py ${@:3} 18 | 19 | -------------------------------------------------------------------------------- /sample_nsys_profiles/16workers.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/16workers.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/4gpu_baseline.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/4gpu_baseline.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/4gpu_bucketcap100mb.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/4gpu_bucketcap100mb.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/4gpu_nobroadcast.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/4gpu_nobroadcast.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/baseline.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/baseline.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/dali.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/dali.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/dali_amp.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp_apex_jit.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/dali_amp_apex_jit.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/summit_6gpu_baseline.qdrep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/summit_6gpu_baseline.qdrep -------------------------------------------------------------------------------- /sample_nsys_profiles/summit_6gpu_bucketcap100mb.qdrep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/summit_6gpu_bucketcap100mb.qdrep -------------------------------------------------------------------------------- /sample_nsys_profiles/summit_6gpu_nobroadcast.qdrep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/sample_nsys_profiles/summit_6gpu_nobroadcast.qdrep -------------------------------------------------------------------------------- /start_tensorboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "314d73bf-204b-410b-9b58-4db9a461e1c4", 6 | "metadata": {}, 7 | "source": [ 8 | "# TensorBoard Launcher\n", 9 | "\n", 10 | "This notebook allows you to start TensorBoard on Perlmutter and view it in a normal browser tab.\n", 11 | "\n", 12 | "The notebook code below assumes you are using the hands-on tutorial path for tensorboard logs.\n", 13 | "\n", 14 | "When you run the cells below, TensorBoard will start but will not display here in the notebook. Instead, the final cell which calls `nersc_tensorboard_helper.tb_address()` will display a URL that you can click to open a new tab with TensorBoard." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "97bcd738-b1f5-40ed-acb0-2b2572f741e2", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import os\n", 25 | "import nersc_tensorboard_helper\n", 26 | "\n", 27 | "%load_ext tensorboard" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "2b4a192d-f5a7-4f8a-becb-ae52ed9cc036", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "log_dir = os.path.expandvars('$SCRATCH/sc22-dl-tutorial/logs')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "4dd0750e-7f27-4a66-9fce-4be91ea88835", 44 | "metadata": { 45 | "tags": [] 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "%%capture\n", 50 | "%tensorboard --logdir $log_dir --port 0" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "7f4158e2-d17a-40f3-a6e5-961e2615e4c3", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "nersc_tensorboard_helper.tb_address()" 61 | ] 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "pytorch-1.9.0", 67 | "language": "python", 68 | "name": "pytorch-1.9.0" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 3 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython3", 80 | "version": "3.8.11" 81 | } 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 5 85 | } 86 | -------------------------------------------------------------------------------- /start_tensorboard_summit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "314d73bf-204b-410b-9b58-4db9a461e1c4", 6 | "metadata": {}, 7 | "source": [ 8 | "# TensorBoard Launcher\n", 9 | "\n", 10 | "This notebook allows you to start TensorBoard on Summit and view it in a normal browser tab.\n", 11 | "\n", 12 | "The notebook code below assumes you are using the hands-on tutorial path for tensorboard logs.\n", 13 | "\n", 14 | "When you run the cells below, TensorBoard will start but will not display here in the notebook. Instead, the final cell which calls `tb_address()` will display a URL that you can click to open a new tab with TensorBoard." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "id": "9aa0f14d-0542-4315-886c-4134b0ba8673", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "outputs": [ 25 | { 26 | "name": "stderr", 27 | "output_type": "stream", 28 | "text": [ 29 | "/tmp/ipykernel_244/1904480837.py:4: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", 30 | " from IPython.core.display import display, HTML\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "import os, pwd\n", 36 | "from tensorboard import notebook\n", 37 | "import getpass\n", 38 | "from IPython.core.display import display, HTML\n", 39 | "\n", 40 | "def get_pid_owner(pid):\n", 41 | " # the /proc/PID is owned by process creator\n", 42 | " proc_stat_file = os.stat(\"/proc/%d\" % pid)\n", 43 | " # get UID via stat call\n", 44 | " uid = proc_stat_file.st_uid\n", 45 | " # look up the username from uid\n", 46 | " username = pwd.getpwuid(uid)[0]\n", 47 | " \n", 48 | " return username\n", 49 | "\n", 50 | "def get_tb_port(username):\n", 51 | " \n", 52 | " for tb_nb in notebook.manager.get_all():\n", 53 | " if get_pid_owner(tb_nb.pid) == username:\n", 54 | " return tb_nb.port\n", 55 | " \n", 56 | "def tb_address():\n", 57 | " \n", 58 | " username = getpass.getuser()\n", 59 | " tb_port = get_tb_port(username)\n", 60 | " \n", 61 | " address = \"https://jupyter.olcf.ornl.gov\" + os.environ['JUPYTERHUB_SERVICE_PREFIX'] + 'proxy/' + str(tb_port) + \"/\"\n", 62 | "\n", 63 | " address = address.strip()\n", 64 | " \n", 65 | " display(HTML('%s'%(address,address)))\n", 66 | " \n", 67 | "%load_ext tensorboard " 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 7, 73 | "id": "b95365cc-256d-4bbe-800b-26ee2fdca21e", 74 | "metadata": { 75 | "tags": [] 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "username = os.environ['JUPYTERHUB_USER']\n", 80 | "log_dir = os.path.expandvars('/gpfs/alpine/trn001/world-shared/%s/sc22-dl-tutorial/logs'%username)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 8, 86 | "id": "c86e34c8-cdd3-4436-8939-f9d24faf80b0", 87 | "metadata": { 88 | "tags": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "%%capture\n", 93 | "%tensorboard --logdir $log_dir --port 0" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 9, 99 | "id": "dcb87cde-89da-40f8-a985-b09f78f791a5", 100 | "metadata": { 101 | "tags": [] 102 | }, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/html": [ 107 | "https://jupyter.olcf.ornl.gov/user/atsaris/proxy/43281/" 108 | ], 109 | "text/plain": [ 110 | "" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "output_type": "display_data" 115 | } 116 | ], 117 | "source": [ 118 | "tb_address()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "d0f31cbd-9c26-4905-98f3-d3996a81e422", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3 (ipykernel)", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.7.13" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 5 151 | } 152 | -------------------------------------------------------------------------------- /submit_cgpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4 4 | #SBATCH --ntasks-per-node 8 5 | #SBATCH --cpus-per-task 10 6 | #SBATCH --gpus-per-task 1 7 | #SBATCH --time=0:30:00 8 | #SBATCH --image=nersc/sc22-dl-tutorial:latest 9 | #SBATCH -J pm-crop64 10 | #SBATCH -o %x-%j.out 11 | 12 | DATADIR=/global/cscratch1/sd/sfarrell/sc21-dl-tutorial/data 13 | LOGDIR=${SCRATCH}/sc21-dl-tutorial/logs 14 | mkdir -p ${LOGDIR} 15 | args="${@}" 16 | 17 | hostname 18 | 19 | #~/dummy 20 | 21 | # Profiling 22 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 23 | echo "Enabling profiling..." 24 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 25 | PROFILE_OUTPUT=/logs/$SLURM_JOB_ID 26 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $PROFILE_OUTPUT" 27 | fi 28 | 29 | set -x 30 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 31 | bash -c " 32 | source export_DDP_vars.sh 33 | ${PROFILE_CMD} python train.py --config=V100_crop64_sqrt ${args} 34 | " 35 | -------------------------------------------------------------------------------- /submit_pm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4_g 4 | #SBATCH --ntasks-per-node 4 5 | #SBATCH --cpus-per-task 32 6 | #SBATCH --gpus-per-node 4 7 | #SBATCH --time=0:15:00 8 | #SBATCH --image=nersc/sc22-dl-tutorial:latest 9 | #SBATCH --reservation=sc22_tutorial 10 | #SBATCH -J pm-crop64 11 | #SBATCH -o %x-%j.out 12 | 13 | DATADIR=/pscratch/sd/j/joshr/nbody2hydro/datacopies 14 | LOGDIR=${SCRATCH}/sc22-dl-tutorial/logs 15 | mkdir -p ${LOGDIR} 16 | args="${@}" 17 | 18 | hostname 19 | 20 | export NCCL_NET_GDR_LEVEL=PHB 21 | 22 | # Profiling 23 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 24 | echo "Enabling profiling..." 25 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 26 | NSYS_OUTPUT=${PROFILE_OUTPUT:-"profile"} 27 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 28 | fi 29 | 30 | BENCHY_CONFIG=benchy-conf.yaml 31 | BENCHY_OUTPUT=${BENCHY_OUTPUT:-"benchy_output"} 32 | sed "s/.*output_filename.*/ output_filename: ${BENCHY_OUTPUT}.json/" ${BENCHY_CONFIG} > benchy-run-${SLURM_JOBID}.yaml 33 | export BENCHY_CONFIG_FILE=benchy-run-${SLURM_JOBID}.yaml 34 | export MASTER_ADDR=$(hostname) 35 | 36 | # Reversing order of GPUs to match default CPU affinities from Slurm 37 | export CUDA_VISIBLE_DEVICES=3,2,1,0 38 | 39 | set -x 40 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 41 | bash -c " 42 | source export_DDP_vars.sh 43 | ${PROFILE_CMD} python train.py ${args} 44 | " 45 | rm benchy-run-${SLURM_JOBID}.yaml 46 | -------------------------------------------------------------------------------- /submit_summit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | nnodes=$(cat ${LSB_DJOB_HOSTFILE} | sort | uniq | grep -v login | grep -v batch | wc -l) 4 | module use /sw/aaims/summit/modulefiles 5 | module load open-ce-pyt 6 | 7 | DATADIR=/gpfs/alpine/stf011/world-shared/atsaris/SC22_tutorial_data 8 | LOGDIR=$WORLDWORK/trn001/$USER/sc22-dl-tutorial/logs 9 | mkdir -p ${LOGDIR} 10 | 11 | if [ "$1" != "-g" ]; then 12 | echo "You need to specify -g for gpus per node as a first argument" ; exit 1; 13 | fi 14 | 15 | re='^[0-9]+$' 16 | if ! [[ $2 =~ $re ]] ; then 17 | echo "error: gpus per node must be a number" >&2; exit 1 18 | fi 19 | args="${@:3}" 20 | 21 | hostname 22 | 23 | #~/dummy 24 | 25 | #export NCCL_NET_GDR_LEVEL=PHB 26 | 27 | BENCHY_CONFIG=benchy-conf.yaml 28 | BENCHY_OUTPUT=${BENCHY_OUTPUT:-"benchy_output"} 29 | sed "s/.*output_filename.*/ output_filename: ${BENCHY_OUTPUT}.json/" ${BENCHY_CONFIG} > benchy-run-${LSB_JOBID}.yaml 30 | export BENCHY_CONFIG_FILE=benchy-run-${LSB_JOBID}.yaml 31 | export MASTER_ADDR=$(hostname) 32 | 33 | set -x 34 | 35 | if [ -z "$ENABLE_PROFILING" ] 36 | then 37 | ENABLE_PROFILING=0 38 | fi 39 | 40 | if [ -z "$PROFILE_OUTPUT" ] 41 | then 42 | PROFILE_OUTPUT=0 43 | fi 44 | 45 | time jsrun -n${nnodes} -a"$(($2))" -c42 -g"$(($2))" -r1 \ 46 | --smpiargs="-disable_gpu_hooks" \ 47 | --bind=proportional-packed:7 \ 48 | --launch_distribution=packed stdbuf -o0 \ 49 | ./launch_summit.sh \ 50 | "./run_summit.sh ${ENABLE_PROFILING} ${PROFILE_OUTPUT} ${args}" 51 | 52 | rm benchy-run-${LSB_JOBID}.yaml 53 | -------------------------------------------------------------------------------- /summit_scaling_logs/plot_weak_scale.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import matplotlib.pyplot as plt 3 | import statistics 4 | import json 5 | import numpy as np 6 | 7 | def return_avg_ips(file_loc): 8 | avg_ips = 0 9 | steps = 0 10 | myfile = open(file_loc, 'r') 11 | myline = myfile.readline().strip() 12 | while myline: 13 | stud_obj = json.loads(myline) 14 | io = stud_obj["IO"]["trial_throughput"][0] 15 | syth = stud_obj["SYNTHETIC"]["trial_throughput"][0] 16 | full = stud_obj["FULL"]["trial_throughput"][0] 17 | myline = myfile.readline().strip() 18 | 19 | myfile.close() 20 | return io, syth, full 21 | 22 | 23 | n1 = return_avg_ips('weak_scale_1.json') 24 | n2 = return_avg_ips('weak_scale_2.json') 25 | n4 = return_avg_ips('weak_scale_4.json') 26 | n8 = return_avg_ips('weak_scale_8.json') 27 | n16 = return_avg_ips('weak_scale_16.json') 28 | n32 = return_avg_ips('weak_scale_32.json') 29 | 30 | # Real 31 | io = [n1[0], n2[0], n4[0], n8[0], n16[0], n32[0]] 32 | syth = [n1[1], n2[1], n4[1], n8[1], n16[1], n32[1]] 33 | full = [n1[2], n2[2], n4[2], n8[2], n16[2], n32[2]] 34 | 35 | # Ideal 36 | io_id = [n1[0], n1[0]*2, n1[0]*4, n1[0]*8, n1[0]*16, n1[0]*32] 37 | syth_id = [n1[1], n1[1]*2, n1[1]*4, n1[1]*8, n1[1]*16, n1[1]*32] 38 | full_id = [n1[2], n1[2]*2, n1[2]*4, n1[2]*8, n1[2]*16, n1[2]*32] 39 | 40 | bX = [1, 2, 4, 8, 16, 32] 41 | 42 | fig, axarr = plt.subplots(1,1) 43 | font = {'family': 'sans-serif', 44 | 'color': 'black', 45 | 'weight': 'normal', 46 | } 47 | 48 | axarr.plot(bX, full, '-o', color='g', label='real data') 49 | axarr.plot(bX, syth, '-o', color='darkorange', label='synthetic data') 50 | axarr.plot(bX, io, '-o', color='darkblue', label='data only') 51 | axarr.plot(bX, full_id, '--o', color='g', label='real data ideal') 52 | axarr.plot(bX, syth_id, '--o', color='darkorange', label='synthetic data ideal') 53 | axarr.plot(bX, io_id, '--o', color='darkblue', label='data only ideal') 54 | axarr.set_ylabel('Avg. Throughput (img/sec)',fontdict=font) 55 | axarr.set_xlabel('Nodes',fontdict=font) 56 | axarr.set_title('3D U-Net Cosmo Weak Scaling', fontdict=font) 57 | axarr.legend() 58 | axarr.grid() 59 | 60 | plt.show() 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.cuda.amp import autocast, GradScaler 11 | import torch.multiprocessing 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.parallel import DistributedDataParallel 14 | 15 | import logging 16 | from utils import logging_utils 17 | logging_utils.config_logger() 18 | from utils.YParams import YParams 19 | from utils import get_data_loader_distributed, lr_schedule 20 | from networks import UNet 21 | 22 | import apex.optimizers as aoptim 23 | 24 | def train(params, args, local_rank, world_rank, world_size): 25 | # set device and benchmark mode 26 | torch.backends.cudnn.benchmark = True 27 | torch.cuda.set_device(local_rank) 28 | device = torch.device('cuda:%d'%local_rank) 29 | 30 | # get data loader 31 | logging.info('rank %d, begin data loader init'%world_rank) 32 | train_data_loader, val_data_loader = get_data_loader_distributed(params, world_rank, device.index) 33 | logging.info('rank %d, data loader initialized with config %s'%(world_rank, params.data_loader_config)) 34 | 35 | # create model 36 | model = UNet.UNet(params).to(device) 37 | model.apply(model.get_weights_function(params.weight_init)) 38 | 39 | if params.amp_dtype == torch.float16: 40 | scaler = GradScaler() 41 | if params.distributed and not args.noddp: 42 | if args.disable_broadcast_buffers: 43 | model = DistributedDataParallel(model, device_ids=[local_rank], 44 | bucket_cap_mb=args.bucket_cap_mb, 45 | broadcast_buffers=False, 46 | gradient_as_bucket_view=True) 47 | else: 48 | model = DistributedDataParallel(model, device_ids=[local_rank], 49 | bucket_cap_mb=args.bucket_cap_mb) 50 | 51 | if params.enable_apex: 52 | optimizer = aoptim.FusedAdam(model.parameters(), lr = params.lr_schedule['start_lr'], 53 | adam_w_mode=False, set_grad_none=True) 54 | else: 55 | optimizer = optim.Adam(model.parameters(), lr = params.lr_schedule['start_lr']) 56 | 57 | if params.enable_jit: 58 | model_handle = model.module if (params.distributed and not args.noddp) else model 59 | model_handle = torch.jit.script(model_handle) 60 | 61 | # select loss function 62 | if params.enable_jit: 63 | loss_func = UNet.loss_func_opt_final 64 | lambda_rho = torch.zeros((1,5,1,1,1), dtype=torch.float32).to(device) 65 | lambda_rho[:,0,:,:,:] = params.lambda_rho 66 | else: 67 | loss_func = UNet.loss_func 68 | lambda_rho = params.lambda_rho 69 | 70 | # start training 71 | iters = 0 72 | startEpoch = 0 73 | params.lr_schedule['tot_steps'] = params.num_epochs*(params.Nsamples//params.global_batch_size) 74 | 75 | if world_rank==0: 76 | logging.info("Starting Training Loop...") 77 | 78 | # Log initial loss on train and validation to tensorboard 79 | if not args.enable_benchy: 80 | with torch.no_grad(): 81 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 82 | tr_loss = loss_func(model(inp), tar, lambda_rho) 83 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 84 | val_loss= loss_func(model(inp), tar, lambda_rho) 85 | if params.distributed: 86 | torch.distributed.all_reduce(tr_loss) 87 | torch.distributed.all_reduce(val_loss) 88 | if world_rank==0: 89 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item()/world_size, 0) 90 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item()/world_size, 0) 91 | 92 | iters = 0 93 | t1 = time.time() 94 | for epoch in range(startEpoch, startEpoch+params.num_epochs): 95 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 96 | start = time.time() 97 | tr_loss = [] 98 | tr_time = 0. 99 | dat_time = 0. 100 | log_time = 0. 101 | 102 | model.train() 103 | step_count = 0 104 | for i, data in enumerate(train_data_loader, 0): 105 | if (args.enable_manual_profiling and world_rank==0): 106 | if (epoch == 3 and i == 0): 107 | torch.cuda.profiler.start() 108 | if (epoch == 3 and i == 59): 109 | torch.cuda.profiler.stop() 110 | 111 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"step {i}") 112 | iters += 1 113 | dat_start = time.time() 114 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"data copy in {i}") 115 | inp, tar = map(lambda x: x.to(device), data) 116 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # copy in 117 | tr_start = time.time() 118 | b_size = inp.size(0) 119 | 120 | lr_schedule(optimizer, iters, global_bs=params.global_batch_size, base_bs=params.base_batch_size, **params.lr_schedule) 121 | optimizer.zero_grad() 122 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"forward") 123 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 124 | gen = model(inp) 125 | loss = loss_func(gen, tar, lambda_rho) 126 | tr_loss.append(loss.item()) 127 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() #forward 128 | 129 | if params.amp_dtype == torch.float16: 130 | scaler.scale(loss).backward() 131 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer") 132 | scaler.step(optimizer) 133 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer 134 | scaler.update() 135 | else: 136 | loss.backward() 137 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer") 138 | optimizer.step() 139 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer 140 | 141 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # step 142 | 143 | tr_end = time.time() 144 | tr_time += tr_end - tr_start 145 | dat_time += tr_start - dat_start 146 | step_count += 1 147 | 148 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 149 | end = time.time() 150 | if world_rank==0: 151 | logging.info('Time taken for epoch {} is {} sec, avg {} samples/sec'.format(epoch + 1, end-start, 152 | (step_count * params["global_batch_size"])/(end-start))) 153 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 154 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 155 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 156 | args.tboard_writer.add_scalar('Avg iters per sec', step_count/(end-start), iters) 157 | 158 | val_start = time.time() 159 | val_loss = [] 160 | model.eval() 161 | if not args.enable_benchy: 162 | with torch.no_grad(): 163 | for i, data in enumerate(val_data_loader, 0): 164 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 165 | inp, tar = map(lambda x: x.to(device), data) 166 | gen = model(inp) 167 | loss = loss_func(gen, tar, lambda_rho) 168 | if params.distributed: 169 | torch.distributed.all_reduce(loss) 170 | val_loss.append(loss.item()/world_size) 171 | val_end = time.time() 172 | if world_rank==0: 173 | logging.info(' Avg val loss=%f'%np.mean(val_loss)) 174 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 175 | args.tboard_writer.add_scalar('Loss/valid', np.mean(val_loss), iters) 176 | args.tboard_writer.flush() 177 | 178 | t2 = time.time() 179 | tottime = t2 - t1 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--run_num", default='00', type=str, help='tag for indexing the current experiment') 187 | parser.add_argument("--yaml_config", default='./config/UNet.yaml', type=str, help='path to yaml file containing training configs') 188 | parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file') 189 | parser.add_argument("--amp_mode", default='none', type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode') 190 | parser.add_argument("--enable_apex", action='store_true', help='enable apex fused Adam optimizer') 191 | parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation') 192 | parser.add_argument("--enable_benchy", action='store_true', help='enable benchy tool usage') 193 | parser.add_argument("--enable_manual_profiling", action='store_true', help='enable manual nvtx ranges and profiler start/stop calls') 194 | parser.add_argument("--data_loader_config", default=None, type=str, 195 | choices=['synthetic', 'inmem', 'lowmem', 'dali-lowmem'], 196 | help="dataloader configuration. choices: 'synthetic', 'inmem', 'lowmem', 'dali-lowmem'") 197 | parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)') 198 | parser.add_argument("--num_epochs", default=None, type=int, help='number of epochs to run') 199 | parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader') 200 | parser.add_argument("--bucket_cap_mb", default=25, type=int, help='max message bucket size in mb') 201 | parser.add_argument("--disable_broadcast_buffers", action='store_true', help='disable syncing broadcasting buffers') 202 | parser.add_argument("--noddp", action='store_true', help='disable DDP communication') 203 | args = parser.parse_args() 204 | 205 | if (args.enable_benchy and args.enable_manual_profiling): 206 | raise RuntimeError("Enable either benchy profiling or manual profiling, not both.") 207 | 208 | run_num = args.run_num 209 | 210 | params = YParams(os.path.abspath(args.yaml_config), args.config) 211 | 212 | # Update config with modified args 213 | # set up amp 214 | if args.amp_mode != 'none': 215 | params.update({"amp_mode": args.amp_mode}) 216 | amp_dtype = torch.float32 217 | if params.amp_mode == "fp16": 218 | amp_dtype = torch.float16 219 | elif params.amp_mode == "bf16": 220 | amp_dtype = torch.bfloat16 221 | params.update({"amp_enabled": amp_dtype is not torch.float32, 222 | "amp_dtype" : amp_dtype, 223 | "enable_apex" : args.enable_apex, 224 | "enable_jit" : args.enable_jit, 225 | "enable_benchy" : args.enable_benchy}) 226 | 227 | if args.data_loader_config: 228 | params.update({"data_loader_config" : args.data_loader_config}) 229 | 230 | if args.num_epochs: 231 | params.update({"num_epochs" : args.num_epochs}) 232 | 233 | if args.num_data_workers: 234 | params.update({"num_data_workers" : args.num_data_workers}) 235 | 236 | params.distributed = False 237 | if 'WORLD_SIZE' in os.environ: 238 | params.distributed = int(os.environ['WORLD_SIZE']) > 1 239 | world_size = int(os.environ['WORLD_SIZE']) 240 | else: 241 | world_size = 1 242 | 243 | world_rank = 0 244 | local_rank = 0 245 | if params.distributed: 246 | torch.distributed.init_process_group(backend='nccl', 247 | init_method='env://') 248 | world_rank = torch.distributed.get_rank() 249 | local_rank = int(os.environ['LOCAL_RANK']) 250 | 251 | if args.local_batch_size: 252 | # Manually override batch size 253 | params.local_batch_size = args.local_batch_size 254 | params.update({"global_batch_size" : world_size*args.local_batch_size}) 255 | else: 256 | # Compute local batch size based on number of ranks 257 | params.local_batch_size = params.global_batch_size//world_size 258 | 259 | # Set up directory 260 | baseDir = params.expdir 261 | expDir = os.path.join(baseDir, args.config+'/%dGPU/'%(world_size)+str(run_num)+'/') 262 | if world_rank==0: 263 | if not os.path.isdir(expDir): 264 | os.makedirs(expDir) 265 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 266 | params.log() 267 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 268 | 269 | params.experiment_dir = os.path.abspath(expDir) 270 | 271 | train(params, args, local_rank, world_rank, world_size) 272 | if params.distributed: 273 | torch.distributed.barrier() 274 | logging.info('DONE ---- rank %d'%world_rank) 275 | 276 | -------------------------------------------------------------------------------- /train_graph.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import gc 4 | import time 5 | import numpy as np 6 | import argparse 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.cuda.amp import autocast, GradScaler 12 | import torch.multiprocessing 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.nn.parallel import DistributedDataParallel 15 | 16 | import logging 17 | from utils import logging_utils 18 | logging_utils.config_logger() 19 | from utils.YParams import YParams 20 | from utils import get_data_loader_distributed, lr_schedule 21 | from networks import UNet 22 | 23 | import apex.optimizers as aoptim 24 | 25 | 26 | def capture_model(params, model, loss_func, lambda_rho, scaler, capture_stream, device, num_warmup=20): 27 | print("Capturing Model") 28 | inp_shape = (params.local_batch_size, 4, params.data_size, params.data_size, params.data_size) 29 | tar_shape = (params.local_batch_size, 5, params.data_size, params.data_size, params.data_size) 30 | static_input = torch.zeros(inp_shape, dtype=torch.float32, device=device) 31 | static_label = torch.zeros(tar_shape, dtype=torch.float32, device=device) 32 | 33 | capture_stream.wait_stream(torch.cuda.current_stream()) 34 | with torch.cuda.stream(capture_stream): 35 | for _ in range(num_warmup): 36 | model.zero_grad(set_to_none=True) 37 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 38 | static_output = model(static_input) 39 | static_loss = loss_func(static_output, static_label, lambda_rho) 40 | 41 | if params.amp_dtype == torch.float16: 42 | scaler.scale(static_loss).backward() 43 | else: 44 | static_loss.backward() 45 | 46 | # sync here 47 | capture_stream.synchronize() 48 | 49 | gc.collect() 50 | torch.cuda.empty_cache() 51 | 52 | # create graph 53 | graph = torch.cuda.CUDAGraph() 54 | 55 | # zero grads before capture: 56 | model.zero_grad(set_to_none=True) 57 | 58 | # do the capture with the context manager: 59 | with torch.cuda.graph(graph): 60 | 61 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 62 | static_output = model(static_input) 63 | static_loss = loss_func(static_output, static_label, lambda_rho) 64 | 65 | if params.amp_dtype == torch.float16: 66 | scaler.scale(static_loss).backward() 67 | else: 68 | static_loss.backward() 69 | 70 | torch.cuda.current_stream().wait_stream(capture_stream) 71 | 72 | return graph, static_input, static_output, static_label, static_loss 73 | 74 | 75 | def train(params, args, local_rank, world_rank, world_size): 76 | # set device and benchmark mode 77 | torch.backends.cudnn.benchmark = True 78 | torch.cuda.set_device(local_rank) 79 | device = torch.device('cuda:%d'%local_rank) 80 | 81 | # get data loader 82 | logging.info('rank %d, begin data loader init'%world_rank) 83 | train_data_loader, val_data_loader = get_data_loader_distributed(params, world_rank, device.index) 84 | logging.info('rank %d, data loader initialized with config %s'%(world_rank, params.data_loader_config)) 85 | 86 | # create model 87 | model = UNet.UNet(params).to(device) 88 | model.apply(model.get_weights_function(params.weight_init)) 89 | 90 | if params.amp_dtype == torch.float16: 91 | scaler = GradScaler() 92 | else: 93 | scaler = None 94 | 95 | capture_stream = torch.cuda.Stream() 96 | with torch.cuda.stream(capture_stream): 97 | if params.distributed: 98 | model = DistributedDataParallel(model, device_ids=[local_rank]) 99 | capture_stream.synchronize() 100 | 101 | if params.enable_apex: 102 | optimizer = aoptim.FusedAdam(model.parameters(), lr = params.lr_schedule['start_lr'], 103 | adam_w_mode=False, set_grad_none=True) 104 | else: 105 | optimizer = optim.Adam(model.parameters(), lr = params.lr_schedule['start_lr']) 106 | 107 | if params.enable_jit: 108 | model_handle = model.module if params.distributed else model 109 | model_handle = torch.jit.script(model_handle) 110 | 111 | # select loss function 112 | if params.enable_jit: 113 | loss_func = UNet.loss_func_opt_final 114 | lambda_rho = torch.zeros((1,5,1,1,1), dtype=torch.float32).to(device) 115 | lambda_rho[:,0,:,:,:] = params.lambda_rho 116 | else: 117 | loss_func = UNet.loss_func 118 | lambda_rho = params.lambda_rho 119 | 120 | # capture the model 121 | graph, static_input, static_output, static_label, static_loss = capture_model(params, model, loss_func, lambda_rho, scaler, 122 | capture_stream, device, num_warmup=20) 123 | 124 | # start training 125 | iters = 0 126 | startEpoch = 0 127 | params.lr_schedule['tot_steps'] = params.num_epochs*(params.Nsamples//params.global_batch_size) 128 | 129 | if world_rank==0: 130 | logging.info("Starting Training Loop...") 131 | 132 | # Log initial loss on train and validation to tensorboard 133 | if not args.enable_benchy: 134 | with torch.no_grad(): 135 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 136 | tr_loss = loss_func(model(inp), tar, lambda_rho) 137 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 138 | val_loss= loss_func(model(inp), tar, lambda_rho) 139 | if params.distributed: 140 | torch.distributed.all_reduce(tr_loss) 141 | torch.distributed.all_reduce(val_loss) 142 | if world_rank==0: 143 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item()/world_size, 0) 144 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item()/world_size, 0) 145 | 146 | iters = 0 147 | t1 = time.time() 148 | for epoch in range(startEpoch, startEpoch+params.num_epochs): 149 | start = time.time() 150 | tr_loss = [] 151 | tr_time = 0. 152 | dat_time = 0. 153 | log_time = 0. 154 | 155 | model.train() 156 | step_count = 0 157 | for i, data in enumerate(train_data_loader, 0): 158 | if (args.enable_manual_profiling and world_rank==0): 159 | if (epoch == 1 and i == 0): 160 | torch.cuda.profiler.start() 161 | if (epoch == 1 and i == 29): 162 | torch.cuda.profiler.stop() 163 | 164 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"step {i}") 165 | iters += 1 166 | dat_start = time.time() 167 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"data copy in {i}") 168 | inp, tar = map(lambda x: x.to(device), data) 169 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # copy in 170 | tr_start = time.time() 171 | b_size = inp.size(0) 172 | 173 | lr_schedule(optimizer, iters, global_bs=params.global_batch_size, base_bs=params.base_batch_size, **params.lr_schedule) 174 | 175 | static_input.copy_(inp) 176 | static_label.copy_(tar) 177 | graph.replay() 178 | 179 | if params.amp_dtype == torch.float16: 180 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer") 181 | scaler.step(optimizer) 182 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer 183 | scaler.update() 184 | else: 185 | if args.enable_manual_profiling: torch.cuda.nvtx.range_push(f"optimizer") 186 | optimizer.step() 187 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # optimizer 188 | 189 | tr_loss.append(static_loss.item()) 190 | if args.enable_manual_profiling: torch.cuda.nvtx.range_pop() # step 191 | 192 | tr_end = time.time() 193 | tr_time += tr_end - tr_start 194 | dat_time += tr_start - dat_start 195 | step_count += 1 196 | 197 | end = time.time() 198 | if world_rank==0: 199 | logging.info('Time taken for epoch {} is {} sec, avg {} samples/sec'.format(epoch + 1, end-start, 200 | (step_count * params["global_batch_size"])/(end-start))) 201 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 202 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 203 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 204 | args.tboard_writer.add_scalar('Avg iters per sec', step_count/(end-start), iters) 205 | 206 | val_start = time.time() 207 | val_loss = [] 208 | model.eval() 209 | if not args.enable_benchy: 210 | with torch.no_grad(): 211 | for i, data in enumerate(val_data_loader, 0): 212 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 213 | inp, tar = map(lambda x: x.to(device), data) 214 | gen = model(inp) 215 | loss = loss_func(gen, tar, lambda_rho) 216 | if params.distributed: 217 | torch.distributed.all_reduce(loss) 218 | val_loss.append(loss.item()/world_size) 219 | val_end = time.time() 220 | if world_rank==0: 221 | logging.info(' Avg val loss=%f'%np.mean(val_loss)) 222 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 223 | args.tboard_writer.add_scalar('Loss/valid', np.mean(val_loss), iters) 224 | args.tboard_writer.flush() 225 | 226 | t2 = time.time() 227 | tottime = t2 - t1 228 | 229 | 230 | 231 | if __name__ == '__main__': 232 | 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument("--run_num", default='00', type=str, help='tag for indexing the current experiment') 235 | parser.add_argument("--yaml_config", default='./config/UNet.yaml', type=str, help='path to yaml file containing training configs') 236 | parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file') 237 | parser.add_argument("--amp_mode", default=None, type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode') 238 | parser.add_argument("--enable_apex", action='store_true', help='enable apex fused Adam optimizer') 239 | parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation') 240 | parser.add_argument("--enable_benchy", action='store_true', help='enable benchy tool usage') 241 | parser.add_argument("--enable_manual_profiling", action='store_true', help='enable manual nvtx ranges and profiler start/stop calls') 242 | parser.add_argument("--data_loader_config", default=None, type=str, 243 | choices=['synthetic', 'inmem', 'lowmem', 'dali-lowmem'], 244 | help="dataloader configuration. choices: 'synthetic', 'inmem', 'lowmem', 'dali-lowmem'") 245 | parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)') 246 | parser.add_argument("--num_epochs", default=None, type=int, help='number of epochs to run') 247 | parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader') 248 | args = parser.parse_args() 249 | 250 | if (args.enable_benchy and args.enable_manual_profiling): 251 | raise RuntimeError("Enable either benchy profiling or manual profiling, not both.") 252 | 253 | run_num = args.run_num 254 | 255 | params = YParams(os.path.abspath(args.yaml_config), args.config) 256 | 257 | # Update config with modified args 258 | # set up amp 259 | if args.amp_mode is not None: 260 | params.update({"amp_mode": args.amp_mode}) 261 | amp_dtype = torch.float32 262 | if params.amp_mode == "fp16": 263 | amp_dtype = torch.float16 264 | elif params.amp_mode == "bf16": 265 | amp_dtype = torch.bfloat16 266 | params.update({"amp_enabled": amp_dtype is not None, 267 | "amp_dtype" : amp_dtype, 268 | "enable_apex" : args.enable_apex, 269 | "enable_jit" : args.enable_jit, 270 | "enable_benchy" : args.enable_benchy}) 271 | 272 | if args.data_loader_config: 273 | params.update({"data_loader_config" : args.data_loader_config}) 274 | 275 | if args.num_epochs: 276 | params.update({"num_epochs" : args.num_epochs}) 277 | 278 | if args.num_data_workers: 279 | params.update({"num_data_workers" : args.num_data_workers}) 280 | 281 | params.distributed = False 282 | if 'WORLD_SIZE' in os.environ: 283 | params.distributed = int(os.environ['WORLD_SIZE']) > 1 284 | world_size = int(os.environ['WORLD_SIZE']) 285 | else: 286 | world_size = 1 287 | 288 | world_rank = 0 289 | local_rank = 0 290 | if params.distributed: 291 | torch.distributed.init_process_group(backend='nccl', 292 | init_method='env://') 293 | world_rank = torch.distributed.get_rank() 294 | local_rank = int(os.environ['LOCAL_RANK']) 295 | 296 | if args.local_batch_size: 297 | # Manually override batch size 298 | params.local_batch_size = args.local_batch_size 299 | params.update({"global_batch_size" : world_size*args.local_batch_size}) 300 | else: 301 | # Compute local batch size based on number of ranks 302 | params.local_batch_size = params.global_batch_size//world_size 303 | 304 | # Set up directory 305 | baseDir = params.expdir 306 | expDir = os.path.join(baseDir, args.config+'/%dGPU/'%(world_size)+str(run_num)+'/') 307 | if world_rank==0: 308 | if not os.path.isdir(expDir): 309 | os.makedirs(expDir) 310 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 311 | params.log() 312 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 313 | 314 | params.experiment_dir = os.path.abspath(expDir) 315 | 316 | train(params, args, local_rank, world_rank, world_size) 317 | if params.distributed: 318 | torch.distributed.barrier() 319 | logging.info('DONE ---- rank %d'%world_rank) 320 | 321 | -------------------------------------------------------------------------------- /tutorial_images/baseline_tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/baseline_tb.png -------------------------------------------------------------------------------- /tutorial_images/baseline_tb_summit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/baseline_tb_summit.png -------------------------------------------------------------------------------- /tutorial_images/bs512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs512.png -------------------------------------------------------------------------------- /tutorial_images/bs512_short.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs512_short.png -------------------------------------------------------------------------------- /tutorial_images/bs576_short_summit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs576_short_summit.png -------------------------------------------------------------------------------- /tutorial_images/bs576_summit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs576_summit.png -------------------------------------------------------------------------------- /tutorial_images/bs_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs_compare.png -------------------------------------------------------------------------------- /tutorial_images/bs_compare_summit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/bs_compare_summit.png -------------------------------------------------------------------------------- /tutorial_images/nbody2hydro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nbody2hydro.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4gpu_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_4gpu_baseline.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4gpu_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_4gpu_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4gpu_bucketcap100mb_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_4gpu_bucketcap100mb_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4gpu_nobroadcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_4gpu_nobroadcast.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_baseline.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali_amp.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_apex_jit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali_amp_apex_jit.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_apex_jit_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali_amp_apex_jit_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali_amp_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_dali_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_16workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_nativedata_16workers.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_16workers_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_nativedata_16workers_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_8workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_nativedata_8workers.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_8workers_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_nativedata_8workers_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_summit_6gpu_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_summit_6gpu_baseline.png -------------------------------------------------------------------------------- /tutorial_images/nsys_summit_6gpu_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_summit_6gpu_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_summit_6gpu_bucketcap100mb_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_summit_6gpu_bucketcap100mb_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_summit_6gpu_nobroadcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/nsys_summit_6gpu_nobroadcast.png -------------------------------------------------------------------------------- /tutorial_images/relative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/relative.png -------------------------------------------------------------------------------- /tutorial_images/scale_perfComm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/scale_perfComm.png -------------------------------------------------------------------------------- /tutorial_images/scale_perfDiffBS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/scale_perfDiffBS.png -------------------------------------------------------------------------------- /tutorial_images/scale_perfEff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/scale_perfEff.png -------------------------------------------------------------------------------- /tutorial_images/scale_perfEff_bs128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc22-dl-tutorial/21ea7dc9419a474cfa40c9790adb935554c4c590/tutorial_images/scale_perfEff_bs128.png -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | class YParams(): 5 | """ Yaml file parser """ 6 | def __init__(self, yaml_filename, config_name, print_params=False): 7 | self._yaml_filename = yaml_filename 8 | self._config_name = config_name 9 | self.params = {} 10 | 11 | with open(yaml_filename) as _file: 12 | 13 | for key, val in YAML().load(_file)[config_name].items(): 14 | if val =='None': val = None 15 | 16 | self.params[key] = val 17 | self.__setattr__(key, val) 18 | 19 | if print_params: 20 | self.log() 21 | 22 | def __getitem__(self, key): 23 | return self.params[key] 24 | 25 | def __setitem__(self, key, val): 26 | self.params[key] = val 27 | 28 | def log(self): 29 | logging.info("------------------ Configuration ------------------") 30 | logging.info("Configuration file: "+str(self._yaml_filename)) 31 | logging.info("Configuration name: "+str(self._config_name)) 32 | for key, val in self.params.items(): 33 | logging.info(str(key) + ' ' + str(val)) 34 | logging.info("---------------------------------------------------") 35 | 36 | def update(self, new_params): 37 | self.params.update(new_params) 38 | for key, val in new_params.items(): 39 | self.__setattr__(key, val) 40 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def get_data_loader_distributed(params, world_rank, device_id=0): 5 | if params.data_loader_config.startswith("dali"): 6 | if params.data_loader_config == "dali-lowmem": 7 | from .data_loader_dali import get_data_loader_distributed 8 | else: 9 | raise NotImplementedError(f"Error, data loader config {params.data_loader_config} not supported!") 10 | else: 11 | from .data_loader import get_data_loader_distributed 12 | 13 | return get_data_loader_distributed(params, world_rank, device_id) 14 | 15 | def lr_schedule(optimizer, iternum, global_bs, base_bs, scaling='none', start_lr=1e-4, tot_steps=1000, end_lr=0., warmup_steps=0): 16 | if scaling=='sqrt': 17 | init_lr = np.sqrt(global_bs/base_bs)*start_lr 18 | elif scaling=='linear': 19 | init_lr = (global_bs/base_bs)*start_lr 20 | elif scaling=='none': 21 | init_lr = start_lr 22 | 23 | if global_bs > base_bs and scaling != 'none': 24 | # warm-up lr rate 25 | if iternum