├── .gitignore ├── LICENSE ├── README.md ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── environment.yml ├── exps ├── generate_samples.sh ├── generate_video_from_samples.sh ├── json │ └── example_ns.json └── main.sh ├── generate.py ├── media ├── dataset.gif ├── generated.gif └── generated128.gif ├── resume.py ├── scripts ├── f2id.py ├── generate_frames_from_dataset.py └── generate_frames_from_samples.py ├── torch_utils ├── __init__.py ├── distributed.py ├── misc.py ├── persistence.py └── training_stats.py ├── train.py └── training ├── __init__.py ├── dataset_edm.py ├── datasets ├── __init__.py ├── dataset.py ├── dataset_fake.py └── dataset_ns.py ├── img_utils.py ├── loss.py ├── networks.py ├── noise_samplers.py └── training_loop.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.bak 3 | *.py~ 4 | *.py# 5 | *.pt 6 | downloads/ 7 | fid-refs/ 8 | exps/ 9 | slurm*.out 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 NVIDIA Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EDM in function spaces (edm-fs) 2 | 3 | This is the public code release of Christopher Beckham's internship at NVIDIA in using neural operators for time series modelling of climate data. This repository is based on [EDM](https://github.com/NVlabs/edm/). 4 | 5 | ## Setup 6 | 7 | This repository is based on the EDM codebase and as such will have a similar set of requirements. First create a conda environment. We will use the `environment.yml` file that is in the root directory of this repository. This is the same as the corresponding environment file in the original EDM repository from Karras et al located [here](https://github.com/NVlabs/edm). 8 | 9 | To create an environment called `edm_fs` we do: 10 | 11 | ``` 12 | conda env create -n edm_fs -f environment.yml 13 | ``` 14 | 15 | Some dependencies may be flexible but PyTorch 1.12 is absolutely crucial (see [this issue](https://github.com/NVlabs/edm/issues/18)). 16 | 17 | ### Dependencies 18 | 19 | We also need to install the `neuraloperators` library. For this repo we need to use my fork of it [here](https://github.com/christopher-beckham/neuraloperator). (Whoever builds on top of this repo may find it useful to see what has changed and try to reconsolidate it with the latest version of the library.) 20 | 21 | Clone it, switch to the `dev_refactor` branch and install it with the following steps: 22 | 23 | ``` 24 | git clone git@github.com:christopher-beckham/neuraloperator.git 25 | cd neuraloperator 26 | git checkout dev_refactor 27 | pip install -e . 28 | ``` 29 | 30 | Note that you may also need to install other dependencies that are required by `neuraloperators`. Check their `requirements.txt` file. We need to also install: 31 | 32 | - jstyleson: `pip install jstyleson` 33 | - torchvision: `pip install torchvision==0.13 --no-deps` (0.13 is compatible with PyTorch `1.12`, and you must specify no-deps so that it doesn't force installing PyTorch 2.0) 34 | - ffmpeg: `conda install -c conda-forge ffmpeg` 35 | 36 | ### Environment variables 37 | 38 | Also cd into `exps` and copy `cp env.sh.bak env.sh` and define the following env variables: 39 | 40 | - `$SAVE_DIR`: some location where experiments are to be saved. 41 | - `$DATA_DIR`: this should be kept as is, since it points to the raw CWA/ERA dataset zarr file. 42 | 43 | ### Dataset 44 | 45 | Since the climate dataset used for the internship is not (yet) public, we have to use an open source substitute. For this we use a 2D Navier-Stokes dataset which consists of trajectories in time. Concretely, the neural operator diffusion model is trained to model the following conditional distribution `p(u_t | y_{t-k}, ..., y_{t+k})` where: 46 | - `u` is the function to model (whose samples are at a fixed discretisation which is defined by `resolution` in `training.datasets.NSDataset`, e.g. `128` for 128px on both spatial dimensions); 47 | - `y` are samples from the same function at a much coarser discretisation, and this is determined by `lowres_scale_factor` in `training.datasets.NSDataset`, i.e. if `lowres_scale_factor=0.0625` then this is a 16x reduction in spatial resolution). 48 | - `k` denotes the size of the context window. 49 | 50 | Whatever `$DATA_DIR` is defined to, cd into that directory and download the data: 51 | 52 | ``` 53 | wget https://zenodo.org/records/7495555/files/2D_NS_Re40.npy 54 | ``` 55 | 56 | ### Visualising dataset 57 | 58 | We can visualise the dataset as a video. For example, if we define `y` to be a 16x reduction of the original resolution then we run the following: 59 | 60 | ``` 61 | python -m scripts.generate_frames_from_dataset \ 62 | --outdir=/tmp/dataset \ 63 | --dataset_kwargs '{"lowres_scale_factor": 0.0625}' 64 | cd /tmp/dataset 65 | ffmpeg -framerate 30 -pattern_type glob -i '*.png' \ 66 | -c:v libx264 -pix_fmt yuv420p out.mp4 67 | ``` 68 | 69 | Here is an example output (in gif format): 70 | 71 | ![dataset viz](media/dataset.gif) 72 | 73 | Here, the `y` variable is actually 8x8 px (i.e. `128*0.0625 = 8`) but has been upsampled back up to 128px with bilinear resampling. 74 | 75 | ## Running experiments 76 | 77 | This code assumes you have a Slurm-based environment and that you are either in an interactive job or will be launching a job. If this is not the case then you can still run the below code but you must ensure that `$SLURM_JOB_ID` is defined. For instance, in a non-Slurm environment you can simply set this to be any other unique identifier. 78 | 79 | Experiments are launched by going into `exps` and running `main.sh` with the following arguments: 80 | 81 | ``` 82 | bash main.sh 83 | ``` 84 | 85 | `` means that the experiment will be saved to `$SAVE_DIR//`. Example json config files are in `exps/json` and you can consult the full set of supported arguments in `train.py`. `` specifies how many GPUs to train on. 86 | 87 | For running experiments with `sbatch`, write a wrapper script which calls `main.sh` and specifies any required arguments. 88 | 89 | ### Generation 90 | 91 | To generate a trajectory using a pretrained model, we can use the `generate_samples.sh` script in `exps`. This is a convenience wrapper on top of `generate.py` and it is run with the following arguments: 92 | 93 | ``` 94 | bash generate_samples.sh \ 95 | \ 96 | \ 97 | \ 98 | [] 99 | ``` 100 | 101 | Respectively, these arguments correspond to: 102 | - The name of the experiment, _relative_ to `$SAVE_DIR`; 103 | - number of diffusion steps to perform (higher is better quality but takes longer); 104 | - length of the trajectory to generate; 105 | - and output file. 106 | 107 | Please see `generate.py` for the full list of supported arguments, and see the next section for an example of how to generate with this script. 108 | 109 | ## Pretrained models 110 | 111 | An example pretrained model can be downloaded [here](https://drive.google.com/file/d/1lpH6WVPqjZU1qNCH_2aWejU834mo6Urj/view?usp=drive_link). Download it to `$SAVE_DIR` and untar it via: 112 | 113 | ``` 114 | cd $SAVE_DIR && tar -xvzf test_ns_ws3_ngf64_v2.tar.gz 115 | ``` 116 | 117 | To generate a sample file from this model for 200 diffusion timesteps, cd into `exps` and run: 118 | 119 | ``` 120 | bash generate_samples.sh \ 121 | test_ns_ws3_ngf64_v2/4148123 \ 122 | 200 \ 123 | 64 \ 124 | samples200.pt 125 | ``` 126 | 127 | This will spit out a file called `samples200.pt` in the same directory. To generate a 30-fps video from these samples, run the following: 128 | 129 | ``` 130 | bash generate_video_from_samples.sh \ 131 | samples200.pt \ 132 | 30 \ 133 | samples200.pt.mp4 134 | ``` 135 | 136 | Example video (in gif format): 137 | 138 | ![generated viz](media/generated.gif) 139 | 140 | with the first column denoting ground truth `u_t` (it is the same across each row), middle column denoting the generated function from diffusion `\tilde{u_t}`, and the third column denoting the low-res function `y_t` (again, same for each row). 141 | 142 | ### Super-resolution 143 | 144 | Since this model was trained with samples from `u` being 64px, we can perform 2x super-resolution by passing in `--resolution=128` like so: 145 | 146 | ``` 147 | bash generate_from_samples.sh \ 148 | test_ns_ws3_ngf64_v2/4148123 \ 149 | 200 \ 150 | samples200_128.pt \ 151 | --resolution=128 --batch_size=16 152 | ``` 153 | 154 | Example video (in gif format): 155 | 156 | ![generated viz](media/generated128.gif) 157 | 158 | ## Bugs / limitations 159 | 160 | This section details some things that should be improved or considered by whomever forks this repository. 161 | 162 | ### Generation 163 | 164 | If you make posthoc changes to the model code (e.g. `training.networks.py`) and then want to generate samples you should also add `--reload_network`, e.g 165 | 166 | ``` 167 | bash generate.sh ... --reload_network 168 | ``` 169 | 170 | This will tell the generation script to instead instantiate the model with its network definition as defined in `networks` and then load the weights from the pickle. By default, EDM's training script pickles not just the model weights but also the network code, and this can be frustrating if one wants to make post-hoc changes to the code which are backward compatible with existing pretrained models. 171 | 172 | ### Training 173 | 174 | Neural operators require significantly more parameters than their finite-d counterparts and this issue is also exacerbated when one is training high-res diffusion models. I suggest future works look at latent consistency models, e.g. performing function space diffusion in the latent space of a pretrained autoencoder. Otherwise, the code should be modified to support `float16` training to alleviate the memory burden. 175 | 176 | ## Credits 177 | 178 | Thanks to my co-authors Kamyar Azzizadenesheli, Nikola Kovachki, Jean Kossaifi, Boris Bonev, and Anima Anandkumar. Special thanks to Tero Karras, Morteza Mardani, Noah Brenowitz, and Miika Aittala. 179 | -------------------------------------------------------------------------------- /dataset_tool.py: -------------------------------------------------------------------------------- 1 | """Tool for creating ZIP/PNG based datasets.""" 2 | 3 | import functools 4 | import gzip 5 | import io 6 | import json 7 | import os 8 | import pickle 9 | import re 10 | import sys 11 | import tarfile 12 | import zipfile 13 | from pathlib import Path 14 | from typing import Callable, Optional, Tuple, Union 15 | import click 16 | import numpy as np 17 | import PIL.Image 18 | from tqdm import tqdm 19 | 20 | #---------------------------------------------------------------------------- 21 | # Parse a 'M,N' or 'MxN' integer tuple. 22 | # Example: '4x2' returns (4,2) 23 | 24 | def parse_tuple(s: str) -> Tuple[int, int]: 25 | m = re.match(r'^(\d+)[x,](\d+)$', s) 26 | if m: 27 | return int(m.group(1)), int(m.group(2)) 28 | raise click.ClickException(f'cannot parse tuple {s}') 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | def maybe_min(a: int, b: Optional[int]) -> int: 33 | if b is not None: 34 | return min(a, b) 35 | return a 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def file_ext(name: Union[str, Path]) -> str: 40 | return str(name).split('.')[-1] 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def is_image_ext(fname: Union[str, Path]) -> bool: 45 | ext = file_ext(fname).lower() 46 | return f'.{ext}' in PIL.Image.EXTENSION 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def open_image_folder(source_dir, *, max_images: Optional[int]): 51 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 52 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 53 | max_idx = maybe_min(len(input_images), max_images) 54 | 55 | # Load labels. 56 | labels = dict() 57 | meta_fname = os.path.join(source_dir, 'dataset.json') 58 | if os.path.isfile(meta_fname): 59 | with open(meta_fname, 'r') as file: 60 | data = json.load(file)['labels'] 61 | if data is not None: 62 | labels = {x[0]: x[1] for x in data} 63 | 64 | # No labels available => determine from top-level directory names. 65 | if len(labels) == 0: 66 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 67 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 68 | if len(toplevel_indices) > 1: 69 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 70 | 71 | def iterate_images(): 72 | for idx, fname in enumerate(input_images): 73 | img = np.array(PIL.Image.open(fname)) 74 | yield dict(img=img, label=labels.get(arch_fnames.get(fname))) 75 | if idx >= max_idx - 1: 76 | break 77 | return max_idx, iterate_images() 78 | 79 | #---------------------------------------------------------------------------- 80 | 81 | def open_image_zip(source, *, max_images: Optional[int]): 82 | with zipfile.ZipFile(source, mode='r') as z: 83 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 84 | max_idx = maybe_min(len(input_images), max_images) 85 | 86 | # Load labels. 87 | labels = dict() 88 | if 'dataset.json' in z.namelist(): 89 | with z.open('dataset.json', 'r') as file: 90 | data = json.load(file)['labels'] 91 | if data is not None: 92 | labels = {x[0]: x[1] for x in data} 93 | 94 | def iterate_images(): 95 | with zipfile.ZipFile(source, mode='r') as z: 96 | for idx, fname in enumerate(input_images): 97 | with z.open(fname, 'r') as file: 98 | img = np.array(PIL.Image.open(file)) 99 | yield dict(img=img, label=labels.get(fname)) 100 | if idx >= max_idx - 1: 101 | break 102 | return max_idx, iterate_images() 103 | 104 | #---------------------------------------------------------------------------- 105 | 106 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): 107 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python 108 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb 109 | 110 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 111 | max_idx = maybe_min(txn.stat()['entries'], max_images) 112 | 113 | def iterate_images(): 114 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 115 | for idx, (_key, value) in enumerate(txn.cursor()): 116 | try: 117 | try: 118 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) 119 | if img is None: 120 | raise IOError('cv2.imdecode failed') 121 | img = img[:, :, ::-1] # BGR => RGB 122 | except IOError: 123 | img = np.array(PIL.Image.open(io.BytesIO(value))) 124 | yield dict(img=img, label=None) 125 | if idx >= max_idx - 1: 126 | break 127 | except: 128 | print(sys.exc_info()[1]) 129 | 130 | return max_idx, iterate_images() 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def open_cifar10(tarball: str, *, max_images: Optional[int]): 135 | images = [] 136 | labels = [] 137 | 138 | with tarfile.open(tarball, 'r:gz') as tar: 139 | for batch in range(1, 6): 140 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') 141 | with tar.extractfile(member) as file: 142 | data = pickle.load(file, encoding='latin1') 143 | images.append(data['data'].reshape(-1, 3, 32, 32)) 144 | labels.append(data['labels']) 145 | 146 | images = np.concatenate(images) 147 | labels = np.concatenate(labels) 148 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC 149 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 150 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] 151 | assert np.min(images) == 0 and np.max(images) == 255 152 | assert np.min(labels) == 0 and np.max(labels) == 9 153 | 154 | max_idx = maybe_min(len(images), max_images) 155 | 156 | def iterate_images(): 157 | for idx, img in enumerate(images): 158 | yield dict(img=img, label=int(labels[idx])) 159 | if idx >= max_idx - 1: 160 | break 161 | 162 | return max_idx, iterate_images() 163 | 164 | #---------------------------------------------------------------------------- 165 | 166 | def open_mnist(images_gz: str, *, max_images: Optional[int]): 167 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') 168 | assert labels_gz != images_gz 169 | images = [] 170 | labels = [] 171 | 172 | with gzip.open(images_gz, 'rb') as f: 173 | images = np.frombuffer(f.read(), np.uint8, offset=16) 174 | with gzip.open(labels_gz, 'rb') as f: 175 | labels = np.frombuffer(f.read(), np.uint8, offset=8) 176 | 177 | images = images.reshape(-1, 28, 28) 178 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 179 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 180 | assert labels.shape == (60000,) and labels.dtype == np.uint8 181 | assert np.min(images) == 0 and np.max(images) == 255 182 | assert np.min(labels) == 0 and np.max(labels) == 9 183 | 184 | max_idx = maybe_min(len(images), max_images) 185 | 186 | def iterate_images(): 187 | for idx, img in enumerate(images): 188 | yield dict(img=img, label=int(labels[idx])) 189 | if idx >= max_idx - 1: 190 | break 191 | 192 | return max_idx, iterate_images() 193 | 194 | #---------------------------------------------------------------------------- 195 | 196 | def make_transform( 197 | transform: Optional[str], 198 | output_width: Optional[int], 199 | output_height: Optional[int] 200 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 201 | def scale(width, height, img): 202 | w = img.shape[1] 203 | h = img.shape[0] 204 | if width == w and height == h: 205 | return img 206 | img = PIL.Image.fromarray(img) 207 | ww = width if width is not None else w 208 | hh = height if height is not None else h 209 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 210 | return np.array(img) 211 | 212 | def center_crop(width, height, img): 213 | crop = np.min(img.shape[:2]) 214 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 215 | if img.ndim == 2: 216 | img = img[:, :, np.newaxis].repeat(3, axis=2) 217 | img = PIL.Image.fromarray(img, 'RGB') 218 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 219 | return np.array(img) 220 | 221 | def center_crop_wide(width, height, img): 222 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 223 | if img.shape[1] < width or ch < height: 224 | return None 225 | 226 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 227 | if img.ndim == 2: 228 | img = img[:, :, np.newaxis].repeat(3, axis=2) 229 | img = PIL.Image.fromarray(img, 'RGB') 230 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 231 | img = np.array(img) 232 | 233 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 234 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 235 | return canvas 236 | 237 | if transform is None: 238 | return functools.partial(scale, output_width, output_height) 239 | if transform == 'center-crop': 240 | if output_width is None or output_height is None: 241 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 242 | return functools.partial(center_crop, output_width, output_height) 243 | if transform == 'center-crop-wide': 244 | if output_width is None or output_height is None: 245 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 246 | return functools.partial(center_crop_wide, output_width, output_height) 247 | assert False, 'unknown transform' 248 | 249 | #---------------------------------------------------------------------------- 250 | 251 | def open_dataset(source, *, max_images: Optional[int]): 252 | if os.path.isdir(source): 253 | if source.rstrip('/').endswith('_lmdb'): 254 | return open_lmdb(source, max_images=max_images) 255 | else: 256 | return open_image_folder(source, max_images=max_images) 257 | elif os.path.isfile(source): 258 | if os.path.basename(source) == 'cifar-10-python.tar.gz': 259 | return open_cifar10(source, max_images=max_images) 260 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': 261 | return open_mnist(source, max_images=max_images) 262 | elif file_ext(source) == 'zip': 263 | return open_image_zip(source, max_images=max_images) 264 | else: 265 | assert False, 'unknown archive type' 266 | else: 267 | raise click.ClickException(f'Missing input file or directory: {source}') 268 | 269 | #---------------------------------------------------------------------------- 270 | 271 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 272 | dest_ext = file_ext(dest) 273 | 274 | if dest_ext == 'zip': 275 | if os.path.dirname(dest) != '': 276 | os.makedirs(os.path.dirname(dest), exist_ok=True) 277 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 278 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 279 | zf.writestr(fname, data) 280 | return '', zip_write_bytes, zf.close 281 | else: 282 | # If the output folder already exists, check that is is 283 | # empty. 284 | # 285 | # Note: creating the output directory is not strictly 286 | # necessary as folder_write_bytes() also mkdirs, but it's better 287 | # to give an error message earlier in case the dest folder 288 | # somehow cannot be created. 289 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 290 | raise click.ClickException('--dest folder must be empty') 291 | os.makedirs(dest, exist_ok=True) 292 | 293 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 294 | os.makedirs(os.path.dirname(fname), exist_ok=True) 295 | with open(fname, 'wb') as fout: 296 | if isinstance(data, str): 297 | data = data.encode('utf8') 298 | fout.write(data) 299 | return dest, folder_write_bytes, lambda: None 300 | 301 | #---------------------------------------------------------------------------- 302 | 303 | @click.command() 304 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 305 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 306 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 307 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide'])) 308 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 309 | 310 | def main( 311 | source: str, 312 | dest: str, 313 | max_images: Optional[int], 314 | transform: Optional[str], 315 | resolution: Optional[Tuple[int, int]] 316 | ): 317 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. 318 | 319 | The input dataset format is guessed from the --source argument: 320 | 321 | \b 322 | --source *_lmdb/ Load LSUN dataset 323 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 324 | --source train-images-idx3-ubyte.gz Load MNIST dataset 325 | --source path/ Recursively load all images from path/ 326 | --source dataset.zip Recursively load all images from dataset.zip 327 | 328 | Specifying the output format and path: 329 | 330 | \b 331 | --dest /path/to/dir Save output files under /path/to/dir 332 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 333 | 334 | The output dataset format can be either an image folder or an uncompressed zip archive. 335 | Zip archives makes it easier to move datasets around file servers and clusters, and may 336 | offer better training performance on network file systems. 337 | 338 | Images within the dataset archive will be stored as uncompressed PNG. 339 | Uncompresed PNGs can be efficiently decoded in the training loop. 340 | 341 | Class labels are stored in a file called 'dataset.json' that is stored at the 342 | dataset root folder. This file has the following structure: 343 | 344 | \b 345 | { 346 | "labels": [ 347 | ["00000/img00000000.png",6], 348 | ["00000/img00000001.png",9], 349 | ... repeated for every image in the datase 350 | ["00049/img00049999.png",1] 351 | ] 352 | } 353 | 354 | If the 'dataset.json' file cannot be found, class labels are determined from 355 | top-level directory names. 356 | 357 | Image scale/crop and resolution requirements: 358 | 359 | Output images must be square-shaped and they must all have the same power-of-two 360 | dimensions. 361 | 362 | To scale arbitrary input image size to a specific width and height, use the 363 | --resolution option. Output resolution will be either the original 364 | input resolution (if resolution was not specified) or the one specified with 365 | --resolution option. 366 | 367 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a 368 | center crop transform on the input image. These options should be used with the 369 | --resolution option. For example: 370 | 371 | \b 372 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ 373 | --transform=center-crop-wide --resolution=512x384 374 | """ 375 | 376 | PIL.Image.init() 377 | 378 | if dest == '': 379 | raise click.ClickException('--dest output filename or directory must not be an empty string') 380 | 381 | num_files, input_iter = open_dataset(source, max_images=max_images) 382 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 383 | 384 | if resolution is None: resolution = (None, None) 385 | transform_image = make_transform(transform, *resolution) 386 | 387 | dataset_attrs = None 388 | 389 | labels = [] 390 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 391 | idx_str = f'{idx:08d}' 392 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 393 | 394 | # Apply crop and resize. 395 | img = transform_image(image['img']) 396 | if img is None: 397 | continue 398 | 399 | # Error check to require uniform image attributes across 400 | # the whole dataset. 401 | channels = img.shape[2] if img.ndim == 3 else 1 402 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels} 403 | if dataset_attrs is None: 404 | dataset_attrs = cur_image_attrs 405 | width = dataset_attrs['width'] 406 | height = dataset_attrs['height'] 407 | if width != height: 408 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 409 | if dataset_attrs['channels'] not in [1, 3]: 410 | raise click.ClickException('Input images must be stored as RGB or grayscale') 411 | if width != 2 ** int(np.floor(np.log2(width))): 412 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 413 | elif dataset_attrs != cur_image_attrs: 414 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 415 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 416 | 417 | # Save the image as an uncompressed PNG. 418 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels]) 419 | image_bits = io.BytesIO() 420 | img.save(image_bits, format='png', compress_level=0, optimize=False) 421 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 422 | labels.append([archive_fname, image['label']] if image['label'] is not None else None) 423 | 424 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 425 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 426 | close_dest() 427 | 428 | #---------------------------------------------------------------------------- 429 | 430 | if __name__ == "__main__": 431 | main() 432 | 433 | #---------------------------------------------------------------------------- 434 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import EasyDict, make_cache_dir_path 2 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utility classes and functions.""" 2 | 3 | import ctypes 4 | import fnmatch 5 | import importlib 6 | import inspect 7 | import numpy as np 8 | import os 9 | import shutil 10 | import sys 11 | import types 12 | import io 13 | import pickle 14 | import re 15 | import requests 16 | import html 17 | import hashlib 18 | import glob 19 | import tempfile 20 | import urllib 21 | import urllib.request 22 | import uuid 23 | 24 | from distutils.util import strtobool 25 | from typing import Any, List, Tuple, Union, Optional 26 | 27 | 28 | # Util classes 29 | # ------------------------------------------------------------------------------------------ 30 | 31 | 32 | class EasyDict(dict): 33 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 34 | 35 | def __getattr__(self, name: str) -> Any: 36 | try: 37 | return self[name] 38 | except KeyError: 39 | raise AttributeError(name) 40 | 41 | def __setattr__(self, name: str, value: Any) -> None: 42 | self[name] = value 43 | 44 | def __delattr__(self, name: str) -> None: 45 | del self[name] 46 | 47 | 48 | class Logger(object): 49 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 50 | 51 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 52 | self.file = None 53 | 54 | if file_name is not None: 55 | self.file = open(file_name, file_mode) 56 | 57 | self.should_flush = should_flush 58 | self.stdout = sys.stdout 59 | self.stderr = sys.stderr 60 | 61 | sys.stdout = self 62 | sys.stderr = self 63 | 64 | def __enter__(self) -> "Logger": 65 | return self 66 | 67 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 68 | self.close() 69 | 70 | def write(self, text: Union[str, bytes]) -> None: 71 | """Write text to stdout (and a file) and optionally flush.""" 72 | if isinstance(text, bytes): 73 | text = text.decode() 74 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 75 | return 76 | 77 | if self.file is not None: 78 | self.file.write(text) 79 | 80 | self.stdout.write(text) 81 | 82 | if self.should_flush: 83 | self.flush() 84 | 85 | def flush(self) -> None: 86 | """Flush written text to both stdout and a file, if open.""" 87 | if self.file is not None: 88 | self.file.flush() 89 | 90 | self.stdout.flush() 91 | 92 | def close(self) -> None: 93 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 94 | self.flush() 95 | 96 | # if using multiple loggers, prevent closing in wrong order 97 | if sys.stdout is self: 98 | sys.stdout = self.stdout 99 | if sys.stderr is self: 100 | sys.stderr = self.stderr 101 | 102 | if self.file is not None: 103 | self.file.close() 104 | self.file = None 105 | 106 | 107 | # Cache directories 108 | # ------------------------------------------------------------------------------------------ 109 | 110 | _dnnlib_cache_dir = None 111 | 112 | def set_cache_dir(path: str) -> None: 113 | global _dnnlib_cache_dir 114 | _dnnlib_cache_dir = path 115 | 116 | def make_cache_dir_path(*paths: str) -> str: 117 | if _dnnlib_cache_dir is not None: 118 | return os.path.join(_dnnlib_cache_dir, *paths) 119 | if 'DNNLIB_CACHE_DIR' in os.environ: 120 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 121 | if 'HOME' in os.environ: 122 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 123 | if 'USERPROFILE' in os.environ: 124 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 125 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 126 | 127 | # Small util functions 128 | # ------------------------------------------------------------------------------------------ 129 | 130 | 131 | def format_time(seconds: Union[int, float]) -> str: 132 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 133 | s = int(np.rint(seconds)) 134 | 135 | if s < 60: 136 | return "{0}s".format(s) 137 | elif s < 60 * 60: 138 | return "{0}m {1:02}s".format(s // 60, s % 60) 139 | elif s < 24 * 60 * 60: 140 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 141 | else: 142 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 143 | 144 | 145 | def format_time_brief(seconds: Union[int, float]) -> str: 146 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 147 | s = int(np.rint(seconds)) 148 | 149 | if s < 60: 150 | return "{0}s".format(s) 151 | elif s < 60 * 60: 152 | return "{0}m {1:02}s".format(s // 60, s % 60) 153 | elif s < 24 * 60 * 60: 154 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 155 | else: 156 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 157 | 158 | 159 | def ask_yes_no(question: str) -> bool: 160 | """Ask the user the question until the user inputs a valid answer.""" 161 | while True: 162 | try: 163 | print("{0} [y/n]".format(question)) 164 | return strtobool(input().lower()) 165 | except ValueError: 166 | pass 167 | 168 | 169 | def tuple_product(t: Tuple) -> Any: 170 | """Calculate the product of the tuple elements.""" 171 | result = 1 172 | 173 | for v in t: 174 | result *= v 175 | 176 | return result 177 | 178 | 179 | _str_to_ctype = { 180 | "uint8": ctypes.c_ubyte, 181 | "uint16": ctypes.c_uint16, 182 | "uint32": ctypes.c_uint32, 183 | "uint64": ctypes.c_uint64, 184 | "int8": ctypes.c_byte, 185 | "int16": ctypes.c_int16, 186 | "int32": ctypes.c_int32, 187 | "int64": ctypes.c_int64, 188 | "float32": ctypes.c_float, 189 | "float64": ctypes.c_double 190 | } 191 | 192 | 193 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 194 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 195 | type_str = None 196 | 197 | if isinstance(type_obj, str): 198 | type_str = type_obj 199 | elif hasattr(type_obj, "__name__"): 200 | type_str = type_obj.__name__ 201 | elif hasattr(type_obj, "name"): 202 | type_str = type_obj.name 203 | else: 204 | raise RuntimeError("Cannot infer type name from input") 205 | 206 | assert type_str in _str_to_ctype.keys() 207 | 208 | my_dtype = np.dtype(type_str) 209 | my_ctype = _str_to_ctype[type_str] 210 | 211 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 212 | 213 | return my_dtype, my_ctype 214 | 215 | 216 | def is_pickleable(obj: Any) -> bool: 217 | try: 218 | with io.BytesIO() as stream: 219 | pickle.dump(obj, stream) 220 | return True 221 | except: 222 | return False 223 | 224 | 225 | # Functionality to import modules/objects by name, and call functions by name 226 | # ------------------------------------------------------------------------------------------ 227 | 228 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 229 | """Searches for the underlying module behind the name to some python object. 230 | Returns the module and the object name (original name with module part removed).""" 231 | 232 | # allow convenience shorthands, substitute them by full names 233 | obj_name = re.sub("^np.", "numpy.", obj_name) 234 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 235 | 236 | # list alternatives for (module_name, local_obj_name) 237 | parts = obj_name.split(".") 238 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 239 | 240 | # try each alternative in turn 241 | for module_name, local_obj_name in name_pairs: 242 | try: 243 | module = importlib.import_module(module_name) # may raise ImportError 244 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 245 | return module, local_obj_name 246 | except: 247 | pass 248 | 249 | # maybe some of the modules themselves contain errors? 250 | for module_name, _local_obj_name in name_pairs: 251 | try: 252 | importlib.import_module(module_name) # may raise ImportError 253 | except ImportError: 254 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 255 | raise 256 | 257 | # maybe the requested attribute is missing? 258 | for module_name, local_obj_name in name_pairs: 259 | try: 260 | module = importlib.import_module(module_name) # may raise ImportError 261 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 262 | except ImportError: 263 | pass 264 | 265 | # we are out of luck, but we have no idea why 266 | raise ImportError(obj_name) 267 | 268 | 269 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 270 | """Traverses the object name and returns the last (rightmost) python object.""" 271 | if obj_name == '': 272 | return module 273 | obj = module 274 | for part in obj_name.split("."): 275 | obj = getattr(obj, part) 276 | return obj 277 | 278 | 279 | def get_obj_by_name(name: str) -> Any: 280 | """Finds the python object with the given name.""" 281 | module, obj_name = get_module_from_obj_name(name) 282 | return get_obj_from_module(module, obj_name) 283 | 284 | 285 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 286 | """Finds the python object with the given name and calls it as a function.""" 287 | assert func_name is not None 288 | func_obj = get_obj_by_name(func_name) 289 | assert callable(func_obj) 290 | return func_obj(*args, **kwargs) 291 | 292 | 293 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 294 | """Finds the python class with the given name and constructs it with the given arguments.""" 295 | return call_func_by_name(*args, func_name=class_name, **kwargs) 296 | 297 | 298 | def get_module_dir_by_obj_name(obj_name: str) -> str: 299 | """Get the directory path of the module containing the given object name.""" 300 | module, _ = get_module_from_obj_name(obj_name) 301 | return os.path.dirname(inspect.getfile(module)) 302 | 303 | 304 | def is_top_level_function(obj: Any) -> bool: 305 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 306 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 307 | 308 | 309 | def get_top_level_function_name(obj: Any) -> str: 310 | """Return the fully-qualified name of a top-level function.""" 311 | assert is_top_level_function(obj) 312 | module = obj.__module__ 313 | if module == '__main__': 314 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 315 | return module + "." + obj.__name__ 316 | 317 | 318 | # File system helpers 319 | # ------------------------------------------------------------------------------------------ 320 | 321 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 322 | """List all files recursively in a given directory while ignoring given file and directory names. 323 | Returns list of tuples containing both absolute and relative paths.""" 324 | assert os.path.isdir(dir_path) 325 | base_name = os.path.basename(os.path.normpath(dir_path)) 326 | 327 | if ignores is None: 328 | ignores = [] 329 | 330 | result = [] 331 | 332 | for root, dirs, files in os.walk(dir_path, topdown=True): 333 | for ignore_ in ignores: 334 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 335 | 336 | # dirs need to be edited in-place 337 | for d in dirs_to_remove: 338 | dirs.remove(d) 339 | 340 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 341 | 342 | absolute_paths = [os.path.join(root, f) for f in files] 343 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 344 | 345 | if add_base_to_relative: 346 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 347 | 348 | assert len(absolute_paths) == len(relative_paths) 349 | result += zip(absolute_paths, relative_paths) 350 | 351 | return result 352 | 353 | 354 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 355 | """Takes in a list of tuples of (src, dst) paths and copies files. 356 | Will create all necessary directories.""" 357 | for file in files: 358 | target_dir_name = os.path.dirname(file[1]) 359 | 360 | # will create all intermediate-level directories 361 | if not os.path.exists(target_dir_name): 362 | os.makedirs(target_dir_name) 363 | 364 | shutil.copyfile(file[0], file[1]) 365 | 366 | 367 | # URL helpers 368 | # ------------------------------------------------------------------------------------------ 369 | 370 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 371 | """Determine whether the given object is a valid URL string.""" 372 | if not isinstance(obj, str) or not "://" in obj: 373 | return False 374 | if allow_file_urls and obj.startswith('file://'): 375 | return True 376 | try: 377 | res = requests.compat.urlparse(obj) 378 | if not res.scheme or not res.netloc or not "." in res.netloc: 379 | return False 380 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 381 | if not res.scheme or not res.netloc or not "." in res.netloc: 382 | return False 383 | except: 384 | return False 385 | return True 386 | 387 | 388 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 389 | """Download the given URL and return a binary-mode file object to access the data.""" 390 | assert num_attempts >= 1 391 | assert not (return_filename and (not cache)) 392 | 393 | # Doesn't look like an URL scheme so interpret it as a local filename. 394 | if not re.match('^[a-z]+://', url): 395 | return url if return_filename else open(url, "rb") 396 | 397 | # Handle file URLs. This code handles unusual file:// patterns that 398 | # arise on Windows: 399 | # 400 | # file:///c:/foo.txt 401 | # 402 | # which would translate to a local '/c:/foo.txt' filename that's 403 | # invalid. Drop the forward slash for such pathnames. 404 | # 405 | # If you touch this code path, you should test it on both Linux and 406 | # Windows. 407 | # 408 | # Some internet resources suggest using urllib.request.url2pathname() but 409 | # but that converts forward slashes to backslashes and this causes 410 | # its own set of problems. 411 | if url.startswith('file://'): 412 | filename = urllib.parse.urlparse(url).path 413 | if re.match(r'^/[a-zA-Z]:', filename): 414 | filename = filename[1:] 415 | return filename if return_filename else open(filename, "rb") 416 | 417 | assert is_url(url) 418 | 419 | # Lookup from cache. 420 | if cache_dir is None: 421 | cache_dir = make_cache_dir_path('downloads') 422 | 423 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 424 | if cache: 425 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 426 | if len(cache_files) == 1: 427 | filename = cache_files[0] 428 | return filename if return_filename else open(filename, "rb") 429 | 430 | # Download. 431 | url_name = None 432 | url_data = None 433 | with requests.Session() as session: 434 | if verbose: 435 | print("Downloading %s ..." % url, end="", flush=True) 436 | for attempts_left in reversed(range(num_attempts)): 437 | try: 438 | with session.get(url) as res: 439 | res.raise_for_status() 440 | if len(res.content) == 0: 441 | raise IOError("No data received") 442 | 443 | if len(res.content) < 8192: 444 | content_str = res.content.decode("utf-8") 445 | if "download_warning" in res.headers.get("Set-Cookie", ""): 446 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 447 | if len(links) == 1: 448 | url = requests.compat.urljoin(url, links[0]) 449 | raise IOError("Google Drive virus checker nag") 450 | if "Google Drive - Quota exceeded" in content_str: 451 | raise IOError("Google Drive download quota exceeded -- please try again later") 452 | 453 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 454 | url_name = match[1] if match else url 455 | url_data = res.content 456 | if verbose: 457 | print(" done") 458 | break 459 | except KeyboardInterrupt: 460 | raise 461 | except: 462 | if not attempts_left: 463 | if verbose: 464 | print(" failed") 465 | raise 466 | if verbose: 467 | print(".", end="", flush=True) 468 | 469 | # Save to cache. 470 | if cache: 471 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 472 | safe_name = safe_name[:min(len(safe_name), 128)] 473 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 474 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 475 | os.makedirs(cache_dir, exist_ok=True) 476 | with open(temp_file, "wb") as f: 477 | f.write(url_data) 478 | os.replace(temp_file, cache_file) # atomic 479 | if return_filename: 480 | return cache_file 481 | 482 | # Return data as file object. 483 | assert not return_filename 484 | return io.BytesIO(url_data) 485 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: edm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python>=3.8, < 3.10 # package build failures on 3.10 8 | - pip 9 | - numpy>=1.20 10 | - click>=8.0 11 | - pillow>=8.3.1 12 | - scipy>=1.7.1 13 | - pytorch=1.12.1 14 | - psutil 15 | - requests 16 | - tqdm 17 | - imageio 18 | - pip: 19 | - imageio-ffmpeg>=0.4.3 20 | - pyspng 21 | -------------------------------------------------------------------------------- /exps/generate_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ----------------------------------------------------------------------- 4 | # This script is for generating a samples output file from an experiment. 5 | # ----------------------------------------------------------------------- 6 | 7 | cd .. 8 | 9 | USAGE="Usage: bash generate_samples.sh [extra_args]" 10 | 11 | EXPNAME=$1 12 | if [ -z $EXPNAME ]; then 13 | echo "Must specify an experiment name!" 14 | echo $USAGE 15 | exit 1 16 | fi 17 | 18 | STEPS=$2 19 | if [ -z $STEPS ]; then 20 | echo "Must specify number of diffusion steps to perform" 21 | echo $USAGE 22 | exit 1 23 | fi 24 | 25 | TMAX=$3 26 | if [ -z $TMAX ]; then 27 | echo "You must specify how many frames you want to generate (i.e. the length of the trajectory)" 28 | echo $USAGE 29 | exit 1 30 | fi 31 | 32 | EXPORT_TO=$4 33 | if [ -z $EXPORT_TO ]; then 34 | echo "Must specify output filename!" 35 | echo $USAGE 36 | exit 1 37 | fi 38 | 39 | shift 4 40 | 41 | #mkdir -p $SAVEDIR/$EXPNAME/samples 42 | 43 | mkdir -p $SAVE_DIR/$EXPNAME/samples/$STEPS 44 | 45 | echo "-----------------------------" 46 | echo "Processing n steps: " $STEPS 47 | echo "T max: " $TMAX 48 | echo "Extra arguments: " "$@" 49 | echo "Save to: " ${EXPORT_TO} 50 | echo "-----------------------------" 51 | 52 | python generate.py \ 53 | --network=$SAVE_DIR/$EXPNAME/network-snapshot.pkl \ 54 | --outfile=$EXPORT_TO \ 55 | --t_max=$TMAX \ 56 | --steps=$STEPS "$@" 57 | -------------------------------------------------------------------------------- /exps/generate_video_from_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #set -x 4 | 5 | USAGE="Usage: bash generate_video_from_samples.sh " 6 | 7 | SAMPLES_PATH=$1 8 | FPS=$2 9 | OUTFILE=$3 10 | 11 | if [ -z $SAMPLES_PATH ]; then 12 | echo "Error: you must specify path to samples file" 13 | echo $USAGE 14 | exit 1 15 | fi 16 | 17 | if [ -z $FPS ]; then 18 | echo "Error: you must specify frame rate" 19 | echo $USAGE 20 | exit 1 21 | fi 22 | 23 | if [ -z $OUTFILE ]; then 24 | echo "Error: you must specify output mp4 file" 25 | echo $USAGE 26 | exit 1 27 | fi 28 | 29 | 30 | #if [ ! -f $SAVE_DIR/$SAMPLES_PATH ]; then 31 | # echo "Path to file does not exist: $SAVEDIR/$SAMPLES_PATH" 32 | # exit 1 33 | #fi 34 | 35 | TMP_DIR=`mktemp -d` 36 | echo "Created tmp directory: " $TMP_DIR 37 | 38 | #FIGSIZE="10 6" 39 | FIGSIZE="20 12" 40 | 41 | pushd . 42 | 43 | cd .. 44 | python -m scripts.generate_frames_from_samples \ 45 | --samples=$SAMPLES_PATH \ 46 | --outdir=$TMP_DIR --figsize $FIGSIZE && \ 47 | cd $TMP_DIR && ffmpeg -r $FPS -i %07d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p out.mp4 48 | 49 | # Go back to our directory 50 | popd 51 | 52 | cp $TMP_DIR/out.mp4 $OUTFILE 53 | -------------------------------------------------------------------------------- /exps/json/example_ns.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_class": "training.datasets.NSDataset", 3 | 4 | // the low-res conditioning variable (y) should be 1/8th the size of samples 5 | // from u. 6 | "dataset_kwargs": {"lowres_scale_factor": 0.125}, 7 | 8 | // samples from the function u should be at this spatial resolution. 9 | "resolution": 64, 10 | 11 | // total window size is 7, 3 observations on each side 12 | "window_size": 3, 13 | 14 | "arch": "ddpmpp", 15 | "tick": 5, 16 | "snap": 5, 17 | 18 | // batch size 19 | "batch": 128, 20 | 21 | // base width of UNO 22 | "cbase": 64, 23 | 24 | // rank hyperparameter for UNO (from 0-1, larger = more parameters) 25 | "rank": 0.1, 26 | 27 | // retain what % of fourier modes (larger = more parameters) 28 | "fmult": 0.5 29 | } 30 | -------------------------------------------------------------------------------- /exps/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | METHOD="train.py" 4 | EXP_GROUP=$1 5 | CFG_FILE=$2 6 | N_GPU=$3 7 | 8 | if [ -z $EXP_GROUP ]; then 9 | echo "Must specify an experiment group name!" 10 | exit 1 11 | fi 12 | 13 | if [ -z $CFG_FILE ]; then 14 | echo "Must specify a json file" 15 | exit 1 16 | fi 17 | 18 | if [ -z $SAVE_DIR ]; then 19 | echo "SAVE_DIR not found, source env.sh?" 20 | exit 1 21 | fi 22 | 23 | if [ -z $N_GPU ]; then 24 | echo "Must set number of gpus to train on. If > 1 we use torchrun" 25 | exit 1 26 | fi 27 | 28 | if [ -z "$SLURM_JOB_ID" ]; then 29 | echo "$SLURM_JOB_ID not set for some reason, are you in a login node?" 30 | echo "Set variable to '999999' for now" 31 | SLURM_JOB_ID=999999 32 | fi 33 | 34 | EXP_NAME="${EXP_GROUP}/${SLURM_JOB_ID}" 35 | echo "Experiment name: " $EXP_NAME 36 | 37 | cd .. 38 | 39 | 40 | # TODO: add argument for distributed 41 | 42 | CFG_ABS_PATH=`pwd`/exps/${CFG_FILE} 43 | #python train.py --cfg=$CFG_ABS_PATH --savedir=${SAVEDIR}/${EXP_NAME} 44 | 45 | # If code does not exist for this experiment, copy 46 | # it over. Then cd into that directory and run the code. 47 | # But only if we're not in run_local mode. 48 | if [ -z $RUN_LOCAL ]; then 49 | if [ ! -d ${SAVEDIR}/${EXP_NAME}/code ]; then 50 | mkdir -p ${SAVE_DIR}/${EXP_NAME} 51 | echo "Copying code..." 52 | rsync -r -v --exclude='exps' --exclude='.git' --exclude='__pycache__' --exclude '*.pyc' . ${SAVE_DIR}/${EXP_NAME}/code 53 | if [ ! $? -eq 0 ]; then 54 | echo "rsync returned error, terminating..." 55 | exit 1 56 | fi 57 | fi 58 | fi 59 | 60 | CFG_ABS_PATH=`pwd`/exps/${CFG_FILE} 61 | echo "Absolute path of cfg: " $CFG_ABS_PATH 62 | 63 | if [[ ! "$N_GPU" -eq 1 ]]; then 64 | CMD_TO_RUN="torchrun --standalone --nproc_per_node=${N_GPU} train.py" 65 | else 66 | CMD_TO_RUN="python train.py" 67 | fi 68 | echo "Command to run: " $CMD_TO_RUN 69 | 70 | if [ -z $RUN_LOCAL ]; then 71 | cd ${SAVE_DIR}/${EXP_NAME}/code 72 | ${CMD_TO_RUN} --cfg=$CFG_ABS_PATH --savedir=${SAVE_DIR}/${EXP_NAME} 73 | else 74 | echo "RUN_LOCAL mode set, run code from this directory..." 75 | # --override_cfg = use the local cfg file, not the one in the experiment directory 76 | # also do not use torchrun, just debug with 1 gpu 77 | ${CMD_TO_RUN} --cfg=$CFG_ABS_PATH --savedir=${SAVE_DIR}/${EXP_NAME} --override_cfg 78 | fi 79 | echo "Current working directory: " `pwd` 80 | 81 | # Use this as a reference for trapping SIGTERM signal: 82 | # https://wikis.ch.cam.ac.uk/ro-walesdocs/wiki/index.php/Getting_started_with_SLURM 83 | 84 | #bash launch.sh $EXP_NAME 85 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | """Generate random images using the techniques described in the paper 2 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 3 | 4 | import os 5 | import re 6 | import click 7 | import tqdm 8 | import pickle 9 | import json 10 | import numpy as np 11 | import torch 12 | import PIL.Image 13 | import dnnlib 14 | from torch_utils import misc 15 | from torch.nn.functional import interpolate 16 | from torch_utils import distributed as dist 17 | from torchvision.utils import save_image 18 | 19 | from einops import rearrange 20 | 21 | from training.datasets.dataset import WindowedDataset 22 | 23 | from torchvision.transforms.functional import gaussian_blur 24 | 25 | #---------------------------------------------------------------------------- 26 | # Deterministic EDM sampler. 27 | 28 | def deterministic_edm_sampler( 29 | net, latents, class_labels=None, 30 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7 31 | ): 32 | # Adjust noise levels based on what's supported by the network. 33 | sigma_min = max(sigma_min, net.sigma_min) 34 | sigma_max = min(sigma_max, net.sigma_max) 35 | 36 | # Time step discretization. 37 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 38 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 39 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 40 | 41 | # Main sampling loop. 42 | x_next = latents.to(torch.float64) * t_steps[0] 43 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 44 | x_cur = x_next 45 | 46 | x_hat = x_cur 47 | t_hat = t_cur 48 | 49 | # Euler step. 50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 51 | d_cur = (x_hat - denoised) / t_hat 52 | x_next = x_hat + (t_next - t_hat) * d_cur 53 | 54 | # Apply 2nd order correction. 55 | if i < num_steps - 1: 56 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 57 | d_prime = (x_next - denoised) / t_next 58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 59 | 60 | return x_next 61 | 62 | def deterministic_ablation_sampler( 63 | net, latents, class_labels=None, 64 | num_steps=18, sigma_min=None, sigma_max=None, rho=7, 65 | solver='heun', discretization='edm', schedule='linear', scaling='none', 66 | epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1 67 | ): 68 | assert solver in ['euler', 'heun'] 69 | assert discretization in ['vp', 've', 'iddpm', 'edm'] 70 | assert schedule in ['vp', 've', 'linear'] 71 | assert scaling in ['vp', 'none'] 72 | 73 | # Helper functions for VP & VE noise level schedules. 74 | vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 75 | vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) 76 | vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d 77 | ve_sigma = lambda t: t.sqrt() 78 | ve_sigma_deriv = lambda t: 0.5 / t.sqrt() 79 | ve_sigma_inv = lambda sigma: sigma ** 2 80 | 81 | # Select default noise level range based on the specified time step discretization. 82 | if sigma_min is None: 83 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) 84 | sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] 85 | if sigma_max is None: 86 | vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) 87 | sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] 88 | 89 | # Adjust noise levels based on what's supported by the network. 90 | sigma_min = max(sigma_min, net.sigma_min) 91 | sigma_max = min(sigma_max, net.sigma_max) 92 | 93 | # Compute corresponding betas for VP. 94 | vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) 95 | vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d 96 | 97 | # Define time steps in terms of noise level. 98 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 99 | if discretization == 'vp': 100 | orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) 101 | sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) 102 | elif discretization == 've': 103 | orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) 104 | sigma_steps = ve_sigma(orig_t_steps) 105 | elif discretization == 'iddpm': 106 | u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) 107 | alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 108 | for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 109 | u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() 110 | u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] 111 | sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] 112 | else: 113 | assert discretization == 'edm' 114 | sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 115 | 116 | # Define noise level schedule. 117 | if schedule == 'vp': 118 | sigma = vp_sigma(vp_beta_d, vp_beta_min) 119 | sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) 120 | sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) 121 | elif schedule == 've': 122 | sigma = ve_sigma 123 | sigma_deriv = ve_sigma_deriv 124 | sigma_inv = ve_sigma_inv 125 | else: 126 | assert schedule == 'linear' 127 | sigma = lambda t: t 128 | sigma_deriv = lambda t: 1 129 | sigma_inv = lambda sigma: sigma 130 | 131 | # Define scaling schedule. 132 | if scaling == 'vp': 133 | s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() 134 | s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) 135 | else: 136 | assert scaling == 'none' 137 | s = lambda t: 1 138 | s_deriv = lambda t: 0 139 | 140 | # Compute final time steps based on the corresponding noise levels. 141 | t_steps = sigma_inv(net.round_sigma(sigma_steps)) 142 | t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 143 | 144 | # Main sampling loop. 145 | t_next = t_steps[0] 146 | x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) 147 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 148 | x_cur = x_next 149 | 150 | x_hat = x_cur 151 | t_hat = t_cur 152 | 153 | # Euler step. 154 | h = t_next - t_hat 155 | denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) 156 | d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised 157 | x_prime = x_hat + alpha * h * d_cur 158 | t_prime = t_hat + alpha * h 159 | 160 | # Apply 2nd order correction. 161 | if solver == 'euler' or i == num_steps - 1: 162 | x_next = x_hat + h * d_cur 163 | else: 164 | assert solver == 'heun' 165 | denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) 166 | d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised 167 | x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) 168 | 169 | return x_next 170 | 171 | #---------------------------------------------------------------------------- 172 | # Wrapper for torch.Generator that allows specifying a different random seed 173 | # for each sample in a minibatch. 174 | 175 | class StackedRandomGenerator: 176 | def __init__(self, device, seeds): 177 | super().__init__() 178 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 179 | 180 | def randn(self, size, **kwargs): 181 | assert size[0] == len(self.generators) 182 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 183 | 184 | def randn_like(self, input): 185 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 186 | 187 | def randint(self, *args, size, **kwargs): 188 | assert size[0] == len(self.generators) 189 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 190 | 191 | #---------------------------------------------------------------------------- 192 | # Parse a comma separated list of numbers or ranges and return a list of ints. 193 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 194 | 195 | def parse_int_list(s): 196 | if isinstance(s, list): return s 197 | ranges = [] 198 | range_re = re.compile(r'^(\d+)-(\d+)$') 199 | for p in s.split(','): 200 | m = range_re.match(p) 201 | if m: 202 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 203 | else: 204 | ranges.append(int(p)) 205 | return ranges 206 | 207 | #---------------------------------------------------------------------------- 208 | 209 | @click.command() 210 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 211 | @click.option("--reload_network", help="If set, do not use network code pickled in checkpoint", is_flag=True) 212 | @click.option("--resolution", help="Desired resolution of noise (and therefore generated images", type=int, default=None) 213 | @click.option('--outfile', help='Where to save the output images', metavar='DIR', type=str, required=True) 214 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 215 | # The number of forecasts (x's) we generate per x_t 216 | @click.option('--examples_per_t', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 217 | # The number of timesteps y_t we consider, for t = {1, ..., t_max}. 218 | @click.option('--t_max', help='Number of timesteps (examples) to generate in total', metavar='INT', type=click.IntRange(min=1), default=2) 219 | # Batch size for generation. 220 | @click.option('--batch_size', help='Batch size for generation', metavar='INT', type=click.IntRange(min=1), default=32) 221 | @click.option('--num_workers', help='Number of workers for data loader', metavar='INT', type=click.IntRange(min=0), default=0) 222 | #@click.option('--noise_kwargs', type=str, default="{}") 223 | @click.option('--rbf_scale', help="RBF scale", metavar='INT', type=click.FloatRange(min=0, min_open=True), default=None) 224 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 225 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=0.0002) 226 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 227 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 228 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) 229 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) 230 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) 231 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) 232 | 233 | def main(network_pkl, 234 | reload_network, 235 | resolution, 236 | outfile, 237 | subdirs, 238 | examples_per_t, 239 | t_max, 240 | batch_size, 241 | num_workers, 242 | #noise_kwargs, 243 | device=torch.device('cuda'), 244 | **sampler_kwargs): 245 | """Generate random images using the techniques described in the paper 246 | "Elucidating the Design Space of Diffusion-Based Generative Models". 247 | 248 | Examples: 249 | 250 | \b 251 | # Generate 64 images and save them as out/*.png 252 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\ 253 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 254 | """ 255 | 256 | if (t_max*examples_per_t) % batch_size != 0: 257 | raise ValueError("t_max * examples_per_t must be evenly divisible by batch_size!" + \ 258 | " values are {} * {}, batch_size = {}".format(t_max,examples_per_t,batch_size)) 259 | 260 | dist.init() 261 | 262 | # Load dataset because we need to be able to sample y's to condition on. 263 | exp_dir = os.path.dirname(network_pkl) 264 | config = dnnlib.EasyDict(json.loads( 265 | open(os.path.join(exp_dir, "training_options.json"), "r").read() 266 | )) 267 | dist.print0('Loading dataset...') 268 | dataset_obj = dnnlib.util.construct_class_by_name(**config.dataset_kwargs) # subclass of training.dataset.Dataset 269 | dist.print0('Windowing dataset...') 270 | dataset_obj = WindowedDataset(dataset_obj, window_size=config.window_size) 271 | 272 | # Load network. 273 | if reload_network: 274 | # If this is set, do NOT load the network code from the pickle. Reconstruct 275 | # the network from the actual current code and only load in the weights. 276 | # This should be set if you've made post-hoc changes to the network code 277 | # but are loading in weights corresponding to an older version. 278 | dist.print0('Constructing network...') 279 | interface_kwargs = dict( 280 | img_resolution=dataset_obj.resolution, 281 | img_channels=dataset_obj.num_channels, 282 | label_dim=dataset_obj.label_dim 283 | ) 284 | net = dnnlib.util.construct_class_by_name(**config.network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 285 | net.eval().requires_grad_(False).to(device) 286 | dist.print0(f'Loading network, load weights from inside "{network_pkl}"...') 287 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 288 | net_weights = pickle.load(f)['ema'].to(device).state_dict() 289 | net.load_state_dict(net_weights) 290 | else: 291 | dist.print0(f'Loading network from inside "{network_pkl}"...') 292 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 293 | net = pickle.load(f)['ema'].to(device) 294 | 295 | dist.print0("Sampler kwargs: {}".format(sampler_kwargs)) 296 | 297 | dataset_sampler = misc.InfiniteSampler( 298 | dataset=dataset_obj, 299 | rank=dist.get_rank(), 300 | num_replicas=dist.get_world_size(), 301 | shuffle=False, 302 | seed=0 # TODO make it an arg 303 | ) 304 | data_loader_kwargs = dnnlib.EasyDict( 305 | pin_memory=True, 306 | num_workers=num_workers, 307 | prefetch_factor=2 # what is this? 308 | ) 309 | dataset_iterator = iter( 310 | torch.utils.data.DataLoader( 311 | dataset=dataset_obj, 312 | sampler=dataset_sampler, 313 | batch_size=t_max, # only return `t_max` images 314 | **data_loader_kwargs 315 | ) 316 | ) 317 | 318 | #noise_kwargs = json.loads(noise_kwargs) 319 | dist.print0("Loading noise sampler...") 320 | noise_sampler_kwargs = dnnlib.EasyDict(config.sampler_kwargs) 321 | noise_sampler_kwargs.n_in = dataset_obj.num_channels 322 | noise_sampler_kwargs.device = device 323 | if resolution is not None: 324 | noise_sampler_kwargs.Ln1 = resolution 325 | noise_sampler_kwargs.Ln2 = resolution 326 | # We can override arguments in the noise_sampler at generation time, 327 | # for instance if we want to increase the resolution or change the 328 | # smoothness of the noise. 329 | """ 330 | if len(noise_kwargs.keys()) > 0: 331 | for key in noise_kwargs.keys(): 332 | if key in noise_sampler_kwargs: 333 | noise_sampler_kwargs[key] = noise_kwargs[key] 334 | dist.print0(f' noise_sampler: override {key}={noise_kwargs[key]} ...') 335 | else: 336 | raise ValueError(f'Unknown key for noise_sampler: "{key}"') 337 | """ 338 | 339 | noise_sampler = dnnlib.util.construct_class_by_name(**noise_sampler_kwargs) 340 | 341 | # Pick latents and labels. 342 | #rnd = StackedRandomGenerator(device, np.arange(0, examples_per_t).tolist()) 343 | 344 | # shape: (t_max, nc, h, w) and (t_max, w, nc, h, w) 345 | images_real_, class_labels_ = next(dataset_iterator) 346 | # t = timestep, ws = window size, nc = num channels 347 | class_labels = rearrange(class_labels_, 't ws nc h w -> t (ws nc) h w') 348 | class_labels = rearrange(class_labels, 't N h w -> t 1 N h w').\ 349 | repeat(1, examples_per_t, 1, 1, 1) 350 | class_labels = rearrange(class_labels, 't rep N h w -> (t rep) N h w') 351 | #images_real = images_real.view(-1, *tuple(images_real.shape[2:])) 352 | 353 | class_labels = class_labels.to(device) 354 | 355 | # batch_size = the number of conditioning images 356 | 357 | # TODO parallelise this 358 | buf_samples = [] 359 | N_total = class_labels.size(0) 360 | n_iters = int(np.ceil(N_total / batch_size)) 361 | for j in range(n_iters): 362 | dist.print0("Processing batch: {} / {} ...".format(j+1, n_iters)) 363 | this_slice = slice(j*batch_size, (j+1)*batch_size) 364 | this_class_labels = class_labels[this_slice] 365 | 366 | this_latents = noise_sampler.sample(this_class_labels.size(0)).to(device) 367 | if this_class_labels.size(-1) != this_latents.size(-1): 368 | # If we're doing super-resolution 369 | dist.print0(f' `this_class_label` and `latents` spatial dim mismatch: {this_class_labels.size(-1)} and {this_latents.size(-1)}, upscaling `this_class_label`...') 370 | this_class_labels = interpolate( 371 | this_class_labels, 372 | (this_latents.size(-2), this_latents.size(-1)), 373 | mode='bilinear' 374 | ) 375 | 376 | # Generate images. 377 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 378 | have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) 379 | sampler_fn = deterministic_ablation_sampler if have_ablation_kwargs else deterministic_edm_sampler 380 | samples = sampler_fn(net, this_latents, this_class_labels, **sampler_kwargs) 381 | 382 | print(" samples min={}, max={}".format(samples.min(), samples.max())) 383 | 384 | samples_torch = ((samples*0.5 + 0.5)).cpu() 385 | buf_samples.append(samples_torch) 386 | 387 | # shape = (t_max*examples_per_t, ch_x, h, w) 388 | buf_samples = torch.cat(buf_samples, dim=0) 389 | # shape = (t_max, examples_per_t, ch_x, h, w) 390 | buf_samples = buf_samples.reshape( 391 | t_max, examples_per_t, *tuple(buf_samples.shape[1:]) 392 | ) 393 | 394 | # shape = (t_max, ch_x, h, w) 395 | images_real_ = (images_real_*0.5 + 0.5) 396 | # shape = (t_max, ch_y, h, w) 397 | class_labels_ = (class_labels_*0.5 + 0.5) 398 | 399 | outdir = os.path.dirname(outfile) 400 | if not os.path.exists(outdir): 401 | os.makedirs(outdir) 402 | dist.print0("Saving to: {}".format(outfile)) 403 | torch.save( 404 | dict(gen=buf_samples, x=images_real_, y=class_labels_, metadata={}), 405 | outfile 406 | ) 407 | 408 | # Done. 409 | dist.print0('Done.') 410 | 411 | #---------------------------------------------------------------------------- 412 | 413 | if __name__ == "__main__": 414 | main() 415 | 416 | #---------------------------------------------------------------------------- 417 | -------------------------------------------------------------------------------- /media/dataset.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/cond-diffusion-operators-edm/572d3f9b6ccb79d529f353e2005fb16e07673ccc/media/dataset.gif -------------------------------------------------------------------------------- /media/generated.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/cond-diffusion-operators-edm/572d3f9b6ccb79d529f353e2005fb16e07673ccc/media/generated.gif -------------------------------------------------------------------------------- /media/generated128.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/cond-diffusion-operators-edm/572d3f9b6ccb79d529f353e2005fb16e07673ccc/media/generated128.gif -------------------------------------------------------------------------------- /resume.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | import glob 5 | import pickle 6 | import dnnlib 7 | 8 | import torch 9 | from torch_utils import distributed as dist 10 | 11 | #from train import run, Arguments 12 | from training import training_loop 13 | 14 | from omegaconf import OmegaConf as OC 15 | 16 | import warnings 17 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 18 | 19 | if __name__ == '__main__': 20 | 21 | # Have to do this as it was done in train.py as well. 22 | torch.multiprocessing.set_start_method('spawn') 23 | dist.init() 24 | 25 | exp_dir = sys.argv[1] 26 | config = dnnlib.EasyDict(json.loads( 27 | open(os.path.join(exp_dir, "training_options.json"), "r").read() 28 | )) 29 | # Convert any internal dictionaries into EasyDict as well. 30 | for key in config.keys(): 31 | if type(config[key]) is dict: 32 | config[key] = dnnlib.EasyDict(config[key]) 33 | 34 | snapshots = sorted( 35 | glob.glob("{}/network-snapshot.pkl".format(exp_dir)) 36 | ) 37 | if len(snapshots) != 0: 38 | latest_snapshot = snapshots[-1] 39 | dist.print0("Found checkpoint: {}".format(latest_snapshot)) 40 | config.resume_pkl = latest_snapshot 41 | # HACK: we actually have to open the pkl here to 42 | # get the epoch number. 43 | with dnnlib.util.open_url(config.resume_pkl, verbose=(dist.get_rank() == 0)) as f: 44 | config.resume_kimg = pickle.load(f)['cur_nimg'] // 1000 45 | dist.print0("cur_knimg={}".format(config.resume_kimg)) 46 | 47 | training_loop.training_loop(**config) -------------------------------------------------------------------------------- /scripts/f2id.py: -------------------------------------------------------------------------------- 1 | """Credit to Md Ashiqur Rahman for giving me this code.""" 2 | 3 | import torch 4 | from torch.nn.functional import interpolate 5 | from scipy import linalg 6 | import numpy as np 7 | 8 | def calculated_f2id(features1, features2, resolution=50, mode='linear'): 9 | ''' 10 | features1, features2: discretized feature functions of real and generated data points. 11 | assumed to be 1D of shape (batch, 1, grid_size) 12 | resolution: Required to put both feature1 and feature2 function to same grid 13 | 14 | ''' 15 | if features1.shape[1]!=resolution: 16 | features1 = interpolate(features1, size=resolution, mode=mode) 17 | if features2.shape[1]!=resolution: 18 | features2 = interpolate(features2, size=resolution, mode=mode) 19 | 20 | features1 = features1.reshape(features1.shape[0], -1) 21 | features2 = features2.reshape(features2.shape[0], -1) 22 | 23 | #print(features1.shape, features2.shape) 24 | 25 | mu1 = torch.mean(features1, dim=0).cpu().detach().numpy() 26 | mu2 = torch.mean(features2, dim=0).cpu().detach().numpy() 27 | sigma1 = torch.cov(torch.transpose(features1, 0, -1)).cpu().detach().numpy() 28 | sigma2 = torch.cov(torch.transpose(features2, 0, -1)).cpu().detach().numpy() 29 | #print(sigma1.shape, sigma2.shape) 30 | diff = mu1 - mu2 31 | 32 | s = 1/resolution 33 | 34 | # Product might be almost singular 35 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 36 | if not np.isfinite(covmean).all(): 37 | msg = ('fid calculation produces singular product; ' 38 | 'adding %s to diagonal of cov estimates') % eps 39 | print(msg) 40 | offset = np.eye(sigma1.shape[0]) * eps 41 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 42 | 43 | # Numerical error might give slight imaginary component 44 | if np.iscomplexobj(covmean): 45 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3/s): 46 | m = np.max(np.abs(covmean.imag)) 47 | raise ValueError('Imaginary component {}'.format(m)) 48 | covmean = covmean.real 49 | 50 | tr_covmean = np.trace(covmean) 51 | 52 | return (s*diff.dot(diff) + s*np.trace(sigma1) + 53 | s*np.trace(sigma2) - 2 *s* tr_covmean) 54 | 55 | 56 | k1 = torch.randn(1000,1,50) 57 | k2 = torch.randn(1000,1,100) 58 | 59 | print(calculated_f2id(k1,k2)) -------------------------------------------------------------------------------- /scripts/generate_frames_from_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.utils.data import DataLoader 8 | #from training.dataset_zarr import CwaEraDataset 9 | 10 | import dnnlib 11 | 12 | def get_figure(u: torch.FloatTensor, y: torch.FloatTensor, idx: int, figsize=(8,4)): 13 | fig, axs = plt.subplots(1,2, figsize=figsize) 14 | fig.suptitle("time: {}".format(idx)) 15 | #x = x.transpose(0,1).transpose(1,2) # convert to TF format 16 | axs[0].imshow(u) 17 | axs[0].set_title("u") 18 | axs[1].imshow(y) 19 | axs[1].set_title("y") 20 | return fig 21 | 22 | def to_tf_format(x: torch.Tensor): 23 | assert len(x.shape) == 3 24 | return x.transpose(0,1).transpose(1,2) 25 | 26 | def run(args): 27 | 28 | if not os.path.exists(args.outdir): 29 | os.makedirs(args.outdir) 30 | 31 | if args.path is None: 32 | print("args.path is `None` so setting to `$DATA_DIR`...") 33 | if 'DATA_DIR' not in os.environ: 34 | raise ValueError("DATA_DIR not set, please source env.sh") 35 | else: 36 | args.path = os.environ['DATA_DIR'] 37 | 38 | dataset_kwargs = json.loads(args.dataset_kwargs) 39 | 40 | dataset_kwargs = dnnlib.EasyDict( 41 | #class_name='training.dataset_zarr.CwaEraDataset', 42 | class_name=args.dataset_class, 43 | path=args.path, 44 | resolution=args.resolution, 45 | train=args.split=='train', 46 | **dataset_kwargs 47 | ) 48 | 49 | ds = dnnlib.util.construct_class_by_name(**dataset_kwargs) 50 | 51 | loader = DataLoader( 52 | ds, batch_size=args.batch_size, num_workers=args.num_workers, 53 | shuffle=False 54 | ) 55 | counter = 0 56 | for b, (xbatch, ybatch) in enumerate(loader): 57 | # data is in [-1, 1] so rescale first 58 | xbatch = xbatch*0.5 + 0.5 59 | ybatch = ybatch*0.5 + 0.5 60 | for j in range(len(xbatch)): 61 | fig = get_figure( 62 | to_tf_format(xbatch[j]), to_tf_format(ybatch[j]), 63 | idx=counter 64 | ) 65 | fig.savefig("{}/{}.png".format(args.outdir, str(counter).zfill(7))) 66 | plt.close(fig) 67 | counter += 1 68 | print("processed: {} frames".format((b+1)*args.batch_size)) 69 | 70 | if __name__ == '__main__': 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--outdir", type=str, required=True) 74 | parser.add_argument("--resolution", type=int, default=128) 75 | parser.add_argument("--dataset_class", type=str, 76 | default="training.datasets.NSDataset") 77 | parser.add_argument("--dataset_kwargs", type=str, 78 | default="{}", 79 | help="JSON string for extra args to pass to dataset.") 80 | parser.add_argument( 81 | "--path", 82 | type=str, 83 | default=None, 84 | help="Path to the dataset. Defaults to $DATA_DIR if not set." 85 | ) 86 | parser.add_argument("--split", type=str, choices=['train', 'test'], 87 | default='train') 88 | parser.add_argument("--batch_size", type=int, default=128) 89 | parser.add_argument("--num_workers", type=int, default=4) 90 | args = parser.parse_args() 91 | 92 | run(args) -------------------------------------------------------------------------------- /scripts/generate_frames_from_samples.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import os 6 | 7 | from typing import List, Tuple 8 | 9 | def get_figure(gen: torch.FloatTensor, 10 | u: torch.FloatTensor, 11 | y: torch.FloatTensor, 12 | idx: int, 13 | rect: List = [0, 0.03, 1, 0.93], 14 | figsize: Tuple = (12,4)): 15 | """ 16 | Args: 17 | gen: shape (n_seeds, h, w) (batch and channel already indexed into) 18 | xu: (h, w) (batch and channel already indexed into) 19 | y: (h, w) (batch and channel already indexed into) 20 | """ 21 | n_seeds = gen.size(0) 22 | fig, axs = plt.subplots(n_seeds, 3, figsize=figsize) 23 | fig.tight_layout(rect=rect) 24 | axs = axs.flat 25 | fig.suptitle("t = {}".format(idx)) 26 | #x = x.transpose(0,1).transpose(1,2) # convert to TF format 27 | for s in range(n_seeds): 28 | axs[s*3].imshow(u) 29 | axs[s*3 + 1].imshow(gen[s]) 30 | axs[s*3 + 2].imshow(y) 31 | if s == 0: 32 | axs[s*3].set_title("$\\boldsymbol{u}_t$") 33 | axs[s*3 + 1].set_title("$\\tilde{\\boldsymbol{u}}_{t}$") 34 | axs[s*3 + 2].set_title("$\\boldsymbol{y}_t$") 35 | return fig 36 | 37 | def run(args): 38 | 39 | if len(args.figsize) != 2: 40 | raise ValueError("figsize must be a two-tuple, received: {}".format(args.figsize)) 41 | 42 | sample_dict = torch.load(args.samples) 43 | real = sample_dict['x'] 44 | y = sample_dict['y'] 45 | gen = sample_dict['gen'] 46 | 47 | print("real shape = {} (t, nc, h, w)".format(real.shape)) 48 | print("y shape = {} (t, ws, nc, h, w)".format(y.shape)) 49 | print("gen shape = {} (t, n_repeat, nc, h, w)".format(gen.shape)) 50 | 51 | if not os.path.exists(args.outdir): 52 | print("{} does not exist, creating...".format(args.outdir)) 53 | os.makedirs(args.outdir) 54 | 55 | """ 56 | n_seeds = gen.size(1) 57 | if args.seed > n_seeds-1: 58 | raise ValueError("Only {} seeds detected in `gen`, yet seed={}".format( 59 | n_seeds, args.seed 60 | )) 61 | """ 62 | # index into this seed 63 | #gen = gen[:, args.seed] 64 | 65 | ch_u, ch_y = args.ch_u, args.ch_y 66 | 67 | counter = 0 68 | for j in range(len(real)): 69 | fig = get_figure( 70 | u=real[j, ch_u], 71 | gen=gen[j, :, ch_u], 72 | # for y, just index into the midpoint of the window 73 | y=y[j, (y.size(1)-1)//2, ch_y], 74 | idx=counter, 75 | figsize=args.figsize 76 | ) 77 | fig.savefig("{}/{}.png".format(args.outdir, str(counter).zfill(7))) 78 | plt.close(fig) 79 | counter += 1 80 | 81 | if __name__ == '__main__': 82 | 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("--outdir", type=str, required=True) 85 | parser.add_argument("--ch_u", type=int, default=0, 86 | help="Which index of the channel axis do we want to viz for x?") 87 | parser.add_argument("--ch_y", type=int, default=0, 88 | help="Which index of the channel axis do we want to viz for y?") 89 | parser.add_argument("--figsize", nargs="+", type=int, default=[12,8]) 90 | parser.add_argument("--seed", type=int, default=0, 91 | help="What seed do we index into for the `gen` tensor?") 92 | parser.add_argument( 93 | "--samples", 94 | type=str, 95 | required=True, 96 | help="Path to the samples.pt file. E.g. ///samples/samples.pt" 97 | ) 98 | args = parser.parse_args() 99 | 100 | run(args) -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # empty 2 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from . import training_stats 4 | 5 | #---------------------------------------------------------------------------- 6 | 7 | def init(): 8 | if 'MASTER_ADDR' not in os.environ: 9 | os.environ['MASTER_ADDR'] = 'localhost' 10 | if 'MASTER_PORT' not in os.environ: 11 | os.environ['MASTER_PORT'] = '29500' 12 | if 'RANK' not in os.environ: 13 | os.environ['RANK'] = '0' 14 | if 'LOCAL_RANK' not in os.environ: 15 | os.environ['LOCAL_RANK'] = '0' 16 | if 'WORLD_SIZE' not in os.environ: 17 | os.environ['WORLD_SIZE'] = '1' 18 | 19 | backend = 'gloo' if os.name == 'nt' else 'nccl' 20 | torch.distributed.init_process_group(backend=backend, init_method='env://') 21 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 22 | 23 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 24 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def get_rank(): 29 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def get_world_size(): 34 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def should_stop(): 39 | return False 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def update_progress(cur, total): 44 | _ = cur, total 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | def print0(*args, **kwargs): 49 | if get_rank() == 0: 50 | print(*args, **kwargs) 51 | 52 | #---------------------------------------------------------------------------- 53 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import contextlib 3 | import numpy as np 4 | import torch 5 | import warnings 6 | import dnnlib 7 | 8 | #---------------------------------------------------------------------------- 9 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 10 | # same constant is used multiple times. 11 | 12 | _constant_cache = dict() 13 | 14 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 15 | value = np.asarray(value) 16 | if shape is not None: 17 | shape = tuple(shape) 18 | if dtype is None: 19 | dtype = torch.get_default_dtype() 20 | if device is None: 21 | device = torch.device('cpu') 22 | if memory_format is None: 23 | memory_format = torch.contiguous_format 24 | 25 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 26 | tensor = _constant_cache.get(key, None) 27 | if tensor is None: 28 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 29 | if shape is not None: 30 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 31 | tensor = tensor.contiguous(memory_format=memory_format) 32 | _constant_cache[key] = tensor 33 | return tensor 34 | 35 | #---------------------------------------------------------------------------- 36 | # Replace NaN/Inf with specified numerical values. 37 | 38 | try: 39 | nan_to_num = torch.nan_to_num # 1.8.0a0 40 | except AttributeError: 41 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 42 | assert isinstance(input, torch.Tensor) 43 | if posinf is None: 44 | posinf = torch.finfo(input.dtype).max 45 | if neginf is None: 46 | neginf = torch.finfo(input.dtype).min 47 | assert nan == 0 48 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 49 | 50 | #---------------------------------------------------------------------------- 51 | # Symbolic assert. 52 | 53 | try: 54 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 55 | except AttributeError: 56 | symbolic_assert = torch.Assert # 1.7.0 57 | 58 | #---------------------------------------------------------------------------- 59 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 60 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 61 | 62 | @contextlib.contextmanager 63 | def suppress_tracer_warnings(): 64 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 65 | warnings.filters.insert(0, flt) 66 | yield 67 | warnings.filters.remove(flt) 68 | 69 | #---------------------------------------------------------------------------- 70 | # Assert that the shape of a tensor matches the given list of integers. 71 | # None indicates that the size of a dimension is allowed to vary. 72 | # Performs symbolic assertion when used in torch.jit.trace(). 73 | 74 | def assert_shape(tensor, ref_shape): 75 | if tensor.ndim != len(ref_shape): 76 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 77 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 78 | if ref_size is None: 79 | pass 80 | elif isinstance(ref_size, torch.Tensor): 81 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 82 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 83 | elif isinstance(size, torch.Tensor): 84 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 85 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 86 | elif size != ref_size: 87 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 88 | 89 | #---------------------------------------------------------------------------- 90 | # Function decorator that calls torch.autograd.profiler.record_function(). 91 | 92 | def profiled_function(fn): 93 | def decorator(*args, **kwargs): 94 | with torch.autograd.profiler.record_function(fn.__name__): 95 | return fn(*args, **kwargs) 96 | decorator.__name__ = fn.__name__ 97 | return decorator 98 | 99 | #---------------------------------------------------------------------------- 100 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 101 | # indefinitely, shuffling items as it goes. 102 | 103 | class InfiniteSampler(torch.utils.data.Sampler): 104 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 105 | assert len(dataset) > 0 106 | assert num_replicas > 0 107 | assert 0 <= rank < num_replicas 108 | assert 0 <= window_size <= 1 109 | super().__init__(dataset) 110 | self.dataset = dataset 111 | self.rank = rank 112 | self.num_replicas = num_replicas 113 | self.shuffle = shuffle 114 | self.seed = seed 115 | self.window_size = window_size 116 | 117 | def __iter__(self): 118 | order = np.arange(len(self.dataset)) 119 | rnd = None 120 | window = 0 121 | if self.shuffle: 122 | rnd = np.random.RandomState(self.seed) 123 | rnd.shuffle(order) 124 | window = int(np.rint(order.size * self.window_size)) 125 | 126 | idx = 0 127 | while True: 128 | i = idx % order.size 129 | if idx % self.num_replicas == self.rank: 130 | yield order[i] 131 | if window >= 2: 132 | j = (i - rnd.randint(window)) % order.size 133 | order[i], order[j] = order[j], order[i] 134 | idx += 1 135 | 136 | #---------------------------------------------------------------------------- 137 | # Utilities for operating with torch.nn.Module parameters and buffers. 138 | 139 | def params_and_buffers(module): 140 | assert isinstance(module, torch.nn.Module) 141 | return list(module.parameters()) + list(module.buffers()) 142 | 143 | def named_params_and_buffers(module): 144 | assert isinstance(module, torch.nn.Module) 145 | return list(module.named_parameters()) + list(module.named_buffers()) 146 | 147 | @torch.no_grad() 148 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 149 | assert isinstance(src_module, torch.nn.Module) 150 | assert isinstance(dst_module, torch.nn.Module) 151 | src_tensors = dict(named_params_and_buffers(src_module)) 152 | for name, tensor in named_params_and_buffers(dst_module): 153 | assert (name in src_tensors) or (not require_all) 154 | if name in src_tensors: 155 | tensor.copy_(src_tensors[name]) 156 | 157 | #---------------------------------------------------------------------------- 158 | # Context manager for easily enabling/disabling DistributedDataParallel 159 | # synchronization. 160 | 161 | @contextlib.contextmanager 162 | def ddp_sync(module, sync): 163 | assert isinstance(module, torch.nn.Module) 164 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 165 | yield 166 | else: 167 | with module.no_sync(): 168 | yield 169 | 170 | #---------------------------------------------------------------------------- 171 | # Check DistributedDataParallel consistency across processes. 172 | 173 | def check_ddp_consistency(module, ignore_regex=None): 174 | assert isinstance(module, torch.nn.Module) 175 | for name, tensor in named_params_and_buffers(module): 176 | fullname = type(module).__name__ + '.' + name 177 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 178 | continue 179 | tensor = tensor.detach() 180 | if tensor.is_floating_point(): 181 | tensor = nan_to_num(tensor) 182 | other = tensor.clone() 183 | torch.distributed.broadcast(tensor=other, src=0) 184 | assert (tensor == other).all(), fullname 185 | 186 | #---------------------------------------------------------------------------- 187 | # Print summary table of module hierarchy. 188 | 189 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 190 | assert isinstance(module, torch.nn.Module) 191 | assert not isinstance(module, torch.jit.ScriptModule) 192 | assert isinstance(inputs, (tuple, list)) 193 | 194 | # Register hooks. 195 | entries = [] 196 | nesting = [0] 197 | def pre_hook(_mod, _inputs): 198 | nesting[0] += 1 199 | def post_hook(mod, _inputs, outputs): 200 | nesting[0] -= 1 201 | if nesting[0] <= max_nesting: 202 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 203 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 204 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 205 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 206 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 207 | 208 | # Run module. 209 | outputs = module(*inputs) 210 | for hook in hooks: 211 | hook.remove() 212 | 213 | # Identify unique outputs, parameters, and buffers. 214 | tensors_seen = set() 215 | for e in entries: 216 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 217 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 218 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 219 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 220 | 221 | # Filter out redundant entries. 222 | if skip_redundant: 223 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 224 | 225 | # Construct table. 226 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 227 | rows += [['---'] * len(rows[0])] 228 | param_total = 0 229 | buffer_total = 0 230 | submodule_names = {mod: name for name, mod in module.named_modules()} 231 | for e in entries: 232 | name = '' if e.mod is module else submodule_names[e.mod] 233 | param_size = sum(t.numel() for t in e.unique_params) 234 | buffer_size = sum(t.numel() for t in e.unique_buffers) 235 | output_shapes = [str(list(t.shape)) for t in e.outputs] 236 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 237 | rows += [[ 238 | name + (':0' if len(e.outputs) >= 2 else ''), 239 | str(param_size) if param_size else '-', 240 | str(buffer_size) if buffer_size else '-', 241 | (output_shapes + ['-'])[0], 242 | (output_dtypes + ['-'])[0], 243 | ]] 244 | for idx in range(1, len(e.outputs)): 245 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 246 | param_total += param_size 247 | buffer_total += buffer_size 248 | rows += [['---'] * len(rows[0])] 249 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 250 | 251 | # Print table. 252 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 253 | print() 254 | for row in rows: 255 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 256 | print() 257 | return outputs 258 | 259 | #---------------------------------------------------------------------------- 260 | 261 | def count_parameters(model): 262 | """Count number of both learnable and total parameters for a module""" 263 | learnable_parameters = filter(lambda p: p.requires_grad, model.parameters()) 264 | num_learned_params = sum([np.prod(p.size()) for p in learnable_parameters]) 265 | num_params = sum([np.prod(p.size()) for p in model.parameters()]) 266 | return num_learned_params, num_params -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | """Facilities for pickling Python code alongside other data. 2 | 3 | The pickled code is automatically imported into a separate Python module 4 | during unpickling. This way, any previously exported pickles will remain 5 | usable even if the original code is no longer available, or if the current 6 | version of the code is not consistent with what was originally pickled.""" 7 | 8 | import sys 9 | import pickle 10 | import io 11 | import inspect 12 | import copy 13 | import uuid 14 | import types 15 | import dnnlib 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | _version = 6 # internal version number 20 | _decorators = set() # {decorator_class, ...} 21 | _import_hooks = [] # [hook_function, ...] 22 | _module_to_src_dict = dict() # {module: src, ...} 23 | _src_to_module_dict = dict() # {src: module, ...} 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def persistent_class(orig_class): 28 | r"""Class decorator that extends a given class to save its source code 29 | when pickled. 30 | 31 | Example: 32 | 33 | from torch_utils import persistence 34 | 35 | @persistence.persistent_class 36 | class MyNetwork(torch.nn.Module): 37 | def __init__(self, num_inputs, num_outputs): 38 | super().__init__() 39 | self.fc = MyLayer(num_inputs, num_outputs) 40 | ... 41 | 42 | @persistence.persistent_class 43 | class MyLayer(torch.nn.Module): 44 | ... 45 | 46 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 47 | source code alongside other internal state (e.g., parameters, buffers, 48 | and submodules). This way, any previously exported pickle will remain 49 | usable even if the class definitions have been modified or are no 50 | longer available. 51 | 52 | The decorator saves the source code of the entire Python module 53 | containing the decorated class. It does *not* save the source code of 54 | any imported modules. Thus, the imported modules must be available 55 | during unpickling, also including `torch_utils.persistence` itself. 56 | 57 | It is ok to call functions defined in the same module from the 58 | decorated class. However, if the decorated class depends on other 59 | classes defined in the same module, they must be decorated as well. 60 | This is illustrated in the above example in the case of `MyLayer`. 61 | 62 | It is also possible to employ the decorator just-in-time before 63 | calling the constructor. For example: 64 | 65 | cls = MyLayer 66 | if want_to_make_it_persistent: 67 | cls = persistence.persistent_class(cls) 68 | layer = cls(num_inputs, num_outputs) 69 | 70 | As an additional feature, the decorator also keeps track of the 71 | arguments that were used to construct each instance of the decorated 72 | class. The arguments can be queried via `obj.init_args` and 73 | `obj.init_kwargs`, and they are automatically pickled alongside other 74 | object state. This feature can be disabled on a per-instance basis 75 | by setting `self._record_init_args = False` in the constructor. 76 | 77 | A typical use case is to first unpickle a previous instance of a 78 | persistent class, and then upgrade it to use the latest version of 79 | the source code: 80 | 81 | with open('old_pickle.pkl', 'rb') as f: 82 | old_net = pickle.load(f) 83 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 84 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 85 | """ 86 | assert isinstance(orig_class, type) 87 | if is_persistent(orig_class): 88 | return orig_class 89 | 90 | assert orig_class.__module__ in sys.modules 91 | orig_module = sys.modules[orig_class.__module__] 92 | orig_module_src = _module_to_src(orig_module) 93 | 94 | class Decorator(orig_class): 95 | _orig_module_src = orig_module_src 96 | _orig_class_name = orig_class.__name__ 97 | 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | record_init_args = getattr(self, '_record_init_args', True) 101 | self._init_args = copy.deepcopy(args) if record_init_args else None 102 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 103 | assert orig_class.__name__ in orig_module.__dict__ 104 | _check_pickleable(self.__reduce__()) 105 | 106 | @property 107 | def init_args(self): 108 | assert self._init_args is not None 109 | return copy.deepcopy(self._init_args) 110 | 111 | @property 112 | def init_kwargs(self): 113 | assert self._init_kwargs is not None 114 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 115 | 116 | def __reduce__(self): 117 | fields = list(super().__reduce__()) 118 | fields += [None] * max(3 - len(fields), 0) 119 | if fields[0] is not _reconstruct_persistent_obj: 120 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 121 | fields[0] = _reconstruct_persistent_obj # reconstruct func 122 | fields[1] = (meta,) # reconstruct args 123 | fields[2] = None # state dict 124 | return tuple(fields) 125 | 126 | Decorator.__name__ = orig_class.__name__ 127 | Decorator.__module__ = orig_class.__module__ 128 | _decorators.add(Decorator) 129 | return Decorator 130 | 131 | #---------------------------------------------------------------------------- 132 | 133 | def is_persistent(obj): 134 | r"""Test whether the given object or class is persistent, i.e., 135 | whether it will save its source code when pickled. 136 | """ 137 | try: 138 | if obj in _decorators: 139 | return True 140 | except TypeError: 141 | pass 142 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 143 | 144 | #---------------------------------------------------------------------------- 145 | 146 | def import_hook(hook): 147 | r"""Register an import hook that is called whenever a persistent object 148 | is being unpickled. A typical use case is to patch the pickled source 149 | code to avoid errors and inconsistencies when the API of some imported 150 | module has changed. 151 | 152 | The hook should have the following signature: 153 | 154 | hook(meta) -> modified meta 155 | 156 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 157 | 158 | type: Type of the persistent object, e.g. `'class'`. 159 | version: Internal version number of `torch_utils.persistence`. 160 | module_src Original source code of the Python module. 161 | class_name: Class name in the original Python module. 162 | state: Internal state of the object. 163 | 164 | Example: 165 | 166 | @persistence.import_hook 167 | def wreck_my_network(meta): 168 | if meta.class_name == 'MyNetwork': 169 | print('MyNetwork is being imported. I will wreck it!') 170 | meta.module_src = meta.module_src.replace("True", "False") 171 | return meta 172 | """ 173 | assert callable(hook) 174 | _import_hooks.append(hook) 175 | 176 | #---------------------------------------------------------------------------- 177 | 178 | def _reconstruct_persistent_obj(meta): 179 | r"""Hook that is called internally by the `pickle` module to unpickle 180 | a persistent object. 181 | """ 182 | meta = dnnlib.EasyDict(meta) 183 | meta.state = dnnlib.EasyDict(meta.state) 184 | for hook in _import_hooks: 185 | meta = hook(meta) 186 | assert meta is not None 187 | 188 | assert meta.version == _version 189 | module = _src_to_module(meta.module_src) 190 | 191 | assert meta.type == 'class' 192 | orig_class = module.__dict__[meta.class_name] 193 | decorator_class = persistent_class(orig_class) 194 | obj = decorator_class.__new__(decorator_class) 195 | 196 | setstate = getattr(obj, '__setstate__', None) 197 | if callable(setstate): 198 | setstate(meta.state) # pylint: disable=not-callable 199 | else: 200 | obj.__dict__.update(meta.state) 201 | return obj 202 | 203 | #---------------------------------------------------------------------------- 204 | 205 | def _module_to_src(module): 206 | r"""Query the source code of a given Python module. 207 | """ 208 | src = _module_to_src_dict.get(module, None) 209 | if src is None: 210 | src = inspect.getsource(module) 211 | _module_to_src_dict[module] = src 212 | _src_to_module_dict[src] = module 213 | return src 214 | 215 | def _src_to_module(src): 216 | r"""Get or create a Python module for the given source code. 217 | """ 218 | module = _src_to_module_dict.get(src, None) 219 | if module is None: 220 | module_name = "_imported_module_" + uuid.uuid4().hex 221 | module = types.ModuleType(module_name) 222 | sys.modules[module_name] = module 223 | _module_to_src_dict[module] = src 224 | _src_to_module_dict[src] = module 225 | exec(src, module.__dict__) # pylint: disable=exec-used 226 | return module 227 | 228 | #---------------------------------------------------------------------------- 229 | 230 | def _check_pickleable(obj): 231 | r"""Check that the given object is pickleable, raising an exception if 232 | it is not. This function is expected to be considerably more efficient 233 | than actually pickling the object. 234 | """ 235 | def recurse(obj): 236 | if isinstance(obj, (list, tuple, set)): 237 | return [recurse(x) for x in obj] 238 | if isinstance(obj, dict): 239 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 240 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 241 | return None # Python primitive types are pickleable. 242 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 243 | return None # NumPy arrays and PyTorch tensors are pickleable. 244 | if is_persistent(obj): 245 | return None # Persistent objects are pickleable, by virtue of the constructor check. 246 | return obj 247 | with io.BytesIO() as f: 248 | pickle.dump(recurse(obj), f) 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | """Facilities for reporting and collecting training statistics across 2 | multiple processes and devices. The interface is designed to minimize 3 | synchronization overhead as well as the amount of boilerplate in user 4 | code.""" 5 | 6 | import re 7 | import numpy as np 8 | import torch 9 | import dnnlib 10 | 11 | from . import misc 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 16 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 17 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 18 | _rank = 0 # Rank of the current process. 19 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 20 | _sync_called = False # Has _sync() been called yet? 21 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 22 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def init_multiprocessing(rank, sync_device): 27 | r"""Initializes `torch_utils.training_stats` for collecting statistics 28 | across multiple processes. 29 | 30 | This function must be called after 31 | `torch.distributed.init_process_group()` and before `Collector.update()`. 32 | The call is not necessary if multi-process collection is not needed. 33 | 34 | Args: 35 | rank: Rank of the current process. 36 | sync_device: PyTorch device to use for inter-process 37 | communication, or None to disable multi-process 38 | collection. Typically `torch.device('cuda', rank)`. 39 | """ 40 | global _rank, _sync_device 41 | assert not _sync_called 42 | _rank = rank 43 | _sync_device = sync_device 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | @misc.profiled_function 48 | def report(name, value): 49 | r"""Broadcasts the given set of scalars to all interested instances of 50 | `Collector`, across device and process boundaries. 51 | 52 | This function is expected to be extremely cheap and can be safely 53 | called from anywhere in the training loop, loss function, or inside a 54 | `torch.nn.Module`. 55 | 56 | Warning: The current implementation expects the set of unique names to 57 | be consistent across processes. Please make sure that `report()` is 58 | called at least once for each unique name by each process, and in the 59 | same order. If a given process has no scalars to broadcast, it can do 60 | `report(name, [])` (empty list). 61 | 62 | Args: 63 | name: Arbitrary string specifying the name of the statistic. 64 | Averages are accumulated separately for each unique name. 65 | value: Arbitrary set of scalars. Can be a list, tuple, 66 | NumPy array, PyTorch tensor, or Python scalar. 67 | 68 | Returns: 69 | The same `value` that was passed in. 70 | """ 71 | if name not in _counters: 72 | _counters[name] = dict() 73 | 74 | elems = torch.as_tensor(value) 75 | if elems.numel() == 0: 76 | return value 77 | 78 | elems = elems.detach().flatten().to(_reduce_dtype) 79 | moments = torch.stack([ 80 | torch.ones_like(elems).sum(), 81 | elems.sum(), 82 | elems.square().sum(), 83 | ]) 84 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 85 | moments = moments.to(_counter_dtype) 86 | 87 | device = moments.device 88 | if device not in _counters[name]: 89 | _counters[name][device] = torch.zeros_like(moments) 90 | _counters[name][device].add_(moments) 91 | return value 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | def report0(name, value): 96 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 97 | but ignores any scalars provided by the other processes. 98 | See `report()` for further details. 99 | """ 100 | report(name, value if _rank == 0 else []) 101 | return value 102 | 103 | #---------------------------------------------------------------------------- 104 | 105 | class Collector: 106 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 107 | computes their long-term averages (mean and standard deviation) over 108 | user-defined periods of time. 109 | 110 | The averages are first collected into internal counters that are not 111 | directly visible to the user. They are then copied to the user-visible 112 | state as a result of calling `update()` and can then be queried using 113 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 114 | internal counters for the next round, so that the user-visible state 115 | effectively reflects averages collected between the last two calls to 116 | `update()`. 117 | 118 | Args: 119 | regex: Regular expression defining which statistics to 120 | collect. The default is to collect everything. 121 | keep_previous: Whether to retain the previous averages if no 122 | scalars were collected on a given round 123 | (default: True). 124 | """ 125 | def __init__(self, regex='.*', keep_previous=True): 126 | self._regex = re.compile(regex) 127 | self._keep_previous = keep_previous 128 | self._cumulative = dict() 129 | self._moments = dict() 130 | self.update() 131 | self._moments.clear() 132 | 133 | def names(self): 134 | r"""Returns the names of all statistics broadcasted so far that 135 | match the regular expression specified at construction time. 136 | """ 137 | return [name for name in _counters if self._regex.fullmatch(name)] 138 | 139 | def update(self): 140 | r"""Copies current values of the internal counters to the 141 | user-visible state and resets them for the next round. 142 | 143 | If `keep_previous=True` was specified at construction time, the 144 | operation is skipped for statistics that have received no scalars 145 | since the last update, retaining their previous averages. 146 | 147 | This method performs a number of GPU-to-CPU transfers and one 148 | `torch.distributed.all_reduce()`. It is intended to be called 149 | periodically in the main training loop, typically once every 150 | N training steps. 151 | """ 152 | if not self._keep_previous: 153 | self._moments.clear() 154 | for name, cumulative in _sync(self.names()): 155 | if name not in self._cumulative: 156 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 157 | delta = cumulative - self._cumulative[name] 158 | self._cumulative[name].copy_(cumulative) 159 | if float(delta[0]) != 0: 160 | self._moments[name] = delta 161 | 162 | def _get_delta(self, name): 163 | r"""Returns the raw moments that were accumulated for the given 164 | statistic between the last two calls to `update()`, or zero if 165 | no scalars were collected. 166 | """ 167 | assert self._regex.fullmatch(name) 168 | if name not in self._moments: 169 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 170 | return self._moments[name] 171 | 172 | def num(self, name): 173 | r"""Returns the number of scalars that were accumulated for the given 174 | statistic between the last two calls to `update()`, or zero if 175 | no scalars were collected. 176 | """ 177 | delta = self._get_delta(name) 178 | return int(delta[0]) 179 | 180 | def mean(self, name): 181 | r"""Returns the mean of the scalars that were accumulated for the 182 | given statistic between the last two calls to `update()`, or NaN if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | if int(delta[0]) == 0: 187 | return float('nan') 188 | return float(delta[1] / delta[0]) 189 | 190 | def std(self, name): 191 | r"""Returns the standard deviation of the scalars that were 192 | accumulated for the given statistic between the last two calls to 193 | `update()`, or NaN if no scalars were collected. 194 | """ 195 | delta = self._get_delta(name) 196 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 197 | return float('nan') 198 | if int(delta[0]) == 1: 199 | return float(0) 200 | mean = float(delta[1] / delta[0]) 201 | raw_var = float(delta[2] / delta[0]) 202 | return np.sqrt(max(raw_var - np.square(mean), 0)) 203 | 204 | def as_dict(self): 205 | r"""Returns the averages accumulated between the last two calls to 206 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 207 | 208 | dnnlib.EasyDict( 209 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 210 | ... 211 | ) 212 | """ 213 | stats = dnnlib.EasyDict() 214 | for name in self.names(): 215 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 216 | return stats 217 | 218 | def __getitem__(self, name): 219 | r"""Convenience getter. 220 | `collector[name]` is a synonym for `collector.mean(name)`. 221 | """ 222 | return self.mean(name) 223 | 224 | #---------------------------------------------------------------------------- 225 | 226 | def _sync(names): 227 | r"""Synchronize the global cumulative counters across devices and 228 | processes. Called internally by `Collector.update()`. 229 | """ 230 | if len(names) == 0: 231 | return [] 232 | global _sync_called 233 | _sync_called = True 234 | 235 | # Collect deltas within current rank. 236 | deltas = [] 237 | device = _sync_device if _sync_device is not None else torch.device('cpu') 238 | for name in names: 239 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 240 | for counter in _counters[name].values(): 241 | delta.add_(counter.to(device)) 242 | counter.copy_(torch.zeros_like(counter)) 243 | deltas.append(delta) 244 | deltas = torch.stack(deltas) 245 | 246 | # Sum deltas across ranks. 247 | if _sync_device is not None: 248 | torch.distributed.all_reduce(deltas) 249 | 250 | # Update cumulative values. 251 | deltas = deltas.cpu() 252 | for idx, name in enumerate(names): 253 | if name not in _cumulative: 254 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 255 | _cumulative[name].add_(deltas[idx]) 256 | 257 | # Return name-value pairs. 258 | return [(name, _cumulative[name]) for name in names] 259 | 260 | #---------------------------------------------------------------------------- 261 | # Convenience. 262 | 263 | default_collector = Collector() 264 | 265 | #---------------------------------------------------------------------------- 266 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train diffusion-based generative model using the techniques described in the 2 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" 3 | 4 | import os 5 | import re 6 | import jstyleson as json 7 | import glob 8 | import click 9 | import torch 10 | import dnnlib 11 | import pickle 12 | from torch_utils import distributed as dist 13 | from training import training_loop 14 | 15 | from omegaconf import OmegaConf as OC 16 | import dataclasses 17 | from dataclasses import asdict, dataclass, field 18 | from typing import List, Union, Tuple, Dict 19 | 20 | import warnings 21 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 22 | 23 | #---------------------------------------------------------------------------- 24 | # Parse a comma separated list of numbers or ranges and return a list of ints. 25 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 26 | 27 | def parse_int_list(s): 28 | if isinstance(s, list): return s 29 | ranges = [] 30 | range_re = re.compile(r'^(\d+)-(\d+)$') 31 | for p in s.split(','): 32 | m = range_re.match(p) 33 | if m: 34 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 35 | else: 36 | ranges.append(int(p)) 37 | return ranges 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @dataclass 42 | class Arguments: 43 | 44 | # Main options. 45 | #outdir: str # we use `savedir` in out code, see main() 46 | 47 | dataset_class: str # the class name of the dataset, e.g. training.dataset_zarr.CwaEraDataset 48 | resolution: int # spatial resolution 49 | 50 | dataset_path: Union[str,None] = None # path to the data location 51 | dataset_kwargs: dict = field(default_factory=lambda: {}) 52 | window_size: int = 0 # window size for y 53 | 54 | #cond: bool = False # type=bool, default=False, show_default=True) 55 | arch: str = 'ddpmpp' # click.Choice(['ddpmpp', 'ncsnpp', 'adm']) 56 | precond: str = 'edm' # click.Choice(['vp', 've', 'edm']) 57 | 58 | # Hyperparameters. 59 | duration: int = 200 60 | batch: int = 512 # Total batch size 61 | batch_gpu: Union[None, int] = None # limit batch size per gpu 62 | 63 | # Architecture-related. 64 | cbase: int = 64 65 | cres: List[int] = field(default_factory=lambda: [1, 2, 4, 4]) # my defaults 66 | attn: List[int] = field(default_factory=lambda: [16]) # What spatial resolution to perform self-attn 67 | num_blocks: int = 4 # Number of residual blocks per resolution 68 | rank: float = 1.0 # Rank for factorisation of weight matrices 69 | fmult: float = 1.0 # Retain this *100% of max number of Fourier modes 70 | 71 | # Noise-related. 72 | rbf_scale: float = 0.05 # For GRF noise, how noisy do we want it? (Less is more noisy.) 73 | 74 | # Training / optimiser-related. 75 | lr: float = 10e-4 # type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 76 | eps: float = 1e-8 # epsilon for ADAM 77 | ema: float = 0.5 # type=click.FloatRange(min=0), default=0.5, show_default=True) 78 | dropout: float = 0.13 # type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 79 | xflip: bool = False # type=bool, default=False, show_default=True) 80 | 81 | # Performance-related. 82 | fp16: bool = False # type=bool, default=False, show_default=True) 83 | ls: float = 1.0 # type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 84 | bench: bool = True # type=bool, default=True, show_default=True) 85 | cache: bool = True # type=bool, default=True, show_default=True) 86 | workers: int = 1 # type=click.IntRange(min=1), default=1, show_default=True) 87 | 88 | # I/O related. 89 | tick: int = 50 # type=click.IntRange(min=1), default=50, show_default=True) 90 | snap: int = 50 # type=click.IntRange(min=1), default=50, show_default=True) 91 | dump: int = 500 # type=click.IntRange(min=1), default=500, show_default=True) 92 | seed: int = 0 # seed 93 | 94 | # CHRIS modification: if it is non-str and `True`, find the latest 95 | # checkpoint snapshot in the folder and load that in. 96 | resume: Union[str, bool] = True # by default, we will resume the experiment 97 | 98 | dry_run: bool = False 99 | 100 | def main(kwargs, outdir): 101 | """Train diffusion-based generative model using the techniques described in the 102 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 103 | 104 | Examples: 105 | 106 | \b 107 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 108 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ 109 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 110 | """ 111 | 112 | # Convert back into a regular dictionary. 113 | kwargs = dataclasses.asdict(kwargs) 114 | opts = dnnlib.EasyDict(kwargs) 115 | 116 | torch.multiprocessing.set_start_method('spawn') 117 | dist.init() 118 | 119 | # Initialize config dict. 120 | c = dnnlib.EasyDict() 121 | dist.print0("window size: {}".format(opts.window_size)) 122 | 123 | if opts.dataset_path is None: 124 | print("args.path is `None` so setting to `$DATA_DIR`...") 125 | if 'DATA_DIR' not in os.environ: 126 | raise ValueError("DATA_DIR not set, please source env.sh") 127 | else: 128 | opts.dataset_path = os.environ['DATA_DIR'] 129 | 130 | c.dataset_kwargs = dnnlib.EasyDict( 131 | #class_name='training.dataset_zarr.CwaEraDataset', 132 | class_name=opts.dataset_class, 133 | path=opts.dataset_path, 134 | resolution=opts.resolution, 135 | train=True, 136 | **opts.dataset_kwargs 137 | ) 138 | c.window_size = opts.window_size 139 | 140 | c.data_loader_kwargs = dnnlib.EasyDict( 141 | pin_memory=True, 142 | num_workers=opts.workers, 143 | prefetch_factor=2 # what is this? 144 | ) 145 | c.network_kwargs = dnnlib.EasyDict() 146 | c.loss_kwargs = dnnlib.EasyDict() 147 | c.sampler_kwargs = dnnlib.EasyDict( 148 | class_name="training.noise_samplers.RBFKernel", 149 | Ln1=opts.resolution, 150 | Ln2=opts.resolution, 151 | scale=opts.rbf_scale 152 | ) 153 | 154 | c.optimizer_kwargs = dnnlib.EasyDict( 155 | class_name='torch.optim.Adam', 156 | lr=opts.lr, 157 | betas=[0.9,0.999], 158 | eps=opts.eps 159 | ) 160 | 161 | # Validate dataset options. 162 | try: 163 | dataset_obj = dnnlib.util.construct_class_by_name( 164 | **c.dataset_kwargs 165 | ) 166 | # Chris B: omit these check, we assume condtional training 167 | #c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 168 | #c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size 169 | #if opts.cond and not dataset_obj.has_labels: 170 | # raise click.ClickException('--cond=True requires labels specified in dataset.json') 171 | del dataset_obj # conserve memory 172 | except IOError as err: 173 | raise click.ClickException(f'--data: {err}') 174 | 175 | # Network architecture. 176 | if opts.arch == 'ddpmpp': # by default we use this 177 | c.network_kwargs.update( 178 | model_type='SongUNO', # was SongUNet previously, now UNO 179 | embedding_type='positional', 180 | encoder_type='standard', 181 | decoder_type='standard', 182 | ) 183 | c.network_kwargs.update( 184 | channel_mult_noise=1, 185 | resample_filter=[1,1], 186 | model_channels=128, 187 | channel_mult=[2,2,2] 188 | ) 189 | elif opts.arch == 'ncsnpp': 190 | c.network_kwargs.update( 191 | model_type='SongUNO', 192 | embedding_type='fourier', 193 | encoder_type='residual', 194 | decoder_type='standard' 195 | ) 196 | c.network_kwargs.update( 197 | channel_mult_noise=2, 198 | resample_filter=[1,3,3,1], 199 | model_channels=128, 200 | channel_mult=[2,2,2] 201 | ) 202 | else: 203 | assert opts.arch == 'adm' 204 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 205 | 206 | # Preconditioning & loss function. 207 | PRECOND_VALUES = ['vp', 've', 'edm', 'recon'] 208 | if opts.precond == 'vp': 209 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 210 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 211 | elif opts.precond == 've': 212 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 213 | c.loss_kwargs.class_name = 'training.loss.VELoss' 214 | elif opts.precond == 'edm': 215 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 216 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 217 | elif opts.precond == 'recon': 218 | # This is only used to train a deterministic autoencoder. 219 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 220 | c.loss_kwargs.class_name = 'training.loss.ReconLoss' 221 | else: 222 | raise ValueError("precond must be one of: {}".format(PRECOND_VALUES)) 223 | 224 | # Network options. 225 | if opts.cbase is not None: 226 | c.network_kwargs.model_channels = opts.cbase 227 | if opts.cres is not None: 228 | c.network_kwargs.channel_mult = opts.cres 229 | if opts.attn is not None: 230 | c.network_kwargs.attn_resolutions = opts.attn 231 | c.network_kwargs.update( 232 | num_blocks=opts.num_blocks, 233 | rank=opts.rank, 234 | fmult=opts.fmult, 235 | dropout=opts.dropout, 236 | use_fp16=opts.fp16 237 | ) 238 | if opts.precond == 'recon': 239 | # Do not use skip connections if we're training this 240 | # as an autoencoder. 241 | c.network_kwargs.update(disable_skip=True) 242 | 243 | # Training options. 244 | c.total_kimg = max(int(opts.duration * 1000), 1) 245 | c.ema_halflife_kimg = int(opts.ema * 1000) 246 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 247 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 248 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) 249 | 250 | # Random seed. 251 | if opts.seed is not None: 252 | c.seed = opts.seed 253 | else: 254 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 255 | torch.distributed.broadcast(seed, src=0) 256 | c.seed = int(seed) 257 | 258 | # Transfer learning and resume. 259 | # CHRIS B: not sure I need this feature so I'll comment it out 260 | """ 261 | if opts.transfer is not None: 262 | if opts.resume is not None: 263 | raise click.ClickException('--transfer and --resume cannot be specified at the same time') 264 | c.resume_pkl = opts.transfer 265 | c.ema_rampup_ratio = None 266 | """ 267 | if opts.resume is not None: 268 | if type(opts.resume) is str: 269 | raise NotImplementedError() 270 | else: 271 | # Find all the network snapshot files 272 | snapshots = sorted( 273 | glob.glob("{}/network-snapshot.pkl".format(outdir)) 274 | ) 275 | if len(snapshots) != 0: 276 | latest_snapshot = snapshots[-1] 277 | dist.print0("Found snapshot: {} ...".format(latest_snapshot)) 278 | c.resume_pkl = latest_snapshot 279 | # HACK: we actually have to open the pkl here to 280 | # get the epoch number. 281 | with dnnlib.util.open_url(c.resume_pkl, verbose=(dist.get_rank() == 0)) as f: 282 | c.resume_kimg = pickle.load(f)['cur_nimg'] // 1000 283 | 284 | #c.resume_state_dump = opts.resume 285 | c.resume_state_dump = None # keep things simple for now 286 | 287 | # Pick output directory. 288 | if dist.get_rank() != 0: 289 | c.run_dir = None 290 | c.run_dir = outdir 291 | 292 | # Print options. 293 | dist.print0() 294 | dist.print0('Training options:') 295 | dist.print0(json.dumps(c, indent=2)) 296 | dist.print0() 297 | dist.print0(f'Output directory: {c.run_dir}') 298 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}') 299 | #dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 300 | dist.print0(f'Network architecture: {opts.arch}') 301 | dist.print0(f'Preconditioning & loss: {opts.precond}') 302 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 303 | dist.print0(f'Batch size: {c.batch_size}') 304 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 305 | dist.print0() 306 | 307 | # Dry run? 308 | if opts.dry_run: 309 | dist.print0('Dry run; exiting.') 310 | return 311 | 312 | # Create output directory. 313 | dist.print0('Creating output directory...') 314 | if dist.get_rank() == 0: 315 | os.makedirs(c.run_dir, exist_ok=True) 316 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 317 | json.dump(c, f, indent=2) 318 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) 319 | 320 | # Train. 321 | training_loop.training_loop(**c) 322 | 323 | #---------------------------------------------------------------------------- 324 | 325 | import argparse 326 | import logger 327 | 328 | def parse_args(): 329 | parser = argparse.ArgumentParser(description="") 330 | # parser.add_argument('--datadir', type=str, default="") 331 | parser.add_argument("--savedir", type=str, required=True) 332 | parser.add_argument("--cfg", type=str, required=True) 333 | parser.add_argument( 334 | "--override_cfg", 335 | action="store_true", 336 | help="If this is set, then if there already exists a config.json " 337 | + "in the directory defined by savedir, load that instead of args.cfg. " 338 | + "This should be set so that SLURM does the right thing if the job is restarted.", 339 | ) 340 | args = parser.parse_args() 341 | return args 342 | 343 | if __name__ == "__main__": 344 | 345 | args = parse_args() 346 | 347 | saved_cfg_file = os.path.join(args.savedir, "config.json") 348 | if os.path.exists(saved_cfg_file) and not args.override_cfg: 349 | cfg_file = json.loads(open(saved_cfg_file, "r").read()) 350 | logger.debug("Found config in exp dir, loading instead...") 351 | else: 352 | cfg_file = json.loads(open(args.cfg, "r").read()) 353 | 354 | # structured() allows type checking 355 | conf = OC.structured(Arguments(**cfg_file)) 356 | 357 | # Since type checking is already done, convert 358 | # it back ito a (dot-accessible) dictionary. 359 | # (OC.to_object() returns back an Arguments object) 360 | main(OC.to_object(conf), args.savedir) 361 | 362 | #---------------------------------------------------------------------------- 363 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # empty 2 | -------------------------------------------------------------------------------- /training/dataset_edm.py: -------------------------------------------------------------------------------- 1 | """Streaming images and labels from datasets created with dataset_tool.py.""" 2 | 3 | import os 4 | import numpy as np 5 | import zipfile 6 | import PIL.Image 7 | import json 8 | import torch 9 | import dnnlib 10 | 11 | try: 12 | import pyspng 13 | except ImportError: 14 | pyspng = None 15 | 16 | #---------------------------------------------------------------------------- 17 | # Abstract base class for datasets. 18 | 19 | class Dataset(torch.utils.data.Dataset): 20 | def __init__(self, 21 | name, # Name of the dataset. 22 | raw_shape, # Shape of the raw image data (NCHW). 23 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 24 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 25 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 26 | random_seed = 0, # Random seed to use when applying max_size. 27 | cache = False, # Cache images in CPU memory? 28 | ): 29 | self._name = name 30 | self._raw_shape = list(raw_shape) 31 | self._use_labels = use_labels 32 | self._cache = cache 33 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 34 | self._raw_labels = None 35 | self._label_shape = None 36 | 37 | # Apply max_size. 38 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 39 | if (max_size is not None) and (self._raw_idx.size > max_size): 40 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 41 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 42 | 43 | # Apply xflip. 44 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 45 | if xflip: 46 | self._raw_idx = np.tile(self._raw_idx, 2) 47 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 48 | 49 | def _get_raw_labels(self): 50 | if self._raw_labels is None: 51 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 52 | if self._raw_labels is None: 53 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 54 | assert isinstance(self._raw_labels, np.ndarray) 55 | assert self._raw_labels.shape[0] == self._raw_shape[0] 56 | assert self._raw_labels.dtype in [np.float32, np.int64] 57 | if self._raw_labels.dtype == np.int64: 58 | assert self._raw_labels.ndim == 1 59 | assert np.all(self._raw_labels >= 0) 60 | return self._raw_labels 61 | 62 | def close(self): # to be overridden by subclass 63 | pass 64 | 65 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 66 | raise NotImplementedError 67 | 68 | def _load_raw_labels(self): # to be overridden by subclass 69 | raise NotImplementedError 70 | 71 | def __getstate__(self): 72 | return dict(self.__dict__, _raw_labels=None) 73 | 74 | def __del__(self): 75 | try: 76 | self.close() 77 | except: 78 | pass 79 | 80 | def __len__(self): 81 | return self._raw_idx.size 82 | 83 | def __getitem__(self, idx): 84 | raw_idx = self._raw_idx[idx] 85 | image = self._cached_images.get(raw_idx, None) 86 | if image is None: 87 | image = self._load_raw_image(raw_idx) 88 | if self._cache: 89 | self._cached_images[raw_idx] = image 90 | assert isinstance(image, np.ndarray) 91 | assert list(image.shape) == self.image_shape 92 | assert image.dtype == np.uint8 93 | if self._xflip[idx]: 94 | assert image.ndim == 3 # CHW 95 | image = image[:, :, ::-1] 96 | return image.copy(), self.get_label(idx) 97 | 98 | def get_label(self, idx): 99 | label = self._get_raw_labels()[self._raw_idx[idx]] 100 | if label.dtype == np.int64: 101 | onehot = np.zeros(self.label_shape, dtype=np.float32) 102 | onehot[label] = 1 103 | label = onehot 104 | return label.copy() 105 | 106 | def get_details(self, idx): 107 | d = dnnlib.EasyDict() 108 | d.raw_idx = int(self._raw_idx[idx]) 109 | d.xflip = (int(self._xflip[idx]) != 0) 110 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 111 | return d 112 | 113 | @property 114 | def name(self): 115 | return self._name 116 | 117 | @property 118 | def image_shape(self): 119 | return list(self._raw_shape[1:]) 120 | 121 | @property 122 | def num_channels(self): 123 | assert len(self.image_shape) == 3 # CHW 124 | return self.image_shape[0] 125 | 126 | @property 127 | def resolution(self): 128 | assert len(self.image_shape) == 3 # CHW 129 | assert self.image_shape[1] == self.image_shape[2] 130 | return self.image_shape[1] 131 | 132 | @property 133 | def label_shape(self): 134 | if self._label_shape is None: 135 | raw_labels = self._get_raw_labels() 136 | if raw_labels.dtype == np.int64: 137 | self._label_shape = [int(np.max(raw_labels)) + 1] 138 | else: 139 | self._label_shape = raw_labels.shape[1:] 140 | return list(self._label_shape) 141 | 142 | @property 143 | def label_dim(self): 144 | assert len(self.label_shape) == 1 145 | return self.label_shape[0] 146 | 147 | @property 148 | def has_labels(self): 149 | return any(x != 0 for x in self.label_shape) 150 | 151 | @property 152 | def has_onehot_labels(self): 153 | return self._get_raw_labels().dtype == np.int64 154 | 155 | #---------------------------------------------------------------------------- 156 | # Dataset subclass that loads images recursively from the specified directory 157 | # or ZIP file. 158 | 159 | class ImageFolderDataset(Dataset): 160 | def __init__(self, 161 | path, # Path to directory or zip. 162 | resolution = None, # Ensure specific resolution, None = highest available. 163 | use_pyspng = True, # Use pyspng if available? 164 | **super_kwargs, # Additional arguments for the Dataset base class. 165 | ): 166 | self._path = path 167 | self._use_pyspng = use_pyspng 168 | self._zipfile = None 169 | 170 | if os.path.isdir(self._path): 171 | self._type = 'dir' 172 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 173 | elif self._file_ext(self._path) == '.zip': 174 | self._type = 'zip' 175 | self._all_fnames = set(self._get_zipfile().namelist()) 176 | else: 177 | raise IOError('Path must point to a directory or zip') 178 | 179 | PIL.Image.init() 180 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 181 | if len(self._image_fnames) == 0: 182 | raise IOError('No image files found in the specified path') 183 | 184 | name = os.path.splitext(os.path.basename(self._path))[0] 185 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 186 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 187 | raise IOError('Image files do not match the specified resolution') 188 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 189 | 190 | @staticmethod 191 | def _file_ext(fname): 192 | return os.path.splitext(fname)[1].lower() 193 | 194 | def _get_zipfile(self): 195 | assert self._type == 'zip' 196 | if self._zipfile is None: 197 | self._zipfile = zipfile.ZipFile(self._path) 198 | return self._zipfile 199 | 200 | def _open_file(self, fname): 201 | if self._type == 'dir': 202 | return open(os.path.join(self._path, fname), 'rb') 203 | if self._type == 'zip': 204 | return self._get_zipfile().open(fname, 'r') 205 | return None 206 | 207 | def close(self): 208 | try: 209 | if self._zipfile is not None: 210 | self._zipfile.close() 211 | finally: 212 | self._zipfile = None 213 | 214 | def __getstate__(self): 215 | return dict(super().__getstate__(), _zipfile=None) 216 | 217 | def _load_raw_image(self, raw_idx): 218 | fname = self._image_fnames[raw_idx] 219 | with self._open_file(fname) as f: 220 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 221 | image = pyspng.load(f.read()) 222 | else: 223 | image = np.array(PIL.Image.open(f)) 224 | if image.ndim == 2: 225 | image = image[:, :, np.newaxis] # HW => HWC 226 | image = image.transpose(2, 0, 1) # HWC => CHW 227 | return image 228 | 229 | def _load_raw_labels(self): 230 | fname = 'dataset.json' 231 | if fname not in self._all_fnames: 232 | return None 233 | with self._open_file(fname) as f: 234 | labels = json.load(f)['labels'] 235 | if labels is None: 236 | return None 237 | labels = dict(labels) 238 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 239 | labels = np.array(labels) 240 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 241 | return labels 242 | 243 | #---------------------------------------------------------------------------- 244 | -------------------------------------------------------------------------------- /training/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_ns import NSDataset 2 | from .dataset_fake import FakeDataset -------------------------------------------------------------------------------- /training/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset as TorchDataset 4 | 5 | class Dataset(TorchDataset): 6 | 7 | norm_params_set: bool = False 8 | 9 | @property 10 | def name(self): 11 | raise NotImplementedError("") 12 | 13 | @property 14 | def num_channels(self): 15 | raise NotImplementedError 16 | 17 | @property 18 | def y_dim(self): 19 | raise NotImplementedError 20 | 21 | def normalise(self, x: torch.Tensor, y: torch.Tensor, gain: float, check_valid: bool = False): 22 | """"Helper function to return normalised version of x and y.""" 23 | self._set_normalisation_parameters(x, y) 24 | self.gain = gain 25 | u_normed = self.norm_u(x) 26 | y_normed = self.norm_y(y) 27 | if check_valid: 28 | assert torch.isclose(self.denorm_u(u_normed), x, atol=1e-4).all() 29 | assert torch.isclose(self.denorm_y(y_normed), y, atol=1e-4).all() 30 | return u_normed, y_normed 31 | 32 | def _set_normalisation_parameters(self, u, y): 33 | self.min_U = u.min(dim=0, keepdims=True)[0] 34 | self.max_U = u.max(dim=0, keepdims=True)[0] 35 | 36 | self.min_y = y.min(dim=0, keepdims=True)[0] 37 | self.max_y = y.max(dim=0, keepdims=True)[0] 38 | 39 | def norm_u(self, U): 40 | """Convert U into a format amenable to training""" 41 | U = (U - self.min_U) / (self.max_U - self.min_U + 1e-6) 42 | U = ((U - 0.5) / 0.5) * self.gain 43 | return U 44 | 45 | def denorm_u(self, u_normed): 46 | """Denormalisation for Design Bench test oracle""" 47 | # u_normed = u_normed*0.5 + 0.5 48 | u_normed = (u_normed * 0.5 + (0.5*self.gain)) / self.gain 49 | return u_normed * (self.max_U - self.min_U + 1e-6) + self.min_U 50 | 51 | def norm_y(self, y): 52 | """Convert y into a format amenable to training""" 53 | y = (y - self.min_y) / (self.max_y - self.min_y) 54 | y = ((y-0.5)/0.5)*self.gain 55 | return y 56 | 57 | def denorm_y(self, y_normed): 58 | """Denormalisation for Design Bench test oracle""" 59 | y_normed = (y_normed * 0.5 + (0.5*self.gain)) / self.gain 60 | return (y_normed * (self.max_y - self.min_y)) + self.min_y 61 | 62 | def window(yy: np.ndarray, t: int, k: int): 63 | """Use this method to validate windowing logic in WindowedDataset.""" 64 | # k is how many elements to the left/right, so the total window size is 2k + 1 65 | assert t >= 0, "t must be >= 0" 66 | assert t <= len(yy)-1, "t must be <= len-1" 67 | #assert k % 2 != 0, "k must be odd numbered" 68 | if t - k < 0: 69 | last_part = yy[0: t+k+1 ] 70 | num_missing = ((2*k)+1) - len(last_part) 71 | padding = np.zeros((num_missing,), dtype=yy.dtype) 72 | return np.concatenate((padding, last_part ), axis=0) 73 | elif t+k > len(yy)-1: 74 | #print("Exceeded") 75 | first_part = yy[t-k :: ] 76 | num_missing = ((2*k)+1) - len(first_part) 77 | padding = np.zeros((num_missing,), dtype=yy.dtype) 78 | return np.concatenate((first_part, padding ), axis=0) 79 | else: 80 | pass 81 | return yy[t-k : t+k+1] 82 | 83 | class WindowedDataset(Dataset): 84 | """Wraps a dataset to allow sampling a window of inputs instead.""" 85 | 86 | def __init__(self, dataset, window_size): 87 | self.dataset = dataset 88 | self.window_size = window_size 89 | 90 | @property 91 | def num_channels(self): 92 | return self.dataset.num_channels 93 | 94 | @property 95 | def y_dim(self): 96 | return self.dataset.y_dim 97 | 98 | @property 99 | def resolution(self): 100 | return self.dataset.resolution 101 | 102 | @property 103 | def label_dim(self): 104 | return (self.window_size*2 + 1)*self.y_dim 105 | 106 | def __len__(self): 107 | return len(self.dataset) 108 | 109 | def __getitem__(self, t): 110 | """Returns tensors of shape (nc, h, w) and (window_sz, nc, h, w)""" 111 | k = self.window_size 112 | u = self.dataset.__getitem__(t)[0] 113 | if t - k < 0: 114 | # If the left side of the window runs off past zero 115 | # valid_part is indexing dataset[0 : t+k+1] 116 | valid_part = [ self.dataset.__getitem__(idx)[1] for idx in range(0, t+k+1) ] 117 | # (n_valid, nc, h, w) 118 | valid_part = torch.stack(valid_part, dim=0) 119 | # (n_empty, nc, h, w) 120 | num_missing = ((2*k)+1) - len(valid_part) 121 | empty_part = torch.zeros_like(valid_part[0:1]).repeat(num_missing, 1, 1, 1) 122 | # (window_sz, nc, h, w) 123 | full = torch.cat((empty_part, valid_part), dim=0) 124 | elif t+k > len(self)-1: 125 | # If the right side of the window runs off past length of data 126 | # valid_part is indexing dataset[ t-k :: ] 127 | valid_part = [ self.dataset.__getitem__(idx)[1] for idx in range(t-k, len(self)) ] 128 | # (n_valid, nc, h, w) 129 | valid_part = torch.stack(valid_part, dim=0) 130 | # (n_empty, nc, h, w) 131 | num_missing = ((2*k)+1) - len(valid_part) 132 | empty_part = torch.zeros_like(valid_part[0:1]).repeat(num_missing, 1, 1, 1) 133 | # (window_sz, nc, h, w) 134 | full = torch.cat((valid_part, empty_part), dim=0) 135 | else: 136 | full = [ self.dataset.__getitem__(idx)[1] for idx in range(t-k, t+k+1) ] 137 | full = torch.stack(full, dim=0) 138 | 139 | return u, full -------------------------------------------------------------------------------- /training/datasets/dataset_fake.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .dataset import Dataset 3 | 4 | class FakeDataset(Dataset): 5 | 6 | def __init__(self, path, resolution, train=True): 7 | self.res = resolution 8 | 9 | @property 10 | def name(self): 11 | return "FakeDataset" 12 | 13 | @property 14 | def num_channels(self): 15 | return 10 16 | 17 | @property 18 | def y_dim(self): 19 | return 2 20 | 21 | def __len__(self): 22 | return 100 23 | 24 | def __getitem__(self, idc): 25 | x = torch.randn((self.num_channels, self.res, self.res)) 26 | y = torch.randn((self.y_dim, self.res, self.res)) 27 | return x, y 28 | 29 | if __name__ == '__main__': 30 | from .dataset import WindowedDataset 31 | from torch.utils.data import DataLoader 32 | 33 | fake_ds = FakeDataset(None, 128, True) 34 | w_ds = WindowedDataset(fake_ds, window_size=5) 35 | print(fake_ds) 36 | print(w_ds) 37 | 38 | loader = DataLoader(w_ds, batch_size=5) 39 | print([ elem.shape for elem in iter(loader).next() ]) -------------------------------------------------------------------------------- /training/datasets/dataset_ns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from .dataset import Dataset 4 | import torch 5 | 6 | from torch.nn import functional as F 7 | 8 | class NSDataset(Dataset): 9 | """ 10 | https://zenodo.org/records/7495555 11 | 12 | 2D Navier-Stokes data (200 trajectories, 500 time-steps each) at 64 x 64 spatial resolution. 13 | 14 | This dataset is a drop-in replacement for the ERA-CWA dataset which is not yet 15 | public. It is not intended to be a difficult dataset to train on, rather its 16 | purpose is to simply provide something which can be trained on as a proof of 17 | concept. 18 | 19 | We wish to train a conditional diffusion model `p(y_t|u_{t-k}, ..., u_{t+k})`, where 20 | `y_t` is the original resolution of NS at timestep `t` and `u` is a much smaller 21 | resolution, i.e. we want to do (function space) super resolution. 22 | 23 | `k` is handled by `WindowedDataset` and we need not worry about that here, we 24 | simply just need to write the class to give us `y_t` and `u_t`, and we can generate 25 | `u_t` by downsampling `y_t` and upsampling it back to the same resolution as `y_t`. 26 | """ 27 | 28 | def __init__(self, 29 | path, 30 | resolution, 31 | lowres_scale_factor=0.25, # 32 | train=True # 33 | ): 34 | 35 | # The dataset contains multiple trajectories, we only want one of them 36 | # since this dataset is intended to represent a single time series. 37 | 38 | if train: 39 | which_trajectory = 0 40 | else: 41 | which_trajectory = 1 42 | 43 | # u is the actual function we wish to learn, conditioned on low res 44 | # versions of the function y 45 | u = np.load(os.path.join(path, "2D_NS_Re40.npy"))[which_trajectory][1:] 46 | u = torch.from_numpy(u).float().unsqueeze(1) 47 | self.u = F.interpolate(u, size=resolution, mode='bilinear') 48 | 49 | # Downscale by a factor of `lowres_scale_factor` (e.g. 0.25 => 1/4) 50 | # then upscale back. y is meant to be a low-res version of u. 51 | y = F.interpolate(self.u, scale_factor=lowres_scale_factor) 52 | self.y = F.interpolate(y, size=resolution, mode='bilinear') 53 | 54 | self.u, self.y = self.normalise( 55 | self.u, self.y, gain=1.0, 56 | check_valid=True 57 | ) 58 | 59 | assert self.u.shape == self.u.shape 60 | self.res = resolution 61 | 62 | @property 63 | def resolution(self): 64 | return self.res 65 | 66 | @property 67 | def name(self): 68 | return "NSDataset" 69 | 70 | @property 71 | def num_channels(self): 72 | """How many channels are in u?""" 73 | return 1 74 | 75 | @property 76 | def y_dim(self): 77 | """How many channels are in y?""" 78 | return 1 79 | 80 | def __len__(self): 81 | return len(self.u) 82 | 83 | def __getitem__(self, idc): 84 | return self.u[idc], self.y[idc] -------------------------------------------------------------------------------- /training/img_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from types import new_class 4 | import torch 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch import Tensor 11 | import h5py 12 | import math 13 | import torchvision.transforms.functional as TF 14 | from torch_utils import distributed as dist 15 | 16 | 17 | def reshape_fields(img, inp_or_tar, crop_size_x, crop_size_y,rnd_x, rnd_y, params, y_roll, train, normalize=True): 18 | #Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y) 19 | 20 | if len(np.shape(img)) == 3: 21 | img = np.expand_dims(img, 0) 22 | 23 | if img.shape[3] > 720: 24 | img = img[:, :, 0:720] #remove last pixel for era5 data 25 | 26 | #n_history = np.shape(img)[0] - 1 #for era5 27 | n_history = params.n_history 28 | 29 | img_shape_x = np.shape(img)[-2] 30 | img_shape_y = np.shape(img)[-1] 31 | n_channels = np.shape(img)[1] #this will either be N_in_channels or N_out_channels 32 | channels = params.in_channels if inp_or_tar =='inp' else params.out_channels 33 | #print('channels', channels) 34 | 35 | # dist.print0('normalize', normalize) 36 | # dist.print0('train', train) 37 | 38 | if normalize and train: 39 | # So it's loading statistics from an external file?? 40 | mins = np.load(params.min_path)[:, channels] 41 | maxs = np.load(params.max_path)[:, channels] 42 | means = np.load(params.global_means_path)[:, channels] 43 | stds = np.load(params.global_stds_path)[:, channels] 44 | 45 | if crop_size_x == None: 46 | crop_size_x = img_shape_x 47 | if crop_size_y == None: 48 | crop_size_y = img_shape_y 49 | 50 | 51 | """ 52 | if normalize and train: 53 | if params.normalization == 'minmax': 54 | img -= mins 55 | img /= (maxs - mins) 56 | elif params.normalization == 'zscore': 57 | #print('params.normalization == zscore') 58 | img -=means 59 | img /=stds 60 | """ 61 | 62 | if params.roll: 63 | img = np.roll(img, y_roll, axis = -1) 64 | 65 | # if train and (crop_size_x or crop_size_y): 66 | # img = img[:,:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y] 67 | 68 | if (crop_size_x or crop_size_y): 69 | img = img[:,:,rnd_x:rnd_x+crop_size_x, rnd_y:rnd_y+crop_size_y] 70 | 71 | if inp_or_tar == 'inp': 72 | img = np.reshape(img, (n_channels*(n_history+1), crop_size_x, crop_size_y)) 73 | elif inp_or_tar == 'tar': 74 | img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) 75 | 76 | # do min max norm here 77 | img = (img - img.min()) / (img.max() - img.min() + 1e-6) 78 | 79 | return torch.as_tensor(img) 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | """Loss functions used in the paper 2 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 3 | 4 | import torch 5 | from torch_utils import persistence 6 | 7 | #---------------------------------------------------------------------------- 8 | # Loss function corresponding to the variance preserving (VP) formulation 9 | # from the paper "Score-Based Generative Modeling through Stochastic 10 | # Differential Equations". 11 | 12 | @persistence.persistent_class 13 | class VPLoss: 14 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 15 | self.beta_d = beta_d 16 | self.beta_min = beta_min 17 | self.epsilon_t = epsilon_t 18 | 19 | def __call__(self, net, images, labels, augment_pipe=None): 20 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 21 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 22 | weight = 1 / sigma ** 2 23 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 24 | n = torch.randn_like(y) * sigma 25 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 26 | loss = weight * ((D_yn - y) ** 2) 27 | return loss 28 | 29 | def sigma(self, t): 30 | t = torch.as_tensor(t) 31 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 32 | 33 | #---------------------------------------------------------------------------- 34 | # Loss function corresponding to the variance exploding (VE) formulation 35 | # from the paper "Score-Based Generative Modeling through Stochastic 36 | # Differential Equations". 37 | 38 | @persistence.persistent_class 39 | class VELoss: 40 | def __init__(self, sigma_min=0.02, sigma_max=100): 41 | self.sigma_min = sigma_min 42 | self.sigma_max = sigma_max 43 | 44 | def __call__(self, net, images, labels, augment_pipe=None): 45 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 46 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 47 | weight = 1 / sigma ** 2 48 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 49 | n = torch.randn_like(y) * sigma 50 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 51 | loss = weight * ((D_yn - y) ** 2) 52 | return loss 53 | 54 | #---------------------------------------------------------------------------- 55 | # Improved loss function proposed in the paper "Elucidating the Design Space 56 | # of Diffusion-Based Generative Models" (EDM). 57 | 58 | @persistence.persistent_class 59 | class EDMLoss: 60 | def __init__(self, sampler, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 61 | self.sampler = sampler 62 | self.P_mean = P_mean 63 | self.P_std = P_std 64 | self.sigma_data = sigma_data 65 | 66 | def __call__(self, net, images, labels=None, augment_pipe=None): 67 | 68 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 69 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 70 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 71 | 72 | # We want to augment pipe both x and y at once. 73 | x_dim = images.size(1) 74 | all_images = torch.cat((images, labels), dim=1) 75 | all_images_augmented, augment_labels = augment_pipe(all_images) if augment_pipe is not None else (images, None) 76 | # Extract out the x and y components 77 | y = all_images[:, 0:x_dim] 78 | labels = all_images[:, x_dim::] 79 | 80 | #n = torch.randn_like(y) * sigma 81 | n = self.sampler.sample(y.size(0)) * sigma 82 | 83 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 84 | loss = weight * ((D_yn - y) ** 2) 85 | return loss 86 | 87 | #---------------------------------------------------------------------------- 88 | 89 | @persistence.persistent_class 90 | class ReconLoss: 91 | def __init__(self, *args, **kwargs): 92 | pass 93 | def __call__(self, net, images, labels=None, augment_pipe=None): 94 | # We want to augment pipe both x and y at once. 95 | x_dim = images.size(1) 96 | all_images = torch.cat((images, labels), dim=1) 97 | all_images_augmented, augment_labels = augment_pipe(all_images) \ 98 | if augment_pipe is not None else (images, None) 99 | # Extract out the x and y components 100 | y = all_images[:, 0:x_dim] 101 | labels = all_images[:, x_dim::] 102 | sigma = torch.zeros((y.size(0), )).to(y.device)+1 103 | D_y = net(y, sigma, labels, augment_labels=augment_labels) 104 | 105 | loss = ((D_y - y) ** 2) 106 | return loss -------------------------------------------------------------------------------- /training/noise_samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft as fft 3 | import numpy as np 4 | import cv2 5 | import math 6 | import os 7 | 8 | def get_fixed_coords(Ln1, Ln2): 9 | xs = torch.linspace(0, 1, steps=Ln1 + 1)[0:-1] 10 | ys = torch.linspace(0, 1, steps=Ln2 + 1)[0:-1] 11 | xx, yy = torch.meshgrid(xs, ys, indexing="xy") 12 | coords = torch.cat([yy.reshape(-1, 1), xx.reshape(-1, 1)], dim=-1) 13 | return coords 14 | 15 | class NoiseSampler(object): 16 | def sample(self, N): 17 | raise NotImplementedError() 18 | 19 | class RBFKernel(NoiseSampler): 20 | @torch.no_grad() 21 | def __init__( 22 | self, n_in, Ln1, Ln2, scale=1, eps=0.01, device=None 23 | ): 24 | self.n_in = n_in 25 | self.Ln1 = Ln1 26 | self.Ln2 = Ln2 27 | self.device = device 28 | self.scale = scale 29 | 30 | # (s^2, 2) 31 | meshgrid = get_fixed_coords(self.Ln1, self.Ln2).to(device) 32 | # (s^2, s^2) 33 | C = torch.exp(-torch.cdist(meshgrid, meshgrid) / (2 * scale**2)) 34 | # Need to add some regularisation or else the sqrt won't exist 35 | I = torch.eye(C.size(-1)).to(device) 36 | 37 | # Not memory efficient 38 | #C = C + (eps**2) * I 39 | I.mul_(eps**2) # inplace multiply by eps**2 40 | C.add_(I) # inplace add by I 41 | del I # don't need it anymore 42 | 43 | # TODO: can we support f16 in this class to save gpu memory? 44 | 45 | self.L = torch.linalg.cholesky(C) 46 | 47 | del C # save memory 48 | 49 | @torch.no_grad() 50 | def sample(self, N): 51 | # (N, s^2, s^2) x (N, s^2, 1) -> (N, s^2, 2) 52 | # We can do this in one big torch.bmm, but I am concerned about memory 53 | # so let's just do it iteratively. 54 | # L_padded = self.L.repeat(N, 1, 1) 55 | # z_mat = torch.randn((N, self.Ln1*self.Ln2, 2)).to(self.device) 56 | # sample = torch.bmm(L_padded, z_mat) 57 | samples = torch.zeros((N, self.Ln1 * self.Ln2, self.n_in)).to(self.device) 58 | for ix in range(N): 59 | # (s^2, s^2) * (s^2, 2) -> (s^2, 2) 60 | this_z = torch.randn(self.Ln1 * self.Ln2, self.n_in).to(self.device) 61 | samples[ix] = torch.matmul(self.L, this_z) 62 | 63 | # reshape into (N, s, s, n_in) 64 | sample_rshp = samples.reshape(-1, self.Ln1, self.Ln2, self.n_in) 65 | 66 | # reshape into (N, n_in, s, s) 67 | sample_rshp = sample_rshp.transpose(-1,-2).transpose(-2,-3) 68 | 69 | return sample_rshp -------------------------------------------------------------------------------- /training/training_loop.py: -------------------------------------------------------------------------------- 1 | """Main training loop.""" 2 | 3 | import os 4 | import time 5 | import copy 6 | import json 7 | import pickle 8 | import psutil 9 | import numpy as np 10 | import torch 11 | import dnnlib 12 | from training.datasets.dataset import WindowedDataset 13 | from torch_utils import distributed as dist 14 | from torch_utils import training_stats 15 | from torch_utils import misc 16 | from torchvision.utils import save_image 17 | 18 | from einops import rearrange 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def training_loop( 23 | run_dir = '.', # Output directory. 24 | dataset_kwargs = {}, # Options for training set. 25 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 26 | network_kwargs = {}, # Options for model and preconditioning. 27 | loss_kwargs = {}, # Options for loss function. 28 | sampler_kwargs = {}, # Options for noise sampler. 29 | optimizer_kwargs = {}, # Options for optimizer. 30 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 31 | seed = 0, # Global random seed. 32 | window_size = 0, # Number of examples per side of y_t 33 | batch_size = 512, # Total batch size for one training iteration. 34 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 35 | total_kimg = 200000, # Training duration, measured in thousands of training images. 36 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 37 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 38 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 39 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 40 | kimg_per_tick = 50, # Interval of progress prints. 41 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 42 | state_dump_ticks = 500, # How often to dump training state, None = disable. 43 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 44 | resume_state_dump = None, # Start from the given training state, None = reset training state. 45 | resume_kimg = 0, # Start from the given training progress. 46 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 47 | device = torch.device('cuda'), 48 | ): 49 | # Initialize. 50 | start_time = time.time() 51 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 52 | torch.manual_seed(np.random.randint(1 << 31)) 53 | torch.backends.cudnn.benchmark = cudnn_benchmark 54 | torch.backends.cudnn.allow_tf32 = False 55 | torch.backends.cuda.matmul.allow_tf32 = False 56 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 57 | 58 | # Select batch size per GPU. 59 | batch_gpu_total = batch_size // dist.get_world_size() 60 | if batch_gpu is None or batch_gpu > batch_gpu_total: 61 | batch_gpu = batch_gpu_total 62 | num_accumulation_rounds = batch_gpu_total // batch_gpu 63 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 64 | 65 | # Load dataset. 66 | dist.print0('Loading dataset...') 67 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 68 | dist.print0('Windowing dataset...') 69 | dataset_obj = WindowedDataset(dataset_obj, window_size=window_size) 70 | dataset_sampler = misc.InfiniteSampler( 71 | dataset=dataset_obj, 72 | rank=dist.get_rank(), 73 | num_replicas=dist.get_world_size(), 74 | seed=seed 75 | ) 76 | dataset_iterator = iter( 77 | torch.utils.data.DataLoader( 78 | dataset=dataset_obj, 79 | sampler=dataset_sampler, 80 | batch_size=batch_gpu, 81 | **data_loader_kwargs 82 | ) 83 | ) 84 | 85 | # Construct network. 86 | dist.print0('Constructing network...') 87 | interface_kwargs = dict( 88 | img_resolution=dataset_obj.resolution, 89 | img_channels=dataset_obj.num_channels, 90 | label_dim=dataset_obj.label_dim 91 | ) 92 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 93 | net.train().requires_grad_(True).to(device) 94 | 95 | dist.print0("Number of params: {}".format(misc.count_parameters(net))) 96 | 97 | # Print network statistics. 98 | if dist.get_rank() == 0: 99 | with torch.no_grad(): 100 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) 101 | sigma = torch.ones([batch_gpu], device=device) 102 | labels = torch.zeros([batch_gpu, net.label_dim, net.img_resolution, net.img_resolution], device=device) 103 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) 104 | 105 | 106 | # Setup optimizer. 107 | dist.print0('Setting up optimizer...') 108 | 109 | # Kinda hacky but we need the loss_kwargs to get the sampler_fn 110 | #print(sampler_kwargs, "<<<<<<<<<") 111 | sampler_kwargs.n_in = dataset_obj.num_channels 112 | # Have to specify `device` here since it is not serialisable 113 | sampler_kwargs.device = device 114 | loss_kwargs.sampler = dnnlib.util.construct_class_by_name(**sampler_kwargs) 115 | 116 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 117 | optimizer = dnnlib.util.construct_class_by_name( 118 | params=net.parameters(), 119 | **optimizer_kwargs) # subclass of torch.optim.Optimizer 120 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 121 | ddp = torch.nn.parallel.DistributedDataParallel( 122 | net, 123 | device_ids=[device], 124 | broadcast_buffers=False, 125 | #find_unused_parameters=True 126 | ) 127 | ema = copy.deepcopy(net).eval().requires_grad_(False) 128 | 129 | # Resume training from previous snapshot. 130 | if resume_pkl is not None: 131 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 132 | if dist.get_rank() != 0: 133 | torch.distributed.barrier() # rank 0 goes first 134 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 135 | data = pickle.load(f) 136 | 137 | if dist.get_rank() == 0: 138 | torch.distributed.barrier() # other ranks follow 139 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 140 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 141 | del data # conserve memory 142 | if resume_state_dump: 143 | # Chris: is None by default 144 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 145 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 146 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) 147 | optimizer.load_state_dict(data['optimizer_state']) 148 | del data # conserve memory 149 | 150 | # Train. 151 | dist.print0(f'Training for {total_kimg} kimg...') 152 | dist.print0() 153 | cur_nimg = resume_kimg * 1000 154 | cur_tick = 0 155 | tick_start_nimg = cur_nimg 156 | tick_start_time = time.time() 157 | maintenance_time = tick_start_time - start_time 158 | dist.update_progress(cur_nimg // 1000, total_kimg) 159 | stats_jsonl = None 160 | first = True 161 | while True: 162 | 163 | # Accumulate gradients. 164 | optimizer.zero_grad(set_to_none=True) 165 | for round_idx in range(num_accumulation_rounds): 166 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 167 | images, labels = next(dataset_iterator) 168 | labels = rearrange(labels, 'bs ws nc h w -> bs (ws nc) h w') 169 | images = images.to(device).to(torch.float32) #/ 127.5 - 1 170 | labels = labels.to(device) 171 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 172 | training_stats.report('Loss/loss', loss) 173 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 174 | 175 | if first: 176 | # Dump some viz to disk. 177 | #save_image( 178 | # images[0:5]*0.5 + 0.5, 179 | # os.path.join(run_dir, "train_samples.png") 180 | #) 181 | if augment_pipe is not None: 182 | torch.save( 183 | augment_pipe(images)[0][0:8].cpu()*0.5 + 0.5, 184 | os.path.join(run_dir, "train_samples.pt") 185 | ) 186 | else: 187 | torch.save( 188 | images[0:8].cpu()*0.5 + 0.5, 189 | os.path.join(run_dir, "train_samples.pt") 190 | ) 191 | first = False 192 | 193 | # Update weights. 194 | for g in optimizer.param_groups: 195 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 196 | for param in net.parameters(): 197 | if param.grad is not None: 198 | training_stats.report("Loss/any_nan", torch.isnan(param.grad).any().item()) 199 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 200 | optimizer.step() 201 | 202 | # Update EMA. 203 | ema_halflife_nimg = ema_halflife_kimg * 1000 204 | if ema_rampup_ratio is not None: 205 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 206 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 207 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 208 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 209 | 210 | # Perform maintenance tasks once per tick. 211 | cur_nimg += batch_size 212 | done = (cur_nimg >= total_kimg * 1000) 213 | if (not done) and (cur_tick != 0) and \ 214 | (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 215 | continue 216 | 217 | # Print status line, accumulating the same information in training_stats. 218 | tick_end_time = time.time() 219 | fields = [] 220 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 221 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 222 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 223 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 224 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 225 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 226 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 227 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 228 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 229 | torch.cuda.reset_peak_memory_stats() 230 | dist.print0(' '.join(fields)) 231 | 232 | # Check for abort. 233 | if (not done) and dist.should_stop(): 234 | done = True 235 | dist.print0() 236 | dist.print0('Aborting...') 237 | 238 | # Save network snapshot. 239 | if (snapshot_ticks is not None) and \ 240 | (done or cur_tick % snapshot_ticks == 0) and \ 241 | (cur_tick != 0): 242 | data = dict( 243 | ema=ema, 244 | loss_fn=loss_fn, 245 | augment_pipe=augment_pipe, 246 | dataset_kwargs=dict(dataset_kwargs), 247 | cur_nimg=cur_nimg # store iteration # here 248 | ) 249 | for key, value in data.items(): 250 | if isinstance(value, torch.nn.Module): 251 | value = copy.deepcopy(value).eval().requires_grad_(False) 252 | misc.check_ddp_consistency(value) 253 | data[key] = value.cpu() 254 | del value # conserve memory 255 | if dist.get_rank() == 0: 256 | # Save with filename corresponding to iteration. 257 | #with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 258 | # pickle.dump(data, f) 259 | # Chris B: I don't want disk space to scale linearly with time, just 260 | # save one snapshot. In the future, we can also do them for each 261 | # validation metric like in DiffusionOperators. 262 | with open(os.path.join(run_dir, f'network-snapshot.pkl'), 'wb') as f: 263 | pickle.dump(data, f) 264 | 265 | del data # conserve memory 266 | 267 | # Chris B: write out stats here. I want this to be synchronised with the 268 | # saving of the checkpoint, so we don't get duplicate entries in the stats 269 | # file. 270 | if dist.get_rank() == 0: 271 | if stats_jsonl is None: 272 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') 273 | stats_jsonl.write( 274 | json.dumps( 275 | dict(training_stats.default_collector.as_dict(), timestamp=time.time()) 276 | ) + '\n' 277 | ) 278 | stats_jsonl.flush() 279 | 280 | # Save full dump of the training state. 281 | # Chris B: I don't want to deal with this in my code yet 282 | #if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 283 | # torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.#join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 284 | 285 | # Update logs. 286 | training_stats.default_collector.update() 287 | dist.update_progress(cur_nimg // 1000, total_kimg) 288 | 289 | # Update state. 290 | cur_tick += 1 291 | tick_start_nimg = cur_nimg 292 | tick_start_time = time.time() 293 | maintenance_time = tick_start_time - tick_end_time 294 | if done: 295 | break 296 | 297 | # Done. 298 | dist.print0() 299 | dist.print0('Exiting...') 300 | 301 | #---------------------------------------------------------------------------- 302 | --------------------------------------------------------------------------------