├── distributed ├── __init__.py ├── helpers.py ├── mappings.py └── layers.py ├── figs └── minibatch_0.jpg ├── tutorial_images ├── mp_comp.png ├── dp_timings.png ├── mp_dp_comp.png ├── nsys_dali.png ├── baseline_tb.png ├── nsys_4workers.png ├── nsys_baseline.png ├── vit_schematic.png ├── nsys_dali_bf16.png ├── nsys_dali_zoomed.png ├── nsys_4workers_zoomed.png ├── nsys_baseline_zoomed.png ├── weather_forecasting.gif ├── nsys_dali_bf16_zoomed.png ├── nsys_dali_bf16_fused_jit.png └── nsys_dali_bf16_fused_jit_zoomed.png ├── sample_nsys_profiles ├── dali.nsys-rep ├── 4workers.nsys-rep ├── baseline.nsys-rep ├── dali_amp_bf16.nsys-rep └── dali_amp_bf16_fused_jit.nsys-rep ├── export_DDP_vars.sh ├── example_logs └── base │ └── 1GPU │ └── 00 │ └── logs │ └── events.out.tfevents.1698839594.nid001332.2071366.0 ├── test_model_dims.py ├── utils ├── __init__.py ├── plots.py ├── loss.py ├── YParams.py ├── check_rank_generator.ipynb ├── logging_utils.py ├── metrics.py ├── data_loader.py ├── comm.py ├── data_loader_dali.py ├── dali_es_helper.py └── rank_generator.py ├── tests ├── run_tests.sh └── test_distributed.py ├── submit_pm.sh ├── submit_pm_dp.sh ├── submit_pm_mp.sh ├── test_data_loader.py ├── start_tensorboard.ipynb ├── config └── ViT.yaml ├── networks ├── helpers.py └── vit.py ├── .gitignore ├── train.py ├── train_mp.py └── train_mp_graphs.py /distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # model parallelism helpers and routines 2 | -------------------------------------------------------------------------------- /figs/minibatch_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/figs/minibatch_0.jpg -------------------------------------------------------------------------------- /tutorial_images/mp_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/mp_comp.png -------------------------------------------------------------------------------- /tutorial_images/dp_timings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/dp_timings.png -------------------------------------------------------------------------------- /tutorial_images/mp_dp_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/mp_dp_comp.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali.png -------------------------------------------------------------------------------- /tutorial_images/baseline_tb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/baseline_tb.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_4workers.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_baseline.png -------------------------------------------------------------------------------- /tutorial_images/vit_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/vit_schematic.png -------------------------------------------------------------------------------- /sample_nsys_profiles/dali.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/sample_nsys_profiles/dali.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_bf16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali_bf16.png -------------------------------------------------------------------------------- /sample_nsys_profiles/4workers.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/sample_nsys_profiles/4workers.nsys-rep -------------------------------------------------------------------------------- /sample_nsys_profiles/baseline.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/sample_nsys_profiles/baseline.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_4workers_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_4workers_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_baseline_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_baseline_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/weather_forecasting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/weather_forecasting.gif -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp_bf16.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/sample_nsys_profiles/dali_amp_bf16.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_bf16_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali_bf16_zoomed.png -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_bf16_fused_jit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali_bf16_fused_jit.png -------------------------------------------------------------------------------- /sample_nsys_profiles/dali_amp_bf16_fused_jit.nsys-rep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/sample_nsys_profiles/dali_amp_bf16_fused_jit.nsys-rep -------------------------------------------------------------------------------- /tutorial_images/nsys_dali_bf16_fused_jit_zoomed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/tutorial_images/nsys_dali_bf16_fused_jit_zoomed.png -------------------------------------------------------------------------------- /export_DDP_vars.sh: -------------------------------------------------------------------------------- 1 | export RANK=$SLURM_PROCID 2 | export LOCAL_RANK=$SLURM_LOCALID 3 | export WORLD_SIZE=$SLURM_NTASKS 4 | export MASTER_PORT=29500 # default from torch launcher 5 | -------------------------------------------------------------------------------- /example_logs/base/1GPU/00/logs/events.out.tfevents.1698839594.nid001332.2071366.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NERSC/dl-at-scale-training/main/example_logs/base/1GPU/00/logs/events.out.tfevents.1698839594.nid001332.2071366.0 -------------------------------------------------------------------------------- /test_model_dims.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks.vit import ViT 3 | from utils.YParams import YParams 4 | from torchinfo import summary 5 | 6 | params = YParams('./config/ViT.yaml', 'short_mp') 7 | model = ViT(params) 8 | summary(model, input_size=(16,20,360,720)) 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | def get_data_loader_distributed(params, location, distributed, train): 2 | if params.data_loader_config.startswith("dali"): 3 | from .data_loader_dali import get_data_loader 4 | else: 5 | from .data_loader import get_data_loader 6 | return get_data_loader(params, location, distributed, train) 7 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | 6 | def generate_images(fields): 7 | inp, tar, gen = [x.detach().float().cpu().numpy() for x in fields] 8 | fig, ax = plt.subplots(1, 2, figsize=(12, 6)) 9 | plt.title("2m temperature") 10 | ax[0].imshow(tar[0, 2, :, :], cmap="turbo") 11 | ax[0].set_title("ERA5 target") 12 | ax[1].imshow(gen[0, 2, :, :], cmap="turbo") 13 | ax[1].set_title("ViT prediction") 14 | fig.tight_layout() 15 | return fig 16 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def l2_loss(pred, target): 5 | num_examples = pred.shape[0] 6 | diff_norms = torch.norm( 7 | pred.reshape(num_examples, -1) - target.reshape(num_examples, -1), 2, 1 8 | ) 9 | y_norms = torch.norm(target.reshape(num_examples, -1), 2, 1) 10 | return torch.mean(diff_norms / y_norms) 11 | 12 | 13 | @torch.jit.script 14 | def l2_loss_opt(pred: torch.Tensor, target: torch.Tensor): 15 | num_examples = pred.shape[0] 16 | diff_norms = torch.norm( 17 | pred.reshape(num_examples, -1) - target.reshape(num_examples, -1), 2, 1 18 | ) 19 | y_norms = torch.norm(target.reshape(num_examples, -1), 2, 1) 20 | return torch.mean(diff_norms / y_norms) 21 | -------------------------------------------------------------------------------- /tests/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | image=nersc/pytorch:24.08.01 4 | dp=1 5 | tp=1 6 | cp=4 7 | nodes=1 8 | 9 | # parse args 10 | for arg in "$@" 11 | do 12 | if [[ $arg == tp=* ]]; then 13 | tp="${arg#*=}" 14 | elif [[ $arg == cp=* ]]; then 15 | cp="${arg#*=}" 16 | elif [[ $arg == dp=* ]]; then 17 | dp="${arg#*=}" 18 | elif [[ $arg == nodes=* ]]; then 19 | nodes="${arg#*=}" 20 | fi 21 | done 22 | 23 | ngpu_per_node=$(( (${tp} * ${cp} * ${dp})/$nodes )) 24 | export MASTER_ADDR=$(hostname) 25 | srun --nodes $nodes --ntasks-per-node $ngpu_per_node --gpus-per-node $ngpu_per_node -u shifter --image=$image --module=gpu,nccl-plugin \ 26 | bash -c " 27 | source export_DDP_vars.sh 28 | export TP=${tp} 29 | export CP=${cp} 30 | export NVIDIA_TF32_OVERRIDE=0 31 | python -m pytest -s tests/test_distributed.py 32 | " 33 | -------------------------------------------------------------------------------- /submit_pm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -q shared 4 | #SBATCH -A ntrain4 5 | #SBATCH --cpus-per-task 32 6 | #SBATCH --gpus-per-task 1 7 | #SBATCH --gpu-bind none 8 | #SBATCH --time=01:00:00 9 | #SBATCH --image=nersc/pytorch:24.08.01 10 | #SBATCH --module=gpu,nccl-plugin 11 | #SBATCH --reservation=dlscale_training_1 12 | #SBATCH -J vit-era5 13 | #SBATCH -o %x-%j.out 14 | 15 | DATADIR=/pscratch/sd/s/shas1693/data/dl-at-scale-training-data 16 | LOGDIR=${SCRATCH}/dl-at-scale-training/logs 17 | mkdir -p ${LOGDIR} 18 | args="${@}" 19 | 20 | export FI_MR_CACHE_MONITOR=userfaultfd 21 | export HDF5_USE_FILE_LOCKING=FALSE 22 | 23 | # Profiling 24 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 25 | echo "Enabling profiling..." 26 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 27 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 28 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 29 | fi 30 | 31 | export MASTER_ADDR=$(hostname) 32 | 33 | # Reversing order of GPUs to match default CPU affinities from Slurm 34 | export CUDA_VISIBLE_DEVICES=3,2,1,0 35 | 36 | set -x 37 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 38 | bash -c " 39 | source export_DDP_vars.sh 40 | ${PROFILE_CMD} python train.py ${args} 41 | " 42 | -------------------------------------------------------------------------------- /submit_pm_dp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4 4 | #SBATCH -q regular 5 | #SBATCH --ntasks-per-node 4 6 | #SBATCH --cpus-per-task 32 7 | #SBATCH --gpus-per-node 4 8 | #SBATCH --time=01:00:00 9 | #SBATCH --image=nersc/pytorch:24.08.01 10 | #SBATCH --module=gpu,nccl-plugin 11 | #SBATCH --reservation=dlscale_training_2 12 | #SBATCH -J vit-era5 13 | #SBATCH -o %x-%j.out 14 | 15 | DATADIR=/pscratch/sd/s/shas1693/data/dl-at-scale-training-data 16 | LOGDIR=${SCRATCH}/dl-at-scale-training/logs 17 | mkdir -p ${LOGDIR} 18 | args="${@}" 19 | 20 | export FI_MR_CACHE_MONITOR=userfaultfd 21 | export HDF5_USE_FILE_LOCKING=FALSE 22 | 23 | # Profiling 24 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 25 | echo "Enabling profiling..." 26 | NSYS_ARGS="--trace=cuda,cublas,nvtx --kill none -c cudaProfilerApi -f true" 27 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 28 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 29 | fi 30 | 31 | export MASTER_ADDR=$(hostname) 32 | 33 | # Reversing order of GPUs to match default CPU affinities from Slurm 34 | export CUDA_VISIBLE_DEVICES=3,2,1,0 35 | 36 | set -x 37 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 38 | bash -c " 39 | source export_DDP_vars.sh 40 | ${PROFILE_CMD} python train.py ${args} 41 | " 42 | -------------------------------------------------------------------------------- /submit_pm_mp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -C gpu 3 | #SBATCH -A ntrain4 4 | #SBATCH -q regular 5 | #SBATCH --ntasks-per-node 4 6 | #SBATCH --cpus-per-task 32 7 | #SBATCH --gpus-per-node 4 8 | #SBATCH --time=01:00:00 9 | #SBATCH --image=nersc/pytorch:24.08.01 10 | #SBATCH --module=gpu,nccl-plugin 11 | #SBATCH --reservation=dlscale_training_2 12 | #SBATCH -J vit-era5-mp 13 | #SBATCH -o %x-%j.out 14 | 15 | DATADIR=/pscratch/sd/s/shas1693/data/dl-at-scale-training-data 16 | LOGDIR=${SCRATCH}/dl-at-scale-training/logs 17 | mkdir -p ${LOGDIR} 18 | args="${@}" 19 | 20 | export HDF5_USE_FILE_LOCKING=FALSE 21 | 22 | # Profiling 23 | if [ "${ENABLE_PROFILING:-0}" -eq 1 ]; then 24 | echo "Enabling profiling..." 25 | NSYS_ARGS="--trace=cuda,cublas,nvtx --cuda-graph-trace=node --kill none -c cudaProfilerApi -f true" 26 | NSYS_OUTPUT=${LOGDIR}/${PROFILE_OUTPUT:-"profile"} 27 | export PROFILE_CMD="nsys profile $NSYS_ARGS -o $NSYS_OUTPUT" 28 | fi 29 | 30 | export MASTER_ADDR=$(hostname) 31 | 32 | # Reversing order of GPUs to match default CPU affinities from Slurm 33 | export CUDA_VISIBLE_DEVICES=3,2,1,0 34 | 35 | # if cuda graphs, use train_mp_graphs.py 36 | set -x 37 | srun -u shifter -V ${DATADIR}:/data -V ${LOGDIR}:/logs \ 38 | bash -c " 39 | source export_DDP_vars.sh 40 | ${PROFILE_CMD} python train_mp.py ${args} 41 | " 42 | -------------------------------------------------------------------------------- /test_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import get_data_loader_distributed 3 | import numpy as np 4 | from utils.YParams import YParams 5 | from networks.vit import ViT 6 | import matplotlib.pyplot as plt 7 | 8 | params = YParams('./config/ViT.yaml', 'short') 9 | params.global_batch_size = 1 10 | params.local_batch_size = 1 11 | 12 | valid_dataloader, dataset_valid = get_data_loader_distributed(params, params.valid_data_path, distributed=False, train=False) 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | params.device = device 16 | model = ViT(params) 17 | model = model.to(device) 18 | 19 | with torch.no_grad(): 20 | for i, data in enumerate(valid_dataloader, 0): 21 | if i >= 1: 22 | break 23 | print("Doing iteration {}".format(i)) 24 | inp, tar = map(lambda x: x.to(device, dtype = torch.float), data) 25 | print("input shape = {}".format(inp.shape)) 26 | print("target shape = {}".format(tar.shape)) 27 | plt.rcParams["figure.figsize"] = (20,20) 28 | plt.figure() 29 | for ch in range(inp.shape[1]): 30 | plt.subplot(inp.shape[1],1, ch+1) 31 | plt.imshow(inp[0,ch,:,:].cpu(), cmap = 'RdBu') 32 | plt.colorbar() 33 | plt.savefig("figs/minibatch_" + str(i) + ".jpg") 34 | gen = model(inp) 35 | print("prediction shape = {}".format(gen.shape)) 36 | 37 | -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | 5 | class YParams: 6 | """Yaml file parser""" 7 | 8 | def __init__(self, yaml_filename, config_name, print_params=False): 9 | self._yaml_filename = yaml_filename 10 | self._config_name = config_name 11 | self.params = {} 12 | 13 | with open(yaml_filename) as _file: 14 | 15 | for key, val in YAML().load(_file)[config_name].items(): 16 | if val == "None": 17 | val = None 18 | 19 | self.params[key] = val 20 | self.__setattr__(key, val) 21 | 22 | if print_params: 23 | self.log() 24 | 25 | def __getitem__(self, key): 26 | return self.params[key] 27 | 28 | def __setitem__(self, key, val): 29 | self.params[key] = val 30 | 31 | def get(self, key, default=None): 32 | """Get a parameter value""" 33 | if hasattr(self, key): 34 | return getattr(self, key) 35 | else: 36 | return self.params.get(key, default) 37 | 38 | def log(self): 39 | logging.info("------------------ Configuration ------------------") 40 | logging.info("Configuration file: " + str(self._yaml_filename)) 41 | logging.info("Configuration name: " + str(self._config_name)) 42 | for key, val in self.params.items(): 43 | logging.info(str(key) + " " + str(val)) 44 | logging.info("---------------------------------------------------") 45 | 46 | def update(self, new_params): 47 | self.params.update(new_params) 48 | for key, val in new_params.items(): 49 | self.__setattr__(key, val) 50 | -------------------------------------------------------------------------------- /utils/check_rank_generator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e9b95e95-3be0-49ea-ba2b-557c12dfdd70", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from rank_generator import RankGenerator" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "22e66d4c-5902-42ff-96dc-1194ce64dc55", 19 | "metadata": { 20 | "tags": [] 21 | }, 22 | "outputs": [ 23 | { 24 | "data": { 25 | "text/plain": [ 26 | "[[0, 4, 8, 12, 16, 20, 24, 28],\n", 27 | " [1, 5, 9, 13, 17, 21, 25, 29],\n", 28 | " [2, 6, 10, 14, 18, 22, 26, 30],\n", 29 | " [3, 7, 11, 15, 19, 23, 27, 31]]" 30 | ] 31 | }, 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "rg = RankGenerator(\n", 39 | " tp=2,\n", 40 | " dp=8,\n", 41 | " pp=1,\n", 42 | " cp=2,\n", 43 | " order='tp-cp-pp-dp',\n", 44 | " )\n", 45 | "rg.get_ranks('dp')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "99f3fb0f-4705-4931-8250-9a12523633a0", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [] 55 | } 56 | ], 57 | "metadata": { 58 | "kernelspec": { 59 | "display_name": "pytorch-2.3.1", 60 | "language": "python", 61 | "name": "pytorch-2.3.1" 62 | }, 63 | "language_info": { 64 | "codemirror_mode": { 65 | "name": "ipython", 66 | "version": 3 67 | }, 68 | "file_extension": ".py", 69 | "mimetype": "text/x-python", 70 | "name": "python", 71 | "nbconvert_exporter": "python", 72 | "pygments_lexer": "ipython3", 73 | "version": "3.11.9" 74 | } 75 | }, 76 | "nbformat": 4, 77 | "nbformat_minor": 5 78 | } 79 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | _format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 5 | 6 | 7 | def slurm_filter(record): 8 | return int(os.environ["SLURM_PROCID"]) == 0 9 | 10 | 11 | def config_logger(log_level=logging.INFO): 12 | logging.basicConfig(format=_format, level=log_level) 13 | root_logger = logging.getLogger() 14 | root_logger.addFilter(slurm_filter) 15 | 16 | 17 | def log_to_file( 18 | logger_name=None, log_level=logging.INFO, log_filename="tensorflow.log" 19 | ): 20 | 21 | if not os.path.exists(os.path.dirname(log_filename)): 22 | os.makedirs(os.path.dirname(log_filename)) 23 | 24 | if logger_name is not None: 25 | log = logging.getLogger(logger_name) 26 | else: 27 | log = logging.getLogger() 28 | 29 | fh = logging.FileHandler(log_filename) 30 | fh.setLevel(log_level) 31 | fh.setFormatter(logging.Formatter(_format)) 32 | log.addHandler(fh) 33 | 34 | 35 | def log_versions(): 36 | import torch 37 | import subprocess 38 | 39 | logging.info("--------------- Versions ---------------") 40 | logging.info( 41 | "git branch: " + str(subprocess.check_output(["git", "branch"]).strip()) 42 | ) 43 | logging.info( 44 | "git hash: " 45 | + str(subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()) 46 | ) 47 | logging.info("Torch: " + str(torch.__version__)) 48 | logging.info("----------------------------------------") 49 | 50 | 51 | class disable_logging(object): 52 | """ 53 | A context manager to disable logging temporarily. 54 | """ 55 | 56 | def __init__(self, level=logging.ERROR): # pragma: no cover 57 | """ 58 | Initialize the context manager. 59 | """ 60 | logging.disable(level=level) 61 | 62 | def __enter__(self): # pragma: no cover 63 | """ 64 | Enter the context manager. 65 | """ 66 | return self 67 | 68 | def __exit__(self, type, value, traceback): # pragma: no cover 69 | """ 70 | Exit the context manager and enable logging. 71 | """ 72 | logging.disable(level=logging.NOTSET) 73 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | @torch.jit.script 6 | def lat(j: torch.Tensor, num_lat: int) -> torch.Tensor: 7 | return 90.0 - j * 180.0 / float(num_lat - 1) 8 | 9 | 10 | @torch.jit.script 11 | def latitude_weighting_factor( 12 | j: torch.Tensor, num_lat: int, s: torch.Tensor 13 | ) -> torch.Tensor: 14 | return num_lat * torch.cos(3.1416 / 180.0 * lat(j, num_lat)) / s 15 | 16 | 17 | @torch.jit.script 18 | def weighted_rmse_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 19 | # takes in arrays of size [n, c, h, w] and returns latitude-weighted rmse for each chann 20 | num_lat = pred.shape[2] 21 | # num_long = target.shape[2] 22 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 23 | 24 | s = torch.sum(torch.cos(3.1416 / 180.0 * lat(lat_t, num_lat))) 25 | weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1)) 26 | result = torch.sqrt(torch.mean(weight * (pred - target) ** 2.0, dim=(-1, -2))) 27 | return result 28 | 29 | 30 | @torch.jit.script 31 | def weighted_rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 32 | result = weighted_rmse_channels(pred, target) 33 | return torch.mean(result, dim=0) 34 | 35 | 36 | @torch.jit.script 37 | def weighted_acc_channels(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 38 | # takes in arrays of size [n, c, h, w] and returns latitude-weighted acc 39 | num_lat = pred.shape[2] 40 | # num_long = target.shape[2] 41 | lat_t = torch.arange(start=0, end=num_lat, device=pred.device) 42 | s = torch.sum(torch.cos(3.1416 / 180.0 * lat(lat_t, num_lat))) 43 | weight = torch.reshape(latitude_weighting_factor(lat_t, num_lat, s), (1, 1, -1, 1)) 44 | result = torch.sum(weight * pred * target, dim=(-1, -2)) / torch.sqrt( 45 | torch.sum(weight * pred * pred, dim=(-1, -2)) 46 | * torch.sum(weight * target * target, dim=(-1, -2)) 47 | ) 48 | return result 49 | 50 | 51 | @torch.jit.script 52 | def weighted_acc(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 53 | result = weighted_acc_channels(pred, target) 54 | return torch.mean(result, dim=0) 55 | -------------------------------------------------------------------------------- /start_tensorboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "314d73bf-204b-410b-9b58-4db9a461e1c4", 6 | "metadata": {}, 7 | "source": [ 8 | "# TensorBoard Launcher\n", 9 | "\n", 10 | "This notebook allows you to start TensorBoard on Perlmutter and view it in a normal browser tab.\n", 11 | "\n", 12 | "The notebook code below assumes you are using the hands-on tutorial path for tensorboard logs.\n", 13 | "\n", 14 | "When you run the cells below, TensorBoard will start but will not display here in the notebook. Instead, the final cell which calls `nersc_tensorboard_helper.tb_address()` will display a URL that you can click to open a new tab with TensorBoard." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "97bcd738-b1f5-40ed-acb0-2b2572f741e2", 21 | "metadata": { 22 | "tags": [] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import nersc_tensorboard_helper\n", 28 | "\n", 29 | "%load_ext tensorboard" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "2b4a192d-f5a7-4f8a-becb-ae52ed9cc036", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "log_dir = os.path.expandvars('${SCRATCH}/dl-at-scale-training/logs')" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "4dd0750e-7f27-4a66-9fce-4be91ea88835", 48 | "metadata": { 49 | "tags": [] 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "%%capture\n", 54 | "%tensorboard --logdir $log_dir --port 0" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "7f4158e2-d17a-40f3-a6e5-961e2615e4c3", 61 | "metadata": { 62 | "tags": [] 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "nersc_tensorboard_helper.tb_address()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "0ad851ac-485a-436c-a88b-356aa0843641", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "tensorflow-2.15.0", 81 | "language": "python", 82 | "name": "tensorflow-2.15.0" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.9.18" 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 5 99 | } 100 | -------------------------------------------------------------------------------- /config/ViT.yaml: -------------------------------------------------------------------------------- 1 | base: &base 2 | 3 | # Model config 4 | embed_dim: 384 5 | depth: 12 6 | dropout: 0.0 7 | patch_size: 8 8 | num_heads: 8 9 | 10 | # Training config 11 | img_size: [360, 720] 12 | dt: 1 13 | global_batch_size: 16 # number of samples per training batch 14 | num_iters: 30000 15 | amp_mode: none 16 | enable_fused: false 17 | enable_jit: false 18 | expdir: '/logs' 19 | lr_schedule: 'cosine' 20 | lr: 5E-4 21 | warmup: 0 22 | optimizer: 'Adam' 23 | 24 | # Data 25 | data_loader_config: 'pytorch' 26 | num_data_workers: 0 # number of dataloader worker threads per proc 27 | n_in_channels: 20 28 | n_out_channels: 20 29 | train_data_path: '/data/train' 30 | valid_data_path: '/data/valid' 31 | inf_data_path: '/data/test' 32 | time_means_path: '/data/stats/time_means.npy' 33 | global_means_path: '/data/stats/global_means.npy' 34 | global_stds_path: '/data/stats/global_stds.npy' 35 | limit_nsamples: None 36 | limit_nsamples_val: None 37 | 38 | # Comms 39 | wireup_info: env 40 | wireup_store: tcp 41 | 42 | # limit the number of samples 43 | short: &short_ls 44 | <<: *base 45 | limit_nsamples: 512 46 | limit_nsamples_val: 128 47 | num_iters: 128 48 | 49 | # add optimization flags 50 | short_opt: 51 | <<: *short_ls 52 | data_loader_config: 'dali' 53 | num_data_workers: 8 54 | amp_mode: fp16 55 | enable_jit: true 56 | enable_fused: true 57 | 58 | # no samples limits 59 | opt: &opt 60 | <<: *base 61 | data_loader_config: 'dali' 62 | num_data_workers: 8 63 | amp_mode: fp16 64 | num_iters: 30000 65 | enable_jit: true 66 | enable_fused: true 67 | 68 | # ----- Data parallel scaling configs 69 | bs16_opt: 70 | <<: *opt 71 | global_batch_size: 16 72 | lr: 5e-4 73 | 74 | bs32_opt: 75 | <<: *opt 76 | global_batch_size: 32 77 | lr: 7.07e-4 78 | 79 | bs64_opt: 80 | <<: *opt 81 | global_batch_size: 64 82 | lr: 1e-3 83 | 84 | bs128_opt: 85 | <<: *opt 86 | global_batch_size: 128 87 | lr: 1.41e-3 88 | 89 | bs256_opt: 90 | <<: *opt 91 | global_batch_size: 256 92 | lr: 2e-3 93 | 94 | bs512_opt: 95 | <<: *opt 96 | global_batch_size: 512 97 | lr: 2.83e-3 98 | 99 | bs1024_opt: 100 | <<: *opt 101 | global_batch_size: 1024 102 | lr: 4e-3 103 | 104 | bs2048_opt: 105 | <<: *opt 106 | global_batch_size: 2048 107 | lr: 5.66e-3 108 | 109 | # Model parallel configs 110 | mp: &mp 111 | <<: *base 112 | num_iters: 30000 113 | global_batch_size: 64 114 | lr: 1e-3 115 | num_data_workers: 8 116 | embed_dim: 1024 # change to bigger model 117 | data_loader_config: 'dali' 118 | amp_mode: fp16 119 | enable_jit: false 120 | enable_fused: true 121 | 122 | mp_bs16: 123 | <<: *mp 124 | global_batch_size: 16 125 | lr: 5e-4 126 | 127 | mp_bs32: 128 | <<: *mp 129 | global_batch_size: 32 130 | lr: 7.07e-4 131 | 132 | # larger seq length (use local bs = 1 here) 133 | mp_patch4: 134 | <<: *mp 135 | patch_size: 4 136 | -------------------------------------------------------------------------------- /networks/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | import torch 6 | import warnings 7 | 8 | # These functions are directly pulled from timm: 9 | # https://github.com/huggingface/pytorch-image-models/tree/main/timm 10 | 11 | 12 | @torch.jit.script 13 | def drop_path( 14 | x: torch.Tensor, drop_prob: float = 0.0, training: bool = False 15 | ) -> torch.Tensor: 16 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 17 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 18 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 19 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 20 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 21 | 'survival rate' as the argument. 22 | """ 23 | if drop_prob == 0.0 or not training: 24 | return x 25 | keep_prob = 1.0 - drop_prob 26 | shape = (x.shape[0],) + (1,) * ( 27 | x.ndim - 1 28 | ) # work with diff dim tensors, not just 2D ConvNets 29 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 30 | random_tensor.floor_() # binarize 31 | output = x.div(keep_prob) * random_tensor 32 | return output 33 | 34 | 35 | class DropPath(nn.Module): 36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 37 | 38 | def __init__(self, drop_prob=None): 39 | super(DropPath, self).__init__() 40 | self.drop_prob = drop_prob 41 | 42 | def forward(self, x): 43 | return drop_path(x, self.drop_prob, self.training) 44 | 45 | 46 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 47 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 48 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 49 | def norm_cdf(x): # pragma: no cover 50 | # Computes standard normal cumulative distribution function 51 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 52 | 53 | if (mean < a - 2 * std) or (mean > b + 2 * std): 54 | warnings.warn( 55 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 56 | "The distribution of values may be incorrect.", 57 | stacklevel=2, 58 | ) 59 | 60 | with torch.no_grad(): 61 | # Values are generated by using a truncated uniform distribution and 62 | # then using the inverse CDF for the normal distribution. 63 | # Get upper and lower cdf values 64 | l = norm_cdf((a - mean) / std) 65 | u = norm_cdf((b - mean) / std) 66 | 67 | # Uniformly fill tensor with values from [l, u], then translate to 68 | # [2l-1, 2u-1]. 69 | tensor.uniform_(2 * l - 1, 2 * u - 1) 70 | 71 | # Use inverse cdf transform for normal distribution to get truncated 72 | # standard normal 73 | tensor.erfinv_() 74 | 75 | # Transform to proper mean, std 76 | tensor.mul_(std * math.sqrt(2.0)) 77 | tensor.add_(mean) 78 | 79 | # Clamp to ensure it's in the proper range 80 | tensor.clamp_(min=a, max=b) 81 | return tensor 82 | 83 | 84 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # pragma: no cover 85 | r"""Fills the input Tensor with values drawn from a truncated 86 | normal distribution. The values are effectively drawn from the 87 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 88 | with values outside :math:`[a, b]` redrawn until they are within 89 | the bounds. The method used for generating the random values works 90 | best when :math:`a \leq \text{mean} \leq b`. 91 | Args: 92 | tensor: an n-dimensional `torch.Tensor` 93 | mean: the mean of the normal distribution 94 | std: the standard deviation of the normal distribution 95 | a: the minimum cutoff value 96 | b: the maximum cutoff value 97 | """ 98 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import glob 4 | import torch 5 | import random 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Dataset 8 | from torch.utils.data.distributed import DistributedSampler 9 | from torch import Tensor 10 | import h5py 11 | 12 | 13 | def worker_init(wrk_id): 14 | np.random.seed(torch.utils.data.get_worker_info().seed % (2**32 - 1)) 15 | 16 | 17 | def get_data_loader(params, files_pattern, distributed, train): 18 | dataset = ERA5Dataset(params, files_pattern, train) 19 | 20 | if distributed: 21 | if hasattr(params, "data_num_shards"): 22 | # this is for model parallelism 23 | assert hasattr( 24 | params, "data_shard_id" 25 | ), "please set data_num_shards and data_shard_id" 26 | sampler = DistributedSampler( 27 | dataset, 28 | shuffle=train, 29 | num_replicas=params.data_num_shards, 30 | rank=params.data_shard_id, 31 | ) 32 | else: 33 | sampler = DistributedSampler(dataset, shuffle=train) 34 | else: 35 | sampler = None 36 | 37 | dataloader = DataLoader( 38 | dataset, 39 | batch_size=int(params.local_batch_size), 40 | num_workers=params.num_data_workers, 41 | shuffle=(sampler is None), 42 | sampler=sampler, 43 | worker_init_fn=worker_init, 44 | drop_last=True, 45 | # persistent_workers=train, 46 | pin_memory=torch.cuda.is_available(), 47 | ) 48 | 49 | if train: 50 | return dataloader, dataset, sampler 51 | else: 52 | return dataloader, dataset 53 | 54 | 55 | class ERA5Dataset(Dataset): 56 | def __init__(self, params, location, train): 57 | self.params = params 58 | self.location = location 59 | self.train = train 60 | self.dt = params.dt 61 | self.n_in_channels = params.n_in_channels 62 | self.n_out_channels = params.n_out_channels 63 | self.normalize = True 64 | self.means = np.load(params.global_means_path)[0] 65 | self.stds = np.load(params.global_stds_path)[0] 66 | self.limit_nsamples = ( 67 | params.limit_nsamples if train else params.limit_nsamples_val 68 | ) 69 | self._get_files_stats() 70 | 71 | def _get_files_stats(self): 72 | self.files_paths = glob.glob(self.location + "/*.h5") 73 | self.files_paths.sort() 74 | self.years = [ 75 | int(os.path.splitext(os.path.basename(x))[0][-4:]) for x in self.files_paths 76 | ] 77 | self.n_years = len(self.files_paths) 78 | 79 | with h5py.File(self.files_paths[0], "r") as _f: 80 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 81 | self.n_samples_per_year = _f["fields"].shape[0] 82 | self.img_shape_x = self.params.img_size[0] 83 | self.img_shape_y = self.params.img_size[1] 84 | assert ( 85 | self.img_shape_x <= _f["fields"].shape[2] 86 | and self.img_shape_y <= _f["fields"].shape[3] 87 | ), "image shapes are greater than dataset image shapes" 88 | 89 | self.n_samples_total = self.n_years * self.n_samples_per_year 90 | if self.limit_nsamples is not None: 91 | self.n_samples_total = min(self.n_samples_total, self.limit_nsamples) 92 | logging.info( 93 | "Overriding total number of samples to: {}".format(self.n_samples_total) 94 | ) 95 | self.files = [None for _ in range(self.n_years)] 96 | logging.info("Number of samples per year: {}".format(self.n_samples_per_year)) 97 | logging.info( 98 | "Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format( 99 | self.location, 100 | self.n_samples_total, 101 | self.img_shape_x, 102 | self.img_shape_y, 103 | self.n_in_channels, 104 | ) 105 | ) 106 | 107 | def _open_file(self, year_idx): 108 | _file = h5py.File(self.files_paths[year_idx], "r") 109 | self.files[year_idx] = _file["fields"] 110 | 111 | def __len__(self): 112 | return self.n_samples_total 113 | 114 | def _normalize(self, img): 115 | if self.normalize: 116 | img -= self.means 117 | img /= self.stds 118 | return torch.as_tensor(img) 119 | 120 | def __getitem__(self, global_idx): 121 | year_idx = int(global_idx / self.n_samples_per_year) # which year 122 | local_idx = int( 123 | global_idx % self.n_samples_per_year 124 | ) # which sample in that year 125 | 126 | # open image file 127 | if self.files[year_idx] is None: 128 | self._open_file(year_idx) 129 | step = self.dt # time step 130 | 131 | # boundary conditions to ensure we don't pull data that is not in a specific year 132 | local_idx = local_idx % (self.n_samples_per_year - step) 133 | if local_idx < step: 134 | local_idx += step 135 | 136 | # pre-process and get the image fields 137 | inp_field = self.files[year_idx][ 138 | local_idx, :, 0 : self.img_shape_x, 0 : self.img_shape_y 139 | ] 140 | tar_field = self.files[year_idx][ 141 | local_idx + step, :, 0 : self.img_shape_x, 0 : self.img_shape_y 142 | ] 143 | inp, tar = self._normalize(inp_field), self._normalize(tar_field) 144 | 145 | return inp, tar 146 | -------------------------------------------------------------------------------- /utils/comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from utils.logging_utils import disable_logging 4 | from utils.rank_generator import RankGenerator 5 | import torch 6 | import math 7 | import numpy as np 8 | import torch.distributed as dist 9 | import datetime as dt 10 | from typing import Union 11 | 12 | # dummy placeholder 13 | _COMM_GROUPS = {} 14 | 15 | 16 | # routines for specific comm groups 17 | def get_names(): 18 | """Returns the names of all available communicators.""" 19 | return _COMM_GROUPS.keys() 20 | 21 | 22 | def is_initialized(comm_name): 23 | """check if initialized.""" 24 | return comm_name in _COMM_GROUPS 25 | 26 | 27 | def get_group(comm_name): 28 | """Returns the group of a specified communicator.""" 29 | if not is_initialized(comm_name): 30 | raise IndexError(f"Error, comm {comm_name} not initialized.") 31 | return _COMM_GROUPS[comm_name] 32 | 33 | 34 | def get_size(comm_name): 35 | """Returns the size of a specified communicator.""" 36 | if (not dist.is_initialized()) or (not is_initialized(comm_name)): 37 | return 1 38 | else: 39 | return dist.get_world_size(group=get_group(comm_name)) 40 | 41 | 42 | def get_rank(comm_name): 43 | """Returns the rank in a specified communicator.""" 44 | if (not dist.is_initialized()) or (not is_initialized(comm_name)): 45 | return 0 46 | else: 47 | return dist.get_rank(group=get_group(comm_name)) 48 | 49 | 50 | # routines for world comms 51 | def get_world_size(): 52 | """Returns the world size""" 53 | if not dist.is_initialized(): 54 | return 1 55 | else: 56 | return dist.get_world_size() 57 | 58 | 59 | def get_world_rank(): 60 | """Returns the world rank""" 61 | if not dist.is_initialized(): 62 | return 0 63 | else: 64 | return dist.get_rank() 65 | 66 | 67 | def get_local_rank(): 68 | """Returns the local rank of the current process.""" 69 | if not dist.is_initialized(): 70 | return 0 71 | else: 72 | if os.getenv("LOCAL_RANK") is not None: 73 | # Use env var if available 74 | return int(os.getenv("LOCAL_RANK")) 75 | else: 76 | return get_world_rank() % torch.cuda.device_count() 77 | 78 | 79 | def init(params, verbose=False): 80 | # init torch.distributed 81 | init_process_group() 82 | 83 | # set model parallel sizes 84 | tp = params.get("tp", 1) 85 | cp = params.get("cp", 1) 86 | pp = params.get("pp", 1) 87 | assert pp == 1, "ERROR: pipeline parallel not implemented" 88 | model_parallel_size = tp * cp * pp 89 | dp = get_world_size() // model_parallel_size 90 | assert dp >= 1, "ERROR: data parallel wireup failed since dp = {}".format(dp) 91 | logging.info("Setting DP = {}, TP = {}, CP = {}, PP = {}".format(dp, tp, cp, pp)) 92 | 93 | # init model + dp groups individually 94 | init_model_parallel_info( 95 | tp=tp, 96 | cp=cp, 97 | dp=dp, 98 | pp=pp, 99 | order=params.get("order", "tp-dp"), 100 | verbose=verbose, 101 | ) 102 | 103 | 104 | def init_process_group(): 105 | """Initial torch distributed process group 106 | Uses NCCL 107 | """ 108 | world_size = int(os.getenv("WORLD_SIZE", 1)) 109 | world_rank = int(os.getenv("RANK", 0)) 110 | port = int(os.getenv("MASTER_PORT", 0)) 111 | master_address = os.getenv("MASTER_ADDR") 112 | local_rank = int(os.getenv("LOCAL_RANK", 0)) 113 | 114 | if world_size > 1: 115 | with disable_logging(): 116 | # create tcp store 117 | store = dist.TCPStore( 118 | host_name=master_address, 119 | port=port, 120 | world_size=world_size, 121 | is_master=(world_rank == 0), 122 | timeout=dt.timedelta(seconds=900), 123 | ) 124 | 125 | # initialize process groups 126 | dist.init_process_group( 127 | backend="nccl", rank=world_rank, world_size=world_size, store=store 128 | ) 129 | 130 | 131 | def init_model_parallel_info(tp=1, pp=1, dp=1, cp=1, order="tp-dp", verbose=False): 132 | 133 | world_size = get_world_size() 134 | world_rank = get_world_rank() 135 | 136 | rank_gen = RankGenerator( 137 | tp=tp, 138 | dp=dp, 139 | pp=pp, 140 | cp=cp, 141 | order=order, 142 | ) 143 | 144 | def generator_wrapper(group_type, **kwargs): 145 | """The `RankGenerator` class produces a hyper-rectangle for a given set of 146 | tensor, pipeline, data, and context parallelism. 147 | """ 148 | ranks = rank_gen.get_ranks(group_type, **kwargs) 149 | for x in ranks: 150 | yield x 151 | 152 | # build the different parallel groups 153 | global _COMM_GROUPS # others need access to this 154 | groups_to_build = ["dp", "tp", "cp", "pp", "tp-cp", "dp-cp"] 155 | for grp in groups_to_build: 156 | for ranks in generator_wrapper(grp): 157 | group = dist.new_group(ranks) 158 | if world_rank in ranks: 159 | _COMM_GROUPS[grp] = group 160 | 161 | 162 | def process_comm_list(input_list): 163 | """Given a list of comms, merge them 164 | Ex: ['tp', 'cp'] is ['tp-cp'] 165 | """ 166 | if not input_list or all(item is None for item in input_list): 167 | return [] 168 | 169 | # filter out None values (ex: [None, 'tp] becomes ['tp']) 170 | filtered_list = [item for item in input_list if item is not None] 171 | 172 | if not filtered_list: 173 | return [] 174 | elif len(filtered_list) == 1: 175 | return filtered_list 176 | else: 177 | return ["-".join(filtered_list)] 178 | -------------------------------------------------------------------------------- /utils/data_loader_dali.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 5 | from torch import Tensor 6 | 7 | # concurrent futures 8 | import concurrent.futures as cf 9 | 10 | # distributed stuff 11 | import torch.distributed as dist 12 | from utils import comm 13 | 14 | # dali stuff 15 | from nvidia.dali.pipeline import Pipeline 16 | import nvidia.dali.fn as fn 17 | import nvidia.dali.types as types 18 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 19 | 20 | # es helper 21 | import utils.dali_es_helper as esh 22 | 23 | 24 | def get_data_loader(params, files_pattern, distributed, train): 25 | dataloader = DaliDataLoader(params, files_pattern, train) 26 | 27 | if train: 28 | return dataloader, None, None 29 | else: 30 | return dataloader, None 31 | 32 | 33 | class DaliDataLoader(object): 34 | def get_pipeline(self): 35 | pipeline = Pipeline( 36 | batch_size=self.batch_size, 37 | num_threads=2, 38 | device_id=self.device_index, 39 | py_num_workers=self.num_data_workers, 40 | py_start_method="spawn", 41 | seed=self.global_seed, 42 | ) 43 | 44 | with pipeline: # get input and target 45 | # get input and target 46 | inp, tar = fn.external_source( 47 | source=esh.ERA5ES( 48 | self.location, 49 | self.train, 50 | self.batch_size, 51 | self.dt, 52 | self.img_size, 53 | self.n_in_channels, 54 | self.n_out_channels, 55 | self.num_shards, 56 | self.shard_id, 57 | self.limit_nsamples, 58 | enable_logging=False, 59 | seed=self.global_seed, 60 | ), 61 | num_outputs=2, 62 | layout=["CHW", "CHW"], 63 | batch=False, 64 | no_copy=True, 65 | parallel=True, 66 | ) 67 | 68 | # upload to GPU 69 | inp = inp.gpu() 70 | tar = tar.gpu() 71 | 72 | if self.normalize: 73 | inp = fn.normalize( 74 | inp, 75 | device="gpu", 76 | axis_names="HW", 77 | batch=False, 78 | mean=self.in_bias, 79 | stddev=self.in_scale, 80 | ) 81 | 82 | tar = fn.normalize( 83 | tar, 84 | device="gpu", 85 | axis_names="HW", 86 | batch=False, 87 | mean=self.out_bias, 88 | stddev=self.out_scale, 89 | ) 90 | 91 | pipeline.set_outputs(inp, tar) 92 | return pipeline 93 | 94 | def __init__(self, params, location, train, seed=333): 95 | # set up seeds 96 | # this one is the same on all ranks 97 | self.global_seed = seed 98 | # this one is the same for all ranks of the same model 99 | model_id = comm.get_world_rank() // comm.get_size("tp-cp-pp") 100 | self.model_seed = self.global_seed + model_id 101 | # this seed is supposed to be diffferent for every rank 102 | self.local_seed = self.global_seed + comm.get_world_rank() 103 | 104 | self.num_data_workers = params.num_data_workers 105 | self.device_index = torch.cuda.current_device() 106 | self.batch_size = int(params.local_batch_size) 107 | 108 | self.location = location 109 | self.train = train 110 | self.dt = params.dt 111 | self.n_in_channels = params.n_in_channels 112 | self.n_out_channels = params.n_out_channels 113 | self.img_size = params.img_size 114 | self.limit_nsamples = ( 115 | params.limit_nsamples if train else params.limit_nsamples_val 116 | ) 117 | 118 | # load stats 119 | self.normalize = True 120 | means = np.load(params.global_means_path)[0][: self.n_in_channels] 121 | stds = np.load(params.global_stds_path)[0][: self.n_in_channels] 122 | self.in_bias = means 123 | self.in_scale = stds 124 | means = np.load(params.global_means_path)[0][: self.n_out_channels] 125 | stds = np.load(params.global_stds_path)[0][: self.n_out_channels] 126 | self.out_bias = means 127 | self.out_scale = stds 128 | 129 | # set sharding 130 | if dist.is_initialized(): 131 | self.num_shards = params.data_num_shards 132 | self.shard_id = params.data_shard_id 133 | else: 134 | self.num_shards = 1 135 | self.shard_id = 0 136 | 137 | # get img source data 138 | extsource = esh.ERA5ES( 139 | self.location, 140 | self.train, 141 | self.batch_size, 142 | self.dt, 143 | self.img_size, 144 | self.n_in_channels, 145 | self.n_out_channels, 146 | self.num_shards, 147 | self.shard_id, 148 | self.limit_nsamples, 149 | seed=self.global_seed, 150 | ) 151 | self.num_batches = extsource.num_steps_per_epoch 152 | del extsource 153 | 154 | # create pipeline 155 | self.pipeline = self.get_pipeline() 156 | self.pipeline.start_py_workers() 157 | self.pipeline.build() 158 | 159 | # create iterator 160 | self.iterator = DALIGenericIterator( 161 | [self.pipeline], 162 | ["inp", "tar"], 163 | auto_reset=True, 164 | last_batch_policy=LastBatchPolicy.DROP, 165 | prepare_first_batch=True, 166 | ) 167 | 168 | def __len__(self): 169 | return self.num_batches 170 | 171 | def __iter__(self): 172 | # self.iterator.reset() 173 | for token in self.iterator: 174 | inp = token[0]["inp"] 175 | tar = token[0]["tar"] 176 | 177 | yield inp, tar 178 | -------------------------------------------------------------------------------- /distributed/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from utils import comm 4 | 5 | def init_params_for_shared_weights(model): 6 | """Helper routine to ensure shared weights are the same after initialization""" 7 | with torch.no_grad(): 8 | # distributed sync step 9 | for param in model.parameters(): 10 | if not hasattr(param, "is_shared_mp"): 11 | # all sharded weights manually annotate this field 12 | # if weight doesnt have annotation, then it is a shared weight 13 | # layers like patch-embed, decoder head, pos-embed are fully 14 | # shared (and not sharded) in this example 15 | param.is_shared_mp = ["tp-cp"] # only TP-CP implemented for now 16 | # careful about this stuff.. 17 | param.mark_for_reduction = [] # not all params need special handling 18 | 19 | for comm_group in param.is_shared_mp: 20 | if comm.get_size(comm_group) > 1: 21 | tlist = [ 22 | torch.empty_like(param) 23 | for x in range(comm.get_size(comm_group)) 24 | ] 25 | tlist[comm.get_rank(comm_group)] = param 26 | # gather all weights in the comm group 27 | dist.all_gather(tlist, param, group=comm.get_group(comm_group)) 28 | # use weight of rank 0 29 | # important to use copy here otherwise the handle gets detaches from the optimizer 30 | param.copy_(tlist[0]) 31 | 32 | 33 | # distributed primitives 34 | # helper routine to compute uneven splitting in balanced way: 35 | def compute_split_shapes(size, num_chunks): 36 | # treat trivial case first 37 | if num_chunks == 1: 38 | return [size] 39 | 40 | # first, check if we can split using div-up to balance the load: 41 | chunk_size = (size + num_chunks - 1) // num_chunks 42 | last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) 43 | if last_chunk_size == 0: 44 | # in this case, the last shard would be empty, split with floor instead: 45 | chunk_size = size // num_chunks 46 | last_chunk_size = size - chunk_size * (num_chunks - 1) 47 | 48 | # generate sections list 49 | sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] 50 | 51 | return sections 52 | 53 | 54 | def _reduce(input_, comm_name): 55 | """All-reduce the input tensor across model parallel group.""" 56 | # Bypass the function if we are using only 1 GPU or if 57 | # communicator is not initialized 58 | if comm.get_size(comm_name) == 1: 59 | return input_ 60 | 61 | # All-reduce. 62 | dist.all_reduce(input_.contiguous(), group=comm.get_group(comm_name)) 63 | 64 | return input_ 65 | 66 | 67 | def split_tensor_along_dim(tensor, dim, num_chunks): 68 | """Helper routine to split a tensor along a given dimension""" 69 | if dim >= tensor.dim(): # scattering from dim that doesnt exist 70 | raise ValueError( 71 | f"Error: Scattering along {dim} for a tensor of size {tensor.dim()}" 72 | ) 73 | if tensor.shape[dim] < num_chunks: 74 | raise ValueError( 75 | f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into {num_chunks} chunks" 76 | ) 77 | 78 | # get split 79 | sections = compute_split_shapes(tensor.shape[dim], num_chunks) 80 | tensor_list = list(torch.split(tensor, sections, dim=dim)) 81 | 82 | return tensor_list 83 | 84 | 85 | def _split(input_, dim_, comm_name): 86 | """Split the tensor along dim.""" 87 | # Bypass the function if we are using only 1 GPU or if 88 | # communicator is not initialized 89 | comm_size = comm.get_size(comm_name) 90 | if comm_size == 1: 91 | return input_ 92 | 93 | # Split along dimension. 94 | input_list = split_tensor_along_dim(input_, dim_, comm_size) 95 | 96 | # Note: torch.split does not create contiguous tensors by default. 97 | comm_rank = comm.get_rank(comm_name) 98 | output = input_list[comm_rank].contiguous() 99 | 100 | return output 101 | 102 | 103 | def _gather(input_, dim_, shapes_, comm_name): 104 | """ 105 | Gather tensors and concatinate along the dimension dim_. 106 | """ 107 | comm_size = comm.get_size(comm_name) 108 | if (shapes_ is not None) and (len(shapes_) != comm_size): 109 | raise ValueError(f"Error: passed shapes of size not equal to {comm_size}") 110 | if dim_ >= input_.dim(): # gathering along dim that doesnt exist 111 | raise ValueError( 112 | f"Error: Gathering along {dim} for a tensor of size {tensor.dim()}" 113 | ) 114 | 115 | # Bypass the function if we are using only 1 GPU or if 116 | # communicator is not initialized 117 | if comm_size == 1: 118 | return input_ 119 | 120 | comm_rank = comm.get_rank(comm_name) 121 | input_ = input_.contiguous() 122 | input_shape = list(input_.shape) 123 | if shapes_ is not None: 124 | input_list = [] 125 | for src in range(comm_size): 126 | input_shape[dim_] = shapes_[src] 127 | input_list.append( 128 | torch.empty(input_shape, dtype=input_.dtype, device=input_.device) 129 | ) 130 | else: 131 | # assume equal shape on all ranks 132 | input_list = [torch.empty_like(input_) for _ in range(comm_size)] 133 | 134 | dist.all_gather(input_list, input_, group=comm.get_group(comm_name)) 135 | output = torch.cat(input_list, dim=dim_).contiguous() 136 | 137 | return output 138 | 139 | def _reduce_scatter(input_, dim_, comm_name): 140 | """ 141 | Reduces and scatters along dim_ 142 | """ 143 | comm_size = comm.get_size(comm_name) 144 | if dim_ >= input_.dim(): # RS along dim that doesnt exist 145 | raise ValueError( 146 | f"Error: Reduce-scatter along {dim} for a tensor of size {tensor.dim()}" 147 | ) 148 | 149 | # Bypass the function if we are using only 1 GPU or if 150 | # communicator is not initialized 151 | if comm_size == 1: 152 | return input_ 153 | 154 | comm_rank = comm.get_rank(comm_name) 155 | input_ = input_.contiguous() 156 | 157 | # Split along dimension. Make sure the individual tensors are contiguous! 158 | input_list = [ 159 | t.contiguous() for t in split_tensor_along_dim(input_, dim_, comm_size) 160 | ] 161 | 162 | output = torch.empty_like(input_list[comm_rank].contiguous()) 163 | dist.reduce_scatter(output, input_list, group=comm.get_group(comm_name)) 164 | 165 | return output 166 | -------------------------------------------------------------------------------- /utils/dali_es_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import cupyx as cpx 5 | import h5py 6 | import logging 7 | 8 | 9 | class ERA5ES(object): 10 | # very important: the seed has to be constant across the workers, or otherwise mayhem: 11 | def __init__( 12 | self, 13 | location, 14 | train, 15 | batch_size, 16 | dt, 17 | img_size, 18 | n_in_channels, 19 | n_out_channels, 20 | num_shards, 21 | shard_id, 22 | limit_nsamples, 23 | enable_logging=True, 24 | seed=333, 25 | ): 26 | self.batch_size = batch_size 27 | self.location = location 28 | self.img_size = img_size 29 | self.train = train 30 | self.dt = dt 31 | self.n_in_channels = n_in_channels 32 | self.n_out_channels = n_out_channels 33 | self.rng = np.random.default_rng(seed=seed) 34 | self.num_shards = num_shards 35 | self.shard_id = shard_id 36 | self.limit_nsamples = limit_nsamples 37 | 38 | self._get_files_stats(enable_logging) 39 | self.shuffle = True if train else False 40 | 41 | def _get_files_stats(self, enable_logging): 42 | self.files_paths = glob.glob(self.location + "/*.h5") 43 | self.files_paths.sort() 44 | self.years = [ 45 | int(os.path.splitext(os.path.basename(x))[0][-4:]) for x in self.files_paths 46 | ] 47 | self.n_years = len(self.files_paths) 48 | 49 | with h5py.File(self.files_paths[0], "r") as _f: 50 | logging.info("Getting file stats from {}".format(self.files_paths[0])) 51 | self.n_samples_per_year = _f["fields"].shape[0] 52 | self.img_shape_x = self.img_size[0] 53 | self.img_shape_y = self.img_size[1] 54 | assert ( 55 | self.img_shape_x <= _f["fields"].shape[2] 56 | and self.img_shape_y <= _f["fields"].shape[3] 57 | ), "image shapes are greater than dataset image shapes" 58 | 59 | self.n_samples_total = self.n_years * self.n_samples_per_year 60 | if self.limit_nsamples is not None: 61 | self.n_samples_total = min(self.n_samples_total, self.limit_nsamples) 62 | logging.info( 63 | "Overriding total number of samples to: {}".format(self.n_samples_total) 64 | ) 65 | self.n_samples_shard = self.n_samples_total // self.num_shards 66 | self.files = [None for _ in range(self.n_years)] 67 | self.dsets = [None for _ in range(self.n_years)] 68 | if enable_logging: 69 | logging.info( 70 | "Number of samples per year: {}".format(self.n_samples_per_year) 71 | ) 72 | logging.info( 73 | "Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format( 74 | self.location, 75 | self.n_samples_total, 76 | self.img_shape_x, 77 | self.img_shape_y, 78 | self.n_in_channels, 79 | ) 80 | ) 81 | if self.num_shards > 1: 82 | logging.info( 83 | "Using shards of size {} per rank".format(self.n_samples_shard) 84 | ) 85 | 86 | # number of steps per epoch 87 | self.num_steps_per_epoch = self.n_samples_shard // self.batch_size 88 | self.last_epoch = None 89 | 90 | self.index_permutation = None 91 | # prepare buffers for double buffering 92 | self.current_buffer = 0 93 | self.inp_buffs = [ 94 | cpx.zeros_pinned( 95 | (self.n_in_channels, self.img_shape_x, self.img_shape_y), 96 | dtype=np.float32, 97 | ), 98 | cpx.zeros_pinned( 99 | (self.n_in_channels, self.img_shape_x, self.img_shape_y), 100 | dtype=np.float32, 101 | ), 102 | ] 103 | self.tar_buffs = [ 104 | cpx.zeros_pinned( 105 | (self.n_out_channels, self.img_shape_x, self.img_shape_y), 106 | dtype=np.float32, 107 | ), 108 | cpx.zeros_pinned( 109 | (self.n_out_channels, self.img_shape_x, self.img_shape_y), 110 | dtype=np.float32, 111 | ), 112 | ] 113 | 114 | def __len__(self): 115 | return self.n_samples_shard 116 | 117 | def __del__(self): 118 | for f in self.files: 119 | if f is not None: 120 | f.close() 121 | 122 | def __call__(self, sample_info): 123 | # check if epoch is done 124 | if sample_info.iteration >= self.num_steps_per_epoch: 125 | raise StopIteration 126 | 127 | # check if we need to shuffle again 128 | if sample_info.epoch_idx != self.last_epoch: 129 | self.last_epoch = sample_info.epoch_idx 130 | if self.shuffle: 131 | self.index_permutation = self.rng.permutation(self.n_samples_total) 132 | else: 133 | self.index_permutation = np.arange(self.n_samples_total) 134 | # shard the data 135 | start = self.n_samples_shard * self.shard_id 136 | end = start + self.n_samples_shard 137 | self.index_permutation = self.index_permutation[start:end] 138 | 139 | # determine local and sample idx 140 | sample_idx = self.index_permutation[sample_info.idx_in_epoch] 141 | year_idx = int(sample_idx / self.n_samples_per_year) # which year we are on 142 | local_idx = int( 143 | sample_idx % self.n_samples_per_year 144 | ) # which sample in that year we are on - determines indices for centering 145 | 146 | step = self.dt # time step 147 | 148 | # boundary conditions to ensure we don't pull data that is not in a specific year 149 | local_idx = local_idx % (self.n_samples_per_year - step) 150 | if local_idx < step: 151 | local_idx += step 152 | 153 | if self.files[year_idx] is None: 154 | self.files[year_idx] = h5py.File(self.files_paths[year_idx], "r") 155 | self.dsets[year_idx] = self.files[year_idx]["fields"] 156 | 157 | tmp_inp = self.dsets[year_idx][local_idx, ...] 158 | tmp_tar = self.dsets[year_idx][local_idx + step, ...] 159 | 160 | # handles to buffers buffers 161 | inp = self.inp_buffs[self.current_buffer] 162 | tar = self.tar_buffs[self.current_buffer] 163 | self.current_buffer = (self.current_buffer + 1) % 2 164 | 165 | # crop the pixels: 166 | inp[...] = tmp_inp[..., : self.img_shape_x, : self.img_shape_y] 167 | tar[...] = tmp_tar[..., : self.img_shape_x, : self.img_shape_y] 168 | 169 | return inp, tar 170 | -------------------------------------------------------------------------------- /utils/rank_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Callable, List, Optional 4 | 5 | """ 6 | Utility to generate hyperrectangle of process groups for any general model parallelism 7 | Taken from MegatronLM: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py 8 | """ 9 | 10 | 11 | def generate_masked_orthogonal_rank_groups( 12 | world_size: int, parallel_size: List[int], mask: List[bool] 13 | ) -> List[List[int]]: 14 | """Generate orthogonal parallel groups based on the parallel size and mask. 15 | 16 | Arguments: 17 | world_size (int): world size 18 | 19 | parallel_size (List[int]): 20 | The parallel size of each orthogonal parallel type. For example, if 21 | tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, 22 | and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. 23 | 24 | mask (List[bool]): 25 | The mask controls which parallel methods the generated groups represent. If mask[i] is 26 | True, it means the generated group contains the i-th parallelism method. For example, 27 | if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then 28 | the generated group is the `tp-dp` group, if the mask = [False, True, False], then the 29 | generated group is the `pp` group. 30 | 31 | Algorithm: 32 | For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and 33 | local_rank satisfy the following equation: 34 | global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) 35 | tp_rank \in [0, tp_size) 36 | dp_rank \in [0, dp_size) 37 | pp_rank \in [0, pp_size) 38 | 39 | If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. 40 | For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the 41 | dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) 42 | The tp_rank and pp_rank will be combined to form the `dp_group_index`. 43 | dp_group_index = tp_rank + pp_rank * tp_size (2) 44 | 45 | So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in 46 | range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the 47 | equation (1). 48 | 49 | This function solve this math problem. 50 | 51 | For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], 52 | and the mask = [False, True, False]. Then, 53 | dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 54 | dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 55 | ... 56 | dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 57 | 58 | dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] 59 | dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] 60 | ... 61 | dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] 62 | """ 63 | 64 | def prefix_product(a: List[int], init=1) -> List[int]: 65 | r = [init] 66 | for v in a: 67 | init = init * v 68 | r.append(init) 69 | return r 70 | 71 | def inner_product(a: List[int], b: List[int]) -> int: 72 | return sum([x * y for x, y in zip(a, b)]) 73 | 74 | def decompose(index, shape, stride=None): 75 | """ 76 | This function solve the math problem below: 77 | There is an equation: 78 | index = sum(idx[i] * stride[i]) 79 | And given the value of index, stride. 80 | Return the idx. 81 | This function will used to get the pp/dp/pp_rank 82 | from group_index and rank_in_group. 83 | """ 84 | if stride is None: 85 | stride = prefix_product(shape) 86 | idx = [(index // d) % s for s, d in zip(shape, stride)] 87 | # stride is a prefix_product result. And the value of stride[-1] 88 | # is not used. 89 | assert ( 90 | sum([x * y for x, y in zip(idx, stride[:-1])]) == index 91 | ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) 92 | return idx 93 | 94 | masked_shape = [s for s, m in zip(parallel_size, mask) if m] 95 | unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] 96 | 97 | global_stride = prefix_product(parallel_size) 98 | masked_stride = [d for d, m in zip(global_stride, mask) if m] 99 | unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] 100 | 101 | group_size = prefix_product(masked_shape)[-1] 102 | num_of_group = world_size // group_size 103 | 104 | ranks = [] 105 | for group_index in range(num_of_group): 106 | # get indices from unmasked for group_index. 107 | decomposed_group_idx = decompose(group_index, unmasked_shape) 108 | rank = [] 109 | for rank_in_group in range(group_size): 110 | # get indices from masked for rank_in_group. 111 | decomposed_rank_idx = decompose(rank_in_group, masked_shape) 112 | rank.append( 113 | inner_product(decomposed_rank_idx, masked_stride) 114 | + inner_product(decomposed_group_idx, unmasked_stride) 115 | ) 116 | ranks.append(rank) 117 | return ranks 118 | 119 | 120 | class RankGenerator(object): 121 | """A class for generating rank groups for different modes of parallelism.""" 122 | 123 | def __init__( 124 | self, tp: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0 125 | ) -> None: 126 | self.tp = tp 127 | self.dp = dp 128 | self.pp = pp 129 | self.cp = cp 130 | self.rank_offset = rank_offset 131 | self.world_size = tp * dp * pp * cp 132 | 133 | self.name_to_size = { 134 | "tp": self.tp, 135 | "pp": self.pp, 136 | "dp": self.dp, 137 | "cp": self.cp, 138 | } 139 | self.order = order 140 | order = order.lower() 141 | 142 | for name in self.name_to_size.keys(): 143 | if name not in order and self.name_to_size[name] != 1: 144 | raise RuntimeError( 145 | f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't" 146 | f"specified the order ({self.order})." 147 | ) 148 | elif name not in order: 149 | order = order + "-" + name 150 | 151 | self.order = order 152 | self.ordered_size = [] 153 | self.ordered_size_w_ep = [] 154 | 155 | for token in order.split("-"): 156 | self.ordered_size.append(self.name_to_size[token]) 157 | 158 | def get_mask(self, order: str, token: str): 159 | """Create a mask for the specified tokens based on the given order. 160 | 161 | Args: 162 | order (str): The order of parallelism types (e.g., 'tp-dp-pp'). 163 | token (str): The specific parallelism types to include in the mask, 164 | separated by hyphens (e.g., 'tp-dp'). 165 | """ 166 | ordered_token = order.split("-") 167 | token = token.split("-") 168 | mask = [False] * len(ordered_token) 169 | for t in token: 170 | mask[ordered_token.index(t)] = True 171 | return mask 172 | 173 | def get_ranks(self, token): 174 | """Get rank group by input token. 175 | 176 | Args: 177 | token (str): 178 | Specify the ranks type that want to get. If we want 179 | to obtain multiple parallel types, we can use a hyphen 180 | '-' to separate them. For example, if we want to obtain 181 | the TP_DP group, the token should be 'tp-dp'. 182 | """ 183 | parallel_size = self.ordered_size 184 | order = self.order 185 | mask = self.get_mask(order, token) 186 | ranks = generate_masked_orthogonal_rank_groups( 187 | self.world_size, parallel_size, mask 188 | ) 189 | if self.rank_offset > 0: 190 | for rank_group in ranks: 191 | for i in range(len(rank_group)): 192 | rank_group[i] += self.rank_offset 193 | return ranks 194 | -------------------------------------------------------------------------------- /distributed/mappings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.parallel import DistributedDataParallel 4 | from utils import comm 5 | from functools import partial 6 | 7 | # torch utils 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | # helper functions 11 | from distributed.helpers import ( 12 | _reduce, 13 | _split, 14 | _gather, 15 | _reduce_scatter, 16 | compute_split_shapes, 17 | ) 18 | 19 | 20 | class _CopyToParallelRegion(torch.autograd.Function): 21 | """Pass the input to the parallel region.""" 22 | 23 | @staticmethod 24 | def symbolic(graph, input_, comm_name_): 25 | """symbolic method""" 26 | return input_ 27 | 28 | @staticmethod 29 | def forward(ctx, input_, comm_name_): 30 | ctx.comm_name = comm_name_ 31 | return input_ 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | return _reduce(grad_output, comm_name=ctx.comm_name), None 36 | 37 | 38 | class _ReduceFromParallelRegion(torch.autograd.Function): 39 | """All-reduce the input from the parallel region.""" 40 | 41 | @staticmethod 42 | def symbolic(graph, input_, comm_name_): 43 | """symbolic method""" 44 | return _reduce(input_, comm_name=comm_name_) 45 | 46 | @staticmethod 47 | def forward(ctx, input_, comm_name_): 48 | return _reduce(input_, comm_name=comm_name_) 49 | 50 | @staticmethod 51 | def backward(ctx, grad_output): 52 | return grad_output, None 53 | 54 | 55 | class _GatherFromParallelRegion(torch.autograd.Function): 56 | """Gather the input and keep it on the rank.""" 57 | 58 | @staticmethod 59 | def symbolic(graph, input_, dim_, shapes_, comm_name_): 60 | return _gather(input_, dim_, shapes_, comm_name_) 61 | 62 | @staticmethod 63 | def forward(ctx, input_, dim_, shapes_, comm_name_): 64 | ctx.dim = dim_ 65 | ctx.comm_name = comm_name_ 66 | return _gather(input_, dim_, shapes_, comm_name_) 67 | 68 | @staticmethod 69 | def backward(ctx, grad_output): 70 | return _split(grad_output, ctx.dim, ctx.comm_name), None, None, None 71 | 72 | 73 | class _ScatterToParallelRegion(torch.autograd.Function): 74 | """Split the input and keep only the corresponding chunk to the rank.""" 75 | 76 | @staticmethod 77 | def symbolic(graph, input_, dim_, comm_name_): 78 | return _split(input_, dim_, comm_name_) 79 | 80 | @staticmethod 81 | def forward(ctx, input_, dim_, comm_name_): 82 | ctx.dim = dim_ 83 | ctx.comm_name = comm_name_ 84 | ctx.split_shapes = compute_split_shapes( 85 | input_.shape[dim_], comm.get_size(comm_name_) 86 | ) 87 | return _split(input_, dim_, comm_name_) 88 | 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | return ( 92 | _gather(grad_output, ctx.dim, ctx.split_shapes, ctx.comm_name), 93 | None, 94 | None, 95 | ) 96 | 97 | 98 | class _ReduceScatterToParallelRegion(torch.autograd.Function): 99 | """Reduce the inputs and scatter to ranks.""" 100 | 101 | @staticmethod 102 | def symbolic(graph, input_, dim_, shapes_, comm_name_): 103 | return _reduce_scatter(input_, dim_, shapes_, comm_name_) 104 | 105 | @staticmethod 106 | def forward(ctx, input_, dim_, comm_name_): 107 | ctx.dim = dim_ 108 | ctx.comm_name = comm_name_ 109 | ctx.split_shapes = compute_split_shapes( 110 | input_.shape[dim_], comm.get_size(comm_name_) 111 | ) 112 | return _reduce_scatter(input_, dim_, comm_name_) 113 | 114 | @staticmethod 115 | def backward(ctx, grad_output): 116 | return ( 117 | _gather(grad_output, ctx.dim, ctx.split_shapes, ctx.comm_name), 118 | None, 119 | None, 120 | ) 121 | 122 | 123 | class _AllGatherFromParallelRegion(torch.autograd.Function): 124 | """Reduce the inputs and scatter to ranks.""" 125 | 126 | @staticmethod 127 | def symbolic(graph, input_, dim_, shapes_, comm_name_): 128 | return _gather(input_, dim_, shapes_, comm_name_) 129 | 130 | @staticmethod 131 | def forward(ctx, input_, dim_, shapes_, comm_name_): 132 | ctx.dim = dim_ 133 | ctx.comm_name = comm_name_ 134 | return _gather(input_, dim_, shapes_, comm_name_) 135 | 136 | @staticmethod 137 | def backward(ctx, grad_output): 138 | return _reduce_scatter(grad_output, ctx.dim, ctx.comm_name), None, None, None 139 | 140 | 141 | # matmul parallel 142 | @torch.compiler.disable 143 | def copy_to_parallel_region(input_, comm_name): 144 | """Parallel copy helper""" 145 | return _CopyToParallelRegion.apply(input_, comm_name) 146 | 147 | 148 | @torch.compiler.disable 149 | def reduce_from_parallel_region(input_, comm_name): 150 | """Parallel reduction helper""" 151 | return _ReduceFromParallelRegion.apply(input_, comm_name) 152 | 153 | 154 | @torch.compiler.disable 155 | def gather_from_parallel_region(input_, dim, shapes, comm_name): 156 | """Parallel gather helper""" 157 | return _GatherFromParallelRegion.apply(input_, dim, shapes, comm_name) 158 | 159 | 160 | @torch.compiler.disable 161 | def all_gather_from_parallel_region(input_, dim, shapes, comm_name): 162 | """ 163 | Parallel allgather helper that combines reduce-scatter 164 | in the bwd pass 165 | """ 166 | return _AllGatherFromParallelRegion.apply(input_, dim, shapes, comm_name) 167 | 168 | 169 | @torch.compiler.disable 170 | def reduce_scatter_to_parallel_region(input_, dim, shapes, comm_name): 171 | """Parallel reduce scatter helper""" 172 | return _ReduceScatterToParallelRegion.apply(input_, dim, shapes, comm_name) 173 | 174 | 175 | @torch.compiler.disable 176 | def scatter_to_parallel_region(input_, dim, comm_name): 177 | """Parallel scatter helper""" 178 | return _ScatterToParallelRegion.apply(input_, dim, comm_name) 179 | 180 | 181 | def init_ddp_model_and_reduction_hooks( 182 | model, 183 | device_ids, 184 | output_device, 185 | bucket_cap_mb=25, 186 | broadcast_buffers=True, 187 | find_unused_parameters=False, 188 | gradient_as_bucket_view=True, 189 | static_graph=False, 190 | ): 191 | # early exit if we are not in a distributed setting: 192 | if not dist.is_initialized(): 193 | return model 194 | 195 | need_hooks = False 196 | if comm.get_size("tp-cp") == 1: 197 | # no model parallel, just use DDP with 198 | # the full world size 199 | ddp_group = None 200 | elif comm.get_size("cp") == 1: 201 | # only cp requires additional allreduce 202 | # if no cp, use DDP 203 | ddp_group = comm.get_group("dp") 204 | else: 205 | broadcast_buffers = False 206 | ddp_group = comm.get_group("dp") 207 | need_hooks = True # need a grad hook for additional reduce 208 | 209 | model = DistributedDataParallel( 210 | model, 211 | device_ids=device_ids, 212 | output_device=output_device, 213 | bucket_cap_mb=bucket_cap_mb, 214 | broadcast_buffers=broadcast_buffers, 215 | find_unused_parameters=find_unused_parameters, 216 | gradient_as_bucket_view=gradient_as_bucket_view, 217 | static_graph=static_graph, 218 | process_group=ddp_group, 219 | ) 220 | if not need_hooks: 221 | return model 222 | 223 | # define comm hook because some params need additional allreduce 224 | def reduction_comm_hook( 225 | state: object, bucket: dist.GradBucket 226 | ) -> torch.futures.Future[torch.Tensor]: 227 | # allreduce everything first 228 | buff = bucket.buffer() 229 | # get future for allreduce 230 | 231 | # do the normal DDP all reduce 232 | fut = dist.all_reduce( 233 | buff, op=dist.ReduceOp.AVG, group=comm.get_group("dp"), async_op=True 234 | ).get_future() 235 | 236 | # get grads for shared weights 237 | params = bucket.parameters() 238 | 239 | def grad_reduction(fut, grads, group): 240 | # reduce remaining gradients 241 | coalesced = _flatten_dense_tensors(grads) 242 | # extra allreduce for param wgrads that need it 243 | dist.all_reduce( 244 | coalesced, 245 | op=dist.ReduceOp.SUM, 246 | group=comm.get_group(group), 247 | async_op=False, 248 | ) 249 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 250 | buf.copy_(synced) 251 | return bucket.buffer() 252 | 253 | append_hooks = False 254 | for group in comm.get_names(): 255 | if group == "dp": 256 | continue 257 | grads = [] 258 | for p in params: 259 | # p needs an allreduce in group 260 | if group in p.mark_for_reduction: 261 | if p.grad is not None: 262 | grads.append(p.grad.data) 263 | if not grads: 264 | continue 265 | # append the new reduction functions 266 | append_hooks = True 267 | fut = fut.then(partial(grad_reduction, grads=grads, group=group)) 268 | 269 | if not append_hooks: 270 | # this bucket's params only needed the DP allreduce 271 | # return the bucket directly 272 | return fut.then(lambda fut: fut.value()[0]) 273 | else: 274 | # got some additional allreduce chained to fut 275 | # the grad_reduction will return the bucket 276 | return fut 277 | 278 | # register model comm hook 279 | model.register_comm_hook(state=None, hook=reduction_comm_hook) 280 | return model 281 | -------------------------------------------------------------------------------- /distributed/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import comm 5 | 6 | from torch.cuda import amp 7 | 8 | from networks.helpers import trunc_normal_ 9 | 10 | # matmul parallel 11 | from distributed.mappings import ( 12 | copy_to_parallel_region, 13 | gather_from_parallel_region, 14 | all_gather_from_parallel_region, 15 | reduce_from_parallel_region, 16 | reduce_scatter_to_parallel_region, 17 | ) 18 | from typing import Tuple 19 | 20 | 21 | class DistributedMatmul(nn.Module): 22 | """Distributed Matrix Multiply 23 | Y = XW 24 | W is sharded in a 1D fashion: either row or col parallel 25 | W is a (in_dim, out_dim) size matrix when unsharded 26 | So shape of W is either (in_dim/n, out_dim) or (in_dim, out_dim/n) 27 | X is assumed sharded similarly to match the dimensions 28 | comm_act_name is an orthogonal comm used for sharding the activation 29 | X using m procs (batch_seq/m, in_dim) 30 | """ 31 | 32 | def __init__( 33 | self, 34 | inp_dim, 35 | out_dim, 36 | comm_inp_name, 37 | comm_out_name, 38 | comm_act_name="cp", 39 | bias=True, 40 | ): 41 | super(DistributedMatmul, self).__init__() 42 | 43 | # get sizes 44 | self.comm_inp_name = comm_inp_name 45 | self.comm_out_name = comm_out_name 46 | comm_inp_size = comm.get_size(self.comm_inp_name) 47 | comm_out_size = comm.get_size(self.comm_out_name) 48 | 49 | assert not ( 50 | comm_inp_size > 1 and comm_out_size > 1 51 | ), "Error, weights are sharded in a 2D fashion, not supported currently" 52 | assert ( 53 | inp_dim % comm_inp_size == 0 54 | ), f"Error, the size of input feature dim ({inp_dim}) has to be evenly divisible by the input feature comm dim ({comm_inp_size})" 55 | assert ( 56 | out_dim % comm_out_size == 0 57 | ), f"Error, the size of output feature dim ({out_dim}) has to be evenly divisible by the output feature comm dim ({comm_out_size})" 58 | 59 | # compute reduced dims 60 | inp_dim_local = inp_dim // comm_inp_size 61 | out_dim_local = out_dim // comm_out_size 62 | 63 | # parameters 64 | self.weight = nn.Parameter(torch.ones(out_dim_local, inp_dim_local)) 65 | self.weight.is_shared_mp = [ 66 | comm_act_name 67 | ] # weights are sharded in tp but shared across cp 68 | self.weight.mark_for_reduction = [ 69 | comm_act_name 70 | ] # shared weights must be additionally reduced 71 | if bias: 72 | self.bias = nn.Parameter(torch.ones(1, 1, out_dim_local)) 73 | # if inp dim of W is sharded, then the bias is shared across this group and also 74 | # shared in cp grp 75 | self.bias.is_shared_mp = [self.comm_inp_name, comm_act_name] 76 | self.bias.mark_for_reduction = [ 77 | comm_act_name 78 | ] # shared bias must be additionally reduced 79 | 80 | # init weights 81 | self._init_weights() 82 | 83 | def _init_weights(self): 84 | trunc_normal_(self.weight, std=0.02) 85 | if hasattr(self, "bias"): 86 | nn.init.constant_(self.bias, 0.0) 87 | 88 | def forward(self, x): 89 | x_cp = copy_to_parallel_region(x, self.comm_out_name) 90 | # don't add bias (else allreduce will add it too often) 91 | x_loc = F.linear(x_cp, self.weight, bias=None) 92 | x_out = reduce_from_parallel_region(x_loc, self.comm_inp_name) 93 | if hasattr(self, "bias"): 94 | x_out = x_out + self.bias 95 | return x_out 96 | 97 | 98 | class DistributedMLP(nn.Module): 99 | """Distributed MLP layer 100 | Currently implements 1D tensor parallelism 101 | """ 102 | 103 | def __init__( 104 | self, 105 | in_features, 106 | hidden_features=None, 107 | out_features=None, 108 | comm_tp_name="tp", 109 | comm_cp_name="cp", 110 | act_layer=nn.GELU, 111 | drop=0.0, 112 | ): 113 | 114 | super(DistributedMLP, self).__init__() 115 | out_features = out_features or in_features 116 | hidden_features = hidden_features or in_features 117 | 118 | self.fc1 = DistributedMatmul( 119 | in_features, 120 | hidden_features, 121 | comm_inp_name=None, 122 | comm_out_name=comm_tp_name, 123 | comm_act_name=comm_cp_name, 124 | bias=True, 125 | ) 126 | 127 | self.fc2 = DistributedMatmul( 128 | hidden_features, 129 | out_features, 130 | comm_inp_name=comm_tp_name, 131 | comm_out_name=None, 132 | comm_act_name=comm_cp_name, 133 | bias=True, 134 | ) 135 | 136 | self.act = act_layer() 137 | self.drop = nn.Dropout(drop) 138 | 139 | def forward(self, x): 140 | x = self.fc1(x) 141 | x = self.act(x) 142 | x = self.drop(x) 143 | x = self.fc2(x) 144 | x = self.drop(x) 145 | return x 146 | 147 | 148 | class DistributedAttention(nn.Module): 149 | """Distributed Attention layer""" 150 | 151 | def __init__( 152 | self, 153 | dim, 154 | comm_tp_name="tp", 155 | comm_cp_name="cp", 156 | cp_shapes=None, 157 | num_heads=8, 158 | qkv_bias=False, 159 | attn_drop=0.0, 160 | proj_drop=0.0, 161 | ): 162 | 163 | super(DistributedAttention, self).__init__() 164 | 165 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 166 | 167 | self.num_heads = num_heads 168 | assert ( 169 | num_heads % comm.get_size(comm_tp_name) == 0 170 | ), "heads are not evenly split across TP model ranks" 171 | 172 | self.num_heads_local = num_heads // comm.get_size(comm_tp_name) 173 | self.head_dim = dim // self.num_heads 174 | self.scale = (dim // self.num_heads) ** -0.5 175 | self.fused_attn = True 176 | 177 | self.comm_tp_name = comm_tp_name 178 | self.comm_cp_name = comm_cp_name 179 | self.cp_shapes = cp_shapes 180 | 181 | # qkv is col parallel in the weights 182 | self.q = DistributedMatmul( 183 | dim, 184 | dim, 185 | comm_inp_name=None, 186 | comm_out_name=comm_tp_name, 187 | bias=qkv_bias, 188 | comm_act_name=comm_cp_name, 189 | ) 190 | self.k = DistributedMatmul( 191 | dim, 192 | dim, 193 | comm_inp_name=None, 194 | comm_out_name=comm_tp_name, 195 | bias=qkv_bias, 196 | comm_act_name=comm_cp_name, 197 | ) 198 | self.v = DistributedMatmul( 199 | dim, 200 | dim, 201 | comm_inp_name=None, 202 | comm_out_name=comm_tp_name, 203 | bias=qkv_bias, 204 | comm_act_name=comm_cp_name, 205 | ) 206 | self.attn_drop = nn.Dropout(attn_drop) 207 | 208 | # proj is row parallel in the weights 209 | self.proj = DistributedMatmul( 210 | dim, 211 | dim, 212 | comm_inp_name=comm_tp_name, 213 | comm_out_name=None, 214 | comm_act_name=comm_cp_name, 215 | ) 216 | self.proj_drop = nn.Dropout(proj_drop) 217 | 218 | def forward(self, x): 219 | # note: N is local sequence shard if CP is on 220 | B, N, C = x.shape 221 | 222 | q = ( 223 | self.q(x) 224 | .reshape(B, N, self.num_heads_local, self.head_dim) 225 | .permute(0, 2, 1, 3) 226 | ) 227 | k = ( 228 | self.k(x) 229 | .reshape(B, N, self.num_heads_local, self.head_dim) 230 | .permute(0, 2, 1, 3) 231 | ) 232 | v = ( 233 | self.v(x) 234 | .reshape(B, N, self.num_heads_local, self.head_dim) 235 | .permute(0, 2, 1, 3) 236 | ) 237 | 238 | k = all_gather_from_parallel_region( 239 | k, dim=2, shapes=self.cp_shapes, comm_name=self.comm_cp_name 240 | ) 241 | v = all_gather_from_parallel_region( 242 | v, dim=2, shapes=self.cp_shapes, comm_name=self.comm_cp_name 243 | ) 244 | 245 | if self.fused_attn: 246 | x = F.scaled_dot_product_attention( 247 | q, 248 | k, 249 | v, 250 | dropout_p=self.attn_drop.p, 251 | ) 252 | else: 253 | q = q * self.scale 254 | attn = q @ k.transpose(-2, -1) 255 | attn = attn.softmax(dim=-1) 256 | attn = self.attn_drop(attn) 257 | x = attn @ v 258 | 259 | # transpose back 260 | x = x.transpose(1, 2).reshape(B, N, self.num_heads_local * self.head_dim) 261 | 262 | # this is distributed again 263 | x = self.proj(x) 264 | 265 | # generally we have to be super careful with dropout layers, since 266 | # those are normalized over the dropouts. That would need to be reduced across nodes 267 | x = self.proj_drop(x) 268 | 269 | return x 270 | 271 | 272 | class DistributedLayerNorm(nn.Module): 273 | """ 274 | Distributed layer norm layer 275 | Sequence parallel only 276 | """ 277 | 278 | def __init__( 279 | self, 280 | normalized_shape, 281 | eps=1e-05, 282 | elementwise_affine=True, 283 | bias=True, 284 | device=None, 285 | dtype=None, 286 | comm_tp_name="tp", 287 | comm_cp_name="cp", 288 | ): 289 | super(DistributedLayerNorm, self).__init__() 290 | 291 | self.norm = nn.LayerNorm( 292 | normalized_shape, 293 | eps=eps, 294 | elementwise_affine=elementwise_affine, 295 | bias=bias, 296 | device=device, 297 | dtype=dtype, 298 | ) 299 | 300 | if elementwise_affine: 301 | # affine weights need additional allreduce and are shared 302 | # across all groups 303 | self.norm.weight.is_shared_mp = [comm_tp_name, comm_cp_name] 304 | self.norm.weight.mark_for_reduction = [comm_cp_name] 305 | if bias: 306 | self.norm.bias.is_shared_mp = [comm_tp_name, comm_cp_name] 307 | self.norm.bias.mark_for_reduction = [comm_cp_name] 308 | 309 | def forward(self, x): 310 | return self.norm(x) 311 | -------------------------------------------------------------------------------- /networks/vit.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial 5 | from networks.helpers import DropPath, trunc_normal_ 6 | 7 | # mp stuff 8 | from utils import comm 9 | from distributed.layers import ( 10 | DistributedMatmul, 11 | DistributedMLP, 12 | DistributedAttention, 13 | DistributedLayerNorm, 14 | ) 15 | from distributed.helpers import compute_split_shapes 16 | from distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region 17 | 18 | 19 | class MLP(nn.Module): 20 | def __init__( 21 | self, 22 | in_features, 23 | hidden_features=None, 24 | out_features=None, 25 | act_layer=nn.GELU, 26 | drop=0.0, 27 | ): 28 | super().__init__() 29 | out_features = out_features or in_features 30 | hidden_features = hidden_features or in_features 31 | self.fc1 = nn.Linear(in_features, hidden_features) 32 | self.act = act_layer() 33 | self.fc2 = nn.Linear(hidden_features, out_features) 34 | self.drop = nn.Dropout(drop) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = self.act(x) 39 | x = self.drop(x) 40 | x = self.fc2(x) 41 | x = self.drop(x) 42 | return x 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): 47 | super().__init__() 48 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 49 | self.num_heads = num_heads 50 | self.head_dim = dim // num_heads 51 | self.scale = self.head_dim**-0.5 52 | self.fused_attn = True 53 | 54 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 55 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 56 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 57 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 58 | self.attn_drop = nn.Dropout(attn_drop) 59 | self.proj = nn.Linear(dim, dim) 60 | self.proj_drop = nn.Dropout(proj_drop) 61 | 62 | def forward(self, x): 63 | B, N, C = x.shape 64 | 65 | q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 66 | k = self.k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 67 | v = self.v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 68 | 69 | if self.fused_attn: 70 | x = F.scaled_dot_product_attention( 71 | q, 72 | k, 73 | v, 74 | dropout_p=self.attn_drop.p, 75 | ) 76 | else: 77 | q = q * self.scale 78 | attn = q @ k.transpose(-2, -1) 79 | attn = attn.softmax(dim=-1) 80 | attn = self.attn_drop(attn) 81 | x = attn @ v 82 | 83 | x = x.transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | def __init__( 91 | self, 92 | dim, 93 | num_heads, 94 | mlp_ratio=4.0, 95 | qkv_bias=False, 96 | drop=0.0, 97 | attn_drop=0.0, 98 | drop_path=0.0, 99 | act_layer=nn.GELU, 100 | norm_layer=nn.LayerNorm, 101 | cp_shapes=None, 102 | ): 103 | super().__init__() 104 | 105 | mlp_hidden_dim = int(dim * mlp_ratio) 106 | 107 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 108 | if (comm.get_size("tp-cp")) > 1: 109 | # model parallelism is on, distribute the layers 110 | # tp: tensor parallel shards the weights 111 | # cp: context parallel shards the sequence 112 | self.attn = DistributedAttention( 113 | dim, 114 | num_heads=num_heads, 115 | qkv_bias=qkv_bias, 116 | attn_drop=attn_drop, 117 | proj_drop=drop, 118 | comm_tp_name="tp", 119 | comm_cp_name="cp", 120 | cp_shapes=cp_shapes, 121 | ) 122 | self.mlp = DistributedMLP( 123 | in_features=dim, 124 | hidden_features=mlp_hidden_dim, 125 | act_layer=act_layer, 126 | drop=drop, 127 | comm_tp_name="tp", 128 | comm_cp_name="cp", 129 | ) 130 | self.norm1 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp") 131 | self.norm2 = DistributedLayerNorm(dim, comm_tp_name="tp", comm_cp_name="cp") 132 | else: 133 | self.norm1 = norm_layer(dim) 134 | self.norm2 = norm_layer(dim) 135 | self.attn = Attention( 136 | dim, 137 | num_heads=num_heads, 138 | qkv_bias=qkv_bias, 139 | attn_drop=attn_drop, 140 | proj_drop=drop, 141 | ) 142 | self.mlp = MLP( 143 | in_features=dim, 144 | hidden_features=mlp_hidden_dim, 145 | act_layer=act_layer, 146 | drop=drop, 147 | ) 148 | 149 | def forward(self, x): 150 | y = self.attn(self.norm1(x)) 151 | x = x + self.drop_path(y) 152 | x = x + self.drop_path(self.mlp(self.norm2(x))) 153 | return x 154 | 155 | 156 | class PatchEmbed(nn.Module): 157 | """Image to Patch Embedding""" 158 | 159 | def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768): 160 | super().__init__() 161 | # grid of patches 162 | self.h = img_size[0] // patch_size 163 | self.w = img_size[1] // patch_size 164 | num_patches = self.h * self.w 165 | self.img_size = img_size 166 | self.patch_size = patch_size 167 | self.num_patches = num_patches 168 | 169 | self.proj = nn.Conv2d( 170 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 171 | ) 172 | 173 | def forward(self, x): 174 | B, C, H, W = x.shape 175 | x = self.proj(x).flatten(2).transpose(1, 2) 176 | return x 177 | 178 | 179 | class VisionTransformer(nn.Module): 180 | def __init__( 181 | self, 182 | img_size=[224, 224], 183 | patch_size=16, 184 | in_chans=3, 185 | out_chans=3, 186 | embed_dim=768, 187 | depth=12, 188 | num_heads=12, 189 | mlp_ratio=4.0, 190 | qkv_bias=False, 191 | drop_rate=0.0, 192 | attn_drop_rate=0.0, 193 | drop_path_rate=0.0, 194 | norm_layer=nn.LayerNorm, 195 | **kwargs 196 | ): 197 | super().__init__() 198 | self.num_features = self.embed_dim = embed_dim 199 | self.patch_size = patch_size 200 | self.img_size = img_size 201 | self.out_ch = out_chans 202 | self.drop_rate = drop_rate 203 | 204 | self.patch_embed = PatchEmbed( 205 | img_size=img_size, 206 | patch_size=patch_size, 207 | in_chans=in_chans, 208 | embed_dim=self.embed_dim, 209 | ) 210 | num_patches = self.patch_embed.num_patches 211 | 212 | # if context parallel, split the sequence/context 213 | self.cp_shapes = compute_split_shapes(num_patches, comm.get_size("cp")) 214 | 215 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) 216 | self.pos_drop = nn.Dropout(p=drop_rate) 217 | 218 | dpr = [ 219 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 220 | ] # stochastic depth decay rule 221 | 222 | self.blocks = nn.ModuleList( 223 | [ 224 | Block( 225 | dim=embed_dim, 226 | num_heads=num_heads, 227 | mlp_ratio=mlp_ratio, 228 | qkv_bias=qkv_bias, 229 | drop=drop_rate, 230 | attn_drop=attn_drop_rate, 231 | drop_path=dpr[i], 232 | norm_layer=norm_layer, 233 | cp_shapes=self.cp_shapes, 234 | ) 235 | for i in range(depth) 236 | ] 237 | ) 238 | 239 | self.norm = norm_layer(embed_dim) 240 | 241 | self.out_size = self.out_ch * self.patch_size * self.patch_size 242 | 243 | self.head = nn.Linear(embed_dim, self.out_size, bias=False) 244 | 245 | trunc_normal_(self.pos_embed, std=0.02) 246 | self.apply(self._init_weights) 247 | 248 | def _init_weights(self, m): 249 | if isinstance(m, nn.Linear): 250 | trunc_normal_(m.weight, std=0.02) 251 | if isinstance(m, nn.Linear) and m.bias is not None: 252 | nn.init.constant_(m.bias, 0) 253 | elif isinstance(m, nn.LayerNorm): 254 | nn.init.constant_(m.bias, 0) 255 | nn.init.constant_(m.weight, 1.0) 256 | 257 | def prepare_tokens(self, x): 258 | B, nc, w, h = x.shape 259 | x = self.patch_embed(x) # patch linear embedding 260 | # add positional encoding to each token 261 | x = x + self.pos_embed 262 | return self.pos_drop(x) 263 | 264 | def forward_head(self, x): 265 | B, _, _ = x.shape # B x N x embed_dim 266 | x = x.reshape(B, self.patch_embed.h, self.patch_embed.w, self.embed_dim) 267 | B, h, w, _ = x.shape 268 | 269 | # apply head 270 | x = self.head(x) 271 | x = x.reshape(shape=(B, h, w, self.patch_size, self.patch_size, self.out_ch)) 272 | x = torch.einsum("nhwpqc->nchpwq", x) 273 | x = x.reshape(shape=(B, self.out_ch, self.img_size[0], self.img_size[1])) 274 | 275 | return x 276 | 277 | def forward(self, x): 278 | x = self.prepare_tokens(x) 279 | 280 | # split sequence if cp is on (shape of x is (batch, seq, embed)) 281 | x = scatter_to_parallel_region(x, dim=1, comm_name="cp") 282 | 283 | # if cp is on, each block operates on a sequence shard 284 | for blk in self.blocks: 285 | x = blk(x) 286 | x = self.norm(x) 287 | 288 | # gather sequence if cp is on 289 | x = gather_from_parallel_region(x, dim=1, shapes=self.cp_shapes, comm_name="cp") 290 | 291 | x = self.forward_head(x) 292 | return x 293 | 294 | 295 | def ViT(params, **kwargs): 296 | model = VisionTransformer( 297 | img_size=tuple(params.img_size), 298 | in_chans=params.n_in_channels, 299 | out_chans=params.n_out_channels, 300 | patch_size=params.patch_size, 301 | embed_dim=params.embed_dim, 302 | depth=params.depth, 303 | num_heads=params.num_heads, 304 | mlp_ratio=4, 305 | qkv_bias=True, 306 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 307 | drop_path_rate=float(params.dropout), 308 | drop_rate=float(params.dropout), 309 | attn_drop_rate=float(params.dropout), 310 | **kwargs 311 | ) 312 | return model 313 | -------------------------------------------------------------------------------- /tests/test_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | import unittest 5 | import datetime as dt 6 | from utils.rank_generator import RankGenerator 7 | from utils import comm 8 | 9 | from networks.vit import MLP, Attention 10 | from parameterized import parameterized 11 | 12 | # distributed 13 | from distributed.layers import ( 14 | DistributedMatmul, 15 | DistributedMLP, 16 | DistributedAttention, 17 | DistributedLayerNorm, 18 | ) 19 | from distributed.helpers import compute_split_shapes 20 | from distributed.mappings import scatter_to_parallel_region, gather_from_parallel_region 21 | 22 | 23 | class TestDistributed(unittest.TestCase): 24 | @classmethod 25 | def setUpClass(cls): 26 | cls.world_size = int(os.getenv("WORLD_SIZE", 1)) 27 | cls.world_rank = int(os.getenv("RANK", 0)) 28 | port = int(os.getenv("MASTER_PORT", 0)) 29 | master_address = os.getenv("MASTER_ADDR") 30 | 31 | # get model parallel sizes 32 | tp = int(os.getenv("TP", 1)) 33 | cp = int(os.getenv("CP", 1)) 34 | pp = 1 35 | order = "cp-tp-dp-pp" 36 | model_parallel_size = tp * cp * pp 37 | dp = cls.world_size // model_parallel_size 38 | assert dp >= 1, "ERROR: data parallel wireup failed since dp = {}".format(dp) 39 | 40 | cls.print_to_screen = cls.world_rank == 0 41 | if cls.print_to_screen: 42 | print( 43 | "Distributed unit tests with DP = {}, TP = {}, CP = {}, PP = {}".format( 44 | dp, tp, cp, pp 45 | ) 46 | ) 47 | 48 | if torch.cuda.is_available(): 49 | if cls.print_to_screen: 50 | print("Running test on GPU") 51 | local_rank = cls.world_rank % torch.cuda.device_count() 52 | cls.device = torch.device(f"cuda:{local_rank}") 53 | torch.cuda.manual_seed(333) 54 | comm_backend = "nccl" 55 | else: 56 | if cls.print_to_screen: 57 | print("Running test on CPU") 58 | cls.device = torch.device("cpu") 59 | comm_backend = "gloo" 60 | torch.manual_seed(333) 61 | 62 | if cls.world_size > 1: 63 | # create tcp store 64 | store = dist.TCPStore( 65 | host_name=master_address, 66 | port=port, 67 | world_size=cls.world_size, 68 | is_master=(cls.world_rank == 0), 69 | timeout=dt.timedelta(seconds=900), 70 | ) 71 | 72 | # initialize process groups 73 | dist.init_process_group( 74 | backend=comm_backend, 75 | rank=cls.world_rank, 76 | world_size=cls.world_size, 77 | store=store, 78 | ) 79 | else: 80 | assert False, "Running distributed tests on single GPU" 81 | 82 | # init model + dp groups individually 83 | comm.init_model_parallel_info(tp=tp, cp=cp, dp=dp, pp=pp, order=order) 84 | 85 | @classmethod 86 | def tearDownClass(cls): 87 | dist.destroy_process_group(None) 88 | 89 | def _copy_mlp_weights(self, mlp_layer, mlp_layer_distributed): 90 | """copy the weights, bias of mlp into the correct shard of mlp_dist""" 91 | tp = comm.get_size("tp") 92 | # fc1 is col sharded, fc2 is row sharded (careful: PyT does AW^T) 93 | embed_local = mlp_layer.fc1.weight.shape[0] // tp 94 | rank_tp = comm.get_rank("tp") # which tp rank 95 | 96 | with torch.no_grad(): 97 | # copy sharded weights and biases for fc1 98 | start = rank_tp * embed_local 99 | end = start + embed_local 100 | mlp_layer_distributed.fc1.weight.copy_(mlp_layer.fc1.weight[start:end, :]) 101 | mlp_layer_distributed.fc1.bias.copy_( 102 | mlp_layer.fc1.bias[start:end].view(1, 1, -1) 103 | ) 104 | # copy sharded weights for fc2 105 | mlp_layer_distributed.fc2.weight.copy_(mlp_layer.fc2.weight[:, start:end]) 106 | # copy shared bias for fc2 across all shards 107 | mlp_layer_distributed.fc2.bias.copy_(mlp_layer.fc2.bias.view(1, 1, -1)) 108 | 109 | # tests to run with input parameterization 110 | # inputs are batch, seq, embed, tolerance 111 | @parameterized.expand([[4, 1024, 2048, 1e-4], [4, 4050, 2048, 1e-4]]) 112 | def test_distributed_mlp(self, batch, seq, embed, tolerance): 113 | # set the ops 114 | mlp_layer = MLP(in_features=embed, hidden_features=4 * embed).to(self.device) 115 | mlp_layer_distributed = DistributedMLP( 116 | in_features=embed, 117 | hidden_features=4 * embed, 118 | comm_tp_name="tp", 119 | comm_cp_name="cp", 120 | ).to(self.device) 121 | 122 | # sync the local and distributed weights 123 | self._copy_mlp_weights(mlp_layer, mlp_layer_distributed) 124 | 125 | ############################################################# 126 | # non-distributed op 127 | ############################################################# 128 | # create tensor 129 | inp = torch.randn((batch, seq, embed), dtype=torch.float32, device=self.device) 130 | inp.requires_grad = True 131 | 132 | # forward pass 133 | out = mlp_layer(inp) 134 | 135 | # backward pass 136 | with torch.no_grad(): 137 | out_grad = torch.randn_like(out) 138 | out.backward(out_grad) # vjp with random vector 139 | inp_grad = inp.grad.clone() 140 | 141 | ############################################################# 142 | # distributed op 143 | ############################################################# 144 | cp_shapes = compute_split_shapes(seq, comm.get_size("cp")) 145 | # split the input tensor to get local tensor 146 | with torch.no_grad(): 147 | inp_local = scatter_to_parallel_region(inp, dim=1, comm_name="cp") 148 | inp_local.requires_grad = True 149 | 150 | # forward pass local 151 | out_local = mlp_layer_distributed(inp_local) 152 | 153 | # backward pass local 154 | with torch.no_grad(): 155 | out_grad_local = scatter_to_parallel_region(out_grad, dim=1, comm_name="cp") 156 | out_local.backward(out_grad_local) # vjp with same random local vector 157 | inp_grad_local = inp_local.grad.clone() 158 | 159 | ############################################################# 160 | # evaluate forward pass 161 | ############################################################# 162 | with torch.no_grad(): 163 | out_gather = gather_from_parallel_region( 164 | out_local, dim=1, shapes=cp_shapes, comm_name="cp" 165 | ) 166 | err = torch.mean( 167 | torch.norm(out - out_gather, p="fro", dim=(-1, -2)) 168 | / torch.norm(out, p="fro", dim=(-1, -2)) 169 | ) 170 | if self.print_to_screen: 171 | print(f"final relative error of output in mlp: {err.item()}") 172 | self.assertTrue(err.item() <= tolerance) 173 | 174 | ############################################################# 175 | # evaluate backward pass 176 | ############################################################# 177 | with torch.no_grad(): 178 | inp_grad_gather = gather_from_parallel_region( 179 | inp_grad_local, dim=1, shapes=cp_shapes, comm_name="cp" 180 | ) 181 | err = torch.mean( 182 | torch.norm(inp_grad - inp_grad_gather, p="fro", dim=(-1, -2)) 183 | / torch.norm(inp_grad, p="fro", dim=(-1, -2)) 184 | ) 185 | if self.print_to_screen: 186 | print(f"final relative error of gradients in mlp: {err.item()}") 187 | self.assertTrue(err.item() <= tolerance) 188 | 189 | def _copy_attn_weights(self, attn_layer, attn_layer_distributed): 190 | """copy the weights, bias of attn into the correct shard of attn_dist""" 191 | tp = comm.get_size("tp") 192 | embed = attn_layer.proj.weight.shape[1] 193 | embed_local = embed // tp 194 | rank_tp = comm.get_rank("tp") # which tp rank 195 | 196 | with torch.no_grad(): 197 | # copy sharded weights and biases for qkv 198 | start = rank_tp * embed_local 199 | end = start + embed_local 200 | attn_layer_distributed.q.weight.copy_(attn_layer.q.weight[start:end, :]) 201 | attn_layer_distributed.q.bias.copy_( 202 | attn_layer.q.bias[start:end].view(1, 1, -1) 203 | ) 204 | attn_layer_distributed.k.weight.copy_(attn_layer.k.weight[start:end, :]) 205 | attn_layer_distributed.k.bias.copy_( 206 | attn_layer.k.bias[start:end].view(1, 1, -1) 207 | ) 208 | attn_layer_distributed.v.weight.copy_(attn_layer.v.weight[start:end, :]) 209 | attn_layer_distributed.v.bias.copy_( 210 | attn_layer.v.bias[start:end].view(1, 1, -1) 211 | ) 212 | # copy sharded weights for proj 213 | start = rank_tp * embed_local 214 | end = start + embed_local 215 | attn_layer_distributed.proj.weight.copy_( 216 | attn_layer.proj.weight[:, start:end] 217 | ) 218 | attn_layer_distributed.proj.bias.copy_(attn_layer.proj.bias.view(1, 1, -1)) 219 | 220 | # tests to run with input parameterization 221 | # inputs are batch, seq, embed, num_heads, tolerance 222 | @parameterized.expand([[4, 1024, 2048, 8, 1e-4], [4, 4050, 2048, 8, 1e-4]]) 223 | def test_distributed_attention(self, batch, seq, embed, num_heads, tolerance): 224 | # set the ops 225 | attn_layer = Attention(dim=embed, num_heads=num_heads, qkv_bias=True).to( 226 | self.device 227 | ) 228 | cp_shapes = compute_split_shapes(seq, comm.get_size("cp")) 229 | attn_layer_distributed = DistributedAttention( 230 | dim=embed, 231 | num_heads=num_heads, 232 | qkv_bias=True, 233 | comm_tp_name="tp", 234 | comm_cp_name="cp", 235 | cp_shapes=cp_shapes, 236 | ).to(self.device) 237 | 238 | # sync the local and distributed weights 239 | self._copy_attn_weights(attn_layer, attn_layer_distributed) 240 | 241 | ############################################################# 242 | # non-distributed op 243 | ############################################################# 244 | # create tensor 245 | inp = torch.randn((batch, seq, embed), dtype=torch.float32, device=self.device) 246 | inp.requires_grad = True 247 | 248 | # forward pass 249 | out = attn_layer(inp) 250 | 251 | # backward pass 252 | with torch.no_grad(): 253 | out_grad = torch.randn_like(out) 254 | out.backward(out_grad) # vjp with random vector 255 | inp_grad = inp.grad.clone() 256 | 257 | ############################################################# 258 | # distributed op 259 | ############################################################# 260 | # split the input tensor to get local tensor 261 | with torch.no_grad(): 262 | inp_local = scatter_to_parallel_region(inp, dim=1, comm_name="cp") 263 | inp_local.requires_grad = True 264 | 265 | # forward pass local 266 | out_local = attn_layer_distributed(inp_local) 267 | 268 | # backward pass local 269 | with torch.no_grad(): 270 | out_grad_local = scatter_to_parallel_region(out_grad, dim=1, comm_name="cp") 271 | out_local.backward(out_grad_local) # vjp with same random local vector 272 | inp_grad_local = inp_local.grad.clone() 273 | 274 | ############################################################# 275 | # evaluate forward pass 276 | ############################################################# 277 | with torch.no_grad(): 278 | out_gather = gather_from_parallel_region( 279 | out_local, dim=1, shapes=cp_shapes, comm_name="cp" 280 | ) 281 | err = torch.mean( 282 | torch.norm(out - out_gather, p="fro", dim=(-1, -2)) 283 | / torch.norm(out, p="fro", dim=(-1, -2)) 284 | ) 285 | if self.print_to_screen: 286 | print(f"final relative error of output in sa: {err.item()}") 287 | self.assertTrue(err.item() <= tolerance) 288 | 289 | ############################################################# 290 | # evaluate backward pass 291 | ############################################################# 292 | with torch.no_grad(): 293 | inp_grad_gather = gather_from_parallel_region( 294 | inp_grad_local, dim=1, shapes=cp_shapes, comm_name="cp" 295 | ) 296 | err = torch.mean( 297 | torch.norm(inp_grad - inp_grad_gather, p="fro", dim=(-1, -2)) 298 | / torch.norm(inp_grad, p="fro", dim=(-1, -2)) 299 | ) 300 | if self.print_to_screen: 301 | print(f"final relative error of gradients in sa: {err.item()}") 302 | self.assertTrue(err.item() <= tolerance) 303 | 304 | 305 | if __name__ == "__main__": 306 | unittest.main() 307 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.amp import autocast, GradScaler 11 | import torch.multiprocessing 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.nn.parallel import DistributedDataParallel 14 | 15 | import logging 16 | from utils import logging_utils 17 | 18 | logging_utils.config_logger() 19 | from utils.YParams import YParams 20 | from utils import get_data_loader_distributed 21 | from utils.loss import l2_loss, l2_loss_opt 22 | from utils.metrics import weighted_rmse 23 | from utils.plots import generate_images 24 | from networks import vit 25 | 26 | 27 | def train(params, args, local_rank, world_rank, world_size): 28 | # set device and benchmark mode 29 | torch.backends.cudnn.benchmark = True 30 | torch.cuda.set_device(local_rank) 31 | device = torch.device("cuda:%d" % local_rank) 32 | 33 | # get data loader 34 | logging.info("rank %d, begin data loader init" % world_rank) 35 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed( 36 | params, params.train_data_path, params.distributed, train=True 37 | ) 38 | val_data_loader, valid_dataset = get_data_loader_distributed( 39 | params, params.valid_data_path, params.distributed, train=False 40 | ) 41 | logging.info("rank %d, data loader initialized" % (world_rank)) 42 | 43 | # create model 44 | model = vit.ViT(params).to(device) 45 | 46 | if params.enable_jit: 47 | model = torch.compile(model) 48 | 49 | if params.amp_dtype == torch.float16: 50 | scaler = GradScaler("cuda") 51 | if params.distributed and not args.noddp: 52 | if args.disable_broadcast_buffers: 53 | model = DistributedDataParallel( 54 | model, 55 | device_ids=[local_rank], 56 | bucket_cap_mb=args.bucket_cap_mb, 57 | broadcast_buffers=False, 58 | gradient_as_bucket_view=True, 59 | ) 60 | else: 61 | model = DistributedDataParallel( 62 | model, device_ids=[local_rank], bucket_cap_mb=args.bucket_cap_mb 63 | ) 64 | 65 | if params.enable_fused: 66 | optimizer = optim.Adam( 67 | model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95) 68 | ) 69 | else: 70 | optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95)) 71 | 72 | if world_rank == 0: 73 | logging.info(model) 74 | 75 | iters = 0 76 | startEpoch = 0 77 | 78 | if params.lr_schedule == "cosine": 79 | if params.warmup > 0: 80 | lr_scale = lambda x: min( 81 | (x + 1) / params.warmup, 82 | 0.5 * (1 + np.cos(np.pi * x / params.num_iters)), 83 | ) 84 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 85 | else: 86 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 87 | optimizer, T_max=params.num_iters 88 | ) 89 | else: 90 | scheduler = None 91 | 92 | # select loss function 93 | if params.enable_jit: 94 | loss_func = l2_loss_opt 95 | else: 96 | loss_func = l2_loss 97 | 98 | if world_rank == 0: 99 | logging.info("Starting Training Loop...") 100 | 101 | # Log initial loss on train and validation to tensorboard 102 | with torch.no_grad(): 103 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 104 | gen = model(inp) 105 | tr_loss = loss_func(gen, tar) 106 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 107 | gen = model(inp) 108 | val_loss = loss_func(gen, tar) 109 | val_rmse = weighted_rmse(gen, tar) 110 | if params.distributed: 111 | torch.distributed.all_reduce(tr_loss) 112 | torch.distributed.all_reduce(val_loss) 113 | torch.distributed.all_reduce(val_rmse) 114 | if world_rank == 0: 115 | args.tboard_writer.add_scalar("Loss/train", tr_loss.item() / world_size, 0) 116 | args.tboard_writer.add_scalar("Loss/valid", val_loss.item() / world_size, 0) 117 | args.tboard_writer.add_scalar( 118 | "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0] / world_size, 0 119 | ) 120 | 121 | params.num_epochs = params.num_iters // len(train_data_loader) 122 | iters = 0 123 | t1 = time.time() 124 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 125 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 126 | if params.distributed and (train_sampler is not None): 127 | train_sampler.set_epoch(epoch) 128 | start = time.time() 129 | tr_loss = [] 130 | tr_time = 0.0 131 | dat_time = 0.0 132 | log_time = 0.0 133 | 134 | model.train() 135 | step_count = 0 136 | for i, data in enumerate(train_data_loader, 0): 137 | if world_rank == 0: 138 | if epoch == 3 and i == 0: 139 | torch.cuda.profiler.start() 140 | if epoch == 3 and i == len(train_data_loader) - 1: 141 | torch.cuda.profiler.stop() 142 | 143 | torch.cuda.nvtx.range_push(f"step {i}") 144 | iters += 1 145 | dat_start = time.time() 146 | torch.cuda.nvtx.range_push(f"data copy in {i}") 147 | 148 | inp, tar = map(lambda x: x.to(device), data) 149 | torch.cuda.nvtx.range_pop() # copy in 150 | 151 | tr_start = time.time() 152 | b_size = inp.size(0) 153 | 154 | optimizer.zero_grad() 155 | 156 | torch.cuda.nvtx.range_push(f"forward") 157 | with autocast("cuda", enabled=params.amp_enabled, dtype=params.amp_dtype): 158 | gen = model(inp) 159 | loss = loss_func(gen, tar) 160 | torch.cuda.nvtx.range_pop() # forward 161 | 162 | if params.amp_dtype == torch.float16: 163 | scaler.scale(loss).backward() 164 | torch.cuda.nvtx.range_push(f"optimizer") 165 | scaler.step(optimizer) 166 | torch.cuda.nvtx.range_pop() # optimizer 167 | scaler.update() 168 | else: 169 | loss.backward() 170 | torch.cuda.nvtx.range_push(f"optimizer") 171 | optimizer.step() 172 | torch.cuda.nvtx.range_pop() # optimizer 173 | 174 | if params.distributed: 175 | torch.distributed.all_reduce(loss) 176 | tr_loss.append(loss.item() / world_size) 177 | 178 | torch.cuda.nvtx.range_pop() # step 179 | # lr step 180 | scheduler.step() 181 | 182 | tr_end = time.time() 183 | tr_time += tr_end - tr_start 184 | dat_time += tr_start - dat_start 185 | step_count += 1 186 | 187 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 188 | end = time.time() 189 | 190 | if world_rank == 0: 191 | iters_per_sec = step_count / (end - start) 192 | samples_per_sec = params["global_batch_size"] * iters_per_sec 193 | logging.info( 194 | "Time taken for epoch %i is %f sec, avg %f samples/sec", 195 | epoch + 1, 196 | end - start, 197 | samples_per_sec, 198 | ) 199 | logging.info(" Avg train loss=%f" % np.mean(tr_loss)) 200 | args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters) 201 | args.tboard_writer.add_scalar( 202 | "Learning Rate", optimizer.param_groups[0]["lr"], iters 203 | ) 204 | args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters) 205 | args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters) 206 | fig = generate_images([inp, tar, gen]) 207 | args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True) 208 | 209 | val_start = time.time() 210 | val_loss = torch.zeros(1, device=device) 211 | val_rmse = torch.zeros( 212 | (params.n_out_channels), dtype=torch.float32, device=device 213 | ) 214 | valid_steps = 0 215 | model.eval() 216 | 217 | with torch.inference_mode(): 218 | with torch.no_grad(): 219 | for i, data in enumerate(val_data_loader, 0): 220 | with autocast( 221 | "cuda", enabled=params.amp_enabled, dtype=params.amp_dtype 222 | ): 223 | inp, tar = map(lambda x: x.to(device), data) 224 | gen = model(inp) 225 | val_loss += loss_func(gen, tar) 226 | val_rmse += weighted_rmse(gen, tar) 227 | valid_steps += 1 228 | 229 | if params.distributed: 230 | torch.distributed.all_reduce(val_loss) 231 | val_loss /= world_size 232 | torch.distributed.all_reduce(val_rmse) 233 | val_rmse /= world_size 234 | 235 | val_rmse /= valid_steps # Avg validation rmse 236 | val_loss /= valid_steps 237 | val_end = time.time() 238 | if world_rank == 0: 239 | logging.info(" Avg val loss={}".format(val_loss.item())) 240 | logging.info(" Total validation time: {} sec".format(val_end - val_start)) 241 | args.tboard_writer.add_scalar("Loss/valid", val_loss, iters) 242 | args.tboard_writer.add_scalar( 243 | "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters 244 | ) 245 | args.tboard_writer.flush() 246 | 247 | t2 = time.time() 248 | tottime = t2 - t1 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument( 254 | "--run_num", 255 | default="00", 256 | type=str, 257 | help="tag for indexing the current experiment", 258 | ) 259 | parser.add_argument( 260 | "--yaml_config", 261 | default="./config/ViT.yaml", 262 | type=str, 263 | help="path to yaml file containing training configs", 264 | ) 265 | parser.add_argument( 266 | "--config", default="base", type=str, help="name of desired config in yaml file" 267 | ) 268 | parser.add_argument( 269 | "--amp_mode", 270 | default="none", 271 | type=str, 272 | choices=["none", "fp16", "bf16"], 273 | help="select automatic mixed precision mode", 274 | ) 275 | parser.add_argument( 276 | "--enable_fused", action="store_true", help="enable fused Adam optimizer" 277 | ) 278 | parser.add_argument( 279 | "--enable_jit", action="store_true", help="enable JIT compilation" 280 | ) 281 | parser.add_argument( 282 | "--local_batch_size", 283 | default=None, 284 | type=int, 285 | help="local batchsize (manually override global_batch_size config setting)", 286 | ) 287 | parser.add_argument( 288 | "--num_iters", default=None, type=int, help="number of iters to run" 289 | ) 290 | parser.add_argument( 291 | "--num_data_workers", 292 | default=None, 293 | type=int, 294 | help="number of data workers for data loader", 295 | ) 296 | parser.add_argument( 297 | "--data_loader_config", 298 | default=None, 299 | type=str, 300 | choices=["pytorch", "dali"], 301 | help="dataloader configuration. choices: 'pytorch', 'dali'", 302 | ) 303 | parser.add_argument( 304 | "--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb" 305 | ) 306 | parser.add_argument( 307 | "--disable_broadcast_buffers", 308 | action="store_true", 309 | help="disable syncing broadcasting buffers", 310 | ) 311 | parser.add_argument( 312 | "--noddp", action="store_true", help="disable DDP communication" 313 | ) 314 | args = parser.parse_args() 315 | 316 | run_num = args.run_num 317 | 318 | params = YParams(os.path.abspath(args.yaml_config), args.config) 319 | 320 | # Update config with modified args 321 | # set up amp 322 | if args.amp_mode != "none": 323 | params.update({"amp_mode": args.amp_mode}) 324 | amp_dtype = torch.float32 325 | 326 | if params.amp_mode == "fp16": 327 | amp_dtype = torch.float16 328 | elif params.amp_mode == "bf16": 329 | amp_dtype = torch.bfloat16 330 | 331 | params.update( 332 | {"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype} 333 | ) 334 | 335 | if args.enable_fused: 336 | params.update({"enable_fused": args.enable_fused}) 337 | 338 | if args.enable_jit: 339 | params.update({"enable_jit": args.enable_jit}) 340 | 341 | if args.data_loader_config: 342 | params.update({"data_loader_config": args.data_loader_config}) 343 | 344 | if args.num_iters: 345 | params.update({"num_iters": args.num_iters}) 346 | 347 | if args.num_data_workers: 348 | params.update({"num_data_workers": args.num_data_workers}) 349 | 350 | params.distributed = False 351 | if "WORLD_SIZE" in os.environ: 352 | params.distributed = int(os.environ["WORLD_SIZE"]) > 1 353 | world_size = int(os.environ["WORLD_SIZE"]) 354 | else: 355 | world_size = 1 356 | 357 | world_rank = 0 358 | local_rank = 0 359 | if params.distributed: 360 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 361 | world_rank = torch.distributed.get_rank() 362 | local_rank = int(os.environ["LOCAL_RANK"]) 363 | 364 | if args.local_batch_size: 365 | # Manually override batch size 366 | params.local_batch_size = args.local_batch_size 367 | params.update({"global_batch_size": world_size * args.local_batch_size}) 368 | else: 369 | # Compute local batch size based on number of ranks 370 | params.local_batch_size = params.global_batch_size // world_size 371 | 372 | # for dali data loader, set the actual number of data shards and id 373 | params.data_num_shards = world_size 374 | params.data_shard_id = world_rank 375 | 376 | # Set up directory 377 | baseDir = params.expdir 378 | expDir = os.path.join( 379 | baseDir, args.config + "/%dGPU/" % (world_size) + str(run_num) + "/" 380 | ) 381 | if world_rank == 0: 382 | if not os.path.isdir(expDir): 383 | os.makedirs(expDir) 384 | logging_utils.log_to_file( 385 | logger_name=None, log_filename=os.path.join(expDir, "out.log") 386 | ) 387 | params.log() 388 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/")) 389 | 390 | params.experiment_dir = os.path.abspath(expDir) 391 | 392 | train(params, args, local_rank, world_rank, world_size) 393 | 394 | if params.distributed: 395 | torch.distributed.barrier() 396 | logging.info("DONE ---- rank %d" % world_rank) 397 | -------------------------------------------------------------------------------- /train_mp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | import pynvml 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.amp import autocast, GradScaler 12 | import torch.multiprocessing 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch.distributed import ReduceOp 16 | 17 | import logging 18 | from utils import logging_utils 19 | 20 | logging_utils.config_logger() 21 | from utils.YParams import YParams 22 | from utils import get_data_loader_distributed 23 | from utils import comm 24 | from utils.loss import l2_loss, l2_loss_opt 25 | from utils.metrics import weighted_rmse 26 | from networks import vit 27 | 28 | from distributed.mappings import init_ddp_model_and_reduction_hooks 29 | from distributed.helpers import init_params_for_shared_weights 30 | 31 | from utils.plots import generate_images 32 | 33 | 34 | def train(params, args, local_rank, world_rank, world_size): 35 | # set device and benchmark mode 36 | torch.backends.cudnn.benchmark = True 37 | torch.cuda.set_device(local_rank) 38 | device = torch.device("cuda:%d" % local_rank) 39 | 40 | # init pynvml and get handle 41 | pynvml.nvmlInit() 42 | nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) 43 | 44 | # get data loader 45 | logging.info("rank %d, begin data loader init" % world_rank) 46 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed( 47 | params, params.train_data_path, params.distributed, train=True 48 | ) 49 | val_data_loader, valid_dataset = get_data_loader_distributed( 50 | params, params.valid_data_path, params.distributed, train=False 51 | ) 52 | logging.info("rank %d, data loader initialized" % (world_rank)) 53 | 54 | # create model 55 | model = vit.ViT(params).to(device) 56 | 57 | if params.enable_jit: 58 | # if params.distributed and not args.noddp: 59 | # torch._dynamo.config.optimize_ddp = False 60 | model = torch.compile(model) 61 | 62 | if params.amp_dtype == torch.float16: 63 | scaler = GradScaler('cuda') 64 | 65 | # weight initialization needs to be synced across shared weights 66 | if comm.get_size("tp-cp") > 1: 67 | init_params_for_shared_weights(model) 68 | 69 | if params.distributed and not args.noddp: 70 | model = init_ddp_model_and_reduction_hooks(model, device_ids=[local_rank], 71 | output_device=[local_rank], 72 | bucket_cap_mb=args.bucket_cap_mb) 73 | 74 | if params.enable_fused: 75 | optimizer = optim.Adam( 76 | model.parameters(), lr=params.lr, fused=True, betas=(0.9, 0.95) 77 | ) 78 | else: 79 | optimizer = optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.95)) 80 | 81 | if world_rank == 0: 82 | logging.info(model) 83 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / ( 84 | 1024.0 * 1024.0 * 1024.0 85 | ) 86 | logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.") 87 | 88 | iters = 0 89 | startEpoch = 0 90 | 91 | if params.lr_schedule == "cosine": 92 | if params.warmup > 0: 93 | lr_scale = lambda x: min( 94 | (x + 1) / params.warmup, 95 | 0.5 * (1 + np.cos(np.pi * x / params.num_iters)), 96 | ) 97 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 98 | else: 99 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 100 | optimizer, T_max=params.num_iters 101 | ) 102 | else: 103 | scheduler = None 104 | 105 | # select loss function 106 | if params.enable_jit: 107 | loss_func = l2_loss_opt 108 | else: 109 | loss_func = l2_loss 110 | 111 | if world_rank == 0: 112 | logging.info("Starting Training Loop...") 113 | 114 | # Log initial loss on train and validation to tensorboard 115 | with torch.no_grad(): 116 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 117 | gen = model(inp) 118 | tr_loss = loss_func(gen, tar) 119 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 120 | gen = model(inp) 121 | val_loss = loss_func(gen, tar) 122 | val_rmse = weighted_rmse(gen, tar) 123 | if params.distributed: 124 | torch.distributed.all_reduce( 125 | tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp") 126 | ) 127 | torch.distributed.all_reduce( 128 | val_loss, op=ReduceOp.AVG, group=comm.get_group("dp") 129 | ) 130 | torch.distributed.all_reduce( 131 | val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp") 132 | ) 133 | if world_rank == 0: 134 | args.tboard_writer.add_scalar("Loss/train", tr_loss.item(), 0) 135 | args.tboard_writer.add_scalar("Loss/valid", val_loss.item(), 0) 136 | args.tboard_writer.add_scalar( 137 | "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], 0 138 | ) 139 | 140 | params.num_epochs = params.num_iters // len(train_data_loader) 141 | iters = 0 142 | t1 = time.time() 143 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 144 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 145 | if params.distributed and (train_sampler is not None): 146 | train_sampler.set_epoch(epoch) 147 | start = time.time() 148 | tr_loss = [] 149 | tr_time = 0.0 150 | dat_time = 0.0 151 | log_time = 0.0 152 | 153 | model.train() 154 | step_count = 0 155 | 156 | for i, data in enumerate(train_data_loader, 0): 157 | if world_rank == 0: 158 | if epoch == 3 and i == 0: 159 | torch.cuda.profiler.start() 160 | if epoch == 3 and i == len(train_data_loader) - 1: 161 | torch.cuda.profiler.stop() 162 | 163 | torch.cuda.nvtx.range_push(f"step {i}") 164 | iters += 1 165 | dat_start = time.time() 166 | torch.cuda.nvtx.range_push(f"data copy in {i}") 167 | 168 | inp, tar = map(lambda x: x.to(device), data) 169 | torch.cuda.nvtx.range_pop() # copy in 170 | 171 | tr_start = time.time() 172 | b_size = inp.size(0) 173 | 174 | optimizer.zero_grad() 175 | 176 | torch.cuda.nvtx.range_push(f"forward") 177 | with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype): 178 | gen = model(inp) 179 | loss = loss_func(gen, tar) 180 | torch.cuda.nvtx.range_pop() # forward 181 | 182 | if world_rank == 0 and i == 1: # print the mem used 183 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / ( 184 | 1024.0 * 1024.0 * 1024.0 185 | ) 186 | logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.") 187 | 188 | if params.amp_dtype == torch.float16: 189 | scaler.scale(loss).backward() 190 | torch.cuda.nvtx.range_push(f"optimizer") 191 | scaler.step(optimizer) 192 | torch.cuda.nvtx.range_pop() # optimizer 193 | scaler.update() 194 | else: 195 | loss.backward() 196 | torch.cuda.nvtx.range_push(f"optimizer") 197 | optimizer.step() 198 | torch.cuda.nvtx.range_pop() # optimizer 199 | 200 | if params.distributed: 201 | torch.distributed.all_reduce( 202 | loss, op=ReduceOp.AVG, group=comm.get_group("dp") 203 | ) 204 | tr_loss.append(loss.item()) 205 | 206 | torch.cuda.nvtx.range_pop() # step 207 | # lr step 208 | scheduler.step() 209 | 210 | tr_end = time.time() 211 | tr_time += tr_end - tr_start 212 | dat_time += tr_start - dat_start 213 | step_count += 1 214 | 215 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 216 | end = time.time() 217 | 218 | if world_rank == 0: 219 | iters_per_sec = step_count / (end - start) 220 | samples_per_sec = params["global_batch_size"] * iters_per_sec 221 | logging.info( 222 | "Time taken for epoch %i is %f sec, avg %f samples/sec", 223 | epoch + 1, 224 | end - start, 225 | samples_per_sec, 226 | ) 227 | logging.info(" Avg train loss=%f" % np.mean(tr_loss)) 228 | args.tboard_writer.add_scalar("Loss/train", np.mean(tr_loss), iters) 229 | args.tboard_writer.add_scalar( 230 | "Learning Rate", optimizer.param_groups[0]["lr"], iters 231 | ) 232 | args.tboard_writer.add_scalar("Avg iters per sec", iters_per_sec, iters) 233 | args.tboard_writer.add_scalar("Avg samples per sec", samples_per_sec, iters) 234 | fig = generate_images([inp, tar, gen]) 235 | args.tboard_writer.add_figure("Visualization, t2m", fig, iters, close=True) 236 | 237 | val_start = time.time() 238 | val_loss = torch.zeros(1, device=device) 239 | val_rmse = torch.zeros( 240 | (params.n_out_channels), dtype=torch.float32, device=device 241 | ) 242 | valid_steps = 0 243 | model.eval() 244 | 245 | with torch.inference_mode(): 246 | with torch.no_grad(): 247 | for i, data in enumerate(val_data_loader, 0): 248 | with autocast('cuda', enabled=params.amp_enabled, dtype=params.amp_dtype): 249 | inp, tar = map(lambda x: x.to(device), data) 250 | gen = model(inp) 251 | val_loss += loss_func(gen, tar) 252 | val_rmse += weighted_rmse(gen, tar) 253 | valid_steps += 1 254 | 255 | if params.distributed: 256 | torch.distributed.all_reduce( 257 | val_loss, op=ReduceOp.AVG, group=comm.get_group("dp") 258 | ) 259 | torch.distributed.all_reduce( 260 | val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp") 261 | ) 262 | 263 | val_rmse /= valid_steps # Avg validation rmse 264 | val_loss /= valid_steps 265 | val_end = time.time() 266 | if world_rank == 0: 267 | logging.info(" Avg val loss={}".format(val_loss.item())) 268 | logging.info(" Total validation time: {} sec".format(val_end - val_start)) 269 | args.tboard_writer.add_scalar("Loss/valid", val_loss, iters) 270 | args.tboard_writer.add_scalar( 271 | "RMSE(u10m)/valid", val_rmse.cpu().numpy()[0], iters 272 | ) 273 | args.tboard_writer.flush() 274 | 275 | torch.cuda.synchronize() 276 | t2 = time.time() 277 | tottime = t2 - t1 278 | pynvml.nvmlShutdown() 279 | 280 | 281 | if __name__ == "__main__": 282 | parser = argparse.ArgumentParser() 283 | parser.add_argument( 284 | "--run_num", 285 | default="00", 286 | type=str, 287 | help="tag for indexing the current experiment", 288 | ) 289 | parser.add_argument( 290 | "--yaml_config", 291 | default="./config/ViT.yaml", 292 | type=str, 293 | help="path to yaml file containing training configs", 294 | ) 295 | parser.add_argument( 296 | "--config", default="base", type=str, help="name of desired config in yaml file" 297 | ) 298 | parser.add_argument( 299 | "--amp_mode", 300 | default="none", 301 | type=str, 302 | choices=["none", "fp16", "bf16"], 303 | help="select automatic mixed precision mode", 304 | ) 305 | parser.add_argument( 306 | "--enable_fused", action="store_true", help="enable fused Adam optimizer" 307 | ) 308 | parser.add_argument( 309 | "--enable_jit", action="store_true", help="enable JIT compilation" 310 | ) 311 | parser.add_argument( 312 | "--local_batch_size", 313 | default=None, 314 | type=int, 315 | help="local batchsize (manually override global_batch_size config setting)", 316 | ) 317 | parser.add_argument( 318 | "--num_iters", default=None, type=int, help="number of iters to run" 319 | ) 320 | parser.add_argument( 321 | "--num_data_workers", 322 | default=None, 323 | type=int, 324 | help="number of data workers for data loader", 325 | ) 326 | parser.add_argument( 327 | "--data_loader_config", 328 | default=None, 329 | type=str, 330 | choices=["pytorch", "dali"], 331 | help="dataloader configuration. choices: 'pytorch', 'dali'", 332 | ) 333 | parser.add_argument( 334 | "--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb" 335 | ) 336 | parser.add_argument( 337 | "--disable_broadcast_buffers", 338 | action="store_true", 339 | help="disable syncing broadcasting buffers", 340 | ) 341 | parser.add_argument( 342 | "--noddp", action="store_true", help="disable DDP communication" 343 | ) 344 | 345 | # model parallelism arguments 346 | parser.add_argument( 347 | "--tensor_parallel", 348 | default=1, 349 | type=int, 350 | help="Number of GPUs for tensor parallelism", 351 | ) 352 | parser.add_argument( 353 | "--context_parallel", 354 | default=1, 355 | type=int, 356 | help="Number of GPUs for context parallelism", 357 | ) 358 | parser.add_argument( 359 | "--parallel_order", 360 | default="tp-cp-dp", 361 | type=str, 362 | help="Order of ranks for parallelism", 363 | ) 364 | 365 | args = parser.parse_args() 366 | 367 | run_num = args.run_num 368 | 369 | params = YParams(os.path.abspath(args.yaml_config), args.config) 370 | 371 | # Update config with modified args 372 | # set up amp 373 | if args.amp_mode != "none": 374 | params.update({"amp_mode": args.amp_mode}) 375 | amp_dtype = torch.float32 376 | if params.amp_mode == "fp16": 377 | amp_dtype = torch.float16 378 | elif params.amp_mode == "bf16": 379 | amp_dtype = torch.bfloat16 380 | 381 | params.update( 382 | {"amp_enabled": amp_dtype is not torch.float32, "amp_dtype": amp_dtype} 383 | ) 384 | 385 | if args.enable_fused: 386 | params.update({"enable_fused": args.enable_fused}) 387 | 388 | if args.enable_jit: 389 | params.update({"enable_jit": args.enable_jit}) 390 | 391 | if args.data_loader_config: 392 | params.update({"data_loader_config": args.data_loader_config}) 393 | 394 | if args.num_iters: 395 | params.update({"num_iters": args.num_iters}) 396 | 397 | if args.num_data_workers: 398 | params.update({"num_data_workers": args.num_data_workers}) 399 | 400 | params.distributed = False 401 | 402 | # setup model parallel sizes 403 | params["tp"] = args.tensor_parallel 404 | params["cp"] = args.context_parallel 405 | params["order"] = args.parallel_order 406 | # initialize comm 407 | comm.init(params, verbose=True) 408 | 409 | # get info from comm 410 | world_size = comm.get_world_size() 411 | world_rank = comm.get_world_rank() 412 | local_rank = comm.get_local_rank() 413 | params.distributed = world_size > 1 414 | 415 | assert ( 416 | params["global_batch_size"] % comm.get_size("dp") == 0 417 | ), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('dp')} GPU." 418 | 419 | if args.local_batch_size: 420 | # Manually override batch size 421 | params.local_batch_size = args.local_batch_size 422 | params.update( 423 | {"global_batch_size": comm.get_size("dp") * args.local_batch_size} 424 | ) 425 | else: 426 | # Compute local batch size based on number of ranks 427 | params.local_batch_size = int( 428 | params["global_batch_size"] // comm.get_size("dp") 429 | ) 430 | 431 | # for data loader, set the actual number of data shards and id 432 | params.data_num_shards = comm.get_size("dp") 433 | params.data_shard_id = comm.get_rank("dp") 434 | 435 | # Set up directory 436 | baseDir = params.expdir 437 | expDir = os.path.join( 438 | baseDir, args.config + "/%dMP/" % (comm.get_size("tp-cp")) + str(run_num) + "/" 439 | ) 440 | if world_rank == 0: 441 | if not os.path.isdir(expDir): 442 | os.makedirs(expDir) 443 | logging_utils.log_to_file( 444 | logger_name=None, log_filename=os.path.join(expDir, "out.log") 445 | ) 446 | params.log() 447 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, "logs/")) 448 | 449 | params.experiment_dir = os.path.abspath(expDir) 450 | 451 | train(params, args, local_rank, world_rank, world_size) 452 | 453 | if params.distributed: 454 | torch.distributed.barrier() 455 | logging.info("DONE ---- rank %d" % world_rank) 456 | -------------------------------------------------------------------------------- /train_mp_graphs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import gc 5 | import numpy as np 6 | import argparse 7 | import pynvml 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.cuda.amp import autocast, GradScaler 13 | import torch.multiprocessing 14 | from torch.utils.tensorboard import SummaryWriter 15 | from torch.nn.parallel import DistributedDataParallel 16 | from torch.distributed import ReduceOp 17 | 18 | import logging 19 | from utils import logging_utils 20 | logging_utils.config_logger() 21 | from utils.YParams import YParams 22 | from utils import get_data_loader_distributed 23 | from utils import comm 24 | from utils.loss import l2_loss, l2_loss_opt 25 | from utils.metrics import weighted_rmse 26 | from networks import vit 27 | 28 | from distributed.mappings import init_ddp_model_and_reduction_hooks 29 | from distributed.helpers import init_params_for_shared_weights 30 | 31 | from utils.plots import generate_images 32 | 33 | 34 | def capture_model(params, model, loss_func, scaler, capture_stream, device, num_warmup=20): 35 | logging.info("Capturing Model") 36 | inp_shape = (params.local_batch_size, params.n_in_channels, params.img_size[0], params.img_size[1]) 37 | tar_shape = (params.local_batch_size, params.n_in_channels, params.img_size[0], params.img_size[1]) 38 | 39 | # no zeros because loss is rel loss 40 | static_input = torch.randn(inp_shape, dtype=torch.float32, device=device) 41 | static_label = torch.randn(tar_shape, dtype=torch.float32, device=device) 42 | 43 | capture_stream.wait_stream(torch.cuda.current_stream()) 44 | with torch.cuda.stream(capture_stream): 45 | for _ in range(num_warmup): 46 | model.zero_grad(set_to_none=True) 47 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 48 | static_output = model(static_input).to(device) 49 | static_loss = loss_func(static_output, static_label) 50 | 51 | if params.amp_dtype == torch.float16: 52 | scaler.scale(static_loss).backward() 53 | else: 54 | static_loss.backward() 55 | 56 | # sync here 57 | capture_stream.synchronize() 58 | 59 | gc.collect() 60 | torch.cuda.empty_cache() 61 | 62 | # create graph 63 | graph = torch.cuda.CUDAGraph() 64 | 65 | # zero grads before capture: 66 | model.zero_grad(set_to_none=True) 67 | 68 | # do the capture with the context manager: 69 | with torch.cuda.graph(graph): 70 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 71 | static_output = model(static_input).to(device) 72 | static_loss = loss_func(static_output, static_label) 73 | 74 | if params.amp_dtype == torch.float16: 75 | scaler.scale(static_loss).backward() 76 | else: 77 | static_loss.backward() 78 | 79 | torch.cuda.current_stream().wait_stream(capture_stream) 80 | 81 | return graph, static_input, static_output, static_label, static_loss 82 | 83 | def train(params, args, local_rank, world_rank, world_size): 84 | # set device and benchmark mode 85 | torch.backends.cudnn.benchmark = True 86 | torch.cuda.set_device(local_rank) 87 | device = torch.device('cuda:%d'%local_rank) 88 | 89 | # init pynvml and get handle 90 | pynvml.nvmlInit() 91 | nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(device.index) 92 | 93 | # get data loader 94 | logging.info('rank %d, begin data loader init'%world_rank) 95 | train_data_loader, train_dataset, train_sampler = get_data_loader_distributed(params, params.train_data_path, params.distributed, train=True) 96 | val_data_loader, valid_dataset = get_data_loader_distributed(params, params.valid_data_path, params.distributed, train=False) 97 | logging.info('rank %d, data loader initialized'%(world_rank)) 98 | 99 | # create model 100 | model = vit.ViT(params).to(device) 101 | 102 | if params.enable_jit: 103 | model = torch.compile(model) 104 | 105 | if params.amp_dtype == torch.float16: 106 | scaler = GradScaler() 107 | else: 108 | scaler = None 109 | 110 | # weight initialization needs to be synced across shared weights 111 | if comm.get_size("tp-cp") > 1: 112 | init_params_for_shared_weights(model) 113 | 114 | capture_stream = torch.cuda.Stream() 115 | if params.distributed: 116 | with torch.cuda.stream(capture_stream): 117 | model = init_ddp_model_and_reduction_hooks(model, device_ids=[local_rank], 118 | output_device=[local_rank], 119 | bucket_cap_mb=args.bucket_cap_mb) 120 | capture_stream.synchronize() 121 | 122 | 123 | if world_rank == 0: 124 | logging.info(model) 125 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 126 | logging.info(f"Scaffolding memory high watermark: {all_mem_gb} GB.") 127 | 128 | if params.enable_fused: 129 | optimizer = optim.Adam(model.parameters(), lr = params.lr, fused=True, betas=(0.9, 0.95)) 130 | else: 131 | optimizer = optim.Adam(model.parameters(), lr = params.lr, betas=(0.9, 0.95)) 132 | 133 | iters = 0 134 | startEpoch = 0 135 | 136 | if params.lr_schedule == 'cosine': 137 | if params.warmup > 0: 138 | lr_scale = lambda x: min((x+1)/params.warmup, 0.5*(1 + np.cos(np.pi*x/params.num_iters))) 139 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scale) 140 | else: 141 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.num_iters) 142 | else: 143 | scheduler = None 144 | 145 | # select loss function 146 | if params.enable_jit: 147 | loss_func = l2_loss_opt 148 | else: 149 | loss_func = l2_loss 150 | 151 | # capture the model 152 | graph, static_input, static_output, static_label, static_loss = capture_model(params, model, loss_func, scaler, 153 | capture_stream, device, num_warmup=20) 154 | 155 | if world_rank==0: 156 | logging.info("Starting Training Loop...") 157 | 158 | # Log initial loss on train and validation to tensorboard 159 | with torch.no_grad(): 160 | inp, tar = map(lambda x: x.to(device), next(iter(train_data_loader))) 161 | gen = model(inp) 162 | tr_loss = loss_func(gen, tar) 163 | inp, tar = map(lambda x: x.to(device), next(iter(val_data_loader))) 164 | gen = model(inp) 165 | val_loss = loss_func(gen, tar) 166 | val_rmse = weighted_rmse(gen, tar) 167 | if params.distributed: 168 | torch.distributed.all_reduce(tr_loss, op=ReduceOp.AVG, group=comm.get_group("dp")) 169 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")) 170 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")) 171 | if world_rank==0: 172 | args.tboard_writer.add_scalar('Loss/train', tr_loss.item(), 0) 173 | args.tboard_writer.add_scalar('Loss/valid', val_loss.item(), 0) 174 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], 0) 175 | 176 | params.num_epochs = params.num_iters//len(train_data_loader) 177 | iters = 0 178 | t1 = time.time() 179 | for epoch in range(startEpoch, startEpoch + params.num_epochs): 180 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 181 | if params.distributed and (train_sampler is not None): 182 | train_sampler.set_epoch(epoch) 183 | start = time.time() 184 | tr_loss = [] 185 | tr_time = 0. 186 | dat_time = 0. 187 | log_time = 0. 188 | 189 | model.train() 190 | step_count = 0 191 | 192 | for i, data in enumerate(train_data_loader, 0): 193 | if world_rank == 0: 194 | if (epoch == 3 and i == 0): 195 | torch.cuda.profiler.start() 196 | if (epoch == 3 and i == len(train_data_loader) - 1): 197 | torch.cuda.profiler.stop() 198 | 199 | torch.cuda.nvtx.range_push(f"step {i}") 200 | iters += 1 201 | dat_start = time.time() 202 | torch.cuda.nvtx.range_push(f"data copy in {i}") 203 | 204 | inp, tar = map(lambda x: x.to(device), data) 205 | torch.cuda.nvtx.range_pop() # copy in 206 | 207 | tr_start = time.time() 208 | b_size = inp.size(0) 209 | 210 | static_input.copy_(inp) 211 | static_label.copy_(tar) 212 | graph.replay() 213 | 214 | if world_rank == 0 and i == 1: 215 | all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(nvml_handle).used / (1024. * 1024. * 1024.) 216 | logging.info(f" Memory usage after forward pass: {all_mem_gb} GB.") 217 | 218 | if params.amp_dtype == torch.float16: 219 | torch.cuda.nvtx.range_push(f"optimizer") 220 | scaler.step(optimizer) 221 | torch.cuda.nvtx.range_pop() # optimizer 222 | scaler.update() 223 | else: 224 | torch.cuda.nvtx.range_push(f"optimizer") 225 | optimizer.step() 226 | torch.cuda.nvtx.range_pop() # optimizer 227 | 228 | 229 | if params.distributed: 230 | torch.distributed.all_reduce(static_loss, op=ReduceOp.AVG, group=comm.get_group("dp")) 231 | tr_loss.append(static_loss.item()) 232 | 233 | torch.cuda.nvtx.range_pop() # step 234 | # lr step 235 | scheduler.step() 236 | 237 | tr_end = time.time() 238 | tr_time += tr_end - tr_start 239 | dat_time += tr_start - dat_start 240 | step_count += 1 241 | 242 | torch.cuda.synchronize() # device sync to ensure accurate epoch timings 243 | end = time.time() 244 | 245 | if world_rank==0: 246 | iters_per_sec = step_count / (end - start) 247 | samples_per_sec = params["global_batch_size"] * iters_per_sec 248 | logging.info('Time taken for epoch %i is %f sec, avg %f samples/sec', 249 | epoch + 1, end - start, samples_per_sec) 250 | logging.info(' Avg train loss=%f'%np.mean(tr_loss)) 251 | args.tboard_writer.add_scalar('Loss/train', np.mean(tr_loss), iters) 252 | args.tboard_writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], iters) 253 | args.tboard_writer.add_scalar('Avg iters per sec', iters_per_sec, iters) 254 | args.tboard_writer.add_scalar('Avg samples per sec', samples_per_sec, iters) 255 | fig = generate_images([inp, tar, gen]) 256 | args.tboard_writer.add_figure('Visualization, t2m', fig, iters, close=True) 257 | 258 | val_start = time.time() 259 | val_loss = torch.zeros(1, device=device) 260 | val_rmse = torch.zeros((params.n_out_channels), dtype=torch.float32, device=device) 261 | valid_steps = 0 262 | model.eval() 263 | 264 | with torch.inference_mode(): 265 | with torch.no_grad(): 266 | for i, data in enumerate(val_data_loader, 0): 267 | with autocast(enabled=params.amp_enabled, dtype=params.amp_dtype): 268 | inp, tar = map(lambda x: x.to(device), data) 269 | gen = model(inp) 270 | val_loss += loss_func(gen, tar) 271 | val_rmse += weighted_rmse(gen, tar) 272 | valid_steps += 1 273 | 274 | if params.distributed: 275 | torch.distributed.all_reduce(val_loss, op=ReduceOp.AVG, group=comm.get_group("dp")) 276 | torch.distributed.all_reduce(val_rmse, op=ReduceOp.AVG, group=comm.get_group("dp")) 277 | 278 | val_rmse /= valid_steps # Avg validation rmse 279 | val_loss /= valid_steps 280 | val_end = time.time() 281 | if world_rank==0: 282 | logging.info(' Avg val loss={}'.format(val_loss.item())) 283 | logging.info(' Total validation time: {} sec'.format(val_end - val_start)) 284 | args.tboard_writer.add_scalar('Loss/valid', val_loss, iters) 285 | args.tboard_writer.add_scalar('RMSE(u10m)/valid', val_rmse.cpu().numpy()[0], iters) 286 | args.tboard_writer.flush() 287 | 288 | torch.cuda.synchronize() 289 | t2 = time.time() 290 | tottime = t2 - t1 291 | pynvml.nvmlShutdown() 292 | 293 | 294 | if __name__ == '__main__': 295 | parser = argparse.ArgumentParser() 296 | parser.add_argument( 297 | "--run_num", 298 | default="00", 299 | type=str, 300 | help="tag for indexing the current experiment", 301 | ) 302 | parser.add_argument( 303 | "--yaml_config", 304 | default="./config/ViT.yaml", 305 | type=str, 306 | help="path to yaml file containing training configs", 307 | ) 308 | parser.add_argument( 309 | "--config", default="base", type=str, help="name of desired config in yaml file" 310 | ) 311 | parser.add_argument( 312 | "--amp_mode", 313 | default="none", 314 | type=str, 315 | choices=["none", "fp16", "bf16"], 316 | help="select automatic mixed precision mode", 317 | ) 318 | parser.add_argument( 319 | "--enable_fused", action="store_true", help="enable fused Adam optimizer" 320 | ) 321 | parser.add_argument( 322 | "--enable_jit", action="store_true", help="enable JIT compilation" 323 | ) 324 | parser.add_argument( 325 | "--local_batch_size", 326 | default=None, 327 | type=int, 328 | help="local batchsize (manually override global_batch_size config setting)", 329 | ) 330 | parser.add_argument( 331 | "--num_iters", default=None, type=int, help="number of iters to run" 332 | ) 333 | parser.add_argument( 334 | "--num_data_workers", 335 | default=None, 336 | type=int, 337 | help="number of data workers for data loader", 338 | ) 339 | parser.add_argument( 340 | "--data_loader_config", 341 | default=None, 342 | type=str, 343 | choices=["pytorch", "dali"], 344 | help="dataloader configuration. choices: 'pytorch', 'dali'", 345 | ) 346 | parser.add_argument( 347 | "--bucket_cap_mb", default=25, type=int, help="max message bucket size in mb" 348 | ) 349 | parser.add_argument( 350 | "--disable_broadcast_buffers", 351 | action="store_true", 352 | help="disable syncing broadcasting buffers", 353 | ) 354 | parser.add_argument( 355 | "--noddp", action="store_true", help="disable DDP communication" 356 | ) 357 | 358 | # model parallelism arguments 359 | parser.add_argument( 360 | "--tensor_parallel", 361 | default=1, 362 | type=int, 363 | help="Number of GPUs for tensor parallelism", 364 | ) 365 | parser.add_argument( 366 | "--context_parallel", 367 | default=1, 368 | type=int, 369 | help="Number of GPUs for tensor parallelism", 370 | ) 371 | parser.add_argument( 372 | "--parallel_order", 373 | default="tp-cp-dp", 374 | type=str, 375 | help="Order of ranks for parallelism", 376 | ) 377 | 378 | args = parser.parse_args() 379 | 380 | 381 | run_num = args.run_num 382 | 383 | params = YParams(os.path.abspath(args.yaml_config), args.config) 384 | 385 | # Update config with modified args 386 | # set up amp 387 | if args.amp_mode != 'none': 388 | params.update({"amp_mode": args.amp_mode}) 389 | amp_dtype = torch.float32 390 | if params.amp_mode == "fp16": 391 | amp_dtype = torch.float16 392 | elif params.amp_mode == "bf16": 393 | amp_dtype = torch.bfloat16 394 | params.update({"amp_enabled": amp_dtype is not torch.float32, 395 | "amp_dtype" : amp_dtype, 396 | "enable_fused" : args.enable_fused, 397 | "enable_jit" : args.enable_jit 398 | }) 399 | 400 | if args.data_loader_config: 401 | params.update({"data_loader_config" : args.data_loader_config}) 402 | 403 | if args.num_iters: 404 | params.update({"num_iters" : args.num_iters}) 405 | 406 | if args.num_data_workers: 407 | params.update({"num_data_workers" : args.num_data_workers}) 408 | 409 | params.distributed = False 410 | 411 | # setup model parallel sizes 412 | params["tp"] = args.tensor_parallel 413 | params["cp"] = args.context_parallel 414 | params["order"] = args.parallel_order 415 | # initialize comm 416 | comm.init(params, verbose=True) 417 | 418 | # get info from comm 419 | world_size = comm.get_world_size() 420 | world_rank = comm.get_world_rank() 421 | local_rank = comm.get_local_rank() 422 | params.distributed = (world_size > 1) 423 | 424 | assert ( 425 | params["global_batch_size"] % comm.get_size("dp") == 0 426 | ), f"Error, cannot evenly distribute {params['global_batch_size']} across {comm.get_size('dp')} GPU." 427 | 428 | if args.local_batch_size: 429 | # Manually override batch size 430 | params.local_batch_size = args.local_batch_size 431 | params.update({"global_batch_size" : comm.get_size("dp") * args.local_batch_size}) 432 | else: 433 | # Compute local batch size based on number of ranks 434 | params.local_batch_size = int(params["global_batch_size"] // comm.get_size("dp")) 435 | 436 | # for data loader, set the actual number of data shards and id 437 | params.data_num_shards = comm.get_size("dp") 438 | params.data_shard_id = comm.get_rank("dp") 439 | 440 | # Set up directory 441 | baseDir = params.expdir 442 | expDir = os.path.join( 443 | baseDir, args.config + "/%dMP/" % (comm.get_size("tp-cp")) + str(run_num) + "/" 444 | ) 445 | if world_rank==0: 446 | if not os.path.isdir(expDir): 447 | os.makedirs(expDir) 448 | logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log')) 449 | params.log() 450 | args.tboard_writer = SummaryWriter(log_dir=os.path.join(expDir, 'logs/')) 451 | 452 | params.experiment_dir = os.path.abspath(expDir) 453 | 454 | train(params, args, local_rank, world_rank, world_size) 455 | 456 | if params.distributed: 457 | torch.distributed.barrier() 458 | logging.info('DONE ---- rank %d'%world_rank) 459 | 460 | --------------------------------------------------------------------------------