├── LICENSE ├── README.md ├── datasets ├── Lagr_u1c_diffusion-demo.h5 ├── Lagr_u3c_diffusion-demo.h5 └── preprocessing-lagr_u1c-diffusion.py ├── guided_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py ├── turb_datasets.py └── unet.py ├── ppdm ├── ppdm │ ├── __init__.py │ └── ppdm.py └── setup.py ├── resources ├── Sampling.png └── Training.png ├── scripts ├── turb_losses.py ├── turb_model.py ├── turb_sample.py ├── turb_sample_history.py └── turb_train.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Smart-TURB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion-lagr 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10563386.svg)](https://doi.org/10.5281/zenodo.10563386) 4 | 5 | This is the codebase for [Synthetic Lagrangian Turbulence by Generative Diffusion Models](https://arxiv.org/abs/2307.08529). 6 | 7 | This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications specifically tailored to adapt the Lagrangian turbulence data in the Smart-TURB portal http://smart-turb.roma2.infn.it, under the [TURB-Lagr](https://smart-turb.roma2.infn.it/init/routes/#/logging/view_dataset/2/tabmeta) dataset. 8 | 9 | # Usage 10 | 11 | ## Development Environment 12 | 13 | Our software was developed and tested on a system with the following specifications: 14 | 15 | - **Operating System**: Ubuntu 20.04.4 LTS 16 | - **Python Version**: 3.7.16 17 | - **PyTorch Version**: 1.13.1 18 | - **MPI Implementation**: OpenRTE 4.0.2 19 | - **CUDA Version**: 11.5 20 | - **GPU Model**: NVIDIA A100 21 | 22 | ## Installation 23 | 24 | We recommend using a Conda environment to manage dependencies. The code relies on the MPI library and [parallel h5py](https://docs.h5py.org/en/stable/mpi.html). Note, however, that the use of MPI is not mandatory for all functionalities. See details in [Training](#Training) and [Sampling](#Sampling) for more information. After setting up your environment, clone this repository and navigate to it in your terminal. Then run: 25 | 26 | ``` 27 | pip install -e . 28 | ``` 29 | 30 | This should install the `guided_diffusion` python package that the scripts depend on. 31 | 32 | ### Troubleshooting Installation 33 | 34 | During the installation process, you might encounter a couple of known issues. Here are some tips to help you resolve them: 35 | 36 | 1. **Parallel h5py Installation**: Setting up parallel h5py can sometimes pose challenges. As a workaround, you can install the serial version of h5py, comment out the specific lines of code found [here](https://github.com/SmartTURB/diffusion-lagr/blob/master/guided_diffusion/turb_datasets.py#L34) and [here](https://github.com/SmartTURB/diffusion-lagr/blob/master/guided_diffusion/turb_datasets.py#L76), and uncomment the lines immediately following them. 37 | 2. **PyTorch Installation**: In our experience, sometimes it's necessary to reinstall PyTorch depending on your system environment. You can download and install PyTorch from their [official website](https://pytorch.org/). 38 | 39 | ## Preparing Data 40 | 41 | The data needed for this project can be obtained from the Smart-TURB portal. Follow these steps to download the data: 42 | 43 | 1. Visit the [Smart-TURB portal](http://smart-turb.roma2.infn.it). 44 | 2. Navigate to `TURB-Lagr` under the `Datasets` section. 45 | 3. Click on `Files` -> `data` -> `Lagr_u3c_diffusion.h5`, 46 | 47 | which can also be accessed directly by clicking on this [link](https://smart-turb.roma2.infn.it/init/files/api_file_download/1/___FOLDERSEPARATOR___scratch___FOLDERSEPARATOR___smartturb___FOLDERSEPARATOR___tov___FOLDERSEPARATOR___turb-lagr___FOLDERSEPARATOR___data___FOLDERSEPARATOR___Lagr_u3c_diffusion___POINT___h5/15728642096). 48 | 49 | ### Data Details and Example Usage 50 | 51 | Here is an example of how you can read the data: 52 | 53 | ```python 54 | import h5py 55 | import numpy as np 56 | 57 | with h5py.File('datasets/Lagr_u3c_diffusion.h5', 'r') as h5f: 58 | rx0 = np.array(h5f.get('min')) 59 | rx1 = np.array(h5f.get('max')) 60 | u3c = np.array(h5f.get('train')) 61 | 62 | velocities = (u3c+1)*(rx1-rx0)/2 + rx0 63 | ``` 64 | 65 | The `u3c` variable is a 3D array with the shape `(327680, 2000, 3)`, representing 327,680 trajectories, each of size 2000, for 3 velocity components. Each component is normalized to the range `[-1, 1]` using the min-max method. The `rx0` and `rx1` variables store the minimum and maximum values for each of the 3 components, respectively. The last line of the code sample retrieves the original velocities from the normalized data. 66 | 67 | The data file `Lagr_u3c_diffusion.h5` mentioned above is used for training the `DM-3c` model. For training `DM-1c`, we do not distinguish between the 3 velocity components, thereby tripling the number of trajectories. You can generate the appropriate data by using the [`datasets/preprocessing-lagr_u1c-diffusion.py`](https://github.com/SmartTURB/diffusion-lagr/blob/master/datasets/preprocessing-lagr_u1c-diffusion.py) script. This script concatenates the three velocity components, applies min-max normalization, and saves the result as `Lagr_u1c_diffusion.h5`. 68 | 69 | ## Training 70 | 71 | 72 | 73 | To train your model, you'll first need to determine certain hyperparameters. We can categorize these hyperparameters into three groups: model architecture, diffusion process, and training flags. Detailed information about these can be found in the [parent repository](https://github.com/openai/improved-diffusion). 74 | 75 | The run flags for the two models featured in our paper are as follows (please refer to Fig.2 in [the paper](https://arxiv.org/abs/2307.08529)): 76 | 77 | For the `DM-1c` model, use the following flags: 78 | 79 | ```sh 80 | DATA_FLAGS="--dataset_path datasets/Lagr_u1c_diffusion.h5 --dataset_name train" 81 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 1 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4" 82 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1" 83 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64" 84 | ``` 85 | 86 | For the `DM-3c` model, you only need to modify `--dataset_path` to `../datasets/Lagr_u3c_diffusion.h5` and `--in_channels` to `3`: 87 | 88 | ```sh 89 | DATA_FLAGS="--dataset_path datasets/Lagr_u3c_diffusion.h5 --dataset_name train" 90 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 3 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4" 91 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1" 92 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64" 93 | ``` 94 | 95 | After defining your hyperparameters, you can initiate an experiment using the following command: 96 | 97 | ```sh 98 | mpiexec -n $NUM_GPUS python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 99 | ``` 100 | 101 | The training process is distributed, and for our model, we set `$NUM_GPUS` to 4. Note that the `--batch_size` flag represents the batch size on each GPU, so the real batch size is `$NUM_GPUS * batch_size = 256`, as reported in the paper (Fig.2). 102 | 103 | The log files and model checkpoints will be saved to a logging directory specified by the `OPENAI_LOGDIR` environment variable. If this variable is not set, a temporary directory in `/tmp` will be created and used instead. 104 | 105 | ### Demo 106 | 107 | To assist with testing the software installation and understanding the hyperparameters mentioned above, we have provided two smaller datasets: `datasets/Lag_u1c_diffusion-demo.h5` and `datasets/Lag_u3c_diffusion-demo.h5`. The `train` dataset within these files has shapes of (768, 2000, 1) and (256, 2000, 3), respectively. 108 | 109 | To run the demo, use the same flags as for the `DM-1c` and `DM-3c` models above, ensuring that you modify the `--dataset_path` flag to the appropriate demo dataset. 110 | 111 | For the `DM-1c` model: 112 | 113 | ```sh 114 | # Set the flags 115 | DATA_FLAGS="--dataset_path datasets/Lagr_u1c_diffusion-demo.h5 --dataset_name train" 116 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 1 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4" 117 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1" 118 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64" 119 | 120 | # Training command 121 | python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 122 | ``` 123 | 124 | For the `DM-3c` model: 125 | 126 | ```sh 127 | # Set the flags 128 | DATA_FLAGS="--dataset_path datasets/Lagr_u3c_diffusion-demo.h5 --dataset_name train" 129 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 3 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4" 130 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1" 131 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64" 132 | 133 | # Training command 134 | python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 135 | ``` 136 | 137 | Remember, for this demo, you can simplify the run by using the serial version of h5py as described in [Parallel h5py Installation](#h5py-installation). 138 | 139 | ## Sampling: 140 | 141 | 142 | 143 | The training script from the previous section stores checkpoints as `.pt` files within the designated logging directory. These checkpoint files will follow naming patterns such as `ema_0.9999_200000.pt` or `model200000.pt`. For improved sampling results, it's advised to sample from the Exponential Moving Average (EMA) models. 144 | 145 | Before sampling, set `SAMPLE_FLAGS` to specify the number of samples `--num_samples`, batch size `--batch_size`, and the path to the model `--model_path`. For example: 146 | 147 | ```sh 148 | SAMPLE_FLAGS="--num_samples 179200 --batch_size 64 --model_path ema_0.9999_250000.pt" 149 | ``` 150 | 151 | Then, run the following command: 152 | 153 | ```sh 154 | python scripts/turb_sample.py $SAMPLE_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS 155 | ``` 156 | 157 | After sampling with the above command, it will generate a file named `samples_179200x2000x3.npz` (for `DM-3c` as an example). You can use the following code to read and retrieve the generated velocities: 158 | 159 | ```python 160 | import h5py 161 | import numpy as np 162 | 163 | with h5py.File('datasets/Lagr_u3c_diffusion.h5', 'r') as h5f: 164 | rx0 = np.array(h5f.get('min')) 165 | rx1 = np.array(h5f.get('max')) 166 | 167 | u3c = (np.load('samples_179200x2000x3.npz')['arr_0']+1)*(rx1-rx0)/2 + rx0 168 | ``` 169 | 170 | Just like for training, you can use multiple GPUs for sampling. Please note that the `$MODEL_FLAGS` and `$DIFFUSION_FLAGS` should be the same as those used in training. 171 | 172 | In training the DM-1c and DM-3c models, we utilized four Nvidia A100 GPUs for periods of one and two days, respectively. Acknowledging that extensive computational demands could be a bottleneck for users, we have provided the checkpoints used in the paper, accessible via the following links: [`DM-1c`](https://www.dropbox.com/scl/fi/ox7ytzyh2qcqswoqingiv/ema_0.9999_250000.pt?rlkey=1ld2ccsttj6s6f5tftvcntu0e&dl=0) and [`DM-3c`](https://www.dropbox.com/scl/fi/o7aun6o7lfk99eikds4c2/ema_0.9999_400000.pt?rlkey=mkxaxs0kw4ighb330a3yca0mo&dl=0). 173 | -------------------------------------------------------------------------------- /datasets/Lagr_u1c_diffusion-demo.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/datasets/Lagr_u1c_diffusion-demo.h5 -------------------------------------------------------------------------------- /datasets/Lagr_u3c_diffusion-demo.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/datasets/Lagr_u3c_diffusion-demo.h5 -------------------------------------------------------------------------------- /datasets/preprocessing-lagr_u1c-diffusion.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | with h5py.File('Lagr_u3c_diffusion.h5', 'r') as h5f: 5 | rx0 = np.array(h5f.get('min')) 6 | rx1 = np.array(h5f.get('max')) 7 | u3c = np.array(h5f.get('train')) 8 | 9 | x_train = (u3c+1)*(rx1-rx0)/2 + rx0 10 | x_train = np.concatenate((x_train[..., :1], x_train[..., 1:2], x_train[..., -1:])) 11 | 12 | rx0, rx1 = np.amin(x_train), np.amax(x_train) 13 | x_train = 2*(x_train-rx0)/(rx1-rx0) - 1 14 | 15 | with h5py.File('Lagr_u1c_diffusion.h5', 'w') as hf: 16 | hf.create_dataset('min', data=rx0) 17 | hf.create_dataset('max', data=rx1) 18 | hf.create_dataset('train', data=x_train) 19 | -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /guided_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | 22 | The beta schedule library consists of beta schedules which remain similar 23 | in the limit of num_diffusion_timesteps. 24 | Beta schedules may be added, but should not be removed or changed once 25 | they are committed to maintain backwards compatibility. 26 | """ 27 | if schedule_name == "linear": 28 | # Linear schedule from Ho et al, extended to work for any number of 29 | # diffusion steps. 30 | scale = 1000 / num_diffusion_timesteps 31 | beta_start = scale * 0.0001 32 | beta_end = scale * 0.02 33 | return np.linspace( 34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 35 | ) 36 | elif schedule_name == "cosine": 37 | return betas_for_alpha_bar( 38 | num_diffusion_timesteps, 39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 40 | ) 41 | elif schedule_name.startswith("power"): 42 | power = int(schedule_name[5:]) 43 | return betas_for_alpha_bar( 44 | num_diffusion_timesteps, 45 | lambda t: 1 - t**power, 46 | ) 47 | elif schedule_name.startswith("exp"): 48 | t0 = float(schedule_name[3:]) 49 | return betas_for_alpha_bar( 50 | num_diffusion_timesteps, 51 | lambda t: 2 - math.exp((t0 + math.log(2)) * t - t0), 52 | ) 53 | elif schedule_name.startswith("tanh"): 54 | t0, t1 = schedule_name.split(",") 55 | t0, t1 = float(t0[4:]), float(t1) 56 | return betas_for_alpha_bar( 57 | num_diffusion_timesteps, 58 | lambda t: -math.tanh((t0 + t1) * t - t0) + math.tanh(t1), 59 | ) 60 | else: 61 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 62 | 63 | 64 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 65 | """ 66 | Create a beta schedule that discretizes the given alpha_t_bar function, 67 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 68 | 69 | :param num_diffusion_timesteps: the number of betas to produce. 70 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 71 | produces the cumulative product of (1-beta) up to that 72 | part of the diffusion process. 73 | :param max_beta: the maximum beta to use; use values lower than 1 to 74 | prevent singularities. 75 | """ 76 | betas = [] 77 | for i in range(num_diffusion_timesteps): 78 | t1 = i / num_diffusion_timesteps 79 | t2 = (i + 1) / num_diffusion_timesteps 80 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 81 | return np.array(betas) 82 | 83 | 84 | class ModelMeanType(enum.Enum): 85 | """ 86 | Which type of output the model predicts. 87 | """ 88 | 89 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 90 | START_X = enum.auto() # the model predicts x_0 91 | EPSILON = enum.auto() # the model predicts epsilon 92 | 93 | 94 | class ModelVarType(enum.Enum): 95 | """ 96 | What is used as the model's output variance. 97 | 98 | The LEARNED_RANGE option has been added to allow the model to predict 99 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 100 | """ 101 | 102 | LEARNED = enum.auto() 103 | FIXED_SMALL = enum.auto() 104 | FIXED_LARGE = enum.auto() 105 | LEARNED_RANGE = enum.auto() 106 | 107 | 108 | class LossType(enum.Enum): 109 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 110 | RESCALED_MSE = ( 111 | enum.auto() 112 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 113 | KL = enum.auto() # use the variational lower-bound 114 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 115 | 116 | def is_vb(self): 117 | return self == LossType.KL or self == LossType.RESCALED_KL 118 | 119 | 120 | class GaussianDiffusion: 121 | """ 122 | Utilities for training and sampling diffusion models. 123 | 124 | Ported directly from here, and then adapted over time to further experimentation. 125 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 126 | 127 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 128 | starting at T and going to 1. 129 | :param model_mean_type: a ModelMeanType determining what the model outputs. 130 | :param model_var_type: a ModelVarType determining how variance is output. 131 | :param loss_type: a LossType determining the loss function to use. 132 | :param rescale_timesteps: if True, pass floating point timesteps into the 133 | model so that they are always scaled like in the 134 | original paper (0 to 1000). 135 | """ 136 | 137 | def __init__( 138 | self, 139 | *, 140 | betas, 141 | model_mean_type, 142 | model_var_type, 143 | loss_type, 144 | rescale_timesteps=False, 145 | ): 146 | self.model_mean_type = model_mean_type 147 | self.model_var_type = model_var_type 148 | self.loss_type = loss_type 149 | self.rescale_timesteps = rescale_timesteps 150 | 151 | # Use float64 for accuracy. 152 | betas = np.array(betas, dtype=np.float64) 153 | self.betas = betas 154 | assert len(betas.shape) == 1, "betas must be 1-D" 155 | assert (betas > 0).all() and (betas <= 1).all() 156 | 157 | self.num_timesteps = int(betas.shape[0]) 158 | 159 | alphas = 1.0 - betas 160 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 161 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 162 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 163 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 164 | 165 | # calculations for diffusion q(x_t | x_{t-1}) and others 166 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 167 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 168 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 169 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 170 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 171 | 172 | # calculations for posterior q(x_{t-1} | x_t, x_0) 173 | self.posterior_variance = ( 174 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 175 | ) 176 | # log calculation clipped because the posterior variance is 0 at the 177 | # beginning of the diffusion chain. 178 | self.posterior_log_variance_clipped = np.log( 179 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 180 | ) 181 | self.posterior_mean_coef1 = ( 182 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 183 | ) 184 | self.posterior_mean_coef2 = ( 185 | (1.0 - self.alphas_cumprod_prev) 186 | * np.sqrt(alphas) 187 | / (1.0 - self.alphas_cumprod) 188 | ) 189 | 190 | def q_mean_variance(self, x_start, t): 191 | """ 192 | Get the distribution q(x_t | x_0). 193 | 194 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 196 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 197 | """ 198 | mean = ( 199 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 200 | ) 201 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 202 | log_variance = _extract_into_tensor( 203 | self.log_one_minus_alphas_cumprod, t, x_start.shape 204 | ) 205 | return mean, variance, log_variance 206 | 207 | def q_sample(self, x_start, t, noise=None): 208 | """ 209 | Diffuse the data for a given number of diffusion steps. 210 | 211 | In other words, sample from q(x_t | x_0). 212 | 213 | :param x_start: the initial data batch. 214 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 215 | :param noise: if specified, the split-out normal noise. 216 | :return: A noisy version of x_start. 217 | """ 218 | if noise is None: 219 | noise = th.randn_like(x_start) 220 | assert noise.shape == x_start.shape 221 | return ( 222 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 223 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 224 | * noise 225 | ) 226 | 227 | def q_posterior_mean_variance(self, x_start, x_t, t): 228 | """ 229 | Compute the mean and variance of the diffusion posterior: 230 | 231 | q(x_{t-1} | x_t, x_0) 232 | 233 | """ 234 | assert x_start.shape == x_t.shape 235 | posterior_mean = ( 236 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 237 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 238 | ) 239 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 240 | posterior_log_variance_clipped = _extract_into_tensor( 241 | self.posterior_log_variance_clipped, t, x_t.shape 242 | ) 243 | assert ( 244 | posterior_mean.shape[0] 245 | == posterior_variance.shape[0] 246 | == posterior_log_variance_clipped.shape[0] 247 | == x_start.shape[0] 248 | ) 249 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 250 | 251 | def p_mean_variance( 252 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 253 | ): 254 | """ 255 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 256 | the initial x, x_0. 257 | 258 | :param model: the model, which takes a signal and a batch of timesteps 259 | as input. 260 | :param x: the [N x C x ...] tensor at time t. 261 | :param t: a 1-D Tensor of timesteps. 262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 263 | :param denoised_fn: if not None, a function which applies to the 264 | x_start prediction before it is used to sample. Applies before 265 | clip_denoised. 266 | :param model_kwargs: if not None, a dict of extra keyword arguments to 267 | pass to the model. This can be used for conditioning. 268 | :return: a dict with the following keys: 269 | - 'mean': the model mean output. 270 | - 'variance': the model variance output. 271 | - 'log_variance': the log of 'variance'. 272 | - 'pred_xstart': the prediction for x_0. 273 | """ 274 | if model_kwargs is None: 275 | model_kwargs = {} 276 | 277 | B, C = x.shape[:2] 278 | assert t.shape == (B,) 279 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 280 | 281 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 282 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 283 | model_output, model_var_values = th.split(model_output, C, dim=1) 284 | if self.model_var_type == ModelVarType.LEARNED: 285 | model_log_variance = model_var_values 286 | model_variance = th.exp(model_log_variance) 287 | else: 288 | min_log = _extract_into_tensor( 289 | self.posterior_log_variance_clipped, t, x.shape 290 | ) 291 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 292 | # The model_var_values is [-1, 1] for [min_var, max_var]. 293 | frac = (model_var_values + 1) / 2 294 | model_log_variance = frac * max_log + (1 - frac) * min_log 295 | model_variance = th.exp(model_log_variance) 296 | else: 297 | model_variance, model_log_variance = { 298 | # for fixedlarge, we set the initial (log-)variance like so 299 | # to get a better decoder log likelihood. 300 | ModelVarType.FIXED_LARGE: ( 301 | np.append(self.posterior_variance[1], self.betas[1:]), 302 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 303 | ), 304 | ModelVarType.FIXED_SMALL: ( 305 | self.posterior_variance, 306 | self.posterior_log_variance_clipped, 307 | ), 308 | }[self.model_var_type] 309 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 310 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 311 | 312 | def process_xstart(x): 313 | if denoised_fn is not None: 314 | x = denoised_fn(x) 315 | if clip_denoised: 316 | return x.clamp(-1, 1) 317 | return x 318 | 319 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 320 | pred_xstart = process_xstart( 321 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 322 | ) 323 | model_mean = model_output 324 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 325 | if self.model_mean_type == ModelMeanType.START_X: 326 | pred_xstart = process_xstart(model_output) 327 | else: 328 | pred_xstart = process_xstart( 329 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 330 | ) 331 | model_mean, _, _ = self.q_posterior_mean_variance( 332 | x_start=pred_xstart, x_t=x, t=t 333 | ) 334 | else: 335 | raise NotImplementedError(self.model_mean_type) 336 | 337 | assert ( 338 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 339 | ) 340 | return { 341 | "mean": model_mean, 342 | "variance": model_variance, 343 | "log_variance": model_log_variance, 344 | "pred_xstart": pred_xstart, 345 | } 346 | 347 | def _predict_xstart_from_eps(self, x_t, t, eps): 348 | assert x_t.shape == eps.shape 349 | return ( 350 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 351 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 352 | ) 353 | 354 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 355 | assert x_t.shape == xprev.shape 356 | return ( # (xprev - coef2*x_t) / coef1 357 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 358 | - _extract_into_tensor( 359 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 360 | ) 361 | * x_t 362 | ) 363 | 364 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 365 | return ( 366 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 367 | - pred_xstart 368 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 369 | 370 | def _scale_timesteps(self, t): 371 | if self.rescale_timesteps: 372 | return t.float() * (1000.0 / self.num_timesteps) 373 | return t 374 | 375 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 376 | """ 377 | Compute the mean for the previous step, given a function cond_fn that 378 | computes the gradient of a conditional log probability with respect to 379 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 380 | condition on y. 381 | 382 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 383 | """ 384 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) 385 | new_mean = ( 386 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 387 | ) 388 | return new_mean 389 | 390 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 391 | """ 392 | Compute what the p_mean_variance output would have been, should the 393 | model's score function be conditioned by cond_fn. 394 | 395 | See condition_mean() for details on cond_fn. 396 | 397 | Unlike condition_mean(), this instead uses the conditioning strategy 398 | from Song et al (2020). 399 | """ 400 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 401 | 402 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 403 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn( 404 | x, self._scale_timesteps(t), **model_kwargs 405 | ) 406 | 407 | out = p_mean_var.copy() 408 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 409 | out["mean"], _, _ = self.q_posterior_mean_variance( 410 | x_start=out["pred_xstart"], x_t=x, t=t 411 | ) 412 | return out 413 | 414 | def p_sample( 415 | self, 416 | model, 417 | x, 418 | t, 419 | clip_denoised=True, 420 | denoised_fn=None, 421 | cond_fn=None, 422 | model_kwargs=None, 423 | ): 424 | """ 425 | Sample x_{t-1} from the model at the given timestep. 426 | 427 | :param model: the model to sample from. 428 | :param x: the current tensor at x_{t-1}. 429 | :param t: the value of t, starting at 0 for the first diffusion step. 430 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 431 | :param denoised_fn: if not None, a function which applies to the 432 | x_start prediction before it is used to sample. 433 | :param cond_fn: if not None, this is a gradient function that acts 434 | similarly to the model. 435 | :param model_kwargs: if not None, a dict of extra keyword arguments to 436 | pass to the model. This can be used for conditioning. 437 | :return: a dict containing the following keys: 438 | - 'sample': a random sample from the model. 439 | - 'pred_xstart': a prediction of x_0. 440 | """ 441 | out = self.p_mean_variance( 442 | model, 443 | x, 444 | t, 445 | clip_denoised=clip_denoised, 446 | denoised_fn=denoised_fn, 447 | model_kwargs=model_kwargs, 448 | ) 449 | noise = th.randn_like(x) 450 | nonzero_mask = ( 451 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 452 | ) # no noise when t == 0 453 | if cond_fn is not None: 454 | out["mean"] = self.condition_mean( 455 | cond_fn, out, x, t, model_kwargs=model_kwargs 456 | ) 457 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 458 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 459 | 460 | def p_sample_loop( 461 | self, 462 | model, 463 | shape, 464 | noise=None, 465 | clip_denoised=True, 466 | denoised_fn=None, 467 | cond_fn=None, 468 | model_kwargs=None, 469 | device=None, 470 | progress=False, 471 | ): 472 | """ 473 | Generate samples from the model. 474 | 475 | :param model: the model module. 476 | :param shape: the shape of the samples, (N, C, H, W). 477 | :param noise: if specified, the noise from the encoder to sample. 478 | Should be of the same shape as `shape`. 479 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 480 | :param denoised_fn: if not None, a function which applies to the 481 | x_start prediction before it is used to sample. 482 | :param cond_fn: if not None, this is a gradient function that acts 483 | similarly to the model. 484 | :param model_kwargs: if not None, a dict of extra keyword arguments to 485 | pass to the model. This can be used for conditioning. 486 | :param device: if specified, the device to create the samples on. 487 | If not specified, use a model parameter's device. 488 | :param progress: if True, show a tqdm progress bar. 489 | :return: a non-differentiable batch of samples. 490 | """ 491 | final = None 492 | for sample in self.p_sample_loop_progressive( 493 | model, 494 | shape, 495 | noise=noise, 496 | clip_denoised=clip_denoised, 497 | denoised_fn=denoised_fn, 498 | cond_fn=cond_fn, 499 | model_kwargs=model_kwargs, 500 | device=device, 501 | progress=progress, 502 | ): 503 | final = sample 504 | return final["sample"] 505 | 506 | def p_sample_loop_history( 507 | self, 508 | model, 509 | shape, 510 | noise=None, 511 | clip_denoised=True, 512 | denoised_fn=None, 513 | cond_fn=None, 514 | model_kwargs=None, 515 | device=None, 516 | progress=False, 517 | ): 518 | """ 519 | Generate samples from the model with the denoising history. 520 | 521 | :param model: the model module. 522 | :param shape: the shape of the samples, (N, C, H, W). 523 | :param noise: if specified, the noise from the encoder to sample. 524 | Should be of the same shape as `shape`. 525 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 526 | :param denoised_fn: if not None, a function which applies to the 527 | x_start prediction before it is used to sample. 528 | :param cond_fn: if not None, this is a gradient function that acts 529 | similarly to the model. 530 | :param model_kwargs: if not None, a dict of extra keyword arguments to 531 | pass to the model. This can be used for conditioning. 532 | :param device: if specified, the device to create the samples on. 533 | If not specified, use a model parameter's device. 534 | :param progress: if True, show a tqdm progress bar. 535 | :return: a non-differentiable batch of samples with the denoising history. 536 | """ 537 | if noise is None: 538 | assert isinstance(shape, (tuple, list)) 539 | if device is None: 540 | device = next(model.parameters()).device 541 | noise = th.randn(*shape, device=device) 542 | 543 | sample_history = [noise] 544 | for sample in self.p_sample_loop_progressive( 545 | model, 546 | shape, 547 | noise=noise, 548 | clip_denoised=clip_denoised, 549 | denoised_fn=denoised_fn, 550 | cond_fn=cond_fn, 551 | model_kwargs=model_kwargs, 552 | device=device, 553 | progress=progress, 554 | ): 555 | sample_history.append(sample["sample"]) 556 | return th.stack(sample_history, dim=1) 557 | 558 | def p_sample_loop_progressive( 559 | self, 560 | model, 561 | shape, 562 | noise=None, 563 | clip_denoised=True, 564 | denoised_fn=None, 565 | cond_fn=None, 566 | model_kwargs=None, 567 | device=None, 568 | progress=False, 569 | ): 570 | """ 571 | Generate samples from the model and yield intermediate samples from 572 | each timestep of diffusion. 573 | 574 | Arguments are the same as p_sample_loop(). 575 | Returns a generator over dicts, where each dict is the return value of 576 | p_sample(). 577 | """ 578 | if device is None: 579 | device = next(model.parameters()).device 580 | assert isinstance(shape, (tuple, list)) 581 | if noise is not None: 582 | img = noise 583 | else: 584 | img = th.randn(*shape, device=device) 585 | indices = list(range(self.num_timesteps))[::-1] 586 | 587 | if progress: 588 | # Lazy import so that we don't depend on tqdm. 589 | from tqdm.auto import tqdm 590 | 591 | indices = tqdm(indices) 592 | 593 | for i in indices: 594 | t = th.tensor([i] * shape[0], device=device) 595 | with th.no_grad(): 596 | out = self.p_sample( 597 | model, 598 | img, 599 | t, 600 | clip_denoised=clip_denoised, 601 | denoised_fn=denoised_fn, 602 | cond_fn=cond_fn, 603 | model_kwargs=model_kwargs, 604 | ) 605 | yield out 606 | img = out["sample"] 607 | 608 | def ddim_sample( 609 | self, 610 | model, 611 | x, 612 | t, 613 | clip_denoised=True, 614 | denoised_fn=None, 615 | cond_fn=None, 616 | model_kwargs=None, 617 | eta=0.0, 618 | ): 619 | """ 620 | Sample x_{t-1} from the model using DDIM. 621 | 622 | Same usage as p_sample(). 623 | """ 624 | out = self.p_mean_variance( 625 | model, 626 | x, 627 | t, 628 | clip_denoised=clip_denoised, 629 | denoised_fn=denoised_fn, 630 | model_kwargs=model_kwargs, 631 | ) 632 | if cond_fn is not None: 633 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 634 | 635 | # Usually our model outputs epsilon, but we re-derive it 636 | # in case we used x_start or x_prev prediction. 637 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 638 | 639 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 640 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 641 | sigma = ( 642 | eta 643 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 644 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 645 | ) 646 | # Equation 12. 647 | noise = th.randn_like(x) 648 | mean_pred = ( 649 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 650 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 651 | ) 652 | nonzero_mask = ( 653 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 654 | ) # no noise when t == 0 655 | sample = mean_pred + nonzero_mask * sigma * noise 656 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 657 | 658 | def ddim_reverse_sample( 659 | self, 660 | model, 661 | x, 662 | t, 663 | clip_denoised=True, 664 | denoised_fn=None, 665 | model_kwargs=None, 666 | eta=0.0, 667 | ): 668 | """ 669 | Sample x_{t+1} from the model using DDIM reverse ODE. 670 | """ 671 | assert eta == 0.0, "Reverse ODE only for deterministic path" 672 | out = self.p_mean_variance( 673 | model, 674 | x, 675 | t, 676 | clip_denoised=clip_denoised, 677 | denoised_fn=denoised_fn, 678 | model_kwargs=model_kwargs, 679 | ) 680 | # Usually our model outputs epsilon, but we re-derive it 681 | # in case we used x_start or x_prev prediction. 682 | eps = ( 683 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 684 | - out["pred_xstart"] 685 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 686 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 687 | 688 | # Equation 12. reversed 689 | mean_pred = ( 690 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 691 | + th.sqrt(1 - alpha_bar_next) * eps 692 | ) 693 | 694 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 695 | 696 | def ddim_sample_loop( 697 | self, 698 | model, 699 | shape, 700 | noise=None, 701 | clip_denoised=True, 702 | denoised_fn=None, 703 | cond_fn=None, 704 | model_kwargs=None, 705 | device=None, 706 | progress=False, 707 | eta=0.0, 708 | ): 709 | """ 710 | Generate samples from the model using DDIM. 711 | 712 | Same usage as p_sample_loop(). 713 | """ 714 | final = None 715 | for sample in self.ddim_sample_loop_progressive( 716 | model, 717 | shape, 718 | noise=noise, 719 | clip_denoised=clip_denoised, 720 | denoised_fn=denoised_fn, 721 | cond_fn=cond_fn, 722 | model_kwargs=model_kwargs, 723 | device=device, 724 | progress=progress, 725 | eta=eta, 726 | ): 727 | final = sample 728 | return final["sample"] 729 | 730 | def ddim_sample_loop_progressive( 731 | self, 732 | model, 733 | shape, 734 | noise=None, 735 | clip_denoised=True, 736 | denoised_fn=None, 737 | cond_fn=None, 738 | model_kwargs=None, 739 | device=None, 740 | progress=False, 741 | eta=0.0, 742 | ): 743 | """ 744 | Use DDIM to sample from the model and yield intermediate samples from 745 | each timestep of DDIM. 746 | 747 | Same usage as p_sample_loop_progressive(). 748 | """ 749 | if device is None: 750 | device = next(model.parameters()).device 751 | assert isinstance(shape, (tuple, list)) 752 | if noise is not None: 753 | img = noise 754 | else: 755 | img = th.randn(*shape, device=device) 756 | indices = list(range(self.num_timesteps))[::-1] 757 | 758 | if progress: 759 | # Lazy import so that we don't depend on tqdm. 760 | from tqdm.auto import tqdm 761 | 762 | indices = tqdm(indices) 763 | 764 | for i in indices: 765 | t = th.tensor([i] * shape[0], device=device) 766 | with th.no_grad(): 767 | out = self.ddim_sample( 768 | model, 769 | img, 770 | t, 771 | clip_denoised=clip_denoised, 772 | denoised_fn=denoised_fn, 773 | cond_fn=cond_fn, 774 | model_kwargs=model_kwargs, 775 | eta=eta, 776 | ) 777 | yield out 778 | img = out["sample"] 779 | 780 | def _vb_terms_bpd( 781 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 782 | ): 783 | """ 784 | Get a term for the variational lower-bound. 785 | 786 | The resulting units are bits (rather than nats, as one might expect). 787 | This allows for comparison to other papers. 788 | 789 | :return: a dict with the following keys: 790 | - 'output': a shape [N] tensor of NLLs or KLs. 791 | - 'pred_xstart': the x_0 predictions. 792 | """ 793 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 794 | x_start=x_start, x_t=x_t, t=t 795 | ) 796 | out = self.p_mean_variance( 797 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 798 | ) 799 | kl = normal_kl( 800 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 801 | ) 802 | kl = mean_flat(kl) / np.log(2.0) 803 | 804 | decoder_nll = -discretized_gaussian_log_likelihood( 805 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 806 | ) 807 | assert decoder_nll.shape == x_start.shape 808 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 809 | 810 | # At the first timestep return the decoder NLL, 811 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 812 | output = th.where((t == 0), decoder_nll, kl) 813 | return {"output": output, "pred_xstart": out["pred_xstart"]} 814 | 815 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 816 | """ 817 | Compute training losses for a single timestep. 818 | 819 | :param model: the model to evaluate loss on. 820 | :param x_start: the [N x C x ...] tensor of inputs. 821 | :param t: a batch of timestep indices. 822 | :param model_kwargs: if not None, a dict of extra keyword arguments to 823 | pass to the model. This can be used for conditioning. 824 | :param noise: if specified, the specific Gaussian noise to try to remove. 825 | :return: a dict with the key "loss" containing a tensor of shape [N]. 826 | Some mean or variance settings may also have other keys. 827 | """ 828 | if model_kwargs is None: 829 | model_kwargs = {} 830 | if noise is None: 831 | noise = th.randn_like(x_start) 832 | x_t = self.q_sample(x_start, t, noise=noise) 833 | 834 | terms = {} 835 | 836 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 837 | terms["loss"] = self._vb_terms_bpd( 838 | model=model, 839 | x_start=x_start, 840 | x_t=x_t, 841 | t=t, 842 | clip_denoised=False, 843 | model_kwargs=model_kwargs, 844 | )["output"] 845 | if self.loss_type == LossType.RESCALED_KL: 846 | terms["loss"] *= self.num_timesteps 847 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 848 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 849 | 850 | if self.model_var_type in [ 851 | ModelVarType.LEARNED, 852 | ModelVarType.LEARNED_RANGE, 853 | ]: 854 | B, C = x_t.shape[:2] 855 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 856 | model_output, model_var_values = th.split(model_output, C, dim=1) 857 | # Learn the variance using the variational bound, but don't let 858 | # it affect our mean prediction. 859 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 860 | terms["vb"] = self._vb_terms_bpd( 861 | model=lambda *args, r=frozen_out: r, 862 | x_start=x_start, 863 | x_t=x_t, 864 | t=t, 865 | clip_denoised=False, 866 | )["output"] 867 | if self.loss_type == LossType.RESCALED_MSE: 868 | # Divide by 1000 for equivalence with initial implementation. 869 | # Without a factor of 1/1000, the VB term hurts the MSE term. 870 | terms["vb"] *= self.num_timesteps / 1000.0 871 | 872 | target = { 873 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 874 | x_start=x_start, x_t=x_t, t=t 875 | )[0], 876 | ModelMeanType.START_X: x_start, 877 | ModelMeanType.EPSILON: noise, 878 | }[self.model_mean_type] 879 | assert model_output.shape == target.shape == x_start.shape 880 | terms["mse"] = mean_flat((target - model_output) ** 2) 881 | if "vb" in terms: 882 | terms["loss"] = terms["mse"] + terms["vb"] 883 | else: 884 | terms["loss"] = terms["mse"] 885 | else: 886 | raise NotImplementedError(self.loss_type) 887 | 888 | return terms 889 | 890 | def calc_losses_loop(self, model, x_start, model_kwargs=None): 891 | """ 892 | Compute training losses for all timesteps. 893 | 894 | :param model: the model to evaluate loss on. 895 | :param x_start: the [N x C x ...] tensor of inputs. 896 | :param model_kwargs: if not None, a dict of extra keyword arguments to 897 | pass to the model. This can be used for conditioning. 898 | :return: a dict with the key "loss" containing a tensor of shape [N x T]. 899 | Some mean or variance settings may also have other keys. 900 | """ 901 | device = x_start.device 902 | batch_size = x_start.shape[0] 903 | 904 | outs = [] 905 | for t in list(range(self.num_timesteps))[::-1]: 906 | t_batch = th.tensor([t] * batch_size, device=device) 907 | noise = th.randn_like(x_start) 908 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 909 | # Compute training losses at the current timestep 910 | with th.no_grad(): 911 | out = self.training_losses( 912 | model, 913 | x_start=x_start, 914 | t=t_batch, 915 | model_kwargs=model_kwargs, 916 | noise=noise 917 | ) 918 | outs.append(out) 919 | 920 | output = {key: [] for key in out.keys()} 921 | for out in outs: 922 | for key in out: 923 | output[key].append(out[key]) 924 | for key in output: 925 | output[key] = th.stack(output[key], dim=1) 926 | 927 | return output 928 | 929 | def _prior_bpd(self, x_start): 930 | """ 931 | Get the prior KL term for the variational lower-bound, measured in 932 | bits-per-dim. 933 | 934 | This term can't be optimized, as it only depends on the encoder. 935 | 936 | :param x_start: the [N x C x ...] tensor of inputs. 937 | :return: a batch of [N] KL values (in bits), one per batch element. 938 | """ 939 | batch_size = x_start.shape[0] 940 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 941 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 942 | kl_prior = normal_kl( 943 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 944 | ) 945 | return mean_flat(kl_prior) / np.log(2.0) 946 | 947 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 948 | """ 949 | Compute the entire variational lower-bound, measured in bits-per-dim, 950 | as well as other related quantities. 951 | 952 | :param model: the model to evaluate loss on. 953 | :param x_start: the [N x C x ...] tensor of inputs. 954 | :param clip_denoised: if True, clip denoised samples. 955 | :param model_kwargs: if not None, a dict of extra keyword arguments to 956 | pass to the model. This can be used for conditioning. 957 | 958 | :return: a dict containing the following keys: 959 | - total_bpd: the total variational lower-bound, per batch element. 960 | - prior_bpd: the prior term in the lower-bound. 961 | - vb: an [N x T] tensor of terms in the lower-bound. 962 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 963 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 964 | """ 965 | device = x_start.device 966 | batch_size = x_start.shape[0] 967 | 968 | vb = [] 969 | xstart_mse = [] 970 | mse = [] 971 | for t in list(range(self.num_timesteps))[::-1]: 972 | t_batch = th.tensor([t] * batch_size, device=device) 973 | noise = th.randn_like(x_start) 974 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 975 | # Calculate VLB term at the current timestep 976 | with th.no_grad(): 977 | out = self._vb_terms_bpd( 978 | model, 979 | x_start=x_start, 980 | x_t=x_t, 981 | t=t_batch, 982 | clip_denoised=clip_denoised, 983 | model_kwargs=model_kwargs, 984 | ) 985 | vb.append(out["output"]) 986 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 987 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 988 | mse.append(mean_flat((eps - noise) ** 2)) 989 | 990 | vb = th.stack(vb, dim=1) 991 | xstart_mse = th.stack(xstart_mse, dim=1) 992 | mse = th.stack(mse, dim=1) 993 | 994 | prior_bpd = self._prior_bpd(x_start) 995 | total_bpd = vb.sum(dim=1) + prior_bpd 996 | return { 997 | "total_bpd": total_bpd, 998 | "prior_bpd": prior_bpd, 999 | "vb": vb, 1000 | "xstart_mse": xstart_mse, 1001 | "mse": mse, 1002 | } 1003 | 1004 | 1005 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1006 | """ 1007 | Extract values from a 1-D numpy array for a batch of indices. 1008 | 1009 | :param arr: the 1-D numpy array. 1010 | :param timesteps: a tensor of indices into the array to extract. 1011 | :param broadcast_shape: a larger shape of K dimensions with the batch 1012 | dimension equal to the length of timesteps. 1013 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1014 | """ 1015 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1016 | while len(res.shape) < len(broadcast_shape): 1017 | res = res[..., None] 1018 | return res.expand(broadcast_shape) 1019 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------------------------------------------- 3 | Copied from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py 4 | 5 | Functions have been modified to make the inputs "dims", "in_channels" and "out_channels" of UNetModel changeable. 6 | ------------------------------------------------------------------------------- 7 | """ 8 | 9 | import argparse 10 | import inspect 11 | 12 | from . import gaussian_diffusion as gd 13 | from .respace import SpacedDiffusion, space_timesteps 14 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 15 | 16 | NUM_CLASSES = 1000 17 | 18 | 19 | def diffusion_defaults(): 20 | """ 21 | Defaults for image and classifier training. 22 | """ 23 | return dict( 24 | learn_sigma=False, 25 | diffusion_steps=1000, 26 | noise_schedule="linear", 27 | timestep_respacing="", 28 | use_kl=False, 29 | predict_xstart=False, 30 | rescale_timesteps=False, 31 | rescale_learned_sigmas=False, 32 | ) 33 | 34 | 35 | def classifier_defaults(): 36 | """ 37 | Defaults for classifier models. 38 | """ 39 | return dict( 40 | image_size=64, 41 | classifier_use_fp16=False, 42 | classifier_width=128, 43 | classifier_depth=2, 44 | classifier_attention_resolutions="32,16,8", # 16 45 | classifier_use_scale_shift_norm=True, # False 46 | classifier_resblock_updown=True, # False 47 | classifier_pool="attention", 48 | ) 49 | 50 | 51 | def model_and_diffusion_defaults(): 52 | """ 53 | Defaults for image training. 54 | """ 55 | res = dict( 56 | dims=2, 57 | image_size=64, 58 | in_channels=3, 59 | num_channels=128, 60 | num_res_blocks=2, 61 | num_heads=4, 62 | num_heads_upsample=-1, 63 | num_head_channels=-1, 64 | attention_resolutions="16,8", 65 | channel_mult="", 66 | dropout=0.0, 67 | class_cond=False, 68 | use_checkpoint=False, 69 | use_scale_shift_norm=True, 70 | resblock_updown=False, 71 | use_fp16=False, 72 | use_new_attention_order=False, 73 | ) 74 | res.update(diffusion_defaults()) 75 | return res 76 | 77 | 78 | def classifier_and_diffusion_defaults(): 79 | res = classifier_defaults() 80 | res.update(diffusion_defaults()) 81 | return res 82 | 83 | 84 | def create_model_and_diffusion( 85 | dims, 86 | image_size, 87 | in_channels, 88 | class_cond, 89 | learn_sigma, 90 | num_channels, 91 | num_res_blocks, 92 | channel_mult, 93 | num_heads, 94 | num_head_channels, 95 | num_heads_upsample, 96 | attention_resolutions, 97 | dropout, 98 | diffusion_steps, 99 | noise_schedule, 100 | timestep_respacing, 101 | use_kl, 102 | predict_xstart, 103 | rescale_timesteps, 104 | rescale_learned_sigmas, 105 | use_checkpoint, 106 | use_scale_shift_norm, 107 | resblock_updown, 108 | use_fp16, 109 | use_new_attention_order, 110 | ): 111 | model = create_model( 112 | dims, 113 | image_size, 114 | in_channels, 115 | num_channels, 116 | num_res_blocks, 117 | channel_mult=channel_mult, 118 | learn_sigma=learn_sigma, 119 | class_cond=class_cond, 120 | use_checkpoint=use_checkpoint, 121 | attention_resolutions=attention_resolutions, 122 | num_heads=num_heads, 123 | num_head_channels=num_head_channels, 124 | num_heads_upsample=num_heads_upsample, 125 | use_scale_shift_norm=use_scale_shift_norm, 126 | dropout=dropout, 127 | resblock_updown=resblock_updown, 128 | use_fp16=use_fp16, 129 | use_new_attention_order=use_new_attention_order, 130 | ) 131 | diffusion = create_gaussian_diffusion( 132 | steps=diffusion_steps, 133 | learn_sigma=learn_sigma, 134 | noise_schedule=noise_schedule, 135 | use_kl=use_kl, 136 | predict_xstart=predict_xstart, 137 | rescale_timesteps=rescale_timesteps, 138 | rescale_learned_sigmas=rescale_learned_sigmas, 139 | timestep_respacing=timestep_respacing, 140 | ) 141 | return model, diffusion 142 | 143 | 144 | def create_model( 145 | dims, 146 | image_size, 147 | in_channels, 148 | num_channels, 149 | num_res_blocks, 150 | channel_mult="", 151 | learn_sigma=False, 152 | class_cond=False, 153 | use_checkpoint=False, 154 | attention_resolutions="16", 155 | num_heads=1, 156 | num_head_channels=-1, 157 | num_heads_upsample=-1, 158 | use_scale_shift_norm=False, 159 | dropout=0, 160 | resblock_updown=False, 161 | use_fp16=False, 162 | use_new_attention_order=False, 163 | ): 164 | if channel_mult == "": 165 | if image_size == 512: 166 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 167 | elif image_size == 256: 168 | channel_mult = (1, 1, 2, 2, 4, 4) 169 | elif image_size == 128: 170 | channel_mult = (1, 1, 2, 3, 4) 171 | elif image_size == 64: 172 | channel_mult = (1, 2, 3, 4) 173 | else: 174 | raise ValueError(f"unsupported image size: {image_size}") 175 | else: 176 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 177 | 178 | attention_ds = [] 179 | for res in attention_resolutions.split(","): 180 | attention_ds.append(image_size // int(res)) 181 | 182 | return UNetModel( 183 | image_size=image_size, 184 | in_channels=in_channels, 185 | model_channels=num_channels, 186 | out_channels=(in_channels if not learn_sigma else 2*in_channels), 187 | num_res_blocks=num_res_blocks, 188 | attention_resolutions=tuple(attention_ds), 189 | dropout=dropout, 190 | channel_mult=channel_mult, 191 | dims=dims, 192 | num_classes=(NUM_CLASSES if class_cond else None), 193 | use_checkpoint=use_checkpoint, 194 | use_fp16=use_fp16, 195 | num_heads=num_heads, 196 | num_head_channels=num_head_channels, 197 | num_heads_upsample=num_heads_upsample, 198 | use_scale_shift_norm=use_scale_shift_norm, 199 | resblock_updown=resblock_updown, 200 | use_new_attention_order=use_new_attention_order, 201 | ) 202 | 203 | 204 | def create_classifier_and_diffusion( 205 | image_size, 206 | classifier_use_fp16, 207 | classifier_width, 208 | classifier_depth, 209 | classifier_attention_resolutions, 210 | classifier_use_scale_shift_norm, 211 | classifier_resblock_updown, 212 | classifier_pool, 213 | learn_sigma, 214 | diffusion_steps, 215 | noise_schedule, 216 | timestep_respacing, 217 | use_kl, 218 | predict_xstart, 219 | rescale_timesteps, 220 | rescale_learned_sigmas, 221 | ): 222 | classifier = create_classifier( 223 | image_size, 224 | classifier_use_fp16, 225 | classifier_width, 226 | classifier_depth, 227 | classifier_attention_resolutions, 228 | classifier_use_scale_shift_norm, 229 | classifier_resblock_updown, 230 | classifier_pool, 231 | ) 232 | diffusion = create_gaussian_diffusion( 233 | steps=diffusion_steps, 234 | learn_sigma=learn_sigma, 235 | noise_schedule=noise_schedule, 236 | use_kl=use_kl, 237 | predict_xstart=predict_xstart, 238 | rescale_timesteps=rescale_timesteps, 239 | rescale_learned_sigmas=rescale_learned_sigmas, 240 | timestep_respacing=timestep_respacing, 241 | ) 242 | return classifier, diffusion 243 | 244 | 245 | def create_classifier( 246 | image_size, 247 | classifier_use_fp16, 248 | classifier_width, 249 | classifier_depth, 250 | classifier_attention_resolutions, 251 | classifier_use_scale_shift_norm, 252 | classifier_resblock_updown, 253 | classifier_pool, 254 | ): 255 | if image_size == 512: 256 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 257 | elif image_size == 256: 258 | channel_mult = (1, 1, 2, 2, 4, 4) 259 | elif image_size == 128: 260 | channel_mult = (1, 1, 2, 3, 4) 261 | elif image_size == 64: 262 | channel_mult = (1, 2, 3, 4) 263 | else: 264 | raise ValueError(f"unsupported image size: {image_size}") 265 | 266 | attention_ds = [] 267 | for res in classifier_attention_resolutions.split(","): 268 | attention_ds.append(image_size // int(res)) 269 | 270 | return EncoderUNetModel( 271 | image_size=image_size, 272 | in_channels=3, 273 | model_channels=classifier_width, 274 | out_channels=1000, 275 | num_res_blocks=classifier_depth, 276 | attention_resolutions=tuple(attention_ds), 277 | channel_mult=channel_mult, 278 | use_fp16=classifier_use_fp16, 279 | num_head_channels=64, 280 | use_scale_shift_norm=classifier_use_scale_shift_norm, 281 | resblock_updown=classifier_resblock_updown, 282 | pool=classifier_pool, 283 | ) 284 | 285 | 286 | def sr_model_and_diffusion_defaults(): 287 | res = model_and_diffusion_defaults() 288 | res["large_size"] = 256 289 | res["small_size"] = 64 290 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 291 | for k in res.copy().keys(): 292 | if k not in arg_names: 293 | del res[k] 294 | return res 295 | 296 | 297 | def sr_create_model_and_diffusion( 298 | large_size, 299 | small_size, 300 | class_cond, 301 | learn_sigma, 302 | num_channels, 303 | num_res_blocks, 304 | num_heads, 305 | num_head_channels, 306 | num_heads_upsample, 307 | attention_resolutions, 308 | dropout, 309 | diffusion_steps, 310 | noise_schedule, 311 | timestep_respacing, 312 | use_kl, 313 | predict_xstart, 314 | rescale_timesteps, 315 | rescale_learned_sigmas, 316 | use_checkpoint, 317 | use_scale_shift_norm, 318 | resblock_updown, 319 | use_fp16, 320 | ): 321 | model = sr_create_model( 322 | large_size, 323 | small_size, 324 | num_channels, 325 | num_res_blocks, 326 | learn_sigma=learn_sigma, 327 | class_cond=class_cond, 328 | use_checkpoint=use_checkpoint, 329 | attention_resolutions=attention_resolutions, 330 | num_heads=num_heads, 331 | num_head_channels=num_head_channels, 332 | num_heads_upsample=num_heads_upsample, 333 | use_scale_shift_norm=use_scale_shift_norm, 334 | dropout=dropout, 335 | resblock_updown=resblock_updown, 336 | use_fp16=use_fp16, 337 | ) 338 | diffusion = create_gaussian_diffusion( 339 | steps=diffusion_steps, 340 | learn_sigma=learn_sigma, 341 | noise_schedule=noise_schedule, 342 | use_kl=use_kl, 343 | predict_xstart=predict_xstart, 344 | rescale_timesteps=rescale_timesteps, 345 | rescale_learned_sigmas=rescale_learned_sigmas, 346 | timestep_respacing=timestep_respacing, 347 | ) 348 | return model, diffusion 349 | 350 | 351 | def sr_create_model( 352 | large_size, 353 | small_size, 354 | num_channels, 355 | num_res_blocks, 356 | learn_sigma, 357 | class_cond, 358 | use_checkpoint, 359 | attention_resolutions, 360 | num_heads, 361 | num_head_channels, 362 | num_heads_upsample, 363 | use_scale_shift_norm, 364 | dropout, 365 | resblock_updown, 366 | use_fp16, 367 | ): 368 | _ = small_size # hack to prevent unused variable 369 | 370 | if large_size == 512: 371 | channel_mult = (1, 1, 2, 2, 4, 4) 372 | elif large_size == 256: 373 | channel_mult = (1, 1, 2, 2, 4, 4) 374 | elif large_size == 64: 375 | channel_mult = (1, 2, 3, 4) 376 | else: 377 | raise ValueError(f"unsupported large size: {large_size}") 378 | 379 | attention_ds = [] 380 | for res in attention_resolutions.split(","): 381 | attention_ds.append(large_size // int(res)) 382 | 383 | return SuperResModel( 384 | image_size=large_size, 385 | in_channels=3, 386 | model_channels=num_channels, 387 | out_channels=(3 if not learn_sigma else 6), 388 | num_res_blocks=num_res_blocks, 389 | attention_resolutions=tuple(attention_ds), 390 | dropout=dropout, 391 | channel_mult=channel_mult, 392 | num_classes=(NUM_CLASSES if class_cond else None), 393 | use_checkpoint=use_checkpoint, 394 | num_heads=num_heads, 395 | num_head_channels=num_head_channels, 396 | num_heads_upsample=num_heads_upsample, 397 | use_scale_shift_norm=use_scale_shift_norm, 398 | resblock_updown=resblock_updown, 399 | use_fp16=use_fp16, 400 | ) 401 | 402 | 403 | def create_gaussian_diffusion( 404 | *, 405 | steps=1000, 406 | learn_sigma=False, 407 | sigma_small=False, 408 | noise_schedule="linear", 409 | use_kl=False, 410 | predict_xstart=False, 411 | rescale_timesteps=False, 412 | rescale_learned_sigmas=False, 413 | timestep_respacing="", 414 | ): 415 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 416 | if use_kl: 417 | loss_type = gd.LossType.RESCALED_KL 418 | elif rescale_learned_sigmas: 419 | loss_type = gd.LossType.RESCALED_MSE 420 | else: 421 | loss_type = gd.LossType.MSE 422 | if not timestep_respacing: 423 | timestep_respacing = [steps] 424 | return SpacedDiffusion( 425 | use_timesteps=space_timesteps(steps, timestep_respacing), 426 | betas=betas, 427 | model_mean_type=( 428 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 429 | ), 430 | model_var_type=( 431 | ( 432 | gd.ModelVarType.FIXED_LARGE 433 | if not sigma_small 434 | else gd.ModelVarType.FIXED_SMALL 435 | ) 436 | if not learn_sigma 437 | else gd.ModelVarType.LEARNED_RANGE 438 | ), 439 | loss_type=loss_type, 440 | rescale_timesteps=rescale_timesteps, 441 | ) 442 | 443 | 444 | def add_dict_to_argparser(parser, default_dict): 445 | for k, v in default_dict.items(): 446 | v_type = type(v) 447 | if v is None: 448 | v_type = str 449 | elif isinstance(v, bool): 450 | v_type = str2bool 451 | parser.add_argument(f"--{k}", default=v, type=v_type) 452 | 453 | 454 | def args_to_dict(args, keys): 455 | return {k: getattr(args, k) for k in keys} 456 | 457 | 458 | def str2bool(v): 459 | """ 460 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 461 | """ 462 | if isinstance(v, bool): 463 | return v 464 | if v.lower() in ("yes", "true", "t", "y", "1"): 465 | return True 466 | elif v.lower() in ("no", "false", "f", "n", "0"): 467 | return False 468 | else: 469 | raise argparse.ArgumentTypeError("boolean value expected") 470 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | #if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | #if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | for k, v in cond.items() 187 | } 188 | last_batch = (i + self.microbatch) >= batch.shape[0] 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 190 | 191 | compute_losses = functools.partial( 192 | self.diffusion.training_losses, 193 | self.ddp_model, 194 | micro, 195 | t, 196 | model_kwargs=micro_cond, 197 | ) 198 | 199 | if last_batch or not self.use_ddp: 200 | losses = compute_losses() 201 | else: 202 | with self.ddp_model.no_sync(): 203 | losses = compute_losses() 204 | 205 | if isinstance(self.schedule_sampler, LossAwareSampler): 206 | self.schedule_sampler.update_with_local_losses( 207 | t, losses["loss"].detach() 208 | ) 209 | 210 | loss = (losses["loss"] * weights).mean() 211 | log_loss_dict( 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 213 | ) 214 | self.mp_trainer.backward(loss) 215 | 216 | def _update_ema(self): 217 | for rate, params in zip(self.ema_rate, self.ema_params): 218 | update_ema(params, self.mp_trainer.master_params, rate=rate) 219 | 220 | def _anneal_lr(self): 221 | if not self.lr_anneal_steps: 222 | return 223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 224 | lr = self.lr * (1 - frac_done) 225 | for param_group in self.opt.param_groups: 226 | param_group["lr"] = lr 227 | 228 | def log_step(self): 229 | logger.logkv("step", self.step + self.resume_step) 230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 231 | 232 | def save(self): 233 | def save_checkpoint(rate, params): 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 235 | if dist.get_rank() == 0: 236 | logger.log(f"saving model {rate}...") 237 | if not rate: 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" 239 | else: 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 242 | th.save(state_dict, f) 243 | 244 | save_checkpoint(0, self.mp_trainer.master_params) 245 | for rate, params in zip(self.ema_rate, self.ema_params): 246 | save_checkpoint(rate, params) 247 | 248 | if dist.get_rank() == 0: 249 | with bf.BlobFile( 250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 251 | "wb", 252 | ) as f: 253 | th.save(self.opt.state_dict(), f) 254 | 255 | dist.barrier() 256 | 257 | 258 | def parse_resume_step_from_filename(filename): 259 | """ 260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 261 | checkpoint's number of steps. 262 | """ 263 | split = filename.split("model") 264 | if len(split) < 2: 265 | return 0 266 | split1 = split[-1].split(".")[0] 267 | try: 268 | return int(split1) 269 | except ValueError: 270 | return 0 271 | 272 | 273 | def get_blob_logdir(): 274 | # You can change this to be a separate path to save checkpoints to 275 | # a blobstore or some external drive. 276 | return logger.get_dir() 277 | 278 | 279 | def find_resume_checkpoint(): 280 | # On your infrastructure, you may want to override this to automatically 281 | # discover the latest checkpoint on your blob storage, etc. 282 | return None 283 | 284 | 285 | def find_ema_checkpoint(main_checkpoint, step, rate): 286 | if main_checkpoint is None: 287 | return None 288 | filename = f"ema_{rate}_{(step):06d}.pt" 289 | path = bf.join(bf.dirname(main_checkpoint), filename) 290 | if bf.exists(path): 291 | return path 292 | return None 293 | 294 | 295 | def log_loss_dict(diffusion, ts, losses): 296 | for key, values in losses.items(): 297 | logger.logkv_mean(key, values.mean().item()) 298 | # Log the quantiles (four quartiles, in particular). 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 302 | -------------------------------------------------------------------------------- /guided_diffusion/turb_datasets.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import h5py 3 | from torch.utils.data import DataLoader, Dataset 4 | import numpy as np 5 | 6 | 7 | def load_data( 8 | *, 9 | dataset_path, 10 | dataset_name, 11 | batch_size, 12 | class_cond=False, 13 | deterministic=False, 14 | ): 15 | """ 16 | For a dataset, create a generator over (images, kwargs) pairs. 17 | 18 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 19 | more keys, each of which map to a batched Tensor of their own. 20 | The kwargs dict can be used for class labels, in which case the key is "y" 21 | and the values are integer tensors of class labels. 22 | 23 | :param dataset_path: a dataset path. 24 | :param dataset_name: a dataset name. 25 | :param batch_size: the batch size of each returned pair. 26 | :param class_cond: if True, include a "y" key in returned dicts for class 27 | label. Not implemented. 28 | :param deterministic: if True, yield results in a deterministic order. 29 | """ 30 | comm = MPI.COMM_WORLD 31 | rank = comm.Get_rank() 32 | size = comm.Get_size() 33 | 34 | with h5py.File(dataset_path, 'r', driver='mpio', comm=MPI.COMM_SELF) as f: 35 | #with h5py.File(dataset_path, 'r') as f: # replace the above line with this line for serial h5py 36 | len_dataset = f[dataset_name].len() 37 | 38 | chunk_size = len_dataset // size 39 | start_idx = rank * chunk_size 40 | 41 | dataset = TurbDataset( 42 | dataset_path, dataset_name, class_cond, start_idx, chunk_size, 43 | ) 44 | 45 | shuffle = True if deterministic else False 46 | loader = DataLoader( 47 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=True 48 | ) 49 | 50 | while True: 51 | yield from loader 52 | 53 | 54 | class TurbDataset(Dataset): 55 | def __init__( 56 | self, 57 | dataset_path, 58 | dataset_name, 59 | class_cond, 60 | start_idx, 61 | chunk_size, 62 | ): 63 | super().__init__() 64 | self.dataset_path = dataset_path 65 | self.dataset_name = dataset_name 66 | self.class_cond = class_cond 67 | self.start_idx = start_idx 68 | self.chunk_size = chunk_size 69 | 70 | def __len__(self): 71 | return self.chunk_size 72 | 73 | def __getitem__(self, idx): 74 | idx += self.start_idx 75 | 76 | with h5py.File(self.dataset_path, 'r', driver='mpio', comm=MPI.COMM_SELF) as f: 77 | #with h5py.File(self.dataset_path, 'r') as f: # replace the above line with this line for serial h5py 78 | data = f[self.dataset_name][idx].astype(np.float32) 79 | data = np.moveaxis(data, -1, 0) 80 | 81 | out_dict = {} 82 | if self.class_cond: 83 | raise NotImplementedError() 84 | out_dict["y"] = f[self.dataset_name + '_y'][idx] 85 | 86 | return data, out_dict 87 | -------------------------------------------------------------------------------- /guided_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | checkpoint, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | ) 20 | 21 | 22 | class AttentionPool2d(nn.Module): 23 | """ 24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spacial_dim: int, 30 | embed_dim: int, 31 | num_heads_channels: int, 32 | output_dim: int = None, 33 | ): 34 | super().__init__() 35 | self.positional_embedding = nn.Parameter( 36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 37 | ) 38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 40 | self.num_heads = embed_dim // num_heads_channels 41 | self.attention = QKVAttention(self.num_heads) 42 | 43 | def forward(self, x): 44 | b, c, *_spatial = x.shape 45 | x = x.reshape(b, c, -1) # NC(HW) 46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 48 | x = self.qkv_proj(x) 49 | x = self.attention(x) 50 | x = self.c_proj(x) 51 | return x[:, :, 0] 52 | 53 | 54 | class TimestepBlock(nn.Module): 55 | """ 56 | Any module where forward() takes timestep embeddings as a second argument. 57 | """ 58 | 59 | @abstractmethod 60 | def forward(self, x, emb): 61 | """ 62 | Apply the module to `x` given `emb` timestep embeddings. 63 | """ 64 | 65 | 66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 67 | """ 68 | A sequential module that passes timestep embeddings to the children that 69 | support it as an extra input. 70 | """ 71 | 72 | def forward(self, x, emb): 73 | for layer in self: 74 | if isinstance(layer, TimestepBlock): 75 | x = layer(x, emb) 76 | else: 77 | x = layer(x) 78 | return x 79 | 80 | 81 | class Upsample(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | upsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 92 | super().__init__() 93 | self.channels = channels 94 | self.out_channels = out_channels or channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | if use_conv: 98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | if self.dims == 3: 103 | x = F.interpolate( 104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 105 | ) 106 | else: 107 | x = F.interpolate(x, scale_factor=2, mode="nearest") 108 | if self.use_conv: 109 | x = self.conv(x) 110 | return x 111 | 112 | 113 | class Downsample(nn.Module): 114 | """ 115 | A downsampling layer with an optional convolution. 116 | 117 | :param channels: channels in the inputs and outputs. 118 | :param use_conv: a bool determining if a convolution is applied. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 120 | downsampling occurs in the inner-two dimensions. 121 | """ 122 | 123 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | self.use_conv = use_conv 128 | self.dims = dims 129 | stride = 2 if dims != 3 else (1, 2, 2) 130 | if use_conv: 131 | self.op = conv_nd( 132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 133 | ) 134 | else: 135 | assert self.channels == self.out_channels 136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 137 | 138 | def forward(self, x): 139 | assert x.shape[1] == self.channels 140 | return self.op(x) 141 | 142 | 143 | class ResBlock(TimestepBlock): 144 | """ 145 | A residual block that can optionally change the number of channels. 146 | 147 | :param channels: the number of input channels. 148 | :param emb_channels: the number of timestep embedding channels. 149 | :param dropout: the rate of dropout. 150 | :param out_channels: if specified, the number of out channels. 151 | :param use_conv: if True and out_channels is specified, use a spatial 152 | convolution instead of a smaller 1x1 convolution to change the 153 | channels in the skip connection. 154 | :param dims: determines if the signal is 1D, 2D, or 3D. 155 | :param use_checkpoint: if True, use gradient checkpointing on this module. 156 | :param up: if True, use this block for upsampling. 157 | :param down: if True, use this block for downsampling. 158 | """ 159 | 160 | def __init__( 161 | self, 162 | channels, 163 | emb_channels, 164 | dropout, 165 | out_channels=None, 166 | use_conv=False, 167 | use_scale_shift_norm=False, 168 | dims=2, 169 | use_checkpoint=False, 170 | up=False, 171 | down=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | nn.SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | 188 | self.updown = up or down 189 | 190 | if up: 191 | self.h_upd = Upsample(channels, False, dims) 192 | self.x_upd = Upsample(channels, False, dims) 193 | elif down: 194 | self.h_upd = Downsample(channels, False, dims) 195 | self.x_upd = Downsample(channels, False, dims) 196 | else: 197 | self.h_upd = self.x_upd = nn.Identity() 198 | 199 | self.emb_layers = nn.Sequential( 200 | nn.SiLU(), 201 | linear( 202 | emb_channels, 203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 204 | ), 205 | ) 206 | self.out_layers = nn.Sequential( 207 | normalization(self.out_channels), 208 | nn.SiLU(), 209 | nn.Dropout(p=dropout), 210 | zero_module( 211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 212 | ), 213 | ) 214 | 215 | if self.out_channels == channels: 216 | self.skip_connection = nn.Identity() 217 | elif use_conv: 218 | self.skip_connection = conv_nd( 219 | dims, channels, self.out_channels, 3, padding=1 220 | ) 221 | else: 222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 223 | 224 | def forward(self, x, emb): 225 | """ 226 | Apply the block to a Tensor, conditioned on a timestep embedding. 227 | 228 | :param x: an [N x C x ...] Tensor of features. 229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 230 | :return: an [N x C x ...] Tensor of outputs. 231 | """ 232 | return checkpoint( 233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 234 | ) 235 | 236 | def _forward(self, x, emb): 237 | if self.updown: 238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 239 | h = in_rest(x) 240 | h = self.h_upd(h) 241 | x = self.x_upd(x) 242 | h = in_conv(h) 243 | else: 244 | h = self.in_layers(x) 245 | emb_out = self.emb_layers(emb).type(h.dtype) 246 | while len(emb_out.shape) < len(h.shape): 247 | emb_out = emb_out[..., None] 248 | if self.use_scale_shift_norm: 249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 250 | scale, shift = th.chunk(emb_out, 2, dim=1) 251 | h = out_norm(h) * (1 + scale) + shift 252 | h = out_rest(h) 253 | else: 254 | h = h + emb_out 255 | h = self.out_layers(h) 256 | return self.skip_connection(x) + h 257 | 258 | 259 | class AttentionBlock(nn.Module): 260 | """ 261 | An attention block that allows spatial positions to attend to each other. 262 | 263 | Originally ported from here, but adapted to the N-d case. 264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 265 | """ 266 | 267 | def __init__( 268 | self, 269 | channels, 270 | num_heads=1, 271 | num_head_channels=-1, 272 | use_checkpoint=False, 273 | use_new_attention_order=False, 274 | ): 275 | super().__init__() 276 | self.channels = channels 277 | if num_head_channels == -1: 278 | self.num_heads = num_heads 279 | else: 280 | assert ( 281 | channels % num_head_channels == 0 282 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 283 | self.num_heads = channels // num_head_channels 284 | self.use_checkpoint = use_checkpoint 285 | self.norm = normalization(channels) 286 | self.qkv = conv_nd(1, channels, channels * 3, 1) 287 | if use_new_attention_order: 288 | # split qkv before split heads 289 | self.attention = QKVAttention(self.num_heads) 290 | else: 291 | # split heads before split qkv 292 | self.attention = QKVAttentionLegacy(self.num_heads) 293 | 294 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 295 | 296 | def forward(self, x): 297 | return checkpoint(self._forward, (x,), self.parameters(), True) 298 | 299 | def _forward(self, x): 300 | b, c, *spatial = x.shape 301 | x = x.reshape(b, c, -1) 302 | qkv = self.qkv(self.norm(x)) 303 | h = self.attention(qkv) 304 | h = self.proj_out(h) 305 | return (x + h).reshape(b, c, *spatial) 306 | 307 | 308 | def count_flops_attn(model, _x, y): 309 | """ 310 | A counter for the `thop` package to count the operations in an 311 | attention operation. 312 | Meant to be used like: 313 | macs, params = thop.profile( 314 | model, 315 | inputs=(inputs, timestamps), 316 | custom_ops={QKVAttention: QKVAttention.count_flops}, 317 | ) 318 | """ 319 | b, c, *spatial = y[0].shape 320 | num_spatial = int(np.prod(spatial)) 321 | # We perform two matmuls with the same number of ops. 322 | # The first computes the weight matrix, the second computes 323 | # the combination of the value vectors. 324 | matmul_ops = 2 * b * (num_spatial ** 2) * c 325 | model.total_ops += th.DoubleTensor([matmul_ops]) 326 | 327 | 328 | class QKVAttentionLegacy(nn.Module): 329 | """ 330 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 331 | """ 332 | 333 | def __init__(self, n_heads): 334 | super().__init__() 335 | self.n_heads = n_heads 336 | 337 | def forward(self, qkv): 338 | """ 339 | Apply QKV attention. 340 | 341 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 342 | :return: an [N x (H * C) x T] tensor after attention. 343 | """ 344 | bs, width, length = qkv.shape 345 | assert width % (3 * self.n_heads) == 0 346 | ch = width // (3 * self.n_heads) 347 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 348 | scale = 1 / math.sqrt(math.sqrt(ch)) 349 | weight = th.einsum( 350 | "bct,bcs->bts", q * scale, k * scale 351 | ) # More stable with f16 than dividing afterwards 352 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 353 | a = th.einsum("bts,bcs->bct", weight, v) 354 | return a.reshape(bs, -1, length) 355 | 356 | @staticmethod 357 | def count_flops(model, _x, y): 358 | return count_flops_attn(model, _x, y) 359 | 360 | 361 | class QKVAttention(nn.Module): 362 | """ 363 | A module which performs QKV attention and splits in a different order. 364 | """ 365 | 366 | def __init__(self, n_heads): 367 | super().__init__() 368 | self.n_heads = n_heads 369 | 370 | def forward(self, qkv): 371 | """ 372 | Apply QKV attention. 373 | 374 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 375 | :return: an [N x (H * C) x T] tensor after attention. 376 | """ 377 | bs, width, length = qkv.shape 378 | assert width % (3 * self.n_heads) == 0 379 | ch = width // (3 * self.n_heads) 380 | q, k, v = qkv.chunk(3, dim=1) 381 | scale = 1 / math.sqrt(math.sqrt(ch)) 382 | weight = th.einsum( 383 | "bct,bcs->bts", 384 | (q * scale).view(bs * self.n_heads, ch, length), 385 | (k * scale).view(bs * self.n_heads, ch, length), 386 | ) # More stable with f16 than dividing afterwards 387 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 388 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 389 | return a.reshape(bs, -1, length) 390 | 391 | @staticmethod 392 | def count_flops(model, _x, y): 393 | return count_flops_attn(model, _x, y) 394 | 395 | 396 | class UNetModel(nn.Module): 397 | """ 398 | The full UNet model with attention and timestep embedding. 399 | 400 | :param in_channels: channels in the input Tensor. 401 | :param model_channels: base channel count for the model. 402 | :param out_channels: channels in the output Tensor. 403 | :param num_res_blocks: number of residual blocks per downsample. 404 | :param attention_resolutions: a collection of downsample rates at which 405 | attention will take place. May be a set, list, or tuple. 406 | For example, if this contains 4, then at 4x downsampling, attention 407 | will be used. 408 | :param dropout: the dropout probability. 409 | :param channel_mult: channel multiplier for each level of the UNet. 410 | :param conv_resample: if True, use learned convolutions for upsampling and 411 | downsampling. 412 | :param dims: determines if the signal is 1D, 2D, or 3D. 413 | :param num_classes: if specified (as an int), then this model will be 414 | class-conditional with `num_classes` classes. 415 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 416 | :param num_heads: the number of attention heads in each attention layer. 417 | :param num_heads_channels: if specified, ignore num_heads and instead use 418 | a fixed channel width per attention head. 419 | :param num_heads_upsample: works with num_heads to set a different number 420 | of heads for upsampling. Deprecated. 421 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 422 | :param resblock_updown: use residual blocks for up/downsampling. 423 | :param use_new_attention_order: use a different attention pattern for potentially 424 | increased efficiency. 425 | """ 426 | 427 | def __init__( 428 | self, 429 | image_size, 430 | in_channels, 431 | model_channels, 432 | out_channels, 433 | num_res_blocks, 434 | attention_resolutions, 435 | dropout=0, 436 | channel_mult=(1, 2, 4, 8), 437 | conv_resample=True, 438 | dims=2, 439 | num_classes=None, 440 | use_checkpoint=False, 441 | use_fp16=False, 442 | num_heads=1, 443 | num_head_channels=-1, 444 | num_heads_upsample=-1, 445 | use_scale_shift_norm=False, 446 | resblock_updown=False, 447 | use_new_attention_order=False, 448 | ): 449 | super().__init__() 450 | 451 | if num_heads_upsample == -1: 452 | num_heads_upsample = num_heads 453 | 454 | self.image_size = image_size 455 | self.in_channels = in_channels 456 | self.model_channels = model_channels 457 | self.out_channels = out_channels 458 | self.num_res_blocks = num_res_blocks 459 | self.attention_resolutions = attention_resolutions 460 | self.dropout = dropout 461 | self.channel_mult = channel_mult 462 | self.conv_resample = conv_resample 463 | self.num_classes = num_classes 464 | self.use_checkpoint = use_checkpoint 465 | self.dtype = th.float16 if use_fp16 else th.float32 466 | self.num_heads = num_heads 467 | self.num_head_channels = num_head_channels 468 | self.num_heads_upsample = num_heads_upsample 469 | 470 | time_embed_dim = model_channels * 4 471 | self.time_embed = nn.Sequential( 472 | linear(model_channels, time_embed_dim), 473 | nn.SiLU(), 474 | linear(time_embed_dim, time_embed_dim), 475 | ) 476 | 477 | if self.num_classes is not None: 478 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 479 | 480 | ch = input_ch = int(channel_mult[0] * model_channels) 481 | self.input_blocks = nn.ModuleList( 482 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 483 | ) 484 | self._feature_size = ch 485 | input_block_chans = [ch] 486 | ds = 1 487 | for level, mult in enumerate(channel_mult): 488 | for _ in range(num_res_blocks): 489 | layers = [ 490 | ResBlock( 491 | ch, 492 | time_embed_dim, 493 | dropout, 494 | out_channels=int(mult * model_channels), 495 | dims=dims, 496 | use_checkpoint=use_checkpoint, 497 | use_scale_shift_norm=use_scale_shift_norm, 498 | ) 499 | ] 500 | ch = int(mult * model_channels) 501 | if ds in attention_resolutions: 502 | layers.append( 503 | AttentionBlock( 504 | ch, 505 | use_checkpoint=use_checkpoint, 506 | num_heads=num_heads, 507 | num_head_channels=num_head_channels, 508 | use_new_attention_order=use_new_attention_order, 509 | ) 510 | ) 511 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 512 | self._feature_size += ch 513 | input_block_chans.append(ch) 514 | if level != len(channel_mult) - 1: 515 | out_ch = ch 516 | self.input_blocks.append( 517 | TimestepEmbedSequential( 518 | ResBlock( 519 | ch, 520 | time_embed_dim, 521 | dropout, 522 | out_channels=out_ch, 523 | dims=dims, 524 | use_checkpoint=use_checkpoint, 525 | use_scale_shift_norm=use_scale_shift_norm, 526 | down=True, 527 | ) 528 | if resblock_updown 529 | else Downsample( 530 | ch, conv_resample, dims=dims, out_channels=out_ch 531 | ) 532 | ) 533 | ) 534 | ch = out_ch 535 | input_block_chans.append(ch) 536 | ds *= 2 537 | self._feature_size += ch 538 | 539 | self.middle_block = TimestepEmbedSequential( 540 | ResBlock( 541 | ch, 542 | time_embed_dim, 543 | dropout, 544 | dims=dims, 545 | use_checkpoint=use_checkpoint, 546 | use_scale_shift_norm=use_scale_shift_norm, 547 | ), 548 | AttentionBlock( 549 | ch, 550 | use_checkpoint=use_checkpoint, 551 | num_heads=num_heads, 552 | num_head_channels=num_head_channels, 553 | use_new_attention_order=use_new_attention_order, 554 | ), 555 | ResBlock( 556 | ch, 557 | time_embed_dim, 558 | dropout, 559 | dims=dims, 560 | use_checkpoint=use_checkpoint, 561 | use_scale_shift_norm=use_scale_shift_norm, 562 | ), 563 | ) 564 | self._feature_size += ch 565 | 566 | self.output_blocks = nn.ModuleList([]) 567 | for level, mult in list(enumerate(channel_mult))[::-1]: 568 | for i in range(num_res_blocks + 1): 569 | ich = input_block_chans.pop() 570 | layers = [ 571 | ResBlock( 572 | ch + ich, 573 | time_embed_dim, 574 | dropout, 575 | out_channels=int(model_channels * mult), 576 | dims=dims, 577 | use_checkpoint=use_checkpoint, 578 | use_scale_shift_norm=use_scale_shift_norm, 579 | ) 580 | ] 581 | ch = int(model_channels * mult) 582 | if ds in attention_resolutions: 583 | layers.append( 584 | AttentionBlock( 585 | ch, 586 | use_checkpoint=use_checkpoint, 587 | num_heads=num_heads_upsample, 588 | num_head_channels=num_head_channels, 589 | use_new_attention_order=use_new_attention_order, 590 | ) 591 | ) 592 | if level and i == num_res_blocks: 593 | out_ch = ch 594 | layers.append( 595 | ResBlock( 596 | ch, 597 | time_embed_dim, 598 | dropout, 599 | out_channels=out_ch, 600 | dims=dims, 601 | use_checkpoint=use_checkpoint, 602 | use_scale_shift_norm=use_scale_shift_norm, 603 | up=True, 604 | ) 605 | if resblock_updown 606 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 607 | ) 608 | ds //= 2 609 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 610 | self._feature_size += ch 611 | 612 | self.out = nn.Sequential( 613 | normalization(ch), 614 | nn.SiLU(), 615 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), 616 | ) 617 | 618 | def convert_to_fp16(self): 619 | """ 620 | Convert the torso of the model to float16. 621 | """ 622 | self.input_blocks.apply(convert_module_to_f16) 623 | self.middle_block.apply(convert_module_to_f16) 624 | self.output_blocks.apply(convert_module_to_f16) 625 | 626 | def convert_to_fp32(self): 627 | """ 628 | Convert the torso of the model to float32. 629 | """ 630 | self.input_blocks.apply(convert_module_to_f32) 631 | self.middle_block.apply(convert_module_to_f32) 632 | self.output_blocks.apply(convert_module_to_f32) 633 | 634 | def forward(self, x, timesteps, y=None): 635 | """ 636 | Apply the model to an input batch. 637 | 638 | :param x: an [N x C x ...] Tensor of inputs. 639 | :param timesteps: a 1-D batch of timesteps. 640 | :param y: an [N] Tensor of labels, if class-conditional. 641 | :return: an [N x C x ...] Tensor of outputs. 642 | """ 643 | assert (y is not None) == ( 644 | self.num_classes is not None 645 | ), "must specify y if and only if the model is class-conditional" 646 | 647 | hs = [] 648 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 649 | 650 | if self.num_classes is not None: 651 | assert y.shape == (x.shape[0],) 652 | emb = emb + self.label_emb(y) 653 | 654 | h = x.type(self.dtype) 655 | for module in self.input_blocks: 656 | h = module(h, emb) 657 | hs.append(h) 658 | h = self.middle_block(h, emb) 659 | for module in self.output_blocks: 660 | h = th.cat([h, hs.pop()], dim=1) 661 | h = module(h, emb) 662 | h = h.type(x.dtype) 663 | return self.out(h) 664 | 665 | 666 | class SuperResModel(UNetModel): 667 | """ 668 | A UNetModel that performs super-resolution. 669 | 670 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 671 | """ 672 | 673 | def __init__(self, image_size, in_channels, *args, **kwargs): 674 | super().__init__(image_size, in_channels * 2, *args, **kwargs) 675 | 676 | def forward(self, x, timesteps, low_res=None, **kwargs): 677 | _, _, new_height, new_width = x.shape 678 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 679 | x = th.cat([x, upsampled], dim=1) 680 | return super().forward(x, timesteps, **kwargs) 681 | 682 | 683 | class EncoderUNetModel(nn.Module): 684 | """ 685 | The half UNet model with attention and timestep embedding. 686 | 687 | For usage, see UNet. 688 | """ 689 | 690 | def __init__( 691 | self, 692 | image_size, 693 | in_channels, 694 | model_channels, 695 | out_channels, 696 | num_res_blocks, 697 | attention_resolutions, 698 | dropout=0, 699 | channel_mult=(1, 2, 4, 8), 700 | conv_resample=True, 701 | dims=2, 702 | use_checkpoint=False, 703 | use_fp16=False, 704 | num_heads=1, 705 | num_head_channels=-1, 706 | num_heads_upsample=-1, 707 | use_scale_shift_norm=False, 708 | resblock_updown=False, 709 | use_new_attention_order=False, 710 | pool="adaptive", 711 | ): 712 | super().__init__() 713 | 714 | if num_heads_upsample == -1: 715 | num_heads_upsample = num_heads 716 | 717 | self.in_channels = in_channels 718 | self.model_channels = model_channels 719 | self.out_channels = out_channels 720 | self.num_res_blocks = num_res_blocks 721 | self.attention_resolutions = attention_resolutions 722 | self.dropout = dropout 723 | self.channel_mult = channel_mult 724 | self.conv_resample = conv_resample 725 | self.use_checkpoint = use_checkpoint 726 | self.dtype = th.float16 if use_fp16 else th.float32 727 | self.num_heads = num_heads 728 | self.num_head_channels = num_head_channels 729 | self.num_heads_upsample = num_heads_upsample 730 | 731 | time_embed_dim = model_channels * 4 732 | self.time_embed = nn.Sequential( 733 | linear(model_channels, time_embed_dim), 734 | nn.SiLU(), 735 | linear(time_embed_dim, time_embed_dim), 736 | ) 737 | 738 | ch = int(channel_mult[0] * model_channels) 739 | self.input_blocks = nn.ModuleList( 740 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 741 | ) 742 | self._feature_size = ch 743 | input_block_chans = [ch] 744 | ds = 1 745 | for level, mult in enumerate(channel_mult): 746 | for _ in range(num_res_blocks): 747 | layers = [ 748 | ResBlock( 749 | ch, 750 | time_embed_dim, 751 | dropout, 752 | out_channels=int(mult * model_channels), 753 | dims=dims, 754 | use_checkpoint=use_checkpoint, 755 | use_scale_shift_norm=use_scale_shift_norm, 756 | ) 757 | ] 758 | ch = int(mult * model_channels) 759 | if ds in attention_resolutions: 760 | layers.append( 761 | AttentionBlock( 762 | ch, 763 | use_checkpoint=use_checkpoint, 764 | num_heads=num_heads, 765 | num_head_channels=num_head_channels, 766 | use_new_attention_order=use_new_attention_order, 767 | ) 768 | ) 769 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 770 | self._feature_size += ch 771 | input_block_chans.append(ch) 772 | if level != len(channel_mult) - 1: 773 | out_ch = ch 774 | self.input_blocks.append( 775 | TimestepEmbedSequential( 776 | ResBlock( 777 | ch, 778 | time_embed_dim, 779 | dropout, 780 | out_channels=out_ch, 781 | dims=dims, 782 | use_checkpoint=use_checkpoint, 783 | use_scale_shift_norm=use_scale_shift_norm, 784 | down=True, 785 | ) 786 | if resblock_updown 787 | else Downsample( 788 | ch, conv_resample, dims=dims, out_channels=out_ch 789 | ) 790 | ) 791 | ) 792 | ch = out_ch 793 | input_block_chans.append(ch) 794 | ds *= 2 795 | self._feature_size += ch 796 | 797 | self.middle_block = TimestepEmbedSequential( 798 | ResBlock( 799 | ch, 800 | time_embed_dim, 801 | dropout, 802 | dims=dims, 803 | use_checkpoint=use_checkpoint, 804 | use_scale_shift_norm=use_scale_shift_norm, 805 | ), 806 | AttentionBlock( 807 | ch, 808 | use_checkpoint=use_checkpoint, 809 | num_heads=num_heads, 810 | num_head_channels=num_head_channels, 811 | use_new_attention_order=use_new_attention_order, 812 | ), 813 | ResBlock( 814 | ch, 815 | time_embed_dim, 816 | dropout, 817 | dims=dims, 818 | use_checkpoint=use_checkpoint, 819 | use_scale_shift_norm=use_scale_shift_norm, 820 | ), 821 | ) 822 | self._feature_size += ch 823 | self.pool = pool 824 | if pool == "adaptive": 825 | self.out = nn.Sequential( 826 | normalization(ch), 827 | nn.SiLU(), 828 | nn.AdaptiveAvgPool2d((1, 1)), 829 | zero_module(conv_nd(dims, ch, out_channels, 1)), 830 | nn.Flatten(), 831 | ) 832 | elif pool == "attention": 833 | assert num_head_channels != -1 834 | self.out = nn.Sequential( 835 | normalization(ch), 836 | nn.SiLU(), 837 | AttentionPool2d( 838 | (image_size // ds), ch, num_head_channels, out_channels 839 | ), 840 | ) 841 | elif pool == "spatial": 842 | self.out = nn.Sequential( 843 | nn.Linear(self._feature_size, 2048), 844 | nn.ReLU(), 845 | nn.Linear(2048, self.out_channels), 846 | ) 847 | elif pool == "spatial_v2": 848 | self.out = nn.Sequential( 849 | nn.Linear(self._feature_size, 2048), 850 | normalization(2048), 851 | nn.SiLU(), 852 | nn.Linear(2048, self.out_channels), 853 | ) 854 | else: 855 | raise NotImplementedError(f"Unexpected {pool} pooling") 856 | 857 | def convert_to_fp16(self): 858 | """ 859 | Convert the torso of the model to float16. 860 | """ 861 | self.input_blocks.apply(convert_module_to_f16) 862 | self.middle_block.apply(convert_module_to_f16) 863 | 864 | def convert_to_fp32(self): 865 | """ 866 | Convert the torso of the model to float32. 867 | """ 868 | self.input_blocks.apply(convert_module_to_f32) 869 | self.middle_block.apply(convert_module_to_f32) 870 | 871 | def forward(self, x, timesteps): 872 | """ 873 | Apply the model to an input batch. 874 | 875 | :param x: an [N x C x ...] Tensor of inputs. 876 | :param timesteps: a 1-D batch of timesteps. 877 | :return: an [N x K] Tensor of outputs. 878 | """ 879 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 880 | 881 | results = [] 882 | h = x.type(self.dtype) 883 | for module in self.input_blocks: 884 | h = module(h, emb) 885 | if self.pool.startswith("spatial"): 886 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 887 | h = self.middle_block(h, emb) 888 | if self.pool.startswith("spatial"): 889 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 890 | h = th.cat(results, axis=-1) 891 | return self.out(h) 892 | else: 893 | h = h.type(x.dtype) 894 | return self.out(h) 895 | -------------------------------------------------------------------------------- /ppdm/ppdm/__init__.py: -------------------------------------------------------------------------------- 1 | from ppdm.ppdm import comput_jsd 2 | from ppdm.ppdm import struct_func 3 | from ppdm.ppdm import comput_batch_mean_err 4 | from ppdm.ppdm import corr_func 5 | -------------------------------------------------------------------------------- /ppdm/ppdm/ppdm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import distance 3 | 4 | def comput_jsd(y_true, y_pred, bins): 5 | y_min = min(y_true.min(), y_pred.min()) 6 | y_max = max(y_true.max(), y_pred.max()) 7 | hist_true, bin_edges = np.histogram(y_true, bins=bins, range=(y_min, y_max)) 8 | hist_pred, bin_edges = np.histogram(y_pred, bins=bins, range=(y_min, y_max)) 9 | return distance.jensenshannon(hist_true, hist_pred, 2.0)**2 10 | 11 | def struct_func(p, dt, u): 12 | du_p = (u[:, dt:] - u[:, :-dt]) ** p 13 | return np.mean(du_p), np.std(du_p) 14 | 15 | def comput_batch_mean_err(Sp_batch): 16 | Sp_mean = np.mean(Sp_batch, axis=0) 17 | Sp_min = np.amin(Sp_batch, axis=0) 18 | Sp_max = np.amax(Sp_batch, axis=0) 19 | return Sp_mean, np.vstack([Sp_mean - Sp_min, Sp_max - Sp_mean]) 20 | 21 | def corr_func(dt, u): 22 | return np.mean(u[:, :-dt] * u[:, dt:]) -------------------------------------------------------------------------------- /ppdm/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name='ppdm', 7 | version='0.1', 8 | # list folders, not files 9 | packages=['ppdm'], 10 | ) 11 | -------------------------------------------------------------------------------- /resources/Sampling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/resources/Sampling.png -------------------------------------------------------------------------------- /resources/Training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/resources/Training.png -------------------------------------------------------------------------------- /scripts/turb_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate the losses for a diffusion model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | from guided_diffusion import dist_util, logger 13 | from guided_diffusion.turb_datasets import load_data 14 | from guided_diffusion.script_util import ( 15 | model_and_diffusion_defaults, 16 | create_model_and_diffusion, 17 | add_dict_to_argparser, 18 | args_to_dict, 19 | ) 20 | 21 | 22 | def main(): 23 | args = create_argparser().parse_args() 24 | 25 | dist_util.setup_dist() 26 | logger.configure() 27 | 28 | logger.log("creating model and diffusion...") 29 | model, diffusion = create_model_and_diffusion( 30 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 31 | ) 32 | model.load_state_dict( 33 | dist_util.load_state_dict(args.model_path, map_location="cpu") 34 | ) 35 | model.to(dist_util.dev()) 36 | model.eval() 37 | 38 | logger.log("creating data loader...") 39 | data = load_data( 40 | dataset_path=args.dataset_path, 41 | dataset_name=args.dataset_name, 42 | batch_size=args.batch_size, 43 | class_cond=args.class_cond, 44 | deterministic=True, 45 | ) 46 | 47 | logger.log("evaluating...") 48 | import os 49 | seed = 0*4 + int(os.environ["CUDA_VISIBLE_DEVICES"]) 50 | th.manual_seed(seed) 51 | run_losses_evaluation(model, diffusion, data, args.num_samples) 52 | 53 | 54 | def run_losses_evaluation(model, diffusion, data, num_samples): 55 | all_total_loss = [] 56 | all_losses = {"loss": [], "vb": [], "mse": []} 57 | num_complete = 0 58 | while num_complete < num_samples: 59 | batch, model_kwargs = next(data) 60 | batch = batch.to(dist_util.dev()) 61 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 62 | minibatch_losses = diffusion.calc_losses_loop( 63 | model, batch, model_kwargs=model_kwargs 64 | ) 65 | 66 | for key in minibatch_losses: 67 | losses = minibatch_losses[key] 68 | gathered_losses = [th.zeros_like(losses) for _ in range(dist.get_world_size())] 69 | dist.all_gather(gathered_losses, losses) # gather not supported with NCCL 70 | all_losses[key].extend([losses.cpu().numpy() for losses in gathered_losses]) 71 | 72 | total_loss = minibatch_losses["loss"] 73 | total_loss = total_loss.mean() / dist.get_world_size() 74 | dist.all_reduce(total_loss) 75 | all_total_loss.append(total_loss.item()) 76 | num_complete += dist.get_world_size() * batch.shape[0] 77 | 78 | logger.log(f"done {num_complete} samples: total_loss={np.mean(all_total_loss)}") 79 | 80 | if dist.get_rank() == 0: 81 | for name in minibatch_losses: 82 | losses = np.concatenate(all_losses[name]) 83 | shape_str = "x".join([str(x) for x in losses.shape]) 84 | out_path = os.path.join(logger.get_dir(), f"{name}_losses_{shape_str}.npz") 85 | logger.log(f"saving {name} losses to {out_path}") 86 | np.savez(out_path, losses) 87 | 88 | dist.barrier() 89 | logger.log("evaluation complete") 90 | 91 | 92 | def create_argparser(): 93 | defaults = dict( 94 | dataset_path="", dataset_name="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 95 | ) 96 | defaults.update(model_and_diffusion_defaults()) 97 | parser = argparse.ArgumentParser() 98 | add_dict_to_argparser(parser, defaults) 99 | return parser 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /scripts/turb_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Print the model summary of UNetModel being used. 3 | """ 4 | 5 | import argparse 6 | from torchsummary import summary 7 | 8 | from guided_diffusion import dist_util, logger 9 | from guided_diffusion.script_util import ( 10 | model_and_diffusion_defaults, 11 | create_model_and_diffusion, 12 | args_to_dict, 13 | add_dict_to_argparser, 14 | ) 15 | 16 | 17 | def main(): 18 | args = create_argparser().parse_args() 19 | 20 | dist_util.setup_dist() 21 | logger.configure() 22 | 23 | logger.log("creating model and diffusion...") 24 | model, diffusion = create_model_and_diffusion( 25 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 26 | ) 27 | model.to(dist_util.dev()) 28 | 29 | summary(model, [(args.in_channels, args.image_size), ()]) 30 | 31 | 32 | def create_argparser(): 33 | defaults = dict( 34 | dataset_path="", 35 | dataset_name="", 36 | schedule_sampler="uniform", 37 | lr=1e-4, 38 | weight_decay=0.0, 39 | lr_anneal_steps=0, 40 | batch_size=1, 41 | microbatch=-1, # -1 disables microbatches 42 | ema_rate="0.9999", # comma-separated list of EMA values 43 | log_interval=10, 44 | save_interval=10000, 45 | resume_checkpoint="", 46 | use_fp16=False, 47 | fp16_scale_growth=1e-3, 48 | ) 49 | defaults.update(model_and_diffusion_defaults()) 50 | parser = argparse.ArgumentParser() 51 | add_dict_to_argparser(parser, defaults) 52 | return parser 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /scripts/turb_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of Lagrangian trajectories from a model and save them as a large 3 | numpy array. This can be used to produce samples for statistical evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from guided_diffusion import dist_util, logger 14 | from guided_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model and diffusion...") 30 | model, diffusion = create_model_and_diffusion( 31 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("sampling...") 42 | all_images = [] 43 | all_labels = [] 44 | #noise = th.zeros( 45 | # noise = th.ones( 46 | # (args.batch_size, args.in_channels, args.image_size), 47 | # dtype=th.float32, 48 | # device=dist_util.dev() 49 | # ) * 2 50 | # noise = th.from_numpy( 51 | # np.load('../velocity_module-IS64-NC128-NRB3-DS4000-NScosine-LR1e-4-BS256-sample/fixed_noise_64x1x64x64.npy') 52 | # ).to(dtype=th.float32, device=dist_util.dev()) 53 | import os 54 | seed = 0*8 + int(os.environ["CUDA_VISIBLE_DEVICES"]) 55 | th.manual_seed(seed) 56 | while len(all_images) * args.batch_size < args.num_samples: 57 | model_kwargs = {} 58 | if args.class_cond: 59 | classes = th.randint( 60 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 61 | ) 62 | model_kwargs["y"] = classes 63 | sample_fn = ( 64 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 65 | ) 66 | #sample_fn = diffusion.p_sample_loop_history 67 | sample = sample_fn( 68 | model, 69 | (args.batch_size, args.in_channels, args.image_size), 70 | #noise=noise, 71 | clip_denoised=args.clip_denoised, 72 | model_kwargs=model_kwargs, 73 | ) 74 | sample = sample.clamp(-1, 1) 75 | #sample[:, -1] = sample[:, -1].clamp(-1, 1) 76 | sample = sample.permute(0, 2, 1) 77 | #sample = sample.permute(0, 1, 3, 2) 78 | sample = sample.contiguous() 79 | 80 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 81 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 82 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 83 | if args.class_cond: 84 | gathered_labels = [ 85 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather(gathered_labels, classes) 88 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 89 | logger.log(f"created {len(all_images) * args.batch_size} samples") 90 | 91 | arr = np.concatenate(all_images, axis=0) 92 | arr = arr[: args.num_samples] 93 | if args.class_cond: 94 | label_arr = np.concatenate(all_labels, axis=0) 95 | label_arr = label_arr[: args.num_samples] 96 | if dist.get_rank() == 0: 97 | shape_str = "x".join([str(x) for x in arr.shape]) 98 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 99 | logger.log(f"saving to {out_path}") 100 | if args.class_cond: 101 | np.savez(out_path, arr, label_arr) 102 | else: 103 | np.savez(out_path, arr) 104 | 105 | dist.barrier() 106 | logger.log("sampling complete") 107 | 108 | 109 | def create_argparser(): 110 | defaults = dict( 111 | clip_denoised=True, 112 | num_samples=10000, 113 | batch_size=16, 114 | use_ddim=False, 115 | model_path="", 116 | ) 117 | defaults.update(model_and_diffusion_defaults()) 118 | parser = argparse.ArgumentParser() 119 | add_dict_to_argparser(parser, defaults) 120 | return parser 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /scripts/turb_sample_history.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of Lagrangian trajectories from a model and save them as a large 3 | numpy array. This can be used to produce samples for statistical evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from guided_diffusion import dist_util, logger 14 | from guided_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model and diffusion...") 30 | model, diffusion = create_model_and_diffusion( 31 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("sampling...") 42 | #noise = th.zeros( 43 | # noise = th.ones( 44 | # (args.batch_size, args.in_channels, args.image_size), 45 | # dtype=th.float32, 46 | # device=dist_util.dev() 47 | # ) * 2 48 | # noise = th.from_numpy( 49 | # np.load('../velocity_module-IS64-NC128-NRB3-DS4000-NScosine-LR1e-4-BS256-sample/fixed_noise_64x1x64x64.npy') 50 | # ).to(dtype=th.float32, device=dist_util.dev()) 51 | import os 52 | seed = 0*8 + int(os.environ["CUDA_VISIBLE_DEVICES"]) 53 | th.manual_seed(seed) 54 | curr_batch, num_complete = 0, 0 55 | while num_complete < args.num_samples: 56 | model_kwargs = {} 57 | if args.class_cond: 58 | classes = th.randint( 59 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 60 | ) 61 | model_kwargs["y"] = classes 62 | sample_fn = diffusion.p_sample_loop_history 63 | sample = sample_fn( 64 | model, 65 | (args.batch_size, args.in_channels, args.image_size), 66 | #noise=noise, 67 | clip_denoised=args.clip_denoised, 68 | model_kwargs=model_kwargs, 69 | ) 70 | sample[:, -1] = sample[:, -1].clamp(-1, 1) 71 | sample = sample.permute(0, 1, 3, 2) 72 | sample = sample.contiguous() 73 | 74 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 75 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 76 | all_images = [sample.cpu().numpy() for sample in gathered_samples] 77 | if args.class_cond: 78 | gathered_labels = [ 79 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 80 | ] 81 | dist.all_gather(gathered_labels, classes) 82 | all_labels = [labels.cpu().numpy() for labels in gathered_labels] 83 | curr_batch += 1 84 | num_complete += dist.get_world_size() * args.batch_size 85 | logger.log(f"created {num_complete} samples") 86 | 87 | arr = np.concatenate(all_images, axis=0) 88 | if args.class_cond: 89 | label_arr = np.concatenate(all_labels, axis=0) 90 | if dist.get_rank() == 0: 91 | shape_str = "x".join([str(x) for x in arr.shape]) 92 | out_path = os.path.join( 93 | logger.get_dir(), "samples_history-seed0", 94 | f"batch{curr_batch:03d}-samples_history_{shape_str}.npz" 95 | ) 96 | logger.log(f"saving to {out_path}") 97 | if args.class_cond: 98 | np.savez(out_path, arr, label_arr) 99 | else: 100 | np.savez(out_path, arr) 101 | 102 | dist.barrier() 103 | logger.log("sampling complete") 104 | 105 | 106 | def create_argparser(): 107 | defaults = dict( 108 | clip_denoised=True, 109 | num_samples=10000, 110 | batch_size=16, 111 | use_ddim=False, 112 | model_path="", 113 | ) 114 | defaults.update(model_and_diffusion_defaults()) 115 | parser = argparse.ArgumentParser() 116 | add_dict_to_argparser(parser, defaults) 117 | return parser 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /scripts/turb_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on Lagrangian trajectories in 3d turbulence. 3 | """ 4 | 5 | import argparse 6 | 7 | from guided_diffusion import dist_util, logger 8 | from guided_diffusion.turb_datasets import load_data 9 | from guided_diffusion.resample import create_named_schedule_sampler 10 | from guided_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from guided_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure() 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | dataset_path=args.dataset_path, 35 | dataset_name=args.dataset_name, 36 | batch_size=args.batch_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | dataset_path="", 63 | dataset_name="", 64 | schedule_sampler="uniform", 65 | lr=1e-4, 66 | weight_decay=0.0, 67 | lr_anneal_steps=0, 68 | batch_size=1, 69 | microbatch=-1, # -1 disables microbatches 70 | ema_rate="0.9999", # comma-separated list of EMA values 71 | log_interval=10, 72 | save_interval=10000, 73 | resume_checkpoint="", 74 | use_fp16=False, 75 | fp16_scale_growth=1e-3, 76 | ) 77 | defaults.update(model_and_diffusion_defaults()) 78 | parser = argparse.ArgumentParser() 79 | add_dict_to_argparser(parser, defaults) 80 | return parser 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="guided-diffusion", 5 | py_modules=["guided_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | --------------------------------------------------------------------------------