├── LICENSE
├── README.md
├── datasets
├── Lagr_u1c_diffusion-demo.h5
├── Lagr_u3c_diffusion-demo.h5
└── preprocessing-lagr_u1c-diffusion.py
├── guided_diffusion
├── __init__.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── logger.py
├── losses.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
├── turb_datasets.py
└── unet.py
├── ppdm
├── ppdm
│ ├── __init__.py
│ └── ppdm.py
└── setup.py
├── resources
├── Sampling.png
└── Training.png
├── scripts
├── turb_losses.py
├── turb_model.py
├── turb_sample.py
├── turb_sample_history.py
└── turb_train.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Smart-TURB
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # diffusion-lagr
2 |
3 | [](https://doi.org/10.5281/zenodo.10563386)
4 |
5 | This is the codebase for [Synthetic Lagrangian Turbulence by Generative Diffusion Models](https://arxiv.org/abs/2307.08529).
6 |
7 | This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications specifically tailored to adapt the Lagrangian turbulence data in the Smart-TURB portal http://smart-turb.roma2.infn.it, under the [TURB-Lagr](https://smart-turb.roma2.infn.it/init/routes/#/logging/view_dataset/2/tabmeta) dataset.
8 |
9 | # Usage
10 |
11 | ## Development Environment
12 |
13 | Our software was developed and tested on a system with the following specifications:
14 |
15 | - **Operating System**: Ubuntu 20.04.4 LTS
16 | - **Python Version**: 3.7.16
17 | - **PyTorch Version**: 1.13.1
18 | - **MPI Implementation**: OpenRTE 4.0.2
19 | - **CUDA Version**: 11.5
20 | - **GPU Model**: NVIDIA A100
21 |
22 | ## Installation
23 |
24 | We recommend using a Conda environment to manage dependencies. The code relies on the MPI library and [parallel h5py](https://docs.h5py.org/en/stable/mpi.html). Note, however, that the use of MPI is not mandatory for all functionalities. See details in [Training](#Training) and [Sampling](#Sampling) for more information. After setting up your environment, clone this repository and navigate to it in your terminal. Then run:
25 |
26 | ```
27 | pip install -e .
28 | ```
29 |
30 | This should install the `guided_diffusion` python package that the scripts depend on.
31 |
32 | ### Troubleshooting Installation
33 |
34 | During the installation process, you might encounter a couple of known issues. Here are some tips to help you resolve them:
35 |
36 | 1. **Parallel h5py Installation**: Setting up parallel h5py can sometimes pose challenges. As a workaround, you can install the serial version of h5py, comment out the specific lines of code found [here](https://github.com/SmartTURB/diffusion-lagr/blob/master/guided_diffusion/turb_datasets.py#L34) and [here](https://github.com/SmartTURB/diffusion-lagr/blob/master/guided_diffusion/turb_datasets.py#L76), and uncomment the lines immediately following them.
37 | 2. **PyTorch Installation**: In our experience, sometimes it's necessary to reinstall PyTorch depending on your system environment. You can download and install PyTorch from their [official website](https://pytorch.org/).
38 |
39 | ## Preparing Data
40 |
41 | The data needed for this project can be obtained from the Smart-TURB portal. Follow these steps to download the data:
42 |
43 | 1. Visit the [Smart-TURB portal](http://smart-turb.roma2.infn.it).
44 | 2. Navigate to `TURB-Lagr` under the `Datasets` section.
45 | 3. Click on `Files` -> `data` -> `Lagr_u3c_diffusion.h5`,
46 |
47 | which can also be accessed directly by clicking on this [link](https://smart-turb.roma2.infn.it/init/files/api_file_download/1/___FOLDERSEPARATOR___scratch___FOLDERSEPARATOR___smartturb___FOLDERSEPARATOR___tov___FOLDERSEPARATOR___turb-lagr___FOLDERSEPARATOR___data___FOLDERSEPARATOR___Lagr_u3c_diffusion___POINT___h5/15728642096).
48 |
49 | ### Data Details and Example Usage
50 |
51 | Here is an example of how you can read the data:
52 |
53 | ```python
54 | import h5py
55 | import numpy as np
56 |
57 | with h5py.File('datasets/Lagr_u3c_diffusion.h5', 'r') as h5f:
58 | rx0 = np.array(h5f.get('min'))
59 | rx1 = np.array(h5f.get('max'))
60 | u3c = np.array(h5f.get('train'))
61 |
62 | velocities = (u3c+1)*(rx1-rx0)/2 + rx0
63 | ```
64 |
65 | The `u3c` variable is a 3D array with the shape `(327680, 2000, 3)`, representing 327,680 trajectories, each of size 2000, for 3 velocity components. Each component is normalized to the range `[-1, 1]` using the min-max method. The `rx0` and `rx1` variables store the minimum and maximum values for each of the 3 components, respectively. The last line of the code sample retrieves the original velocities from the normalized data.
66 |
67 | The data file `Lagr_u3c_diffusion.h5` mentioned above is used for training the `DM-3c` model. For training `DM-1c`, we do not distinguish between the 3 velocity components, thereby tripling the number of trajectories. You can generate the appropriate data by using the [`datasets/preprocessing-lagr_u1c-diffusion.py`](https://github.com/SmartTURB/diffusion-lagr/blob/master/datasets/preprocessing-lagr_u1c-diffusion.py) script. This script concatenates the three velocity components, applies min-max normalization, and saves the result as `Lagr_u1c_diffusion.h5`.
68 |
69 | ## Training
70 |
71 |
72 |
73 | To train your model, you'll first need to determine certain hyperparameters. We can categorize these hyperparameters into three groups: model architecture, diffusion process, and training flags. Detailed information about these can be found in the [parent repository](https://github.com/openai/improved-diffusion).
74 |
75 | The run flags for the two models featured in our paper are as follows (please refer to Fig.2 in [the paper](https://arxiv.org/abs/2307.08529)):
76 |
77 | For the `DM-1c` model, use the following flags:
78 |
79 | ```sh
80 | DATA_FLAGS="--dataset_path datasets/Lagr_u1c_diffusion.h5 --dataset_name train"
81 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 1 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4"
82 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1"
83 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64"
84 | ```
85 |
86 | For the `DM-3c` model, you only need to modify `--dataset_path` to `../datasets/Lagr_u3c_diffusion.h5` and `--in_channels` to `3`:
87 |
88 | ```sh
89 | DATA_FLAGS="--dataset_path datasets/Lagr_u3c_diffusion.h5 --dataset_name train"
90 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 3 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4"
91 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1"
92 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64"
93 | ```
94 |
95 | After defining your hyperparameters, you can initiate an experiment using the following command:
96 |
97 | ```sh
98 | mpiexec -n $NUM_GPUS python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
99 | ```
100 |
101 | The training process is distributed, and for our model, we set `$NUM_GPUS` to 4. Note that the `--batch_size` flag represents the batch size on each GPU, so the real batch size is `$NUM_GPUS * batch_size = 256`, as reported in the paper (Fig.2).
102 |
103 | The log files and model checkpoints will be saved to a logging directory specified by the `OPENAI_LOGDIR` environment variable. If this variable is not set, a temporary directory in `/tmp` will be created and used instead.
104 |
105 | ### Demo
106 |
107 | To assist with testing the software installation and understanding the hyperparameters mentioned above, we have provided two smaller datasets: `datasets/Lag_u1c_diffusion-demo.h5` and `datasets/Lag_u3c_diffusion-demo.h5`. The `train` dataset within these files has shapes of (768, 2000, 1) and (256, 2000, 3), respectively.
108 |
109 | To run the demo, use the same flags as for the `DM-1c` and `DM-3c` models above, ensuring that you modify the `--dataset_path` flag to the appropriate demo dataset.
110 |
111 | For the `DM-1c` model:
112 |
113 | ```sh
114 | # Set the flags
115 | DATA_FLAGS="--dataset_path datasets/Lagr_u1c_diffusion-demo.h5 --dataset_name train"
116 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 1 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4"
117 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1"
118 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64"
119 |
120 | # Training command
121 | python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
122 | ```
123 |
124 | For the `DM-3c` model:
125 |
126 | ```sh
127 | # Set the flags
128 | DATA_FLAGS="--dataset_path datasets/Lagr_u3c_diffusion-demo.h5 --dataset_name train"
129 | MODEL_FLAGS="--dims 1 --image_size 2000 --in_channels 3 --num_channels 128 --num_res_blocks 3 --attention_resolutions 250,125 --channel_mult 1,1,2,3,4"
130 | DIFFUSION_FLAGS="--diffusion_steps 800 --noise_schedule tanh6,1"
131 | TRAIN_FLAGS="--lr 1e-4 --batch_size 64"
132 |
133 | # Training command
134 | python scripts/turb_train.py $DATA_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
135 | ```
136 |
137 | Remember, for this demo, you can simplify the run by using the serial version of h5py as described in [Parallel h5py Installation](#h5py-installation).
138 |
139 | ## Sampling:
140 |
141 |
142 |
143 | The training script from the previous section stores checkpoints as `.pt` files within the designated logging directory. These checkpoint files will follow naming patterns such as `ema_0.9999_200000.pt` or `model200000.pt`. For improved sampling results, it's advised to sample from the Exponential Moving Average (EMA) models.
144 |
145 | Before sampling, set `SAMPLE_FLAGS` to specify the number of samples `--num_samples`, batch size `--batch_size`, and the path to the model `--model_path`. For example:
146 |
147 | ```sh
148 | SAMPLE_FLAGS="--num_samples 179200 --batch_size 64 --model_path ema_0.9999_250000.pt"
149 | ```
150 |
151 | Then, run the following command:
152 |
153 | ```sh
154 | python scripts/turb_sample.py $SAMPLE_FLAGS $MODEL_FLAGS $DIFFUSION_FLAGS
155 | ```
156 |
157 | After sampling with the above command, it will generate a file named `samples_179200x2000x3.npz` (for `DM-3c` as an example). You can use the following code to read and retrieve the generated velocities:
158 |
159 | ```python
160 | import h5py
161 | import numpy as np
162 |
163 | with h5py.File('datasets/Lagr_u3c_diffusion.h5', 'r') as h5f:
164 | rx0 = np.array(h5f.get('min'))
165 | rx1 = np.array(h5f.get('max'))
166 |
167 | u3c = (np.load('samples_179200x2000x3.npz')['arr_0']+1)*(rx1-rx0)/2 + rx0
168 | ```
169 |
170 | Just like for training, you can use multiple GPUs for sampling. Please note that the `$MODEL_FLAGS` and `$DIFFUSION_FLAGS` should be the same as those used in training.
171 |
172 | In training the DM-1c and DM-3c models, we utilized four Nvidia A100 GPUs for periods of one and two days, respectively. Acknowledging that extensive computational demands could be a bottleneck for users, we have provided the checkpoints used in the paper, accessible via the following links: [`DM-1c`](https://www.dropbox.com/scl/fi/ox7ytzyh2qcqswoqingiv/ema_0.9999_250000.pt?rlkey=1ld2ccsttj6s6f5tftvcntu0e&dl=0) and [`DM-3c`](https://www.dropbox.com/scl/fi/o7aun6o7lfk99eikds4c2/ema_0.9999_400000.pt?rlkey=mkxaxs0kw4ighb330a3yca0mo&dl=0).
173 |
--------------------------------------------------------------------------------
/datasets/Lagr_u1c_diffusion-demo.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/datasets/Lagr_u1c_diffusion-demo.h5
--------------------------------------------------------------------------------
/datasets/Lagr_u3c_diffusion-demo.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/datasets/Lagr_u3c_diffusion-demo.h5
--------------------------------------------------------------------------------
/datasets/preprocessing-lagr_u1c-diffusion.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 |
4 | with h5py.File('Lagr_u3c_diffusion.h5', 'r') as h5f:
5 | rx0 = np.array(h5f.get('min'))
6 | rx1 = np.array(h5f.get('max'))
7 | u3c = np.array(h5f.get('train'))
8 |
9 | x_train = (u3c+1)*(rx1-rx0)/2 + rx0
10 | x_train = np.concatenate((x_train[..., :1], x_train[..., 1:2], x_train[..., -1:]))
11 |
12 | rx0, rx1 = np.amin(x_train), np.amax(x_train)
13 | x_train = 2*(x_train-rx0)/(rx1-rx0) - 1
14 |
15 | with h5py.File('Lagr_u1c_diffusion.h5', 'w') as hf:
16 | hf.create_dataset('min', data=rx0)
17 | hf.create_dataset('max', data=rx1)
18 | hf.create_dataset('train', data=x_train)
19 |
--------------------------------------------------------------------------------
/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | for p in self.master_params:
203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204 | opt.step()
205 | zero_master_grads(self.master_params)
206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207 | self.lg_loss_scale += self.fp16_scale_growth
208 | return True
209 |
210 | def _optimize_normal(self, opt: th.optim.Optimizer):
211 | grad_norm, param_norm = self._compute_norms()
212 | logger.logkv_mean("grad_norm", grad_norm)
213 | logger.logkv_mean("param_norm", param_norm)
214 | opt.step()
215 | return True
216 |
217 | def _compute_norms(self, grad_scale=1.0):
218 | grad_norm = 0.0
219 | param_norm = 0.0
220 | for p in self.master_params:
221 | with th.no_grad():
222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223 | if p.grad is not None:
224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226 |
227 | def master_params_to_state_dict(self, master_params):
228 | return master_params_to_state_dict(
229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230 | )
231 |
232 | def state_dict_to_master_params(self, state_dict):
233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234 |
235 |
236 | def check_overflow(value):
237 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
238 |
--------------------------------------------------------------------------------
/guided_diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | This code started out as a PyTorch port of Ho et al's diffusion models:
3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4 |
5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6 | """
7 |
8 | import enum
9 | import math
10 |
11 | import numpy as np
12 | import torch as th
13 |
14 | from .nn import mean_flat
15 | from .losses import normal_kl, discretized_gaussian_log_likelihood
16 |
17 |
18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
19 | """
20 | Get a pre-defined beta schedule for the given name.
21 |
22 | The beta schedule library consists of beta schedules which remain similar
23 | in the limit of num_diffusion_timesteps.
24 | Beta schedules may be added, but should not be removed or changed once
25 | they are committed to maintain backwards compatibility.
26 | """
27 | if schedule_name == "linear":
28 | # Linear schedule from Ho et al, extended to work for any number of
29 | # diffusion steps.
30 | scale = 1000 / num_diffusion_timesteps
31 | beta_start = scale * 0.0001
32 | beta_end = scale * 0.02
33 | return np.linspace(
34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
35 | )
36 | elif schedule_name == "cosine":
37 | return betas_for_alpha_bar(
38 | num_diffusion_timesteps,
39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
40 | )
41 | elif schedule_name.startswith("power"):
42 | power = int(schedule_name[5:])
43 | return betas_for_alpha_bar(
44 | num_diffusion_timesteps,
45 | lambda t: 1 - t**power,
46 | )
47 | elif schedule_name.startswith("exp"):
48 | t0 = float(schedule_name[3:])
49 | return betas_for_alpha_bar(
50 | num_diffusion_timesteps,
51 | lambda t: 2 - math.exp((t0 + math.log(2)) * t - t0),
52 | )
53 | elif schedule_name.startswith("tanh"):
54 | t0, t1 = schedule_name.split(",")
55 | t0, t1 = float(t0[4:]), float(t1)
56 | return betas_for_alpha_bar(
57 | num_diffusion_timesteps,
58 | lambda t: -math.tanh((t0 + t1) * t - t0) + math.tanh(t1),
59 | )
60 | else:
61 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
62 |
63 |
64 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
65 | """
66 | Create a beta schedule that discretizes the given alpha_t_bar function,
67 | which defines the cumulative product of (1-beta) over time from t = [0,1].
68 |
69 | :param num_diffusion_timesteps: the number of betas to produce.
70 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
71 | produces the cumulative product of (1-beta) up to that
72 | part of the diffusion process.
73 | :param max_beta: the maximum beta to use; use values lower than 1 to
74 | prevent singularities.
75 | """
76 | betas = []
77 | for i in range(num_diffusion_timesteps):
78 | t1 = i / num_diffusion_timesteps
79 | t2 = (i + 1) / num_diffusion_timesteps
80 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
81 | return np.array(betas)
82 |
83 |
84 | class ModelMeanType(enum.Enum):
85 | """
86 | Which type of output the model predicts.
87 | """
88 |
89 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
90 | START_X = enum.auto() # the model predicts x_0
91 | EPSILON = enum.auto() # the model predicts epsilon
92 |
93 |
94 | class ModelVarType(enum.Enum):
95 | """
96 | What is used as the model's output variance.
97 |
98 | The LEARNED_RANGE option has been added to allow the model to predict
99 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
100 | """
101 |
102 | LEARNED = enum.auto()
103 | FIXED_SMALL = enum.auto()
104 | FIXED_LARGE = enum.auto()
105 | LEARNED_RANGE = enum.auto()
106 |
107 |
108 | class LossType(enum.Enum):
109 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
110 | RESCALED_MSE = (
111 | enum.auto()
112 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
113 | KL = enum.auto() # use the variational lower-bound
114 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
115 |
116 | def is_vb(self):
117 | return self == LossType.KL or self == LossType.RESCALED_KL
118 |
119 |
120 | class GaussianDiffusion:
121 | """
122 | Utilities for training and sampling diffusion models.
123 |
124 | Ported directly from here, and then adapted over time to further experimentation.
125 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
126 |
127 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
128 | starting at T and going to 1.
129 | :param model_mean_type: a ModelMeanType determining what the model outputs.
130 | :param model_var_type: a ModelVarType determining how variance is output.
131 | :param loss_type: a LossType determining the loss function to use.
132 | :param rescale_timesteps: if True, pass floating point timesteps into the
133 | model so that they are always scaled like in the
134 | original paper (0 to 1000).
135 | """
136 |
137 | def __init__(
138 | self,
139 | *,
140 | betas,
141 | model_mean_type,
142 | model_var_type,
143 | loss_type,
144 | rescale_timesteps=False,
145 | ):
146 | self.model_mean_type = model_mean_type
147 | self.model_var_type = model_var_type
148 | self.loss_type = loss_type
149 | self.rescale_timesteps = rescale_timesteps
150 |
151 | # Use float64 for accuracy.
152 | betas = np.array(betas, dtype=np.float64)
153 | self.betas = betas
154 | assert len(betas.shape) == 1, "betas must be 1-D"
155 | assert (betas > 0).all() and (betas <= 1).all()
156 |
157 | self.num_timesteps = int(betas.shape[0])
158 |
159 | alphas = 1.0 - betas
160 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
161 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
162 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
163 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
164 |
165 | # calculations for diffusion q(x_t | x_{t-1}) and others
166 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
167 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
168 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
169 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
170 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
171 |
172 | # calculations for posterior q(x_{t-1} | x_t, x_0)
173 | self.posterior_variance = (
174 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
175 | )
176 | # log calculation clipped because the posterior variance is 0 at the
177 | # beginning of the diffusion chain.
178 | self.posterior_log_variance_clipped = np.log(
179 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
180 | )
181 | self.posterior_mean_coef1 = (
182 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
183 | )
184 | self.posterior_mean_coef2 = (
185 | (1.0 - self.alphas_cumprod_prev)
186 | * np.sqrt(alphas)
187 | / (1.0 - self.alphas_cumprod)
188 | )
189 |
190 | def q_mean_variance(self, x_start, t):
191 | """
192 | Get the distribution q(x_t | x_0).
193 |
194 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
196 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
197 | """
198 | mean = (
199 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
200 | )
201 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
202 | log_variance = _extract_into_tensor(
203 | self.log_one_minus_alphas_cumprod, t, x_start.shape
204 | )
205 | return mean, variance, log_variance
206 |
207 | def q_sample(self, x_start, t, noise=None):
208 | """
209 | Diffuse the data for a given number of diffusion steps.
210 |
211 | In other words, sample from q(x_t | x_0).
212 |
213 | :param x_start: the initial data batch.
214 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
215 | :param noise: if specified, the split-out normal noise.
216 | :return: A noisy version of x_start.
217 | """
218 | if noise is None:
219 | noise = th.randn_like(x_start)
220 | assert noise.shape == x_start.shape
221 | return (
222 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
223 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
224 | * noise
225 | )
226 |
227 | def q_posterior_mean_variance(self, x_start, x_t, t):
228 | """
229 | Compute the mean and variance of the diffusion posterior:
230 |
231 | q(x_{t-1} | x_t, x_0)
232 |
233 | """
234 | assert x_start.shape == x_t.shape
235 | posterior_mean = (
236 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
237 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
238 | )
239 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
240 | posterior_log_variance_clipped = _extract_into_tensor(
241 | self.posterior_log_variance_clipped, t, x_t.shape
242 | )
243 | assert (
244 | posterior_mean.shape[0]
245 | == posterior_variance.shape[0]
246 | == posterior_log_variance_clipped.shape[0]
247 | == x_start.shape[0]
248 | )
249 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
250 |
251 | def p_mean_variance(
252 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
253 | ):
254 | """
255 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
256 | the initial x, x_0.
257 |
258 | :param model: the model, which takes a signal and a batch of timesteps
259 | as input.
260 | :param x: the [N x C x ...] tensor at time t.
261 | :param t: a 1-D Tensor of timesteps.
262 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263 | :param denoised_fn: if not None, a function which applies to the
264 | x_start prediction before it is used to sample. Applies before
265 | clip_denoised.
266 | :param model_kwargs: if not None, a dict of extra keyword arguments to
267 | pass to the model. This can be used for conditioning.
268 | :return: a dict with the following keys:
269 | - 'mean': the model mean output.
270 | - 'variance': the model variance output.
271 | - 'log_variance': the log of 'variance'.
272 | - 'pred_xstart': the prediction for x_0.
273 | """
274 | if model_kwargs is None:
275 | model_kwargs = {}
276 |
277 | B, C = x.shape[:2]
278 | assert t.shape == (B,)
279 | model_output = model(x, self._scale_timesteps(t), **model_kwargs)
280 |
281 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
282 | assert model_output.shape == (B, C * 2, *x.shape[2:])
283 | model_output, model_var_values = th.split(model_output, C, dim=1)
284 | if self.model_var_type == ModelVarType.LEARNED:
285 | model_log_variance = model_var_values
286 | model_variance = th.exp(model_log_variance)
287 | else:
288 | min_log = _extract_into_tensor(
289 | self.posterior_log_variance_clipped, t, x.shape
290 | )
291 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
292 | # The model_var_values is [-1, 1] for [min_var, max_var].
293 | frac = (model_var_values + 1) / 2
294 | model_log_variance = frac * max_log + (1 - frac) * min_log
295 | model_variance = th.exp(model_log_variance)
296 | else:
297 | model_variance, model_log_variance = {
298 | # for fixedlarge, we set the initial (log-)variance like so
299 | # to get a better decoder log likelihood.
300 | ModelVarType.FIXED_LARGE: (
301 | np.append(self.posterior_variance[1], self.betas[1:]),
302 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
303 | ),
304 | ModelVarType.FIXED_SMALL: (
305 | self.posterior_variance,
306 | self.posterior_log_variance_clipped,
307 | ),
308 | }[self.model_var_type]
309 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
310 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
311 |
312 | def process_xstart(x):
313 | if denoised_fn is not None:
314 | x = denoised_fn(x)
315 | if clip_denoised:
316 | return x.clamp(-1, 1)
317 | return x
318 |
319 | if self.model_mean_type == ModelMeanType.PREVIOUS_X:
320 | pred_xstart = process_xstart(
321 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
322 | )
323 | model_mean = model_output
324 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
325 | if self.model_mean_type == ModelMeanType.START_X:
326 | pred_xstart = process_xstart(model_output)
327 | else:
328 | pred_xstart = process_xstart(
329 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
330 | )
331 | model_mean, _, _ = self.q_posterior_mean_variance(
332 | x_start=pred_xstart, x_t=x, t=t
333 | )
334 | else:
335 | raise NotImplementedError(self.model_mean_type)
336 |
337 | assert (
338 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
339 | )
340 | return {
341 | "mean": model_mean,
342 | "variance": model_variance,
343 | "log_variance": model_log_variance,
344 | "pred_xstart": pred_xstart,
345 | }
346 |
347 | def _predict_xstart_from_eps(self, x_t, t, eps):
348 | assert x_t.shape == eps.shape
349 | return (
350 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
351 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
352 | )
353 |
354 | def _predict_xstart_from_xprev(self, x_t, t, xprev):
355 | assert x_t.shape == xprev.shape
356 | return ( # (xprev - coef2*x_t) / coef1
357 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
358 | - _extract_into_tensor(
359 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
360 | )
361 | * x_t
362 | )
363 |
364 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
365 | return (
366 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
367 | - pred_xstart
368 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
369 |
370 | def _scale_timesteps(self, t):
371 | if self.rescale_timesteps:
372 | return t.float() * (1000.0 / self.num_timesteps)
373 | return t
374 |
375 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
376 | """
377 | Compute the mean for the previous step, given a function cond_fn that
378 | computes the gradient of a conditional log probability with respect to
379 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
380 | condition on y.
381 |
382 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
383 | """
384 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
385 | new_mean = (
386 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
387 | )
388 | return new_mean
389 |
390 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
391 | """
392 | Compute what the p_mean_variance output would have been, should the
393 | model's score function be conditioned by cond_fn.
394 |
395 | See condition_mean() for details on cond_fn.
396 |
397 | Unlike condition_mean(), this instead uses the conditioning strategy
398 | from Song et al (2020).
399 | """
400 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
401 |
402 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
403 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
404 | x, self._scale_timesteps(t), **model_kwargs
405 | )
406 |
407 | out = p_mean_var.copy()
408 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
409 | out["mean"], _, _ = self.q_posterior_mean_variance(
410 | x_start=out["pred_xstart"], x_t=x, t=t
411 | )
412 | return out
413 |
414 | def p_sample(
415 | self,
416 | model,
417 | x,
418 | t,
419 | clip_denoised=True,
420 | denoised_fn=None,
421 | cond_fn=None,
422 | model_kwargs=None,
423 | ):
424 | """
425 | Sample x_{t-1} from the model at the given timestep.
426 |
427 | :param model: the model to sample from.
428 | :param x: the current tensor at x_{t-1}.
429 | :param t: the value of t, starting at 0 for the first diffusion step.
430 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
431 | :param denoised_fn: if not None, a function which applies to the
432 | x_start prediction before it is used to sample.
433 | :param cond_fn: if not None, this is a gradient function that acts
434 | similarly to the model.
435 | :param model_kwargs: if not None, a dict of extra keyword arguments to
436 | pass to the model. This can be used for conditioning.
437 | :return: a dict containing the following keys:
438 | - 'sample': a random sample from the model.
439 | - 'pred_xstart': a prediction of x_0.
440 | """
441 | out = self.p_mean_variance(
442 | model,
443 | x,
444 | t,
445 | clip_denoised=clip_denoised,
446 | denoised_fn=denoised_fn,
447 | model_kwargs=model_kwargs,
448 | )
449 | noise = th.randn_like(x)
450 | nonzero_mask = (
451 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
452 | ) # no noise when t == 0
453 | if cond_fn is not None:
454 | out["mean"] = self.condition_mean(
455 | cond_fn, out, x, t, model_kwargs=model_kwargs
456 | )
457 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
458 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
459 |
460 | def p_sample_loop(
461 | self,
462 | model,
463 | shape,
464 | noise=None,
465 | clip_denoised=True,
466 | denoised_fn=None,
467 | cond_fn=None,
468 | model_kwargs=None,
469 | device=None,
470 | progress=False,
471 | ):
472 | """
473 | Generate samples from the model.
474 |
475 | :param model: the model module.
476 | :param shape: the shape of the samples, (N, C, H, W).
477 | :param noise: if specified, the noise from the encoder to sample.
478 | Should be of the same shape as `shape`.
479 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
480 | :param denoised_fn: if not None, a function which applies to the
481 | x_start prediction before it is used to sample.
482 | :param cond_fn: if not None, this is a gradient function that acts
483 | similarly to the model.
484 | :param model_kwargs: if not None, a dict of extra keyword arguments to
485 | pass to the model. This can be used for conditioning.
486 | :param device: if specified, the device to create the samples on.
487 | If not specified, use a model parameter's device.
488 | :param progress: if True, show a tqdm progress bar.
489 | :return: a non-differentiable batch of samples.
490 | """
491 | final = None
492 | for sample in self.p_sample_loop_progressive(
493 | model,
494 | shape,
495 | noise=noise,
496 | clip_denoised=clip_denoised,
497 | denoised_fn=denoised_fn,
498 | cond_fn=cond_fn,
499 | model_kwargs=model_kwargs,
500 | device=device,
501 | progress=progress,
502 | ):
503 | final = sample
504 | return final["sample"]
505 |
506 | def p_sample_loop_history(
507 | self,
508 | model,
509 | shape,
510 | noise=None,
511 | clip_denoised=True,
512 | denoised_fn=None,
513 | cond_fn=None,
514 | model_kwargs=None,
515 | device=None,
516 | progress=False,
517 | ):
518 | """
519 | Generate samples from the model with the denoising history.
520 |
521 | :param model: the model module.
522 | :param shape: the shape of the samples, (N, C, H, W).
523 | :param noise: if specified, the noise from the encoder to sample.
524 | Should be of the same shape as `shape`.
525 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
526 | :param denoised_fn: if not None, a function which applies to the
527 | x_start prediction before it is used to sample.
528 | :param cond_fn: if not None, this is a gradient function that acts
529 | similarly to the model.
530 | :param model_kwargs: if not None, a dict of extra keyword arguments to
531 | pass to the model. This can be used for conditioning.
532 | :param device: if specified, the device to create the samples on.
533 | If not specified, use a model parameter's device.
534 | :param progress: if True, show a tqdm progress bar.
535 | :return: a non-differentiable batch of samples with the denoising history.
536 | """
537 | if noise is None:
538 | assert isinstance(shape, (tuple, list))
539 | if device is None:
540 | device = next(model.parameters()).device
541 | noise = th.randn(*shape, device=device)
542 |
543 | sample_history = [noise]
544 | for sample in self.p_sample_loop_progressive(
545 | model,
546 | shape,
547 | noise=noise,
548 | clip_denoised=clip_denoised,
549 | denoised_fn=denoised_fn,
550 | cond_fn=cond_fn,
551 | model_kwargs=model_kwargs,
552 | device=device,
553 | progress=progress,
554 | ):
555 | sample_history.append(sample["sample"])
556 | return th.stack(sample_history, dim=1)
557 |
558 | def p_sample_loop_progressive(
559 | self,
560 | model,
561 | shape,
562 | noise=None,
563 | clip_denoised=True,
564 | denoised_fn=None,
565 | cond_fn=None,
566 | model_kwargs=None,
567 | device=None,
568 | progress=False,
569 | ):
570 | """
571 | Generate samples from the model and yield intermediate samples from
572 | each timestep of diffusion.
573 |
574 | Arguments are the same as p_sample_loop().
575 | Returns a generator over dicts, where each dict is the return value of
576 | p_sample().
577 | """
578 | if device is None:
579 | device = next(model.parameters()).device
580 | assert isinstance(shape, (tuple, list))
581 | if noise is not None:
582 | img = noise
583 | else:
584 | img = th.randn(*shape, device=device)
585 | indices = list(range(self.num_timesteps))[::-1]
586 |
587 | if progress:
588 | # Lazy import so that we don't depend on tqdm.
589 | from tqdm.auto import tqdm
590 |
591 | indices = tqdm(indices)
592 |
593 | for i in indices:
594 | t = th.tensor([i] * shape[0], device=device)
595 | with th.no_grad():
596 | out = self.p_sample(
597 | model,
598 | img,
599 | t,
600 | clip_denoised=clip_denoised,
601 | denoised_fn=denoised_fn,
602 | cond_fn=cond_fn,
603 | model_kwargs=model_kwargs,
604 | )
605 | yield out
606 | img = out["sample"]
607 |
608 | def ddim_sample(
609 | self,
610 | model,
611 | x,
612 | t,
613 | clip_denoised=True,
614 | denoised_fn=None,
615 | cond_fn=None,
616 | model_kwargs=None,
617 | eta=0.0,
618 | ):
619 | """
620 | Sample x_{t-1} from the model using DDIM.
621 |
622 | Same usage as p_sample().
623 | """
624 | out = self.p_mean_variance(
625 | model,
626 | x,
627 | t,
628 | clip_denoised=clip_denoised,
629 | denoised_fn=denoised_fn,
630 | model_kwargs=model_kwargs,
631 | )
632 | if cond_fn is not None:
633 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
634 |
635 | # Usually our model outputs epsilon, but we re-derive it
636 | # in case we used x_start or x_prev prediction.
637 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
638 |
639 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
640 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
641 | sigma = (
642 | eta
643 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
644 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
645 | )
646 | # Equation 12.
647 | noise = th.randn_like(x)
648 | mean_pred = (
649 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
650 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
651 | )
652 | nonzero_mask = (
653 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
654 | ) # no noise when t == 0
655 | sample = mean_pred + nonzero_mask * sigma * noise
656 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
657 |
658 | def ddim_reverse_sample(
659 | self,
660 | model,
661 | x,
662 | t,
663 | clip_denoised=True,
664 | denoised_fn=None,
665 | model_kwargs=None,
666 | eta=0.0,
667 | ):
668 | """
669 | Sample x_{t+1} from the model using DDIM reverse ODE.
670 | """
671 | assert eta == 0.0, "Reverse ODE only for deterministic path"
672 | out = self.p_mean_variance(
673 | model,
674 | x,
675 | t,
676 | clip_denoised=clip_denoised,
677 | denoised_fn=denoised_fn,
678 | model_kwargs=model_kwargs,
679 | )
680 | # Usually our model outputs epsilon, but we re-derive it
681 | # in case we used x_start or x_prev prediction.
682 | eps = (
683 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
684 | - out["pred_xstart"]
685 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
686 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
687 |
688 | # Equation 12. reversed
689 | mean_pred = (
690 | out["pred_xstart"] * th.sqrt(alpha_bar_next)
691 | + th.sqrt(1 - alpha_bar_next) * eps
692 | )
693 |
694 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
695 |
696 | def ddim_sample_loop(
697 | self,
698 | model,
699 | shape,
700 | noise=None,
701 | clip_denoised=True,
702 | denoised_fn=None,
703 | cond_fn=None,
704 | model_kwargs=None,
705 | device=None,
706 | progress=False,
707 | eta=0.0,
708 | ):
709 | """
710 | Generate samples from the model using DDIM.
711 |
712 | Same usage as p_sample_loop().
713 | """
714 | final = None
715 | for sample in self.ddim_sample_loop_progressive(
716 | model,
717 | shape,
718 | noise=noise,
719 | clip_denoised=clip_denoised,
720 | denoised_fn=denoised_fn,
721 | cond_fn=cond_fn,
722 | model_kwargs=model_kwargs,
723 | device=device,
724 | progress=progress,
725 | eta=eta,
726 | ):
727 | final = sample
728 | return final["sample"]
729 |
730 | def ddim_sample_loop_progressive(
731 | self,
732 | model,
733 | shape,
734 | noise=None,
735 | clip_denoised=True,
736 | denoised_fn=None,
737 | cond_fn=None,
738 | model_kwargs=None,
739 | device=None,
740 | progress=False,
741 | eta=0.0,
742 | ):
743 | """
744 | Use DDIM to sample from the model and yield intermediate samples from
745 | each timestep of DDIM.
746 |
747 | Same usage as p_sample_loop_progressive().
748 | """
749 | if device is None:
750 | device = next(model.parameters()).device
751 | assert isinstance(shape, (tuple, list))
752 | if noise is not None:
753 | img = noise
754 | else:
755 | img = th.randn(*shape, device=device)
756 | indices = list(range(self.num_timesteps))[::-1]
757 |
758 | if progress:
759 | # Lazy import so that we don't depend on tqdm.
760 | from tqdm.auto import tqdm
761 |
762 | indices = tqdm(indices)
763 |
764 | for i in indices:
765 | t = th.tensor([i] * shape[0], device=device)
766 | with th.no_grad():
767 | out = self.ddim_sample(
768 | model,
769 | img,
770 | t,
771 | clip_denoised=clip_denoised,
772 | denoised_fn=denoised_fn,
773 | cond_fn=cond_fn,
774 | model_kwargs=model_kwargs,
775 | eta=eta,
776 | )
777 | yield out
778 | img = out["sample"]
779 |
780 | def _vb_terms_bpd(
781 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
782 | ):
783 | """
784 | Get a term for the variational lower-bound.
785 |
786 | The resulting units are bits (rather than nats, as one might expect).
787 | This allows for comparison to other papers.
788 |
789 | :return: a dict with the following keys:
790 | - 'output': a shape [N] tensor of NLLs or KLs.
791 | - 'pred_xstart': the x_0 predictions.
792 | """
793 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
794 | x_start=x_start, x_t=x_t, t=t
795 | )
796 | out = self.p_mean_variance(
797 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
798 | )
799 | kl = normal_kl(
800 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
801 | )
802 | kl = mean_flat(kl) / np.log(2.0)
803 |
804 | decoder_nll = -discretized_gaussian_log_likelihood(
805 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
806 | )
807 | assert decoder_nll.shape == x_start.shape
808 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
809 |
810 | # At the first timestep return the decoder NLL,
811 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
812 | output = th.where((t == 0), decoder_nll, kl)
813 | return {"output": output, "pred_xstart": out["pred_xstart"]}
814 |
815 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
816 | """
817 | Compute training losses for a single timestep.
818 |
819 | :param model: the model to evaluate loss on.
820 | :param x_start: the [N x C x ...] tensor of inputs.
821 | :param t: a batch of timestep indices.
822 | :param model_kwargs: if not None, a dict of extra keyword arguments to
823 | pass to the model. This can be used for conditioning.
824 | :param noise: if specified, the specific Gaussian noise to try to remove.
825 | :return: a dict with the key "loss" containing a tensor of shape [N].
826 | Some mean or variance settings may also have other keys.
827 | """
828 | if model_kwargs is None:
829 | model_kwargs = {}
830 | if noise is None:
831 | noise = th.randn_like(x_start)
832 | x_t = self.q_sample(x_start, t, noise=noise)
833 |
834 | terms = {}
835 |
836 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
837 | terms["loss"] = self._vb_terms_bpd(
838 | model=model,
839 | x_start=x_start,
840 | x_t=x_t,
841 | t=t,
842 | clip_denoised=False,
843 | model_kwargs=model_kwargs,
844 | )["output"]
845 | if self.loss_type == LossType.RESCALED_KL:
846 | terms["loss"] *= self.num_timesteps
847 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
848 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
849 |
850 | if self.model_var_type in [
851 | ModelVarType.LEARNED,
852 | ModelVarType.LEARNED_RANGE,
853 | ]:
854 | B, C = x_t.shape[:2]
855 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
856 | model_output, model_var_values = th.split(model_output, C, dim=1)
857 | # Learn the variance using the variational bound, but don't let
858 | # it affect our mean prediction.
859 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
860 | terms["vb"] = self._vb_terms_bpd(
861 | model=lambda *args, r=frozen_out: r,
862 | x_start=x_start,
863 | x_t=x_t,
864 | t=t,
865 | clip_denoised=False,
866 | )["output"]
867 | if self.loss_type == LossType.RESCALED_MSE:
868 | # Divide by 1000 for equivalence with initial implementation.
869 | # Without a factor of 1/1000, the VB term hurts the MSE term.
870 | terms["vb"] *= self.num_timesteps / 1000.0
871 |
872 | target = {
873 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
874 | x_start=x_start, x_t=x_t, t=t
875 | )[0],
876 | ModelMeanType.START_X: x_start,
877 | ModelMeanType.EPSILON: noise,
878 | }[self.model_mean_type]
879 | assert model_output.shape == target.shape == x_start.shape
880 | terms["mse"] = mean_flat((target - model_output) ** 2)
881 | if "vb" in terms:
882 | terms["loss"] = terms["mse"] + terms["vb"]
883 | else:
884 | terms["loss"] = terms["mse"]
885 | else:
886 | raise NotImplementedError(self.loss_type)
887 |
888 | return terms
889 |
890 | def calc_losses_loop(self, model, x_start, model_kwargs=None):
891 | """
892 | Compute training losses for all timesteps.
893 |
894 | :param model: the model to evaluate loss on.
895 | :param x_start: the [N x C x ...] tensor of inputs.
896 | :param model_kwargs: if not None, a dict of extra keyword arguments to
897 | pass to the model. This can be used for conditioning.
898 | :return: a dict with the key "loss" containing a tensor of shape [N x T].
899 | Some mean or variance settings may also have other keys.
900 | """
901 | device = x_start.device
902 | batch_size = x_start.shape[0]
903 |
904 | outs = []
905 | for t in list(range(self.num_timesteps))[::-1]:
906 | t_batch = th.tensor([t] * batch_size, device=device)
907 | noise = th.randn_like(x_start)
908 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
909 | # Compute training losses at the current timestep
910 | with th.no_grad():
911 | out = self.training_losses(
912 | model,
913 | x_start=x_start,
914 | t=t_batch,
915 | model_kwargs=model_kwargs,
916 | noise=noise
917 | )
918 | outs.append(out)
919 |
920 | output = {key: [] for key in out.keys()}
921 | for out in outs:
922 | for key in out:
923 | output[key].append(out[key])
924 | for key in output:
925 | output[key] = th.stack(output[key], dim=1)
926 |
927 | return output
928 |
929 | def _prior_bpd(self, x_start):
930 | """
931 | Get the prior KL term for the variational lower-bound, measured in
932 | bits-per-dim.
933 |
934 | This term can't be optimized, as it only depends on the encoder.
935 |
936 | :param x_start: the [N x C x ...] tensor of inputs.
937 | :return: a batch of [N] KL values (in bits), one per batch element.
938 | """
939 | batch_size = x_start.shape[0]
940 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
941 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
942 | kl_prior = normal_kl(
943 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
944 | )
945 | return mean_flat(kl_prior) / np.log(2.0)
946 |
947 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
948 | """
949 | Compute the entire variational lower-bound, measured in bits-per-dim,
950 | as well as other related quantities.
951 |
952 | :param model: the model to evaluate loss on.
953 | :param x_start: the [N x C x ...] tensor of inputs.
954 | :param clip_denoised: if True, clip denoised samples.
955 | :param model_kwargs: if not None, a dict of extra keyword arguments to
956 | pass to the model. This can be used for conditioning.
957 |
958 | :return: a dict containing the following keys:
959 | - total_bpd: the total variational lower-bound, per batch element.
960 | - prior_bpd: the prior term in the lower-bound.
961 | - vb: an [N x T] tensor of terms in the lower-bound.
962 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
963 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
964 | """
965 | device = x_start.device
966 | batch_size = x_start.shape[0]
967 |
968 | vb = []
969 | xstart_mse = []
970 | mse = []
971 | for t in list(range(self.num_timesteps))[::-1]:
972 | t_batch = th.tensor([t] * batch_size, device=device)
973 | noise = th.randn_like(x_start)
974 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
975 | # Calculate VLB term at the current timestep
976 | with th.no_grad():
977 | out = self._vb_terms_bpd(
978 | model,
979 | x_start=x_start,
980 | x_t=x_t,
981 | t=t_batch,
982 | clip_denoised=clip_denoised,
983 | model_kwargs=model_kwargs,
984 | )
985 | vb.append(out["output"])
986 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
987 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
988 | mse.append(mean_flat((eps - noise) ** 2))
989 |
990 | vb = th.stack(vb, dim=1)
991 | xstart_mse = th.stack(xstart_mse, dim=1)
992 | mse = th.stack(mse, dim=1)
993 |
994 | prior_bpd = self._prior_bpd(x_start)
995 | total_bpd = vb.sum(dim=1) + prior_bpd
996 | return {
997 | "total_bpd": total_bpd,
998 | "prior_bpd": prior_bpd,
999 | "vb": vb,
1000 | "xstart_mse": xstart_mse,
1001 | "mse": mse,
1002 | }
1003 |
1004 |
1005 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
1006 | """
1007 | Extract values from a 1-D numpy array for a batch of indices.
1008 |
1009 | :param arr: the 1-D numpy array.
1010 | :param timesteps: a tensor of indices into the array to extract.
1011 | :param broadcast_shape: a larger shape of K dimensions with the batch
1012 | dimension equal to the length of timesteps.
1013 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1014 | """
1015 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1016 | while len(res.shape) < len(broadcast_shape):
1017 | res = res[..., None]
1018 | return res.expand(broadcast_shape)
1019 |
--------------------------------------------------------------------------------
/guided_diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452 | )
453 | assert isinstance(dir, str)
454 | dir = os.path.expanduser(dir)
455 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
456 |
457 | rank = get_rank_without_mpi_import()
458 | if rank > 0:
459 | log_suffix = log_suffix + "-rank%03i" % rank
460 |
461 | if format_strs is None:
462 | if rank == 0:
463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464 | else:
465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466 | format_strs = filter(None, format_strs)
467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468 |
469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470 | if output_formats:
471 | log("Logging to %s" % dir)
472 |
473 |
474 | def _configure_default_logger():
475 | configure()
476 | Logger.DEFAULT = Logger.CURRENT
477 |
478 |
479 | def reset():
480 | if Logger.CURRENT is not Logger.DEFAULT:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = Logger.DEFAULT
483 | log("Reset logger")
484 |
485 |
486 | @contextmanager
487 | def scoped_configure(dir=None, format_strs=None, comm=None):
488 | prevlogger = Logger.CURRENT
489 | configure(dir=dir, format_strs=format_strs, comm=comm)
490 | try:
491 | yield
492 | finally:
493 | Logger.CURRENT.close()
494 | Logger.CURRENT = prevlogger
495 |
496 |
--------------------------------------------------------------------------------
/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------------------------------------
3 | Copied from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py
4 |
5 | Functions have been modified to make the inputs "dims", "in_channels" and "out_channels" of UNetModel changeable.
6 | -------------------------------------------------------------------------------
7 | """
8 |
9 | import argparse
10 | import inspect
11 |
12 | from . import gaussian_diffusion as gd
13 | from .respace import SpacedDiffusion, space_timesteps
14 | from .unet import SuperResModel, UNetModel, EncoderUNetModel
15 |
16 | NUM_CLASSES = 1000
17 |
18 |
19 | def diffusion_defaults():
20 | """
21 | Defaults for image and classifier training.
22 | """
23 | return dict(
24 | learn_sigma=False,
25 | diffusion_steps=1000,
26 | noise_schedule="linear",
27 | timestep_respacing="",
28 | use_kl=False,
29 | predict_xstart=False,
30 | rescale_timesteps=False,
31 | rescale_learned_sigmas=False,
32 | )
33 |
34 |
35 | def classifier_defaults():
36 | """
37 | Defaults for classifier models.
38 | """
39 | return dict(
40 | image_size=64,
41 | classifier_use_fp16=False,
42 | classifier_width=128,
43 | classifier_depth=2,
44 | classifier_attention_resolutions="32,16,8", # 16
45 | classifier_use_scale_shift_norm=True, # False
46 | classifier_resblock_updown=True, # False
47 | classifier_pool="attention",
48 | )
49 |
50 |
51 | def model_and_diffusion_defaults():
52 | """
53 | Defaults for image training.
54 | """
55 | res = dict(
56 | dims=2,
57 | image_size=64,
58 | in_channels=3,
59 | num_channels=128,
60 | num_res_blocks=2,
61 | num_heads=4,
62 | num_heads_upsample=-1,
63 | num_head_channels=-1,
64 | attention_resolutions="16,8",
65 | channel_mult="",
66 | dropout=0.0,
67 | class_cond=False,
68 | use_checkpoint=False,
69 | use_scale_shift_norm=True,
70 | resblock_updown=False,
71 | use_fp16=False,
72 | use_new_attention_order=False,
73 | )
74 | res.update(diffusion_defaults())
75 | return res
76 |
77 |
78 | def classifier_and_diffusion_defaults():
79 | res = classifier_defaults()
80 | res.update(diffusion_defaults())
81 | return res
82 |
83 |
84 | def create_model_and_diffusion(
85 | dims,
86 | image_size,
87 | in_channels,
88 | class_cond,
89 | learn_sigma,
90 | num_channels,
91 | num_res_blocks,
92 | channel_mult,
93 | num_heads,
94 | num_head_channels,
95 | num_heads_upsample,
96 | attention_resolutions,
97 | dropout,
98 | diffusion_steps,
99 | noise_schedule,
100 | timestep_respacing,
101 | use_kl,
102 | predict_xstart,
103 | rescale_timesteps,
104 | rescale_learned_sigmas,
105 | use_checkpoint,
106 | use_scale_shift_norm,
107 | resblock_updown,
108 | use_fp16,
109 | use_new_attention_order,
110 | ):
111 | model = create_model(
112 | dims,
113 | image_size,
114 | in_channels,
115 | num_channels,
116 | num_res_blocks,
117 | channel_mult=channel_mult,
118 | learn_sigma=learn_sigma,
119 | class_cond=class_cond,
120 | use_checkpoint=use_checkpoint,
121 | attention_resolutions=attention_resolutions,
122 | num_heads=num_heads,
123 | num_head_channels=num_head_channels,
124 | num_heads_upsample=num_heads_upsample,
125 | use_scale_shift_norm=use_scale_shift_norm,
126 | dropout=dropout,
127 | resblock_updown=resblock_updown,
128 | use_fp16=use_fp16,
129 | use_new_attention_order=use_new_attention_order,
130 | )
131 | diffusion = create_gaussian_diffusion(
132 | steps=diffusion_steps,
133 | learn_sigma=learn_sigma,
134 | noise_schedule=noise_schedule,
135 | use_kl=use_kl,
136 | predict_xstart=predict_xstart,
137 | rescale_timesteps=rescale_timesteps,
138 | rescale_learned_sigmas=rescale_learned_sigmas,
139 | timestep_respacing=timestep_respacing,
140 | )
141 | return model, diffusion
142 |
143 |
144 | def create_model(
145 | dims,
146 | image_size,
147 | in_channels,
148 | num_channels,
149 | num_res_blocks,
150 | channel_mult="",
151 | learn_sigma=False,
152 | class_cond=False,
153 | use_checkpoint=False,
154 | attention_resolutions="16",
155 | num_heads=1,
156 | num_head_channels=-1,
157 | num_heads_upsample=-1,
158 | use_scale_shift_norm=False,
159 | dropout=0,
160 | resblock_updown=False,
161 | use_fp16=False,
162 | use_new_attention_order=False,
163 | ):
164 | if channel_mult == "":
165 | if image_size == 512:
166 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
167 | elif image_size == 256:
168 | channel_mult = (1, 1, 2, 2, 4, 4)
169 | elif image_size == 128:
170 | channel_mult = (1, 1, 2, 3, 4)
171 | elif image_size == 64:
172 | channel_mult = (1, 2, 3, 4)
173 | else:
174 | raise ValueError(f"unsupported image size: {image_size}")
175 | else:
176 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
177 |
178 | attention_ds = []
179 | for res in attention_resolutions.split(","):
180 | attention_ds.append(image_size // int(res))
181 |
182 | return UNetModel(
183 | image_size=image_size,
184 | in_channels=in_channels,
185 | model_channels=num_channels,
186 | out_channels=(in_channels if not learn_sigma else 2*in_channels),
187 | num_res_blocks=num_res_blocks,
188 | attention_resolutions=tuple(attention_ds),
189 | dropout=dropout,
190 | channel_mult=channel_mult,
191 | dims=dims,
192 | num_classes=(NUM_CLASSES if class_cond else None),
193 | use_checkpoint=use_checkpoint,
194 | use_fp16=use_fp16,
195 | num_heads=num_heads,
196 | num_head_channels=num_head_channels,
197 | num_heads_upsample=num_heads_upsample,
198 | use_scale_shift_norm=use_scale_shift_norm,
199 | resblock_updown=resblock_updown,
200 | use_new_attention_order=use_new_attention_order,
201 | )
202 |
203 |
204 | def create_classifier_and_diffusion(
205 | image_size,
206 | classifier_use_fp16,
207 | classifier_width,
208 | classifier_depth,
209 | classifier_attention_resolutions,
210 | classifier_use_scale_shift_norm,
211 | classifier_resblock_updown,
212 | classifier_pool,
213 | learn_sigma,
214 | diffusion_steps,
215 | noise_schedule,
216 | timestep_respacing,
217 | use_kl,
218 | predict_xstart,
219 | rescale_timesteps,
220 | rescale_learned_sigmas,
221 | ):
222 | classifier = create_classifier(
223 | image_size,
224 | classifier_use_fp16,
225 | classifier_width,
226 | classifier_depth,
227 | classifier_attention_resolutions,
228 | classifier_use_scale_shift_norm,
229 | classifier_resblock_updown,
230 | classifier_pool,
231 | )
232 | diffusion = create_gaussian_diffusion(
233 | steps=diffusion_steps,
234 | learn_sigma=learn_sigma,
235 | noise_schedule=noise_schedule,
236 | use_kl=use_kl,
237 | predict_xstart=predict_xstart,
238 | rescale_timesteps=rescale_timesteps,
239 | rescale_learned_sigmas=rescale_learned_sigmas,
240 | timestep_respacing=timestep_respacing,
241 | )
242 | return classifier, diffusion
243 |
244 |
245 | def create_classifier(
246 | image_size,
247 | classifier_use_fp16,
248 | classifier_width,
249 | classifier_depth,
250 | classifier_attention_resolutions,
251 | classifier_use_scale_shift_norm,
252 | classifier_resblock_updown,
253 | classifier_pool,
254 | ):
255 | if image_size == 512:
256 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
257 | elif image_size == 256:
258 | channel_mult = (1, 1, 2, 2, 4, 4)
259 | elif image_size == 128:
260 | channel_mult = (1, 1, 2, 3, 4)
261 | elif image_size == 64:
262 | channel_mult = (1, 2, 3, 4)
263 | else:
264 | raise ValueError(f"unsupported image size: {image_size}")
265 |
266 | attention_ds = []
267 | for res in classifier_attention_resolutions.split(","):
268 | attention_ds.append(image_size // int(res))
269 |
270 | return EncoderUNetModel(
271 | image_size=image_size,
272 | in_channels=3,
273 | model_channels=classifier_width,
274 | out_channels=1000,
275 | num_res_blocks=classifier_depth,
276 | attention_resolutions=tuple(attention_ds),
277 | channel_mult=channel_mult,
278 | use_fp16=classifier_use_fp16,
279 | num_head_channels=64,
280 | use_scale_shift_norm=classifier_use_scale_shift_norm,
281 | resblock_updown=classifier_resblock_updown,
282 | pool=classifier_pool,
283 | )
284 |
285 |
286 | def sr_model_and_diffusion_defaults():
287 | res = model_and_diffusion_defaults()
288 | res["large_size"] = 256
289 | res["small_size"] = 64
290 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
291 | for k in res.copy().keys():
292 | if k not in arg_names:
293 | del res[k]
294 | return res
295 |
296 |
297 | def sr_create_model_and_diffusion(
298 | large_size,
299 | small_size,
300 | class_cond,
301 | learn_sigma,
302 | num_channels,
303 | num_res_blocks,
304 | num_heads,
305 | num_head_channels,
306 | num_heads_upsample,
307 | attention_resolutions,
308 | dropout,
309 | diffusion_steps,
310 | noise_schedule,
311 | timestep_respacing,
312 | use_kl,
313 | predict_xstart,
314 | rescale_timesteps,
315 | rescale_learned_sigmas,
316 | use_checkpoint,
317 | use_scale_shift_norm,
318 | resblock_updown,
319 | use_fp16,
320 | ):
321 | model = sr_create_model(
322 | large_size,
323 | small_size,
324 | num_channels,
325 | num_res_blocks,
326 | learn_sigma=learn_sigma,
327 | class_cond=class_cond,
328 | use_checkpoint=use_checkpoint,
329 | attention_resolutions=attention_resolutions,
330 | num_heads=num_heads,
331 | num_head_channels=num_head_channels,
332 | num_heads_upsample=num_heads_upsample,
333 | use_scale_shift_norm=use_scale_shift_norm,
334 | dropout=dropout,
335 | resblock_updown=resblock_updown,
336 | use_fp16=use_fp16,
337 | )
338 | diffusion = create_gaussian_diffusion(
339 | steps=diffusion_steps,
340 | learn_sigma=learn_sigma,
341 | noise_schedule=noise_schedule,
342 | use_kl=use_kl,
343 | predict_xstart=predict_xstart,
344 | rescale_timesteps=rescale_timesteps,
345 | rescale_learned_sigmas=rescale_learned_sigmas,
346 | timestep_respacing=timestep_respacing,
347 | )
348 | return model, diffusion
349 |
350 |
351 | def sr_create_model(
352 | large_size,
353 | small_size,
354 | num_channels,
355 | num_res_blocks,
356 | learn_sigma,
357 | class_cond,
358 | use_checkpoint,
359 | attention_resolutions,
360 | num_heads,
361 | num_head_channels,
362 | num_heads_upsample,
363 | use_scale_shift_norm,
364 | dropout,
365 | resblock_updown,
366 | use_fp16,
367 | ):
368 | _ = small_size # hack to prevent unused variable
369 |
370 | if large_size == 512:
371 | channel_mult = (1, 1, 2, 2, 4, 4)
372 | elif large_size == 256:
373 | channel_mult = (1, 1, 2, 2, 4, 4)
374 | elif large_size == 64:
375 | channel_mult = (1, 2, 3, 4)
376 | else:
377 | raise ValueError(f"unsupported large size: {large_size}")
378 |
379 | attention_ds = []
380 | for res in attention_resolutions.split(","):
381 | attention_ds.append(large_size // int(res))
382 |
383 | return SuperResModel(
384 | image_size=large_size,
385 | in_channels=3,
386 | model_channels=num_channels,
387 | out_channels=(3 if not learn_sigma else 6),
388 | num_res_blocks=num_res_blocks,
389 | attention_resolutions=tuple(attention_ds),
390 | dropout=dropout,
391 | channel_mult=channel_mult,
392 | num_classes=(NUM_CLASSES if class_cond else None),
393 | use_checkpoint=use_checkpoint,
394 | num_heads=num_heads,
395 | num_head_channels=num_head_channels,
396 | num_heads_upsample=num_heads_upsample,
397 | use_scale_shift_norm=use_scale_shift_norm,
398 | resblock_updown=resblock_updown,
399 | use_fp16=use_fp16,
400 | )
401 |
402 |
403 | def create_gaussian_diffusion(
404 | *,
405 | steps=1000,
406 | learn_sigma=False,
407 | sigma_small=False,
408 | noise_schedule="linear",
409 | use_kl=False,
410 | predict_xstart=False,
411 | rescale_timesteps=False,
412 | rescale_learned_sigmas=False,
413 | timestep_respacing="",
414 | ):
415 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
416 | if use_kl:
417 | loss_type = gd.LossType.RESCALED_KL
418 | elif rescale_learned_sigmas:
419 | loss_type = gd.LossType.RESCALED_MSE
420 | else:
421 | loss_type = gd.LossType.MSE
422 | if not timestep_respacing:
423 | timestep_respacing = [steps]
424 | return SpacedDiffusion(
425 | use_timesteps=space_timesteps(steps, timestep_respacing),
426 | betas=betas,
427 | model_mean_type=(
428 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
429 | ),
430 | model_var_type=(
431 | (
432 | gd.ModelVarType.FIXED_LARGE
433 | if not sigma_small
434 | else gd.ModelVarType.FIXED_SMALL
435 | )
436 | if not learn_sigma
437 | else gd.ModelVarType.LEARNED_RANGE
438 | ),
439 | loss_type=loss_type,
440 | rescale_timesteps=rescale_timesteps,
441 | )
442 |
443 |
444 | def add_dict_to_argparser(parser, default_dict):
445 | for k, v in default_dict.items():
446 | v_type = type(v)
447 | if v is None:
448 | v_type = str
449 | elif isinstance(v, bool):
450 | v_type = str2bool
451 | parser.add_argument(f"--{k}", default=v, type=v_type)
452 |
453 |
454 | def args_to_dict(args, keys):
455 | return {k: getattr(args, k) for k in keys}
456 |
457 |
458 | def str2bool(v):
459 | """
460 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
461 | """
462 | if isinstance(v, bool):
463 | return v
464 | if v.lower() in ("yes", "true", "t", "y", "1"):
465 | return True
466 | elif v.lower() in ("no", "false", "f", "n", "0"):
467 | return False
468 | else:
469 | raise argparse.ArgumentTypeError("boolean value expected")
470 |
--------------------------------------------------------------------------------
/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 |
16 | # For ImageNet experiments, this was a good default value.
17 | # We found that the lg_loss_scale quickly climbed to
18 | # 20-21 within the first ~1K steps of training.
19 | INITIAL_LOG_LOSS_SCALE = 20.0
20 |
21 |
22 | class TrainLoop:
23 | def __init__(
24 | self,
25 | *,
26 | model,
27 | diffusion,
28 | data,
29 | batch_size,
30 | microbatch,
31 | lr,
32 | ema_rate,
33 | log_interval,
34 | save_interval,
35 | resume_checkpoint,
36 | use_fp16=False,
37 | fp16_scale_growth=1e-3,
38 | schedule_sampler=None,
39 | weight_decay=0.0,
40 | lr_anneal_steps=0,
41 | ):
42 | self.model = model
43 | self.diffusion = diffusion
44 | self.data = data
45 | self.batch_size = batch_size
46 | self.microbatch = microbatch if microbatch > 0 else batch_size
47 | self.lr = lr
48 | self.ema_rate = (
49 | [ema_rate]
50 | if isinstance(ema_rate, float)
51 | else [float(x) for x in ema_rate.split(",")]
52 | )
53 | self.log_interval = log_interval
54 | self.save_interval = save_interval
55 | self.resume_checkpoint = resume_checkpoint
56 | self.use_fp16 = use_fp16
57 | self.fp16_scale_growth = fp16_scale_growth
58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
59 | self.weight_decay = weight_decay
60 | self.lr_anneal_steps = lr_anneal_steps
61 |
62 | self.step = 0
63 | self.resume_step = 0
64 | self.global_batch = self.batch_size * dist.get_world_size()
65 |
66 | self.sync_cuda = th.cuda.is_available()
67 |
68 | self._load_and_sync_parameters()
69 | self.mp_trainer = MixedPrecisionTrainer(
70 | model=self.model,
71 | use_fp16=self.use_fp16,
72 | fp16_scale_growth=fp16_scale_growth,
73 | )
74 |
75 | self.opt = AdamW(
76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
77 | )
78 | if self.resume_step:
79 | self._load_optimizer_state()
80 | # Model was resumed, either due to a restart or a checkpoint
81 | # being specified at the command line.
82 | self.ema_params = [
83 | self._load_ema_parameters(rate) for rate in self.ema_rate
84 | ]
85 | else:
86 | self.ema_params = [
87 | copy.deepcopy(self.mp_trainer.master_params)
88 | for _ in range(len(self.ema_rate))
89 | ]
90 |
91 | if th.cuda.is_available():
92 | self.use_ddp = True
93 | self.ddp_model = DDP(
94 | self.model,
95 | device_ids=[dist_util.dev()],
96 | output_device=dist_util.dev(),
97 | broadcast_buffers=False,
98 | bucket_cap_mb=128,
99 | find_unused_parameters=False,
100 | )
101 | else:
102 | if dist.get_world_size() > 1:
103 | logger.warn(
104 | "Distributed training requires CUDA. "
105 | "Gradients will not be synchronized properly!"
106 | )
107 | self.use_ddp = False
108 | self.ddp_model = self.model
109 |
110 | def _load_and_sync_parameters(self):
111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
112 |
113 | if resume_checkpoint:
114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
115 | #if dist.get_rank() == 0:
116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
117 | self.model.load_state_dict(
118 | dist_util.load_state_dict(
119 | resume_checkpoint, map_location=dist_util.dev()
120 | )
121 | )
122 |
123 | dist_util.sync_params(self.model.parameters())
124 |
125 | def _load_ema_parameters(self, rate):
126 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
127 |
128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
130 | if ema_checkpoint:
131 | #if dist.get_rank() == 0:
132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
133 | state_dict = dist_util.load_state_dict(
134 | ema_checkpoint, map_location=dist_util.dev()
135 | )
136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
137 |
138 | dist_util.sync_params(ema_params)
139 | return ema_params
140 |
141 | def _load_optimizer_state(self):
142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
143 | opt_checkpoint = bf.join(
144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
145 | )
146 | if bf.exists(opt_checkpoint):
147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
148 | state_dict = dist_util.load_state_dict(
149 | opt_checkpoint, map_location=dist_util.dev()
150 | )
151 | self.opt.load_state_dict(state_dict)
152 |
153 | def run_loop(self):
154 | while (
155 | not self.lr_anneal_steps
156 | or self.step + self.resume_step < self.lr_anneal_steps
157 | ):
158 | batch, cond = next(self.data)
159 | self.run_step(batch, cond)
160 | if self.step % self.log_interval == 0:
161 | logger.dumpkvs()
162 | if self.step % self.save_interval == 0:
163 | self.save()
164 | # Run for a finite amount of time in integration tests.
165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
166 | return
167 | self.step += 1
168 | # Save the last checkpoint if it wasn't already saved.
169 | if (self.step - 1) % self.save_interval != 0:
170 | self.save()
171 |
172 | def run_step(self, batch, cond):
173 | self.forward_backward(batch, cond)
174 | took_step = self.mp_trainer.optimize(self.opt)
175 | if took_step:
176 | self._update_ema()
177 | self._anneal_lr()
178 | self.log_step()
179 |
180 | def forward_backward(self, batch, cond):
181 | self.mp_trainer.zero_grad()
182 | for i in range(0, batch.shape[0], self.microbatch):
183 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
184 | micro_cond = {
185 | k: v[i : i + self.microbatch].to(dist_util.dev())
186 | for k, v in cond.items()
187 | }
188 | last_batch = (i + self.microbatch) >= batch.shape[0]
189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
190 |
191 | compute_losses = functools.partial(
192 | self.diffusion.training_losses,
193 | self.ddp_model,
194 | micro,
195 | t,
196 | model_kwargs=micro_cond,
197 | )
198 |
199 | if last_batch or not self.use_ddp:
200 | losses = compute_losses()
201 | else:
202 | with self.ddp_model.no_sync():
203 | losses = compute_losses()
204 |
205 | if isinstance(self.schedule_sampler, LossAwareSampler):
206 | self.schedule_sampler.update_with_local_losses(
207 | t, losses["loss"].detach()
208 | )
209 |
210 | loss = (losses["loss"] * weights).mean()
211 | log_loss_dict(
212 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
213 | )
214 | self.mp_trainer.backward(loss)
215 |
216 | def _update_ema(self):
217 | for rate, params in zip(self.ema_rate, self.ema_params):
218 | update_ema(params, self.mp_trainer.master_params, rate=rate)
219 |
220 | def _anneal_lr(self):
221 | if not self.lr_anneal_steps:
222 | return
223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
224 | lr = self.lr * (1 - frac_done)
225 | for param_group in self.opt.param_groups:
226 | param_group["lr"] = lr
227 |
228 | def log_step(self):
229 | logger.logkv("step", self.step + self.resume_step)
230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
231 |
232 | def save(self):
233 | def save_checkpoint(rate, params):
234 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
235 | if dist.get_rank() == 0:
236 | logger.log(f"saving model {rate}...")
237 | if not rate:
238 | filename = f"model{(self.step+self.resume_step):06d}.pt"
239 | else:
240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
242 | th.save(state_dict, f)
243 |
244 | save_checkpoint(0, self.mp_trainer.master_params)
245 | for rate, params in zip(self.ema_rate, self.ema_params):
246 | save_checkpoint(rate, params)
247 |
248 | if dist.get_rank() == 0:
249 | with bf.BlobFile(
250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
251 | "wb",
252 | ) as f:
253 | th.save(self.opt.state_dict(), f)
254 |
255 | dist.barrier()
256 |
257 |
258 | def parse_resume_step_from_filename(filename):
259 | """
260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
261 | checkpoint's number of steps.
262 | """
263 | split = filename.split("model")
264 | if len(split) < 2:
265 | return 0
266 | split1 = split[-1].split(".")[0]
267 | try:
268 | return int(split1)
269 | except ValueError:
270 | return 0
271 |
272 |
273 | def get_blob_logdir():
274 | # You can change this to be a separate path to save checkpoints to
275 | # a blobstore or some external drive.
276 | return logger.get_dir()
277 |
278 |
279 | def find_resume_checkpoint():
280 | # On your infrastructure, you may want to override this to automatically
281 | # discover the latest checkpoint on your blob storage, etc.
282 | return None
283 |
284 |
285 | def find_ema_checkpoint(main_checkpoint, step, rate):
286 | if main_checkpoint is None:
287 | return None
288 | filename = f"ema_{rate}_{(step):06d}.pt"
289 | path = bf.join(bf.dirname(main_checkpoint), filename)
290 | if bf.exists(path):
291 | return path
292 | return None
293 |
294 |
295 | def log_loss_dict(diffusion, ts, losses):
296 | for key, values in losses.items():
297 | logger.logkv_mean(key, values.mean().item())
298 | # Log the quantiles (four quartiles, in particular).
299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
300 | quartile = int(4 * sub_t / diffusion.num_timesteps)
301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
302 |
--------------------------------------------------------------------------------
/guided_diffusion/turb_datasets.py:
--------------------------------------------------------------------------------
1 | from mpi4py import MPI
2 | import h5py
3 | from torch.utils.data import DataLoader, Dataset
4 | import numpy as np
5 |
6 |
7 | def load_data(
8 | *,
9 | dataset_path,
10 | dataset_name,
11 | batch_size,
12 | class_cond=False,
13 | deterministic=False,
14 | ):
15 | """
16 | For a dataset, create a generator over (images, kwargs) pairs.
17 |
18 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
19 | more keys, each of which map to a batched Tensor of their own.
20 | The kwargs dict can be used for class labels, in which case the key is "y"
21 | and the values are integer tensors of class labels.
22 |
23 | :param dataset_path: a dataset path.
24 | :param dataset_name: a dataset name.
25 | :param batch_size: the batch size of each returned pair.
26 | :param class_cond: if True, include a "y" key in returned dicts for class
27 | label. Not implemented.
28 | :param deterministic: if True, yield results in a deterministic order.
29 | """
30 | comm = MPI.COMM_WORLD
31 | rank = comm.Get_rank()
32 | size = comm.Get_size()
33 |
34 | with h5py.File(dataset_path, 'r', driver='mpio', comm=MPI.COMM_SELF) as f:
35 | #with h5py.File(dataset_path, 'r') as f: # replace the above line with this line for serial h5py
36 | len_dataset = f[dataset_name].len()
37 |
38 | chunk_size = len_dataset // size
39 | start_idx = rank * chunk_size
40 |
41 | dataset = TurbDataset(
42 | dataset_path, dataset_name, class_cond, start_idx, chunk_size,
43 | )
44 |
45 | shuffle = True if deterministic else False
46 | loader = DataLoader(
47 | dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, drop_last=True
48 | )
49 |
50 | while True:
51 | yield from loader
52 |
53 |
54 | class TurbDataset(Dataset):
55 | def __init__(
56 | self,
57 | dataset_path,
58 | dataset_name,
59 | class_cond,
60 | start_idx,
61 | chunk_size,
62 | ):
63 | super().__init__()
64 | self.dataset_path = dataset_path
65 | self.dataset_name = dataset_name
66 | self.class_cond = class_cond
67 | self.start_idx = start_idx
68 | self.chunk_size = chunk_size
69 |
70 | def __len__(self):
71 | return self.chunk_size
72 |
73 | def __getitem__(self, idx):
74 | idx += self.start_idx
75 |
76 | with h5py.File(self.dataset_path, 'r', driver='mpio', comm=MPI.COMM_SELF) as f:
77 | #with h5py.File(self.dataset_path, 'r') as f: # replace the above line with this line for serial h5py
78 | data = f[self.dataset_name][idx].astype(np.float32)
79 | data = np.moveaxis(data, -1, 0)
80 |
81 | out_dict = {}
82 | if self.class_cond:
83 | raise NotImplementedError()
84 | out_dict["y"] = f[self.dataset_name + '_y'][idx]
85 |
86 | return data, out_dict
87 |
--------------------------------------------------------------------------------
/guided_diffusion/unet.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | import math
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32
11 | from .nn import (
12 | checkpoint,
13 | conv_nd,
14 | linear,
15 | avg_pool_nd,
16 | zero_module,
17 | normalization,
18 | timestep_embedding,
19 | )
20 |
21 |
22 | class AttentionPool2d(nn.Module):
23 | """
24 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
25 | """
26 |
27 | def __init__(
28 | self,
29 | spacial_dim: int,
30 | embed_dim: int,
31 | num_heads_channels: int,
32 | output_dim: int = None,
33 | ):
34 | super().__init__()
35 | self.positional_embedding = nn.Parameter(
36 | th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
37 | )
38 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
39 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
40 | self.num_heads = embed_dim // num_heads_channels
41 | self.attention = QKVAttention(self.num_heads)
42 |
43 | def forward(self, x):
44 | b, c, *_spatial = x.shape
45 | x = x.reshape(b, c, -1) # NC(HW)
46 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
47 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
48 | x = self.qkv_proj(x)
49 | x = self.attention(x)
50 | x = self.c_proj(x)
51 | return x[:, :, 0]
52 |
53 |
54 | class TimestepBlock(nn.Module):
55 | """
56 | Any module where forward() takes timestep embeddings as a second argument.
57 | """
58 |
59 | @abstractmethod
60 | def forward(self, x, emb):
61 | """
62 | Apply the module to `x` given `emb` timestep embeddings.
63 | """
64 |
65 |
66 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
67 | """
68 | A sequential module that passes timestep embeddings to the children that
69 | support it as an extra input.
70 | """
71 |
72 | def forward(self, x, emb):
73 | for layer in self:
74 | if isinstance(layer, TimestepBlock):
75 | x = layer(x, emb)
76 | else:
77 | x = layer(x)
78 | return x
79 |
80 |
81 | class Upsample(nn.Module):
82 | """
83 | An upsampling layer with an optional convolution.
84 |
85 | :param channels: channels in the inputs and outputs.
86 | :param use_conv: a bool determining if a convolution is applied.
87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
88 | upsampling occurs in the inner-two dimensions.
89 | """
90 |
91 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
92 | super().__init__()
93 | self.channels = channels
94 | self.out_channels = out_channels or channels
95 | self.use_conv = use_conv
96 | self.dims = dims
97 | if use_conv:
98 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
99 |
100 | def forward(self, x):
101 | assert x.shape[1] == self.channels
102 | if self.dims == 3:
103 | x = F.interpolate(
104 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
105 | )
106 | else:
107 | x = F.interpolate(x, scale_factor=2, mode="nearest")
108 | if self.use_conv:
109 | x = self.conv(x)
110 | return x
111 |
112 |
113 | class Downsample(nn.Module):
114 | """
115 | A downsampling layer with an optional convolution.
116 |
117 | :param channels: channels in the inputs and outputs.
118 | :param use_conv: a bool determining if a convolution is applied.
119 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
120 | downsampling occurs in the inner-two dimensions.
121 | """
122 |
123 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
124 | super().__init__()
125 | self.channels = channels
126 | self.out_channels = out_channels or channels
127 | self.use_conv = use_conv
128 | self.dims = dims
129 | stride = 2 if dims != 3 else (1, 2, 2)
130 | if use_conv:
131 | self.op = conv_nd(
132 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1
133 | )
134 | else:
135 | assert self.channels == self.out_channels
136 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
137 |
138 | def forward(self, x):
139 | assert x.shape[1] == self.channels
140 | return self.op(x)
141 |
142 |
143 | class ResBlock(TimestepBlock):
144 | """
145 | A residual block that can optionally change the number of channels.
146 |
147 | :param channels: the number of input channels.
148 | :param emb_channels: the number of timestep embedding channels.
149 | :param dropout: the rate of dropout.
150 | :param out_channels: if specified, the number of out channels.
151 | :param use_conv: if True and out_channels is specified, use a spatial
152 | convolution instead of a smaller 1x1 convolution to change the
153 | channels in the skip connection.
154 | :param dims: determines if the signal is 1D, 2D, or 3D.
155 | :param use_checkpoint: if True, use gradient checkpointing on this module.
156 | :param up: if True, use this block for upsampling.
157 | :param down: if True, use this block for downsampling.
158 | """
159 |
160 | def __init__(
161 | self,
162 | channels,
163 | emb_channels,
164 | dropout,
165 | out_channels=None,
166 | use_conv=False,
167 | use_scale_shift_norm=False,
168 | dims=2,
169 | use_checkpoint=False,
170 | up=False,
171 | down=False,
172 | ):
173 | super().__init__()
174 | self.channels = channels
175 | self.emb_channels = emb_channels
176 | self.dropout = dropout
177 | self.out_channels = out_channels or channels
178 | self.use_conv = use_conv
179 | self.use_checkpoint = use_checkpoint
180 | self.use_scale_shift_norm = use_scale_shift_norm
181 |
182 | self.in_layers = nn.Sequential(
183 | normalization(channels),
184 | nn.SiLU(),
185 | conv_nd(dims, channels, self.out_channels, 3, padding=1),
186 | )
187 |
188 | self.updown = up or down
189 |
190 | if up:
191 | self.h_upd = Upsample(channels, False, dims)
192 | self.x_upd = Upsample(channels, False, dims)
193 | elif down:
194 | self.h_upd = Downsample(channels, False, dims)
195 | self.x_upd = Downsample(channels, False, dims)
196 | else:
197 | self.h_upd = self.x_upd = nn.Identity()
198 |
199 | self.emb_layers = nn.Sequential(
200 | nn.SiLU(),
201 | linear(
202 | emb_channels,
203 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
204 | ),
205 | )
206 | self.out_layers = nn.Sequential(
207 | normalization(self.out_channels),
208 | nn.SiLU(),
209 | nn.Dropout(p=dropout),
210 | zero_module(
211 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
212 | ),
213 | )
214 |
215 | if self.out_channels == channels:
216 | self.skip_connection = nn.Identity()
217 | elif use_conv:
218 | self.skip_connection = conv_nd(
219 | dims, channels, self.out_channels, 3, padding=1
220 | )
221 | else:
222 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
223 |
224 | def forward(self, x, emb):
225 | """
226 | Apply the block to a Tensor, conditioned on a timestep embedding.
227 |
228 | :param x: an [N x C x ...] Tensor of features.
229 | :param emb: an [N x emb_channels] Tensor of timestep embeddings.
230 | :return: an [N x C x ...] Tensor of outputs.
231 | """
232 | return checkpoint(
233 | self._forward, (x, emb), self.parameters(), self.use_checkpoint
234 | )
235 |
236 | def _forward(self, x, emb):
237 | if self.updown:
238 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
239 | h = in_rest(x)
240 | h = self.h_upd(h)
241 | x = self.x_upd(x)
242 | h = in_conv(h)
243 | else:
244 | h = self.in_layers(x)
245 | emb_out = self.emb_layers(emb).type(h.dtype)
246 | while len(emb_out.shape) < len(h.shape):
247 | emb_out = emb_out[..., None]
248 | if self.use_scale_shift_norm:
249 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
250 | scale, shift = th.chunk(emb_out, 2, dim=1)
251 | h = out_norm(h) * (1 + scale) + shift
252 | h = out_rest(h)
253 | else:
254 | h = h + emb_out
255 | h = self.out_layers(h)
256 | return self.skip_connection(x) + h
257 |
258 |
259 | class AttentionBlock(nn.Module):
260 | """
261 | An attention block that allows spatial positions to attend to each other.
262 |
263 | Originally ported from here, but adapted to the N-d case.
264 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
265 | """
266 |
267 | def __init__(
268 | self,
269 | channels,
270 | num_heads=1,
271 | num_head_channels=-1,
272 | use_checkpoint=False,
273 | use_new_attention_order=False,
274 | ):
275 | super().__init__()
276 | self.channels = channels
277 | if num_head_channels == -1:
278 | self.num_heads = num_heads
279 | else:
280 | assert (
281 | channels % num_head_channels == 0
282 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
283 | self.num_heads = channels // num_head_channels
284 | self.use_checkpoint = use_checkpoint
285 | self.norm = normalization(channels)
286 | self.qkv = conv_nd(1, channels, channels * 3, 1)
287 | if use_new_attention_order:
288 | # split qkv before split heads
289 | self.attention = QKVAttention(self.num_heads)
290 | else:
291 | # split heads before split qkv
292 | self.attention = QKVAttentionLegacy(self.num_heads)
293 |
294 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
295 |
296 | def forward(self, x):
297 | return checkpoint(self._forward, (x,), self.parameters(), True)
298 |
299 | def _forward(self, x):
300 | b, c, *spatial = x.shape
301 | x = x.reshape(b, c, -1)
302 | qkv = self.qkv(self.norm(x))
303 | h = self.attention(qkv)
304 | h = self.proj_out(h)
305 | return (x + h).reshape(b, c, *spatial)
306 |
307 |
308 | def count_flops_attn(model, _x, y):
309 | """
310 | A counter for the `thop` package to count the operations in an
311 | attention operation.
312 | Meant to be used like:
313 | macs, params = thop.profile(
314 | model,
315 | inputs=(inputs, timestamps),
316 | custom_ops={QKVAttention: QKVAttention.count_flops},
317 | )
318 | """
319 | b, c, *spatial = y[0].shape
320 | num_spatial = int(np.prod(spatial))
321 | # We perform two matmuls with the same number of ops.
322 | # The first computes the weight matrix, the second computes
323 | # the combination of the value vectors.
324 | matmul_ops = 2 * b * (num_spatial ** 2) * c
325 | model.total_ops += th.DoubleTensor([matmul_ops])
326 |
327 |
328 | class QKVAttentionLegacy(nn.Module):
329 | """
330 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
331 | """
332 |
333 | def __init__(self, n_heads):
334 | super().__init__()
335 | self.n_heads = n_heads
336 |
337 | def forward(self, qkv):
338 | """
339 | Apply QKV attention.
340 |
341 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
342 | :return: an [N x (H * C) x T] tensor after attention.
343 | """
344 | bs, width, length = qkv.shape
345 | assert width % (3 * self.n_heads) == 0
346 | ch = width // (3 * self.n_heads)
347 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
348 | scale = 1 / math.sqrt(math.sqrt(ch))
349 | weight = th.einsum(
350 | "bct,bcs->bts", q * scale, k * scale
351 | ) # More stable with f16 than dividing afterwards
352 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
353 | a = th.einsum("bts,bcs->bct", weight, v)
354 | return a.reshape(bs, -1, length)
355 |
356 | @staticmethod
357 | def count_flops(model, _x, y):
358 | return count_flops_attn(model, _x, y)
359 |
360 |
361 | class QKVAttention(nn.Module):
362 | """
363 | A module which performs QKV attention and splits in a different order.
364 | """
365 |
366 | def __init__(self, n_heads):
367 | super().__init__()
368 | self.n_heads = n_heads
369 |
370 | def forward(self, qkv):
371 | """
372 | Apply QKV attention.
373 |
374 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
375 | :return: an [N x (H * C) x T] tensor after attention.
376 | """
377 | bs, width, length = qkv.shape
378 | assert width % (3 * self.n_heads) == 0
379 | ch = width // (3 * self.n_heads)
380 | q, k, v = qkv.chunk(3, dim=1)
381 | scale = 1 / math.sqrt(math.sqrt(ch))
382 | weight = th.einsum(
383 | "bct,bcs->bts",
384 | (q * scale).view(bs * self.n_heads, ch, length),
385 | (k * scale).view(bs * self.n_heads, ch, length),
386 | ) # More stable with f16 than dividing afterwards
387 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
388 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
389 | return a.reshape(bs, -1, length)
390 |
391 | @staticmethod
392 | def count_flops(model, _x, y):
393 | return count_flops_attn(model, _x, y)
394 |
395 |
396 | class UNetModel(nn.Module):
397 | """
398 | The full UNet model with attention and timestep embedding.
399 |
400 | :param in_channels: channels in the input Tensor.
401 | :param model_channels: base channel count for the model.
402 | :param out_channels: channels in the output Tensor.
403 | :param num_res_blocks: number of residual blocks per downsample.
404 | :param attention_resolutions: a collection of downsample rates at which
405 | attention will take place. May be a set, list, or tuple.
406 | For example, if this contains 4, then at 4x downsampling, attention
407 | will be used.
408 | :param dropout: the dropout probability.
409 | :param channel_mult: channel multiplier for each level of the UNet.
410 | :param conv_resample: if True, use learned convolutions for upsampling and
411 | downsampling.
412 | :param dims: determines if the signal is 1D, 2D, or 3D.
413 | :param num_classes: if specified (as an int), then this model will be
414 | class-conditional with `num_classes` classes.
415 | :param use_checkpoint: use gradient checkpointing to reduce memory usage.
416 | :param num_heads: the number of attention heads in each attention layer.
417 | :param num_heads_channels: if specified, ignore num_heads and instead use
418 | a fixed channel width per attention head.
419 | :param num_heads_upsample: works with num_heads to set a different number
420 | of heads for upsampling. Deprecated.
421 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
422 | :param resblock_updown: use residual blocks for up/downsampling.
423 | :param use_new_attention_order: use a different attention pattern for potentially
424 | increased efficiency.
425 | """
426 |
427 | def __init__(
428 | self,
429 | image_size,
430 | in_channels,
431 | model_channels,
432 | out_channels,
433 | num_res_blocks,
434 | attention_resolutions,
435 | dropout=0,
436 | channel_mult=(1, 2, 4, 8),
437 | conv_resample=True,
438 | dims=2,
439 | num_classes=None,
440 | use_checkpoint=False,
441 | use_fp16=False,
442 | num_heads=1,
443 | num_head_channels=-1,
444 | num_heads_upsample=-1,
445 | use_scale_shift_norm=False,
446 | resblock_updown=False,
447 | use_new_attention_order=False,
448 | ):
449 | super().__init__()
450 |
451 | if num_heads_upsample == -1:
452 | num_heads_upsample = num_heads
453 |
454 | self.image_size = image_size
455 | self.in_channels = in_channels
456 | self.model_channels = model_channels
457 | self.out_channels = out_channels
458 | self.num_res_blocks = num_res_blocks
459 | self.attention_resolutions = attention_resolutions
460 | self.dropout = dropout
461 | self.channel_mult = channel_mult
462 | self.conv_resample = conv_resample
463 | self.num_classes = num_classes
464 | self.use_checkpoint = use_checkpoint
465 | self.dtype = th.float16 if use_fp16 else th.float32
466 | self.num_heads = num_heads
467 | self.num_head_channels = num_head_channels
468 | self.num_heads_upsample = num_heads_upsample
469 |
470 | time_embed_dim = model_channels * 4
471 | self.time_embed = nn.Sequential(
472 | linear(model_channels, time_embed_dim),
473 | nn.SiLU(),
474 | linear(time_embed_dim, time_embed_dim),
475 | )
476 |
477 | if self.num_classes is not None:
478 | self.label_emb = nn.Embedding(num_classes, time_embed_dim)
479 |
480 | ch = input_ch = int(channel_mult[0] * model_channels)
481 | self.input_blocks = nn.ModuleList(
482 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
483 | )
484 | self._feature_size = ch
485 | input_block_chans = [ch]
486 | ds = 1
487 | for level, mult in enumerate(channel_mult):
488 | for _ in range(num_res_blocks):
489 | layers = [
490 | ResBlock(
491 | ch,
492 | time_embed_dim,
493 | dropout,
494 | out_channels=int(mult * model_channels),
495 | dims=dims,
496 | use_checkpoint=use_checkpoint,
497 | use_scale_shift_norm=use_scale_shift_norm,
498 | )
499 | ]
500 | ch = int(mult * model_channels)
501 | if ds in attention_resolutions:
502 | layers.append(
503 | AttentionBlock(
504 | ch,
505 | use_checkpoint=use_checkpoint,
506 | num_heads=num_heads,
507 | num_head_channels=num_head_channels,
508 | use_new_attention_order=use_new_attention_order,
509 | )
510 | )
511 | self.input_blocks.append(TimestepEmbedSequential(*layers))
512 | self._feature_size += ch
513 | input_block_chans.append(ch)
514 | if level != len(channel_mult) - 1:
515 | out_ch = ch
516 | self.input_blocks.append(
517 | TimestepEmbedSequential(
518 | ResBlock(
519 | ch,
520 | time_embed_dim,
521 | dropout,
522 | out_channels=out_ch,
523 | dims=dims,
524 | use_checkpoint=use_checkpoint,
525 | use_scale_shift_norm=use_scale_shift_norm,
526 | down=True,
527 | )
528 | if resblock_updown
529 | else Downsample(
530 | ch, conv_resample, dims=dims, out_channels=out_ch
531 | )
532 | )
533 | )
534 | ch = out_ch
535 | input_block_chans.append(ch)
536 | ds *= 2
537 | self._feature_size += ch
538 |
539 | self.middle_block = TimestepEmbedSequential(
540 | ResBlock(
541 | ch,
542 | time_embed_dim,
543 | dropout,
544 | dims=dims,
545 | use_checkpoint=use_checkpoint,
546 | use_scale_shift_norm=use_scale_shift_norm,
547 | ),
548 | AttentionBlock(
549 | ch,
550 | use_checkpoint=use_checkpoint,
551 | num_heads=num_heads,
552 | num_head_channels=num_head_channels,
553 | use_new_attention_order=use_new_attention_order,
554 | ),
555 | ResBlock(
556 | ch,
557 | time_embed_dim,
558 | dropout,
559 | dims=dims,
560 | use_checkpoint=use_checkpoint,
561 | use_scale_shift_norm=use_scale_shift_norm,
562 | ),
563 | )
564 | self._feature_size += ch
565 |
566 | self.output_blocks = nn.ModuleList([])
567 | for level, mult in list(enumerate(channel_mult))[::-1]:
568 | for i in range(num_res_blocks + 1):
569 | ich = input_block_chans.pop()
570 | layers = [
571 | ResBlock(
572 | ch + ich,
573 | time_embed_dim,
574 | dropout,
575 | out_channels=int(model_channels * mult),
576 | dims=dims,
577 | use_checkpoint=use_checkpoint,
578 | use_scale_shift_norm=use_scale_shift_norm,
579 | )
580 | ]
581 | ch = int(model_channels * mult)
582 | if ds in attention_resolutions:
583 | layers.append(
584 | AttentionBlock(
585 | ch,
586 | use_checkpoint=use_checkpoint,
587 | num_heads=num_heads_upsample,
588 | num_head_channels=num_head_channels,
589 | use_new_attention_order=use_new_attention_order,
590 | )
591 | )
592 | if level and i == num_res_blocks:
593 | out_ch = ch
594 | layers.append(
595 | ResBlock(
596 | ch,
597 | time_embed_dim,
598 | dropout,
599 | out_channels=out_ch,
600 | dims=dims,
601 | use_checkpoint=use_checkpoint,
602 | use_scale_shift_norm=use_scale_shift_norm,
603 | up=True,
604 | )
605 | if resblock_updown
606 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
607 | )
608 | ds //= 2
609 | self.output_blocks.append(TimestepEmbedSequential(*layers))
610 | self._feature_size += ch
611 |
612 | self.out = nn.Sequential(
613 | normalization(ch),
614 | nn.SiLU(),
615 | zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
616 | )
617 |
618 | def convert_to_fp16(self):
619 | """
620 | Convert the torso of the model to float16.
621 | """
622 | self.input_blocks.apply(convert_module_to_f16)
623 | self.middle_block.apply(convert_module_to_f16)
624 | self.output_blocks.apply(convert_module_to_f16)
625 |
626 | def convert_to_fp32(self):
627 | """
628 | Convert the torso of the model to float32.
629 | """
630 | self.input_blocks.apply(convert_module_to_f32)
631 | self.middle_block.apply(convert_module_to_f32)
632 | self.output_blocks.apply(convert_module_to_f32)
633 |
634 | def forward(self, x, timesteps, y=None):
635 | """
636 | Apply the model to an input batch.
637 |
638 | :param x: an [N x C x ...] Tensor of inputs.
639 | :param timesteps: a 1-D batch of timesteps.
640 | :param y: an [N] Tensor of labels, if class-conditional.
641 | :return: an [N x C x ...] Tensor of outputs.
642 | """
643 | assert (y is not None) == (
644 | self.num_classes is not None
645 | ), "must specify y if and only if the model is class-conditional"
646 |
647 | hs = []
648 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
649 |
650 | if self.num_classes is not None:
651 | assert y.shape == (x.shape[0],)
652 | emb = emb + self.label_emb(y)
653 |
654 | h = x.type(self.dtype)
655 | for module in self.input_blocks:
656 | h = module(h, emb)
657 | hs.append(h)
658 | h = self.middle_block(h, emb)
659 | for module in self.output_blocks:
660 | h = th.cat([h, hs.pop()], dim=1)
661 | h = module(h, emb)
662 | h = h.type(x.dtype)
663 | return self.out(h)
664 |
665 |
666 | class SuperResModel(UNetModel):
667 | """
668 | A UNetModel that performs super-resolution.
669 |
670 | Expects an extra kwarg `low_res` to condition on a low-resolution image.
671 | """
672 |
673 | def __init__(self, image_size, in_channels, *args, **kwargs):
674 | super().__init__(image_size, in_channels * 2, *args, **kwargs)
675 |
676 | def forward(self, x, timesteps, low_res=None, **kwargs):
677 | _, _, new_height, new_width = x.shape
678 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
679 | x = th.cat([x, upsampled], dim=1)
680 | return super().forward(x, timesteps, **kwargs)
681 |
682 |
683 | class EncoderUNetModel(nn.Module):
684 | """
685 | The half UNet model with attention and timestep embedding.
686 |
687 | For usage, see UNet.
688 | """
689 |
690 | def __init__(
691 | self,
692 | image_size,
693 | in_channels,
694 | model_channels,
695 | out_channels,
696 | num_res_blocks,
697 | attention_resolutions,
698 | dropout=0,
699 | channel_mult=(1, 2, 4, 8),
700 | conv_resample=True,
701 | dims=2,
702 | use_checkpoint=False,
703 | use_fp16=False,
704 | num_heads=1,
705 | num_head_channels=-1,
706 | num_heads_upsample=-1,
707 | use_scale_shift_norm=False,
708 | resblock_updown=False,
709 | use_new_attention_order=False,
710 | pool="adaptive",
711 | ):
712 | super().__init__()
713 |
714 | if num_heads_upsample == -1:
715 | num_heads_upsample = num_heads
716 |
717 | self.in_channels = in_channels
718 | self.model_channels = model_channels
719 | self.out_channels = out_channels
720 | self.num_res_blocks = num_res_blocks
721 | self.attention_resolutions = attention_resolutions
722 | self.dropout = dropout
723 | self.channel_mult = channel_mult
724 | self.conv_resample = conv_resample
725 | self.use_checkpoint = use_checkpoint
726 | self.dtype = th.float16 if use_fp16 else th.float32
727 | self.num_heads = num_heads
728 | self.num_head_channels = num_head_channels
729 | self.num_heads_upsample = num_heads_upsample
730 |
731 | time_embed_dim = model_channels * 4
732 | self.time_embed = nn.Sequential(
733 | linear(model_channels, time_embed_dim),
734 | nn.SiLU(),
735 | linear(time_embed_dim, time_embed_dim),
736 | )
737 |
738 | ch = int(channel_mult[0] * model_channels)
739 | self.input_blocks = nn.ModuleList(
740 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
741 | )
742 | self._feature_size = ch
743 | input_block_chans = [ch]
744 | ds = 1
745 | for level, mult in enumerate(channel_mult):
746 | for _ in range(num_res_blocks):
747 | layers = [
748 | ResBlock(
749 | ch,
750 | time_embed_dim,
751 | dropout,
752 | out_channels=int(mult * model_channels),
753 | dims=dims,
754 | use_checkpoint=use_checkpoint,
755 | use_scale_shift_norm=use_scale_shift_norm,
756 | )
757 | ]
758 | ch = int(mult * model_channels)
759 | if ds in attention_resolutions:
760 | layers.append(
761 | AttentionBlock(
762 | ch,
763 | use_checkpoint=use_checkpoint,
764 | num_heads=num_heads,
765 | num_head_channels=num_head_channels,
766 | use_new_attention_order=use_new_attention_order,
767 | )
768 | )
769 | self.input_blocks.append(TimestepEmbedSequential(*layers))
770 | self._feature_size += ch
771 | input_block_chans.append(ch)
772 | if level != len(channel_mult) - 1:
773 | out_ch = ch
774 | self.input_blocks.append(
775 | TimestepEmbedSequential(
776 | ResBlock(
777 | ch,
778 | time_embed_dim,
779 | dropout,
780 | out_channels=out_ch,
781 | dims=dims,
782 | use_checkpoint=use_checkpoint,
783 | use_scale_shift_norm=use_scale_shift_norm,
784 | down=True,
785 | )
786 | if resblock_updown
787 | else Downsample(
788 | ch, conv_resample, dims=dims, out_channels=out_ch
789 | )
790 | )
791 | )
792 | ch = out_ch
793 | input_block_chans.append(ch)
794 | ds *= 2
795 | self._feature_size += ch
796 |
797 | self.middle_block = TimestepEmbedSequential(
798 | ResBlock(
799 | ch,
800 | time_embed_dim,
801 | dropout,
802 | dims=dims,
803 | use_checkpoint=use_checkpoint,
804 | use_scale_shift_norm=use_scale_shift_norm,
805 | ),
806 | AttentionBlock(
807 | ch,
808 | use_checkpoint=use_checkpoint,
809 | num_heads=num_heads,
810 | num_head_channels=num_head_channels,
811 | use_new_attention_order=use_new_attention_order,
812 | ),
813 | ResBlock(
814 | ch,
815 | time_embed_dim,
816 | dropout,
817 | dims=dims,
818 | use_checkpoint=use_checkpoint,
819 | use_scale_shift_norm=use_scale_shift_norm,
820 | ),
821 | )
822 | self._feature_size += ch
823 | self.pool = pool
824 | if pool == "adaptive":
825 | self.out = nn.Sequential(
826 | normalization(ch),
827 | nn.SiLU(),
828 | nn.AdaptiveAvgPool2d((1, 1)),
829 | zero_module(conv_nd(dims, ch, out_channels, 1)),
830 | nn.Flatten(),
831 | )
832 | elif pool == "attention":
833 | assert num_head_channels != -1
834 | self.out = nn.Sequential(
835 | normalization(ch),
836 | nn.SiLU(),
837 | AttentionPool2d(
838 | (image_size // ds), ch, num_head_channels, out_channels
839 | ),
840 | )
841 | elif pool == "spatial":
842 | self.out = nn.Sequential(
843 | nn.Linear(self._feature_size, 2048),
844 | nn.ReLU(),
845 | nn.Linear(2048, self.out_channels),
846 | )
847 | elif pool == "spatial_v2":
848 | self.out = nn.Sequential(
849 | nn.Linear(self._feature_size, 2048),
850 | normalization(2048),
851 | nn.SiLU(),
852 | nn.Linear(2048, self.out_channels),
853 | )
854 | else:
855 | raise NotImplementedError(f"Unexpected {pool} pooling")
856 |
857 | def convert_to_fp16(self):
858 | """
859 | Convert the torso of the model to float16.
860 | """
861 | self.input_blocks.apply(convert_module_to_f16)
862 | self.middle_block.apply(convert_module_to_f16)
863 |
864 | def convert_to_fp32(self):
865 | """
866 | Convert the torso of the model to float32.
867 | """
868 | self.input_blocks.apply(convert_module_to_f32)
869 | self.middle_block.apply(convert_module_to_f32)
870 |
871 | def forward(self, x, timesteps):
872 | """
873 | Apply the model to an input batch.
874 |
875 | :param x: an [N x C x ...] Tensor of inputs.
876 | :param timesteps: a 1-D batch of timesteps.
877 | :return: an [N x K] Tensor of outputs.
878 | """
879 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
880 |
881 | results = []
882 | h = x.type(self.dtype)
883 | for module in self.input_blocks:
884 | h = module(h, emb)
885 | if self.pool.startswith("spatial"):
886 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
887 | h = self.middle_block(h, emb)
888 | if self.pool.startswith("spatial"):
889 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
890 | h = th.cat(results, axis=-1)
891 | return self.out(h)
892 | else:
893 | h = h.type(x.dtype)
894 | return self.out(h)
895 |
--------------------------------------------------------------------------------
/ppdm/ppdm/__init__.py:
--------------------------------------------------------------------------------
1 | from ppdm.ppdm import comput_jsd
2 | from ppdm.ppdm import struct_func
3 | from ppdm.ppdm import comput_batch_mean_err
4 | from ppdm.ppdm import corr_func
5 |
--------------------------------------------------------------------------------
/ppdm/ppdm/ppdm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import distance
3 |
4 | def comput_jsd(y_true, y_pred, bins):
5 | y_min = min(y_true.min(), y_pred.min())
6 | y_max = max(y_true.max(), y_pred.max())
7 | hist_true, bin_edges = np.histogram(y_true, bins=bins, range=(y_min, y_max))
8 | hist_pred, bin_edges = np.histogram(y_pred, bins=bins, range=(y_min, y_max))
9 | return distance.jensenshannon(hist_true, hist_pred, 2.0)**2
10 |
11 | def struct_func(p, dt, u):
12 | du_p = (u[:, dt:] - u[:, :-dt]) ** p
13 | return np.mean(du_p), np.std(du_p)
14 |
15 | def comput_batch_mean_err(Sp_batch):
16 | Sp_mean = np.mean(Sp_batch, axis=0)
17 | Sp_min = np.amin(Sp_batch, axis=0)
18 | Sp_max = np.amax(Sp_batch, axis=0)
19 | return Sp_mean, np.vstack([Sp_mean - Sp_min, Sp_max - Sp_mean])
20 |
21 | def corr_func(dt, u):
22 | return np.mean(u[:, :-dt] * u[:, dt:])
--------------------------------------------------------------------------------
/ppdm/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup
4 |
5 | setup(
6 | name='ppdm',
7 | version='0.1',
8 | # list folders, not files
9 | packages=['ppdm'],
10 | )
11 |
--------------------------------------------------------------------------------
/resources/Sampling.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/resources/Sampling.png
--------------------------------------------------------------------------------
/resources/Training.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SmartTURB/diffusion-lagr/77da6cdc67775aa45af6c4148068f00a75e54c41/resources/Training.png
--------------------------------------------------------------------------------
/scripts/turb_losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate the losses for a diffusion model.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 | from guided_diffusion import dist_util, logger
13 | from guided_diffusion.turb_datasets import load_data
14 | from guided_diffusion.script_util import (
15 | model_and_diffusion_defaults,
16 | create_model_and_diffusion,
17 | add_dict_to_argparser,
18 | args_to_dict,
19 | )
20 |
21 |
22 | def main():
23 | args = create_argparser().parse_args()
24 |
25 | dist_util.setup_dist()
26 | logger.configure()
27 |
28 | logger.log("creating model and diffusion...")
29 | model, diffusion = create_model_and_diffusion(
30 | **args_to_dict(args, model_and_diffusion_defaults().keys())
31 | )
32 | model.load_state_dict(
33 | dist_util.load_state_dict(args.model_path, map_location="cpu")
34 | )
35 | model.to(dist_util.dev())
36 | model.eval()
37 |
38 | logger.log("creating data loader...")
39 | data = load_data(
40 | dataset_path=args.dataset_path,
41 | dataset_name=args.dataset_name,
42 | batch_size=args.batch_size,
43 | class_cond=args.class_cond,
44 | deterministic=True,
45 | )
46 |
47 | logger.log("evaluating...")
48 | import os
49 | seed = 0*4 + int(os.environ["CUDA_VISIBLE_DEVICES"])
50 | th.manual_seed(seed)
51 | run_losses_evaluation(model, diffusion, data, args.num_samples)
52 |
53 |
54 | def run_losses_evaluation(model, diffusion, data, num_samples):
55 | all_total_loss = []
56 | all_losses = {"loss": [], "vb": [], "mse": []}
57 | num_complete = 0
58 | while num_complete < num_samples:
59 | batch, model_kwargs = next(data)
60 | batch = batch.to(dist_util.dev())
61 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
62 | minibatch_losses = diffusion.calc_losses_loop(
63 | model, batch, model_kwargs=model_kwargs
64 | )
65 |
66 | for key in minibatch_losses:
67 | losses = minibatch_losses[key]
68 | gathered_losses = [th.zeros_like(losses) for _ in range(dist.get_world_size())]
69 | dist.all_gather(gathered_losses, losses) # gather not supported with NCCL
70 | all_losses[key].extend([losses.cpu().numpy() for losses in gathered_losses])
71 |
72 | total_loss = minibatch_losses["loss"]
73 | total_loss = total_loss.mean() / dist.get_world_size()
74 | dist.all_reduce(total_loss)
75 | all_total_loss.append(total_loss.item())
76 | num_complete += dist.get_world_size() * batch.shape[0]
77 |
78 | logger.log(f"done {num_complete} samples: total_loss={np.mean(all_total_loss)}")
79 |
80 | if dist.get_rank() == 0:
81 | for name in minibatch_losses:
82 | losses = np.concatenate(all_losses[name])
83 | shape_str = "x".join([str(x) for x in losses.shape])
84 | out_path = os.path.join(logger.get_dir(), f"{name}_losses_{shape_str}.npz")
85 | logger.log(f"saving {name} losses to {out_path}")
86 | np.savez(out_path, losses)
87 |
88 | dist.barrier()
89 | logger.log("evaluation complete")
90 |
91 |
92 | def create_argparser():
93 | defaults = dict(
94 | dataset_path="", dataset_name="", clip_denoised=True, num_samples=1000, batch_size=1, model_path=""
95 | )
96 | defaults.update(model_and_diffusion_defaults())
97 | parser = argparse.ArgumentParser()
98 | add_dict_to_argparser(parser, defaults)
99 | return parser
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/scripts/turb_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Print the model summary of UNetModel being used.
3 | """
4 |
5 | import argparse
6 | from torchsummary import summary
7 |
8 | from guided_diffusion import dist_util, logger
9 | from guided_diffusion.script_util import (
10 | model_and_diffusion_defaults,
11 | create_model_and_diffusion,
12 | args_to_dict,
13 | add_dict_to_argparser,
14 | )
15 |
16 |
17 | def main():
18 | args = create_argparser().parse_args()
19 |
20 | dist_util.setup_dist()
21 | logger.configure()
22 |
23 | logger.log("creating model and diffusion...")
24 | model, diffusion = create_model_and_diffusion(
25 | **args_to_dict(args, model_and_diffusion_defaults().keys())
26 | )
27 | model.to(dist_util.dev())
28 |
29 | summary(model, [(args.in_channels, args.image_size), ()])
30 |
31 |
32 | def create_argparser():
33 | defaults = dict(
34 | dataset_path="",
35 | dataset_name="",
36 | schedule_sampler="uniform",
37 | lr=1e-4,
38 | weight_decay=0.0,
39 | lr_anneal_steps=0,
40 | batch_size=1,
41 | microbatch=-1, # -1 disables microbatches
42 | ema_rate="0.9999", # comma-separated list of EMA values
43 | log_interval=10,
44 | save_interval=10000,
45 | resume_checkpoint="",
46 | use_fp16=False,
47 | fp16_scale_growth=1e-3,
48 | )
49 | defaults.update(model_and_diffusion_defaults())
50 | parser = argparse.ArgumentParser()
51 | add_dict_to_argparser(parser, defaults)
52 | return parser
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/scripts/turb_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of Lagrangian trajectories from a model and save them as a large
3 | numpy array. This can be used to produce samples for statistical evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 |
13 | from guided_diffusion import dist_util, logger
14 | from guided_diffusion.script_util import (
15 | NUM_CLASSES,
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | add_dict_to_argparser,
19 | args_to_dict,
20 | )
21 |
22 |
23 | def main():
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure()
28 |
29 | logger.log("creating model and diffusion...")
30 | model, diffusion = create_model_and_diffusion(
31 | **args_to_dict(args, model_and_diffusion_defaults().keys())
32 | )
33 | model.load_state_dict(
34 | dist_util.load_state_dict(args.model_path, map_location="cpu")
35 | )
36 | model.to(dist_util.dev())
37 | if args.use_fp16:
38 | model.convert_to_fp16()
39 | model.eval()
40 |
41 | logger.log("sampling...")
42 | all_images = []
43 | all_labels = []
44 | #noise = th.zeros(
45 | # noise = th.ones(
46 | # (args.batch_size, args.in_channels, args.image_size),
47 | # dtype=th.float32,
48 | # device=dist_util.dev()
49 | # ) * 2
50 | # noise = th.from_numpy(
51 | # np.load('../velocity_module-IS64-NC128-NRB3-DS4000-NScosine-LR1e-4-BS256-sample/fixed_noise_64x1x64x64.npy')
52 | # ).to(dtype=th.float32, device=dist_util.dev())
53 | import os
54 | seed = 0*8 + int(os.environ["CUDA_VISIBLE_DEVICES"])
55 | th.manual_seed(seed)
56 | while len(all_images) * args.batch_size < args.num_samples:
57 | model_kwargs = {}
58 | if args.class_cond:
59 | classes = th.randint(
60 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
61 | )
62 | model_kwargs["y"] = classes
63 | sample_fn = (
64 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
65 | )
66 | #sample_fn = diffusion.p_sample_loop_history
67 | sample = sample_fn(
68 | model,
69 | (args.batch_size, args.in_channels, args.image_size),
70 | #noise=noise,
71 | clip_denoised=args.clip_denoised,
72 | model_kwargs=model_kwargs,
73 | )
74 | sample = sample.clamp(-1, 1)
75 | #sample[:, -1] = sample[:, -1].clamp(-1, 1)
76 | sample = sample.permute(0, 2, 1)
77 | #sample = sample.permute(0, 1, 3, 2)
78 | sample = sample.contiguous()
79 |
80 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
81 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
82 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
83 | if args.class_cond:
84 | gathered_labels = [
85 | th.zeros_like(classes) for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(gathered_labels, classes)
88 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
89 | logger.log(f"created {len(all_images) * args.batch_size} samples")
90 |
91 | arr = np.concatenate(all_images, axis=0)
92 | arr = arr[: args.num_samples]
93 | if args.class_cond:
94 | label_arr = np.concatenate(all_labels, axis=0)
95 | label_arr = label_arr[: args.num_samples]
96 | if dist.get_rank() == 0:
97 | shape_str = "x".join([str(x) for x in arr.shape])
98 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
99 | logger.log(f"saving to {out_path}")
100 | if args.class_cond:
101 | np.savez(out_path, arr, label_arr)
102 | else:
103 | np.savez(out_path, arr)
104 |
105 | dist.barrier()
106 | logger.log("sampling complete")
107 |
108 |
109 | def create_argparser():
110 | defaults = dict(
111 | clip_denoised=True,
112 | num_samples=10000,
113 | batch_size=16,
114 | use_ddim=False,
115 | model_path="",
116 | )
117 | defaults.update(model_and_diffusion_defaults())
118 | parser = argparse.ArgumentParser()
119 | add_dict_to_argparser(parser, defaults)
120 | return parser
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/scripts/turb_sample_history.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of Lagrangian trajectories from a model and save them as a large
3 | numpy array. This can be used to produce samples for statistical evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 |
13 | from guided_diffusion import dist_util, logger
14 | from guided_diffusion.script_util import (
15 | NUM_CLASSES,
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | add_dict_to_argparser,
19 | args_to_dict,
20 | )
21 |
22 |
23 | def main():
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure()
28 |
29 | logger.log("creating model and diffusion...")
30 | model, diffusion = create_model_and_diffusion(
31 | **args_to_dict(args, model_and_diffusion_defaults().keys())
32 | )
33 | model.load_state_dict(
34 | dist_util.load_state_dict(args.model_path, map_location="cpu")
35 | )
36 | model.to(dist_util.dev())
37 | if args.use_fp16:
38 | model.convert_to_fp16()
39 | model.eval()
40 |
41 | logger.log("sampling...")
42 | #noise = th.zeros(
43 | # noise = th.ones(
44 | # (args.batch_size, args.in_channels, args.image_size),
45 | # dtype=th.float32,
46 | # device=dist_util.dev()
47 | # ) * 2
48 | # noise = th.from_numpy(
49 | # np.load('../velocity_module-IS64-NC128-NRB3-DS4000-NScosine-LR1e-4-BS256-sample/fixed_noise_64x1x64x64.npy')
50 | # ).to(dtype=th.float32, device=dist_util.dev())
51 | import os
52 | seed = 0*8 + int(os.environ["CUDA_VISIBLE_DEVICES"])
53 | th.manual_seed(seed)
54 | curr_batch, num_complete = 0, 0
55 | while num_complete < args.num_samples:
56 | model_kwargs = {}
57 | if args.class_cond:
58 | classes = th.randint(
59 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
60 | )
61 | model_kwargs["y"] = classes
62 | sample_fn = diffusion.p_sample_loop_history
63 | sample = sample_fn(
64 | model,
65 | (args.batch_size, args.in_channels, args.image_size),
66 | #noise=noise,
67 | clip_denoised=args.clip_denoised,
68 | model_kwargs=model_kwargs,
69 | )
70 | sample[:, -1] = sample[:, -1].clamp(-1, 1)
71 | sample = sample.permute(0, 1, 3, 2)
72 | sample = sample.contiguous()
73 |
74 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
75 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
76 | all_images = [sample.cpu().numpy() for sample in gathered_samples]
77 | if args.class_cond:
78 | gathered_labels = [
79 | th.zeros_like(classes) for _ in range(dist.get_world_size())
80 | ]
81 | dist.all_gather(gathered_labels, classes)
82 | all_labels = [labels.cpu().numpy() for labels in gathered_labels]
83 | curr_batch += 1
84 | num_complete += dist.get_world_size() * args.batch_size
85 | logger.log(f"created {num_complete} samples")
86 |
87 | arr = np.concatenate(all_images, axis=0)
88 | if args.class_cond:
89 | label_arr = np.concatenate(all_labels, axis=0)
90 | if dist.get_rank() == 0:
91 | shape_str = "x".join([str(x) for x in arr.shape])
92 | out_path = os.path.join(
93 | logger.get_dir(), "samples_history-seed0",
94 | f"batch{curr_batch:03d}-samples_history_{shape_str}.npz"
95 | )
96 | logger.log(f"saving to {out_path}")
97 | if args.class_cond:
98 | np.savez(out_path, arr, label_arr)
99 | else:
100 | np.savez(out_path, arr)
101 |
102 | dist.barrier()
103 | logger.log("sampling complete")
104 |
105 |
106 | def create_argparser():
107 | defaults = dict(
108 | clip_denoised=True,
109 | num_samples=10000,
110 | batch_size=16,
111 | use_ddim=False,
112 | model_path="",
113 | )
114 | defaults.update(model_and_diffusion_defaults())
115 | parser = argparse.ArgumentParser()
116 | add_dict_to_argparser(parser, defaults)
117 | return parser
118 |
119 |
120 | if __name__ == "__main__":
121 | main()
122 |
--------------------------------------------------------------------------------
/scripts/turb_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on Lagrangian trajectories in 3d turbulence.
3 | """
4 |
5 | import argparse
6 |
7 | from guided_diffusion import dist_util, logger
8 | from guided_diffusion.turb_datasets import load_data
9 | from guided_diffusion.resample import create_named_schedule_sampler
10 | from guided_diffusion.script_util import (
11 | model_and_diffusion_defaults,
12 | create_model_and_diffusion,
13 | args_to_dict,
14 | add_dict_to_argparser,
15 | )
16 | from guided_diffusion.train_util import TrainLoop
17 |
18 |
19 | def main():
20 | args = create_argparser().parse_args()
21 |
22 | dist_util.setup_dist()
23 | logger.configure()
24 |
25 | logger.log("creating model and diffusion...")
26 | model, diffusion = create_model_and_diffusion(
27 | **args_to_dict(args, model_and_diffusion_defaults().keys())
28 | )
29 | model.to(dist_util.dev())
30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
31 |
32 | logger.log("creating data loader...")
33 | data = load_data(
34 | dataset_path=args.dataset_path,
35 | dataset_name=args.dataset_name,
36 | batch_size=args.batch_size,
37 | class_cond=args.class_cond,
38 | )
39 |
40 | logger.log("training...")
41 | TrainLoop(
42 | model=model,
43 | diffusion=diffusion,
44 | data=data,
45 | batch_size=args.batch_size,
46 | microbatch=args.microbatch,
47 | lr=args.lr,
48 | ema_rate=args.ema_rate,
49 | log_interval=args.log_interval,
50 | save_interval=args.save_interval,
51 | resume_checkpoint=args.resume_checkpoint,
52 | use_fp16=args.use_fp16,
53 | fp16_scale_growth=args.fp16_scale_growth,
54 | schedule_sampler=schedule_sampler,
55 | weight_decay=args.weight_decay,
56 | lr_anneal_steps=args.lr_anneal_steps,
57 | ).run_loop()
58 |
59 |
60 | def create_argparser():
61 | defaults = dict(
62 | dataset_path="",
63 | dataset_name="",
64 | schedule_sampler="uniform",
65 | lr=1e-4,
66 | weight_decay=0.0,
67 | lr_anneal_steps=0,
68 | batch_size=1,
69 | microbatch=-1, # -1 disables microbatches
70 | ema_rate="0.9999", # comma-separated list of EMA values
71 | log_interval=10,
72 | save_interval=10000,
73 | resume_checkpoint="",
74 | use_fp16=False,
75 | fp16_scale_growth=1e-3,
76 | )
77 | defaults.update(model_and_diffusion_defaults())
78 | parser = argparse.ArgumentParser()
79 | add_dict_to_argparser(parser, defaults)
80 | return parser
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="guided-diffusion",
5 | py_modules=["guided_diffusion"],
6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7 | )
8 |
--------------------------------------------------------------------------------