├── .gitmodules ├── .gitignore ├── environment.yml ├── LICENSE ├── benchmarks ├── pyproject.toml ├── era5_zarr_benchmark.py └── zstd_benchmark.py ├── Dockerfile ├── zarr_ML_optimization ├── README.md ├── trainer_utils.py ├── era5_dataloader.py └── train_unet.py ├── zarr_dali_example ├── DALI_images.py └── zarr_dali_minimal.py ├── README.md └── rechunk └── era5_rechunking.ipynb /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Data files and folders 7 | data/** 8 | !data/**/ 9 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gpuhackathon 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - cupy~=13.3.0 7 | - dask-jobqueue~=0.9.0 8 | - rapidsai::kvikio>=25.04.00 9 | - jupyterlab~=4.3.5 10 | - nvcomp~=4.2.0.14 11 | - nvidia-dali-python~=1.45.0 12 | - nvtx~=0.2.11 13 | - python=3.12 14 | - pip~=25.0 15 | - pytorch~=2.6.0 16 | - segmentation-models-pytorch~=0.4.0 17 | - xarray>=2025.4.0 18 | - zarr~=3.0.8 19 | - pip: 20 | - zarr @ git+https://github.com/akshaysubr/zarr-python.git@gpu-codecs 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xarray on GPUs Project Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "benchmark" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "kvikio-cu12>=25.2.0", 9 | "nvtx>=0.2.11", 10 | "zarr", 11 | "ipython", 12 | "xarray", 13 | "rich>=13.9.4", 14 | "kvikio-zarr-v3", 15 | ] 16 | 17 | [tool.uv.sources] 18 | zarr = { git = "https://github.com/akshaysubr/zarr-python", rev = "gpu-codecs" } # need to push tags here 19 | # zarr = { git = "https://github.com/TomAugspurger/zarr-python", rev = "gpu-codecs" } 20 | # xarray = { git = "https://github.com/pydata/xarray", rev = "fix-cupy"} 21 | xarray = { git = "https://github.com/dcherian/xarray", rev = "fix-cupy" } 22 | kvikio_zarr_v3 = { git = "https://github.com/TomAugspurger/kvikio-zarr-v3" } 23 | 24 | [dependency-groups] 25 | dev = [ 26 | "ruff>=0.9.7", 27 | ] 28 | 29 | [tool.uv] 30 | override-dependencies = [ 31 | # "zarr @ git+https://github.com/akshaysubr/zarr-python@gpu-codecs", 32 | "zarr==3.0.4", 33 | "nvidia-nvcomp-cu12==4.2.0.14", 34 | ] 35 | 36 | [tool.ruff.lint] 37 | select = ["E", "W", "F", "I", "C", "B", "UP", "N", "S", "ERA", "PD", "EXE", "PGH"] 38 | 39 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG PYT_VER=25.02 16 | FROM nvcr.io/nvidia/pytorch:$PYT_VER-py3 as builder 17 | 18 | # Update pip and setuptools 19 | RUN pip install --upgrade pip setuptools 20 | 21 | # Setup git lfs, graphviz gl1(vtk dep) 22 | RUN apt-get update && \ 23 | apt-get install -y git-lfs graphviz libgl1 && \ 24 | git lfs install && \ 25 | pip install torchviz 26 | 27 | SHELL ["/bin/bash", "-c"] 28 | 29 | RUN git clone -b gpu-codecs https://github.com/akshaysubr/zarr-python.git /opt/zarr-python && \ 30 | cd /opt/zarr-python && \ 31 | pip install .[gpu] 32 | 33 | ENV _CUDA_COMPAT_TIMEOUT=90 34 | -------------------------------------------------------------------------------- /zarr_ML_optimization/README.md: -------------------------------------------------------------------------------- 1 | # Zarr ML end-to-end example 2 | 3 | This folder contains an end-to-end example of training a UNet model using the DALI library with Zarr data format. 4 | The code for this example is parallelized using PyTorch DDP (Distributed Data Parallel) and can be run on multiple GPUs or nodes. 5 | 6 | 7 | ## How to run! 8 | To run on 1 GPU, use the following command: 9 | 10 | ```bash 11 | module load conda 12 | conda activate gpuhackathon 13 | ``` 14 | 15 | To run on 1 GPU, use the following command: 16 | ``` 17 | ./train_unet.py 18 | ``` 19 | 20 | To run on single node multi-GPU, use the following command: 21 | 22 | ```bash 23 | torchrun --nnodes=1 --nproc-per-node=4 train_unet.py --distributed 24 | ``` 25 | 26 | To run with nsys: 27 | 28 | ```bash 29 | module purge 30 | module load ncarenv/23.09 31 | module reset 32 | module load cuda 33 | ``` 34 | Check the result of the following matches `/glade/u/apps/common/23.08/spack/opt/spack/cuda/12.2.1/bin/nsys`: 35 | 36 | ```bash 37 | which nsys 38 | ``` 39 | 40 | Then run with nsys profiling with 41 | 42 | ``` 43 | nsys profile -t nvtx,cuda,osrt --gpu-metrics-device all --force-overwrite=true --output=training_benchmark python train_unet.py 44 | ``` 45 | 46 | To run without data loading: 47 | 48 | ``` 49 | export synthetic="True"; nsys profile -t nvtx,cuda,osrt --gpu-metrics-device all --force-overwrite=true --output=training_benchmark python train_unet.py 50 | ``` 51 | 52 | ---------------- 53 | 54 | To run on multi-node multi-GPU, use the following command (full example [here](https://github.com/negin513/distributed-pytorch-hpc/blob/main/scripts/run_mpi.sh)): 55 | 56 | ```bash 57 | MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -np 8 ./train_unet.py --distributed 58 | ``` 59 | 60 | In the above command, replace `$head_node_ip` with the IP address of the head node. 61 | For example using PBS: 62 | ``` bash 63 | # Determine the number of nodes: 64 | nnodes=$(< $PBS_NODEFILE wc -l) 65 | nodes=( $( cat $PBS_NODEFILE ) ) 66 | head_node=${nodes[0]} 67 | head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}') 68 | ``` 69 | 70 | 71 | # 72 | Average throughput on 1 GPU (A100-40GB): 92.34 samples/sec, ~1000 seconds per epoch, (~35 minutes) 73 | Average throughput on 4 GPUs (A100-40GB): 800.5 samples/sec, ~500 seconds per epoch , (~ 16 minutes) 74 | -------------------------------------------------------------------------------- /benchmarks/era5_zarr_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is a GPU/CPU I/O benchmark for reading a Zarr dataset. 3 | It compares read performance between CPU-based and GPU-native approaches using Zarr v3. 4 | 5 | The script uses: 6 | - `zarr.config.enable_gpu()` for GPU-backed reads (via CuPy), 7 | - `GDSStore` from `kvikio_zarr_v3` for GPU Direct Storage (GDS) support, 8 | - `nvtx` annotations for profiling iterations with NVIDIA Nsight tools. 9 | 10 | The dataset is assumed to be a 4D array stored under the key 'combined', typically in (time, channel, height, width) format. 11 | 12 | The benchmark: 13 | - Reads pairs of time steps in a loop, 14 | - Measures elapsed time, 15 | - Computes effective I/O bandwidth in GB/s. 16 | """ 17 | import asyncio 18 | from contextlib import nullcontext 19 | import math 20 | import time 21 | from pathlib import Path 22 | from tempfile import TemporaryDirectory 23 | 24 | import numpy as np 25 | import nvtx 26 | 27 | import zarr 28 | from zarr.abc.codec import Codec 29 | from zarr.abc.store import Store 30 | from zarr.codecs import NvcompZstdCodec, ZstdCodec 31 | from zarr.storage import LocalStore 32 | 33 | from kvikio_zarr_v3 import GDSStore 34 | 35 | 36 | def get_store(path: Path, cls: Store = LocalStore) -> LocalStore: 37 | async def _get_store(path: Path) -> LocalStore: 38 | return await cls.open(path) 39 | 40 | return asyncio.run(_get_store(path)) 41 | 42 | @nvtx.annotate(color="red", domain="benchmark") 43 | def read( 44 | store: Store, 45 | gpu: bool = True, 46 | ) -> None: 47 | with zarr.config.enable_gpu() if gpu else nullcontext(): 48 | g = zarr.open_group(store=store) 49 | a = g.get("combined") 50 | size = tuple(a.shape) 51 | print(f"Opened array with compressors: {a.compressors}") 52 | 53 | color = "green" if gpu else "blue" 54 | start_time = time.time() 55 | niters = min(100, size[0] // 2) 56 | for i in range(niters): 57 | with nvtx.annotate(message=f"iteration {i}", color=color): 58 | start_time_index = 2 * i 59 | end_time_index = 2 * (i + 1) 60 | result = a[start_time_index:end_time_index, :, :, :] 61 | end_time = time.time() 62 | elapsed_time = end_time - start_time 63 | total_bytes_gb = 2.0 * (niters) * math.prod(size[1:]) / (1024.0) ** 3 64 | print(f"Total time to read data: {elapsed_time} s") 65 | print(f"Effective I/O bandwidth: {total_bytes_gb / elapsed_time} GB/s") 66 | 67 | 68 | if __name__ == "__main__": 69 | path = Path("/glade/derecho/scratch/katelynw/era5/rechunked_stacked_test.zarr") 70 | store = get_store(path, cls=GDSStore) 71 | read(store, gpu=False) 72 | read(store, gpu=True) 73 | -------------------------------------------------------------------------------- /zarr_dali_example/DALI_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | #----------------- 4 | # This is not working yet. 5 | #----------------- 6 | 7 | from nvidia.dali import pipeline_def, Pipeline 8 | import nvidia.dali.fn as fn, types 9 | #from nvidia.dali.plugin.pytorch import DALIGenericIterator 10 | 11 | from test_case.ERA5TimeSeriesDataset import ERA5Dataset, PyTorchERA5Dataset 12 | 13 | 14 | @pipeline_def 15 | def create_dali_pipeline(dataset, batch_size, device='gpu'): 16 | """ 17 | Creates a DALI pipeline for loading and preprocessing ERA5 time-series data. 18 | 19 | Args: 20 | dataset (ERA5TimeSeriesDataset): The dataset to load from. 21 | batch_size (int): Batch size for the pipeline. 22 | device (str): Device to use ('gpu' or 'cpu'). Defaults to 'gpu'. 23 | """ 24 | 25 | # Define external source to fetch data from the dataset 26 | inputs, targets = fn.external_source( 27 | source=dataset, 28 | num_outputs=2, 29 | dtype=types.FLOAT, 30 | device=device, 31 | parallel=True, 32 | batch=True, # NS: error if not batch=True 33 | ) 34 | print ('----') 35 | print(f"inputs: {inputs}, targets: {targets}") 36 | 37 | print(f"inputs: {inputs}, targets: {targets}") 38 | 39 | return inputs, targets 40 | 41 | 42 | # Example usage: 43 | if __name__ == "__main__": 44 | ## Example usage of the ERA5TimeSeriesDataset class 45 | data_path = "/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv" 46 | start_year = 2001 47 | end_year = 2010 48 | 49 | # for now, just surface variables!cd 50 | input_vars = ['t2m', 'V500', 'U500', 'T500', 'Z500', 'Q500'] 51 | target_vars = ['t2m'] 52 | 53 | # Enable GPU support 54 | zarr.config.enable_gpu() 55 | 56 | #dataset = ERA5TimeSeriesDataset(data_path, start_year, end_year, input_vars=input_vars) 57 | 58 | era5_dataset = ERA5Dataset( 59 | data_path=data_path, 60 | start_year=start_year, 61 | end_year=end_year, 62 | input_vars=input_vars, 63 | target_vars=target_vars 64 | ) 65 | 66 | 67 | 68 | print(era5_dataset) 69 | print(f"Total samples: {len(era5_dataset)}") 70 | 71 | #pytorch_dataset = PyTorchERA5Dataset(train_dataset) 72 | 73 | 74 | # Create DALI pipeline 75 | batch_size = 32 76 | 77 | #pipe = create_dali_pipeline(era5_dataset.fetch_timeseries(), batch_size=batch_size, device='gpu', num_threads=4, device_id=0) 78 | #pipe.build() # This fails here! 79 | 80 | pipe = Pipeline(batch_size=batch_size, num_threads=2, device_id=0) 81 | with pipe: 82 | inputs, targets = create_dali_pipeline(era5_dataset.fetch_timeseries(), batch_size=batch_size, device='gpu') 83 | pipe.set_outputs(inputs, targets) 84 | 85 | 86 | pipe.build() 87 | pipe_out = pipe.run() 88 | 89 | # Create DALI iterator 90 | #dali_iter = DALIGenericIterator(pipe, output_map=["inputs", "targets"], size=len(era5_dataset)) 91 | 92 | # Fetch a batch of data 93 | #for batch in dali_iter: 94 | # inputs = batch[0]["inputs"] 95 | # targets = batch[0]["targets"] 96 | # print(f"Input shape: {inputs.shape}, Target shape: {targets.shape}") 97 | # break # just a test 98 | -------------------------------------------------------------------------------- /benchmarks/zstd_benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | Zarr v3 I/O Benchmark: CPU vs GPU Read Performance 3 | 4 | This script benchmarks the I/O performance of writing and reading a synthetic 5 | Zarr v3 dataset using CPU and GPU. It demonstrates how to: 6 | 7 | - Create a 4D array in Zarr v3 using a specified compression codec (CPU or GPU). 8 | - Read the dataset using either CPU-based or GPU-accelerated access. 9 | - Annotate profiling regions using NVTX for use with NVIDIA Nsight tools. 10 | - Compute and report effective I/O bandwidth in GB/s. 11 | 12 | """ 13 | 14 | import asyncio 15 | from contextlib import nullcontext 16 | import math 17 | import time 18 | from pathlib import Path 19 | from tempfile import TemporaryDirectory 20 | 21 | import numpy as np 22 | import nvtx 23 | 24 | import zarr 25 | from zarr.abc.codec import Codec 26 | from zarr.abc.store import Store 27 | from zarr.codecs import NvcompZstdCodec, ZstdCodec 28 | from zarr.storage import LocalStore 29 | 30 | def get_store(path: Path) -> LocalStore: 31 | async def _get_store(path: Path) -> LocalStore: 32 | return await LocalStore.open(path) 33 | 34 | return asyncio.run(_get_store(path)) 35 | 36 | @nvtx.annotate(color="red", domain="benchmark") 37 | def write( 38 | size: tuple[int, int, int, int], 39 | chunks: tuple[int, int, int, int], 40 | store: Store, 41 | write_codec: str | Codec, 42 | read_codec: str | Codec, 43 | ) -> None: 44 | src = np.random.uniform(size=size).astype(np.float32) # allocate on CPU 45 | z = zarr.create_array( 46 | store, 47 | name="a", 48 | shape=src.shape, 49 | chunks=chunks, 50 | dtype=src.dtype, 51 | overwrite=True, 52 | zarr_format=3, 53 | compressors=write_codec, 54 | ) 55 | z[:] = src 56 | 57 | @nvtx.annotate(color="red", domain="benchmark") 58 | def read( 59 | size: tuple[int, int, int, int], 60 | store: Store, 61 | gpu: bool = True, 62 | ) -> None: 63 | with zarr.config.enable_gpu() if gpu else nullcontext(): 64 | g = zarr.open_group(store=store) 65 | a = g.get("a") 66 | print(f"Opened array with compressors: {a.compressors}") 67 | 68 | color = "green" if gpu else "blue" 69 | start_time = time.time() 70 | for i in range(size[0] // 2): 71 | with nvtx.annotate(message=f"iteration {i}", color=color): 72 | start_time_index = 2 * i 73 | end_time_index = 2 * (i + 1) 74 | result = a[start_time_index:end_time_index, :, :, :] 75 | end_time = time.time() 76 | elapsed_time = end_time - start_time 77 | total_bytes_gb = 2.0 * (size[0] // 2) * math.prod(size[1:]) / (1024.0) ** 3 78 | print(f"Total time to read data: {elapsed_time} s") 79 | print(f"Effective I/O bandwidth: {total_bytes_gb / elapsed_time} GB/s") 80 | 81 | 82 | if __name__ == "__main__": 83 | cpu_codec = ZstdCodec() 84 | gpu_codec = NvcompZstdCodec() 85 | 86 | dataset_size = (10, 128, 1280, 640) 87 | chunks = (1, 1, 1280, 640) 88 | 89 | with TemporaryDirectory() as tmpdir: 90 | path = Path(tmpdir) / "benchmark.zarr" 91 | store = get_store(path) 92 | write(dataset_size, chunks, store, cpu_codec, gpu_codec) 93 | read(dataset_size, store, gpu=False) 94 | read(dataset_size, store, gpu=True) 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerating AI/ML Workflows in Earth Sciences with GPU-Native Xarray and Zarr 2 | 3 | Read about this project in the [Xarray blog](https://xarray.dev/blog/gpu-pipeline). 4 | 5 | 🏔️⚡ A collaborative benchmarking and optimization effort from [NSF-NCAR](https://www.ncar.ucar.edu/), [Development Seed](https://developmentseed.org/), and [NVIDIA](https://www.nvidia.com/) to accelerate data-intensive geoscience AI/ML workflows using GPU-native technologies like Zarr v3, CuPy, KvikIO, and NVIDIA DALI. 6 | 7 | ## 📌 Overview 8 | 9 | This repository contains code, benchmarks, and examples from Xarray on GPUs hackathon project during the 10 | [NREL/NCAR/NOAA Open Hackathon](https://www.openhackathons.org/s/siteevent/a0CUP00000rwYYZ2A2/se000355) 11 | in Golden, Colorado from 18-27 February 2025. The goal of this project is to provide a proof-of-concept example of optimizing the performance of geospatial machine learning workflows on GPUs by using [Zarr-python v3](https://zarr.dev/) and [NVIDIA DALI](https://developer.nvidia.com/dali). 12 | 13 | 📖 [Read the full blog post](https://xarray.dev/blog/gpu-pipeline) 14 | 15 | In this project, we demonstrate how to: 16 | 17 | - Optimize chunking strategies for Zarr datasets 18 | - Read ERA5 Zarr v3 data directly into GPU memory using CuPy and KvikIO 19 | - Apply GPU-based decompression using NVIDIA's nvCOMP 20 | - Build end-to-end GPU-native DALI pipelines 21 | - Improve training throughput for U-Net-based ML models 22 | 23 | 24 | ## 📂 Repository Structure 25 | 26 | In this repository, you will find the following: 27 | 28 | - `benchmarks/`: Scripts to evaluate read and write performance for Zarr v3 datasets on both CPU and GPU. 29 | - `zarr_dali_example/`: Contains a minimal example of using DALI to read Zarr data and train a model. 30 | - `zarr_ML_optimization`: Contains an example benchmark for training a U-Net model using DALI with Zarr data format. 31 | - `rechunk` : Contains a notebook that demonstrates how to optimize chunking strategies for Zarr datasets. 32 | 33 | See [zarr_ML_optimization/README.md](zarr_ML_optimization/README.md) for more details on running the U-Net training example. 34 | 35 | 36 | # Creating the Environment 37 | 38 | ## Basic 39 | 40 | Start by cloning the repo & setting up the `conda` environment: 41 | ```bash 42 | git clone https://github.com/pangeo-data/ncar-hackathon-xarray-on-gpus.git 43 | cd ncar-hackathon-xarray-on-gpus 44 | conda env create --file environment.yml 45 | conda activate gpuhackathon 46 | ``` 47 | 48 | ### Advanced using `conda-lock` 49 | 50 | This is for those who want full reproducibility of the virtual environment. 51 | Create a virtual environment with just Python and conda-lock installed first. 52 | 53 | ``` 54 | conda create --name gpuhackathon python=3.11 conda-lock=2.5.7 55 | conda activate gpuhackathon 56 | ``` 57 | 58 | Generate a unified [`conda-lock.yml`](https://github.com/conda/conda-lock) file 59 | based on the dependency specification in `environment.yml`. Use only when 60 | creating a new `conda-lock.yml` file or refreshing an existing one. 61 | ``` 62 | conda-lock lock --mamba --file environment.yml --platform linux-64 --with-cuda=12.8 63 | ``` 64 | 65 | Installing/Updating a virtual environment from a lockile. Use this to sync your 66 | dependencies to the exact versions in the `conda-lock.yml` file. 67 | 68 | ``` 69 | conda-lock install --mamba --name gpuhackathon conda-lock.yml 70 | ``` 71 | See also https://conda.github.io/conda-lock/output/#unified-lockfile for more 72 | usage details. 73 | 74 | -------------------------------------------------------------------------------- /zarr_dali_example/zarr_dali_minimal.py: -------------------------------------------------------------------------------- 1 | """ 2 | dali + Zarr (GPU) example. 3 | 4 | This script adapts the GPU example from 5 | https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html 6 | to use Zarr for storage. 7 | 8 | To run it, you'll currently need to use my fork of zarr-python: 9 | 10 | pip install git+https://github.com/TomAugspurger/zarr-python/@tom/fix/gpu 11 | 12 | That should be in zarr `main` soon. You'll also need the data. 13 | 14 | ``` 15 | mkdir -p data/images 16 | cd data/images 17 | curl -O https://docs.nvidia.com/deeplearning/dali/user-guide/docs/_images/examples_general_data_loading_external_input_12_2.png 18 | curl -O curl -O https://docs.nvidia.com/deeplearning/dali/user-guide/docs/_images/examples_general_data_loading_external_input_19_2.png 19 | 20 | ``` 21 | 22 | And a `file_list.txt` like 23 | 24 | ``` 25 | examples_general_data_loading_external_input_12_2.png 0 26 | examples_general_data_loading_external_input_19_2.png 1 27 | ``` 28 | 29 | Then run `make_data()` to create the zarr store. 30 | """ 31 | 32 | 33 | import types 34 | from random import shuffle 35 | from nvidia.dali.pipeline import Pipeline 36 | import nvidia.dali.fn as fn 37 | import zarr 38 | import zarr.storage 39 | from PIL import Image 40 | 41 | 42 | batch_size = 16 43 | 44 | 45 | # create the data 46 | # Right now, assuming a chunksize of 1 along the dimension being sampled. 47 | # We have some interesting options here w.r.t. the chunksize and shuffling. 48 | # 49 | 50 | 51 | def make_data(): 52 | # TODO: figure out the shape here. 53 | # goes from 4 -> 3 somewhere. 54 | store = zarr.storage.LocalStore(root="data/example.zarr") 55 | group = zarr.create_group(store, overwrite=True) 56 | 57 | TOTAL_SAMPLES = 100 58 | 59 | # note: the images from the docs vary in size while Zarr requires 60 | # uniform chunk sizes. I've truncated the images to 231 x 300 61 | 62 | arr = group.create_array( 63 | name="images", 64 | shape=(TOTAL_SAMPLES, 231, 300, 3), 65 | chunks=(1, 231, 300, 3), 66 | dtype="uint8", 67 | overwrite=True, 68 | ) 69 | 70 | labels = group.create_array( 71 | name="labels", 72 | shape=(TOTAL_SAMPLES,), 73 | chunks=(1,), 74 | dtype="uint8", 75 | overwrite=True, 76 | ) 77 | 78 | # TODO: use file list 79 | # assuming you've downloaded these two 80 | img = Image.open( 81 | "data/images/examples_general_data_loading_external_input_12_2.png" 82 | ) 83 | arr[0] = img 84 | labels[0] = 0 85 | img = Image.open( 86 | "data/images/examples_general_data_loading_external_input_19_2.png" 87 | ) 88 | arr[1] = img 89 | labels[1] = 1 90 | 91 | 92 | class ExternalInputIterator: 93 | def __init__(self, batch_size: int): 94 | self.root = "data/example.zarr/" 95 | self.variable = "images" 96 | self.batch_size = batch_size 97 | 98 | # Does this class get serialized? Is it safe to store 99 | # references to zarr arrays here? 100 | # self.images = zarr.open_array(self.root, path=self.variable) 101 | # self.labels = zarr.open_array(self.root, path="labels") 102 | 103 | self.indices = list( 104 | range(zarr.open_array(self.root, path=self.variable).shape[0]) 105 | ) 106 | shuffle(self.indices) 107 | self.i = 0 108 | self.n = len(self.indices) 109 | 110 | def __iter__(self): 111 | self.i = 0 112 | self.n = len(self.indices) 113 | return self 114 | 115 | def __next__(self): 116 | batch = [] 117 | labels = [] 118 | 119 | arr = zarr.open(self.root, path=self.variable) 120 | arr_labels = zarr.open(self.root, path="labels") 121 | 122 | for _ in range(self.batch_size): 123 | batch.append(arr[self.i]) 124 | labels.append(arr_labels[self.i]) 125 | self.i = (self.i + 1) % self.n 126 | return (batch, labels) 127 | 128 | 129 | def main(): 130 | make_data() 131 | print (" Data created!") 132 | eii = ExternalInputIterator(batch_size) 133 | zarr.config.enable_gpu() 134 | pipe = Pipeline(batch_size=batch_size, num_threads=2, device_id=0) 135 | # note: using the `device="gpu"` variant from https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/external_input.html 136 | with pipe: 137 | images, labels = fn.external_source(source=eii, num_outputs=2, dtype=types.UINT8, device="gpu") 138 | enhance = fn.brightness_contrast(images, contrast=2) 139 | pipe.set_outputs(enhance, labels) 140 | 141 | pipe.build() 142 | pipe_out = pipe.run() 143 | 144 | batch_cpu = pipe_out[0].as_cpu() 145 | labels_cpu = pipe_out[1].as_cpu() 146 | 147 | print(batch_cpu.at(0).shape) 148 | print(labels_cpu.at(0)) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /zarr_ML_optimization/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import socket 4 | import numpy as np 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | def setup_logging(world_rank: int, level: int = logging.INFO) -> None: 10 | """Sets up basic logging. Logs only from rank 0.""" 11 | if world_rank == 0: 12 | logging.basicConfig( 13 | level=level, 14 | format="%(asctime)s [%(levelname)s] : %(message)s", 15 | datefmt="%Y-%m-%d %H:%M:%S", 16 | ) 17 | else: 18 | logging.basicConfig(level=logging.CRITICAL + 1) 19 | 20 | def set_random_seeds(random_seed=0): 21 | """ 22 | Sets random seeds for reproducibility 23 | """ 24 | torch.manual_seed(random_seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = True 27 | np.random.seed(random_seed) 28 | 29 | def custom_loss(predictions, targets, lambda_std=0.1): 30 | """ 31 | Another custom loss function combining RMSE with standard deviation matching. 32 | 33 | The function handles two key aspects of the prediction quality: 34 | 1. Accuracy: Through RMSE calculation 35 | 2. Variability: Through standard deviation matching 36 | 37 | """ 38 | # Calculate RMSE for prediction accuracy 39 | rmse_loss = torch.nn.functional.mse_loss( 40 | predictions, targets, reduction="mean" 41 | ).sqrt() 42 | 43 | # Calculate standard deviation component 44 | # We'll calculate std over the batch dimension (dim=0) and average over spatial dimensions 45 | # unbiased=False removes the Bessel correction and addresses the warning 46 | pred_std = torch.std(predictions.view(-1), unbiased=False) 47 | target_std = torch.std(targets.view(-1), unbiased=False) 48 | 49 | # Average the standard deviation differences across spatial dimensions 50 | std_loss = torch.mean(torch.abs(pred_std - target_std)) 51 | 52 | # Combine the losses with the weighting factor 53 | total_loss = rmse_loss + lambda_std * std_loss 54 | 55 | # Store components for monitoring 56 | loss_components = { 57 | "rmse": rmse_loss.item(), 58 | "std_diff": std_loss.item(), 59 | "total": total_loss.item(), 60 | } 61 | 62 | return total_loss, loss_components 63 | 64 | def init_process_group( 65 | distributed: bool, backend: str = "nccl" 66 | ) -> tuple[int, int, int]: 67 | """ 68 | Initialize the process group for distributed training. 69 | """ 70 | if distributed: 71 | # Try MPI detection first 72 | try: 73 | from mpi4py import MPI 74 | 75 | comm = MPI.COMM_WORLD 76 | shmem_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) 77 | 78 | local_rank = shmem_comm.Get_rank() 79 | world_size = comm.Get_size() 80 | world_rank = comm.Get_rank() 81 | 82 | if "MASTER_ADDR" not in os.environ: 83 | os.environ["MASTER_ADDR"] = comm.bcast(socket.gethostbyname(socket.gethostname()), root=0) 84 | 85 | if "MASTER_PORT" not in os.environ: 86 | os.environ["MASTER_PORT"] = str(np.random.randint(1000, 8000)) 87 | 88 | except: 89 | if "LOCAL_RANK" in os.environ: 90 | # Environment variables set by torch.distributed.launch or torchrun 91 | LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 92 | WORLD_SIZE = int(os.environ["WORLD_SIZE"]) 93 | WORLD_RANK = int(os.environ["RANK"]) 94 | elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: 95 | # Environment variables set by mpirun 96 | LOCAL_RANK = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 97 | WORLD_SIZE = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 98 | WORLD_RANK = int(os.environ["OMPI_COMM_WORLD_RANK"]) 99 | elif "PMI_RANK" in os.environ: 100 | # Environment variables set by cray-mpich 101 | LOCAL_RANK = int(os.environ["PMI_LOCAL_RANK"]) 102 | WORLD_SIZE = int(os.environ["PMI_SIZE"]) 103 | WORLD_RANK = int(os.environ["PMI_RANK"]) 104 | else: 105 | raise RuntimeError( 106 | "Can't find the environment variables for local rank!" 107 | ) 108 | else: # Non-distributed mode 109 | # for running without torchrun or mpirun (i.e. ./train_unet.py) 110 | LOCAL_RANK = 0 111 | WORLD_SIZE = 1 112 | WORLD_RANK = 0 113 | 114 | # --------------------- 115 | # Initialize distributed training 116 | if distributed: 117 | torch.distributed.init_process_group( 118 | backend=backend, rank=WORLD_RANK, world_size=WORLD_SIZE 119 | ) 120 | torch.cuda.set_device(LOCAL_RANK) 121 | if WORLD_RANK == 0: 122 | print("----Distrbuted Setup-----") 123 | print("LOCAL_RANK : ", LOCAL_RANK) 124 | print("WORLD_SIZE : ", WORLD_SIZE) 125 | print("WORLD_RANK : ", WORLD_RANK) 126 | print("cuda device : ", torch.cuda.device_count()) 127 | print("pytorch version : ", torch.__version__) 128 | print("nccl version : ", torch.cuda.nccl.version()) 129 | print("torch config : ", torch.__config__.show()) 130 | print(torch.__config__.parallel_info()) 131 | print("-------------------------") 132 | 133 | return LOCAL_RANK, WORLD_SIZE, WORLD_RANK -------------------------------------------------------------------------------- /rechunk/era5_rechunking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "153c98aa-2f72-4cb4-a96a-c01374d84930", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Dask imports\n", 11 | "\n", 12 | "from dask_jobqueue import PBSCluster\n", 13 | "from dask.distributed import Client" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "338d23f5-92d2-422b-bbe4-01955aceff50", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Dask cluster config\n", 24 | "\n", 25 | "cluster = PBSCluster(\n", 26 | " # Basic job directives\n", 27 | " job_name = 'hackathon-rechunk',\n", 28 | " queue = 'casper',\n", 29 | " walltime = '120:00',\n", 30 | " # Make sure you change the project code if running this notebook!!\n", 31 | " account = 'UCSG0002',\n", 32 | " log_directory = 'dask-logs',\n", 33 | " # These settings impact the resources assigned to the job\n", 34 | " cores = 1,\n", 35 | " memory = '10GiB',\n", 36 | " resource_spec = 'select=1:ncpus=1:mem=10GB',\n", 37 | " # These settings define the resources assigned to a worker\n", 38 | " processes = 1,\n", 39 | " # This controls where Dask will write data to disk if memory is exhausted\n", 40 | " local_directory = '/local_scratch/pbs.$PBS_JOBID/dask/spill',\n", 41 | " # This specifies which network interface the cluster will use\n", 42 | " interface = 'ext'\n", 43 | ")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "0d4322e9-cc4f-4b45-815c-9b8228eb03a2", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# Create the client to load the Dashboard\n", 54 | "client = Client(cluster)\n", 55 | "\n", 56 | "# Display the client repr\n", 57 | "client" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "0c737b08-9cf2-4e90-9646-2013641815b7", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# Scale and wait for workers\n", 68 | "\n", 69 | "cluster.scale(40)\n", 70 | "client.wait_for_workers(40)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "9d7d7583-c695-43c9-86a8-12f20b5d432d", 77 | "metadata": { 78 | "scrolled": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "import xarray as xr\n", 83 | "import pandas as pd\n", 84 | "import dask\n", 85 | "\n", 86 | "# Read in files\n", 87 | "ds = xr.open_mfdataset('/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL_202*.zarr',\n", 88 | " engine = 'zarr',\n", 89 | " consolidated=True,\n", 90 | " data_vars='minimal',\n", 91 | " coords='minimal',\n", 92 | " compat='override',\n", 93 | " parallel=True)\n", 94 | "\n", 95 | "# Rechunk the data\n", 96 | "ds = ds.chunk({\"time\": 1, \"level\": 1, \"latitude\": 640, \"longitude\": 1280})\n", 97 | "\n", 98 | "# Remove the old encoding info and set compression to none\n", 99 | "for k, v in ds.variables.items():\n", 100 | " v.encoding['compressors'] = None\n", 101 | " del v.encoding['chunks']\n", 102 | " del v.encoding['preferred_chunks']\n", 103 | "\n", 104 | "# Remove the old encoding info (default compression will then apply when written to Zarr)\n", 105 | "# for k, v in ds.variables.items():\n", 106 | "# del v.encoding['compressors']\n", 107 | "# del v.encoding['chunks']\n", 108 | "# del v.encoding['preferred_chunks']\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "53fd0270-d21e-4f2f-a769-1701900f66f4", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# Some not particularly polished data wrangling to combine the arrays\n", 119 | "# Skip this to write separate arrays\n", 120 | "\n", 121 | "full_variables = ['Q', 'T', 'U', 'V']\n", 122 | "single_level_variables = ['Q500', 'T500', 'U500', 'V500', 'Z500', 't2m', 'SP']\n", 123 | "\n", 124 | "ds1 = xr.concat([ds[x] for x in single_level_variables],\n", 125 | " pd.Index(single_level_variables,\n", 126 | " name='channel')).transpose('time',\n", 127 | " 'channel',\n", 128 | " 'latitude',\n", 129 | " 'longitude')\n", 130 | "\n", 131 | "c = xr.concat([ds[x] for x in full_variables], dim=full_variables)\n", 132 | "\n", 133 | "s = c.stack(channel = ('concat_dim','level')).transpose('time',\n", 134 | " 'channel',\n", 135 | " 'latitude',\n", 136 | " 'longitude').reset_index('channel')\n", 137 | "\n", 138 | "s['channel'] = s['concat_dim'] + s['level'].astype('str')\n", 139 | "\n", 140 | "ds2 = s.drop_vars(['level', 'concat_dim'])\n", 141 | "\n", 142 | "combined = xr.concat([ds1, ds2], dim='channel').rename('combined')\n", 143 | "\n", 144 | "combined.encoding" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "a3b394cc-9186-4a83-8a5d-2fedc3f10825", 151 | "metadata": { 152 | "scrolled": true 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "# Write to Zarr v3 with consolidated metdata\n", 157 | "\n", 158 | "combined.to_zarr('/glade/derecho/scratch/katelynw/era5/rechunked_stacked_uncompressed_test.zarr',\n", 159 | " zarr_version=3,\n", 160 | " consolidated=True)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "cd075006-9d9c-43b4-82ce-9cb1a7d1c576", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Shutdown the cluster\n", 171 | "\n", 172 | "client.shutdown()" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "9eedf120-3afa-4e26-a345-f58cbdc032a7", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# Open up the new dataset and check the encoding\n", 183 | "\n", 184 | "ds_new = xr.open_dataset('/glade/derecho/scratch/katelynw/era5/rechunked_stacked_uncompressed_test.zarr')\n", 185 | "\n", 186 | "ds_new.combined.encoding" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": "Python [conda env:my-env]", 193 | "language": "python", 194 | "name": "conda-env-my-env-py" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.12.9" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 5 211 | } 212 | -------------------------------------------------------------------------------- /zarr_ML_optimization/era5_dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | This module defines classes to handle ERA5 datasets stored in Zarr format, 4 | including support for PyTorch DataLoader and NVIDIA DALI pipelines. 5 | 6 | - ERA5Dataset: Load multi-year ERA5 data from Zarr stores. (No PyTorch dependency) 7 | - PyTorchERA5Dataset: PyTorch-compatible wrapper for ERA5Dataset. 8 | 9 | - SeqZarrSource: NVIDIA DALI-compatible external source for ERA5 Zarr data. 10 | - seqzarr_pipeline: DALI pipeline for loading Zarr data using SeqZarrSource. 11 | 12 | Example: 13 | python ERA5TimeSeriesDataset.py 14 | - Use the `--use-dali` flag to load data using DALI pipeline. 15 | """ 16 | import os 17 | from contextlib import nullcontext 18 | 19 | import numpy as np 20 | import cupy as cp 21 | import torch 22 | import xarray as xr 23 | import zarr 24 | 25 | import nvidia.dali as dali 26 | from nvidia.dali.pipeline import pipeline_def 27 | from torch.utils.data import Dataset, DataLoader 28 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 29 | import nvidia.dali.fn as fn 30 | 31 | class ERA5Dataset: 32 | """ 33 | Load multiple years of ERA5 and forcing datasets from Zarr. 34 | Each __getitem__(index) returns (input, target) as NumPy arrays. 35 | """ 36 | 37 | def __init__(self, data_path, start_year, end_year, input_vars, target_vars=None,forecast_step=1, use_synthetic=False): 38 | """ 39 | Initializes the dataset. 40 | 41 | Args: 42 | data_path (str): Path to the zarr store base.... 43 | start_year (int): Start year for the dataset. 44 | end_year (int): End year for the dataset. 45 | input_vars (list): List of input variable names. 46 | target_vars (list, optional): List of target variable names. Defaults to input_vars. 47 | """ 48 | self.data_path = data_path 49 | self.start_year = start_year 50 | self.end_year = end_year 51 | self.input_vars = input_vars 52 | self.target_vars = target_vars if target_vars is not None else input_vars 53 | self.normalized= False 54 | self.forecast_step = forecast_step 55 | self.use_synthetic = True if use_synthetic else False 56 | 57 | # load all zarr: 58 | self.dataset = self._load_data() 59 | self.ds_x, self.ds_y = self.fetch_timeseries(self.forecast_step) # Precompute pairs 60 | self.length = self.ds_x.sizes['time'] # Update length based on valid pairs 61 | 62 | def _load_data(self): 63 | """Loads all zarr files into a dictionary keyed by year.""" 64 | zarr_paths = [] 65 | for year in range(self.start_year, self.end_year + 1): 66 | zarr_path = os.path.join(self.data_path, f"SixHourly_y_TOTAL_{year}-01-01_{year}-12-31_rechunked_uncompressed.zarr") 67 | if os.path.exists(zarr_path): 68 | zarr_paths.append(zarr_path) 69 | else: 70 | print (f"{zarr_path} does not exist for year {year}. Skipping...") 71 | ds = xr.open_mfdataset(zarr_paths, engine='zarr', consolidated=True, combine='by_coords')[self.input_vars] 72 | self.length = ds.sizes['time'] 73 | return ds 74 | 75 | def __len__(self): 76 | """Returns the total number of samples in the dataset.""" 77 | return self.length 78 | 79 | 80 | def fetch_timeseries(self, forecast_step=1): 81 | """ 82 | Fetches the input and target timeseries data for a given forecast step. 83 | """ 84 | ds_x = self.dataset.isel(time=slice(None, -forecast_step)) 85 | ds_y = self.dataset.isel(time=slice(forecast_step, None)) 86 | return ds_x, ds_y 87 | 88 | def normalize (self, mean_file=None, std_file=None): 89 | """ 90 | Normalize the dataset using the mean and std files. 91 | """ 92 | if mean_file is not None and std_file is not None: 93 | mean = xr.open_dataset(mean_file) 94 | std = xr.open_dataset(std_file) 95 | else: 96 | mean = self.dataset.mean(dim='time') 97 | std = self.dataset.std(dim='time') 98 | self.dataset = (self.dataset - mean) / std 99 | self.normalized = True 100 | 101 | def __repr__(self): 102 | """Returns a summary of all datasets loaded.""" 103 | return self.dataset.__repr__() 104 | 105 | def __getitem__(self, index): 106 | """Enable direct indexing""" 107 | if self.use_synthetic: 108 | x_data = np.zeros([6, 640, 1280], dtype=np.float32) 109 | y_data = np.zeros([6, 640, 1280], dtype=np.float32) 110 | else: 111 | x_data = self.ds_x.isel(time=index).to_array().values 112 | y_data = self.ds_y.isel(time=index).to_array().values 113 | return (x_data, y_data) 114 | 115 | 116 | 117 | class PyTorchERA5Dataset(Dataset): 118 | """ 119 | Wraps the ERA5TimeSeriesDataset so it can be used in PyTorch DataLoader. 120 | """ 121 | def __init__(self, era5_dataset, forecast_step=1): 122 | """ 123 | era5_dataset (ERA5Dataset): An instance of the custom ERA5 dataset. 124 | forecast_step (int): The forecast step to use for fetching timeseries data. 125 | """ 126 | self.era5_dataset = era5_dataset 127 | self.forecast_step = forecast_step 128 | self.ds_x, self.ds_y = self.era5_dataset.fetch_timeseries(forecast_step=self.forecast_step) 129 | self.use_synthetic = self.era5_dataset.use_synthetic 130 | 131 | def __len__(self): 132 | """Returns the total number of samples in the dataset.""" 133 | return self.ds_x.sizes['time'] 134 | 135 | def __getitem__(self, index): 136 | """ 137 | Returns a single sample (input, target) as PyTorch tensors. 138 | 139 | Args: 140 | index (int): Index of the sample to retrieve. 141 | 142 | Returns: 143 | tuple: (input_tensor, target_tensor) 144 | """ 145 | if self.use_synthetic: 146 | x_tensor = torch.zeros([6, 640, 1280], dtype=torch.float32) 147 | y_tensor = torch.zeros([6, 640, 1280], dtype=torch.float32) 148 | else: 149 | x_data = self.ds_x.isel(time=index).to_array().values 150 | y_data = self.ds_y.isel(time=index).to_array().values 151 | # Extract data at the given index 152 | x_data = self.ds_x.isel(time=index).to_array().values 153 | y_data = self.ds_y.isel(time=index).to_array().values 154 | 155 | # Convert to PyTorch tensors 156 | x_tensor = torch.from_numpy(x_data).float() 157 | y_tensor = torch.from_numpy(y_data).float() 158 | 159 | return x_tensor, y_tensor 160 | 161 | def __repr__(self): 162 | x_tensor, y_tensor = self[0] 163 | """Returns a summary of all datasets loaded.""" 164 | return ( 165 | f"PyTorchERA5Dataset(forecast_step={self.forecast_step}, " 166 | f"use_synthetic={self.use_synthetic}, " 167 | f"length={len(self)}, " 168 | f"input_tensor_shape={tuple(x_tensor.shape)}, " 169 | f"target_tensor_shape={tuple(y_tensor.shape)}, " 170 | ) 171 | 172 | class SeqZarrSource: 173 | """ 174 | DALI Source for loading a zarr array. 175 | The arrays will be indexed along the first dimension (usually time). 176 | 177 | https://github.com/NVIDIA/modulus/blob/e6d7b02fb19ab9cdb3138de228ca3d6f0c99e7d1/examples/weather/unified_recipe/seq_zarr_datapipe.py#L186 178 | """ 179 | 180 | def __init__( 181 | self, 182 | file_store: str = "/glade/derecho/scratch/negins/era5/rechunked_stacked_uncompressed_test.zarr", 183 | variables: list[str] = ["combined"], 184 | start_year: int = 2010, 185 | end_year: int = 2010, 186 | num_steps: int = 2, 187 | batch_size: int = 16, 188 | shuffle: bool = True, 189 | process_rank: int = 0, 190 | world_size: int = 1, 191 | batch: bool = True, 192 | gpu: bool = True, 193 | ): 194 | # Set up parameters 195 | self.file_store = file_store 196 | self.variables = variables 197 | self.num_steps = num_steps 198 | self.batch_size = batch_size 199 | self.shuffle = shuffle 200 | self.batch = batch 201 | self.gpu = gpu 202 | 203 | # Check if all zarr arrays have the same first dimension 204 | _zarr_dataset: zarr.Group = zarr.open(self.file_store, mode="r") 205 | self.first_dim: int = _zarr_dataset[variables[0]].shape[0] 206 | for variable in self.variables: 207 | if _zarr_dataset[variable].shape[0] != self.first_dim: 208 | raise ValueError("All zarr arrays must have the same first dimension.") 209 | 210 | # Get number of samples 211 | self.indices: np.ndarray = np.arange( 212 | batch_size 213 | * world_size 214 | * ((self.first_dim - self.num_steps) // batch_size // world_size) 215 | ) 216 | self.indices: np.ndarray = np.array_split(self.indices, world_size)[ 217 | process_rank 218 | ] 219 | 220 | # Get number of full batches, ignore possible last incomplete batch for now. 221 | self.num_batches: int = len(self.indices) // self.batch_size 222 | 223 | # Set up last epoch 224 | self.last_epoch = None 225 | 226 | # Set zarr dataset 227 | self.zarr_dataset = None 228 | 229 | # Set call 230 | if self.batch: 231 | self._call = self.__call__ 232 | self.batch_mapping: np.ndarray = np.stack( 233 | np.array_split( 234 | self.indices[ 235 | : len(self.indices) - len(self.indices) % self.batch_size 236 | ], 237 | self.batch_size, 238 | ), 239 | axis=1, 240 | ) 241 | else: 242 | self._call = self._sample_call 243 | 244 | print (self.batch_mapping.shape) 245 | 246 | def __call__(self, index: list[np.ndarray]) -> tuple[np.ndarray, np.ndarray]: 247 | with zarr.config.enable_gpu() if self.gpu else nullcontext(): 248 | # Open Zarr dataset 249 | if self.zarr_dataset is None: 250 | self.zarr_dataset: zarr.Group = zarr.open(self.file_store, mode="r") 251 | 252 | #index: int = index[ 253 | # 0 254 | #] # turn [np.ndarray()] with one element to np.ndarray() 255 | index = int(index[0]) 256 | 257 | if index > self.batch_mapping.shape[0]: 258 | raise StopIteration() 259 | 260 | # Get batch indices 261 | if self.gpu: 262 | self.batch_mapping = cp.asanyarray(self.batch_mapping) 263 | batch_idx: np.ndarray = self.batch_mapping[index] 264 | time_idx: np.ndarray = cp.concatenate( 265 | [idx + cp.arange(self.num_steps) for idx in batch_idx] 266 | ) 267 | # print(time_idx) 268 | 269 | # Get data 270 | data = [] 271 | 272 | # Get slices 273 | for i, variable in enumerate(self.variables): 274 | batch_data = self.zarr_dataset[variable][time_idx.tolist()] 275 | data.append( 276 | cp.reshape( 277 | batch_data, 278 | (self.batch_size, self.num_steps, *batch_data.shape[1:]), 279 | ) 280 | ) 281 | # assert len(data) == 6 # number of variables 282 | # assert data[0].shape == (16, 2, 640, 1280) # BTHW 283 | 284 | # Stack variables along channel dimension, and split into two along timestep dim 285 | data_stack = cp.stack(data, axis=2) 286 | # assert data_stack.shape == (16, 2, 6, 640, 1280) # BTCHW 287 | data_x = data_stack[:, 0, :, :, :] 288 | # assert data_x.shape == (16, 6, 640, 1280) # BCHW 289 | data_y = data_stack[:, 1, :, :, :] 290 | # assert data_y.shape == (16, 6, 640, 1280) # BCHW 291 | 292 | # Return list to satisfy batch_processing=True 293 | return [data_x], [data_y] 294 | 295 | 296 | def __len__(self): 297 | if self.batch: 298 | print(f"Batch mapping shape: {self.batch_mapping.shape}") 299 | return len(self.batch_mapping) 300 | else: 301 | return len(self.indices) 302 | 303 | def __repr__(self): 304 | return ( 305 | f"{self.__class__.__name__}(\n" 306 | f" file_store={self.file_store!r},\n" 307 | f" variables={self.variables},\n" 308 | f" num_steps={self.num_steps},\n" 309 | f" batch_size={self.batch_size},\n" 310 | f" shuffle={self.shuffle},\n" 311 | f" batch={self.batch},\n" 312 | f" gpu={self.gpu},\n" 313 | f" first_dim={self.first_dim},\n" 314 | f" total_samples={len(self.indices)},\n" 315 | f" num_batches={self.num_batches}\n" 316 | f")" 317 | ) 318 | 319 | 320 | def build_seqzarr_pipeline(source: SeqZarrSource, batch_size: int = 16): 321 | """ 322 | Build the DALI pipeline for loading Zarr data. 323 | """ 324 | @pipeline_def( 325 | batch_size=4, 326 | num_threads=2, 327 | prefetch_queue_depth=2, 328 | py_num_workers=2, 329 | device_id=0, 330 | py_start_method="spawn", 331 | ) 332 | 333 | def seqzarr_pipeline(): 334 | """ 335 | Pipeline to load Zarr stores via a DALI External Source operator. 336 | """ 337 | # Zarr source 338 | source = SeqZarrSource(batch_size=16) 339 | print (source) 340 | print ("shape of this source:", source.__len__()) 341 | 342 | # generate indexes for the external source 343 | def index_generator(idx: int) -> np.ndarray: 344 | return np.array([idx]) 345 | 346 | indexes = dali.fn.external_source( 347 | source=index_generator, 348 | dtype=dali.types.INT64, 349 | device="gpu" if source.gpu else "cpu", 350 | batch=True, 351 | ) 352 | 353 | print (indexes) 354 | 355 | # Use DALI to read current batch from SeqZarrSource 356 | data_x, data_y = dali.fn.python_function( 357 | indexes, 358 | function=source, 359 | batch_processing=True, 360 | num_outputs=2, 361 | device="gpu" if source.gpu else "cpu", 362 | ) 363 | 364 | #data_x = data_x.squeeze(0).squeeze(1) 365 | #data_y = data_y.squeeze(0).squeeze(1) 366 | 367 | data_x = fn.reshape(data_x, shape=[16,6,640,1280]) 368 | data_y = fn.reshape(data_y, shape=[16,6,640,1280]) 369 | 370 | # if self.device.type == "cuda": 371 | # Move tensors to GPU as external_source won't do that automatically 372 | if not source.gpu: 373 | data_x = data_x.gpu() 374 | data_y = data_y.gpu() 375 | 376 | 377 | # Set outputs 378 | return data_x, data_y 379 | return seqzarr_pipeline() 380 | 381 | # -------------------------------------------------# 382 | # ----------------- Example usage -----------------# 383 | # -------------------------------------------------# 384 | if __name__ == "__main__": 385 | import argparse 386 | 387 | # Set up simple argument parser 388 | parser = argparse.ArgumentParser(description="ERA5 Data Loader...") 389 | parser.add_argument( 390 | "--use-dali", 391 | action="store_true", 392 | help="Use DALI pipeline instead of standard loading" 393 | ) 394 | args = parser.parse_args() 395 | 396 | # Hardcoded parameters (as in original code) 397 | data_path = "/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv" 398 | input_vars = ['t2m', 'V500', 'U500', 'T500', 'Z500', 'Q500'] 399 | target_vars = ['t2m', 'V500', 'U500', 'T500', 'Z500', 'Q500'] 400 | start_year = 2010 401 | end_year = 2010 402 | 403 | if not args.use_dali: 404 | # Standard dataset loading 405 | train_dataset = ERA5Dataset( 406 | data_path=data_path, 407 | start_year=start_year, 408 | end_year=end_year, 409 | input_vars=input_vars, 410 | target_vars=target_vars 411 | ) 412 | print(train_dataset) 413 | train_pytorch = PyTorchERA5Dataset( 414 | train_dataset, 415 | forecast_step=1 416 | ) 417 | print(train_pytorch) 418 | 419 | # Example of using PyTorch DataLoader 420 | train_loader = DataLoader(train_pytorch, batch_size=16, pin_memory=True, shuffle=True) 421 | print (f"Number of batches: {len(train_loader)}") 422 | print (f"Batch size: {train_loader.batch_size}") 423 | 424 | for i, batch in enumerate(train_loader): 425 | inputs, targets = batch 426 | 427 | print(f"Batch {i+1}: inputs shape = {inputs.shape}, targets shape = {targets.shape}") 428 | 429 | sample_size_bytes = ( 430 | inputs.element_size() * inputs.nelement() + 431 | targets.element_size() * targets.nelement() 432 | ) 433 | sample_size_mb = sample_size_bytes / 1024 / 1024 / inputs.shape[0] # per sample 434 | print(f"Estimated sample size: {sample_size_mb:.2f} MB") 435 | 436 | print(f"Total samples in dataset: {len(train_loader)}") 437 | 438 | break 439 | 440 | else: 441 | # DALI pipeline loading 442 | source = SeqZarrSource(batch_size=16) 443 | print ("...shape of this source:", source.__len__()) 444 | pipe = build_seqzarr_pipeline(source=source) 445 | pipe.build() 446 | 447 | 448 | 449 | #pipe = seqzarr_pipeline() 450 | train_loader = DALIGenericIterator( 451 | pipelines=pipe, 452 | output_map=["input", "target"], 453 | auto_reset=True, 454 | last_batch_padded=False, 455 | #fill_last_batch=False, 456 | #size = -1, 457 | ) 458 | print (f"Number of batches: {len(train_loader)}") -------------------------------------------------------------------------------- /zarr_ML_optimization/train_unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | A U-Net benchmark using PyTorch and ERA5 surface variables 4 | """ 5 | import os 6 | import time 7 | import argparse 8 | import logging 9 | 10 | import xarray as xr 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.utils.data import Dataset, DataLoader 17 | from torch.utils.data.distributed import DistributedSampler 18 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 19 | import segmentation_models_pytorch as smp 20 | 21 | from era5_dataloader import ( 22 | ERA5Dataset, 23 | PyTorchERA5Dataset, 24 | build_seqzarr_pipeline, 25 | SeqZarrSource, 26 | ) 27 | 28 | from trainer_utils import setup_logging, set_random_seeds, init_process_group, custom_loss 29 | 30 | 31 | # --- Logger Setup --- 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def main(): 36 | num_epochs_default = 1 37 | batch_size_default = 16 38 | learning_rate_default = 0.001 # Adjusted for Adam optimizer 39 | 40 | parser = argparse.ArgumentParser( 41 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 42 | ) 43 | parser.add_argument( 44 | "--num_epochs", 45 | type=int, 46 | help="Number of training epochs.", 47 | default=num_epochs_default, 48 | ) 49 | parser.add_argument( 50 | "--batch_size", 51 | type=int, 52 | help="Training batch size for one process.", 53 | default=batch_size_default, 54 | ) 55 | parser.add_argument( 56 | "--learning_rate", 57 | type=float, 58 | help="Learning rate.", 59 | default=learning_rate_default, 60 | ) 61 | parser.add_argument( 62 | "--distributed", 63 | action="store_true", 64 | help="Use distributed data parallel (DDP).", 65 | ) 66 | parser.add_argument( 67 | "--skip-training", 68 | action="store_true", 69 | help="Skip training & validation for benchmarking purposes.", 70 | dest="notraining", 71 | ) 72 | parser.add_argument( 73 | "--synth", 74 | "--use-synthetic", 75 | action="store_true", 76 | help="Use synthetic data to skip loading ERA5 data (for benchmarking).", 77 | ) 78 | parser.add_argument( 79 | "--use-dali", 80 | action="store_true", 81 | help="Use DALI pipeline instead of regular Pytorch Dataloader.", 82 | ) 83 | parser.add_argument( 84 | "--skip-validation", 85 | action="store_false", 86 | help="Skip validation loop.", 87 | dest= "validation", 88 | default=0, 89 | ) 90 | parser.add_argument( 91 | "--era5_path", 92 | type=str, 93 | help="Path to the ERA5 Zarr dataset.", 94 | default="/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv", 95 | ) 96 | parser.add_argument( 97 | "--num_workers", 98 | type=int, 99 | help="Number of workers for DataLoader.", 100 | default=1, 101 | ) 102 | 103 | argv = parser.parse_args() 104 | num_epochs = argv.num_epochs 105 | batch_size = argv.batch_size 106 | learning_rate = argv.learning_rate 107 | distributed = argv.distributed 108 | use_synthetic = argv.synth 109 | use_dali = argv.use_dali 110 | 111 | # --------------------------- 112 | # Set random seeds for reproducibility! 113 | random_seed = 0 114 | set_random_seeds(random_seed=random_seed) 115 | 116 | # -------------------------- 117 | # Initialize the process group for single or multi-GPU training 118 | LOCAL_RANK, WORLD_SIZE, WORLD_RANK = init_process_group( 119 | distributed=distributed, backend="nccl" 120 | ) 121 | # -------------------------- 122 | setup_logging(world_rank=WORLD_RANK) 123 | logger.info(f"Distributed : {distributed}") 124 | logger.info(f"Using DALI : {use_dali}") 125 | logger.info(f"Synthetic data : {use_synthetic}") 126 | 127 | # -------------------------- 128 | # Read the ERA5 Zarr dataset 129 | #data_path = "/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/" 130 | #data_path = "/glade/derecho/scratch/negins/CREDIT_data/ERA5_mlevel_arXiv" # Updated path for optimized zarr chunks 131 | data_path = argv.era5_path 132 | if use_dali: 133 | input_vars = ["combined"] * 6 # 6 input variables -- here for proof of concept, we combine all 6 input variables into one "combined" variable 134 | target_vars = [ 135 | "combined" 136 | ] * 6 # Stacked input variables into one "combined" variable 137 | # input_vars = ['t2m','V500', 'U500', 'T500', 'Z500', 'Q500'] # 6 input variables 138 | # target_vars = ['t2m','V500', 'U500', 'T500', 'Z500', 'Q500'] # Predict all 6 variables 139 | else: 140 | input_vars = [ 141 | "t2m", 142 | "V500", 143 | "U500", 144 | "T500", 145 | "Z500", 146 | "Q500", 147 | ] # 6 input variables 148 | target_vars = [ 149 | "t2m", 150 | "V500", 151 | "U500", 152 | "T500", 153 | "Z500", 154 | "Q500", 155 | ] 156 | 157 | train_start_year, train_end_year = 2013, 2013 158 | val_start_year, val_end_year = 2013, 2013 159 | 160 | # ----------------------------------------------------------------------- 161 | # Create train and validation datasets 162 | # ----------------------------------------------------------------------- 163 | # 1) Training dataset 164 | 165 | requested_workers = argv.num_workers 166 | num_workers = min( 167 | requested_workers, 168 | os.cpu_count() // 2, # Safe default 169 | torch.cuda.device_count() * 4 # GPU-aware 170 | ) 171 | 172 | if use_dali: 173 | # pipe_train = seqzarr_pipeline() 174 | # train_loader = DALIGenericIterator( 175 | # pipelines=pipe_train, output_map=["input", "target"] 176 | # ) 177 | train_source = SeqZarrSource(batch_size=batch_size ) 178 | num_train_samples=train_source.__len__() 179 | logger.info("...shape of this source:", train_source.__len__()) 180 | logger.info(f"DALI Train Source: Total samples reported by SeqZarrSource.__len__(): {num_train_samples}") 181 | pipe_train = build_seqzarr_pipeline(source=train_source) 182 | pipe_train.build() 183 | train_loader = DALIGenericIterator( 184 | pipelines=pipe_train, output_map=["input", "target"], 185 | ) 186 | logger.info(f"Train loader effective size (batches): {len(train_loader)}") 187 | if distributed: 188 | raise NotImplementedError("DALI pipeline with distributed not working yet") 189 | elif not use_dali: 190 | train_dataset = ERA5Dataset( 191 | data_path=data_path, 192 | start_year=train_start_year, 193 | end_year=train_end_year, 194 | input_vars=input_vars, 195 | target_vars=target_vars, 196 | forecast_step=1, 197 | use_synthetic=use_synthetic, 198 | ) 199 | 200 | # Normalize the dataset using pre-computed mean and std files 201 | mean_file = "/glade/derecho/scratch/negins/hackathon-files/mean_6h_0.25deg.nc" # pre-computed mean file for normalization -- copied over from /glade/campaign/cisl/aiml/ksha/CREDIT/ 202 | std_file = "/glade/derecho/scratch/negins/hackathon-files/std_6h_0.25deg.nc" # pre-computed std file for normalization -- copied over from /glade/campaign/cisl/aiml/ksha/CREDIT/ 203 | train_dataset.normalize(mean_file=mean_file, std_file=std_file) 204 | train_pytorch = PyTorchERA5Dataset(train_dataset, forecast_step=1) 205 | 206 | if distributed: 207 | train_sampler = DistributedSampler(dataset=train_pytorch, shuffle=False) 208 | train_loader = DataLoader( 209 | train_pytorch, 210 | batch_size=batch_size, 211 | num_workers=num_workers, 212 | sampler=train_sampler, 213 | drop_last=True, # Drop last incomplete batch -- easier benchmarking 214 | ) 215 | else: 216 | train_loader = DataLoader( 217 | train_pytorch, 218 | batch_size=batch_size, 219 | pin_memory=True, 220 | num_workers=num_workers, 221 | persistent_workers=True, 222 | drop_last=True, 223 | ) 224 | 225 | # -------------------------- 226 | # 2) validation dataset 227 | # -------------------------- 228 | if use_dali: 229 | source = SeqZarrSource(batch_size=batch_size) 230 | logger.info("...shape of this source:", source.__len__()) 231 | num_val_samples = source.__len__() 232 | pipe_val = build_seqzarr_pipeline(source=source) 233 | pipe_val.build() 234 | val_loader = DALIGenericIterator( 235 | pipelines=pipe_val, output_map=["input", "target"], 236 | ) 237 | if distributed: 238 | raise NotImplementedError("DALI pipeline with distributed not working yet") 239 | elif not use_dali: # classic pytorch dataset 240 | val_dataset = ERA5Dataset( 241 | data_path=data_path, 242 | start_year=val_start_year, 243 | end_year=val_end_year, 244 | input_vars=input_vars, 245 | target_vars=target_vars, 246 | use_synthetic=use_synthetic, 247 | ) 248 | val_dataset.normalize(mean_file=mean_file, std_file=std_file) 249 | val_pytorch = PyTorchERA5Dataset(val_dataset, forecast_step=1) 250 | 251 | val_sampler = DistributedSampler(dataset=val_pytorch, shuffle=False) if distributed else None 252 | val_loader = DataLoader( 253 | val_pytorch, 254 | batch_size=batch_size, 255 | num_workers=num_workers, 256 | sampler=val_sampler, 257 | drop_last=True, # Drop last incomplete batch -- easier benchmarking 258 | ) 259 | 260 | if not use_dali: 261 | logger.info(f"Using PyTorch DataLoader (workers: {num_workers})") 262 | logger.info(f"Train samples: {len(train_loader.dataset)}") 263 | logger.info(f"Validation samples: {len(val_loader.dataset)}") 264 | 265 | # -------------------------- 266 | # Define the U-Net model using segmentation_models_pytorch 267 | ENCODER = "resnet18" # Encoder backbone 268 | ENCODER_WEIGHTS = "imagenet" # Pretrained weights 269 | # ENCODER_WEIGHTS = None # No pretrained weights 270 | CLASSES = input_vars # Number of output channels (same as input variables) 271 | ACTIVATION = None # No activation for regression tasks 272 | 273 | # Create the U-Net model 274 | model = smp.Unet( 275 | encoder_name=ENCODER, 276 | encoder_weights="imagenet", 277 | decoder_attention_type="scse", 278 | in_channels=len(input_vars), 279 | classes=len(CLASSES), # Number of output channels 280 | activation=ACTIVATION, 281 | ) 282 | 283 | # -------------------------- 284 | # Move the model to GPU 285 | if distributed: 286 | torch.cuda.set_device(LOCAL_RANK) 287 | device = torch.device("cuda:{}".format(LOCAL_RANK)) 288 | model = model.to(device) 289 | # Wrap the model with DDP 290 | ddp_model = torch.nn.parallel.DistributedDataParallel( 291 | model, 292 | device_ids=[LOCAL_RANK], 293 | output_device=LOCAL_RANK, 294 | find_unused_parameters=True, 295 | ) 296 | model = ddp_model.to(device) 297 | else: 298 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 299 | model = model.to(device) 300 | 301 | # -------------------------- 302 | # Define the loss function and optimizer 303 | #criterion = torch.nn.L1Loss() 304 | optimizer = torch.optim.AdamW( 305 | model.parameters(), lr=learning_rate, weight_decay=1e-4 306 | ) 307 | 308 | # -------------------------- 309 | # Training Loop 310 | 311 | training_start_time = time.time() 312 | epoch_metrics_history = [] 313 | 314 | logger.info("-" * 50) 315 | logger.info("Starting training loop...") 316 | 317 | num_train_steps = num_train_samples if use_dali else len(train_loader) 318 | num_val_steps = num_val_samples if use_dali else len(val_loader) 319 | 320 | for epoch in range(num_epochs): 321 | epoch_train_steps = 0 322 | epoch_val_steps = 0 323 | running_loss = 0.0 324 | epoch_train_losses = [] 325 | 326 | epoch_start_time = time.time() 327 | 328 | model.train() 329 | 330 | for i, batch in enumerate(train_loader): 331 | start_time = time.time() # Start time for the train step 332 | 333 | if len(batch) == 1: # DALI 334 | inputs = batch[0]["input"].squeeze(dim=(0, 2)) 335 | targets = batch[0]["target"].squeeze(dim=(0, 2)) 336 | if i == num_train_samples-2: 337 | logger.info(f"Last batch in epoch {epoch+1} has shape: {inputs.shape}, {targets.shape}") 338 | break 339 | 340 | 341 | else: # non-DALI 342 | inputs, targets = batch 343 | 344 | inputs, targets = inputs.to(device), targets.to(device) 345 | sample_train_size = ( 346 | inputs.element_size() * inputs.nelement() 347 | + targets.element_size() * targets.nelement() 348 | ) / batch_size 349 | sample_train_size_mb = sample_train_size / 1024 / 1024 # Convert to MB 350 | 351 | if not argv.notraining: # training 352 | 353 | optimizer.zero_grad() 354 | # Forward pass 355 | outputs = model(inputs) 356 | 357 | # Compute loss 358 | loss, loss_components = custom_loss(outputs, targets) 359 | 360 | # Backprop 361 | loss.backward() 362 | # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Prevents exploding gradients 363 | 364 | optimizer.step() 365 | torch.cuda.synchronize() 366 | 367 | epoch_train_losses.append(loss_components) 368 | running_loss += loss.item() 369 | 370 | step_train_time = ( 371 | time.time() - start_time 372 | ) # Compute elapsed time in milliseconds 373 | 374 | epoch_train_steps += 1 375 | 376 | if WORLD_RANK == 0: 377 | print( 378 | f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{num_train_steps}], " 379 | f"Loss: {loss.item():.4f}," 380 | f"RMSE: {loss_components['rmse']:.2f}, Std Diff: {loss_components['std_diff']:.2f},", 381 | f"Time per training step: {step_train_time:.4f} sec.", 382 | ) 383 | 384 | else: # Skip training for benchmarking purposes 385 | # Time should come out as 0.0 386 | torch.cuda.synchronize() 387 | step_train_time = time.time() - start_time 388 | epoch_train_steps += 1 389 | if WORLD_RANK == 0: 390 | print( 391 | f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{num_train_steps}], " 392 | f"Time per training step: {step_train_time:.4f} sec." 393 | ) 394 | 395 | # End of training loop for this epoch 396 | stop_train_time = time.time() 397 | 398 | # reset the training loader for the next epoch 399 | if use_dali: 400 | train_loader.reset() 401 | 402 | # ----------------------------------------------------------------- 403 | # Validation Loop 404 | # ----------------------------------------------------------------- 405 | logger.info("-" * 50) 406 | logger.info("Starting validation loop...") 407 | model.eval() 408 | epoch_val_losses = [] 409 | 410 | start_val_time = time.time() 411 | sample_val_size_mb = 0 412 | if argv.validation: 413 | with torch.no_grad(): 414 | 415 | # for i, (inputs, targets) in enumerate(val_loader): 416 | for i, batch in enumerate(val_loader): 417 | 418 | step_val_start_time = time.time() # Start time for the step 419 | 420 | if len(batch) == 1: # DALI 421 | inputs = batch[0]["input"].squeeze(dim=(0, 2)) 422 | targets = batch[0]["target"].squeeze(dim=(0, 2)) 423 | if i == num_train_samples-2: 424 | logger.info(f"Last batch in epoch {epoch+1} has shape: {inputs.shape}, {targets.shape}") 425 | break 426 | else: 427 | inputs, targets = batch 428 | 429 | inputs, targets = inputs.to(device), targets.to(device) 430 | 431 | sample_val_size = ( 432 | inputs.element_size() * inputs.nelement() 433 | + targets.element_size() * targets.nelement() 434 | ) / batch_size 435 | sample_val_size_mb = sample_val_size / 1024 / 1024 # Convert to MB 436 | 437 | if not argv.notraining: 438 | outputs = model(inputs) 439 | loss, loss_components = custom_loss(outputs, targets) 440 | epoch_val_losses.append(loss_components) 441 | torch.cuda.synchronize() 442 | step_val_time = time.time() - step_val_start_time 443 | epoch_val_steps += 1 444 | 445 | if WORLD_RANK == 0: 446 | print( 447 | f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{num_val_steps}], " 448 | f"Validation Loss: {loss.item():.4f}," 449 | f"RMSE: {loss_components['rmse']:.2f}, Std Diff: {loss_components['std_diff']:.2f},", 450 | f"Time per validation step: {step_val_time:.4f} sec.", 451 | ) 452 | 453 | else: 454 | # Skip validation for benchmarking purposes 455 | step_val_time = time.time() - step_val_start_time 456 | epoch_val_steps += 1 457 | if WORLD_RANK == 0: 458 | print( 459 | f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{num_val_steps}], " 460 | f"Time per validation step: {step_val_time:.4f} sec." 461 | ) 462 | 463 | torch.cuda.synchronize() 464 | stop_val_time = time.time() 465 | 466 | 467 | if use_dali: 468 | val_loader.reset() 469 | 470 | if WORLD_RANK == 0: 471 | epoch_time = time.time() - epoch_start_time 472 | val_time = stop_val_time - start_val_time 473 | train_time = stop_train_time - epoch_start_time 474 | 475 | # Throughput calculation 476 | num_train_batches_processed = epoch_train_steps 477 | num_val_batches_processed = epoch_val_steps 478 | 479 | total_samples_processed_epoch = (num_train_batches_processed + num_val_batches_processed) * batch_size * WORLD_SIZE 480 | 481 | throughput_sps = total_samples_processed_epoch / epoch_time 482 | throughput_mbps = ( 483 | sample_train_size_mb * batch_size * WORLD_SIZE / train_time + 484 | sample_val_size_mb * batch_size * WORLD_SIZE / val_time 485 | ) 486 | throughput_mbps = throughput_sps * sample_train_size_mb 487 | 488 | current_epoch_metrics = { 489 | "epoch": epoch + 1, 490 | "epoch_time": epoch_time, 491 | "train_time": train_time, 492 | "val_time": val_time, 493 | "throughput_sps": throughput_sps, 494 | "throughput_mbps": throughput_mbps, 495 | "total_samples": total_samples_processed_epoch, 496 | "sample_train_size_mb": sample_train_size_mb, 497 | "sample_val_size_mb": sample_val_size_mb, 498 | } 499 | 500 | epoch_metrics_history.append(current_epoch_metrics) 501 | 502 | 503 | # Logging output 504 | print("\n" + "-" * 60) 505 | print(f"Epoch [{epoch+1}/{num_epochs}] Summary") 506 | print(f" Total Epoch Time : {epoch_time:.2f} sec") 507 | print(f" Training Time : {train_time:.2f} sec") 508 | print(f" Validation Time : {val_time:.2f} sec") 509 | print(f" WORLD_SIZE : {WORLD_SIZE}") 510 | print(f" Total Samples : {total_samples_processed_epoch}") 511 | print(f" Sample Size : {sample_train_size_mb:.2f} MB") 512 | print(f" Throughput (samples) : {throughput_sps:.2f} samples/sec") 513 | print(f" Throughput (MB) : {throughput_mbps:.2f} MB/sec") 514 | print("\n" + "-" * 60) 515 | 516 | 517 | if distributed: 518 | torch.distributed.barrier() 519 | torch.distributed.destroy_process_group() 520 | logger.info("Destroyed process group!") 521 | 522 | # End of training loop for all epochs 523 | 524 | logger.info("Training completed.") 525 | 526 | if WORLD_RANK == 0: 527 | total_time = time.time() - training_start_time 528 | avg_epoch_wall_time = np.mean([m['epoch_time'] for m in epoch_metrics_history]) 529 | avg_train_loop_time = np.mean([m['train_time'] for m in epoch_metrics_history]) 530 | avg_val_loop_time = np.mean([m['val_time'] for m in epoch_metrics_history]) 531 | avg_tput_sps = np.mean([m['throughput_sps'] for m in epoch_metrics_history]) 532 | avg_tput_mbps = np.mean([m['throughput_mbps'] for m in epoch_metrics_history]) 533 | avg_train_sample_size = np.mean([m['sample_train_size_mb'] for m in epoch_metrics_history]) 534 | avg_val_sample_size = np.mean([m['sample_val_size_mb'] for m in epoch_metrics_history]) 535 | 536 | print("-" * 50) 537 | print ("Overall Training Summary (Averages Over All Epochs)") 538 | print(f"Total training time (sec) : {total_time:.2f}") 539 | print(f"Average epoch wall time (sec) : {avg_epoch_wall_time:.2f}") 540 | print(f"Average train loop time (sec) : {avg_train_loop_time:.2f}") 541 | print(f"Average validation loop time (sec) : {avg_val_loop_time:.2f}") 542 | print(f"Average throughput (samples/sec) : {avg_tput_sps:.2f}") 543 | print(f"Average throughput (MB/sec) : {avg_tput_mbps:.2f}") 544 | print(f"Average sample size (MB) : {avg_train_sample_size:.2f}") 545 | print("-" * 50) 546 | 547 | 548 | if __name__ == "__main__": 549 | main() 550 | --------------------------------------------------------------------------------