├── distributed ├── __init__.py ├── helpers.py ├── mappings.py └── layers.py ├── figs └── minibatch_0.jpg ├── tutorial_images ├── mp_comp.png ├── dp_timings.png ├── mp_dp_comp.png ├── nsys_dali.png ├── baseline_tb.png ├── nsys_baseline.png ├── nsys_dali_amp.png ├── vit_schematic.png ├── nsys_dali_zoomed.png ├── nsys_baseline_zoomed.png ├── nsys_dali_amp_zoomed.png ├── weather_forecasting.gif ├── nsys_dali_amp_fused_jit.png ├── nsys_nativedata_4workers.png ├── nsys_dali_amp_fused_jit_zoomed.png └── nsys_nativedata_4workers_zoomed.png ├── sample_nsys_profiles ├── dali.nsys-rep ├── 4workers.nsys-rep ├── baseline.nsys-rep ├── dali_amp_bf16.nsys-rep └── dali_amp_bf16_fused_jit.nsys-rep ├── export_DDP_vars.sh ├── example_logs └── base │ └── 1GPU │ └── 00 │ └── logs │ └── events.out.tfevents.1698839594.nid001332.2071366.0 ├── test_model_dims.py ├── utils ├── __init__.py ├── plots.py ├── loss.py ├── YParams.py ├── logging_utils.py ├── metrics.py ├── data_loader.py ├── dali_es_helper.py ├── data_loader_dali.py └── comm.py ├── submit_pm.sh ├── submit_pm_dp.sh ├── test_data_loader.py ├── submit_pm_mp.sh ├── start_tensorboard.ipynb ├── config └── ViT.yaml ├── .gitignore ├── networks ├── helpers.py └── vit.py ├── train.py ├── train_mp.py ├── train_mp_graphs.py └── README.md /distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # model parallelism helpers and routines 2 | -------------------------------------------------------------------------------- /figs/minibatch_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/figs/minibatch_0.jpg -------------------------------------------------------------------------------- /tutorial_images/mp_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/mp_comp.png -------------------------------------------------------------------------------- /tutorial_images/dp_timings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/dp_timings.png -------------------------------------------------------------------------------- /tutorial_images/mp_dp_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/mp_dp_comp.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali.png -------------------------------------------------------------------------------- /tutorial_images/baseline_tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/baseline_tb.png -------------------------------------------------------------------------------- /sample_nsys_profiles/dali.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/sample_nsys_profiles/dali.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_baseline.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali_amp.png -------------------------------------------------------------------------------- /tutorial_images/vit_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/vit_schematic.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali_zoomed.png -------------------------------------------------------------------------------- /sample_nsys_profiles/4workers.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/sample_nsys_profiles/4workers.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/baseline.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/sample_nsys_profiles/baseline.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali_amp_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/weather_forecasting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/weather_forecasting.gif -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp_bf16.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/sample_nsys_profiles/dali_amp_bf16.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_fused_jit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali_amp_fused_jit.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_4workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_nativedata_4workers.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_amp_fused_jit_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_dali_amp_fused_jit_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_nativedata_4workers_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/tutorial_images/nsys_nativedata_4workers_zoomed.png -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp_bf16_fused_jit.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/sample_nsys_profiles/dali_amp_bf16_fused_jit.nsys-rep -------------------------------------------------------------------------------- /export_DDP_vars.sh: -------------------------------------------------------------------------------- 1 | export RANK=$SLURM_PROCID 2 | export LOCAL_RANK=$SLURM_LOCALID 3 | export WORLD_SIZE=$SLURM_NTASKS 4 | export MASTER_PORT=29500 # default from torch launcher 5 | -------------------------------------------------------------------------------- /example_logs/base/1GPU/00/logs/events.out.tfevents.1698839594.nid001332.2071366.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/sc23-dl-tutorial/HEAD/example_logs/base/1GPU/00/logs/events.out.tfevents.1698839594.nid001332.2071366.0 -------------------------------------------------------------------------------- /test_model_dims.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks.vit import ViT 3 | from utils.YParams import YParams 4 | from torchinfo import summary 5 | 6 | params = YParams('./config/ViT.yaml', 'short_mp') 7 | model = ViT(params) 8 | summary(model, input_size=(16,20,360,720)) 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | def get_data_loader_distributed(params, location, distributed, train): 3 | if params.data_loader_config.startswith("dali"): 4 | from .data_loader_dali import get_data_loader 5 | else: 6 | from .data_loader import get_data_loader 7 | return get_data_loader(params, location, distributed, train) 8 | 9 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def generate_images(fields): 6 | inp, tar, gen = [x.detach().float().cpu().numpy() for x in fields] 7 | fig, ax = plt.subplots(1, 2, figsize=(12,6)) 8 | plt.title('2m temperature') 9 | ax[0].imshow(tar[0,2,:,:], cmap="turbo") 10 | ax[0].set_title("ERA5 target") 11 | ax[1].imshow(gen[0,2,:,:], cmap="turbo") 12 | ax[1].set_title("ViT prediction") 13 | fig.tight_layout() 14 | return fig 15 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def l2_loss(pred, target): 4 | num_examples = pred.shape[0] 5 | diff_norms = torch.norm(pred.reshape(num_examples,-1) - target.reshape(num_examples,-1), 2, 1) 6 | y_norms = torch.norm(target.reshape(num_examples,-1), 2, 1) 7 | return torch.mean(diff_norms/y_norms) 8 | 9 | 10 | @torch.jit.script 11 | def l2_loss_opt(pred: torch.Tensor, target: torch.Tensor): 12 | num_examples = pred.shape[0] 13 | diff_norms = torch.norm(pred.reshape(num_examples,-1) - target.reshape(num_examples,-1), 2, 1) 14 | y_norms = torch.norm(target.reshape(num_examples,-1), 2, 1) 15 | return torch.mean(diff_norms/y_norms) 16 | 17 | -------------------------------------------------------------------------------- /submit_pm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -q shared 4 | #SBATCH -A ntrain4 5 | #SBATCH --cpus-per-task 32 6 | #SBATCH --gpus-per-task 1 7 | #SBATCH --gpu-bind none 8 | #SBATCH --time=01:00:00 9 | #SBATCH --image=nersc/pytorch:ngc-23.07-v0 10 | #SBATCH --module=gpu,nccl-2.18 11 | #SBATCH --reservation=sc23_dl_tutorial_1 12 | #SBATCH -J vit-era5 13 | #SBATCH -o %x-%j.out 14 | 15 | DATADIR=/pscratch/sd/s/shas1693/data/sc23_tutorial_data/downsampled 16 | LOGDIR=${SCRATCH}/sc23-dl-tutorial/logs 17 | mkdir -p ${LOGDIR} 18 | args="${@}" 19 | 20 | export FI_MR_CACHE_MONITOR=userfaultfd 21 | export HDF5_USE_FILE_LOCKING=FALSE 22 | 23 | # Profiling 24 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 25 | echo "Enabling profiling..." 26 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 27 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 28 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 29 | fi 30 | 31 | export MASTER_ADDR=$(hostname) 32 | 33 | # Reversing order of GPUs to match default CPU affinities from Slurm 34 | export CUDA_VISIBLE_DEVICES=3,2,1,0 35 | 36 | set -x 37 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 38 | bash -c " 39 | source export_DDP_vars.sh 40 | ${PROFILE_CMD} python train.py ${args} 41 | " 42 | -------------------------------------------------------------------------------- /submit_pm_dp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4 4 | #SBATCH -q regular 5 | #SBATCH --ntasks-per-node 4 6 | #SBATCH --cpus-per-task 32 7 | #SBATCH --gpus-per-node 4 8 | #SBATCH --time=01:00:00 9 | #SBATCH --image=nersc/pytorch:ngc-23.07-v0 10 | #SBATCH --module=gpu,nccl-2.18 11 | #SBATCH --reservation=sc23_dl_tutorial_2 12 | #SBATCH -J vit-era5 13 | #SBATCH -o %x-%j.out 14 | 15 | DATADIR=/pscratch/sd/s/shas1693/data/sc23_tutorial_data/downsampled 16 | LOGDIR=${SCRATCH}/sc23-dl-tutorial/logs 17 | mkdir -p ${LOGDIR} 18 | args="${@}" 19 | 20 | export FI_MR_CACHE_MONITOR=userfaultfd 21 | export HDF5_USE_FILE_LOCKING=FALSE 22 | 23 | # Profiling 24 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 25 | echo "Enabling profiling..." 26 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 27 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 28 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 29 | fi 30 | 31 | export MASTER_ADDR=$(hostname) 32 | 33 | # Reversing order of GPUs to match default CPU affinities from Slurm 34 | export CUDA_VISIBLE_DEVICES=3,2,1,0 35 | 36 | set -x 37 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 38 | bash -c " 39 | source export_DDP_vars.sh 40 | ${PROFILE_CMD} python train.py ${args} 41 | " 42 | -------------------------------------------------------------------------------- /test_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import get_data_loader_distributed 3 | import numpy as np 4 | from utils.YParams import YParams 5 | from networks.vit import ViT 6 | import matplotlib.pyplot as plt 7 | 8 | params = YParams('./config/ViT.yaml', 'short') 9 | params.global_batch_size = 1 10 | params.local_batch_size = 1 11 | 12 | valid_dataloader, dataset_valid = get_data_loader_distributed(params, params.valid_data_path, distributed=False, train=False) 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | params.device = device 16 | model = ViT(params) 17 | model = model.to(device) 18 | 19 | with torch.no_grad(): 20 | for i, data in enumerate(valid_dataloader, 0): 21 | if i >= 1: 22 | break 23 | print("Doing iteration {}".format(i)) 24 | inp, tar = map(lambda x: x.to(device, dtype = torch.float), data) 25 | print("input shape = {}".format(inp.shape)) 26 | print("target shape = {}".format(tar.shape)) 27 | plt.rcParams["figure.figsize"] = (20,20) 28 | plt.figure() 29 | for ch in range(inp.shape[1]): 30 | plt.subplot(inp.shape[1],1, ch+1) 31 | plt.imshow(inp[0,ch,:,:].cpu(), cmap = 'RdBu') 32 | plt.colorbar() 33 | plt.savefig("figs/minibatch_" + str(i) + ".jpg") 34 | gen = model(inp) 35 | print("prediction shape = {}".format(gen.shape)) 36 | 37 | -------------------------------------------------------------------------------- /submit_pm_mp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4 4 | #SBATCH --ntasks-per-node 4 5 | #SBATCH --cpus-per-task 32 6 | #SBATCH --gpus-per-node 4 7 | #SBATCH --time=01:00:00 8 | #SBATCH --image=nersc/pytorch:ngc-23.07-v0 9 | #SBATCH --module=gpu,nccl-2.18 10 | #SBATCH --reservation=sc23_dl_tutorial_2 11 | #SBATCH -J vit-era5-mp 12 | #SBATCH -o %x-%j.out 13 | 14 | DATADIR=/pscratch/sd/s/shas1693/data/sc23_tutorial_data/downsampled 15 | LOGDIR=${SCRATCH}/sc23-dl-tutorial/logs 16 | mkdir -p ${LOGDIR} 17 | args="${@}" 18 | #args="--config=mp --row_parallel_size=4" 19 | 20 | export FI_MR_CACHE_MONITOR=userfaultfd 21 | export HDF5_USE_FILE_LOCKING=FALSE 22 | 23 | # Profiling 24 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 25 | echo "Enabling profiling..." 26 | NSYS_ARGS="--trace=cuda,cublas,nvtx --cuda-graph-trace=node --kill none -c cudaProfilerApi -f true" 27 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 28 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 29 | fi 30 | 31 | export MASTER_ADDR=$(hostname) 32 | 33 | # Reversing order of GPUs to match default CPU affinities from Slurm 34 | export CUDA_VISIBLE_DEVICES=3,2,1,0 35 | 36 | # if cuda graphs, use train_mp_graphs.py 37 | set -x 38 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 39 | bash -c " 40 | source export_DDP_vars.sh 41 | ${PROFILE_CMD} python train_mp.py ${args} 42 | " 43 | -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | class YParams(): 5 | """ Yaml file parser """ 6 | def __init__(self, yaml_filename, config_name, print_params=False): 7 | self._yaml_filename = yaml_filename 8 | self._config_name = config_name 9 | self.params = {} 10 | 11 | with open(yaml_filename) as _file: 12 | 13 | for key, val in YAML().load(_file)[config_name].items(): 14 | if val =='None': val = None 15 | 16 | self.params[key] = val 17 | self.__setattr__(key, val) 18 | 19 | if print_params: 20 | self.log() 21 | 22 | def __getitem__(self, key): 23 | return self.params[key] 24 | 25 | def __setitem__(self, key, val): 26 | self.params[key] = val 27 | 28 | def get(self, key, default=None): 29 | """Get a parameter value""" 30 | if hasattr(self, key): 31 | return getattr(self, key) 32 | else: 33 | return self.params.get(key, default) 34 | 35 | def log(self): 36 | logging.info("------------------ Configuration ------------------") 37 | logging.info("Configuration file: "+str(self._yaml_filename)) 38 | logging.info("Configuration name: "+str(self._config_name)) 39 | for key, val in self.params.items(): 40 | logging.info(str(key) + ' ' + str(val)) 41 | logging.info("---------------------------------------------------") 42 | 43 | def update(self, new_params): 44 | self.params.update(new_params) 45 | for key, val in new_params.items(): 46 | self.__setattr__(key, val) 47 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 5 | 6 | def slurm_filter(record): 7 | return int(os.environ['SLURM_PROCID']) == 0 8 | 9 | def config_logger(log_level=logging.INFO): 10 | logging.basicConfig(format=_format, level=log_level) 11 | root_logger = logging.getLogger() 12 | root_logger.addFilter(slurm_filter) 13 | 14 | def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'): 15 | 16 | if not os.path.exists(os.path.dirname(log_filename)): 17 | os.makedirs(os.path.dirname(log_filename)) 18 | 19 | if logger_name is not None: 20 | log = logging.getLogger(logger_name) 21 | else: 22 | log = logging.getLogger() 23 | 24 | fh = logging.FileHandler(log_filename) 25 | fh.setLevel(log_level) 26 | fh.setFormatter(logging.Formatter(_format)) 27 | log.addHandler(fh) 28 | 29 | def log_versions(): 30 | import torch 31 | import subprocess 32 | 33 | logging.info('--------------- Versions ---------------') 34 | logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip())) 35 | logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip())) 36 | logging.info('Torch: ' + str(torch.__version__)) 37 | logging.info('----------------------------------------') 38 | 39 | class disable_logging(object): 40 | """ 41 | A context manager to disable logging temporarily. 42 | """ 43 | 44 | def __init__(self, level=logging.ERROR): # pragma: no cover 45 | """ 46 | Initialize the context manager. 47 | """ 48 | logging.disable(level=level) 49 | 50 | def __enter__(self): # pragma: no cover 51 | """ 52 | Enter the context manager. 53 | """ 54 | return self 55 | 56 | def __exit__(self, type, value, traceback): # pragma: no cover 57 | """ 58 | Exit the context manager and enable logging. 59 | """ 60 | logging.disable(level=logging.NOTSET) 61 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | @torch.jit.script 5 | def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor: 6 | return 90. - j * 180./float(num_lat-1) 7 | 8 | @torch.jit.script 9 | def latitude_weighting_factor(j: torch.Tensor, num_lat: int, s: torch.Tensor) -> torch.Tensor: 10 | return num_lat * torch.cos(3.1416/180. * lat(j, num_lat))/s 11 | 12 | @torch.jit.script 13 | def weighted_rmse_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 14 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted rmse for each chann 15 | num_lat = pred.shape[2] 16 | #num_long = target.shape[2] 17 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 18 | 19 | s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat))) 20 | weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1)) 21 | result = torch.sqrt(torch.mean(weight * (pred - target)**2., dim=(-1,-2))) 22 | return result 23 | 24 | @torch.jit.script 25 | def weighted_rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 26 | result = weighted_rmse_channels(pred, target) 27 | return torch.mean(result, dim=0) 28 | 29 | @torch.jit.script 30 | def weighted_acc_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 31 | #takes in arrays of size [n, c, h, w] and returns latitude-weighted acc 32 | num_lat = pred.shape[2] 33 | #num_long = target.shape[2] 34 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 35 | s = torch.sum(torch.cos(3.1416/180. * lat(lat_t, num_lat))) 36 | weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1)) 37 | result = torch.sum(weight * pred * target, dim=(-1,-2)) / torch.sqrt(torch.sum(weight * pred * pred, dim=(-1,-2)) * torch.sum(weight * target * 38 | target, dim=(-1,-2))) 39 | return result 40 | 41 | @torch.jit.script 42 | def weighted_acc(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 43 | result = weighted_acc_channels(pred, target) 44 | return torch.mean(result, dim=0) 45 | -------------------------------------------------------------------------------- /start_tensorboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "314d73bf-204b-410b-9b58-4db9a461e1c4", 6 | "metadata": {}, 7 | "source": [ 8 | "# TensorBoard Launcher\n", 9 | "\n", 10 | "This notebook allows you to start TensorBoard on Perlmutter and view it in a normal browser tab.\n", 11 | "\n", 12 | "The notebook code below assumes you are using the hands-on tutorial path for tensorboard logs.\n", 13 | "\n", 14 | "When you run the cells below, TensorBoard will start but will not display here in the notebook. Instead, the final cell which calls `nersc_tensorboard_helper.tb_address()` will display a URL that you can click to open a new tab with TensorBoard." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "97bcd738-b1f5-40ed-acb0-2b2572f741e2", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import nersc_tensorboard_helper\n", 28 | "\n", 29 | "%load_ext tensorboard" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "2b4a192d-f5a7-4f8a-becb-ae52ed9cc036", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "log_dir = os.path.expandvars('${SCRATCH}/sc23-dl-tutorial/logs')" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "4dd0750e-7f27-4a66-9fce-4be91ea88835", 48 | "metadata": { 49 | "tags": [] 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "%%capture\n", 54 | "%tensorboard --logdir $log_dir --port 0" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "7f4158e2-d17a-40f3-a6e5-961e2615e4c3", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "nersc_tensorboard_helper.tb_address()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "0ad851ac-485a-436c-a88b-356aa0843641", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "pytorch-1.13.1", 81 | "language": "python", 82 | "name": "pytorch-1.13.1" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.9.15" 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 5 99 | } 100 | -------------------------------------------------------------------------------- /config/ViT.yaml: -------------------------------------------------------------------------------- 1 | base: &base 2 | 3 | # Model config 4 | embed_dim: 384 5 | depth: 12 6 | dropout: 0.0 7 | patch_size: 8 8 | num_heads: 8 9 | 10 | # Training config 11 | img_size: [360, 720] 12 | dt: 1 13 | global_batch_size: 16 # number of samples per training batch 14 | num_iters: 30000 15 | amp_mode: none 16 | enable_fused: false 17 | enable_jit: false 18 | expdir: '/logs' 19 | lr_schedule: 'cosine' 20 | lr: 5E-4 21 | warmup: 0 22 | optimizer: 'Adam' 23 | 24 | # Data 25 | data_loader_config: 'pytorch' 26 | num_data_workers: 0 # number of dataloader worker threads per proc 27 | n_in_channels: 20 28 | n_out_channels: 20 29 | train_data_path: '/data/train' 30 | valid_data_path: '/data/valid' 31 | inf_data_path: '/data/test' 32 | time_means_path: '/data/stats/time_means.npy' 33 | global_means_path: '/data/stats/global_means.npy' 34 | global_stds_path: '/data/stats/global_stds.npy' 35 | limit_nsamples: None 36 | limit_nsamples_val: None 37 | 38 | # Comms 39 | wireup_info: env 40 | wireup_store: tcp 41 | 42 | # limit the number of samples 43 | short: &short_ls 44 | <<: *base 45 | limit_nsamples: 512 46 | limit_nsamples_val: 128 47 | num_iters: 128 48 | 49 | # add optimization flags 50 | short_opt: 51 | <<: *short_ls 52 | data_loader_config: 'dali' 53 | num_data_workers: 8 54 | amp_mode: fp16 55 | enable_jit: true 56 | enable_fused: true 57 | 58 | # no samples limits 59 | opt: &opt 60 | <<: *base 61 | data_loader_config: 'dali' 62 | num_data_workers: 8 63 | amp_mode: fp16 64 | num_iters: 30000 65 | enable_fused: True 66 | enable_apex: True 67 | 68 | # ----- Data parallel scaling configs 69 | bs16_opt: 70 | <<: *opt 71 | global_batch_size: 16 72 | lr: 5e-4 73 | 74 | bs32_opt: 75 | <<: *opt 76 | global_batch_size: 32 77 | lr: 7.07e-4 78 | 79 | bs64_opt: 80 | <<: *opt 81 | global_batch_size: 64 82 | lr: 1e-3 83 | 84 | bs128_opt: 85 | <<: *opt 86 | global_batch_size: 128 87 | lr: 1.41e-3 88 | 89 | bs256_opt: 90 | <<: *opt 91 | global_batch_size: 256 92 | lr: 2e-3 93 | 94 | bs512_opt: 95 | <<: *opt 96 | global_batch_size: 512 97 | lr: 2.83e-3 98 | 99 | bs1024_opt: 100 | <<: *opt 101 | global_batch_size: 1024 102 | lr: 4e-3 103 | 104 | bs2048_opt: 105 | <<: *opt 106 | global_batch_size: 2048 107 | lr: 5.66e-3 108 | 109 | # Model parallel configs 110 | mp: &mp 111 | <<: *base 112 | num_iters: 30000 113 | global_batch_size: 64 114 | lr: 1e-3 115 | num_data_workers: 8 116 | embed_dim: 1024 # change to bigger model 117 | data_loader_config: 'dali' 118 | amp_mode: fp16 119 | enable_jit: true 120 | enable_fused: true 121 | 122 | mp_bs16: 123 | <<: *mp 124 | global_batch_size: 16 125 | lr: 5e-4 126 | 127 | mp_bs32: 128 | <<: *mp 129 | global_batch_size: 32 130 | lr: 7.07e-4 131 | 132 | # larger seq length (use local bs = 1 here) 133 | mp_patch4: 134 | <<: *mp 135 | patch_size: 4 136 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # output logs 163 | *.out 164 | -------------------------------------------------------------------------------- /networks/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | import warnings 7 | 8 | # These functions are directly pulled from timm: 9 | # https://github.com/huggingface/pytorch-image-models/tree/main/timm 10 | 11 | @torch.jit.script 12 | def drop_path(x: torch.Tensor, drop_prob: float = 0., training: bool = False) -> torch.Tensor: 13 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 14 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 15 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 16 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 17 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 18 | 'survival rate' as the argument. 19 | """ 20 | if drop_prob == 0. or not training: 21 | return x 22 | keep_prob = 1. - drop_prob 23 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 24 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 25 | random_tensor.floor_() # binarize 26 | output = x.div(keep_prob) * random_tensor 27 | return output 28 | 29 | 30 | class DropPath(nn.Module): 31 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 32 | """ 33 | def __init__(self, drop_prob=None): 34 | super(DropPath, self).__init__() 35 | self.drop_prob = drop_prob 36 | 37 | def forward(self, x): 38 | return drop_path(x, self.drop_prob, self.training) 39 | 40 | 41 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 42 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 43 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 44 | def norm_cdf(x): # pragma: no cover 45 | # Computes standard normal cumulative distribution function 46 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 47 | 48 | if (mean < a - 2 * std) or (mean > b + 2 * std): 49 | warnings.warn( 50 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 51 | "The distribution of values may be incorrect.", 52 | stacklevel=2, 53 | ) 54 | 55 | with torch.no_grad(): 56 | # Values are generated by using a truncated uniform distribution and 57 | # then using the inverse CDF for the normal distribution. 58 | # Get upper and lower cdf values 59 | l = norm_cdf((a - mean) / std) 60 | u = norm_cdf((b - mean) / std) 61 | 62 | # Uniformly fill tensor with values from [l, u], then translate to 63 | # [2l-1, 2u-1]. 64 | tensor.uniform_(2 * l - 1, 2 * u - 1) 65 | 66 | # Use inverse cdf transform for normal distribution to get truncated 67 | # standard normal 68 | tensor.erfinv_() 69 | 70 | # Transform to proper mean, std 71 | tensor.mul_(std * math.sqrt(2.0)) 72 | tensor.add_(mean) 73 | 74 | # Clamp to ensure it's in the proper range 75 | tensor.clamp_(min=a, max=b) 76 | return tensor 77 | 78 | 79 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # pragma: no cover 80 | r"""Fills the input Tensor with values drawn from a truncated 81 | normal distribution. The values are effectively drawn from the 82 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 83 | with values outside :math:`[a, b]` redrawn until they are within 84 | the bounds. The method used for generating the random values works 85 | best when :math:`a \leq \text{mean} \leq b`. 86 | Args: 87 | tensor: an n-dimensional `torch.Tensor` 88 | mean: the mean of the normal distribution 89 | std: the standard deviation of the normal distribution 90 | a: the minimum cutoff value 91 | b: the maximum cutoff value 92 | """ 93 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 94 | -------------------------------------------------------------------------------- /distributed/helpers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.distributed as dist 4 | from utils import comm 5 | 6 | 7 | def get_memory_format(tensor): 8 | """Helper routine to get the memory format""" 9 | if tensor.is_contiguous(memory_format=torch.channels_last): 10 | return torch.channels_last 11 | else: 12 | return torch.contiguous_format 13 | 14 | def sync_params(model): 15 | """Helper routine to ensure shared weights are the same after initialization""" 16 | with torch.no_grad(): 17 | # distributed sync step 18 | for param in model.parameters(): 19 | if not hasattr(param, "is_shared_mp"): 20 | param.is_shared_mp = ["model"] 21 | 22 | for comm_group in param.is_shared_mp: 23 | if comm.get_size(comm_group) > 1: 24 | tlist = [ 25 | torch.empty_like(param) 26 | for x in range(comm.get_size(comm_group)) 27 | ] 28 | tlist[comm.get_rank(comm_group)] = param 29 | # gather all weights in the comm group 30 | dist.all_gather(tlist, param, group=comm.get_group(comm_group)) 31 | # use weight of rank 0 32 | # important to use copy here otherwise the handle gets detaches from the optimizer 33 | param.copy_(tlist[0]) 34 | 35 | # distributed primitives 36 | def _reduce(input_, use_fp32=True, group=None): 37 | """All-reduce the input tensor across model parallel group.""" 38 | 39 | # Bypass the function if we are using only 1 GPU. 40 | if dist.get_world_size(group=group) == 1: 41 | return input_ 42 | 43 | # All-reduce. 44 | if use_fp32: 45 | dtype = input_.dtype 46 | inputf_ = input_.float().contiguous() 47 | dist.all_reduce(inputf_, group=group) 48 | input_ = inputf_.to(dtype) 49 | else: 50 | input_ = input_.contiguous() 51 | dist.all_reduce(input_, group=group) 52 | 53 | return input_ 54 | 55 | 56 | def split_tensor_along_dim(tensor, dim, num_chunks): 57 | """Helper routine to split a tensor along a given dimension""" 58 | assert ( 59 | dim < tensor.dim() 60 | ), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" 61 | assert ( 62 | tensor.shape[dim] % num_chunks == 0 63 | ), f"Error, cannot split dim {dim} evenly. Dim size is \ 64 | {tensor.shape[dim]} and requested numnber of splits is {num_chunks}" 65 | chunk_size = tensor.shape[dim] // num_chunks 66 | tensor_list = torch.split(tensor, chunk_size, dim=dim) 67 | 68 | return tensor_list 69 | 70 | def _split(input_, dim_, group=None): 71 | """Split the tensor along dim.""" 72 | # get input format 73 | input_format = get_memory_format(input_) 74 | 75 | # Bypass the function if we are using only 1 GPU. 76 | comm_size = dist.get_world_size(group=group) 77 | if comm_size == 1: 78 | return input_ 79 | 80 | # Split along last dimension. 81 | input_list = split_tensor_along_dim(input_, dim_, comm_size) 82 | 83 | # Note: torch.split does not create contiguous tensors by default. 84 | rank = dist.get_rank(group=group) 85 | output = input_list[rank].contiguous(memory_format=input_format) 86 | 87 | return output 88 | 89 | def _gather(input_, dim_, group=None): 90 | """Gather tensors and concatinate along the last dimension.""" 91 | # get input format 92 | input_format = get_memory_format(input_) 93 | 94 | comm_size = dist.get_world_size(group=group) 95 | # Bypass the function if we are using only 1 GPU. 96 | if comm_size == 1: 97 | return input_ 98 | 99 | # sanity checks 100 | assert ( 101 | dim_ < input_.dim() 102 | ), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions." 103 | 104 | # Size and dimension. 105 | comm_rank = dist.get_rank(group=group) 106 | 107 | input_ = input_.contiguous(memory_format=input_format) 108 | tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] 109 | tensor_list[comm_rank] = input_ 110 | dist.all_gather(tensor_list, input_, group=group) 111 | 112 | # Note: torch.cat already creates a contiguous tensor. 113 | output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) 114 | 115 | return output 116 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import glob 4 | import torch 5 | import random 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Dataset 8 | from torch.utils.data.distributed import DistributedSampler 9 | from torch import Tensor 10 | import h5py 11 | 12 | def worker_init(wrk_id): 13 | np.random.seed(torch.utils.data.get_worker_info().seed%(2**32 - 1)) 14 | 15 | def get_data_loader(params, files_pattern, distributed, train): 16 | dataset = ERA5Dataset(params, files_pattern, train) 17 | 18 | if distributed: 19 | if hasattr(params, 'data_num_shards'): 20 | # this is for model parallelism 21 | assert hasattr(params, 'data_shard_id'), 'please set data_num_shards and data_shard_id' 22 | sampler = DistributedSampler(dataset, shuffle=train, num_replicas=params.data_num_shards, rank=params.data_shard_id) 23 | else: 24 | sampler = DistributedSampler(dataset, shuffle=train) 25 | else: 26 | sampler = None 27 | 28 | 29 | dataloader = DataLoader(dataset, 30 | batch_size=int(params.local_batch_size), 31 | num_workers=params.num_data_workers, 32 | shuffle=(sampler is None), 33 | sampler=sampler, 34 | worker_init_fn=worker_init, 35 | drop_last=True, 36 | # persistent_workers=train, 37 | pin_memory=torch.cuda.is_available()) 38 | 39 | if train: 40 | return dataloader, dataset, sampler 41 | else: 42 | return dataloader, dataset 43 | 44 | class ERA5Dataset(Dataset): 45 | def __init__(self, params, location, train): 46 | self.params = params 47 | self.location = location 48 | self.train = train 49 | self.dt = params.dt 50 | self.n_in_channels = params.n_in_channels 51 | self.n_out_channels = params.n_out_channels 52 | self.normalize = True 53 | self.means = np.load(params.global_means_path)[0] 54 | self.stds = np.load(params.global_stds_path)[0] 55 | self.limit_nsamples = params.limit_nsamples if train else params.limit_nsamples_val 56 | self._get_files_stats() 57 | 58 | def _get_files_stats(self): 59 | self.files_paths = glob.glob(self.location + "/*.h5") 60 | self.files_paths.sort() 61 | self.years = [int(os.path.splitext(os.path.basename(x))[0][-4:]) for x in self.files_paths] 62 | self.n_years = len(self.files_paths) 63 | 64 | with h5py.File(self.files_paths[0], 'r') as _f: 65 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 66 | self.n_samples_per_year = _f['fields'].shape[0] 67 | self.img_shape_x = self.params.img_size[0] 68 | self.img_shape_y = self.params.img_size[1] 69 | assert(self.img_shape_x <= _f['fields'].shape[2] and self.img_shape_y <= _f['fields'].shape[3]), 'image shapes are greater than dataset image shapes' 70 | 71 | self.n_samples_total = self.n_years * self.n_samples_per_year 72 | if self.limit_nsamples is not None: 73 | self.n_samples_total = min(self.n_samples_total, self.limit_nsamples) 74 | logging.info("Overriding total number of samples to: {}".format(self.n_samples_total)) 75 | self.files = [None for _ in range(self.n_years)] 76 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 77 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels)) 78 | 79 | def _open_file(self, year_idx): 80 | _file = h5py.File(self.files_paths[year_idx], 'r') 81 | self.files[year_idx] = _file['fields'] 82 | 83 | def __len__(self): 84 | return self.n_samples_total 85 | 86 | def _normalize(self, img): 87 | if self.normalize: 88 | img -= self.means 89 | img /= self.stds 90 | return torch.as_tensor(img) 91 | 92 | def __getitem__(self, global_idx): 93 | year_idx = int(global_idx / self.n_samples_per_year) # which year 94 | local_idx = int(global_idx % self.n_samples_per_year) # which sample in that year 95 | 96 | # open image file 97 | if self.files[year_idx] is None: 98 | self._open_file(year_idx) 99 | step = self.dt # time step 100 | 101 | # boundary conditions to ensure we don't pull data that is not in a specific year 102 | local_idx = local_idx % (self.n_samples_per_year - step) 103 | if local_idx < step: 104 | local_idx += step 105 | 106 | # pre-process and get the image fields 107 | inp_field = self.files[year_idx][local_idx,:,0:self.img_shape_x,0:self.img_shape_y] 108 | tar_field = self.files[year_idx][local_idx+step,:,0:self.img_shape_x,0:self.img_shape_y] 109 | inp, tar = self._normalize(inp_field), self._normalize(tar_field) 110 | 111 | return inp, tar 112 | -------------------------------------------------------------------------------- /utils/dali_es_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import cupyx as cpx 5 | import h5py 6 | import logging 7 | 8 | class ERA5ES(object): 9 | # very important: the seed has to be constant across the workers, or otherwise mayhem: 10 | def __init__(self, location, 11 | train, batch_size, 12 | dt, img_size, 13 | n_in_channels, n_out_channels, 14 | num_shards, 15 | shard_id, 16 | limit_nsamples, 17 | enable_logging=True, 18 | seed=333): 19 | self.batch_size = batch_size 20 | self.location = location 21 | self.img_size = img_size 22 | self.train = train 23 | self.dt = dt 24 | self.n_in_channels = n_in_channels 25 | self.n_out_channels = n_out_channels 26 | self.rng = np.random.default_rng(seed = seed) 27 | self.num_shards = num_shards 28 | self.shard_id = shard_id 29 | self.limit_nsamples = limit_nsamples 30 | 31 | self._get_files_stats(enable_logging) 32 | self.shuffle = True if train else False 33 | 34 | def _get_files_stats(self, enable_logging): 35 | self.files_paths = glob.glob(self.location + "/*.h5") 36 | self.files_paths.sort() 37 | self.years = [int(os.path.splitext(os.path.basename(x))[0][-4:]) for x in self.files_paths] 38 | self.n_years = len(self.files_paths) 39 | 40 | with h5py.File(self.files_paths[0], 'r') as _f: 41 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 42 | self.n_samples_per_year = _f['fields'].shape[0] 43 | self.img_shape_x = self.img_size[0] 44 | self.img_shape_y = self.img_size[1] 45 | assert(self.img_shape_x <= _f['fields'].shape[2] and self.img_shape_y <= _f['fields'].shape[3]), 'image shapes are greater than dataset image shapes' 46 | 47 | self.n_samples_total = self.n_years * self.n_samples_per_year 48 | if self.limit_nsamples is not None: 49 | self.n_samples_total = min(self.n_samples_total, self.limit_nsamples) 50 | logging.info("Overriding total number of samples to: {}".format(self.n_samples_total)) 51 | self.n_samples_shard = self.n_samples_total // self.num_shards 52 | self.files = [None for _ in range(self.n_years)] 53 | self.dsets = [None for _ in range(self.n_years)] 54 | if enable_logging: 55 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 56 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location, self.n_samples_total, self.img_shape_x, self.img_shape_y, self.n_in_channels)) 57 | if self.num_shards > 1: 58 | logging.info("Using shards of size {} per rank".format(self.n_samples_shard)) 59 | 60 | # number of steps per epoch 61 | self.num_steps_per_epoch = self.n_samples_shard // self.batch_size 62 | self.last_epoch = None 63 | 64 | self.index_permutation = None 65 | # prepare buffers for double buffering 66 | self.current_buffer = 0 67 | self.inp_buffs = [cpx.zeros_pinned((self.n_in_channels, self.img_shape_x, self.img_shape_y), dtype=np.float32), 68 | cpx.zeros_pinned((self.n_in_channels, self.img_shape_x, self.img_shape_y), dtype=np.float32)] 69 | self.tar_buffs = [cpx.zeros_pinned((self.n_out_channels, self.img_shape_x, self.img_shape_y), dtype=np.float32), 70 | cpx.zeros_pinned((self.n_out_channels, self.img_shape_x, self.img_shape_y), dtype=np.float32)] 71 | 72 | def __len__(self): 73 | return self.n_samples_shard 74 | 75 | def __del__(self): 76 | for f in self.files: 77 | if f is not None: 78 | f.close() 79 | 80 | def __call__(self, sample_info): 81 | # check if epoch is done 82 | if sample_info.iteration >= self.num_steps_per_epoch: 83 | raise StopIteration 84 | 85 | # check if we need to shuffle again 86 | if sample_info.epoch_idx != self.last_epoch: 87 | self.last_epoch = sample_info.epoch_idx 88 | if self.shuffle: 89 | self.index_permutation = self.rng.permutation(self.n_samples_total) 90 | else: 91 | self.index_permutation = np.arange(self.n_samples_total) 92 | # shard the data 93 | start = self.n_samples_shard * self.shard_id 94 | end = start + self.n_samples_shard 95 | self.index_permutation = self.index_permutation[start:end] 96 | 97 | # determine local and sample idx 98 | sample_idx = self.index_permutation[sample_info.idx_in_epoch] 99 | year_idx = int(sample_idx / self.n_samples_per_year) #which year we are on 100 | local_idx = int(sample_idx % self.n_samples_per_year) #which sample in that year we are on - determines indices for centering 101 | 102 | step = self.dt # time step 103 | 104 | # boundary conditions to ensure we don't pull data that is not in a specific year 105 | local_idx = local_idx % (self.n_samples_per_year - step) 106 | if local_idx < step: 107 | local_idx += step 108 | 109 | if self.files[year_idx] is None: 110 | self.files[year_idx] = h5py.File(self.files_paths[year_idx], 'r') 111 | self.dsets[year_idx] = self.files[year_idx]['fields'] 112 | 113 | tmp_inp = self.dsets[year_idx][local_idx, ...] 114 | tmp_tar = self.dsets[year_idx][local_idx+step, ...] 115 | 116 | # handles to buffers buffers 117 | inp = self.inp_buffs[self.current_buffer] 118 | tar = self.tar_buffs[self.current_buffer] 119 | self.current_buffer = (self.current_buffer + 1) % 2 120 | 121 | # crop the pixels: 122 | inp[...] = tmp_inp[..., :self.img_shape_x, :self.img_shape_y] 123 | tar[...] = tmp_tar[..., :self.img_shape_x, :self.img_shape_y] 124 | 125 | 126 | return inp, tar 127 | -------------------------------------------------------------------------------- /distributed/mappings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.parallel import DistributedDataParallel 4 | from utils import comm 5 | 6 | # torch utils 7 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 8 | 9 | # helper functions 10 | from distributed.helpers import _reduce 11 | 12 | 13 | class _CopyToParallelRegion(torch.autograd.Function): 14 | """Pass the input to the parallel region.""" 15 | 16 | @staticmethod 17 | def symbolic(graph, input_, comm_id_): 18 | """symbolic method""" 19 | return input_ 20 | 21 | @staticmethod 22 | def forward(ctx, input_, comm_id_): 23 | ctx.comm_id = comm_id_ 24 | return input_ 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | if comm.is_distributed(ctx.comm_id): 29 | return _reduce(grad_output, group=comm.get_group(ctx.comm_id)), None 30 | else: 31 | return grad_output, None 32 | 33 | 34 | class _ReduceFromParallelRegion(torch.autograd.Function): 35 | """All-reduce the input from the parallel region.""" 36 | 37 | @staticmethod 38 | def symbolic(graph, input_, comm_id_): # pragma: no cover 39 | """symbolic method""" 40 | if comm.is_distributed(comm_id_): 41 | return _reduce(input_, group=comm.get_group(comm_id_)) 42 | else: 43 | return input_ 44 | 45 | @staticmethod 46 | def forward(ctx, input_, comm_id_): # pragma: no cover 47 | if comm.is_distributed(comm_id_): 48 | return _reduce(input_, group=comm.get_group(comm_id_)) 49 | else: 50 | return input_ 51 | 52 | @staticmethod 53 | def backward(ctx, grad_output): # pragma: no cover 54 | return grad_output, None 55 | 56 | 57 | # matmul parallel 58 | def copy_to_parallel_region(input_, comm_name): # pragma: no cover 59 | """Parallel copy helper""" 60 | return _CopyToParallelRegion.apply(input_, comm_name) 61 | 62 | 63 | def reduce_from_parallel_region(input_, comm_name): # pragma: no cover 64 | """Parallel reduction helper""" 65 | return _ReduceFromParallelRegion.apply(input_, comm_name) 66 | 67 | 68 | def gather_from_parallel_region(input_, dim, comm_name): 69 | """Parallel gather helper""" 70 | return _GatherFromParallelRegion.apply(input_, dim, comm_name) 71 | 72 | 73 | def init_ddp_model_and_reduction_hooks(model, 74 | device_ids, 75 | output_device, 76 | bucket_cap_mb = 25, 77 | broadcast_buffers = True, 78 | find_unused_parameters = False, 79 | gradient_as_bucket_view = True, 80 | static_graph = False): 81 | # early exit if we are not in a distributed setting: 82 | if not dist.is_initialized(): 83 | return model 84 | 85 | # set this to false in init and then find out if we can use it: 86 | need_hooks = False 87 | ddp_group = comm.get_group("data") 88 | # this is the trivial case 89 | if comm.get_size("model") == 1: 90 | # the simple case, we can just continue then 91 | ddp_group = None 92 | else: 93 | # count parameters and reduction groups 94 | num_parameters_total = 0 95 | num_parameters_shared_model = 0 96 | for param in model.parameters(): 97 | # # if it does not have any annotation, we assume it is shared between all model ranks 98 | # # not needed here, sync_params annotates everything 99 | # if not hasattr(param, "is_shared_mp"): 100 | # param.is_shared_mp = ["model"] 101 | # add the sharing type to the dict 102 | num_parameters_total += 1 103 | if "model" in param.is_shared_mp: 104 | num_parameters_shared_model += 1 105 | 106 | # if all parameters are shared between all model ranks, then the situation is easy 107 | if (num_parameters_shared_model == num_parameters_total): 108 | # we can always use DDP 109 | ddp_group = None 110 | # register some pre-multiply reduction hooks 111 | print("Setting up gradient hooks to account for shared parameter multiplicity") 112 | for param in model.parameters(): 113 | param.register_hook(lambda grad: grad * float(comm.get_size("model"))) 114 | else: 115 | ddp_group = comm.get_group("data") 116 | broadcast_buffers = False 117 | need_hooks = True 118 | 119 | model = DistributedDataParallel(model, 120 | device_ids = device_ids, 121 | output_device = output_device, 122 | bucket_cap_mb = bucket_cap_mb, 123 | broadcast_buffers = broadcast_buffers, 124 | find_unused_parameters = find_unused_parameters, 125 | gradient_as_bucket_view = gradient_as_bucket_view, 126 | static_graph = static_graph, 127 | process_group = ddp_group) 128 | if not need_hooks: 129 | return model 130 | 131 | # define comm hook: 132 | def reduction_comm_hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: 133 | # allreduce everything first: 134 | buff = bucket.buffer() 135 | # get future for allreduce 136 | fut = dist.all_reduce(buff, op=dist.ReduceOp.AVG, group=comm.get_group("data"), async_op=True).get_future() 137 | # get grads for shared weights 138 | params = bucket.parameters() 139 | def grad_reduction(fut, grads, group): 140 | # reduce remaining gradients 141 | coalesced = _flatten_dense_tensors(grads) 142 | dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=comm.get_group(group), async_op=False) 143 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 144 | buf.copy_(synced) 145 | return bucket.buffer() 146 | 147 | for group in comm.get_names(): 148 | if group == "data": 149 | continue 150 | grads = [] 151 | for p in params: 152 | if group in p.is_shared_mp: 153 | if p.grad is not None: 154 | grads.append(p.grad.data) 155 | if not grads: 156 | continue 157 | # append the new reduction functions 158 | fut = fut.then(lambda x: grad_reduction(x, grads=grads, group=group)) 159 | 160 | return fut 161 | # register model comm hook 162 | model.register_comm_hook(state=None, hook=reduction_comm_hook) 163 | return model 164 | -------------------------------------------------------------------------------- /utils/data_loader_dali.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 5 | from torch import Tensor 6 | 7 | #concurrent futures 8 | import concurrent.futures as cf 9 | 10 | # distributed stuff 11 | import torch.distributed as dist 12 | from utils import comm 13 | 14 | #dali stuff 15 | from nvidia.dali.pipeline import Pipeline 16 | import nvidia.dali.fn as fn 17 | import nvidia.dali.types as types 18 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 19 | 20 | # es helper 21 | import utils.dali_es_helper as esh 22 | 23 | def get_data_loader(params, files_pattern, distributed, train): 24 | dataloader = DaliDataLoader(params, files_pattern, train) 25 | 26 | if train: 27 | return dataloader, None, None 28 | else: 29 | return dataloader, None 30 | 31 | class DaliDataLoader(object): 32 | def get_pipeline(self): 33 | pipeline = Pipeline(batch_size = self.batch_size, 34 | num_threads = 2, 35 | device_id = self.device_index, 36 | py_num_workers = self.num_data_workers, 37 | py_start_method='spawn', 38 | seed = self.model_seed) 39 | 40 | 41 | with pipeline: # get input and target 42 | # get input and target 43 | inp, tar = fn.external_source(source = esh.ERA5ES(self.location, 44 | self.train, 45 | self.batch_size, 46 | self.dt, 47 | self.img_size, 48 | self.n_in_channels, 49 | self.n_out_channels, 50 | self.num_shards, 51 | self.shard_id, 52 | self.limit_nsamples, 53 | enable_logging = False, 54 | seed=self.global_seed), 55 | num_outputs = 2, 56 | layout = ["CHW", "CHW"], 57 | batch = False, 58 | no_copy = True, 59 | parallel = True) 60 | 61 | # upload to GPU 62 | inp = inp.gpu() 63 | tar = tar.gpu() 64 | 65 | if self.normalize: 66 | inp = fn.normalize(inp, 67 | device = "gpu", 68 | axis_names = "HW", 69 | batch = False, 70 | mean = self.in_bias, 71 | stddev = self.in_scale) 72 | 73 | tar = fn.normalize(tar, 74 | device = "gpu", 75 | axis_names = "HW", 76 | batch = False, 77 | mean = self.out_bias, 78 | stddev = self.out_scale) 79 | 80 | pipeline.set_outputs(inp, tar) 81 | return pipeline 82 | 83 | def __init__(self, params, location, train, seed = 333): 84 | # set up seeds 85 | # this one is the same on all ranks 86 | self.global_seed = seed 87 | # this one is the same for all ranks of the same model 88 | model_id = comm.get_world_rank() // comm.get_size("model") 89 | self.model_seed = self.global_seed + model_id 90 | # this seed is supposed to be diffferent for every rank 91 | self.local_seed = self.global_seed + comm.get_world_rank() 92 | 93 | self.num_data_workers = params.num_data_workers 94 | self.device_index = torch.cuda.current_device() 95 | self.batch_size = int(params.local_batch_size) 96 | 97 | self.location = location 98 | self.train = train 99 | self.dt = params.dt 100 | self.n_in_channels = params.n_in_channels 101 | self.n_out_channels = params.n_out_channels 102 | self.img_size = params.img_size 103 | self.limit_nsamples = params.limit_nsamples if train else params.limit_nsamples_val 104 | 105 | # load stats 106 | self.normalize = True 107 | means = np.load(params.global_means_path)[0][:self.n_in_channels] 108 | stds = np.load(params.global_stds_path)[0][:self.n_in_channels] 109 | self.in_bias = means 110 | self.in_scale = stds 111 | means = np.load(params.global_means_path)[0][:self.n_out_channels] 112 | stds = np.load(params.global_stds_path)[0][:self.n_out_channels] 113 | self.out_bias = means 114 | self.out_scale = stds 115 | 116 | # set sharding 117 | if dist.is_initialized(): 118 | self.num_shards = params.data_num_shards 119 | self.shard_id = params.data_shard_id 120 | else: 121 | self.num_shards = 1 122 | self.shard_id = 0 123 | 124 | # get img source data 125 | extsource = esh.ERA5ES(self.location, 126 | self.train, 127 | self.batch_size, 128 | self.dt, 129 | self.img_size, 130 | self.n_in_channels, 131 | self.n_out_channels, 132 | self.num_shards, 133 | self.shard_id, 134 | self.limit_nsamples, 135 | seed=self.global_seed) 136 | self.num_batches = extsource.num_steps_per_epoch 137 | del extsource 138 | 139 | # create pipeline 140 | self.pipeline = self.get_pipeline() 141 | self.pipeline.start_py_workers() 142 | self.pipeline.build() 143 | 144 | # create iterator 145 | self.iterator = DALIGenericIterator([self.pipeline], ['inp', 'tar'], 146 | auto_reset = True, 147 | last_batch_policy = LastBatchPolicy.DROP, 148 | prepare_first_batch = True) 149 | 150 | def __len__(self): 151 | return self.num_batches 152 | 153 | def __iter__(self): 154 | #self.iterator.reset() 155 | for token in self.iterator: 156 | inp = token[0]['inp'] 157 | tar = token[0]['tar'] 158 | 159 | yield inp, tar 160 | -------------------------------------------------------------------------------- /distributed/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import comm 5 | 6 | from torch.cuda import amp 7 | 8 | from networks.helpers import trunc_normal_ 9 | 10 | # matmul parallel 11 | from distributed.mappings import copy_to_parallel_region 12 | from distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region 13 | from typing import Tuple 14 | 15 | class DistributedMatmul(nn.Module): 16 | """Distributed Matrix Multiply""" 17 | 18 | def __init__( 19 | self, 20 | inp_dim, 21 | out_dim, 22 | comm_inp_name, 23 | comm_out_name, 24 | bias=True, 25 | ): 26 | super(DistributedMatmul, self).__init__() 27 | 28 | # get sizes 29 | self.comm_inp_name = comm_inp_name 30 | self.comm_out_name = comm_out_name 31 | comm_inp_size = comm.get_size(self.comm_inp_name) 32 | comm_out_size = comm.get_size(self.comm_out_name) 33 | 34 | assert ( 35 | inp_dim % comm_inp_size == 0 36 | ), f"Error, the size of input feature dim ({inp_dim}) has to be evenly divisible by the input feature comm dim ({comm_inp_size})" 37 | assert ( 38 | out_dim % comm_out_size == 0 39 | ), f"Error, the size of output feature dim ({out_dim}) has to be evenly divisible by the output feature comm dim ({comm_out_size})" 40 | 41 | # compute reduced dims 42 | inp_dim_local = inp_dim // comm_inp_size 43 | out_dim_local = out_dim // comm_out_size 44 | 45 | # parameters 46 | # weights are shared on all comm dims other than the ones used (comm_inp_name, comm_out_name) 47 | comm_names_shared = [c for c in comm.get_names(meta=False) if c not in [comm_inp_name, comm_out_name]] 48 | self.weight = nn.Parameter(torch.ones(out_dim_local, inp_dim_local)) 49 | self.weight.is_shared_mp = comm_names_shared 50 | self.weight.sharded_dims_mp = [ 51 | self.comm_out_name, 52 | self.comm_inp_name, 53 | None, 54 | None, 55 | ] 56 | if bias: 57 | self.bias = nn.Parameter(torch.ones(1, 1, out_dim_local)) 58 | self.bias.is_shared_mp = comm_names_shared 59 | self.bias.sharded_dims_mp = [None, self.comm_out_name, None, None] 60 | 61 | # init weights 62 | self._init_weights() 63 | 64 | def _init_weights(self): 65 | trunc_normal_(self.weight, std=0.02) 66 | if hasattr(self, "bias"): 67 | nn.init.constant_(self.bias, 0.0) 68 | 69 | # since this method is full of custom autograd, it cannot be jitted from torch frontend. 70 | @torch.jit.ignore 71 | def forward(self, x): 72 | # print("before matmul, shape = {}".format(x.shape)) 73 | x_cp = copy_to_parallel_region(x, self.comm_out_name) 74 | x_loc = F.linear(x_cp, self.weight, bias=None) 75 | x_out = reduce_from_parallel_region(x_loc, self.comm_inp_name) 76 | if hasattr(self, "bias"): 77 | x_out = x_out + self.bias 78 | # print("after matmul, shape = {}".format(x_out.shape)) 79 | return x_out 80 | 81 | 82 | class DistributedMLP(nn.Module): 83 | """Distributed MLP layer""" 84 | 85 | def __init__( 86 | self, 87 | in_features, 88 | hidden_features=None, 89 | out_features=None, 90 | comm_inp_name="col_matmul", 91 | comm_hidden_name="row_matmul", 92 | act_layer=nn.GELU, 93 | drop=0.0 94 | ): 95 | 96 | super(DistributedMLP, self).__init__() 97 | out_features = out_features or in_features 98 | hidden_features = hidden_features or in_features 99 | 100 | # get effective embedding size: 101 | comm_inp_size = comm.get_size(comm_inp_name) 102 | comm_hid_size = comm.get_size(comm_hidden_name) 103 | 104 | self.fc1 = DistributedMatmul( 105 | in_features, 106 | hidden_features, 107 | comm_inp_name=comm_inp_name, 108 | comm_out_name=comm_hidden_name, 109 | bias=True, 110 | ) 111 | 112 | self.fc2 = DistributedMatmul( 113 | hidden_features, 114 | out_features, 115 | comm_inp_name=comm_hidden_name, 116 | comm_out_name=comm_inp_name, 117 | bias=True, 118 | ) 119 | 120 | self.act = act_layer() 121 | self.drop = nn.Dropout(drop) 122 | 123 | def forward(self, x): 124 | x = self.fc1(x) 125 | x = self.act(x) 126 | x = self.drop(x) 127 | x = self.fc2(x) 128 | x = self.drop(x) 129 | return x 130 | 131 | 132 | class DistributedAttention(nn.Module): 133 | """Distributed Attention layer""" 134 | 135 | def __init__( 136 | self, 137 | dim, 138 | comm_inp_name, 139 | comm_hidden_name, 140 | num_heads=8, 141 | qkv_bias=False, 142 | qk_norm=False, 143 | attn_drop=0., 144 | proj_drop=0., 145 | norm_layer=nn.LayerNorm, 146 | ): 147 | 148 | super(DistributedAttention, self).__init__() 149 | 150 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 151 | self.num_heads = num_heads 152 | 153 | assert num_heads % comm.get_size(comm_hidden_name) == 0, 'heads are not evenly split across model ranks' 154 | self.num_heads_local = num_heads // comm.get_size(comm_hidden_name) 155 | self.head_dim = dim // self.num_heads 156 | self.scale = (dim // self.num_heads) ** -0.5 157 | self.fused_attn = True 158 | 159 | self.comm_inp_name = comm_inp_name 160 | self.comm_hidden_name = comm_hidden_name 161 | 162 | self.qkv = DistributedMatmul(dim, dim * 3, comm_inp_name, comm_hidden_name, bias=qkv_bias) 163 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 164 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 165 | self.attn_drop = nn.Dropout(attn_drop) 166 | self.proj = DistributedMatmul(dim, dim, comm_hidden_name, comm_inp_name, bias=False) 167 | self.proj_drop = nn.Dropout(proj_drop) 168 | 169 | def forward(self, x): 170 | B, N, C = x.shape 171 | 172 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads_local, self.head_dim).permute(2, 0, 3, 1, 4) 173 | q, k, v = qkv.unbind(0) 174 | q, k = self.q_norm(q), self.k_norm(k) 175 | 176 | if self.fused_attn: 177 | x = F.scaled_dot_product_attention( 178 | q, k, v, 179 | dropout_p=self.attn_drop.p, 180 | ) 181 | else: 182 | q = q * self.scale 183 | attn = q @ k.transpose(-2, -1) 184 | attn = attn.softmax(dim=-1) 185 | attn = self.attn_drop(attn) 186 | x = attn @ v 187 | 188 | # transpose back 189 | x = x.transpose(1, 2).reshape(B, N, self.num_heads_local * self.head_dim) 190 | 191 | # this is distributed again 192 | x = self.proj(x) 193 | 194 | # generally we have to be super careful with dropout layers, since 195 | # those are normalized over the dropouts. That would need to be reduced across nodes 196 | x = self.proj_drop(x) 197 | 198 | return x 199 | 200 | -------------------------------------------------------------------------------- /networks/vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial 5 | from networks.helpers import DropPath, trunc_normal_ 6 | 7 | # mp stuff 8 | from utils import comm 9 | from distributed.layers import DistributedMatmul, DistributedMLP, DistributedAttention 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 13 | super().__init__() 14 | out_features = out_features or in_features 15 | hidden_features = hidden_features or in_features 16 | self.fc1 = nn.Linear(in_features, hidden_features) 17 | self.act = act_layer() 18 | self.fc2 = nn.Linear(hidden_features, out_features) 19 | self.drop = nn.Dropout(drop) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.drop(x) 25 | x = self.fc2(x) 26 | x = self.drop(x) 27 | return x 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim, 33 | num_heads=8, 34 | qkv_bias=False, 35 | qk_norm=False, 36 | attn_drop=0., 37 | proj_drop=0., 38 | norm_layer=nn.LayerNorm, 39 | ): 40 | super().__init__() 41 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 42 | self.num_heads = num_heads 43 | self.head_dim = dim // num_heads 44 | self.scale = self.head_dim ** -0.5 45 | self.fused_attn = True 46 | 47 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 48 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 49 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(dim, dim) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | def forward(self, x): 55 | B, N, C = x.shape 56 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 57 | q, k, v = qkv.unbind(0) 58 | q, k = self.q_norm(q), self.k_norm(k) 59 | 60 | if self.fused_attn: 61 | x = F.scaled_dot_product_attention( 62 | q, k, v, 63 | dropout_p=self.attn_drop.p, 64 | ) 65 | else: 66 | q = q * self.scale 67 | attn = q @ k.transpose(-2, -1) 68 | attn = attn.softmax(dim=-1) 69 | attn = self.attn_drop(attn) 70 | x = attn @ v 71 | 72 | x = x.transpose(1, 2).reshape(B, N, C) 73 | x = self.proj(x) 74 | x = self.proj_drop(x) 75 | return x 76 | 77 | 78 | class Block(nn.Module): 79 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 80 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, 81 | comm_inp_name="col_matmul", comm_hidden_name="row_matmul"): 82 | super().__init__() 83 | 84 | if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: 85 | self.attn = DistributedAttention( 86 | dim, comm_inp_name=comm_inp_name, comm_hidden_name=comm_hidden_name, 87 | num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 88 | norm_layer=norm_layer) 89 | else: 90 | self.attn = Attention( 91 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, 92 | norm_layer=norm_layer) 93 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | 95 | self.norm1 = norm_layer(dim) 96 | self.norm2 = norm_layer(dim) 97 | 98 | mlp_hidden_dim = int(dim * mlp_ratio) 99 | 100 | # distribute MLP for model parallelism 101 | if (comm.get_size(comm_inp_name) * comm.get_size(comm_hidden_name)) > 1: 102 | self.mlp = DistributedMLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, 103 | comm_inp_name=comm_inp_name, 104 | comm_hidden_name=comm_hidden_name 105 | ) 106 | else: 107 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 108 | 109 | def forward(self, x): 110 | y = self.attn(self.norm1(x)) 111 | x = x + self.drop_path(y) 112 | x = x + self.drop_path(self.mlp(self.norm2(x))) 113 | return x 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """ Image to Patch Embedding 118 | """ 119 | def __init__(self, img_size=[224,224], patch_size=16, in_chans=3, embed_dim=768): 120 | super().__init__() 121 | # grid of patches 122 | self.h = img_size[0] // patch_size 123 | self.w = img_size[1] // patch_size 124 | num_patches = self.h * self.w 125 | self.img_size = img_size 126 | self.patch_size = patch_size 127 | self.num_patches = num_patches 128 | 129 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 130 | 131 | def forward(self, x): 132 | B, C, H, W = x.shape 133 | x = self.proj(x).flatten(2).transpose(1, 2) 134 | return x 135 | 136 | class VisionTransformer(nn.Module): 137 | def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, out_chans=3, embed_dim=768, depth=12, 138 | num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., 139 | drop_path_rate=0., norm_layer=nn.LayerNorm, 140 | comm_inp_name="col_matmul", comm_hidden_name="row_matmul", **kwargs): 141 | super().__init__() 142 | self.num_features = self.embed_dim = embed_dim 143 | self.patch_size = patch_size 144 | self.img_size = img_size 145 | self.out_ch = out_chans 146 | self.drop_rate = drop_rate 147 | self.comm_inp_name = comm_inp_name 148 | self.comm_hidden_name = comm_hidden_name 149 | 150 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim) 151 | num_patches = self.patch_embed.num_patches 152 | 153 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) 154 | self.pos_drop = nn.Dropout(p=drop_rate) 155 | 156 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 157 | 158 | self.blocks = nn.ModuleList([ 159 | Block( 160 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 161 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 162 | comm_inp_name=comm_inp_name, comm_hidden_name=comm_hidden_name) 163 | for i in range(depth)]) 164 | 165 | self.norm = norm_layer(embed_dim) 166 | 167 | self.out_size = self.out_ch * self.patch_size * self.patch_size 168 | 169 | self.head = nn.Linear(embed_dim, self.out_size, bias=False) 170 | 171 | trunc_normal_(self.pos_embed, std=.02) 172 | self.apply(self._init_weights) 173 | 174 | def _init_weights(self, m): 175 | if isinstance(m, nn.Linear): 176 | trunc_normal_(m.weight, std=.02) 177 | if isinstance(m, nn.Linear) and m.bias is not None: 178 | nn.init.constant_(m.bias, 0) 179 | elif isinstance(m, nn.LayerNorm): 180 | nn.init.constant_(m.bias, 0) 181 | nn.init.constant_(m.weight, 1.0) 182 | 183 | def prepare_tokens(self, x): 184 | B, nc, w, h = x.shape 185 | x = self.patch_embed(x) # patch linear embedding 186 | # add positional encoding to each token 187 | x = x + self.pos_embed 188 | return self.pos_drop(x) 189 | 190 | def forward_head(self, x): 191 | B, _, _ = x.shape # B x N x embed_dim 192 | x = x.reshape(B, self.patch_embed.h, self.patch_embed.w, self.embed_dim) 193 | B, h, w, _ = x.shape 194 | 195 | # apply head 196 | x = self.head(x) 197 | x = x.reshape(shape=(B, h, w, self.patch_size, self.patch_size, self.out_ch)) 198 | x = torch.einsum("nhwpqc->nchpwq", x) 199 | x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1])) 200 | 201 | return x 202 | 203 | def forward(self, x): 204 | x = self.prepare_tokens(x) 205 | for blk in self.blocks: 206 | x = blk(x) 207 | x = self.norm(x) 208 | x = self.forward_head(x) 209 | return x 210 | 211 | def ViT(params, **kwargs): 212 | model = VisionTransformer( 213 | img_size=params.img_size, 214 | in_chans=params.n_in_channels, out_chans=params.n_out_channels, 215 | patch_size=params.patch_size, 216 | embed_dim=params.embed_dim, depth=params.depth, num_heads=params.num_heads, mlp_ratio=4, 217 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 218 | drop_path_rate=params.dropout, 219 | drop_rate=params.dropout, 220 | attn_drop_rate=params.dropout, 221 | **kwargs) 222 | return model 223 | 224 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from utils.logging_utils import disable_logging 4 | import torch 5 | import math 6 | import numpy as np 7 | import torch.distributed as dist 8 | import datetime as dt 9 | from typing import Union 10 | 11 | # dummy placeholders 12 | _COMM_LIST = [] 13 | _COMM_NAMES = {} 14 | _COMM_NAMES_META = [] 15 | 16 | # world comm 17 | def get_size(comm_id: Union[str, int]) -> int: 18 | """Returns the size of a specified communicator.""" 19 | if isinstance(comm_id, int): 20 | cid = comm_id 21 | else: 22 | cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) 23 | 24 | if not dist.is_initialized() or (cid >= len(_COMM_LIST)): 25 | return 1 26 | else: 27 | return dist.get_world_size(group=_COMM_LIST[cid]) 28 | 29 | 30 | def get_rank(comm_id: Union[str, int]) -> int: 31 | """Returns the rank of a specified communicator.""" 32 | if isinstance(comm_id, int): 33 | cid = comm_id 34 | else: 35 | cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) 36 | 37 | if not dist.is_initialized() or (cid >= len(_COMM_LIST)): 38 | return 0 39 | else: 40 | return dist.get_rank(group=_COMM_LIST[cid]) 41 | 42 | 43 | def get_group(comm_id: Union[str, int]) -> int: 44 | """Returns the group of a specified communicator.""" 45 | if isinstance(comm_id, int): 46 | cid = comm_id 47 | else: 48 | cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) 49 | 50 | if not dist.is_initialized() or (cid >= len(_COMM_LIST)): 51 | raise IndexError(f"Error, comm with id {comm_id} not available.") 52 | else: 53 | return _COMM_LIST[cid] 54 | 55 | 56 | # specialized routines for world comms 57 | def get_world_size(): 58 | """Returns the world size""" 59 | if not dist.is_initialized(): 60 | return 1 61 | else: 62 | return dist.get_world_size() 63 | 64 | 65 | def get_world_rank(): 66 | """Returns the world rank""" 67 | if not dist.is_initialized(): 68 | return 0 69 | else: 70 | return dist.get_rank() 71 | 72 | 73 | def get_local_rank(): 74 | """Returns the local rank of the current process.""" 75 | if os.getenv("LOCAL_RANK") is not None and False: 76 | # Use PyTorch env var if available 77 | return int(os.getenv("LOCAL_RANK")) 78 | 79 | if not dist.is_initialized(): 80 | return 0 81 | else: 82 | num_gpu = int(os.getenv("NGPU_PER_NODE", torch.cuda.device_count())) 83 | return get_world_rank() % num_gpu 84 | 85 | 86 | def get_names(meta=True): 87 | """Returns the names of all available communicators.""" 88 | if meta: 89 | return _COMM_NAMES 90 | else: 91 | return [c for c,v in _COMM_NAMES.items() if c not in _COMM_NAMES_META] 92 | 93 | 94 | def is_distributed(name: str): 95 | """check if distributed.""" 96 | return name in _COMM_NAMES 97 | 98 | 99 | 100 | def init(params, verbose = False): 101 | init_process_group(info=params.wireup_info, store=params.wireup_store) 102 | 103 | # do individual wireup for model parallel comms: 104 | model_parallel_sizes = params.get("model_parallel_sizes", [1]) 105 | model_parallel_names = params.get("model_parallel_names", ["model"]) 106 | params.model_parallel_size = init_model_parallel_info( 107 | names=model_parallel_names, 108 | sizes=model_parallel_sizes, 109 | verbose=verbose 110 | ) 111 | 112 | def init_process_group(info: str, store: str): 113 | """Initial torch distributed process group based on ``info`` and ``store`` 114 | Uses NCCL 115 | Args: 116 | info: either ``env`` or ``mpi`` 117 | store: either ``file`` or ``tcp`` 118 | 119 | """ 120 | # set up global and local communicator 121 | if info == "env": 122 | world_size = int(os.getenv('WORLD_SIZE', 1)) 123 | world_rank = int(os.getenv('RANK', 0)) 124 | if os.getenv('WORLD_RANK') is not None: 125 | # Use WORLD_RANK if available for backwards compatibility 126 | world_rank = int(os.getenv('WORLD_RANK')) 127 | port = int(os.getenv('MASTER_PORT', 0)) 128 | master_address = os.getenv('MASTER_ADDR') 129 | if os.getenv('MASTER_ADDRESS') is not None: 130 | # Use MASTER_ADDRESS if available for backwards compatibility 131 | master_address = os.getenv('MASTER_ADDRESS') 132 | elif info == "mpi": 133 | import socket 134 | from mpi4py import MPI 135 | mpi_comm = MPI.COMM_WORLD.Dup() 136 | world_size = mpi_comm.Get_size() 137 | world_rank = mpi_comm.Get_rank() 138 | my_host = socket.gethostname() 139 | port = 29500 140 | master_address = None 141 | if world_rank == 0: 142 | master_address_info = socket.getaddrinfo(my_host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP) 143 | master_address = master_address_info[0][-1][0] 144 | master_address = mpi_comm.bcast(master_address, root=0) 145 | os.environ["MASTER_ADDRESS"] = master_address 146 | os.environ["MASTER_PORT"] = str(port) 147 | else: 148 | raise ValueError(f"Error, wireup-info {info} not supported") 149 | 150 | # set local rank to 0 if env var not available 151 | local_rank = int(os.getenv('LOCAL_RANK', 0)) 152 | 153 | if world_size > 1: 154 | with disable_logging(): 155 | if store == "file": 156 | wireup_file_path = os.getenv('WIREUP_FILE_PATH') 157 | store = dist.FileStore(wireup_file_path, world_size) 158 | elif store == "tcp": 159 | # create tcp store 160 | store = dist.TCPStore(host_name = master_address, 161 | port = port, 162 | world_size = world_size, 163 | is_master = (world_rank == 0), 164 | timeout = dt.timedelta(seconds=900)) 165 | else: 166 | store = None 167 | 168 | # initialize process groups 169 | dist.init_process_group(backend = 'nccl', 170 | rank = world_rank, 171 | world_size = world_size, 172 | store = store) 173 | 174 | def init_model_parallel_info(names, sizes, verbose=False): 175 | """Create communicators for model parallelism _COMM_LIST, _COMM_NAMES""" 176 | world_size = get_world_size() 177 | world_rank = get_world_rank() 178 | local_rank = get_local_rank() 179 | 180 | model_parallel_names = names 181 | model_parallel_sizes = sizes 182 | 183 | assert(len(model_parallel_names) == len(model_parallel_sizes)), "Please specify names for your communicators" 184 | model_parallel_size = math.prod(model_parallel_sizes) 185 | 186 | assert ( (world_size % model_parallel_size == 0) ), \ 187 | "Error, please make sure that the product of model parallel ranks evenly divides the total number of ranks" 188 | 189 | # we set this to be orthogonal to the MP groups 190 | # we can play tricks with the ddp_group later, in case if all the weights are shared 191 | data_parallel_size = world_size // model_parallel_size 192 | 193 | # create orthogonal communicators first 194 | global _COMM_LIST 195 | global _COMM_NAMES 196 | 197 | if world_size > 1: 198 | # set up the strides: 199 | model_parallel_sizes_reversed = model_parallel_sizes[::-1] 200 | model_grid = np.reshape(np.arange(0, model_parallel_size), model_parallel_sizes[::-1]) 201 | perm = np.roll(np.arange(0,len(model_parallel_sizes)), 1).tolist() 202 | ranks_lookup = {} 203 | 204 | comm_count = 0 205 | for mpname in model_parallel_names: 206 | base_group = np.reshape(model_grid, (-1, model_grid.shape[-1])) 207 | model_groups = [] 208 | for goffset in range(0, world_size, model_parallel_size): 209 | model_groups += sorted((goffset + base_group).tolist()) 210 | 211 | if verbose and world_rank == 0: 212 | print(f"Creating comm groups for id {mpname}: {model_groups}") 213 | 214 | for grp in model_groups: 215 | if len(grp) > 1: 216 | tmp_group = dist.new_group(ranks = grp) 217 | if world_rank in grp: 218 | _COMM_LIST.append(tmp_group) 219 | _COMM_NAMES[mpname] = comm_count 220 | comm_count += 1 221 | ranks_lookup[mpname] = model_groups 222 | 223 | # go for the next step 224 | model_grid = np.transpose(model_grid, perm) 225 | 226 | # helper routine for creating meta comms 227 | def merge_comms(comm_count, ranks_lookup, comm_name_1, comm_name_2, merge_name): 228 | if ((get_size(comm_name_1) == 1) and (get_size(comm_name_2) > 1)): 229 | if verbose and world_rank == 0: 230 | print(f'Creating comm groups for id {merge_name}: {ranks_lookup[comm_name_2]}') 231 | _COMM_LIST.append(get_group(comm_name_2)) 232 | _COMM_NAMES[merge_name] = comm_count 233 | _COMM_NAMES_META.append(merge_name) 234 | comm_count += 1 235 | elif ((get_size(comm_name_1) > 1) and (get_size(comm_name_2) == 1)): 236 | if verbose and world_rank == 0: 237 | print(f'Creating comm groups for id {merge_name}: {ranks_lookup[comm_name_1]}') 238 | _COMM_LIST.append(get_group(comm_name_1)) 239 | _COMM_NAMES[merge_name] = comm_count 240 | _COMM_NAMES_META.append(merge_name) 241 | comm_count += 1 242 | elif ((get_size(comm_name_1) > 1) and (get_size(comm_name_2) > 1)): 243 | # fuse the lists: 244 | def merge_ranks(list1, list2): 245 | coll = list1 + list2 246 | pooled = [set(subList) for subList in coll] 247 | merging = True 248 | while merging: 249 | merging=False 250 | for i,group in enumerate(pooled): 251 | merged = next((g for g in pooled[i+1:] if g.intersection(group)),None) 252 | if not merged: continue 253 | group.update(merged) 254 | pooled.remove(merged) 255 | merging = True 256 | return [list(x) for x in pooled] 257 | 258 | model_groups = merge_ranks(ranks_lookup[comm_name_1], ranks_lookup[comm_name_2]) 259 | if verbose and world_rank == 0: 260 | print(f'Creating comm groups for id {merge_name}: {model_groups}') 261 | for grp in model_groups: 262 | tmp_group = dist.new_group(ranks = grp) 263 | if world_rank in grp: 264 | _COMM_LIST.append(tmp_group) 265 | _COMM_NAMES[merge_name] = comm_count 266 | _COMM_NAMES_META.append(merge_name) 267 | comm_count += 1 268 | return comm_count 269 | 270 | # # no spatial for now: merge spatial 271 | # comm_count = merge_comms(comm_count, ranks_lookup, "h", "w", "spatial") 272 | 273 | # merge matmul 274 | comm_count = merge_comms(comm_count, ranks_lookup, "row_matmul", "col_matmul", "matmul") 275 | 276 | # now the data and model comm: 277 | model_groups = np.reshape(np.arange(0, world_size), (-1, model_parallel_size)).tolist() 278 | for grp in model_groups: 279 | if len(grp) > 1: 280 | tmp_group = dist.new_group(ranks = grp) 281 | if world_rank in grp: 282 | _COMM_LIST.append(tmp_group) 283 | _COMM_NAMES["model"] = comm_count 284 | _COMM_NAMES_META.append("model") 285 | comm_count += 1 286 | 287 | if data_parallel_size == world_size: 288 | if verbose and world_rank == 0: 289 | print(f"Creating comm groups for id data: {[list(range(0, world_size))]}") 290 | 291 | _COMM_LIST.append(None) 292 | _COMM_NAMES["data"] = comm_count 293 | else: 294 | data_groups = [sorted(list(i)) for i in zip(*model_groups)] 295 | if verbose and world_rank == 0: 296 | print(f"Creating comm groups for id data: {data_groups}") 297 | 298 | for grp in data_groups: 299 | tmp_group = dist.new_group(ranks = grp) 300 | if world_rank in grp: 301 | _COMM_LIST.append(tmp_group) 302 | _COMM_NAMES["data"] = comm_count 303 | _COMM_NAMES_META.append("data") 304 | 305 | # if verbose and world_rank == 0: 306 | # print(f"comm lists are: {_COMM_LIST}") 307 | # print(f"comm names are: {_COMM_NAMES}") 308 | 309 | return model_parallel_size 310 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.cuda.amp import autocast, GradScaler 11 | import torch.multiprocessing 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.parallel import DistributedDataParallel 14 | 15 | import logging 16 | from utils import logging_utils 17 | logging_utils.config_logger() 18 | from utils.YParams import YParams 19 | from utils import get_data_loader_distributed 20 | from utils.loss import l2_loss, l2_loss_opt 21 | from utils.metrics import weighted_rmse 22 | from utils.plots import generate_images 23 | from networks import vit 24 | 25 | def train(params, args, local_rank, world_rank, world_size): 26 | # set device and benchmark mode 27 | torch.backends.cudnn.benchmark = True 28 | torch.cuda.set_device(local_rank) 29 | device = torch.device('cuda:%d'%local_rank) 30 | 31 | # get data loader 32 | logging.info('rank %d, begin data loader init'%world_rank) 33 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(params, params.train_data_path, params.distributed, train=True) 34 | val_data_loader, valid_dataset = get_data_loader_distributed(params, params.valid_data_path, params.distributed, train=False) 35 | logging.info('rank %d, data loader initialized'%(world_rank)) 36 | 37 | # create model 38 | model = vit.ViT(params).to(device) 39 | 40 | if params.enable_jit: 41 | model = torch.compile(model) 42 | 43 | if params.amp_dtype == torch.float16: 44 | scaler = GradScaler() 45 | if params.distributed and not args.noddp: 46 | if args.disable_broadcast_buffers: 47 | model = DistributedDataParallel(model, device_ids=[local_rank], 48 | bucket_cap_mb=args.bucket_cap_mb, 49 | broadcast_buffers=False, 50 | gradient_as_bucket_view=True) 51 | else: 52 | model = DistributedDataParallel(model, device_ids=[local_rank], 53 | bucket_cap_mb=args.bucket_cap_mb) 54 | 55 | if params.enable_fused: 56 | optimizer = optim.Adam(model.parameters(), lr = params.lr, fused=True, betas=(0.9, 0.95)) 57 | else: 58 | optimizer = optim.Adam(model.parameters(), lr = params.lr, betas=(0.9, 0.95)) 59 | 60 | if world_rank == 0: 61 | logging.info(model) 62 | 63 | iters = 0 64 | startEpoch = 0 65 | 66 | if params.lr_schedule == 'cosine': 67 | if params.warmup > 0: 68 | lr_scale = lambda x: min((x+1)/params.warmup, 0.5*(1 + np.cos(np.pi*x/params.num_iters))) 69 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 70 | else: 71 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_iters) 72 | else: 73 | scheduler = None 74 | 75 | # select loss function 76 | if params.enable_jit: 77 | loss_func = l2_loss_opt 78 | else: 79 | loss_func = l2_loss 80 | 81 | if world_rank==0: 82 | logging.info("Starting Training Loop...") 83 | 84 | # Log initial loss on train and validation to tensorboard 85 | with torch.no_grad(): 86 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 87 | gen = model(inp) 88 | tr_loss = loss_func(gen, tar) 89 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 90 | gen = model(inp) 91 | val_loss = loss_func(gen, tar) 92 | val_rmse = weighted_rmse(gen, tar) 93 | if params.distributed: 94 | torch.distributed.all_reduce(tr_loss) 95 | torch.distributed.all_reduce(val_loss) 96 | torch.distributed.all_reduce(val_rmse) 97 | if world_rank==0: 98 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item()/world_size, 0) 99 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item()/world_size, 0) 100 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0]/world_size, 0) 101 | 102 | params.num_epochs = params.num_iters//len(train_data_loader) 103 | iters = 0 104 | t1 = time.time() 105 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 106 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 107 | if params.distributed and (train_sampler is not None): 108 | train_sampler.set_epoch(epoch) 109 | start = time.time() 110 | tr_loss = [] 111 | tr_time = 0. 112 | dat_time = 0. 113 | log_time = 0. 114 | 115 | model.train() 116 | step_count = 0 117 | for i, data in enumerate(train_data_loader, 0): 118 | if world_rank == 0: 119 | if (epoch == 3 and i == 0): 120 | torch.cuda.profiler.start() 121 | if (epoch == 3 and i == len(train_data_loader) - 1): 122 | torch.cuda.profiler.stop() 123 | 124 | torch.cuda.nvtx.range_push(f"step {i}") 125 | iters += 1 126 | dat_start = time.time() 127 | torch.cuda.nvtx.range_push(f"data copy in {i}") 128 | 129 | inp, tar = map(lambda x: x.to(device), data) 130 | torch.cuda.nvtx.range_pop() # copy in 131 | 132 | tr_start = time.time() 133 | b_size = inp.size(0) 134 | 135 | optimizer.zero_grad() 136 | 137 | torch.cuda.nvtx.range_push(f"forward") 138 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 139 | gen = model(inp) 140 | loss = loss_func(gen, tar) 141 | torch.cuda.nvtx.range_pop() #forward 142 | 143 | if params.amp_dtype == torch.float16: 144 | scaler.scale(loss).backward() 145 | torch.cuda.nvtx.range_push(f"optimizer") 146 | scaler.step(optimizer) 147 | torch.cuda.nvtx.range_pop() # optimizer 148 | scaler.update() 149 | else: 150 | loss.backward() 151 | torch.cuda.nvtx.range_push(f"optimizer") 152 | optimizer.step() 153 | torch.cuda.nvtx.range_pop() # optimizer 154 | 155 | if params.distributed: 156 | torch.distributed.all_reduce(loss) 157 | tr_loss.append(loss.item()/world_size) 158 | 159 | torch.cuda.nvtx.range_pop() # step 160 | # lr step 161 | scheduler.step() 162 | 163 | tr_end = time.time() 164 | tr_time += tr_end - tr_start 165 | dat_time += tr_start - dat_start 166 | step_count += 1 167 | 168 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 169 | end = time.time() 170 | 171 | if world_rank==0: 172 | iters_per_sec = step_count / (end - start) 173 | samples_per_sec = params["global_batch_size"] * iters_per_sec 174 | logging.info('Time taken for epoch %i is %f sec, avg %f samples/sec', 175 | epoch + 1, end - start, samples_per_sec) 176 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 177 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 178 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 179 | args.tboard_writer.add_scalar('Avg iters per sec', iters_per_sec, iters) 180 | args.tboard_writer.add_scalar('Avg samples per sec', samples_per_sec, iters) 181 | fig = generate_images([inp, tar, gen]) 182 | args.tboard_writer.add_figure('Visualization, t2m', fig, iters, close=True) 183 | 184 | val_start = time.time() 185 | val_loss = torch.zeros(1, device=device) 186 | val_rmse = torch.zeros((params.n_out_channels), dtype=torch.float32, device=device) 187 | valid_steps = 0 188 | model.eval() 189 | 190 | with torch.inference_mode(): 191 | with torch.no_grad(): 192 | for i, data in enumerate(val_data_loader, 0): 193 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 194 | inp, tar = map(lambda x: x.to(device), data) 195 | gen = model(inp) 196 | val_loss += loss_func(gen, tar) 197 | val_rmse += weighted_rmse(gen, tar) 198 | valid_steps += 1 199 | 200 | if params.distributed: 201 | torch.distributed.all_reduce(val_loss) 202 | val_loss /= world_size 203 | torch.distributed.all_reduce(val_rmse) 204 | val_rmse /= world_size 205 | 206 | val_rmse /= valid_steps # Avg validation rmse 207 | val_loss /= valid_steps 208 | val_end = time.time() 209 | if world_rank==0: 210 | logging.info(' Avg val loss={}'.format(val_loss.item())) 211 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 212 | args.tboard_writer.add_scalar('Loss/valid', val_loss, iters) 213 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], iters) 214 | args.tboard_writer.flush() 215 | 216 | t2 = time.time() 217 | tottime = t2 - t1 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument("--run_num", default='00', type=str, help='tag for indexing the current experiment') 223 | parser.add_argument("--yaml_config", default='./config/ViT.yaml', type=str, help='path to yaml file containing training configs') 224 | parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file') 225 | parser.add_argument("--amp_mode", default='none', type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode') 226 | parser.add_argument("--enable_fused", action='store_true', help='enable fused Adam optimizer') 227 | parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation') 228 | parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)') 229 | parser.add_argument("--num_iters", default=None, type=int, help='number of iters to run') 230 | parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader') 231 | parser.add_argument("--data_loader_config", default=None, type=str, choices=['pytorch', 'dali'], help="dataloader configuration. choices: 'pytorch', 'dali'") 232 | parser.add_argument("--bucket_cap_mb", default=25, type=int, help='max message bucket size in mb') 233 | parser.add_argument("--disable_broadcast_buffers", action='store_true', help='disable syncing broadcasting buffers') 234 | parser.add_argument("--noddp", action='store_true', help='disable DDP communication') 235 | args = parser.parse_args() 236 | 237 | run_num = args.run_num 238 | 239 | params = YParams(os.path.abspath(args.yaml_config), args.config) 240 | 241 | # Update config with modified args 242 | # set up amp 243 | if args.amp_mode != 'none': 244 | params.update({"amp_mode": args.amp_mode}) 245 | amp_dtype = torch.float32 246 | if params.amp_mode == "fp16": 247 | amp_dtype = torch.float16 248 | elif params.amp_mode == "bf16": 249 | amp_dtype = torch.bfloat16 250 | params.update({"amp_enabled": amp_dtype is not torch.float32, 251 | "amp_dtype" : amp_dtype, 252 | "enable_fused" : args.enable_fused, 253 | "enable_jit" : args.enable_jit 254 | }) 255 | 256 | if args.data_loader_config: 257 | params.update({"data_loader_config" : args.data_loader_config}) 258 | 259 | if args.num_iters: 260 | params.update({"num_iters" : args.num_iters}) 261 | 262 | if args.num_data_workers: 263 | params.update({"num_data_workers" : args.num_data_workers}) 264 | 265 | params.distributed = False 266 | if 'WORLD_SIZE' in os.environ: 267 | params.distributed = int(os.environ['WORLD_SIZE']) > 1 268 | world_size = int(os.environ['WORLD_SIZE']) 269 | else: 270 | world_size = 1 271 | 272 | world_rank = 0 273 | local_rank = 0 274 | if params.distributed: 275 | torch.distributed.init_process_group(backend='nccl', 276 | init_method='env://') 277 | world_rank = torch.distributed.get_rank() 278 | local_rank = int(os.environ['LOCAL_RANK']) 279 | 280 | if args.local_batch_size: 281 | # Manually override batch size 282 | params.local_batch_size = args.local_batch_size 283 | params.update({"global_batch_size" : world_size*args.local_batch_size}) 284 | else: 285 | # Compute local batch size based on number of ranks 286 | params.local_batch_size = params.global_batch_size//world_size 287 | 288 | # for dali data loader, set the actual number of data shards and id 289 | params.data_num_shards = world_size 290 | params.data_shard_id = world_rank 291 | 292 | # Set up directory 293 | baseDir = params.expdir 294 | expDir = os.path.join(baseDir, args.config + '/%dGPU/'%(world_size) + str(run_num) + '/') 295 | if world_rank==0: 296 | if not os.path.isdir(expDir): 297 | os.makedirs(expDir) 298 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 299 | params.log() 300 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 301 | 302 | params.experiment_dir = os.path.abspath(expDir) 303 | 304 | train(params, args, local_rank, world_rank, world_size) 305 | 306 | if params.distributed: 307 | torch.distributed.barrier() 308 | logging.info('DONE ---- rank %d'%world_rank) 309 | 310 | -------------------------------------------------------------------------------- /train_mp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | import pynvml 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.cuda.amp import autocast, GradScaler 12 | import torch.multiprocessing 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch.distributed import ReduceOp 16 | 17 | import logging 18 | from utils import logging_utils 19 | logging_utils.config_logger() 20 | from utils.YParams import YParams 21 | from utils import get_data_loader_distributed 22 | from utils import comm 23 | from utils.loss import l2_loss, l2_loss_opt 24 | from utils.metrics import weighted_rmse 25 | from networks import vit 26 | 27 | from distributed.mappings import init_ddp_model_and_reduction_hooks 28 | from distributed.helpers import sync_params 29 | 30 | from utils.plots import generate_images 31 | 32 | def train(params, args, local_rank, world_rank, world_size): 33 | # set device and benchmark mode 34 | torch.backends.cudnn.benchmark = True 35 | torch.cuda.set_device(local_rank) 36 | device = torch.device('cuda:%d'%local_rank) 37 | # torch.autograd.set_detect_anomaly(True) 38 | 39 | # init pynvml and get handle 40 | pynvml.nvmlInit() 41 | nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) 42 | 43 | # get data loader 44 | logging.info('rank %d, begin data loader init'%world_rank) 45 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(params, params.train_data_path, params.distributed, train=True) 46 | val_data_loader, valid_dataset = get_data_loader_distributed(params, params.valid_data_path, params.distributed, train=False) 47 | logging.info('rank %d, data loader initialized'%(world_rank)) 48 | 49 | # create model 50 | model = vit.ViT(params).to(device) 51 | 52 | if params.enable_jit: 53 | model = torch.compile(model) 54 | 55 | if params.amp_dtype == torch.float16: 56 | scaler = GradScaler() 57 | 58 | # weight initialization needs to be synced across shared weights 59 | if comm.get_size("model") > 1: 60 | sync_params(model) 61 | 62 | if params.distributed and not args.noddp: 63 | model = init_ddp_model_and_reduction_hooks(model, device_ids=[local_rank], 64 | output_device=[local_rank], 65 | bucket_cap_mb=args.bucket_cap_mb) 66 | 67 | 68 | if params.enable_fused: 69 | optimizer = optim.Adam(model.parameters(), lr = params.lr, fused=True, betas=(0.9, 0.95)) 70 | else: 71 | optimizer = optim.Adam(model.parameters(), lr = params.lr, betas=(0.9, 0.95)) 72 | 73 | if world_rank == 0: 74 | logging.info(model) 75 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 76 | logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.") 77 | 78 | iters = 0 79 | startEpoch = 0 80 | 81 | if params.lr_schedule == 'cosine': 82 | if params.warmup > 0: 83 | lr_scale = lambda x: min((x+1)/params.warmup, 0.5*(1 + np.cos(np.pi*x/params.num_iters))) 84 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 85 | else: 86 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_iters) 87 | else: 88 | scheduler = None 89 | 90 | # select loss function 91 | if params.enable_jit: 92 | loss_func = l2_loss_opt 93 | else: 94 | loss_func = l2_loss 95 | 96 | if world_rank==0: 97 | logging.info("Starting Training Loop...") 98 | 99 | # Log initial loss on train and validation to tensorboard 100 | with torch.no_grad(): 101 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 102 | gen = model(inp) 103 | tr_loss = loss_func(gen, tar) 104 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 105 | gen = model(inp) 106 | val_loss = loss_func(gen, tar) 107 | val_rmse = weighted_rmse(gen, tar) 108 | if params.distributed: 109 | torch.distributed.all_reduce(tr_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 110 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 111 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("data")) 112 | if world_rank==0: 113 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item(), 0) 114 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item(), 0) 115 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], 0) 116 | 117 | params.num_epochs = params.num_iters//len(train_data_loader) 118 | iters = 0 119 | t1 = time.time() 120 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 121 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 122 | if params.distributed and (train_sampler is not None): 123 | train_sampler.set_epoch(epoch) 124 | start = time.time() 125 | tr_loss = [] 126 | tr_time = 0. 127 | dat_time = 0. 128 | log_time = 0. 129 | 130 | model.train() 131 | step_count = 0 132 | 133 | for i, data in enumerate(train_data_loader, 0): 134 | if world_rank == 0: 135 | if (epoch == 3 and i == 0): 136 | torch.cuda.profiler.start() 137 | if (epoch == 3 and i == len(train_data_loader) - 1): 138 | torch.cuda.profiler.stop() 139 | 140 | torch.cuda.nvtx.range_push(f"step {i}") 141 | iters += 1 142 | dat_start = time.time() 143 | torch.cuda.nvtx.range_push(f"data copy in {i}") 144 | 145 | inp, tar = map(lambda x: x.to(device), data) 146 | torch.cuda.nvtx.range_pop() # copy in 147 | 148 | tr_start = time.time() 149 | b_size = inp.size(0) 150 | 151 | optimizer.zero_grad() 152 | 153 | torch.cuda.nvtx.range_push(f"forward") 154 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 155 | gen = model(inp) 156 | loss = loss_func(gen, tar) 157 | torch.cuda.nvtx.range_pop() #forward 158 | 159 | if world_rank == 0 and i == 1: # print the mem used 160 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 161 | logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.") 162 | 163 | if params.amp_dtype == torch.float16: 164 | scaler.scale(loss).backward() 165 | torch.cuda.nvtx.range_push(f"optimizer") 166 | scaler.step(optimizer) 167 | torch.cuda.nvtx.range_pop() # optimizer 168 | scaler.update() 169 | else: 170 | loss.backward() 171 | torch.cuda.nvtx.range_push(f"optimizer") 172 | optimizer.step() 173 | torch.cuda.nvtx.range_pop() # optimizer 174 | 175 | if params.distributed: 176 | torch.distributed.all_reduce(loss, op=ReduceOp.AVG, group=comm.get_group("data")) 177 | tr_loss.append(loss.item()) 178 | 179 | torch.cuda.nvtx.range_pop() # step 180 | # lr step 181 | scheduler.step() 182 | 183 | tr_end = time.time() 184 | tr_time += tr_end - tr_start 185 | dat_time += tr_start - dat_start 186 | step_count += 1 187 | 188 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 189 | end = time.time() 190 | 191 | if world_rank==0: 192 | iters_per_sec = step_count / (end - start) 193 | samples_per_sec = params["global_batch_size"] * iters_per_sec 194 | logging.info('Time taken for epoch %i is %f sec, avg %f samples/sec', 195 | epoch + 1, end - start, samples_per_sec) 196 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 197 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 198 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 199 | args.tboard_writer.add_scalar('Avg iters per sec', iters_per_sec, iters) 200 | args.tboard_writer.add_scalar('Avg samples per sec', samples_per_sec, iters) 201 | fig = generate_images([inp, tar, gen]) 202 | args.tboard_writer.add_figure('Visualization, t2m', fig, iters, close=True) 203 | 204 | val_start = time.time() 205 | val_loss = torch.zeros(1, device=device) 206 | val_rmse = torch.zeros((params.n_out_channels), dtype=torch.float32, device=device) 207 | valid_steps = 0 208 | model.eval() 209 | 210 | with torch.inference_mode(): 211 | with torch.no_grad(): 212 | for i, data in enumerate(val_data_loader, 0): 213 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 214 | inp, tar = map(lambda x: x.to(device), data) 215 | gen = model(inp) 216 | val_loss += loss_func(gen, tar) 217 | val_rmse += weighted_rmse(gen, tar) 218 | valid_steps += 1 219 | 220 | if params.distributed: 221 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 222 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("data")) 223 | 224 | val_rmse /= valid_steps # Avg validation rmse 225 | val_loss /= valid_steps 226 | val_end = time.time() 227 | if world_rank==0: 228 | logging.info(' Avg val loss={}'.format(val_loss.item())) 229 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 230 | args.tboard_writer.add_scalar('Loss/valid', val_loss, iters) 231 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], iters) 232 | args.tboard_writer.flush() 233 | 234 | torch.cuda.synchronize() 235 | t2 = time.time() 236 | tottime = t2 - t1 237 | pynvml.nvmlShutdown() 238 | 239 | 240 | if __name__ == '__main__': 241 | parser = argparse.ArgumentParser() 242 | parser.add_argument("--run_num", default='00', type=str, help='tag for indexing the current experiment') 243 | parser.add_argument("--yaml_config", default='./config/ViT.yaml', type=str, help='path to yaml file containing training configs') 244 | parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file') 245 | parser.add_argument("--amp_mode", default='none', type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode') 246 | parser.add_argument("--enable_fused", action='store_true', help='enable fused Adam optimizer') 247 | parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation') 248 | parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)') 249 | parser.add_argument("--num_iters", default=None, type=int, help='number of iters to run') 250 | parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader') 251 | parser.add_argument("--data_loader_config", default=None, type=str, choices=['pytorch', 'dali'], help="dataloader configuration. choices: 'pytorch', 'dali'") 252 | parser.add_argument("--bucket_cap_mb", default=25, type=int, help='max message bucket size in mb') 253 | parser.add_argument("--disable_broadcast_buffers", action='store_true', help='disable syncing broadcasting buffers') 254 | parser.add_argument("--noddp", action='store_true', help='disable DDP communication') 255 | 256 | # model parallelism arguments 257 | parser.add_argument("--row_parallel_size", default=1, type=int, help="Number of row comms") 258 | parser.add_argument("--col_parallel_size", default=1, type=int, help="Number of col comms") # not used here 259 | 260 | args = parser.parse_args() 261 | 262 | run_num = args.run_num 263 | 264 | params = YParams(os.path.abspath(args.yaml_config), args.config) 265 | 266 | # Update config with modified args 267 | # set up amp 268 | if args.amp_mode != 'none': 269 | params.update({"amp_mode": args.amp_mode}) 270 | amp_dtype = torch.float32 271 | if params.amp_mode == "fp16": 272 | amp_dtype = torch.float16 273 | elif params.amp_mode == "bf16": 274 | amp_dtype = torch.bfloat16 275 | params.update({"amp_enabled": amp_dtype is not torch.float32, 276 | "amp_dtype" : amp_dtype, 277 | "enable_fused" : args.enable_fused, 278 | "enable_jit" : args.enable_jit 279 | }) 280 | 281 | if args.data_loader_config: 282 | params.update({"data_loader_config" : args.data_loader_config}) 283 | 284 | if args.num_iters: 285 | params.update({"num_iters" : args.num_iters}) 286 | 287 | if args.num_data_workers: 288 | params.update({"num_data_workers" : args.num_data_workers}) 289 | 290 | params.distributed = False 291 | 292 | # setup model parallel sizes 293 | # we do not use col parallel size for this tutorial, but leave it in 294 | # so that an interested user can begin to extend 295 | assert ( 296 | args.col_parallel_size == 1 297 | ), f"col_parallel_size is not used in this example, please set to 1." 298 | 299 | 300 | params["model_parallel_sizes"] = [ 301 | args.row_parallel_size, 302 | args.col_parallel_size 303 | ] 304 | params["model_parallel_names"] = ["row_matmul", "col_matmul"] 305 | 306 | # initialize comm 307 | comm.init(params, verbose=True) 308 | 309 | # get info from comm 310 | world_size = comm.get_world_size() 311 | world_rank = comm.get_world_rank() 312 | local_rank = comm.get_local_rank() 313 | params.distributed = (world_size > 1) 314 | 315 | assert ( 316 | params["global_batch_size"] % comm.get_size("data") == 0 317 | ), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('data')} GPU." 318 | 319 | if args.local_batch_size: 320 | # Manually override batch size 321 | params.local_batch_size = args.local_batch_size 322 | params.update({"global_batch_size" : comm.get_size("data") * args.local_batch_size}) 323 | else: 324 | # Compute local batch size based on number of ranks 325 | params.local_batch_size = int(params["global_batch_size"] // comm.get_size("data")) 326 | 327 | # for data loader, set the actual number of data shards and id 328 | params.data_num_shards = comm.get_size("data") 329 | params.data_shard_id = comm.get_rank("data") 330 | 331 | # Set up directory 332 | baseDir = params.expdir 333 | expDir = os.path.join(baseDir, args.config + '/%dMP/'%(comm.get_size("model")) + str(run_num) + '/') 334 | if world_rank==0: 335 | if not os.path.isdir(expDir): 336 | os.makedirs(expDir) 337 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 338 | params.log() 339 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 340 | 341 | params.experiment_dir = os.path.abspath(expDir) 342 | 343 | train(params, args, local_rank, world_rank, world_size) 344 | 345 | if params.distributed: 346 | torch.distributed.barrier() 347 | logging.info('DONE ---- rank %d'%world_rank) 348 | 349 | -------------------------------------------------------------------------------- /train_mp_graphs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import gc 5 | import numpy as np 6 | import argparse 7 | import pynvml 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.cuda.amp import autocast, GradScaler 13 | import torch.multiprocessing 14 | from torch.utils.tensorboard import SummaryWriter 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.distributed import ReduceOp 17 | 18 | import logging 19 | from utils import logging_utils 20 | logging_utils.config_logger() 21 | from utils.YParams import YParams 22 | from utils import get_data_loader_distributed 23 | from utils import comm 24 | from utils.loss import l2_loss, l2_loss_opt 25 | from utils.metrics import weighted_rmse 26 | from networks import vit 27 | 28 | from distributed.mappings import init_ddp_model_and_reduction_hooks 29 | from distributed.helpers import sync_params 30 | 31 | from utils.plots import generate_images 32 | 33 | 34 | def capture_model(params, model, loss_func, scaler, capture_stream, device, num_warmup=20): 35 | logging.info("Capturing Model") 36 | inp_shape = (params.local_batch_size, params.n_in_channels, params.img_size[0], params.img_size[1]) 37 | tar_shape = (params.local_batch_size, params.n_in_channels, params.img_size[0], params.img_size[1]) 38 | 39 | # no zeros because loss is rel loss 40 | static_input = torch.randn(inp_shape, dtype=torch.float32, device=device) 41 | static_label = torch.randn(tar_shape, dtype=torch.float32, device=device) 42 | 43 | capture_stream.wait_stream(torch.cuda.current_stream()) 44 | with torch.cuda.stream(capture_stream): 45 | for _ in range(num_warmup): 46 | model.zero_grad(set_to_none=True) 47 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 48 | static_output = model(static_input).to(device) 49 | static_loss = loss_func(static_output, static_label) 50 | 51 | if params.amp_dtype == torch.float16: 52 | scaler.scale(static_loss).backward() 53 | else: 54 | static_loss.backward() 55 | 56 | # sync here 57 | capture_stream.synchronize() 58 | 59 | gc.collect() 60 | torch.cuda.empty_cache() 61 | 62 | # create graph 63 | graph = torch.cuda.CUDAGraph() 64 | 65 | # zero grads before capture: 66 | model.zero_grad(set_to_none=True) 67 | 68 | # do the capture with the context manager: 69 | with torch.cuda.graph(graph): 70 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 71 | static_output = model(static_input).to(device) 72 | static_loss = loss_func(static_output, static_label) 73 | 74 | if params.amp_dtype == torch.float16: 75 | scaler.scale(static_loss).backward() 76 | else: 77 | static_loss.backward() 78 | 79 | torch.cuda.current_stream().wait_stream(capture_stream) 80 | 81 | return graph, static_input, static_output, static_label, static_loss 82 | 83 | def train(params, args, local_rank, world_rank, world_size): 84 | # set device and benchmark mode 85 | torch.backends.cudnn.benchmark = True 86 | torch.cuda.set_device(local_rank) 87 | device = torch.device('cuda:%d'%local_rank) 88 | 89 | # init pynvml and get handle 90 | pynvml.nvmlInit() 91 | nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) 92 | 93 | # get data loader 94 | logging.info('rank %d, begin data loader init'%world_rank) 95 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(params, params.train_data_path, params.distributed, train=True) 96 | val_data_loader, valid_dataset = get_data_loader_distributed(params, params.valid_data_path, params.distributed, train=False) 97 | logging.info('rank %d, data loader initialized'%(world_rank)) 98 | 99 | # create model 100 | model = vit.ViT(params).to(device) 101 | 102 | if params.enable_jit: 103 | model = torch.compile(model) 104 | 105 | if params.amp_dtype == torch.float16: 106 | scaler = GradScaler() 107 | else: 108 | scaler = None 109 | 110 | # weight initialization needs to be synced across shared weights 111 | if comm.get_size("model") > 1: 112 | sync_params(model) 113 | 114 | capture_stream = torch.cuda.Stream() 115 | if params.distributed: 116 | with torch.cuda.stream(capture_stream): 117 | model = init_ddp_model_and_reduction_hooks(model, device_ids=[local_rank], 118 | output_device=[local_rank], 119 | bucket_cap_mb=args.bucket_cap_mb) 120 | capture_stream.synchronize() 121 | 122 | 123 | if world_rank == 0: 124 | logging.info(model) 125 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 126 | logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.") 127 | 128 | if params.enable_fused: 129 | optimizer = optim.Adam(model.parameters(), lr = params.lr, fused=True, betas=(0.9, 0.95)) 130 | else: 131 | optimizer = optim.Adam(model.parameters(), lr = params.lr, betas=(0.9, 0.95)) 132 | 133 | iters = 0 134 | startEpoch = 0 135 | 136 | if params.lr_schedule == 'cosine': 137 | if params.warmup > 0: 138 | lr_scale = lambda x: min((x+1)/params.warmup, 0.5*(1 + np.cos(np.pi*x/params.num_iters))) 139 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 140 | else: 141 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_iters) 142 | else: 143 | scheduler = None 144 | 145 | # select loss function 146 | if params.enable_jit: 147 | loss_func = l2_loss_opt 148 | else: 149 | loss_func = l2_loss 150 | 151 | # capture the model 152 | graph, static_input, static_output, static_label, static_loss = capture_model(params, model, loss_func, scaler, 153 | capture_stream, device, num_warmup=20) 154 | 155 | if world_rank==0: 156 | logging.info("Starting Training Loop...") 157 | 158 | # Log initial loss on train and validation to tensorboard 159 | with torch.no_grad(): 160 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 161 | gen = model(inp) 162 | tr_loss = loss_func(gen, tar) 163 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 164 | gen = model(inp) 165 | val_loss = loss_func(gen, tar) 166 | val_rmse = weighted_rmse(gen, tar) 167 | if params.distributed: 168 | torch.distributed.all_reduce(tr_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 169 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 170 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("data")) 171 | if world_rank==0: 172 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item(), 0) 173 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item(), 0) 174 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], 0) 175 | 176 | params.num_epochs = params.num_iters//len(train_data_loader) 177 | iters = 0 178 | t1 = time.time() 179 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 180 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 181 | if params.distributed and (train_sampler is not None): 182 | train_sampler.set_epoch(epoch) 183 | start = time.time() 184 | tr_loss = [] 185 | tr_time = 0. 186 | dat_time = 0. 187 | log_time = 0. 188 | 189 | model.train() 190 | step_count = 0 191 | 192 | for i, data in enumerate(train_data_loader, 0): 193 | if world_rank == 0: 194 | if (epoch == 3 and i == 0): 195 | torch.cuda.profiler.start() 196 | if (epoch == 3 and i == len(train_data_loader) - 1): 197 | torch.cuda.profiler.stop() 198 | 199 | torch.cuda.nvtx.range_push(f"step {i}") 200 | iters += 1 201 | dat_start = time.time() 202 | torch.cuda.nvtx.range_push(f"data copy in {i}") 203 | 204 | inp, tar = map(lambda x: x.to(device), data) 205 | torch.cuda.nvtx.range_pop() # copy in 206 | 207 | tr_start = time.time() 208 | b_size = inp.size(0) 209 | 210 | static_input.copy_(inp) 211 | static_label.copy_(tar) 212 | graph.replay() 213 | 214 | if world_rank == 0 and i == 1: 215 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 216 | logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.") 217 | 218 | if params.amp_dtype == torch.float16: 219 | torch.cuda.nvtx.range_push(f"optimizer") 220 | scaler.step(optimizer) 221 | torch.cuda.nvtx.range_pop() # optimizer 222 | scaler.update() 223 | else: 224 | torch.cuda.nvtx.range_push(f"optimizer") 225 | optimizer.step() 226 | torch.cuda.nvtx.range_pop() # optimizer 227 | 228 | 229 | if params.distributed: 230 | torch.distributed.all_reduce(static_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 231 | tr_loss.append(static_loss.item()) 232 | 233 | torch.cuda.nvtx.range_pop() # step 234 | # lr step 235 | scheduler.step() 236 | 237 | tr_end = time.time() 238 | tr_time += tr_end - tr_start 239 | dat_time += tr_start - dat_start 240 | step_count += 1 241 | 242 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 243 | end = time.time() 244 | 245 | if world_rank==0: 246 | iters_per_sec = step_count / (end - start) 247 | samples_per_sec = params["global_batch_size"] * iters_per_sec 248 | logging.info('Time taken for epoch %i is %f sec, avg %f samples/sec', 249 | epoch + 1, end - start, samples_per_sec) 250 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 251 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 252 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 253 | args.tboard_writer.add_scalar('Avg iters per sec', iters_per_sec, iters) 254 | args.tboard_writer.add_scalar('Avg samples per sec', samples_per_sec, iters) 255 | fig = generate_images([inp, tar, gen]) 256 | args.tboard_writer.add_figure('Visualization, t2m', fig, iters, close=True) 257 | 258 | val_start = time.time() 259 | val_loss = torch.zeros(1, device=device) 260 | val_rmse = torch.zeros((params.n_out_channels), dtype=torch.float32, device=device) 261 | valid_steps = 0 262 | model.eval() 263 | 264 | with torch.inference_mode(): 265 | with torch.no_grad(): 266 | for i, data in enumerate(val_data_loader, 0): 267 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 268 | inp, tar = map(lambda x: x.to(device), data) 269 | gen = model(inp) 270 | val_loss += loss_func(gen, tar) 271 | val_rmse += weighted_rmse(gen, tar) 272 | valid_steps += 1 273 | 274 | if params.distributed: 275 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("data")) 276 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("data")) 277 | 278 | val_rmse /= valid_steps # Avg validation rmse 279 | val_loss /= valid_steps 280 | val_end = time.time() 281 | if world_rank==0: 282 | logging.info(' Avg val loss={}'.format(val_loss.item())) 283 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 284 | args.tboard_writer.add_scalar('Loss/valid', val_loss, iters) 285 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], iters) 286 | args.tboard_writer.flush() 287 | 288 | torch.cuda.synchronize() 289 | t2 = time.time() 290 | tottime = t2 - t1 291 | pynvml.nvmlShutdown() 292 | 293 | 294 | if __name__ == '__main__': 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument("--run_num", default='00', type=str, help='tag for indexing the current experiment') 297 | parser.add_argument("--yaml_config", default='./config/ViT.yaml', type=str, help='path to yaml file containing training configs') 298 | parser.add_argument("--config", default='base', type=str, help='name of desired config in yaml file') 299 | parser.add_argument("--amp_mode", default='none', type=str, choices=['none', 'fp16', 'bf16'], help='select automatic mixed precision mode') 300 | parser.add_argument("--enable_fused", action='store_true', help='enable fused Adam optimizer') 301 | parser.add_argument("--enable_jit", action='store_true', help='enable JIT compilation') 302 | parser.add_argument("--local_batch_size", default=None, type=int, help='local batchsize (manually override global_batch_size config setting)') 303 | parser.add_argument("--num_iters", default=None, type=int, help='number of iters to run') 304 | parser.add_argument("--num_data_workers", default=None, type=int, help='number of data workers for data loader') 305 | parser.add_argument("--data_loader_config", default=None, type=str, choices=['pytorch', 'dali'], help="dataloader configuration. choices: 'pytorch', 'dali'") 306 | parser.add_argument("--bucket_cap_mb", default=25, type=int, help='max message bucket size in mb') 307 | parser.add_argument("--disable_broadcast_buffers", action='store_true', help='disable syncing broadcasting buffers') 308 | parser.add_argument("--noddp", action='store_true', help='disable DDP communication') 309 | 310 | # model parallelism arguments 311 | parser.add_argument("--row_parallel_size", default=1, type=int, help="Number of row comms") 312 | parser.add_argument("--col_parallel_size", default=1, type=int, help="Number of col comms") 313 | 314 | args = parser.parse_args() 315 | 316 | run_num = args.run_num 317 | 318 | params = YParams(os.path.abspath(args.yaml_config), args.config) 319 | 320 | # Update config with modified args 321 | # set up amp 322 | if args.amp_mode != 'none': 323 | params.update({"amp_mode": args.amp_mode}) 324 | amp_dtype = torch.float32 325 | if params.amp_mode == "fp16": 326 | amp_dtype = torch.float16 327 | elif params.amp_mode == "bf16": 328 | amp_dtype = torch.bfloat16 329 | params.update({"amp_enabled": amp_dtype is not torch.float32, 330 | "amp_dtype" : amp_dtype, 331 | "enable_fused" : args.enable_fused, 332 | "enable_jit" : args.enable_jit 333 | }) 334 | 335 | if args.data_loader_config: 336 | params.update({"data_loader_config" : args.data_loader_config}) 337 | 338 | if args.num_iters: 339 | params.update({"num_iters" : args.num_iters}) 340 | 341 | if args.num_data_workers: 342 | params.update({"num_data_workers" : args.num_data_workers}) 343 | 344 | params.distributed = False 345 | 346 | # setup model parallel sizes 347 | # we do not use col parallel size for this tutorial, but leave it in 348 | # so that an interested user can begin to extend 349 | assert ( 350 | args.col_parallel_size == 1 351 | ), f"col_parallel_size is not used in this example, please set to 1." 352 | 353 | params["model_parallel_sizes"] = [ 354 | args.row_parallel_size, 355 | args.col_parallel_size 356 | ] 357 | params["model_parallel_names"] = ["row_matmul", "col_matmul"] 358 | 359 | # initialize comm 360 | comm.init(params, verbose=True) 361 | 362 | # get info from comm 363 | world_size = comm.get_world_size() 364 | world_rank = comm.get_world_rank() 365 | local_rank = comm.get_local_rank() 366 | params.distributed = (world_size > 1) 367 | 368 | assert ( 369 | params["global_batch_size"] % comm.get_size("data") == 0 370 | ), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('data')} GPU." 371 | 372 | if args.local_batch_size: 373 | # Manually override batch size 374 | params.local_batch_size = args.local_batch_size 375 | params.update({"global_batch_size" : comm.get_size("data") * args.local_batch_size}) 376 | else: 377 | # Compute local batch size based on number of ranks 378 | params.local_batch_size = int(params["global_batch_size"] // comm.get_size("data")) 379 | 380 | # for data loader, set the actual number of data shards and id 381 | params.data_num_shards = comm.get_size("data") 382 | params.data_shard_id = comm.get_rank("data") 383 | 384 | # Set up directory 385 | baseDir = params.expdir 386 | expDir = os.path.join(baseDir, args.config + '/%dMP/'%(comm.get_size("model")) + str(run_num) + '/') 387 | if world_rank==0: 388 | if not os.path.isdir(expDir): 389 | os.makedirs(expDir) 390 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 391 | params.log() 392 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 393 | 394 | params.experiment_dir = os.path.abspath(expDir) 395 | 396 | train(params, args, local_rank, world_rank, world_size) 397 | 398 | if params.distributed: 399 | torch.distributed.barrier() 400 | logging.info('DONE ---- rank %d'%world_rank) 401 | 402 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC23 Deep Learning at Scale Tutorial 2 | 3 | This repository contains the example code material for the SC23 tutorial: 4 | *Deep Learning at Scale*. 5 | 6 | **Contents** 7 | * [Links](#links) 8 | * [Installation](#installation-and-setup) 9 | * [Model, data, and code overview](#model-data-and-training-code-overview) 10 | * [Single GPU training](#single-gpu-training) 11 | * [Single GPU performance](#single-gpu-performance-profiling-and-optimization) 12 | * [Distributed training with data parallelism](#distributed-training-with-data-parallelism) 13 | * [Multi-GPU model parallelism](#model-parallelism) 14 | 15 | ## Links 16 | 17 | Tutorial slides: https://drive.google.com/drive/folders/1wN1bCjHk2iocI6nowuzSugQopAwetCjR?usp=sharing 18 | 19 | Join the Slack workspace: https://join.slack.com/t/nersc-dl-tutorial/shared_invite/zt-25yvx25rr-RPWN1UclFvwgnRyr39Qt0w 20 | 21 | NERSC JupyterHub: https://jupyter.nersc.gov 22 | 23 | Data download (only needed if you want to run our examples elsewhere): https://portal.nersc.gov/project/dasrepo/pharring/sc23_data 24 | 25 | ## Installation and Setup 26 | 27 | ### Software environment 28 | 29 | The instructions in this README are intended to be used with NERSC's Perlmutter machine. 30 | 31 | Access to the Perlmutter machine is provided for this tutorial via [jupyter.nersc.gov](https://jupyter.nersc.gov). 32 | Training account setup instructions will be given during the session. Once you have your provided account credentials, you can log in to Jupyter via the link (leave the OTP field blank when logging into Jupyter). 33 | Once logged into the hub, start a session by clicking the button for Perlmutter Login Node (other options will not work with this tutorial material). 34 | This will open up a session on a Perlmutter login node, from which you can submit jobs to the GPU nodes and monitor their progress. 35 | 36 | To begin, start a terminal from JupyterHub and clone this repository with: 37 | ```bash 38 | git clone https://github.com/NERSC/sc23-dl-tutorial.git 39 | ``` 40 | You can use the Jupyter file browser to view and edit source files and scripts. For all of the example commands provided below, make sure you are running them from within the top-level folder of the repository. In your terminal, change to the directory with 41 | ```bash 42 | cd sc23-dl-tutorial 43 | ``` 44 | 45 | For running slurm jobs on Perlmutter, we will use training accounts which are provided under the `ntrain4` project. The slurm script `submit_pm.sh` included in the repository is configured to work automatically as is, but if you submit your own custom jobs via `salloc` or `sbatch` you must include the following flags for slurm: 46 | * `-A ntrain4_g` is required for training accounts 47 | * `--reservation=` is required to access the set of GPU nodes we have reserved for the duration of the tutorial. For the morning session use `` set to `sc23_dl_tutorial_1`, and for the afternoon session use `` set to `sc23_dl_tutorial_2` (we have two different size reservations for the single-GPU and multi-GPU sections respectively) 48 | 49 | The code can be run using the `nersc/pytorch:ngc-23.07-v0` docker container. On Perlmutter, docker containers are run via [shifter](https://docs.nersc.gov/development/shifter/), and this container is already downloaded and automatically invoked by our job submission scripts. Our container is based on the [NVIDIA NGC 23.07 pytorch container](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-07.html), with a few additional packages added. 50 | 51 | ### Installing Nsight Systems 52 | In this tutorial, we will be generating profile files using NVIDIA Nsight Systems on the remote systems. In order to open and view these 53 | files on your local computer, you will need to install the Nsight Systems program, which you can download [here](https://developer.nvidia.com/gameworksdownload#?search=nsight%20systems). Select the download option required for your system (e.g. Mac OS host for MacOS, Window Host for Windows, or Linux Host .rpm/.deb/.run for Linux). You may need to sign up and create a login to NVIDIA's developer program if you do not 54 | already have an account to access the download. Proceed to run and install the program using your selected installation method. 55 | 56 | ## Model, data, and training code overview 57 | 58 | The model in this repository is adapted from modern applications of deep learning for weather forecasting, e.g. [FourCastNet](https://arxiv.org/abs/2202.11214), [GraphCast](https://arxiv.org/abs/2212.12794), [Pangu-Weather](https://arxiv.org/abs/2211.02556), and others. These models are trained on a combination of observed and simulated data describing the atmospheric state on Earth over the past several decades, and they achieve impressive performance in terms of accuracy and forecast speed when compared against traditional numerical weather prediction (NWP) models. 59 | 60 | ![weather forecasting animation](tutorial_images/weather_forecasting.gif) 61 | 62 | For these examples we will be using a [vision transformer](https://arxiv.org/abs/2010.11929) (ViT) architecture, for which our implementation can be found in [`networks/vit.py`](networks/vit.py). ViTs are a widely-used architecture in computer vision, known for scaling well to large datasets and being able to model long-range dependencies easily via the use of self-attention layers. While 'vanilla' ViTs are not necessarily state-of-the-art on the weather forecasting task, they are a good model to use for educational purposes as they are widely used in a variety of applications and the techniques outlined here (e.g. channel-wise tensor parallelism) would transfer well to other applications (e.g. NLP/LLMs). 63 | 64 | ![vision transformer schematic](tutorial_images/vit_schematic.png) 65 | 66 | Data-driven weather models are typically trained on the [ERA5 reanalysis dataset](https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5) from the European Center for Medium-range Weather Forecasts (ECMWF). This dataset represents 40 years of atmospheric data on a 25km global grid, combining simulation outputs assimilated with observations. The basic data loading pipeline for training models is defined in [`utils/data_loader.py`](utils/data_loader.py), whose primary components are: 67 | * The `ERA5Dataset`, which accesses the data stored on disk and serves input-output pairs of the atmospheric variables for training and validation. Each pair is a randomly-sampled snapshots of the atmosphere, separated by a 6 hour timestep. The model is given the first snapshot as input and is trained to predict the snapshot 6 hours later. 68 | * For this repository, we will be using a spatially-downsampled version of the data so training runs a little faster. 69 | * The above dataset object is passed to a PyTorch `DataLoader` which takes the samples and combines them into a batch for each training step. 70 | 71 | It is common practice to decay the learning rate according to some schedule as the model trains, so that the optimizer can settle into sharper minima during gradient descent. Here we opt for the cosine learning rate decay schedule, which starts at an intial learning rate and decays continuously throughout training according to a cosine function. This is handled by the `LambdaLR` or `CosineAnnealingLR` utilities from PyTorch, set in [`train.py`](train.py) -- the `LambdaLR` uses custom logic to implement learning rate warm-up if desired for distributed training. 72 | 73 | As we will see in the [Single GPU performance profiling and optimization](#Single-GPU-performance-profiling-and-optimization) section, we'll be able to speed up the baseline data loading pipeline significantly by making various improvements. Another option introduced in that section is to do data loading using NVIDIA's DALI library, for which the implementation can be found in [`utils/data_loader_dali.py`](utils/data_loader_dali.py). 74 | 75 | The script to train the model is [`train.py`](train.py), which uses the following arguments to load the desired training setup: 76 | ``` 77 | --yaml_config YAML_CONFIG path to yaml file containing training configs 78 | --config CONFIG name of desired config in yaml file 79 | ``` 80 | 81 | Based on the selected configuration, the train script will then: 82 | 1. Set up the data loaders and construct our ViT model, the Adam optimizer, and our L2 loss function. 83 | 2. Loop over training epochs to run the training. See if you can identify the following key components: 84 | * Looping over data batches from our data loader. 85 | * Applying the forward pass of the model and computing the loss function. 86 | * Calling `backward()` on the loss value to backpropagate gradients. Note the use of the `grad_scaler` will be explained below when enabling mixed precision. 87 | * Applying the model to the validation dataset and logging training and validation metrics to visualize in TensorBoard (see if you can find where we construct the TensorBoard `SummaryWriter` and where our specific metrics are logged via the `add_scalar` call). 88 | 89 | More info on the model and data can be found in the [slides](https://drive.google.com/drive/folders/1wN1bCjHk2iocI6nowuzSugQopAwetCjR?usp=drive_link). If you are experimenting with this repository after the tutorial date, you can download the data from here: https://portal.nersc.gov/project/dasrepo/pharring/sc23_data. 90 | Note that you will have to adjust the data path in `submit_pm.sh` to point your personal copy after downloading. 91 | 92 | ## Single GPU training 93 | 94 | First, let us look at the performance of the training script without optimizations on a single GPU. 95 | 96 | On Perlmutter for the tutorial, we will be submitting jobs to the batch queue. To submit this job, use the following command: 97 | ``` 98 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short 99 | ``` 100 | `submit_pm.sh` is a batch submission script that defines resources to be requested by SLURM as well as the command to run. 101 | Note that any arguments for `train.py`, such as the desired config (`--config`), can be added after `submit_pm.sh` when submitting, and they will be passed to `train.py` properly. 102 | When using batch submission, you can see the job output by viewing the file `vit-era5-.out` in the submission 103 | directory. You can find the job id of your job using the command `squeue --me` and looking at the first column of the output. 104 | 105 | This will run 128 training iterations on a single GPU using a default batch size of 16. 106 | See [`config/ViT.yaml`](config/ViT.yaml) for specific configuration details. 107 | Note we will use the default batch size for the optimization work in the next section 108 | and will push beyond to larger batch sizes in the distributed training section. 109 | 110 | While the model predicts many atmospheric variables, we will focus on the prediction error of surface wind at 10m `u10` to represent model quality. 111 | In the baseline configuration, the model converges to a u10 RMSE of about `0.13` on 112 | the validation dataset in about 22k training iterations. This takes around 22 hours hours to run, so to save time we have already included an example TensorBoard log for the `base` config in the `example_logs` directory for you. 113 | We want to compare our training results against the `base` config baseline, and TensorBoard makes this easy as long as all training runs are stored in the same place. 114 | To copy the example TensorBoard log to the scratch directory where our training jobs will output their logs, do 115 | ``` 116 | mkdir -p $SCRATCH/sc23-dl-tutorial/logs 117 | cp -r ./example_logs/base $SCRATCH/sc23-dl-tutorial/logs 118 | ``` 119 | 120 | To view results in TensorBoard, open the [`start_tensorboard.ipynb`](start_tensorboard.ipynb) notebook and follow the instructions in it to launch a TensorBoard session in your browser. Once you have TensorBoard open, you should see a dashboard with data for the loss values, learning rate, and average iterations per second. Looking at the validation loss for the `base` config, you should see the following training curve: 121 | ![baseline training](tutorial_images/baseline_tb.png) 122 | 123 | As our training with the `short` config runs, it should also dump the training metrics to the TensorBoard directory, and TensorBoard will parse the data and display it for you. You can hit the refresh button in the upper-right corner of TensorBoard to update the plots with the latest data. 124 | 125 | ## Single GPU performance profiling and optimization 126 | 127 | This is the performance of the baseline script for the first four epochs on a 40GB A100 card with batch size 16 using the `short` config, which limits the number of training and validation samples to 512 and 128 samples respectively: 128 | ``` 129 | 2023-09-26 21:29:00,679 - root - INFO - Starting Training Loop... 130 | 2023-09-26 21:30:08,688 - root - INFO - Time taken for epoch 1 is 63.020848512649536 sec, avg 8.12429556382808 samples/sec 131 | 2023-09-26 21:30:08,690 - root - INFO - Avg train loss=0.579061 132 | 2023-09-26 21:30:17,316 - root - INFO - Avg val loss=0.419114 133 | 2023-09-26 21:30:17,316 - root - INFO - Total validation time: 8.258756637573242 sec 134 | 2023-09-26 21:31:11,898 - root - INFO - Time taken for epoch 2 is 54.578805923461914 sec, avg 9.380930772248819 samples/sec 135 | 2023-09-26 21:31:11,898 - root - INFO - Avg train loss=0.390744 136 | 2023-09-26 21:31:18,989 - root - INFO - Avg val loss=0.375897 137 | 2023-09-26 21:31:18,989 - root - INFO - Total validation time: 6.766376972198486 sec 138 | 2023-09-26 21:32:13,578 - root - INFO - Time taken for epoch 3 is 54.58618688583374 sec, avg 9.37966231403635 samples/sec 139 | 2023-09-26 21:32:13,579 - root - INFO - Avg train loss=0.356790 140 | 2023-09-26 21:32:20,685 - root - INFO - Avg val loss=0.353825 141 | 2023-09-26 21:32:20,685 - root - INFO - Total validation time: 6.767474889755249 sec 142 | 2023-09-26 21:33:15,322 - root - INFO - Time taken for epoch 4 is 54.63401126861572 sec, avg 9.371451740614114 samples/sec 143 | 2023-09-26 21:33:15,322 - root - INFO - Avg train loss=0.343523 144 | 2023-09-26 21:33:22,444 - root - INFO - Avg val loss=0.347524 145 | 2023-09-26 21:33:22,444 - root - INFO - Total validation time: 6.78272819519043 sec 146 | ``` 147 | After the first epoch, we see that the throughput achieved is about 9.3 samples/s. 148 | 149 | ### Profiling with Nsight Systems 150 | #### Adding NVTX ranges and profiler controls 151 | Before generating a profile with Nsight, we can add NVTX ranges to the script to add context to the produced timeline. 152 | We can add some manually defined NVTX ranges to the code using `torch.cuda.nvtx.range_push` and `torch.cuda.nvtx.range_pop`. 153 | We can also add calls to `torch.cuda.profiler.start()` and `torch.cuda.profiler.stop()` to control the duration of the profiling 154 | (e.g., limit profiling to single epoch). You can `grep` through `train.py` for these API calls to see what we've added in this example. 155 | 156 | To generate a profile using our scripts on Perlmutter, run the following command: 157 | ``` 158 | ENABLE_PROFILING=1 PROFILE_OUTPUT=baseline sbatch -n1 -t 20 submit_pm.sh --config=short 159 | ``` 160 | This command will run four epochs of the training script, profiling only the last epoch run. It will produce a file `baseline.nsys-rep` that can be opened in the Nsight System's program. The arg `--trace=cuda,nvtx` is optional and is used here to disable OS Runtime tracing for speed. The arg `-c cudaProfilerApi` instructs the profiler to only profile the duration of the runtime between the `torch.cuda.profiler.start()` and `torch.cuda.profiler.stop()` calls. 161 | 162 | Loading this profile ([`baseline.nsys-rep`](sample_nsys_profiles/baseline.nsys-rep)) in Nsight Systems will look like this: 163 | ![NSYS Baseline](tutorial_images/nsys_baseline.png) 164 | 165 | From this zoomed out view, we can see some idle gaps between training iterations. These gaps are due to the data loading, which we will address in the next section. 166 | 167 | Beyond this, we can zoom into a single iteration and get an idea of where compute time is being spent: 168 | ![NSYS Baseline zoomed](tutorial_images/nsys_baseline_zoomed.png) 169 | 170 | 171 | ### Data loading optimizations 172 | #### Improving the native PyTorch dataloader performance 173 | The PyTorch dataloader has several knobs we can adjust to improve performance. If you look at the `DataLoader` initialization in 174 | `utils/data_loader.py`, you'll see we've already set several useful options, like `pin_memory` and `persistent_workers`. 175 | `pin_memory` has the data loader read input data into pinned host memory, which typically yields better host-to-device and device-to-host 176 | memcopy bandwidth. `persistent_workers` allows PyTorch to reuse workers between epochs, instead of the default behavior which is to 177 | respawn them. One knob we've left to adjust is the `num_workers` argument, which we can control via the `--num_data_workers` command 178 | line arg to our script. The default used by PyTorch is `num_workers=0`, which runs data loading *sequentially* in the training Python process. This is one source of the large gaps we observed in the first profile. By setting `num_workers>0`, we enable PyTorch to use multiprocessing to perform data loading in a side process to hide this cost. We can experiment with the number of workers to see if performance is improved. 179 | 180 | We can run this experiment on Perlmutter by running the following command: 181 | ``` 182 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 183 | ``` 184 | 185 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16 and 4 data workers: 186 | ``` 187 | 2023-09-26 21:18:44,034 - root - INFO - Time taken for epoch 1 is 43.38622999191284 sec, avg 11.800979252989633 samples/sec 188 | 2023-09-26 21:18:44,035 - root - INFO - Avg train loss=0.577452 189 | 2023-09-26 21:18:49,678 - root - INFO - Avg val loss=0.418861 190 | 2023-09-26 21:18:49,679 - root - INFO - Total validation time: 5.192107677459717 sec 191 | 2023-09-26 21:19:30,999 - root - INFO - Time taken for epoch 2 is 41.31834650039673 sec, avg 12.391589774655767 samples/sec 192 | 2023-09-26 21:19:31,001 - root - INFO - Avg train loss=0.390701 193 | 2023-09-26 21:19:36,231 - root - INFO - Avg val loss=0.372989 194 | 2023-09-26 21:19:36,232 - root - INFO - Total validation time: 4.828763484954834 sec 195 | 2023-09-26 21:20:17,169 - root - INFO - Time taken for epoch 3 is 40.93515610694885 sec, avg 12.507586355902198 samples/sec 196 | 2023-09-26 21:20:17,171 - root - INFO - Avg train loss=0.356448 197 | 2023-09-26 21:20:22,409 - root - INFO - Avg val loss=0.355308 198 | 2023-09-26 21:20:22,409 - root - INFO - Total validation time: 4.8364222049713135 sec 199 | 2023-09-26 21:21:03,627 - root - INFO - Time taken for epoch 4 is 41.21541452407837 sec, avg 12.42253671137738 samples/sec 200 | 2023-09-26 21:21:03,629 - root - INFO - Avg train loss=0.343769 201 | 2023-09-26 21:21:08,695 - root - INFO - Avg val loss=0.347322 202 | 2023-09-26 21:21:08,695 - root - INFO - Total validation time: 4.662991523742676 sec 203 | ``` 204 | 205 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16 and 8 data workers: 206 | ``` 207 | 2023-09-26 21:18:59,332 - root - INFO - Time taken for epoch 1 is 45.54166626930237 sec, avg 11.242452065156796 samples/sec 208 | 2023-09-26 21:18:59,333 - root - INFO - Avg train loss=0.577049 209 | 2023-09-26 21:19:05,821 - root - INFO - Avg val loss=0.419312 210 | 2023-09-26 21:19:05,821 - root - INFO - Total validation time: 6.034433841705322 sec 211 | 2023-09-26 21:19:47,276 - root - INFO - Time taken for epoch 2 is 41.4513418674469 sec, avg 12.351831736527942 samples/sec 212 | 2023-09-26 21:19:47,277 - root - INFO - Avg train loss=0.389672 213 | 2023-09-26 21:19:53,126 - root - INFO - Avg val loss=0.373399 214 | 2023-09-26 21:19:53,126 - root - INFO - Total validation time: 5.442654848098755 sec 215 | 2023-09-26 21:20:36,164 - root - INFO - Time taken for epoch 3 is 43.03392195701599 sec, avg 11.897590940268149 samples/sec 216 | 2023-09-26 21:20:36,165 - root - INFO - Avg train loss=0.355648 217 | 2023-09-26 21:20:41,650 - root - INFO - Avg val loss=0.353144 218 | 2023-09-26 21:20:41,650 - root - INFO - Total validation time: 5.0764687061309814 sec 219 | 2023-09-26 21:21:24,205 - root - INFO - Time taken for epoch 4 is 42.55116081237793 sec, avg 12.032574205380119 samples/sec 220 | 2023-09-26 21:21:24,206 - root - INFO - Avg train loss=0.342547 221 | 2023-09-26 21:21:30,034 - root - INFO - Avg val loss=0.346312 222 | 2023-09-26 21:21:30,034 - root - INFO - Total validation time: 5.32970404624939 sec 223 | ``` 224 | 225 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16 and 16 data workers: 226 | ``` 227 | 2023-09-26 21:27:28,037 - root - INFO - Time taken for epoch 1 is 47.84179139137268 sec, avg 10.701940397915974 samples/sec 228 | 2023-09-26 21:27:28,037 - root - INFO - Avg train loss=0.575174 229 | 2023-09-26 21:27:34,156 - root - INFO - Avg val loss=0.418625 230 | 2023-09-26 21:27:34,156 - root - INFO - Total validation time: 5.6687445640563965 sec 231 | 2023-09-26 21:28:20,330 - root - INFO - Time taken for epoch 2 is 46.170273542404175 sec, avg 11.089386324076328 samples/sec 232 | 2023-09-26 21:28:20,331 - root - INFO - Avg train loss=0.388281 233 | 2023-09-26 21:28:25,466 - root - INFO - Avg val loss=0.373171 234 | 2023-09-26 21:28:25,467 - root - INFO - Total validation time: 4.725477695465088 sec 235 | 2023-09-26 21:29:11,989 - root - INFO - Time taken for epoch 3 is 46.51985311508179 sec, avg 11.006053667740602 samples/sec 236 | 2023-09-26 21:29:11,991 - root - INFO - Avg train loss=0.354430 237 | 2023-09-26 21:29:17,389 - root - INFO - Avg val loss=0.351720 238 | 2023-09-26 21:29:17,390 - root - INFO - Total validation time: 4.990921974182129 sec 239 | 2023-09-26 21:30:02,644 - root - INFO - Time taken for epoch 4 is 45.25181460380554 sec, avg 11.314463397384783 samples/sec 240 | 2023-09-26 21:30:02,645 - root - INFO - Avg train loss=0.341476 241 | 2023-09-26 21:30:07,853 - root - INFO - Avg val loss=0.345648 242 | 2023-09-26 21:30:07,853 - root - INFO - Total validation time: 4.801238775253296 sec 243 | ``` 244 | 245 | Increasing the number of workers to 4 improves throughput to around 12.4 samples per second, while increasing to more workers yields a slight degradation in performance. 246 | 247 | We can run the 4 worker configuration through profiler using the instructions in the previous section with the added `--num_data_workers` 248 | argument and load that profile in Nsight Systems. This is what this profile ([`4workers.nsys-rep`](sample_nsys_profiles/4workers.nsys-rep)) looks like: 249 | ![NSYS Native Data](tutorial_images/nsys_nativedata_4workers.png) 250 | 251 | and zoomed in: 252 | ![NSYS Native Data Zoomed](tutorial_images/nsys_nativedata_4workers_zoomed.png) 253 | 254 | With 4 data workers, the idle gaps between steps are resolved, improving the throughput. Looking at the zoomed in profile, we 255 | still see that the H2D copy in of the input data (i.e. the light green activity at the beginning of the step) takes some time and runs in same CUDA stream as the compute. One option here is to implement a prefetching 256 | mechanism in PyTorch directly using CUDA streams to concurrently load and copy in the next batch of input during the current batch, however 257 | this is left as an exercise outside of this tutorial. A good example of this can be found in [here](https://github.com/NVIDIA/DeepLearningExamples/blob/41f582bd9f65f6ebede77532b7cd64f038a8a380/PyTorch/Classification/ConvNets/image_classification/dataloaders.py#L354) 258 | 259 | #### Using NVIDIA DALI 260 | While we were able to get more performance out of the PyTorch native DataLoader, there are several potential overheads we cannot overcome in 261 | PyTorch alone: 262 | 1. The PyTorch DataLoader will use CPU operations for all I/O operations as well as data augmentations 263 | 2. The PyTorch DataLoader uses multi-processing to spawn data workers, which has performance overheads compared to true threads 264 | 265 | The NVIDIA DALI library is a data loading library that can address both of these points: 266 | 1. DALI can perform a wide array of data augmentation operations on the GPU, benefitting from acceleration relative to the CPU. 267 | 2. DALI maintains its own worker threads in the C++ backend, enabling much more performant threading and concurrent operation. 268 | 269 | For this tutorial, we've provided an alternative data loader using DALI to accelerate the data augementations used in this training script that can be found in `utils/data_loader_dali.py`. This data loader is enabled via the command line 270 | argument `--data_loader_config=dali` to the training script. 271 | 272 | We can run this experiment on Perlmutter using DALI with 8 worker threads by running the following command: 273 | ``` 274 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali 275 | ``` 276 | 277 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16 and DALI: 278 | ``` 279 | 2023-09-26 22:01:24,018 - root - INFO - Time taken for epoch 1 is 38.48570990562439 sec, avg 12.887900501674608 samples/sec 280 | 2023-09-26 22:01:24,020 - root - INFO - Avg train loss=0.587751 281 | 2023-09-26 22:01:28,215 - root - INFO - Avg val loss=0.425913 282 | 2023-09-26 22:01:28,215 - root - INFO - Total validation time: 3.625275135040283 sec 283 | 2023-09-26 22:02:06,757 - root - INFO - Time taken for epoch 2 is 38.5366051197052 sec, avg 13.286069138928777 samples/sec 284 | 2023-09-26 22:02:06,759 - root - INFO - Avg train loss=0.394755 285 | 2023-09-26 22:02:10,374 - root - INFO - Avg val loss=0.376960 286 | 2023-09-26 22:02:10,375 - root - INFO - Total validation time: 3.0912325382232666 sec 287 | 2023-09-26 22:02:48,918 - root - INFO - Time taken for epoch 3 is 38.53870248794556 sec, avg 13.285346079312022 samples/sec 288 | 2023-09-26 22:02:48,921 - root - INFO - Avg train loss=0.359927 289 | 2023-09-26 22:02:52,485 - root - INFO - Avg val loss=0.355281 290 | 2023-09-26 22:02:52,485 - root - INFO - Total validation time: 3.052870988845825 sec 291 | 2023-09-26 22:03:31,039 - root - INFO - Time taken for epoch 4 is 38.549081325531006 sec, avg 13.281769173080217 samples/sec 292 | 2023-09-26 22:03:31,041 - root - INFO - Avg train loss=0.345901 293 | 2023-09-26 22:03:34,623 - root - INFO - Avg val loss=0.349484 294 | 2023-09-26 22:03:34,623 - root - INFO - Total validation time: 3.0705220699310303 sec 295 | ``` 296 | 297 | We can run the DALI case through profiler using the instructions in the earlier section with the added `--data_loader_config=dali` 298 | argument and load that profile in Nsight Systems. This is what this profile ([`dali.nsys-rep`](sample_nsys_profiles/dali.nsys-rep)) looks like: 299 | ![NSYS DALI](tutorial_images/nsys_dali.png) 300 | 301 | and zoomed in to a single iteration: 302 | ![NSYS DALI Zoomed](tutorial_images/nsys_dali_zoomed.png) 303 | 304 | With DALI, you will see that there are now multiple CUDA stream rows in the timeline view, corresponding to internal streams DALI uses 305 | to run data augmentation kernels and any memory movement concurrently with the existing PyTorch compute kernels. Stream 16 in this view shows concurrent H2D memory copies of the batch input data, which is an improvement over the native dataloader. 306 | 307 | ### Enabling Mixed Precision Training 308 | Now that the data loading performance has been improved, we can start focusing on pushing compute performance. As a first step to improve the compute performance of this training script, we can enable automatic mixed precision (AMP) in PyTorch. AMP provides a simple way for users to convert existing FP32 training scripts to mixed FP32/FP16 of FP32/BF16 precision, unlocking 309 | faster computation with Tensor Cores on NVIDIA GPUs. 310 | 311 | The AMP module in torch is composed of two main parts: `torch.cuda.amp.GradScaler` and `torch.cuda.amp.autocast`. `torch.cuda.amp.GradScaler` handles automatic loss scaling to control the range of FP16 gradients when using FP16 precision. Note that since BF16 precision maintains the range of FP32, loss scaling is not required when using AMP with this data type. 312 | The `torch.cuda.amp.autocast` context manager handles converting model operations to BF16/FP16 where appropriate. 313 | 314 | As a quick note, the A100 GPUs we've been using to report results thus far have been able to benefit from Tensor Core compute via the use of TF32 precision operations, enabled by default for CUDNN and CUBLAS in PyTorch. We can measure the benefit of TF32 precision usage on the A100 GPU by temporarily disabling it via setting the environment variable `NVIDIA_TF32_OVERRIDE=0`. 315 | We can run this experiment on Perlmutter by running the following command: 316 | ``` 317 | NVIDIA_TF32_OVERRIDE=0 sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali 318 | ``` 319 | yields the following result for 4 epochs: 320 | ``` 321 | 2023-09-26 22:37:05,159 - root - INFO - Time taken for epoch 1 is 50.52403998374939 sec, avg 9.817108848768507 samples/sec 322 | 2023-09-26 22:37:05,160 - root - INFO - Avg train loss=0.585963 323 | 2023-09-26 22:37:10,101 - root - INFO - Avg val loss=0.428734 324 | 2023-09-26 22:37:10,102 - root - INFO - Total validation time: 4.387829065322876 sec 325 | 2023-09-26 22:38:00,735 - root - INFO - Time taken for epoch 2 is 50.62814474105835 sec, avg 10.112952047100768 samples/sec 326 | 2023-09-26 22:38:00,736 - root - INFO - Avg train loss=0.394807 327 | 2023-09-26 22:38:05,347 - root - INFO - Avg val loss=0.378771 328 | 2023-09-26 22:38:05,348 - root - INFO - Total validation time: 4.096112012863159 sec 329 | 2023-09-26 22:38:55,989 - root - INFO - Time taken for epoch 3 is 50.63650107383728 sec, avg 10.111283148363873 samples/sec 330 | 2023-09-26 22:38:55,991 - root - INFO - Avg train loss=0.360278 331 | 2023-09-26 22:39:00,564 - root - INFO - Avg val loss=0.355521 332 | 2023-09-26 22:39:00,564 - root - INFO - Total validation time: 4.063924789428711 sec 333 | 2023-09-26 22:39:51,199 - root - INFO - Time taken for epoch 4 is 50.62860679626465 sec, avg 10.112859752596927 samples/sec 334 | 2023-09-26 22:39:51,200 - root - INFO - Avg train loss=0.345876 335 | 2023-09-26 22:39:55,772 - root - INFO - Avg val loss=0.349507 336 | 2023-09-26 22:39:55,773 - root - INFO - Total validation time: 4.065291404724121 sec 337 | ``` 338 | From here, we can see that running in FP32 without TF32 acceleration is reduced, hence we are seeing some benefits from 339 | TF32 Tensor Core operations without any code changes to add AMP. With that said, AMP can still provide more performance improvement for A100 GPUs, 340 | as TF32 is a compute type only, leaving all data in full precision FP32. FP16 precision has the compute benefits of Tensor Cores combined with a reduction in storage and memory bandwidth requirements. 341 | 342 | We can run this experiment using AMP on Perlmutter by running one of the following commands: 343 | ``` 344 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali --amp_mode=fp16 345 | ``` 346 | for AMP with FP16 precision or 347 | ``` 348 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali --amp_mode=bf16 349 | ``` 350 | for AMP with BF16 precision. 351 | 352 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16, DALI, and AMP FP16: 353 | ``` 354 | 2023-09-26 22:42:50,782 - root - INFO - Time taken for epoch 1 is 13.934328317642212 sec, avg 35.59554423387713 samples/sec 355 | 2023-09-26 22:42:50,782 - root - INFO - Avg train loss=0.585250 356 | 2023-09-26 22:42:53,841 - root - INFO - Avg val loss=0.427797 357 | 2023-09-26 22:42:53,841 - root - INFO - Total validation time: 2.5032312870025635 sec 358 | 2023-09-26 22:43:05,905 - root - INFO - Time taken for epoch 2 is 12.058186531066895 sec, avg 42.46077954432662 samples/sec 359 | 2023-09-26 22:43:05,906 - root - INFO - Avg train loss=0.396558 360 | 2023-09-26 22:43:07,896 - root - INFO - Avg val loss=0.381889 361 | 2023-09-26 22:43:07,896 - root - INFO - Total validation time: 1.4918978214263916 sec 362 | 2023-09-26 22:43:19,939 - root - INFO - Time taken for epoch 3 is 12.037509441375732 sec, avg 42.53371534149218 samples/sec 363 | 2023-09-26 22:43:19,940 - root - INFO - Avg train loss=0.362181 364 | 2023-09-26 22:43:21,919 - root - INFO - Avg val loss=0.357079 365 | 2023-09-26 22:43:21,919 - root - INFO - Total validation time: 1.4730119705200195 sec 366 | 2023-09-26 22:43:33,977 - root - INFO - Time taken for epoch 4 is 12.047151803970337 sec, avg 42.499671983153895 samples/sec 367 | 2023-09-26 22:43:33,978 - root - INFO - Avg train loss=0.347396 368 | 2023-09-26 22:43:35,959 - root - INFO - Avg val loss=0.351097 369 | 2023-09-26 22:43:35,960 - root - INFO - Total validation time: 1.48984694480896 sec 370 | ``` 371 | 372 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16, DALI, and AMP BF16: 373 | ``` 374 | 2023-09-26 22:55:22,111 - root - INFO - Time taken for epoch 1 is 13.740016222000122 sec, avg 36.09893845727918 samples/sec 375 | 2023-09-26 22:55:22,112 - root - INFO - Avg train loss=0.581764 376 | 2023-09-26 22:55:25,160 - root - INFO - Avg val loss=0.423542 377 | 2023-09-26 22:55:25,161 - root - INFO - Total validation time: 2.495091438293457 sec 378 | 2023-09-26 22:55:37,176 - root - INFO - Time taken for epoch 2 is 12.007216453552246 sec, avg 42.64102358615585 samples/sec 379 | 2023-09-26 22:55:37,177 - root - INFO - Avg train loss=0.392484 380 | 2023-09-26 22:55:39,214 - root - INFO - Avg val loss=0.374372 381 | 2023-09-26 22:55:39,214 - root - INFO - Total validation time: 1.5221989154815674 sec 382 | 2023-09-26 22:55:51,228 - root - INFO - Time taken for epoch 3 is 12.009136199951172 sec, avg 42.63420711325447 samples/sec 383 | 2023-09-26 22:55:51,229 - root - INFO - Avg train loss=0.357453 384 | 2023-09-26 22:55:53,237 - root - INFO - Avg val loss=0.353668 385 | 2023-09-26 22:55:53,237 - root - INFO - Total validation time: 1.491905927658081 sec 386 | 2023-09-26 22:56:05,255 - root - INFO - Time taken for epoch 4 is 12.012868881225586 sec, avg 42.620959661033474 samples/sec 387 | 2023-09-26 22:56:05,256 - root - INFO - Avg train loss=0.343864 388 | 2023-09-26 22:56:07,237 - root - INFO - Avg val loss=0.347740 389 | 2023-09-26 22:56:07,237 - root - INFO - Total validation time: 1.470574140548706 sec 390 | ``` 391 | 392 | For this model, we see a massive improvement when using AMP with either FP16 or BF16 precision, improving throughput to over 42 samples/s in each case. BF16 has a slight edge over FP16 due to the lack of loss scaling. 393 | 394 | We can run the case with AMP BF16 enabled through profiler using the instructions in the earlier section with the added `--amp_mode=bf16` 395 | argument and load that profile in Nsight Systems. This is what this profile ([`dali_amp_bf16.nsys-rep`](sample_nsys_profiles/dali_amp_bf16.nsys-rep)) looks like: 396 | ![NSYS DALI AMP](tutorial_images/nsys_dali_amp.png) 397 | 398 | and zoomed in to a single iteration: 399 | ![NSYS DALI AMP Zoomed](tutorial_images/nsys_dali_amp_zoomed.png) 400 | 401 | With AMP enabled, we see that the `forward` (and, correspondingly the backward) time is significantly reduced. The transformer 402 | architecture we are using relies mainly on GEMM operations that greatly benefit from mixed precision. 403 | 404 | ### Just-in-time (JIT) compiliation via `torch.compile` and fused optimizers 405 | While AMP provided a large increase in compute speed already, there are a few other optimizations available for PyTorch to improve 406 | compute throughput. A first (and simple change) is to enable the `fused` option in the Adam optimizer from `torch.optim.Adam`. 407 | In the past, this fused optimizer was mainly available in 408 | [APEX](https://github.com/NVIDIA/apex) but has now been made available in PyTorch directly. Enabling the `fused` option resultsin fewer kernels to perform the weight 409 | update than the unfused Adam optimizer, reducing latency and making more efficient use of GPU bandwidth by increasing register 410 | reuse. We can enabled the use of the fused optimizer in our training script by adding the flag `--enable_fused`. 411 | 412 | We can run this experiment using the fused optimizer on Perlmutter by running the following command: 413 | ``` 414 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali --amp_mode=bf16 --enable_fused 415 | ``` 416 | 417 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16, DALI, and AMP, and the fused optimizer: 418 | ``` 419 | 2023-09-26 23:06:32,768 - root - INFO - Time taken for epoch 1 is 13.392464637756348 sec, avg 37.0357520752129 samples/sec 420 | 2023-09-26 23:06:32,769 - root - INFO - Avg train loss=0.587116 421 | 2023-09-26 23:06:35,805 - root - INFO - Avg val loss=0.428104 422 | 2023-09-26 23:06:35,806 - root - INFO - Total validation time: 2.46842885017395 sec 423 | 2023-09-26 23:06:47,807 - root - INFO - Time taken for epoch 2 is 11.996378421783447 sec, avg 42.67954727655909 samples/sec 424 | 2023-09-26 23:06:47,808 - root - INFO - Avg train loss=0.395509 425 | 2023-09-26 23:06:49,794 - root - INFO - Avg val loss=0.377574 426 | 2023-09-26 23:06:49,794 - root - INFO - Total validation time: 1.474884033203125 sec 427 | 2023-09-26 23:07:01,795 - root - INFO - Time taken for epoch 3 is 11.994306564331055 sec, avg 42.686919602555214 samples/sec 428 | 2023-09-26 23:07:01,796 - root - INFO - Avg train loss=0.359626 429 | 2023-09-26 23:07:03,782 - root - INFO - Avg val loss=0.356546 430 | 2023-09-26 23:07:03,782 - root - INFO - Total validation time: 1.4720070362091064 sec 431 | 2023-09-26 23:07:15,797 - root - INFO - Time taken for epoch 4 is 12.009339809417725 sec, avg 42.63348428183284 samples/sec 432 | 2023-09-26 23:07:15,798 - root - INFO - Avg train loss=0.345925 433 | 2023-09-26 23:07:17,786 - root - INFO - Avg val loss=0.349518 434 | 2023-09-26 23:07:17,786 - root - INFO - Total validation time: 1.4778716564178467 sec 435 | ``` 436 | 437 | In additional to optimizer fusion, for more general fusion of operations in PyTorch, we can enable 438 | JIT compilation, done in our training script via the flag `--enable_jit`. This option wraps the model in `torch.compile` which 439 | will compile/fuse eligible operations in the model, further reducing latency. 440 | 441 | We can run this experiment using JIT on Perlmutter by running the following command: 442 | ``` 443 | sbatch -n 1 -t 20 ./submit_pm.sh --config=short --num_data_workers 8 --data_loader_config=dali --amp_mode=bf16 --enable_fused --enable_jit 444 | ``` 445 | 446 | This is the performance of the training script for the first four epochs on a 40GB A100 card with batch size 16, DALI, AMP, fused optimizer and JIT: 447 | ``` 448 | 2023-09-26 23:13:06,278 - root - INFO - Time taken for epoch 1 is 43.2601523399353 sec, avg 11.465516720848926 samples/sec 449 | 2023-09-26 23:13:06,279 - root - INFO - Avg train loss=0.586837 450 | 2023-09-26 23:13:16,808 - root - INFO - Avg val loss=0.429435 451 | 2023-09-26 23:13:16,808 - root - INFO - Total validation time: 9.924283266067505 sec 452 | 2023-09-26 23:13:28,794 - root - INFO - Time taken for epoch 2 is 11.979468584060669 sec, avg 42.73979237119447 samples/sec 453 | 2023-09-26 23:13:28,794 - root - INFO - Avg train loss=0.397797 454 | 2023-09-26 23:13:30,768 - root - INFO - Avg val loss=0.381870 455 | 2023-09-26 23:13:30,768 - root - INFO - Total validation time: 1.462252140045166 sec 456 | 2023-09-26 23:13:42,724 - root - INFO - Time taken for epoch 3 is 11.948866605758667 sec, avg 42.849252309273034 samples/sec 457 | 2023-09-26 23:13:42,724 - root - INFO - Avg train loss=0.362702 458 | 2023-09-26 23:13:44,678 - root - INFO - Avg val loss=0.357342 459 | 2023-09-26 23:13:44,679 - root - INFO - Total validation time: 1.4342505931854248 sec 460 | 2023-09-26 23:13:56,639 - root - INFO - Time taken for epoch 4 is 11.952066659927368 sec, avg 42.83777982234843 samples/sec 461 | 2023-09-26 23:13:56,640 - root - INFO - Avg train loss=0.347776 462 | 2023-09-26 23:13:58,597 - root - INFO - Avg val loss=0.351206 463 | 2023-09-26 23:13:58,597 - root - INFO - Total validation time: 1.4387221336364746 sec 464 | ``` 465 | 466 | Running a profile ([`dali_amp_bf16_fused_jit.nsys-rep`](sample_nsys_profiles/dali_amp_bf16_fused_jit.nsys-rep)) using these new options and loading in Nsight Systems looks like this: 467 | ![NSYS DALI AMP APEX JIT](tutorial_images/nsys_dali_amp_fused_jit.png) 468 | 469 | and zoomed in to a single iteration: 470 | ![NSYS DALI AMP APEX JIT Zoomed](tutorial_images/nsys_dali_amp_fused_jit_zoomed.png) 471 | 472 | As the compute cost of this model is mostly dominated by large GEMMs, latency reductions via optimizer and pointwise operation fusion are less impactful, but they still provide a small performance boost in this case. 473 | 474 | ## Distributed training with data parallelism 475 | 476 | Instructions for hands-on with mulit-GPU and multi-node training using distributed data parallelism. 477 | 478 | Now that we have model training code that is optimized for training on a single GPU, 479 | we are ready to utilize multiple GPUs and multiple nodes to accelerate the workflow 480 | with *distributed training*. We will use the recommended `DistributedDataParallel` 481 | wrapper in PyTorch with the NCCL backend for optimized communication operations on 482 | systems with NVIDIA GPUs. Refer to the PyTorch documentation for additional details 483 | on the distributed package: https://pytorch.org/docs/stable/distributed.html 484 | 485 | ### Code basics 486 | 487 | To submit multi-GPU and multi-node jobs, we can use the same slurm script but specify either 488 | the number of tasks (GPUs) with `-n ` or `-N