├── LICENSE ├── README.md ├── file1000031_mask.pt ├── improved_diffusion.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── improved_diffusion ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── complex_datasets.cpython-38.pyc │ ├── complex_img_datasets.cpython-38.pyc │ ├── complex_multi_img_dataset.cpython-38.pyc │ ├── complex_repeat_img_dataset.cpython-38.pyc │ ├── complex_two_img_datasets.cpython-38.pyc │ ├── complex_two_img_skm_datasets.cpython-38.pyc │ ├── dist_util.cpython-38.pyc │ ├── fp16_util.cpython-38.pyc │ ├── gaussian_diffusion.cpython-38.pyc │ ├── gaussian_diffusion_duo.cpython-38.pyc │ ├── gaussian_diffusion_five.cpython-38.pyc │ ├── image_datasets.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── nn.cpython-38.pyc │ ├── nn_complex.cpython-38.pyc │ ├── resample.cpython-38.pyc │ ├── respace.cpython-38.pyc │ ├── respace_duo.cpython-38.pyc │ ├── respace_five.cpython-38.pyc │ ├── script_util.cpython-38.pyc │ ├── script_util_duo.cpython-38.pyc │ ├── script_util_five.cpython-38.pyc │ ├── train_util.cpython-38.pyc │ ├── unet.cpython-38.pyc │ └── unet_test.cpython-38.pyc ├── complex_two_img_datasets.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion_duo.py ├── logger.py ├── losses.py ├── nn.py ├── nn_complex.py ├── resample.py ├── respace_duo.py ├── script_util_duo.py ├── train_util.py ├── unet.py └── unet_test.py ├── scripts ├── data_process.py ├── image_sample_complex_duo.py ├── image_train.py └── test_eval.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 cpeng93 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DiffuseRecon 2 | 3 | This codebase is modified based on [Improved DDPM](https://github.com/openai/improved-diffusion) 4 | 5 | ## Installation 6 | 7 | Clone this repository and navigate to it in your terminal. Then run: 8 | 9 | ``` 10 | pip install -e . 11 | ``` 12 | 13 | This should install the `improved_diffusion` python package that the scripts depend on. 14 | 15 | ## Data Preparation and Pre-Trained Checkpoints 16 | 17 | A pre-trained checkpoint can be downloaded via this [link](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/cpeng26_jh_edu/ESGvudC6-ZlApb5xmkmDVzoBVk3Fn1QHXMFxBEvkayulgQ?e=5Xcjfv) or [link](https://drive.google.com/file/d/1rii1GJXW6pZNu3vajDcJe9huFrsX9zBS/view?usp=sharing). 18 | 19 | 20 | For FastMRI, the simplified h5 data can be downloaded by following the instructions in [ReconFormer](https://github.com/guopengf/ReconFormer), i.e. through [Link](https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/pguo4_jh_edu/EtXsMeyrJB1Pn-JOjM_UqhUBdY1KPrvs-PwF2fW7gERKIA?e=uuBINy). DiffuseRecon converts it to a normalized format in scripts/data_process.py 21 | 22 | ``` 23 | python scripts/data_process.py 24 | ``` 25 | 26 | 27 | 28 | ## Sampling 29 | 30 | ``` 31 | python scripts/image_sample_complex_duo.py --model_path img_space_dual/ema_0.9999_150000.pt --data_path EVAL_PATH \ 32 | --image_size 320 --num_channels 128 --num_res_blocks 3 --learn_sigma False --dropout 0.3 --diffusion_steps 4000 \ 33 | --noise_schedule cosine --timestep_respacing 100 --save_path test/ --num_samples 1 --batch_size 5 34 | ``` 35 | Note that timestep_respacing indicates the initial coarse sampling steps. 36 | ## Training 37 | 38 | ``` 39 | mpiexec -n GPU_NUMS python scripts/image_train.py --data_dir TRAIN_PATH --image_size 320 --num_channels 128\ 40 | --num_res_blocks 3 --learn_sigma False --dropout 0.3 --diffusion_steps 4000 --noise_schedule cosine --lr 1e-4 --batch_size 1\ 41 | --save_dir img_space_dual 42 | ``` 43 | ## TODO 44 | - Upload PSNR evaluation 45 | - Currently, the refinement step is fixed at 20 (line 592, gaussian_diffusion_duo.py); make this an adjustable input. 46 | - Graphics. 47 | -------------------------------------------------------------------------------- /file1000031_mask.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/file1000031_mask.pt -------------------------------------------------------------------------------- /improved_diffusion.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: improved-diffusion 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /improved_diffusion.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | improved_diffusion.egg-info/PKG-INFO 4 | improved_diffusion.egg-info/SOURCES.txt 5 | improved_diffusion.egg-info/dependency_links.txt 6 | improved_diffusion.egg-info/requires.txt 7 | improved_diffusion.egg-info/top_level.txt -------------------------------------------------------------------------------- /improved_diffusion.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /improved_diffusion.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | blobfile>=1.0.5 2 | torch 3 | tqdm 4 | -------------------------------------------------------------------------------- /improved_diffusion.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | improved_diffusion 2 | -------------------------------------------------------------------------------- /improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_img_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_img_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_multi_img_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_multi_img_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_repeat_img_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_repeat_img_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_two_img_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_two_img_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/complex_two_img_skm_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/complex_two_img_skm_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/fp16_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/fp16_util.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/gaussian_diffusion_duo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/gaussian_diffusion_duo.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/gaussian_diffusion_five.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/gaussian_diffusion_five.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/image_datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/image_datasets.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/nn_complex.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/nn_complex.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/respace.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/respace.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/respace_duo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/respace_duo.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/respace_five.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/respace_five.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/script_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/script_util.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/script_util_duo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/script_util_duo.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/script_util_five.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/script_util_five.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/__pycache__/unet_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpeng93/DiffuseRecon/13a9e166185a959e7196b0c9e3d170b8ec50dc0e/improved_diffusion/__pycache__/unet_test.cpython-38.pyc -------------------------------------------------------------------------------- /improved_diffusion/complex_two_img_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | import pickle, os 7 | 8 | def load_data( 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 10 | ): 11 | """ 12 | For a dataset, create a generator over (images, kwargs) pairs. 13 | 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 15 | more keys, each of which map to a batched Tensor of their own. 16 | The kwargs dict can be used for class labels, in which case the key is "y" 17 | and the values are integer tensors of class labels. 18 | 19 | :param data_dir: a dataset directory. 20 | :param batch_size: the batch size of each returned pair. 21 | :param image_size: the size to which images are resized. 22 | :param class_cond: if True, include a "y" key in returned dicts for class 23 | label. If classes are not available and this is true, an 24 | exception will be raised. 25 | :param deterministic: if True, yield results in a deterministic order. 26 | """ 27 | if not data_dir: 28 | raise ValueError("unspecified data directory") 29 | all_files = _list_image_files_recursively(data_dir) 30 | classes = None 31 | if class_cond: 32 | # Assume classes are the first part of the filename, 33 | # before an underscore. 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 36 | classes = [sorted_classes[x] for x in class_names] 37 | dataset = ImageDataset( 38 | image_size, 39 | all_files, 40 | classes=classes, 41 | shard=MPI.COMM_WORLD.Get_rank(), 42 | num_shards=MPI.COMM_WORLD.Get_size(), 43 | ) 44 | if deterministic: 45 | loader = DataLoader( 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 47 | ) 48 | else: 49 | loader = DataLoader( 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 51 | ) 52 | while True: 53 | yield from loader 54 | 55 | 56 | def _list_image_files_recursively(data_dir): 57 | results = [] 58 | for entry in sorted(bf.listdir(data_dir)): 59 | full_path = bf.join(data_dir, entry) 60 | ext = entry.split(".")[-1] 61 | if "." in entry and ext.lower() in ["pt"]: 62 | results.append(full_path) 63 | elif bf.isdir(full_path): 64 | results.extend(_list_image_files_recursively(full_path)) 65 | return results 66 | 67 | 68 | class ImageDataset(Dataset): 69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1): 70 | super().__init__() 71 | self.resolution = resolution 72 | self.local_images = image_paths[shard:][::num_shards] 73 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 74 | 75 | def __len__(self): 76 | return len(self.local_images) 77 | 78 | def __getitem__(self, idx): 79 | """ 80 | fastMRI is preprocessed and stored as pickle files, where kspace raw data is stored under 'img'; 81 | """ 82 | 83 | path = self.local_images[idx] 84 | slice_num = path.split('_')[-1][:-3] 85 | next_path = path.replace(slice_num+'.pt',str(int(slice_num)+1)+'.pt') 86 | if not os.path.isfile(next_path): 87 | next_path = path 88 | path = path.replace(slice_num+'.pt',str(int(slice_num)-1)+'.pt') 89 | 90 | arr = pickle.load(open(path,'rb'))['img'] 91 | arr_next = pickle.load(open(next_path,'rb'))['img'] 92 | real = np.real(arr) 93 | imag = np.imag(arr) 94 | real_next = np.real(arr_next) 95 | imag_next = np.imag(arr_next) 96 | out = np.stack([real,imag,real_next,imag_next]).astype(np.float32) 97 | max_val = abs(out).max() 98 | out /= max_val 99 | return out, {} 100 | -------------------------------------------------------------------------------- /improved_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | if th.cuda.is_available(): 49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 50 | return th.device("cpu") 51 | 52 | 53 | def load_state_dict(path, **kwargs): 54 | """ 55 | Load a PyTorch file without redundant fetches across MPI ranks. 56 | """ 57 | if MPI.COMM_WORLD.Get_rank() == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | else: 61 | data = None 62 | data = MPI.COMM_WORLD.bcast(data) 63 | return th.load(io.BytesIO(data), **kwargs) 64 | 65 | 66 | def sync_params(params): 67 | """ 68 | Synchronize a sequence of Tensors across ranks from rank 0. 69 | """ 70 | for p in params: 71 | with th.no_grad(): 72 | dist.broadcast(p, 0) 73 | 74 | 75 | def _find_free_port(): 76 | try: 77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 78 | s.bind(("", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | finally: 82 | s.close() 83 | -------------------------------------------------------------------------------- /improved_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /improved_diffusion/gaussian_diffusion_duo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | import time 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | 22 | The beta schedule library consists of beta schedules which remain similar 23 | in the limit of num_diffusion_timesteps. 24 | Beta schedules may be added, but should not be removed or changed once 25 | they are committed to maintain backwards compatibility. 26 | """ 27 | if schedule_name == "linear": 28 | # Linear schedule from Ho et al, extended to work for any number of 29 | # diffusion steps. 30 | scale = 1000 / num_diffusion_timesteps 31 | beta_start = scale * 0.0001 32 | beta_end = scale * 0.02 33 | return np.linspace( 34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 35 | ) 36 | elif schedule_name == "cosine": 37 | return betas_for_alpha_bar( 38 | num_diffusion_timesteps, 39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 40 | ) 41 | else: 42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 43 | 44 | 45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 46 | """ 47 | Create a beta schedule that discretizes the given alpha_t_bar function, 48 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 49 | 50 | :param num_diffusion_timesteps: the number of betas to produce. 51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 52 | produces the cumulative product of (1-beta) up to that 53 | part of the diffusion process. 54 | :param max_beta: the maximum beta to use; use values lower than 1 to 55 | prevent singularities. 56 | """ 57 | betas = [] 58 | for i in range(num_diffusion_timesteps): 59 | t1 = i / num_diffusion_timesteps 60 | t2 = (i + 1) / num_diffusion_timesteps 61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 62 | return np.array(betas) 63 | 64 | 65 | class ModelMeanType(enum.Enum): 66 | """ 67 | Which type of output the model predicts. 68 | """ 69 | 70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 71 | START_X = enum.auto() # the model predicts x_0 72 | EPSILON = enum.auto() # the model predicts epsilon 73 | 74 | 75 | class ModelVarType(enum.Enum): 76 | """ 77 | What is used as the model's output variance. 78 | 79 | The LEARNED_RANGE option has been added to allow the model to predict 80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 81 | """ 82 | 83 | LEARNED = enum.auto() 84 | FIXED_SMALL = enum.auto() 85 | FIXED_LARGE = enum.auto() 86 | LEARNED_RANGE = enum.auto() 87 | 88 | 89 | class LossType(enum.Enum): 90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 91 | RESCALED_MSE = ( 92 | enum.auto() 93 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 94 | KL = enum.auto() # use the variational lower-bound 95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 96 | 97 | def is_vb(self): 98 | return self == LossType.KL or self == LossType.RESCALED_KL 99 | 100 | 101 | class GaussianDiffusion: 102 | """ 103 | Utilities for training and sampling diffusion models. 104 | 105 | Ported directly from here, and then adapted over time to further experimentation. 106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 107 | 108 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 109 | starting at T and going to 1. 110 | :param model_mean_type: a ModelMeanType determining what the model outputs. 111 | :param model_var_type: a ModelVarType determining how variance is output. 112 | :param loss_type: a LossType determining the loss function to use. 113 | :param rescale_timesteps: if True, pass floating point timesteps into the 114 | model so that they are always scaled like in the 115 | original paper (0 to 1000). 116 | """ 117 | 118 | def __init__( 119 | self, 120 | *, 121 | betas, 122 | model_mean_type, 123 | model_var_type, 124 | loss_type, 125 | rescale_timesteps=False, 126 | ): 127 | self.model_mean_type = model_mean_type 128 | self.model_var_type = model_var_type 129 | self.loss_type = loss_type 130 | self.rescale_timesteps = rescale_timesteps 131 | 132 | # Use float64 for accuracy. 133 | betas = np.array(betas, dtype=np.float64) 134 | self.betas = betas 135 | assert len(betas.shape) == 1, "betas must be 1-D" 136 | assert (betas > 0).all() and (betas <= 1).all() 137 | 138 | self.num_timesteps = int(betas.shape[0]) 139 | 140 | alphas = 1.0 - betas 141 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 142 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 143 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 144 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 145 | 146 | # calculations for diffusion q(x_t | x_{t-1}) and others 147 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 148 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 149 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 150 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 151 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 152 | 153 | # calculations for posterior q(x_{t-1} | x_t, x_0) 154 | self.posterior_variance = ( 155 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 156 | ) 157 | # log calculation clipped because the posterior variance is 0 at the 158 | # beginning of the diffusion chain. 159 | self.posterior_log_variance_clipped = np.log( 160 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 161 | ) 162 | self.posterior_mean_coef1 = ( 163 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 164 | ) 165 | self.posterior_mean_coef2 = ( 166 | (1.0 - self.alphas_cumprod_prev) 167 | * np.sqrt(alphas) 168 | / (1.0 - self.alphas_cumprod) 169 | ) 170 | 171 | def q_mean_variance(self, x_start, t): 172 | """ 173 | Get the distribution q(x_t | x_0). 174 | 175 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 176 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 177 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 178 | """ 179 | mean = ( 180 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 181 | ) 182 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 183 | log_variance = _extract_into_tensor( 184 | self.log_one_minus_alphas_cumprod, t, x_start.shape 185 | ) 186 | return mean, variance, log_variance 187 | 188 | def q_sample(self, x_start, t, noise=None): 189 | """ 190 | Diffuse the data for a given number of diffusion steps. 191 | 192 | In other words, sample from q(x_t | x_0). 193 | 194 | :param x_start: the initial data batch. 195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 196 | :param noise: if specified, the split-out normal noise. 197 | :return: A noisy version of x_start. 198 | """ 199 | if noise is None: 200 | noise = th.randn_like(x_start) 201 | assert noise.shape == x_start.shape 202 | return ( 203 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 204 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 205 | * noise 206 | ) 207 | 208 | 209 | def q_sample_fft_duo(self, k_start, t, noise=None): 210 | """ 211 | Diffuse the data for a given number of diffusion steps. 212 | 213 | In other words, sample from q(x_t | x_0). 214 | 215 | :param x_start: the initial data batch. 216 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 217 | :param noise: if specified, the split-out normal noise. 218 | :return: A noisy version of x_start. 219 | """ 220 | if noise is None: 221 | noise = th.randn_like(k_start) 222 | assert noise.shape == k_start.shape 223 | 224 | noise1 = noise[:,[0]]+noise[:,[1]]*1j 225 | noise2 = noise[:,[2]]+noise[:,[3]]*1j 226 | noise1 = th.fft.fft2(noise1) 227 | noise2 = th.fft.fft2(noise2) 228 | noise = th.cat([noise1.real,noise1.imag,noise2.real,noise2.imag],1) 229 | 230 | return ( 231 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, k_start.shape) * k_start 232 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, k_start.shape) 233 | * noise 234 | ) 235 | 236 | def q_posterior_mean_variance(self, x_start, x_t, t): 237 | """ 238 | Compute the mean and variance of the diffusion posterior: 239 | 240 | q(x_{t-1} | x_t, x_0) 241 | 242 | """ 243 | assert x_start.shape == x_t.shape 244 | posterior_mean = ( 245 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 246 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 247 | ) 248 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 249 | posterior_log_variance_clipped = _extract_into_tensor( 250 | self.posterior_log_variance_clipped, t, x_t.shape 251 | ) 252 | assert ( 253 | posterior_mean.shape[0] 254 | == posterior_variance.shape[0] 255 | == posterior_log_variance_clipped.shape[0] 256 | == x_start.shape[0] 257 | ) 258 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 259 | 260 | def p_mean_variance( 261 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 262 | ): 263 | """ 264 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 265 | the initial x, x_0. 266 | 267 | :param model: the model, which takes a signal and a batch of timesteps 268 | as input. 269 | :param x: the [N x C x ...] tensor at time t. 270 | :param t: a 1-D Tensor of timesteps. 271 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 272 | :param denoised_fn: if not None, a function which applies to the 273 | x_start prediction before it is used to sample. Applies before 274 | clip_denoised. 275 | :param model_kwargs: if not None, a dict of extra keyword arguments to 276 | pass to the model. This can be used for conditioning. 277 | :return: a dict with the following keys: 278 | - 'mean': the model mean output. 279 | - 'variance': the model variance output. 280 | - 'log_variance': the log of 'variance'. 281 | - 'pred_xstart': the prediction for x_0. 282 | """ 283 | if model_kwargs is None: 284 | model_kwargs = {} 285 | 286 | B, C = x.shape[:2] 287 | assert t.shape == (B,) 288 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 289 | 290 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 291 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 292 | model_output, model_var_values = th.split(model_output, C, dim=1) 293 | if self.model_var_type == ModelVarType.LEARNED: 294 | model_log_variance = model_var_values 295 | model_variance = th.exp(model_log_variance) 296 | else: 297 | min_log = _extract_into_tensor( 298 | self.posterior_log_variance_clipped, t, x.shape 299 | ) 300 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 301 | # The model_var_values is [-1, 1] for [min_var, max_var]. 302 | frac = (model_var_values + 1) / 2 303 | model_log_variance = frac * max_log + (1 - frac) * min_log 304 | model_variance = th.exp(model_log_variance) 305 | else: 306 | model_variance, model_log_variance = { 307 | # for fixedlarge, we set the initial (log-)variance like so 308 | # to get a better decoder log likelihood. 309 | ModelVarType.FIXED_LARGE: ( 310 | np.append(self.posterior_variance[1], self.betas[1:]), 311 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 312 | ), 313 | ModelVarType.FIXED_SMALL: ( 314 | self.posterior_variance, 315 | self.posterior_log_variance_clipped, 316 | ), 317 | }[self.model_var_type] 318 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 319 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 320 | 321 | def process_xstart(x): 322 | if denoised_fn is not None: 323 | x = denoised_fn(x) 324 | if clip_denoised: 325 | return x.clamp(-1, 1) 326 | return x 327 | 328 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 329 | pred_xstart = process_xstart( 330 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 331 | ) 332 | model_mean = model_output 333 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 334 | if self.model_mean_type == ModelMeanType.START_X: 335 | pred_xstart = process_xstart(model_output) 336 | else: 337 | pred_xstart = process_xstart( 338 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 339 | ) 340 | model_mean, _, _ = self.q_posterior_mean_variance( 341 | x_start=pred_xstart, x_t=x, t=t 342 | ) 343 | else: 344 | raise NotImplementedError(self.model_mean_type) 345 | 346 | assert ( 347 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 348 | ) 349 | return { 350 | "mean": model_mean, 351 | "variance": model_variance, 352 | "log_variance": model_log_variance, 353 | "pred_xstart": pred_xstart, 354 | "model_output": model_output 355 | } 356 | 357 | def _predict_xstart_from_eps(self, x_t, t, eps): 358 | assert x_t.shape == eps.shape 359 | return ( 360 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 361 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 362 | ) 363 | 364 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 365 | assert x_t.shape == xprev.shape 366 | return ( # (xprev - coef2*x_t) / coef1 367 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 368 | - _extract_into_tensor( 369 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 370 | ) 371 | * x_t 372 | ) 373 | 374 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 375 | return ( 376 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 377 | - pred_xstart 378 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 379 | 380 | def _scale_timesteps(self, t): 381 | if self.rescale_timesteps: 382 | return t.float() * (1000.0 / self.num_timesteps) 383 | return t 384 | 385 | def p_sample( 386 | self, model, x, t, no_noise=False, clip_denoised=True, denoised_fn=None, model_kwargs=None 387 | ): 388 | """ 389 | Sample x_{t-1} from the model at the given timestep. 390 | 391 | :param model: the model to sample from. 392 | :param x: the current tensor at x_{t-1}. 393 | :param t: the value of t, starting at 0 for the first diffusion step. 394 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 395 | :param denoised_fn: if not None, a function which applies to the 396 | x_start prediction before it is used to sample. 397 | :param model_kwargs: if not None, a dict of extra keyword arguments to 398 | pass to the model. This can be used for conditioning. 399 | :return: a dict containing the following keys: 400 | - 'sample': a random sample from the model. 401 | - 'pred_xstart': a prediction of x_0. 402 | """ 403 | out = self.p_mean_variance( 404 | model, 405 | x, 406 | t, 407 | clip_denoised=clip_denoised, 408 | denoised_fn=denoised_fn, 409 | model_kwargs=model_kwargs, 410 | ) 411 | noise = th.randn_like(x) 412 | nonzero_mask = ( 413 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 414 | ) 415 | if no_noise: 416 | sample = out["mean"] 417 | else: 418 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 419 | return {"sample": sample, "pred_xstart": out["pred_xstart"],"model_output":out["model_output"]} 420 | 421 | def p_sample_loop( 422 | self, 423 | model, 424 | shape, 425 | noise=None, 426 | clip_denoised=True, 427 | denoised_fn=None, 428 | model_kwargs=None, 429 | device=None, 430 | progress=False, 431 | ): 432 | """ 433 | Generate samples from the model. 434 | 435 | :param model: the model module. 436 | :param shape: the shape of the samples, (N, C, H, W). 437 | :param noise: if specified, the noise from the encoder to sample. 438 | Should be of the same shape as `shape`. 439 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 440 | :param denoised_fn: if not None, a function which applies to the 441 | x_start prediction before it is used to sample. 442 | :param model_kwargs: if not None, a dict of extra keyword arguments to 443 | pass to the model. This can be used for conditioning. 444 | :param device: if specified, the device to create the samples on. 445 | If not specified, use a model parameter's device. 446 | :param progress: if True, show a tqdm progress bar. 447 | :return: a non-differentiable batch of samples. 448 | """ 449 | final = None 450 | for sample in self.p_sample_loop_progressive( 451 | model, 452 | shape, 453 | noise=noise, 454 | clip_denoised=clip_denoised, 455 | denoised_fn=denoised_fn, 456 | model_kwargs=model_kwargs, 457 | device=device, 458 | progress=progress, 459 | ): 460 | final = sample 461 | return final["sample"] 462 | 463 | def p_sample_loop_progressive( 464 | self, 465 | model, 466 | shape, 467 | noise=None, 468 | clip_denoised=True, 469 | denoised_fn=None, 470 | model_kwargs=None, 471 | device=None, 472 | progress=False, 473 | ): 474 | """ 475 | Generate samples from the model and yield intermediate samples from 476 | each timestep of diffusion. 477 | 478 | Arguments are the same as p_sample_loop(). 479 | Returns a generator over dicts, where each dict is the return value of 480 | p_sample(). 481 | """ 482 | if device is None: 483 | device = next(model.parameters()).device 484 | assert isinstance(shape, (tuple, list)) 485 | if noise is not None: 486 | img = noise 487 | else: 488 | img = th.randn(*shape, device=device) 489 | indices = list(range(self.num_timesteps))[::-1] 490 | 491 | if progress: 492 | # Lazy import so that we don't depend on tqdm. 493 | from tqdm.auto import tqdm 494 | 495 | indices = tqdm(indices) 496 | 497 | for i in indices: 498 | t = th.tensor([i] * shape[0], device=device) 499 | with th.no_grad(): 500 | out = self.p_sample( 501 | model, 502 | img, 503 | t, 504 | clip_denoised=clip_denoised, 505 | denoised_fn=denoised_fn, 506 | model_kwargs=model_kwargs, 507 | ) 508 | yield out 509 | img = out["sample"] 510 | 511 | def p_sample_loop_condition( 512 | self, 513 | model, 514 | shape, 515 | kspace, 516 | mask, 517 | noise=None, 518 | clip_denoised=True, 519 | denoised_fn=None, 520 | model_kwargs=None, 521 | device=None, 522 | progress=False, 523 | refine = False 524 | ): 525 | """ 526 | Generate samples from the model. 527 | 528 | :param model: the model module. 529 | :param shape: the shape of the samples, (N, C, H, W). 530 | :param noise: if specified, the noise from the encoder to sample. 531 | Should be of the same shape as `shape`. 532 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 533 | :param denoised_fn: if not None, a function which applies to the 534 | x_start prediction before it is used to sample. 535 | :param model_kwargs: if not None, a dict of extra keyword arguments to 536 | pass to the model. This can be used for conditioning. 537 | :param device: if specified, the device to create the samples on. 538 | If not specified, use a model parameter's device. 539 | :param progress: if True, show a tqdm progress bar. 540 | :return: a non-differentiable batch of samples. 541 | """ 542 | final = [] 543 | for sample in self.p_sample_loop_condition_progressive( 544 | model, 545 | shape, 546 | kspace, 547 | mask, 548 | noise=noise, 549 | clip_denoised=clip_denoised, 550 | denoised_fn=denoised_fn, 551 | model_kwargs=model_kwargs, 552 | device=device, 553 | progress=progress, 554 | refine=refine 555 | ): 556 | final.append(sample) 557 | return final 558 | 559 | def p_sample_loop_condition_progressive( 560 | self, 561 | model, 562 | shape, 563 | kspace, 564 | mask, 565 | noise=None, 566 | clip_denoised=True, 567 | denoised_fn=None, 568 | model_kwargs=None, 569 | device=None, 570 | progress=False, 571 | refine=False 572 | ): 573 | """ 574 | Generate samples from the model and yield intermediate samples from 575 | each timestep of diffusion. 576 | 577 | Arguments are the same as p_sample_loop(). 578 | Returns a generator over dicts, where each dict is the return value of 579 | p_sample(). 580 | """ 581 | if device is None: 582 | device = next(model.parameters()).device 583 | assert isinstance(shape, (tuple, list)) 584 | if noise is not None: 585 | img = noise 586 | print('PROVIDED INPUT') 587 | else: 588 | img = th.randn(*shape, device=device) 589 | print('NO INPUT') 590 | indices = list(range(self.num_timesteps))[::-1] 591 | if refine: 592 | indices = indices[-20:] 593 | 594 | if progress: 595 | # Lazy import so that we don't depend on tqdm. 596 | from tqdm.auto import tqdm 597 | 598 | indices = tqdm(indices) 599 | 600 | for i in indices: 601 | print('ITER:',i) 602 | t = th.tensor([i] * shape[0], device=device) 603 | with th.no_grad(): 604 | out = self.p_sample( 605 | model, 606 | img, 607 | t, 608 | no_noise=refine, 609 | clip_denoised=clip_denoised, 610 | denoised_fn=denoised_fn, 611 | model_kwargs=model_kwargs, 612 | ) 613 | 614 | 615 | if i != 0 and not refine: 616 | kspace_t_minus_1 = self.q_sample_fft_duo(kspace,t-1,out["model_output"]) 617 | img = self.mixture_duo(out["sample"], kspace_t_minus_1,mask) 618 | else: 619 | print('replace kspace') 620 | kspace_t_minus_1 = kspace 621 | img = self.mixture_duo(out["sample"], kspace_t_minus_1,mask) 622 | yield img# out['pred_xstart'] 623 | 624 | 625 | 626 | def mixture_duo(self,img,kspace,mask): 627 | img1 = img[:,[0]]+img[:,[1]]*1j 628 | img2 = img[:,[2]]+img[:,[3]]*1j 629 | img1 = th.fft.fft2(img1) 630 | img2 = th.fft.fft2(img2) 631 | kspace1 = kspace[:,[0]]+kspace[:,[1]]*1j 632 | kspace2 = kspace[:,[2]]+kspace[:,[3]]*1j 633 | out1 = mask[:,[0]]*kspace1+(1-mask[:,[0]])*img1 634 | out1 = th.fft.ifft2(out1) 635 | out2 = mask[:,[1]]*kspace2+(1-mask[:,[1]])*img2 636 | out2 = th.fft.ifft2(out2) 637 | out = th.cat([out1.real,out1.imag,out2.real,out2.imag],1) 638 | return out 639 | 640 | 641 | def ddim_sample( 642 | self, 643 | model, 644 | x, 645 | t, 646 | clip_denoised=True, 647 | denoised_fn=None, 648 | model_kwargs=None, 649 | eta=0.0, 650 | ): 651 | """ 652 | Sample x_{t-1} from the model using DDIM. 653 | 654 | Same usage as p_sample(). 655 | """ 656 | out = self.p_mean_variance( 657 | model, 658 | x, 659 | t, 660 | clip_denoised=clip_denoised, 661 | denoised_fn=denoised_fn, 662 | model_kwargs=model_kwargs, 663 | ) 664 | # Usually our model outputs epsilon, but we re-derive it 665 | # in case we used x_start or x_prev prediction. 666 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 667 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 668 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 669 | sigma = ( 670 | eta 671 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 672 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 673 | ) 674 | # Equation 12. 675 | noise = th.randn_like(x) 676 | mean_pred = ( 677 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 678 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 679 | ) 680 | nonzero_mask = ( 681 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 682 | ) # no noise when t == 0 683 | sample = mean_pred + nonzero_mask * sigma * noise 684 | return {"sample": sample, "pred_xstart": out["pred_xstart"],"model_output":out["model_output"]} 685 | 686 | def ddim_reverse_sample( 687 | self, 688 | model, 689 | x, 690 | t, 691 | clip_denoised=True, 692 | denoised_fn=None, 693 | model_kwargs=None, 694 | eta=0.0, 695 | ): 696 | """ 697 | Sample x_{t+1} from the model using DDIM reverse ODE. 698 | """ 699 | assert eta == 0.0, "Reverse ODE only for deterministic path" 700 | out = self.p_mean_variance( 701 | model, 702 | x, 703 | t, 704 | clip_denoised=clip_denoised, 705 | denoised_fn=denoised_fn, 706 | model_kwargs=model_kwargs, 707 | ) 708 | # Usually our model outputs epsilon, but we re-derive it 709 | # in case we used x_start or x_prev prediction. 710 | eps = ( 711 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 712 | - out["pred_xstart"] 713 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 714 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 715 | 716 | # Equation 12. reversed 717 | mean_pred = ( 718 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 719 | + th.sqrt(1 - alpha_bar_next) * eps 720 | ) 721 | 722 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 723 | 724 | def ddim_sample_loop( 725 | self, 726 | model, 727 | shape, 728 | noise=None, 729 | clip_denoised=True, 730 | denoised_fn=None, 731 | model_kwargs=None, 732 | device=None, 733 | progress=False, 734 | eta=0.0, 735 | ): 736 | """ 737 | Generate samples from the model using DDIM. 738 | 739 | Same usage as p_sample_loop(). 740 | """ 741 | final = None 742 | for sample in self.ddim_sample_loop_progressive( 743 | model, 744 | shape, 745 | noise=noise, 746 | clip_denoised=clip_denoised, 747 | denoised_fn=denoised_fn, 748 | model_kwargs=model_kwargs, 749 | device=device, 750 | progress=progress, 751 | eta=eta, 752 | ): 753 | final = sample 754 | return final["sample"] 755 | 756 | def ddim_sample_loop_condition( 757 | self, 758 | model, 759 | shape, 760 | kspace, 761 | mask, 762 | noise=None, 763 | clip_denoised=True, 764 | denoised_fn=None, 765 | model_kwargs=None, 766 | device=None, 767 | progress=False, 768 | eta=0.0, 769 | ): 770 | """ 771 | Generate samples from the model using DDIM. 772 | 773 | Same usage as p_sample_loop(). 774 | """ 775 | final = None 776 | final = [] 777 | for sample in self.ddim_sample_loop_condition_progressive( 778 | model, 779 | shape, 780 | kspace, 781 | mask, 782 | noise=noise, 783 | clip_denoised=clip_denoised, 784 | denoised_fn=denoised_fn, 785 | model_kwargs=model_kwargs, 786 | device=device, 787 | progress=progress, 788 | eta=eta, 789 | ): 790 | final.append(sample) 791 | return final 792 | 793 | 794 | def ddim_sample_loop_progressive( 795 | self, 796 | model, 797 | shape, 798 | noise=None, 799 | clip_denoised=True, 800 | denoised_fn=None, 801 | model_kwargs=None, 802 | device=None, 803 | progress=False, 804 | eta=0.0, 805 | ): 806 | """ 807 | Use DDIM to sample from the model and yield intermediate samples from 808 | each timestep of DDIM. 809 | 810 | Same usage as p_sample_loop_progressive(). 811 | """ 812 | if device is None: 813 | device = next(model.parameters()).device 814 | assert isinstance(shape, (tuple, list)) 815 | if noise is not None: 816 | img = noise 817 | else: 818 | img = th.randn(*shape, device=device) 819 | indices = list(range(self.num_timesteps))[::-1] 820 | 821 | if progress: 822 | # Lazy import so that we don't depend on tqdm. 823 | from tqdm.auto import tqdm 824 | 825 | indices = tqdm(indices) 826 | 827 | for i in indices: 828 | t = th.tensor([i] * shape[0], device=device) 829 | with th.no_grad(): 830 | out = self.ddim_sample( 831 | model, 832 | img, 833 | t, 834 | clip_denoised=clip_denoised, 835 | denoised_fn=denoised_fn, 836 | model_kwargs=model_kwargs, 837 | eta=eta, 838 | ) 839 | yield out 840 | img = out["sample"] 841 | 842 | 843 | def ddim_sample_loop_condition_progressive( 844 | self, 845 | model, 846 | shape, 847 | kspace, 848 | mask, 849 | noise=None, 850 | clip_denoised=True, 851 | denoised_fn=None, 852 | model_kwargs=None, 853 | device=None, 854 | progress=False, 855 | eta=0.0, 856 | ): 857 | """ 858 | Use DDIM to sample from the model and yield intermediate samples from 859 | each timestep of DDIM. 860 | 861 | Same usage as p_sample_loop_progressive(). 862 | """ 863 | if device is None: 864 | device = next(model.parameters()).device 865 | assert isinstance(shape, (tuple, list)) 866 | if noise is not None: 867 | img = noise 868 | else: 869 | img = th.randn(*shape, device=device) 870 | indices = list(range(self.num_timesteps))[::-1] 871 | 872 | if progress: 873 | # Lazy import so that we don't depend on tqdm. 874 | from tqdm.auto import tqdm 875 | 876 | indices = tqdm(indices) 877 | 878 | for i in indices: 879 | print('ITER:',i) 880 | t = th.tensor([i] * shape[0], device=device) 881 | with th.no_grad(): 882 | out = self.ddim_sample( 883 | model, 884 | img, 885 | t, 886 | clip_denoised=clip_denoised, 887 | denoised_fn=denoised_fn, 888 | model_kwargs=model_kwargs, 889 | eta=eta, 890 | ) 891 | # if i != 0: 892 | # kspace_t_minus_1 = self.q_sample_fft(kspace,t-1,out["model_output"]) 893 | # else: 894 | kspace_t_minus_1 = kspace 895 | out = self.mixture(out["sample"], kspace_t_minus_1,mask) 896 | yield out 897 | img = out 898 | 899 | def _vb_terms_bpd( 900 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 901 | ): 902 | """ 903 | Get a term for the variational lower-bound. 904 | 905 | The resulting units are bits (rather than nats, as one might expect). 906 | This allows for comparison to other papers. 907 | 908 | :return: a dict with the following keys: 909 | - 'output': a shape [N] tensor of NLLs or KLs. 910 | - 'pred_xstart': the x_0 predictions. 911 | """ 912 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 913 | x_start=x_start, x_t=x_t, t=t 914 | ) 915 | out = self.p_mean_variance( 916 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 917 | ) 918 | kl = normal_kl( 919 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 920 | ) 921 | kl = mean_flat(kl) / np.log(2.0) 922 | 923 | decoder_nll = -discretized_gaussian_log_likelihood( 924 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 925 | ) 926 | assert decoder_nll.shape == x_start.shape 927 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 928 | 929 | # At the first timestep return the decoder NLL, 930 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 931 | output = th.where((t == 0), decoder_nll, kl) 932 | return {"output": output, "pred_xstart": out["pred_xstart"]} 933 | 934 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 935 | """ 936 | Compute training losses for a single timestep. 937 | 938 | :param model: the model to evaluate loss on. 939 | :param x_start: the [N x C x ...] tensor of inputs. 940 | :param t: a batch of timestep indices. 941 | :param model_kwargs: if not None, a dict of extra keyword arguments to 942 | pass to the model. This can be used for conditioning. 943 | :param noise: if specified, the specific Gaussian noise to try to remove. 944 | :return: a dict with the key "loss" containing a tensor of shape [N]. 945 | Some mean or variance settings may also have other keys. 946 | """ 947 | if model_kwargs is None: 948 | model_kwargs = {} 949 | if noise is None: 950 | noise = th.randn_like(x_start) 951 | x_t = self.q_sample(x_start, t, noise=noise) 952 | 953 | terms = {} 954 | 955 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 956 | terms["loss"] = self._vb_terms_bpd( 957 | model=model, 958 | x_start=x_start, 959 | x_t=x_t, 960 | t=t, 961 | clip_denoised=False, 962 | model_kwargs=model_kwargs, 963 | )["output"] 964 | if self.loss_type == LossType.RESCALED_KL: 965 | terms["loss"] *= self.num_timesteps 966 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 967 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 968 | 969 | if self.model_var_type in [ 970 | ModelVarType.LEARNED, 971 | ModelVarType.LEARNED_RANGE, 972 | ]: 973 | B, C = x_t.shape[:2] 974 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 975 | model_output, model_var_values = th.split(model_output, C, dim=1) 976 | # Learn the variance using the variational bound, but don't let 977 | # it affect our mean prediction. 978 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 979 | terms["vb"] = self._vb_terms_bpd( 980 | model=lambda *args, r=frozen_out: r, 981 | x_start=x_start, 982 | x_t=x_t, 983 | t=t, 984 | clip_denoised=False, 985 | )["output"] 986 | if self.loss_type == LossType.RESCALED_MSE: 987 | # Divide by 1000 for equivalence with initial implementation. 988 | # Without a factor of 1/1000, the VB term hurts the MSE term. 989 | terms["vb"] *= self.num_timesteps / 1000.0 990 | 991 | target = { 992 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 993 | x_start=x_start, x_t=x_t, t=t 994 | )[0], 995 | ModelMeanType.START_X: x_start, 996 | ModelMeanType.EPSILON: noise, 997 | }[self.model_mean_type] 998 | assert model_output.shape == target.shape == x_start.shape 999 | terms["mse"] = mean_flat((target - model_output) ** 2) 1000 | if "vb" in terms: 1001 | terms["loss"] = terms["mse"] + terms["vb"] 1002 | else: 1003 | terms["loss"] = terms["mse"] 1004 | else: 1005 | raise NotImplementedError(self.loss_type) 1006 | 1007 | return terms 1008 | 1009 | def _prior_bpd(self, x_start): 1010 | """ 1011 | Get the prior KL term for the variational lower-bound, measured in 1012 | bits-per-dim. 1013 | 1014 | This term can't be optimized, as it only depends on the encoder. 1015 | 1016 | :param x_start: the [N x C x ...] tensor of inputs. 1017 | :return: a batch of [N] KL values (in bits), one per batch element. 1018 | """ 1019 | batch_size = x_start.shape[0] 1020 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 1021 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 1022 | kl_prior = normal_kl( 1023 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 1024 | ) 1025 | return mean_flat(kl_prior) / np.log(2.0) 1026 | 1027 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 1028 | """ 1029 | Compute the entire variational lower-bound, measured in bits-per-dim, 1030 | as well as other related quantities. 1031 | 1032 | :param model: the model to evaluate loss on. 1033 | :param x_start: the [N x C x ...] tensor of inputs. 1034 | :param clip_denoised: if True, clip denoised samples. 1035 | :param model_kwargs: if not None, a dict of extra keyword arguments to 1036 | pass to the model. This can be used for conditioning. 1037 | 1038 | :return: a dict containing the following keys: 1039 | - total_bpd: the total variational lower-bound, per batch element. 1040 | - prior_bpd: the prior term in the lower-bound. 1041 | - vb: an [N x T] tensor of terms in the lower-bound. 1042 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 1043 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 1044 | """ 1045 | device = x_start.device 1046 | batch_size = x_start.shape[0] 1047 | 1048 | vb = [] 1049 | xstart_mse = [] 1050 | mse = [] 1051 | for t in list(range(self.num_timesteps))[::-1]: 1052 | t_batch = th.tensor([t] * batch_size, device=device) 1053 | noise = th.randn_like(x_start) 1054 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 1055 | # Calculate VLB term at the current timestep 1056 | with th.no_grad(): 1057 | out = self._vb_terms_bpd( 1058 | model, 1059 | x_start=x_start, 1060 | x_t=x_t, 1061 | t=t_batch, 1062 | clip_denoised=clip_denoised, 1063 | model_kwargs=model_kwargs, 1064 | ) 1065 | vb.append(out["output"]) 1066 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 1067 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 1068 | mse.append(mean_flat((eps - noise) ** 2)) 1069 | 1070 | vb = th.stack(vb, dim=1) 1071 | xstart_mse = th.stack(xstart_mse, dim=1) 1072 | mse = th.stack(mse, dim=1) 1073 | 1074 | prior_bpd = self._prior_bpd(x_start) 1075 | total_bpd = vb.sum(dim=1) + prior_bpd 1076 | return { 1077 | "total_bpd": total_bpd, 1078 | "prior_bpd": prior_bpd, 1079 | "vb": vb, 1080 | "xstart_mse": xstart_mse, 1081 | "mse": mse, 1082 | } 1083 | 1084 | 1085 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 1086 | """ 1087 | Extract values from a 1-D numpy array for a batch of indices. 1088 | 1089 | :param arr: the 1-D numpy array. 1090 | :param timesteps: a tensor of indices into the array to extract. 1091 | :param broadcast_shape: a larger shape of K dimensions with the batch 1092 | dimension equal to the length of timesteps. 1093 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 1094 | """ 1095 | # print(arr,timesteps) 1096 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 1097 | while len(res.shape) < len(broadcast_shape): 1098 | res = res[..., None] 1099 | return res.expand(broadcast_shape) 1100 | -------------------------------------------------------------------------------- /improved_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | 447 | if dir is None: 448 | dir = osp.join( 449 | tempfile.gettempdir(), 450 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 451 | ) 452 | assert isinstance(dir, str) 453 | dir = os.path.expanduser(dir) 454 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 455 | 456 | rank = get_rank_without_mpi_import() 457 | if rank > 0: 458 | log_suffix = log_suffix + "-rank%03i" % rank 459 | 460 | if format_strs is None: 461 | if rank == 0: 462 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 463 | else: 464 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 465 | format_strs = filter(None, format_strs) 466 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 467 | 468 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 469 | if output_formats: 470 | log("Logging to %s" % dir) 471 | 472 | 473 | def _configure_default_logger(): 474 | configure() 475 | Logger.DEFAULT = Logger.CURRENT 476 | 477 | 478 | def reset(): 479 | if Logger.CURRENT is not Logger.DEFAULT: 480 | Logger.CURRENT.close() 481 | Logger.CURRENT = Logger.DEFAULT 482 | log("Reset logger") 483 | 484 | 485 | @contextmanager 486 | def scoped_configure(dir=None, format_strs=None, comm=None): 487 | prevlogger = Logger.CURRENT 488 | configure(dir=dir, format_strs=format_strs, comm=comm) 489 | try: 490 | yield 491 | finally: 492 | Logger.CURRENT.close() 493 | Logger.CURRENT = prevlogger 494 | 495 | -------------------------------------------------------------------------------- /improved_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /improved_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /improved_diffusion/nn_complex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.cfloat()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs).to(th.cfloat) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs).to(th.cfloat) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs).to(th.cfloat) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs).to(th.cfloat) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels).to(th.cfloat) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].cfloat() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /improved_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /improved_diffusion/respace_duo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion_duo import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel( 102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 103 | ) 104 | 105 | def _scale_timesteps(self, t): 106 | # Scaling is done by the wrapped model. 107 | return t 108 | 109 | 110 | class _WrappedModel: 111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 112 | self.model = model 113 | self.timestep_map = timestep_map 114 | self.rescale_timesteps = rescale_timesteps 115 | self.original_num_steps = original_num_steps 116 | 117 | def __call__(self, x, ts, **kwargs): 118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 119 | new_ts = map_tensor[ts] 120 | if self.rescale_timesteps: 121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 122 | return self.model(x, new_ts, **kwargs) 123 | -------------------------------------------------------------------------------- /improved_diffusion/script_util_duo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion_duo as gd 5 | from .respace_duo import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel 7 | 8 | 9 | 10 | def model_and_diffusion_defaults(): 11 | """ 12 | Defaults for image training. 13 | """ 14 | return dict( 15 | image_size=64, 16 | num_channels=128, 17 | num_res_blocks=2, 18 | num_heads=4, 19 | num_heads_upsample=-1, 20 | attention_resolutions="16,8", 21 | dropout=0.0, 22 | learn_sigma=False, 23 | sigma_small=False, 24 | class_cond=False, 25 | diffusion_steps=1000, 26 | noise_schedule="linear", 27 | timestep_respacing="", 28 | use_kl=False, 29 | predict_xstart=False, 30 | rescale_timesteps=True, 31 | rescale_learned_sigmas=True, 32 | use_checkpoint=False, 33 | use_scale_shift_norm=True, 34 | ) 35 | 36 | 37 | def create_model_and_diffusion( 38 | image_size, 39 | class_cond, 40 | learn_sigma, 41 | sigma_small, 42 | num_channels, 43 | num_res_blocks, 44 | num_heads, 45 | num_heads_upsample, 46 | attention_resolutions, 47 | dropout, 48 | diffusion_steps, 49 | noise_schedule, 50 | timestep_respacing, 51 | use_kl, 52 | predict_xstart, 53 | rescale_timesteps, 54 | rescale_learned_sigmas, 55 | use_checkpoint, 56 | use_scale_shift_norm, 57 | ): 58 | model = create_model( 59 | image_size, 60 | num_channels, 61 | num_res_blocks, 62 | learn_sigma=learn_sigma, 63 | class_cond=class_cond, 64 | use_checkpoint=use_checkpoint, 65 | attention_resolutions=attention_resolutions, 66 | num_heads=num_heads, 67 | num_heads_upsample=num_heads_upsample, 68 | use_scale_shift_norm=use_scale_shift_norm, 69 | dropout=dropout, 70 | ) 71 | diffusion = create_gaussian_diffusion( 72 | steps=diffusion_steps, 73 | learn_sigma=learn_sigma, 74 | sigma_small=sigma_small, 75 | noise_schedule=noise_schedule, 76 | use_kl=use_kl, 77 | predict_xstart=predict_xstart, 78 | rescale_timesteps=rescale_timesteps, 79 | rescale_learned_sigmas=rescale_learned_sigmas, 80 | timestep_respacing=timestep_respacing, 81 | ) 82 | return model, diffusion 83 | 84 | def create_model_and_two_diffusion( 85 | image_size, 86 | class_cond, 87 | learn_sigma, 88 | sigma_small, 89 | num_channels, 90 | num_res_blocks, 91 | num_heads, 92 | num_heads_upsample, 93 | attention_resolutions, 94 | dropout, 95 | diffusion_steps, 96 | noise_schedule, 97 | timestep_respacing, 98 | use_kl, 99 | predict_xstart, 100 | rescale_timesteps, 101 | rescale_learned_sigmas, 102 | use_checkpoint, 103 | use_scale_shift_norm, 104 | ): 105 | model = create_model( 106 | image_size, 107 | num_channels, 108 | num_res_blocks, 109 | learn_sigma=learn_sigma, 110 | class_cond=class_cond, 111 | use_checkpoint=use_checkpoint, 112 | attention_resolutions=attention_resolutions, 113 | num_heads=num_heads, 114 | num_heads_upsample=num_heads_upsample, 115 | use_scale_shift_norm=use_scale_shift_norm, 116 | dropout=dropout, 117 | ) 118 | diffusion = create_gaussian_diffusion( 119 | steps=diffusion_steps, 120 | learn_sigma=learn_sigma, 121 | sigma_small=sigma_small, 122 | noise_schedule=noise_schedule, 123 | use_kl=use_kl, 124 | predict_xstart=predict_xstart, 125 | rescale_timesteps=rescale_timesteps, 126 | rescale_learned_sigmas=rescale_learned_sigmas, 127 | timestep_respacing=timestep_respacing, 128 | ) 129 | diffusion_two = create_gaussian_diffusion( 130 | steps=diffusion_steps, 131 | learn_sigma=learn_sigma, 132 | sigma_small=sigma_small, 133 | noise_schedule=noise_schedule, 134 | use_kl=use_kl, 135 | predict_xstart=predict_xstart, 136 | rescale_timesteps=rescale_timesteps, 137 | rescale_learned_sigmas=rescale_learned_sigmas, 138 | timestep_respacing='500', 139 | ) 140 | return model, diffusion, diffusion_two 141 | 142 | def create_model( 143 | image_size, 144 | num_channels, 145 | num_res_blocks, 146 | learn_sigma, 147 | class_cond, 148 | use_checkpoint, 149 | attention_resolutions, 150 | num_heads, 151 | num_heads_upsample, 152 | use_scale_shift_norm, 153 | dropout, 154 | ): 155 | if image_size == 256: 156 | channel_mult = (1, 1, 2, 2, 4, 4) 157 | elif image_size == 320: 158 | channel_mult = (1, 1, 2, 2, 4, 4) 159 | elif image_size == 64: 160 | channel_mult = (1, 2, 3, 4) 161 | elif image_size == 32: 162 | channel_mult = (1, 2, 2, 2) 163 | else: 164 | raise ValueError(f"unsupported image size: {image_size}") 165 | 166 | attention_ds = [] 167 | for res in attention_resolutions.split(","): 168 | attention_ds.append(image_size // int(res)) 169 | 170 | return UNetModel( 171 | in_channels=4, 172 | model_channels=num_channels, 173 | out_channels=(4 if not learn_sigma else 8), 174 | num_res_blocks=num_res_blocks, 175 | attention_resolutions=tuple(attention_ds), 176 | dropout=dropout, 177 | channel_mult=channel_mult, 178 | num_classes=(NUM_CLASSES if class_cond else None), 179 | use_checkpoint=use_checkpoint, 180 | num_heads=num_heads, 181 | num_heads_upsample=num_heads_upsample, 182 | use_scale_shift_norm=use_scale_shift_norm, 183 | ) 184 | 185 | def sr_model_and_diffusion_defaults(): 186 | res = model_and_diffusion_defaults() 187 | res["large_size"] = 256 188 | res["small_size"] = 64 189 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 190 | for k in res.copy().keys(): 191 | if k not in arg_names: 192 | del res[k] 193 | return res 194 | 195 | 196 | def sr_create_model_and_diffusion( 197 | large_size, 198 | small_size, 199 | class_cond, 200 | learn_sigma, 201 | num_channels, 202 | num_res_blocks, 203 | num_heads, 204 | num_heads_upsample, 205 | attention_resolutions, 206 | dropout, 207 | diffusion_steps, 208 | noise_schedule, 209 | timestep_respacing, 210 | use_kl, 211 | predict_xstart, 212 | rescale_timesteps, 213 | rescale_learned_sigmas, 214 | use_checkpoint, 215 | use_scale_shift_norm, 216 | ): 217 | model = sr_create_model( 218 | large_size, 219 | small_size, 220 | num_channels, 221 | num_res_blocks, 222 | learn_sigma=learn_sigma, 223 | class_cond=class_cond, 224 | use_checkpoint=use_checkpoint, 225 | attention_resolutions=attention_resolutions, 226 | num_heads=num_heads, 227 | num_heads_upsample=num_heads_upsample, 228 | use_scale_shift_norm=use_scale_shift_norm, 229 | dropout=dropout, 230 | ) 231 | diffusion = create_gaussian_diffusion( 232 | steps=diffusion_steps, 233 | learn_sigma=learn_sigma, 234 | noise_schedule=noise_schedule, 235 | use_kl=use_kl, 236 | predict_xstart=predict_xstart, 237 | rescale_timesteps=rescale_timesteps, 238 | rescale_learned_sigmas=rescale_learned_sigmas, 239 | timestep_respacing=timestep_respacing, 240 | ) 241 | return model, diffusion 242 | 243 | 244 | 245 | def create_gaussian_diffusion( 246 | *, 247 | steps=1000, 248 | learn_sigma=False, 249 | sigma_small=False, 250 | noise_schedule="linear", 251 | use_kl=False, 252 | predict_xstart=False, 253 | rescale_timesteps=False, 254 | rescale_learned_sigmas=False, 255 | timestep_respacing="", 256 | ): 257 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 258 | if use_kl: 259 | loss_type = gd.LossType.RESCALED_KL 260 | elif rescale_learned_sigmas: 261 | loss_type = gd.LossType.RESCALED_MSE 262 | else: 263 | loss_type = gd.LossType.MSE 264 | if not timestep_respacing: 265 | timestep_respacing = [steps] 266 | return SpacedDiffusion( 267 | use_timesteps=space_timesteps(steps, timestep_respacing), 268 | betas=betas, 269 | model_mean_type=( 270 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 271 | ), 272 | model_var_type=( 273 | ( 274 | gd.ModelVarType.FIXED_LARGE 275 | if not sigma_small 276 | else gd.ModelVarType.FIXED_SMALL 277 | ) 278 | if not learn_sigma 279 | else gd.ModelVarType.LEARNED_RANGE 280 | ), 281 | loss_type=loss_type, 282 | rescale_timesteps=rescale_timesteps, 283 | ) 284 | 285 | 286 | def add_dict_to_argparser(parser, default_dict): 287 | for k, v in default_dict.items(): 288 | v_type = type(v) 289 | if v is None: 290 | v_type = str 291 | elif isinstance(v, bool): 292 | v_type = str2bool 293 | parser.add_argument(f"--{k}", default=v, type=v_type) 294 | 295 | 296 | def args_to_dict(args, keys): 297 | return {k: getattr(args, k) for k in keys} 298 | 299 | 300 | def str2bool(v): 301 | """ 302 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 303 | """ 304 | if isinstance(v, bool): 305 | return v 306 | if v.lower() in ("yes", "true", "t", "y", "1"): 307 | return True 308 | elif v.lower() in ("no", "false", "f", "n", "0"): 309 | return False 310 | else: 311 | raise argparse.ArgumentTypeError("boolean value expected") 312 | -------------------------------------------------------------------------------- /improved_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 10 | from torch.optim import AdamW 11 | 12 | from . import dist_util, logger 13 | from .fp16_util import ( 14 | make_master_params, 15 | master_params_to_model_params, 16 | model_grads_to_master_grads, 17 | unflatten_master_params, 18 | zero_grad, 19 | ) 20 | from .nn import update_ema 21 | from .resample import LossAwareSampler, UniformSampler 22 | 23 | # For ImageNet experiments, this was a good default value. 24 | # We found that the lg_loss_scale quickly climbed to 25 | # 20-21 within the first ~1K steps of training. 26 | INITIAL_LOG_LOSS_SCALE = 20.0 27 | 28 | 29 | class TrainLoop: 30 | def __init__( 31 | self, 32 | *, 33 | model, 34 | diffusion, 35 | data, 36 | batch_size, 37 | microbatch, 38 | lr, 39 | ema_rate, 40 | log_interval, 41 | save_interval, 42 | resume_checkpoint, 43 | use_fp16=False, 44 | fp16_scale_growth=1e-3, 45 | schedule_sampler=None, 46 | weight_decay=0.0, 47 | lr_anneal_steps=0, 48 | ): 49 | self.model = model 50 | self.diffusion = diffusion 51 | self.data = data 52 | self.batch_size = batch_size 53 | self.microbatch = microbatch if microbatch > 0 else batch_size 54 | self.lr = lr 55 | self.ema_rate = ( 56 | [ema_rate] 57 | if isinstance(ema_rate, float) 58 | else [float(x) for x in ema_rate.split(",")] 59 | ) 60 | self.log_interval = log_interval 61 | self.save_interval = save_interval 62 | self.resume_checkpoint = resume_checkpoint 63 | self.use_fp16 = use_fp16 64 | self.fp16_scale_growth = fp16_scale_growth 65 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 66 | self.weight_decay = weight_decay 67 | self.lr_anneal_steps = lr_anneal_steps 68 | 69 | self.step = 0 70 | self.resume_step = 0 71 | self.global_batch = self.batch_size * dist.get_world_size() 72 | 73 | self.model_params = list(self.model.parameters()) 74 | self.master_params = self.model_params 75 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 76 | self.sync_cuda = th.cuda.is_available() 77 | 78 | self._load_and_sync_parameters() 79 | if self.use_fp16: 80 | self._setup_fp16() 81 | 82 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 83 | if self.resume_step: 84 | self._load_optimizer_state() 85 | # Model was resumed, either due to a restart or a checkpoint 86 | # being specified at the command line. 87 | self.ema_params = [ 88 | self._load_ema_parameters(rate) for rate in self.ema_rate 89 | ] 90 | else: 91 | self.ema_params = [ 92 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 93 | ] 94 | 95 | if th.cuda.is_available(): 96 | self.use_ddp = True 97 | self.ddp_model = DDP( 98 | self.model, 99 | device_ids=[dist_util.dev()], 100 | output_device=dist_util.dev(), 101 | broadcast_buffers=False, 102 | bucket_cap_mb=128, 103 | find_unused_parameters=False, 104 | ) 105 | else: 106 | if dist.get_world_size() > 1: 107 | logger.warn( 108 | "Distributed training requires CUDA. " 109 | "Gradients will not be synchronized properly!" 110 | ) 111 | self.use_ddp = False 112 | self.ddp_model = self.model 113 | 114 | def _load_and_sync_parameters(self): 115 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 116 | 117 | if resume_checkpoint: 118 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 119 | if dist.get_rank() == 0: 120 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 121 | self.model.load_state_dict( 122 | dist_util.load_state_dict( 123 | resume_checkpoint, map_location=dist_util.dev() 124 | ) 125 | ) 126 | 127 | dist_util.sync_params(self.model.parameters()) 128 | 129 | def _load_ema_parameters(self, rate): 130 | ema_params = copy.deepcopy(self.master_params) 131 | 132 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 133 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 134 | if ema_checkpoint: 135 | if dist.get_rank() == 0: 136 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 137 | state_dict = dist_util.load_state_dict( 138 | ema_checkpoint, map_location=dist_util.dev() 139 | ) 140 | ema_params = self._state_dict_to_master_params(state_dict) 141 | 142 | dist_util.sync_params(ema_params) 143 | return ema_params 144 | 145 | def _load_optimizer_state(self): 146 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 147 | opt_checkpoint = bf.join( 148 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 149 | ) 150 | if bf.exists(opt_checkpoint): 151 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 152 | state_dict = dist_util.load_state_dict( 153 | opt_checkpoint, map_location=dist_util.dev() 154 | ) 155 | self.opt.load_state_dict(state_dict) 156 | 157 | def _setup_fp16(self): 158 | self.master_params = make_master_params(self.model_params) 159 | self.model.convert_to_fp16() 160 | 161 | def run_loop(self): 162 | while ( 163 | not self.lr_anneal_steps 164 | or self.step + self.resume_step < self.lr_anneal_steps 165 | ): 166 | batch, cond = next(self.data) 167 | self.run_step(batch, cond) 168 | if self.step % self.log_interval == 0: 169 | logger.dumpkvs() 170 | if self.step % self.save_interval == 0: 171 | self.save() 172 | # Run for a finite amount of time in integration tests. 173 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 174 | return 175 | self.step += 1 176 | # Save the last checkpoint if it wasn't already saved. 177 | if (self.step - 1) % self.save_interval != 0: 178 | self.save() 179 | 180 | def run_step(self, batch, cond): 181 | self.forward_backward(batch, cond) 182 | if self.use_fp16: 183 | self.optimize_fp16() 184 | else: 185 | self.optimize_normal() 186 | self.log_step() 187 | 188 | def forward_backward(self, batch, cond): 189 | zero_grad(self.model_params) 190 | for i in range(0, batch.shape[0], self.microbatch): 191 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 192 | micro_cond = { 193 | k: v[i : i + self.microbatch].to(dist_util.dev()) 194 | for k, v in cond.items() 195 | } 196 | last_batch = (i + self.microbatch) >= batch.shape[0] 197 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 198 | 199 | compute_losses = functools.partial( 200 | self.diffusion.training_losses, 201 | self.ddp_model, 202 | micro, 203 | t, 204 | model_kwargs=micro_cond, 205 | ) 206 | 207 | if last_batch or not self.use_ddp: 208 | losses = compute_losses() 209 | else: 210 | with self.ddp_model.no_sync(): 211 | losses = compute_losses() 212 | 213 | if isinstance(self.schedule_sampler, LossAwareSampler): 214 | self.schedule_sampler.update_with_local_losses( 215 | t, losses["loss"].detach() 216 | ) 217 | 218 | loss = (losses["loss"] * weights).mean() 219 | log_loss_dict( 220 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 221 | ) 222 | if self.use_fp16: 223 | loss_scale = 2 ** self.lg_loss_scale 224 | (loss * loss_scale).backward() 225 | else: 226 | loss.backward() 227 | 228 | def optimize_fp16(self): 229 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 230 | self.lg_loss_scale -= 1 231 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 232 | return 233 | 234 | model_grads_to_master_grads(self.model_params, self.master_params) 235 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 236 | self._log_grad_norm() 237 | self._anneal_lr() 238 | self.opt.step() 239 | for rate, params in zip(self.ema_rate, self.ema_params): 240 | update_ema(params, self.master_params, rate=rate) 241 | master_params_to_model_params(self.model_params, self.master_params) 242 | self.lg_loss_scale += self.fp16_scale_growth 243 | 244 | def optimize_normal(self): 245 | self._log_grad_norm() 246 | self._anneal_lr() 247 | self.opt.step() 248 | for rate, params in zip(self.ema_rate, self.ema_params): 249 | update_ema(params, self.master_params, rate=rate) 250 | 251 | def _log_grad_norm(self): 252 | sqsum = 0.0 253 | for p in self.master_params: 254 | sqsum += (p.grad ** 2).sum().item() 255 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 256 | 257 | def _anneal_lr(self): 258 | if not self.lr_anneal_steps: 259 | return 260 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 261 | lr = self.lr * (1 - frac_done) 262 | for param_group in self.opt.param_groups: 263 | param_group["lr"] = lr 264 | 265 | def log_step(self): 266 | logger.logkv("step", self.step + self.resume_step) 267 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 268 | if self.use_fp16: 269 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 270 | 271 | def save(self): 272 | def save_checkpoint(rate, params): 273 | state_dict = self._master_params_to_state_dict(params) 274 | if dist.get_rank() == 0: 275 | logger.log(f"saving model {rate}...") 276 | if not rate: 277 | filename = f"model{(self.step+self.resume_step):06d}.pt" 278 | else: 279 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 280 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 281 | th.save(state_dict, f) 282 | 283 | save_checkpoint(0, self.master_params) 284 | for rate, params in zip(self.ema_rate, self.ema_params): 285 | save_checkpoint(rate, params) 286 | 287 | if dist.get_rank() == 0: 288 | with bf.BlobFile( 289 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 290 | "wb", 291 | ) as f: 292 | th.save(self.opt.state_dict(), f) 293 | 294 | dist.barrier() 295 | 296 | def _master_params_to_state_dict(self, master_params): 297 | if self.use_fp16: 298 | master_params = unflatten_master_params( 299 | self.model.parameters(), master_params 300 | ) 301 | state_dict = self.model.state_dict() 302 | for i, (name, _value) in enumerate(self.model.named_parameters()): 303 | assert name in state_dict 304 | state_dict[name] = master_params[i] 305 | return state_dict 306 | 307 | def _state_dict_to_master_params(self, state_dict): 308 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 309 | if self.use_fp16: 310 | return make_master_params(params) 311 | else: 312 | return params 313 | 314 | 315 | def parse_resume_step_from_filename(filename): 316 | """ 317 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 318 | checkpoint's number of steps. 319 | """ 320 | split = filename.split("model") 321 | if len(split) < 2: 322 | return 0 323 | split1 = split[-1].split(".")[0] 324 | try: 325 | return int(split1) 326 | except ValueError: 327 | return 0 328 | 329 | 330 | def get_blob_logdir(): 331 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 332 | 333 | 334 | def find_resume_checkpoint(): 335 | # On your infrastructure, you may want to override this to automatically 336 | # discover the latest checkpoint on your blob storage, etc. 337 | return None 338 | 339 | 340 | def find_ema_checkpoint(main_checkpoint, step, rate): 341 | if main_checkpoint is None: 342 | return None 343 | filename = f"ema_{rate}_{(step):06d}.pt" 344 | path = bf.join(bf.dirname(main_checkpoint), filename) 345 | if bf.exists(path): 346 | return path 347 | return None 348 | 349 | 350 | def log_loss_dict(diffusion, ts, losses): 351 | for key, values in losses.items(): 352 | logger.logkv_mean(key, values.mean().item()) 353 | # Log the quantiles (four quartiles, in particular). 354 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 355 | quartile = int(4 * sub_t / diffusion.num_timesteps) 356 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 357 | -------------------------------------------------------------------------------- /improved_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .nn import ( 12 | SiLU, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | checkpoint, 20 | ) 21 | 22 | 23 | class TimestepBlock(nn.Module): 24 | """ 25 | Any module where forward() takes timestep embeddings as a second argument. 26 | """ 27 | 28 | @abstractmethod 29 | def forward(self, x, emb): 30 | """ 31 | Apply the module to `x` given `emb` timestep embeddings. 32 | """ 33 | 34 | 35 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 36 | """ 37 | A sequential module that passes timestep embeddings to the children that 38 | support it as an extra input. 39 | """ 40 | 41 | def forward(self, x, emb): 42 | for layer in self: 43 | if isinstance(layer, TimestepBlock): 44 | x = layer(x, emb) 45 | else: 46 | x = layer(x) 47 | return x 48 | 49 | 50 | class Upsample(nn.Module): 51 | """ 52 | An upsampling layer with an optional convolution. 53 | 54 | :param channels: channels in the inputs and outputs. 55 | :param use_conv: a bool determining if a convolution is applied. 56 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 57 | upsampling occurs in the inner-two dimensions. 58 | """ 59 | 60 | def __init__(self, channels, use_conv, dims=2): 61 | super().__init__() 62 | self.channels = channels 63 | self.use_conv = use_conv 64 | self.dims = dims 65 | if use_conv: 66 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 67 | 68 | def forward(self, x): 69 | assert x.shape[1] == self.channels 70 | if self.dims == 3: 71 | x = F.interpolate( 72 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 73 | ) 74 | else: 75 | x = F.interpolate(x, scale_factor=2, mode="nearest") 76 | if self.use_conv: 77 | x = self.conv(x) 78 | return x 79 | 80 | 81 | class Downsample(nn.Module): 82 | """ 83 | A downsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | downsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2): 92 | super().__init__() 93 | self.channels = channels 94 | self.use_conv = use_conv 95 | self.dims = dims 96 | stride = 2 if dims != 3 else (1, 2, 2) 97 | if use_conv: 98 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 99 | else: 100 | self.op = avg_pool_nd(stride) 101 | 102 | def forward(self, x): 103 | assert x.shape[1] == self.channels 104 | return self.op(x) 105 | 106 | 107 | class ResBlock(TimestepBlock): 108 | """ 109 | A residual block that can optionally change the number of channels. 110 | 111 | :param channels: the number of input channels. 112 | :param emb_channels: the number of timestep embedding channels. 113 | :param dropout: the rate of dropout. 114 | :param out_channels: if specified, the number of out channels. 115 | :param use_conv: if True and out_channels is specified, use a spatial 116 | convolution instead of a smaller 1x1 convolution to change the 117 | channels in the skip connection. 118 | :param dims: determines if the signal is 1D, 2D, or 3D. 119 | :param use_checkpoint: if True, use gradient checkpointing on this module. 120 | """ 121 | 122 | def __init__( 123 | self, 124 | channels, 125 | emb_channels, 126 | dropout, 127 | out_channels=None, 128 | use_conv=False, 129 | use_scale_shift_norm=False, 130 | dims=2, 131 | use_checkpoint=False, 132 | ): 133 | super().__init__() 134 | self.channels = channels 135 | self.emb_channels = emb_channels 136 | self.dropout = dropout 137 | self.out_channels = out_channels or channels 138 | self.use_conv = use_conv 139 | self.use_checkpoint = use_checkpoint 140 | self.use_scale_shift_norm = use_scale_shift_norm 141 | 142 | self.in_layers = nn.Sequential( 143 | normalization(channels), 144 | SiLU(), 145 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 146 | ) 147 | self.emb_layers = nn.Sequential( 148 | SiLU(), 149 | linear( 150 | emb_channels, 151 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 152 | ), 153 | ) 154 | self.out_layers = nn.Sequential( 155 | normalization(self.out_channels), 156 | SiLU(), 157 | nn.Dropout(p=dropout), 158 | zero_module( 159 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 160 | ), 161 | ) 162 | 163 | if self.out_channels == channels: 164 | self.skip_connection = nn.Identity() 165 | elif use_conv: 166 | self.skip_connection = conv_nd( 167 | dims, channels, self.out_channels, 3, padding=1 168 | ) 169 | else: 170 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 171 | 172 | def forward(self, x, emb): 173 | """ 174 | Apply the block to a Tensor, conditioned on a timestep embedding. 175 | 176 | :param x: an [N x C x ...] Tensor of features. 177 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 178 | :return: an [N x C x ...] Tensor of outputs. 179 | """ 180 | return checkpoint( 181 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 182 | ) 183 | 184 | def _forward(self, x, emb): 185 | h = self.in_layers(x) 186 | emb_out = self.emb_layers(emb).type(h.dtype) 187 | while len(emb_out.shape) < len(h.shape): 188 | emb_out = emb_out[..., None] 189 | if self.use_scale_shift_norm: 190 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 191 | scale, shift = th.chunk(emb_out, 2, dim=1) 192 | h = out_norm(h) * (1 + scale) + shift 193 | h = out_rest(h) 194 | else: 195 | h = h + emb_out 196 | h = self.out_layers(h) 197 | return self.skip_connection(x) + h 198 | 199 | 200 | class AttentionBlock(nn.Module): 201 | """ 202 | An attention block that allows spatial positions to attend to each other. 203 | 204 | Originally ported from here, but adapted to the N-d case. 205 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 206 | """ 207 | 208 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 209 | super().__init__() 210 | self.channels = channels 211 | self.num_heads = num_heads 212 | self.use_checkpoint = use_checkpoint 213 | 214 | self.norm = normalization(channels) 215 | self.qkv = conv_nd(1, channels, channels * 3, 1) 216 | self.attention = QKVAttention() 217 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 218 | 219 | def forward(self, x): 220 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 221 | 222 | def _forward(self, x): 223 | b, c, *spatial = x.shape 224 | x = x.reshape(b, c, -1) 225 | qkv = self.qkv(self.norm(x)) 226 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 227 | h = self.attention(qkv) 228 | h = h.reshape(b, -1, h.shape[-1]) 229 | h = self.proj_out(h) 230 | return (x + h).reshape(b, c, *spatial) 231 | 232 | 233 | class QKVAttention(nn.Module): 234 | """ 235 | A module which performs QKV attention. 236 | """ 237 | 238 | def forward(self, qkv): 239 | """ 240 | Apply QKV attention. 241 | 242 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 243 | :return: an [N x C x T] tensor after attention. 244 | """ 245 | ch = qkv.shape[1] // 3 246 | q, k, v = th.split(qkv, ch, dim=1) 247 | scale = 1 / math.sqrt(math.sqrt(ch)) 248 | weight = th.einsum( 249 | "bct,bcs->bts", q * scale, k * scale 250 | ) # More stable with f16 than dividing afterwards 251 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 252 | return th.einsum("bts,bcs->bct", weight, v) 253 | 254 | @staticmethod 255 | def count_flops(model, _x, y): 256 | """ 257 | A counter for the `thop` package to count the operations in an 258 | attention operation. 259 | 260 | Meant to be used like: 261 | 262 | macs, params = thop.profile( 263 | model, 264 | inputs=(inputs, timestamps), 265 | custom_ops={QKVAttention: QKVAttention.count_flops}, 266 | ) 267 | 268 | """ 269 | b, c, *spatial = y[0].shape 270 | num_spatial = int(np.prod(spatial)) 271 | # We perform two matmuls with the same number of ops. 272 | # The first computes the weight matrix, the second computes 273 | # the combination of the value vectors. 274 | matmul_ops = 2 * b * (num_spatial ** 2) * c 275 | model.total_ops += th.DoubleTensor([matmul_ops]) 276 | 277 | 278 | class UNetModel(nn.Module): 279 | """ 280 | The full UNet model with attention and timestep embedding. 281 | 282 | :param in_channels: channels in the input Tensor. 283 | :param model_channels: base channel count for the model. 284 | :param out_channels: channels in the output Tensor. 285 | :param num_res_blocks: number of residual blocks per downsample. 286 | :param attention_resolutions: a collection of downsample rates at which 287 | attention will take place. May be a set, list, or tuple. 288 | For example, if this contains 4, then at 4x downsampling, attention 289 | will be used. 290 | :param dropout: the dropout probability. 291 | :param channel_mult: channel multiplier for each level of the UNet. 292 | :param conv_resample: if True, use learned convolutions for upsampling and 293 | downsampling. 294 | :param dims: determines if the signal is 1D, 2D, or 3D. 295 | :param num_classes: if specified (as an int), then this model will be 296 | class-conditional with `num_classes` classes. 297 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 298 | :param num_heads: the number of attention heads in each attention layer. 299 | """ 300 | 301 | def __init__( 302 | self, 303 | in_channels, 304 | model_channels, 305 | out_channels, 306 | num_res_blocks, 307 | attention_resolutions, 308 | dropout=0, 309 | channel_mult=(1, 2, 4, 8), 310 | conv_resample=True, 311 | dims=2, 312 | num_classes=None, 313 | use_checkpoint=False, 314 | num_heads=1, 315 | num_heads_upsample=-1, 316 | use_scale_shift_norm=False, 317 | ): 318 | super().__init__() 319 | 320 | if num_heads_upsample == -1: 321 | num_heads_upsample = num_heads 322 | 323 | self.in_channels = in_channels 324 | self.model_channels = model_channels 325 | self.out_channels = out_channels 326 | self.num_res_blocks = num_res_blocks 327 | self.attention_resolutions = attention_resolutions 328 | self.dropout = dropout 329 | self.channel_mult = channel_mult 330 | self.conv_resample = conv_resample 331 | self.num_classes = num_classes 332 | self.use_checkpoint = use_checkpoint 333 | self.num_heads = num_heads 334 | self.num_heads_upsample = num_heads_upsample 335 | 336 | time_embed_dim = model_channels * 4 337 | self.time_embed = nn.Sequential( 338 | linear(model_channels, time_embed_dim), 339 | SiLU(), 340 | linear(time_embed_dim, time_embed_dim), 341 | ) 342 | 343 | if self.num_classes is not None: 344 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 345 | 346 | self.input_blocks = nn.ModuleList( 347 | [ 348 | TimestepEmbedSequential( 349 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 350 | ) 351 | ] 352 | ) 353 | input_block_chans = [model_channels] 354 | ch = model_channels 355 | ds = 1 356 | for level, mult in enumerate(channel_mult): 357 | for _ in range(num_res_blocks): 358 | layers = [ 359 | ResBlock( 360 | ch, 361 | time_embed_dim, 362 | dropout, 363 | out_channels=mult * model_channels, 364 | dims=dims, 365 | use_checkpoint=use_checkpoint, 366 | use_scale_shift_norm=use_scale_shift_norm, 367 | ) 368 | ] 369 | ch = mult * model_channels 370 | if ds in attention_resolutions: 371 | layers.append( 372 | AttentionBlock( 373 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 374 | ) 375 | ) 376 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 377 | input_block_chans.append(ch) 378 | if level != len(channel_mult) - 1: 379 | self.input_blocks.append( 380 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 381 | ) 382 | input_block_chans.append(ch) 383 | ds *= 2 384 | 385 | self.middle_block = TimestepEmbedSequential( 386 | ResBlock( 387 | ch, 388 | time_embed_dim, 389 | dropout, 390 | dims=dims, 391 | use_checkpoint=use_checkpoint, 392 | use_scale_shift_norm=use_scale_shift_norm, 393 | ), 394 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 395 | ResBlock( 396 | ch, 397 | time_embed_dim, 398 | dropout, 399 | dims=dims, 400 | use_checkpoint=use_checkpoint, 401 | use_scale_shift_norm=use_scale_shift_norm, 402 | ), 403 | ) 404 | 405 | self.output_blocks = nn.ModuleList([]) 406 | for level, mult in list(enumerate(channel_mult))[::-1]: 407 | for i in range(num_res_blocks + 1): 408 | layers = [ 409 | ResBlock( 410 | ch + input_block_chans.pop(), 411 | time_embed_dim, 412 | dropout, 413 | out_channels=model_channels * mult, 414 | dims=dims, 415 | use_checkpoint=use_checkpoint, 416 | use_scale_shift_norm=use_scale_shift_norm, 417 | ) 418 | ] 419 | ch = model_channels * mult 420 | if ds in attention_resolutions: 421 | layers.append( 422 | AttentionBlock( 423 | ch, 424 | use_checkpoint=use_checkpoint, 425 | num_heads=num_heads_upsample, 426 | ) 427 | ) 428 | if level and i == num_res_blocks: 429 | layers.append(Upsample(ch, conv_resample, dims=dims)) 430 | ds //= 2 431 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 432 | 433 | self.out = nn.Sequential( 434 | normalization(ch), 435 | SiLU(), 436 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 437 | ) 438 | 439 | def convert_to_fp16(self): 440 | """ 441 | Convert the torso of the model to float16. 442 | """ 443 | self.input_blocks.apply(convert_module_to_f16) 444 | self.middle_block.apply(convert_module_to_f16) 445 | self.output_blocks.apply(convert_module_to_f16) 446 | 447 | def convert_to_fp32(self): 448 | """ 449 | Convert the torso of the model to float32. 450 | """ 451 | self.input_blocks.apply(convert_module_to_f32) 452 | self.middle_block.apply(convert_module_to_f32) 453 | self.output_blocks.apply(convert_module_to_f32) 454 | 455 | @property 456 | def inner_dtype(self): 457 | """ 458 | Get the dtype used by the torso of the model. 459 | """ 460 | return next(self.input_blocks.parameters()).dtype 461 | 462 | def forward(self, x, timesteps, y=None): 463 | """ 464 | Apply the model to an input batch. 465 | 466 | :param x: an [N x C x ...] Tensor of inputs. 467 | :param timesteps: a 1-D batch of timesteps. 468 | :param y: an [N] Tensor of labels, if class-conditional. 469 | :return: an [N x C x ...] Tensor of outputs. 470 | """ 471 | # print(timesteps,timesteps.shape) 472 | assert (y is not None) == ( 473 | self.num_classes is not None 474 | ), "must specify y if and only if the model is class-conditional" 475 | 476 | hs = [] 477 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 478 | 479 | if self.num_classes is not None: 480 | assert y.shape == (x.shape[0],) 481 | emb = emb + self.label_emb(y) 482 | 483 | h = x.type(self.inner_dtype) 484 | for module in self.input_blocks: 485 | h = module(h, emb) 486 | hs.append(h) 487 | h = self.middle_block(h, emb) 488 | for module in self.output_blocks: 489 | cat_in = th.cat([h, hs.pop()], dim=1) 490 | h = module(cat_in, emb) 491 | h = h.type(x.dtype) 492 | return self.out(h) 493 | 494 | def get_feature_vectors(self, x, timesteps, y=None): 495 | """ 496 | Apply the model and return all of the intermediate tensors. 497 | 498 | :param x: an [N x C x ...] Tensor of inputs. 499 | :param timesteps: a 1-D batch of timesteps. 500 | :param y: an [N] Tensor of labels, if class-conditional. 501 | :return: a dict with the following keys: 502 | - 'down': a list of hidden state tensors from downsampling. 503 | - 'middle': the tensor of the output of the lowest-resolution 504 | block in the model. 505 | - 'up': a list of hidden state tensors from upsampling. 506 | """ 507 | hs = [] 508 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 509 | if self.num_classes is not None: 510 | assert y.shape == (x.shape[0],) 511 | emb = emb + self.label_emb(y) 512 | result = dict(down=[], up=[]) 513 | h = x.type(self.inner_dtype) 514 | for module in self.input_blocks: 515 | h = module(h, emb) 516 | hs.append(h) 517 | result["down"].append(h.type(x.dtype)) 518 | h = self.middle_block(h, emb) 519 | result["middle"] = h.type(x.dtype) 520 | for module in self.output_blocks: 521 | cat_in = th.cat([h, hs.pop()], dim=1) 522 | h = module(cat_in, emb) 523 | result["up"].append(h.type(x.dtype)) 524 | return result 525 | 526 | 527 | class SuperResModel(UNetModel): 528 | """ 529 | A UNetModel that performs super-resolution. 530 | 531 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 532 | """ 533 | 534 | def __init__(self, in_channels, *args, **kwargs): 535 | super().__init__(in_channels * 2, *args, **kwargs) 536 | 537 | def forward(self, x, timesteps, low_res=None, **kwargs): 538 | _, _, new_height, new_width = x.shape 539 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 540 | x = th.cat([x, upsampled], dim=1) 541 | return super().forward(x, timesteps, **kwargs) 542 | 543 | def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): 544 | _, new_height, new_width, _ = x.shape 545 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 546 | x = th.cat([x, upsampled], dim=1) 547 | return super().get_feature_vectors(x, timesteps, **kwargs) 548 | 549 | -------------------------------------------------------------------------------- /improved_diffusion/unet_test.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from nn_complex import ( 12 | SiLU, 13 | conv_nd, 14 | linear, 15 | avg_pool_nd, 16 | zero_module, 17 | normalization, 18 | timestep_embedding, 19 | checkpoint, 20 | ) 21 | 22 | 23 | class TimestepBlock(nn.Module): 24 | """ 25 | Any module where forward() takes timestep embeddings as a second argument. 26 | """ 27 | 28 | @abstractmethod 29 | def forward(self, x, emb): 30 | """ 31 | Apply the module to `x` given `emb` timestep embeddings. 32 | """ 33 | 34 | 35 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 36 | """ 37 | A sequential module that passes timestep embeddings to the children that 38 | support it as an extra input. 39 | """ 40 | 41 | def forward(self, x, emb): 42 | for layer in self: 43 | if isinstance(layer, TimestepBlock): 44 | x = layer(x, emb) 45 | else: 46 | x = layer(x) 47 | return x 48 | 49 | 50 | class Upsample(nn.Module): 51 | """ 52 | An upsampling layer with an optional convolution. 53 | 54 | :param channels: channels in the inputs and outputs. 55 | :param use_conv: a bool determining if a convolution is applied. 56 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 57 | upsampling occurs in the inner-two dimensions. 58 | """ 59 | 60 | def __init__(self, channels, use_conv, dims=2): 61 | super().__init__() 62 | self.channels = channels 63 | self.use_conv = use_conv 64 | self.dims = dims 65 | if use_conv: 66 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 67 | 68 | def forward(self, x): 69 | assert x.shape[1] == self.channels 70 | if self.dims == 3: 71 | x = F.interpolate( 72 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 73 | ) 74 | else: 75 | x = F.interpolate(x, scale_factor=2, mode="nearest") 76 | if self.use_conv: 77 | x = self.conv(x) 78 | return x 79 | 80 | 81 | class Downsample(nn.Module): 82 | """ 83 | A downsampling layer with an optional convolution. 84 | 85 | :param channels: channels in the inputs and outputs. 86 | :param use_conv: a bool determining if a convolution is applied. 87 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 88 | downsampling occurs in the inner-two dimensions. 89 | """ 90 | 91 | def __init__(self, channels, use_conv, dims=2): 92 | super().__init__() 93 | self.channels = channels 94 | self.use_conv = use_conv 95 | self.dims = dims 96 | stride = 2 if dims != 3 else (1, 2, 2) 97 | if use_conv: 98 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 99 | else: 100 | self.op = avg_pool_nd(stride) 101 | 102 | def forward(self, x): 103 | assert x.shape[1] == self.channels 104 | return self.op(x) 105 | 106 | 107 | class ResBlock(TimestepBlock): 108 | """ 109 | A residual block that can optionally change the number of channels. 110 | 111 | :param channels: the number of input channels. 112 | :param emb_channels: the number of timestep embedding channels. 113 | :param dropout: the rate of dropout. 114 | :param out_channels: if specified, the number of out channels. 115 | :param use_conv: if True and out_channels is specified, use a spatial 116 | convolution instead of a smaller 1x1 convolution to change the 117 | channels in the skip connection. 118 | :param dims: determines if the signal is 1D, 2D, or 3D. 119 | :param use_checkpoint: if True, use gradient checkpointing on this module. 120 | """ 121 | 122 | def __init__( 123 | self, 124 | channels, 125 | emb_channels, 126 | dropout, 127 | out_channels=None, 128 | use_conv=False, 129 | use_scale_shift_norm=False, 130 | dims=2, 131 | use_checkpoint=False, 132 | ): 133 | super().__init__() 134 | self.channels = channels 135 | self.emb_channels = emb_channels 136 | self.dropout = dropout 137 | self.out_channels = out_channels or channels 138 | self.use_conv = use_conv 139 | self.use_checkpoint = use_checkpoint 140 | self.use_scale_shift_norm = use_scale_shift_norm 141 | 142 | self.in_layers = nn.Sequential( 143 | normalization(channels), 144 | SiLU(), 145 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 146 | ) 147 | self.emb_layers = nn.Sequential( 148 | SiLU(), 149 | linear( 150 | emb_channels, 151 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 152 | ), 153 | ) 154 | self.out_layers = nn.Sequential( 155 | normalization(self.out_channels), 156 | SiLU(), 157 | nn.Dropout(p=dropout), 158 | zero_module( 159 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 160 | ), 161 | ) 162 | 163 | if self.out_channels == channels: 164 | self.skip_connection = nn.Identity() 165 | elif use_conv: 166 | self.skip_connection = conv_nd( 167 | dims, channels, self.out_channels, 3, padding=1 168 | ) 169 | else: 170 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 171 | 172 | def forward(self, x, emb): 173 | """ 174 | Apply the block to a Tensor, conditioned on a timestep embedding. 175 | 176 | :param x: an [N x C x ...] Tensor of features. 177 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 178 | :return: an [N x C x ...] Tensor of outputs. 179 | """ 180 | return checkpoint( 181 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 182 | ) 183 | 184 | def _forward(self, x, emb): 185 | h = self.in_layers(x) 186 | emb_out = self.emb_layers(emb).type(h.dtype) 187 | while len(emb_out.shape) < len(h.shape): 188 | emb_out = emb_out[..., None] 189 | if self.use_scale_shift_norm: 190 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 191 | scale, shift = th.chunk(emb_out, 2, dim=1) 192 | h = out_norm(h) * (1 + scale) + shift 193 | h = out_rest(h) 194 | else: 195 | h = h + emb_out 196 | h = self.out_layers(h) 197 | return self.skip_connection(x) + h 198 | 199 | 200 | class AttentionBlock(nn.Module): 201 | """ 202 | An attention block that allows spatial positions to attend to each other. 203 | 204 | Originally ported from here, but adapted to the N-d case. 205 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 206 | """ 207 | 208 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 209 | super().__init__() 210 | self.channels = channels 211 | self.num_heads = num_heads 212 | self.use_checkpoint = use_checkpoint 213 | 214 | self.norm = normalization(channels) 215 | self.qkv = conv_nd(1, channels, channels * 3, 1) 216 | self.attention = QKVAttention() 217 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 218 | 219 | def forward(self, x): 220 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 221 | 222 | def _forward(self, x): 223 | b, c, *spatial = x.shape 224 | x = x.reshape(b, c, -1) 225 | qkv = self.qkv(self.norm(x)) 226 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 227 | h = self.attention(qkv) 228 | h = h.reshape(b, -1, h.shape[-1]) 229 | h = self.proj_out(h) 230 | return (x + h).reshape(b, c, *spatial) 231 | 232 | 233 | class QKVAttention(nn.Module): 234 | """ 235 | A module which performs QKV attention. 236 | """ 237 | 238 | def forward(self, qkv): 239 | """ 240 | Apply QKV attention. 241 | 242 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 243 | :return: an [N x C x T] tensor after attention. 244 | """ 245 | ch = qkv.shape[1] // 3 246 | q, k, v = th.split(qkv, ch, dim=1) 247 | scale = 1 / math.sqrt(math.sqrt(ch)) 248 | weight = th.einsum( 249 | "bct,bcs->bts", q * scale, k * scale 250 | ) # More stable with f16 than dividing afterwards 251 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 252 | return th.einsum("bts,bcs->bct", weight, v) 253 | 254 | @staticmethod 255 | def count_flops(model, _x, y): 256 | """ 257 | A counter for the `thop` package to count the operations in an 258 | attention operation. 259 | 260 | Meant to be used like: 261 | 262 | macs, params = thop.profile( 263 | model, 264 | inputs=(inputs, timestamps), 265 | custom_ops={QKVAttention: QKVAttention.count_flops}, 266 | ) 267 | 268 | """ 269 | b, c, *spatial = y[0].shape 270 | num_spatial = int(np.prod(spatial)) 271 | # We perform two matmuls with the same number of ops. 272 | # The first computes the weight matrix, the second computes 273 | # the combination of the value vectors. 274 | matmul_ops = 2 * b * (num_spatial ** 2) * c 275 | model.total_ops += th.DoubleTensor([matmul_ops]) 276 | 277 | 278 | class UNetModel(nn.Module): 279 | """ 280 | The full UNet model with attention and timestep embedding. 281 | 282 | :param in_channels: channels in the input Tensor. 283 | :param model_channels: base channel count for the model. 284 | :param out_channels: channels in the output Tensor. 285 | :param num_res_blocks: number of residual blocks per downsample. 286 | :param attention_resolutions: a collection of downsample rates at which 287 | attention will take place. May be a set, list, or tuple. 288 | For example, if this contains 4, then at 4x downsampling, attention 289 | will be used. 290 | :param dropout: the dropout probability. 291 | :param channel_mult: channel multiplier for each level of the UNet. 292 | :param conv_resample: if True, use learned convolutions for upsampling and 293 | downsampling. 294 | :param dims: determines if the signal is 1D, 2D, or 3D. 295 | :param num_classes: if specified (as an int), then this model will be 296 | class-conditional with `num_classes` classes. 297 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 298 | :param num_heads: the number of attention heads in each attention layer. 299 | """ 300 | 301 | def __init__( 302 | self, 303 | in_channels, 304 | model_channels, 305 | out_channels, 306 | num_res_blocks, 307 | attention_resolutions, 308 | dropout=0, 309 | channel_mult=(1, 2, 4, 8), 310 | conv_resample=True, 311 | dims=2, 312 | num_classes=None, 313 | use_checkpoint=False, 314 | num_heads=1, 315 | num_heads_upsample=-1, 316 | use_scale_shift_norm=False, 317 | ): 318 | super().__init__() 319 | 320 | if num_heads_upsample == -1: 321 | num_heads_upsample = num_heads 322 | 323 | self.in_channels = in_channels 324 | self.model_channels = model_channels 325 | self.out_channels = out_channels 326 | self.num_res_blocks = num_res_blocks 327 | self.attention_resolutions = attention_resolutions 328 | self.dropout = dropout 329 | self.channel_mult = channel_mult 330 | self.conv_resample = conv_resample 331 | self.num_classes = num_classes 332 | self.use_checkpoint = use_checkpoint 333 | self.num_heads = num_heads 334 | self.num_heads_upsample = num_heads_upsample 335 | 336 | time_embed_dim = model_channels * 4 337 | self.time_embed = nn.Sequential( 338 | linear(model_channels, time_embed_dim), 339 | SiLU(), 340 | linear(time_embed_dim, time_embed_dim), 341 | ) 342 | 343 | if self.num_classes is not None: 344 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 345 | 346 | self.input_blocks = nn.ModuleList( 347 | [ 348 | TimestepEmbedSequential( 349 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 350 | ) 351 | ] 352 | ) 353 | input_block_chans = [model_channels] 354 | ch = model_channels 355 | ds = 1 356 | for level, mult in enumerate(channel_mult): 357 | for _ in range(num_res_blocks): 358 | layers = [ 359 | ResBlock( 360 | ch, 361 | time_embed_dim, 362 | dropout, 363 | out_channels=mult * model_channels, 364 | dims=dims, 365 | use_checkpoint=use_checkpoint, 366 | use_scale_shift_norm=use_scale_shift_norm, 367 | ) 368 | ] 369 | ch = mult * model_channels 370 | if ds in attention_resolutions: 371 | layers.append( 372 | AttentionBlock( 373 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 374 | ) 375 | ) 376 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 377 | input_block_chans.append(ch) 378 | if level != len(channel_mult) - 1: 379 | self.input_blocks.append( 380 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 381 | ) 382 | input_block_chans.append(ch) 383 | ds *= 2 384 | 385 | self.middle_block = TimestepEmbedSequential( 386 | ResBlock( 387 | ch, 388 | time_embed_dim, 389 | dropout, 390 | dims=dims, 391 | use_checkpoint=use_checkpoint, 392 | use_scale_shift_norm=use_scale_shift_norm, 393 | ), 394 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 395 | ResBlock( 396 | ch, 397 | time_embed_dim, 398 | dropout, 399 | dims=dims, 400 | use_checkpoint=use_checkpoint, 401 | use_scale_shift_norm=use_scale_shift_norm, 402 | ), 403 | ) 404 | 405 | self.output_blocks = nn.ModuleList([]) 406 | for level, mult in list(enumerate(channel_mult))[::-1]: 407 | for i in range(num_res_blocks + 1): 408 | layers = [ 409 | ResBlock( 410 | ch + input_block_chans.pop(), 411 | time_embed_dim, 412 | dropout, 413 | out_channels=model_channels * mult, 414 | dims=dims, 415 | use_checkpoint=use_checkpoint, 416 | use_scale_shift_norm=use_scale_shift_norm, 417 | ) 418 | ] 419 | ch = model_channels * mult 420 | if ds in attention_resolutions: 421 | layers.append( 422 | AttentionBlock( 423 | ch, 424 | use_checkpoint=use_checkpoint, 425 | num_heads=num_heads_upsample, 426 | ) 427 | ) 428 | if level and i == num_res_blocks: 429 | layers.append(Upsample(ch, conv_resample, dims=dims)) 430 | ds //= 2 431 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 432 | 433 | self.out = nn.Sequential( 434 | normalization(ch), 435 | SiLU(), 436 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 437 | ) 438 | 439 | def convert_to_fp16(self): 440 | """ 441 | Convert the torso of the model to float16. 442 | """ 443 | self.input_blocks.apply(convert_module_to_f16) 444 | self.middle_block.apply(convert_module_to_f16) 445 | self.output_blocks.apply(convert_module_to_f16) 446 | 447 | def convert_to_fp32(self): 448 | """ 449 | Convert the torso of the model to float32. 450 | """ 451 | self.input_blocks.apply(convert_module_to_f32) 452 | self.middle_block.apply(convert_module_to_f32) 453 | self.output_blocks.apply(convert_module_to_f32) 454 | 455 | @property 456 | def inner_dtype(self): 457 | """ 458 | Get the dtype used by the torso of the model. 459 | """ 460 | return next(self.input_blocks.parameters()).dtype 461 | 462 | def forward(self, x, timesteps, y=None): 463 | """ 464 | Apply the model to an input batch. 465 | 466 | :param x: an [N x C x ...] Tensor of inputs. 467 | :param timesteps: a 1-D batch of timesteps. 468 | :param y: an [N] Tensor of labels, if class-conditional. 469 | :return: an [N x C x ...] Tensor of outputs. 470 | """ 471 | assert (y is not None) == ( 472 | self.num_classes is not None 473 | ), "must specify y if and only if the model is class-conditional" 474 | 475 | hs = [] 476 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 477 | 478 | if self.num_classes is not None: 479 | assert y.shape == (x.shape[0],) 480 | emb = emb + self.label_emb(y) 481 | 482 | h = x.type(self.inner_dtype) 483 | for module in self.input_blocks: 484 | h = module(h, emb) 485 | hs.append(h) 486 | h = self.middle_block(h, emb) 487 | for module in self.output_blocks: 488 | cat_in = th.cat([h, hs.pop()], dim=1) 489 | h = module(cat_in, emb) 490 | h = h.type(x.dtype) 491 | return self.out(h) 492 | 493 | def get_feature_vectors(self, x, timesteps, y=None): 494 | """ 495 | Apply the model and return all of the intermediate tensors. 496 | 497 | :param x: an [N x C x ...] Tensor of inputs. 498 | :param timesteps: a 1-D batch of timesteps. 499 | :param y: an [N] Tensor of labels, if class-conditional. 500 | :return: a dict with the following keys: 501 | - 'down': a list of hidden state tensors from downsampling. 502 | - 'middle': the tensor of the output of the lowest-resolution 503 | block in the model. 504 | - 'up': a list of hidden state tensors from upsampling. 505 | """ 506 | hs = [] 507 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 508 | if self.num_classes is not None: 509 | assert y.shape == (x.shape[0],) 510 | emb = emb + self.label_emb(y) 511 | result = dict(down=[], up=[]) 512 | h = x.type(self.inner_dtype) 513 | for module in self.input_blocks: 514 | h = module(h, emb) 515 | hs.append(h) 516 | result["down"].append(h.type(x.dtype)) 517 | h = self.middle_block(h, emb) 518 | result["middle"] = h.type(x.dtype) 519 | for module in self.output_blocks: 520 | cat_in = th.cat([h, hs.pop()], dim=1) 521 | h = module(cat_in, emb) 522 | result["up"].append(h.type(x.dtype)) 523 | return result 524 | 525 | 526 | class SuperResModel(UNetModel): 527 | """ 528 | A UNetModel that performs super-resolution. 529 | 530 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 531 | """ 532 | 533 | def __init__(self, in_channels, *args, **kwargs): 534 | super().__init__(in_channels * 2, *args, **kwargs) 535 | 536 | def forward(self, x, timesteps, low_res=None, **kwargs): 537 | _, _, new_height, new_width = x.shape 538 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 539 | x = th.cat([x, upsampled], dim=1) 540 | return super().forward(x, timesteps, **kwargs) 541 | 542 | def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): 543 | _, new_height, new_width, _ = x.shape 544 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 545 | x = th.cat([x, upsampled], dim=1) 546 | return super().get_feature_vectors(x, timesteps, **kwargs) 547 | 548 | -------------------------------------------------------------------------------- /scripts/data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py, os,pickle,imageio 3 | 4 | def normalize_complex(data, eps=0.): 5 | mag = np.abs(data) 6 | mag_std = mag.std() 7 | return data / (mag_std + eps), mag_std 8 | 9 | 10 | # out_kspace_dir = '../mri_recon/DDPM/noscale/train/kspace/' 11 | # visual_dir = '../mri_recon/DDPM/noscale/train/visual/' 12 | out_img_dir = './data/val/img/' 13 | inp_dir = '../mri_recon/T1/val/' 14 | 15 | os.makedirs(out_img_dir, exist_ok=True) 16 | 17 | files = os.listdir(inp_dir) 18 | for file in files: 19 | data = h5py.File(inp_dir+file, 'r')['kspace'] 20 | for i in range(data.shape[0]): 21 | norm_kspace,std_kspace = normalize_complex(data[i]) 22 | img = np.fft.ifft2(norm_kspace) 23 | img = np.fft.fftshift(img) 24 | norm_img,std_img = normalize_complex(img) 25 | print(file,i) 26 | pickle.dump({'img':norm_img},open(out_img_dir+file.replace('.h5','_'+str(i)+'.pt'), 27 | 'wb')) -------------------------------------------------------------------------------- /scripts/image_sample_complex_duo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import torch as th 6 | import torch.distributed as dist 7 | import pickle 8 | from improved_diffusion import dist_util, logger 9 | from improved_diffusion.script_util_duo import ( 10 | model_and_diffusion_defaults, 11 | create_model_and_two_diffusion, 12 | add_dict_to_argparser, 13 | args_to_dict, 14 | ) 15 | import imageio 16 | mask = pickle.load(open('file1000031_mask.pt','rb')).view(1,1,320,320).cuda() 17 | mask= th.cat([mask,mask],1) 18 | images = ['file1000031'] 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | 23 | dist_util.setup_dist() 24 | logger.configure(dir=args.save_path) 25 | 26 | logger.log("creating model and diffusion...") 27 | model, diffusion, diffusion_two = create_model_and_two_diffusion( 28 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 29 | ) 30 | model.load_state_dict( 31 | dist_util.load_state_dict(args.model_path, map_location="cpu") 32 | ) 33 | model.to(dist_util.dev()) 34 | model.eval() 35 | for image in images: 36 | slice = 16 37 | coarse = [] 38 | for i in range(slice-1,slice+2): 39 | file_name1 = image + '_' + str(i) + '.pt' 40 | file_name2 = image + '_' + str(i + 1) + '.pt' 41 | kspace = load_data(args.data_path,file_name1,file_name2,args.batch_size) 42 | #save for refining 43 | if i == slice: 44 | input = kspace[[0]] 45 | logger.log("sampling...") 46 | samples = [] 47 | for _ in range(2): 48 | model_kwargs = {} 49 | sample = diffusion.p_sample_loop_condition( 50 | model, 51 | (args.batch_size, 4, args.image_size, args.image_size), 52 | kspace, 53 | mask, 54 | clip_denoised=args.clip_denoised, 55 | model_kwargs=model_kwargs 56 | )[-1] 57 | samples.append(sample) 58 | samples = th.cat(samples) 59 | coarse.append(samples.contiguous()) 60 | 61 | coarse = th.stack(coarse) 62 | print(coarse.shape) 63 | aggregate = [] 64 | for k in range(2): 65 | aggregate.append((coarse[k,:,[2,3]].mean(0) + coarse[k+1,:,[0,1]].mean(0)).view(1,2,320,320)/2) 66 | aggregate = th.cat(aggregate,1) 67 | print(aggregate.shape) 68 | 69 | sample2 = diffusion_two.p_sample_loop_condition( 70 | model, 71 | (1, 4, args.image_size, args.image_size), 72 | input, 73 | mask, 74 | noise=aggregate.float(), 75 | clip_denoised=args.clip_denoised, 76 | model_kwargs={}, 77 | refine=True 78 | ) 79 | sample2 = sample2[-1].cpu().data.numpy() 80 | pickle.dump({'coarse':coarse.cpu().data.numpy(),'fine':sample2}, 81 | open(os.path.join(args.save_path,image+'_'+str(slice)+'.pt'),'wb')) 82 | vis = np.abs(sample2[0,0]+sample2[0,1]*1j) 83 | imageio.imsave(os.path.join(args.save_path,image+'_'+str(slice)+'.png'),vis/vis.max()) 84 | 85 | def load_data(data_path, file1, file2, batch_size): 86 | # load two slices 87 | img_prior1 = pickle.load(open(os.path.join(data_path, file1), 'rb'))['img'] 88 | img_prior2 = pickle.load(open(os.path.join(data_path, file2), 'rb'))['img'] 89 | print('loading', file1, file2) 90 | data = np.stack([np.real(img_prior1), np.imag(img_prior1), np.real(img_prior2), np.imag(img_prior2)]).astype( 91 | np.float32) 92 | max_val = abs(data[:2]).max() 93 | data[:2] /= max_val 94 | max_val = abs(data[2:4]).max() 95 | data[2:4] /= max_val 96 | # regularizing over max value ensures this model works over different preprocessing schemes; 97 | # to not use the gt max value, selecting an appropriate averaged max value from training set leads to 98 | # similar performance, e.g. 99 | # data /= 7.21 (average max value); in general max_value is at DC and should be accessible. 100 | data1 = data[0] + data[1] * 1j 101 | data2 = data[2] + data[3] * 1j 102 | kspace1 = np.fft.fft2(data1) 103 | kspace2 = np.fft.fft2(data2) 104 | kspace = th.FloatTensor( 105 | np.stack([np.real(kspace1), np.imag(kspace1), np.real(kspace2), np.imag(kspace2)])) \ 106 | .cuda().view(1, 4, 320, 320).repeat(batch_size, 1, 1, 1).float() 107 | return kspace 108 | 109 | def create_argparser(): 110 | defaults = dict( 111 | clip_denoised=True, 112 | num_samples=100, 113 | batch_size=5, 114 | use_ddim=False, 115 | model_path="", 116 | data_path="", 117 | save_path="" 118 | ) 119 | defaults.update(model_and_diffusion_defaults()) 120 | parser = argparse.ArgumentParser() 121 | add_dict_to_argparser(parser, defaults) 122 | return parser 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from improved_diffusion import dist_util, logger 8 | from improved_diffusion.complex_two_img_datasets import load_data 9 | from improved_diffusion.resample import create_named_schedule_sampler 10 | from improved_diffusion.script_util_duo import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from improved_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure(args.save_dir) 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | data_dir=args.data_dir, 35 | batch_size=args.batch_size, 36 | image_size=args.image_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | data_dir="", 63 | schedule_sampler="uniform", 64 | lr=1e-4, 65 | weight_decay=0.0, 66 | lr_anneal_steps=0, 67 | batch_size=1, 68 | microbatch=-1, # -1 disables microbatches 69 | ema_rate="0.9999", # comma-separated list of EMA values 70 | log_interval=10, 71 | save_interval=10000, 72 | resume_checkpoint="", 73 | use_fp16=False, 74 | fp16_scale_growth=1e-3, 75 | save_dir='img_space_320', 76 | ) 77 | defaults.update(model_and_diffusion_defaults()) 78 | parser = argparse.ArgumentParser() 79 | add_dict_to_argparser(parser, defaults) 80 | return parser 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /scripts/test_eval.py: -------------------------------------------------------------------------------- 1 | from skimage.metrics import peak_signal_noise_ratio as psnr 2 | from skimage.metrics import structural_similarity as ssim 3 | import numpy as np 4 | import os, pickle 5 | 6 | #provide your path for DiffuseRecon results, saved in pickle, in the dimension of of H,W, # of slice 7 | duo_dir = 'kspace_duo_same_mask_all/vols/results/' 8 | #provide your path for groundtruth 9 | orig_gt_dir = '/cis/home/cpeng/mri_recon/T1/val/' 10 | files = os.listdir(orig_gt_dir) 11 | def norm(img): 12 | img -= img.mean() 13 | img /= img.std() 14 | return img 15 | 16 | def normalize_complex(data, eps=0.): 17 | mag = np.abs(data) 18 | mag_std = mag.std() 19 | return data / (mag_std + eps), mag_std 20 | 21 | psnr_DDPM_duo = [] 22 | for file in files: 23 | data = h5py.File(orig_gt_dir + file, 'r')['kspace'] 24 | orig_target = [] 25 | #the same as data_process, normed in the image space in the end 26 | for i in range(data.shape[0]): 27 | norm_kspace,std_kspace = normalize_complex(data[i]) 28 | img = np.fft.ifft2(norm_kspace) 29 | img = np.fft.fftshift(img) 30 | norm_img,std_img = normalize_complex(img) 31 | #norm in the image space 32 | orig_target.append(norm(np.abs(norm_img))) 33 | 34 | orig_target = np.asarray(orig_target).transpose(1,2,0) 35 | DDPM_duo = pickle.load(open(duo_dir+file.replace('.h5','_full.pt'),'rb')) 36 | orig_target = orig_target[...,4:-1] 37 | DDPM_duo = DDPM_duo[...,4:-1] 38 | 39 | data_range = orig_target.max() - orig_target.min() 40 | psnr_DDPM_duo.append(psnr(orig_target[...,:min_overlap],DDPM_duo[...,:min_overlap],data_range=data_range)) 41 | print(psnr_DDPM_duo[-1]) 42 | 43 | print(np.mean(psnr_DDPM_duo)) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="improved-diffusion", 5 | py_modules=["improved_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | --------------------------------------------------------------------------------