├── improved_diffusion ├── __init__.py ├── utils.py ├── dist_util.py ├── fp16_util.py ├── RRDB.py ├── losses.py ├── metrics.py ├── image_datasets.py ├── respace.py ├── nn.py ├── resample.py ├── script_util.py ├── sampling_util.py ├── logger.py ├── train_util.py ├── unet.py └── gaussian_diffusion.py ├── environment.yml ├── .gitignore ├── datasets ├── preprocess_vaihingen.py ├── vaih.py ├── monu.py ├── city.py └── transforms.py ├── image_sample_diff_medical.py ├── image_sample_diff_city.py ├── image_sample_diff_vaih.py ├── README.md ├── image_train_diff_medical.py ├── image_train_diff_vaih.py └── image_train_diff_city.py /improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: segdiff 2 | channels: 3 | - anaconda 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.8.12 9 | - pip=21.2.4 10 | - pytorch=1.9.0 11 | - torchvision=0.10.0 12 | - cudatoolkit=11.1 13 | - mpi4py=3.1.2 14 | - tqdm=4.62.3 15 | - scikit-learn=0.24.2 16 | - scikit-image=0.18.3 17 | - matplotlib=3.4.3 18 | - seaborn=0.11.2 19 | - pip: 20 | - opencv-python==4.5.1.48 21 | - blobfile==1.2.3 22 | - pycocotools==2.0.2 23 | - gitpython==3.1.24 24 | - kornia==0.5.11 25 | - h5py==3.4.0 26 | - imagecodecs==2021.11.20 27 | -------------------------------------------------------------------------------- /improved_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_random_seed(seed, deterministic=False): 8 | """Set random seed. 9 | Args: 10 | seed (int): Seed to be used. 11 | deterministic (bool): Whether to set the deterministic option for 12 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 13 | to True and `torch.backends.cudnn.benchmark` to False. 14 | Default: False. 15 | """ 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | if deterministic: 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | 25 | def set_random_seed_for_iterations(seed): 26 | """Set random seed. 27 | Args: 28 | seed (int): Seed to be used. 29 | deterministic (bool): Whether to set the deterministic option for 30 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 31 | to True and `torch.backends.cudnn.benchmark` to False. 32 | Default: False. 33 | """ 34 | random.seed(seed) 35 | np.random.seed(seed) 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data 107 | .vscode 108 | .idea 109 | 110 | # custom 111 | *.pkl 112 | *.pkl.json 113 | *.log.json 114 | work_dirs/ 115 | work_dirs 116 | pretrained 117 | pretrained/ 118 | # Pytorch 119 | *.pth 120 | trash/ 121 | trash 122 | -------------------------------------------------------------------------------- /improved_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | if th.cuda.is_available(): 49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 50 | return th.device("cpu") 51 | 52 | 53 | def load_state_dict(path, **kwargs): 54 | """ 55 | Load a PyTorch file without redundant fetches across MPI ranks. 56 | """ 57 | if MPI.COMM_WORLD.Get_rank() == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | else: 61 | data = None 62 | data = MPI.COMM_WORLD.bcast(data) 63 | return th.load(io.BytesIO(data), **kwargs) 64 | 65 | 66 | def sync_params(params): 67 | """ 68 | Synchronize a sequence of Tensors across ranks from rank 0. 69 | """ 70 | for p in params: 71 | with th.no_grad(): 72 | dist.broadcast(p, 0) 73 | 74 | 75 | def _find_free_port(): 76 | try: 77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 78 | s.bind(("", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | finally: 82 | s.close() 83 | -------------------------------------------------------------------------------- /improved_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /datasets/preprocess_vaihingen.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import os 5 | import cv2 6 | import numpy as np 7 | from cv2 import resize 8 | 9 | 10 | def get_img(cfile): 11 | img = cv2.cvtColor(cv2.imread(cfile, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 12 | img = resize(img, (256,256), interpolation=cv2.INTER_NEAREST) 13 | return img 14 | 15 | 16 | def get_mask(cfile): 17 | GT = cv2.imread(cfile, 0) 18 | GT = resize(GT, (256, 256), interpolation=cv2.INTER_LINEAR) 19 | GT[GT >= 0.5] = 1 20 | GT[GT < 0.5] = 0 21 | return GT 22 | 23 | 24 | def main(args, out_path): 25 | data_folder_path = Path(args['path']) 26 | imgs_list = sorted(list(data_folder_path.glob("building_[0-9]*.tif"))) 27 | masks_list = sorted(list(data_folder_path.glob("building_mask_[0-9]*.tif"))) 28 | 29 | hf_tri = h5py.File(str(out_path / "full_training_vaih.hdf5"), 'w') 30 | hf_test = h5py.File(str(out_path / "full_test_vaih.hdf5"), 'w') 31 | 32 | imgs_tri = hf_tri.create_group('imgs') 33 | mask_single_tri = hf_tri.create_group('mask_single') 34 | 35 | imgs_test = hf_test.create_group('imgs') 36 | mask_single_test = hf_test.create_group('mask_single') 37 | 38 | for image_path in imgs_list[:100]: 39 | print('training: ' + str(image_path)) 40 | img = get_img(str(image_path)) 41 | imgs_tri.create_dataset(image_path.stem, data=img, dtype=np.uint8) 42 | 43 | for image_path in imgs_list[100:]: 44 | print('validation: ' + str(image_path)) 45 | img = get_img(str(image_path)) 46 | imgs_test.create_dataset(image_path.stem, data=img, dtype=np.uint8) 47 | 48 | for mask_path in masks_list[:100]: 49 | print('training: ' + str(mask_path)) 50 | mask = get_mask(str(mask_path)) 51 | mask_single_tri.create_dataset(mask_path.stem, data=mask, dtype=np.uint8) 52 | 53 | for mask_path in masks_list[100:]: 54 | print('validation: ' + str(mask_path)) 55 | mask = get_mask(str(mask_path)) 56 | mask_single_test.create_dataset(mask_path.stem, data=mask, dtype=np.uint8) 57 | 58 | hf_tri.close() 59 | hf_test.close() 60 | 61 | 62 | if __name__ == '__main__': 63 | import argparse 64 | folder_path = Path(__file__).absolute().parent.parent.parent / "data" / "Vaihingen" 65 | folder_path.mkdir(parents=True, exist_ok=True) 66 | parser = argparse.ArgumentParser(description='Description of your program') 67 | parser.add_argument('-path', 68 | '--path', 69 | default='', 70 | help='Data path, should point on "building"', 71 | required=True) 72 | args = vars(parser.parse_args()) 73 | main(args, out_path=folder_path) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /image_sample_diff_medical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch.distributed as dist 12 | 13 | from improved_diffusion import dist_util, logger 14 | from datasets.monu import create_dataset 15 | from improved_diffusion.sampling_util import sampling_major_vote_func 16 | from improved_diffusion.script_util import ( 17 | model_and_diffusion_defaults, 18 | create_model_and_diffusion, 19 | add_dict_to_argparser, 20 | args_to_dict, 21 | ) 22 | from improved_diffusion.utils import set_random_seed 23 | import warnings 24 | warnings.filterwarnings('ignore') 25 | 26 | 27 | def main(): 28 | args = create_argparser().parse_args() 29 | 30 | original_logs_path = Path(args.model_path).parent 31 | logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote" 32 | 33 | args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text())) 34 | logger.info(args.__dict__) 35 | dist_util.setup_dist() 36 | 37 | logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}") 38 | 39 | logger.log("creating model and diffusion...") 40 | model, diffusion = create_model_and_diffusion( 41 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 42 | ) 43 | model.load_state_dict( 44 | dist_util.load_state_dict(args.model_path, map_location="cpu") 45 | ) 46 | model.to(dist_util.dev()) 47 | model.eval() 48 | 49 | test_dataset = create_dataset( 50 | mode='val', 51 | ) 52 | 53 | if args.__dict__.get("seed") is None: 54 | seed = 1234 55 | else: 56 | seed = int(args.__dict__.get("seed")) 57 | set_random_seed(seed, deterministic=True) 58 | logger.log("sampling major vote val") 59 | (logs_path / "major_vote").mkdir(exist_ok=True) 60 | step = int(Path(args.model_path).stem.split("_")[-1]) 61 | sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised, 62 | step=step, n_rounds=len(test_dataset)) 63 | 64 | dist.barrier() 65 | logger.log("sampling complete") 66 | 67 | 68 | def create_argparser(): 69 | defaults = dict( 70 | clip_denoised=True, 71 | num_samples=10000, 72 | batch_size=16, 73 | use_ddim=False, 74 | model_path="", 75 | ) 76 | defaults.update(model_and_diffusion_defaults()) 77 | parser = argparse.ArgumentParser() 78 | add_dict_to_argparser(parser, defaults) 79 | return parser 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /improved_diffusion/RRDB.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_layer(block, n_layers): 8 | layers = [] 9 | for _ in range(n_layers): 10 | layers.append(block()) 11 | return nn.Sequential(*layers) 12 | 13 | 14 | class ResidualDenseBlock_5C(nn.Module): 15 | def __init__(self, nf=64, gc=32, bias=True): 16 | super(ResidualDenseBlock_5C, self).__init__() 17 | # gc: growth channel, i.e. intermediate channels 18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # initialization 26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 27 | 28 | def forward(self, x): 29 | x1 = self.lrelu(self.conv1(x)) 30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 34 | return x5 * 0.2 + x 35 | 36 | 37 | class RRDB(nn.Module): 38 | '''Residual in Residual Dense Block''' 39 | 40 | def __init__(self, nf=1, gc=32): 41 | super(RRDB, self).__init__() 42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 45 | 46 | def forward(self, x): 47 | out = self.RDB1(x) 48 | out = self.RDB2(out) 49 | out = self.RDB3(out) 50 | return out * 0.2 + x 51 | 52 | class RRDBNet(nn.Module): 53 | def __init__(self, in_nc=3, out_nc=128, nf=64, nb=3, gc=32): 54 | super(RRDBNet, self).__init__() 55 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 56 | 57 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 58 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 59 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 60 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 62 | 63 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 64 | 65 | def forward(self, x): 66 | fea = self.conv_first(x) 67 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 68 | fea = fea + trunk 69 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 70 | 71 | return out 72 | 73 | -------------------------------------------------------------------------------- /improved_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /image_sample_diff_city.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch.distributed as dist 12 | 13 | from improved_diffusion.sampling_util import sampling_major_vote_func 14 | from improved_diffusion import dist_util, logger 15 | from datasets.city import create_dataset 16 | from improved_diffusion.script_util import ( 17 | model_and_diffusion_defaults, 18 | create_model_and_diffusion, 19 | add_dict_to_argparser, 20 | args_to_dict, 21 | ) 22 | from improved_diffusion.utils import set_random_seed 23 | import warnings 24 | warnings.filterwarnings('ignore') 25 | 26 | 27 | def main(): 28 | args = create_argparser().parse_args() 29 | 30 | original_logs_path = Path(args.model_path).parent 31 | logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote" 32 | 33 | args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text())) 34 | logger.info(args.__dict__) 35 | dist_util.setup_dist() 36 | 37 | logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}") 38 | 39 | logger.log("creating model and diffusion...") 40 | model, diffusion = create_model_and_diffusion( 41 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 42 | ) 43 | model.load_state_dict( 44 | dist_util.load_state_dict(args.model_path, map_location="cpu") 45 | ) 46 | model.to(dist_util.dev()) 47 | model.eval() 48 | 49 | test_dataset = create_dataset( 50 | class_name=args.class_name, 51 | mode='val', 52 | expansion=args.expansion, 53 | ) 54 | 55 | if args.__dict__.get("seed") is None: 56 | seed = 1234 57 | else: 58 | seed = int(args.__dict__.get("seed")) 59 | set_random_seed(seed, deterministic=True) 60 | logger.log("sampling major vote val") 61 | (logs_path / "major_vote").mkdir(exist_ok=True) 62 | step = int(Path(args.model_path).stem.split("_")[-1]) 63 | sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised, 64 | step=step, n_rounds=len(test_dataset)) 65 | 66 | dist.barrier() 67 | logger.log("sampling complete") 68 | 69 | 70 | def create_argparser(): 71 | defaults = dict( 72 | clip_denoised=True, 73 | num_samples=10000, 74 | batch_size=16, 75 | use_ddim=False, 76 | model_path="", 77 | ) 78 | defaults.update(model_and_diffusion_defaults()) 79 | parser = argparse.ArgumentParser() 80 | add_dict_to_argparser(parser, defaults) 81 | return parser 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /image_sample_diff_vaih.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import datetime 8 | import json 9 | from pathlib import Path 10 | 11 | import torch.distributed as dist 12 | from mpi4py import MPI 13 | 14 | from improved_diffusion import dist_util, logger 15 | from improved_diffusion.sampling_util import sampling_major_vote_func 16 | from improved_diffusion.script_util import ( 17 | model_and_diffusion_defaults, 18 | create_model_and_diffusion, 19 | add_dict_to_argparser, 20 | args_to_dict, 21 | ) 22 | from improved_diffusion.utils import set_random_seed 23 | from datasets.vaih import VaihDataset 24 | import warnings 25 | warnings.filterwarnings('ignore') 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | 31 | original_logs_path = Path(args.model_path).parent 32 | logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote" 33 | 34 | args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text())) 35 | logger.info(args.__dict__) 36 | dist_util.setup_dist() 37 | 38 | logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}") 39 | 40 | logger.log("creating model and diffusion...") 41 | model, diffusion = create_model_and_diffusion( 42 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 43 | ) 44 | model.load_state_dict( 45 | dist_util.load_state_dict(args.model_path, map_location="cpu") 46 | ) 47 | model.to(dist_util.dev()) 48 | model.eval() 49 | 50 | test_dataset = VaihDataset( 51 | mode='val', 52 | image_size=args.image_size, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | ) 56 | 57 | if args.__dict__.get("seed") is None: 58 | seed = 1234 59 | else: 60 | seed = int(args.__dict__.get("seed")) 61 | set_random_seed(seed, deterministic=True) 62 | logger.log("sampling major vote val") 63 | (logs_path / "major_vote").mkdir(exist_ok=True) 64 | step = int(Path(args.model_path).stem.split("_")[-1]) 65 | sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised, 66 | step=step, n_rounds=len(test_dataset)) 67 | 68 | dist.barrier() 69 | logger.log("sampling complete") 70 | 71 | 72 | def create_argparser(): 73 | defaults = dict( 74 | clip_denoised=True, 75 | num_samples=10000, 76 | batch_size=16, 77 | use_ddim=False, 78 | model_path="", 79 | ) 80 | defaults.update(model_and_diffusion_defaults()) 81 | parser = argparse.ArgumentParser() 82 | add_dict_to_argparser(parser, defaults) 83 | return parser 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /improved_diffusion/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.morphology import binary_dilation, disk 3 | 4 | 5 | def WCov_metric(pred, gt_mask): 6 | A1 = float(np.count_nonzero(pred)) 7 | A2 = float(np.count_nonzero(gt_mask)) 8 | if A1 >= A2: return A2 / A1 9 | if A2 > A1: return A1 / A2 10 | 11 | 12 | def FBound_metric(pred, gt_mask): 13 | tmp1 = db_eval_boundary(pred, gt_mask, 1)[0] 14 | tmp2 = db_eval_boundary(pred, gt_mask, 2)[0] 15 | tmp3 = db_eval_boundary(pred, gt_mask, 3)[0] 16 | tmp4 = db_eval_boundary(pred, gt_mask, 4)[0] 17 | tmp5 = db_eval_boundary(pred, gt_mask, 5)[0] 18 | return (tmp1 + tmp2 + tmp3 + tmp4 + tmp5) / 5.0 19 | 20 | 21 | def db_eval_boundary(foreground_mask, gt_mask, bound_th): 22 | """ 23 | Compute mean,recall and decay from per-frame evaluation. 24 | Calculates precision/recall for boundaries between foreground_mask and 25 | gt_mask using morphological operators to speed it up. 26 | Arguments: 27 | foreground_mask (ndarray): binary segmentation image. 28 | gt_mask (ndarray): binary annotated image. 29 | Returns: 30 | F (float): boundaries F-measure 31 | P (float): boundaries precision 32 | R (float): boundaries recall 33 | """ 34 | assert np.atleast_3d(foreground_mask).shape[2] == 1 35 | 36 | bound_pix = bound_th if bound_th >= 1 else \ 37 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 38 | 39 | # Get the pixel boundaries of both masks 40 | fg_boundary = seg2bmap(foreground_mask) 41 | gt_boundary = seg2bmap(gt_mask) 42 | 43 | fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 44 | gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 45 | 46 | # Get the intersection 47 | gt_match = gt_boundary * fg_dil 48 | fg_match = fg_boundary * gt_dil 49 | 50 | # Area of the intersection 51 | n_fg = np.sum(fg_boundary) 52 | n_gt = np.sum(gt_boundary) 53 | 54 | # % Compute precision and recall 55 | if n_fg == 0 and n_gt > 0: 56 | precision = 1 57 | recall = 0 58 | elif n_fg > 0 and n_gt == 0: 59 | precision = 0 60 | recall = 1 61 | elif n_fg == 0 and n_gt == 0: 62 | precision = 1 63 | recall = 1 64 | else: 65 | precision = np.sum(fg_match) / float(n_fg) 66 | recall = np.sum(gt_match) / float(n_gt) 67 | 68 | # Compute F measure 69 | if precision + recall == 0: 70 | F = 0 71 | else: 72 | F = 2 * precision * recall / (precision + recall) 73 | 74 | return F, precision, recall, np.sum(fg_match), n_fg, np.sum(gt_match), n_gt 75 | 76 | 77 | def seg2bmap(seg, width=None, height=None): 78 | """ 79 | From a segmentation, compute a binary boundary map with 1 pixel wide 80 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 81 | origin from the actual segment boundary. 82 | Arguments: 83 | seg : Segments labeled from 1..k. 84 | width : Width of desired bmap <= seg.shape[1] 85 | height : Height of desired bmap <= seg.shape[0] 86 | Returns: 87 | bmap (ndarray): Binary boundary map. 88 | David Martin 89 | January 2003 90 | """ 91 | seg = seg.astype(bool) 92 | seg[seg > 0] = 1 93 | 94 | assert np.atleast_3d(seg).shape[2] == 1 95 | 96 | width = seg.shape[1] if width is None else width 97 | height = seg.shape[0] if height is None else height 98 | 99 | h, w = seg.shape[:2] 100 | 101 | ar1 = float(width) / float(height) 102 | ar2 = float(w) / float(h) 103 | 104 | assert not (width > w | height > h | abs(ar1 - ar2) > 0.01), \ 105 | 'Can''t convert %dx%d seg to %dx%d bmap.' % (w, h, width, height) 106 | 107 | e = np.zeros_like(seg) 108 | s = np.zeros_like(seg) 109 | se = np.zeros_like(seg) 110 | 111 | e[:, :-1] = seg[:, 1:] 112 | s[:-1, :] = seg[1:, :] 113 | se[:-1, :-1] = seg[1:, 1:] 114 | 115 | b = seg ^ e | seg ^ s | seg ^ se 116 | b[-1, :] = seg[-1, :] ^ e[-1, :] 117 | b[:, -1] = seg[:, -1] ^ s[:, -1] 118 | b[-1, -1] = 0 119 | 120 | if w == width and h == height: 121 | bmap = b 122 | else: 123 | bmap = np.zeros((height, width)) 124 | for x in range(w): 125 | for y in range(h): 126 | if b[y, x]: 127 | j = 1 + np.floor((y - 1) + height / h) 128 | i = 1 + np.floor((x - 1) + width / h) 129 | bmap[j, i] = 1 130 | 131 | return bmap -------------------------------------------------------------------------------- /improved_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def load_data( 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 10 | ): 11 | """ 12 | For a dataset, create a generator over (images, kwargs) pairs. 13 | 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 15 | more keys, each of which map to a batched Tensor of their own. 16 | The kwargs dict can be used for class labels, in which case the key is "y" 17 | and the values are integer tensors of class labels. 18 | 19 | :param data_dir: a dataset directory. 20 | :param batch_size: the batch size of each returned pair. 21 | :param image_size: the size to which images are resized. 22 | :param class_cond: if True, include a "y" key in returned dicts for class 23 | label. If classes are not available and this is true, an 24 | exception will be raised. 25 | :param deterministic: if True, yield results in a deterministic order. 26 | """ 27 | if not data_dir: 28 | raise ValueError("unspecified data directory") 29 | all_files = _list_image_files_recursively(data_dir) 30 | classes = None 31 | if class_cond: 32 | # Assume classes are the first part of the filename, 33 | # before an underscore. 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 36 | classes = [sorted_classes[x] for x in class_names] 37 | dataset = ImageDataset( 38 | image_size, 39 | all_files, 40 | classes=classes, 41 | shard=MPI.COMM_WORLD.Get_rank(), 42 | num_shards=MPI.COMM_WORLD.Get_size(), 43 | ) 44 | if deterministic: 45 | loader = DataLoader( 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True 47 | ) 48 | else: 49 | loader = DataLoader( 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True 51 | ) 52 | while True: 53 | yield from loader 54 | 55 | 56 | def _list_image_files_recursively(data_dir): 57 | results = [] 58 | for entry in sorted(bf.listdir(data_dir)): 59 | full_path = bf.join(data_dir, entry) 60 | ext = entry.split(".")[-1] 61 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 62 | results.append(full_path) 63 | elif bf.isdir(full_path): 64 | results.extend(_list_image_files_recursively(full_path)) 65 | return results 66 | 67 | 68 | class ImageDataset(Dataset): 69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1): 70 | super().__init__() 71 | self.resolution = resolution 72 | self.local_images = image_paths[shard:][::num_shards] 73 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 74 | 75 | def __len__(self): 76 | return len(self.local_images) 77 | 78 | def __getitem__(self, idx): 79 | path = self.local_images[idx] 80 | with bf.BlobFile(path, "rb") as f: 81 | pil_image = Image.open(f) 82 | pil_image.load() 83 | 84 | # We are not on a new enough PIL to support the `reducing_gap` 85 | # argument, which uses BOX downsampling at powers of two first. 86 | # Thus, we do it by hand to improve downsample quality. 87 | while min(*pil_image.size) >= 2 * self.resolution: 88 | pil_image = pil_image.resize( 89 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 90 | ) 91 | 92 | scale = self.resolution / min(*pil_image.size) 93 | pil_image = pil_image.resize( 94 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 95 | ) 96 | 97 | arr = np.array(pil_image.convert("RGB")) 98 | crop_y = (arr.shape[0] - self.resolution) // 2 99 | crop_x = (arr.shape[1] - self.resolution) // 2 100 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 101 | arr = arr.astype(np.float32) / 127.5 - 1 102 | 103 | out_dict = {} 104 | if self.local_classes is not None: 105 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 106 | return np.transpose(arr, [2, 0, 1]), out_dict 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official repository of the paper [SegDiff: Image Segmentation with Diffusion Probabilistic Models](https://arxiv.org/abs/2112.00390) 2 | 3 | The code is based on [Improved Denoising Diffusion Probabilistic Models.](https://github.com/openai/improved-diffusion) 4 | 5 | ## Installation 6 | ### Conda environment 7 | To create the environment use the conda environment command 8 | ``` 9 | conda env create -f environment.yml 10 | ``` 11 | 12 | ## Project structure and data preparations 13 | our project need to be arranged in the following format 14 | 15 | ``` 16 | segdiff/ # git clone the source code here 17 | 18 | data/ # the root of the data folders 19 | Vaihingen/ 20 | Medical/MoNuSeg/ 21 | cityscapes_instances/ 22 | ``` 23 | 24 | ### Vaihingen 25 | 26 | download the dataset from [link](https://drive.google.com/file/d/1nenpWH4BdplSiHdfXs0oYfiA5qL42plB/view) 27 | and unzip it's content (folder named buildings), execute the preprocess 28 | ``` 29 | datasets/preprocess_vaihingen.py --path building-folder-path 30 | ``` 31 | 32 | Vaihingen dataset should have the following format 33 | ``` 34 | Vaihingen/ 35 | full_test_vaih.hdf5 36 | full_training_vaih.hdf5 37 | ``` 38 | 39 | ### MonuSeg 40 | general [website](https://monuseg.grand-challenge.org/) of the challenge, 41 | download the dataset 42 | [train](https://drive.google.com/file/d/1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA/view?usp=sharing) 43 | and [test](https://drive.google.com/file/d/1NKkSQ5T0ZNQ8aUhh0a8Dt2YKYCQXIViw/view?usp=sharing) sets. 44 | 45 | launch the matlab [code](https://drive.google.com/file/d/1YDtIiLZX0lQzZp_JbqneHXHvRo45ZWGX/view) 46 | for preprocess 47 | 48 | MonuSeg dataset should have the following format 49 | ``` 50 | MonuSeg/ 51 | Test/ 52 | img/ 53 | XX.tif 54 | mask/ 55 | XX.png 56 | Training/ 57 | img/ 58 | XX.tif 59 | mask/ 60 | XX.png 61 | ``` 62 | 63 | ### Cityscapes 64 | 65 | download [cityscapes](https://www.cityscapes-dataset.com) dataset with the splits from 66 | [PolyRNN++](https://github.com/fidler-lab/polyrnn-pp), follow the instructions [here](https://github.com/shirgur/ACDRNet) for preparations 67 | 68 | To get cityscapes_final_v5 annotations you can sign up to get PolygonRNN++ code here http://www.cs.toronto.edu/polyrnn/code_signup/ the cityscapes_final_v5 folder is inside the data folder 69 | 70 | Cityscapes dataset should have the following format 71 | ``` 72 | cityscapes_instances/ 73 | full/ 74 | all_classes_instances.json 75 | train/ 76 | all_classes_instances.json 77 | train_val/ 78 | all_classes_instances.json 79 | val/ 80 | all_classes_instances.json 81 | all_images.hdf5 82 | ``` 83 | 84 | 85 | ## Train and Evaluate 86 | Execute the following commands (multi gpu is supported for training, set the gpus with CUDA_VISIBLE_DEVICES and -n for the actual number) 87 | 88 | Training options: 89 | ``` 90 | # Training 91 | --batch-size Batch size 92 | --lr Learning rate 93 | 94 | # Architecture 95 | --rrdb_blocks Number of rrdb blocks 96 | --dropout Dropout 97 | --diffusion_steps number of steps for the diffusion model 98 | 99 | # Cityscapes 100 | --class_name name of class of cityscapes, options are ["bike", "bus", "person", "train", "motorcycle", "car", "rider"] 101 | --expansion boolean flag, for expansion setting or not 102 | 103 | # Misc 104 | --save_interval interval for saving model weights 105 | ``` 106 | 107 | ### MonuSeg 108 | Training script example: 109 | ``` 110 | CUDA_VISIBLE_DEVICES=0,1,2,3 mpiexec -n 4 image_train_diff_medical.py --rrdb_blocks 12 --batch_size 2 --lr 0.0001 --diffusion_steps 100 111 | ``` 112 | 113 | Evaluation script example: 114 | ``` 115 | CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_medical.py --model_path path-for-model-weights 116 | ``` 117 | 118 | ### Cityscapes 119 | Training script example: 120 | ``` 121 | CUDA_VISIBLE_DEVICES=0,1 mpiexec -n 2 python image_train_diff_city.py --class_name "train" --expansion True --rrdb_blocks 15 --lr 0.0001 --batch_size 15 --diffusion_steps 100 122 | ``` 123 | 124 | Evaluation script example: 125 | ``` 126 | CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_city.py --model_path path-for-model-weights 127 | ``` 128 | 129 | ### Vaihingen 130 | Training script example: 131 | ``` 132 | CUDA_VISIBLE_DEVICES=0,1 mpiexec -n 2 python image_train_diff_vaih.py --lr 0.0001 --batch_size 4 --dropout 0.1 --rrdb_blocks 6 --diffusion_steps 100 133 | ``` 134 | 135 | Evaluation script example: 136 | ``` 137 | CUDA_VISIBLE_DEVICES=0 mpiexec -n 1 python image_sample_diff_vaih.py --model_path path-for-model-weights 138 | ``` 139 | 140 | ## Citation 141 | ``` 142 | @article{amit2021segdiff, 143 | title={Segdiff: Image segmentation with diffusion probabilistic models}, 144 | author={Amit, Tomer and Nachmani, Eliya and Shaharbany, Tal and Wolf, Lior}, 145 | journal={arXiv preprint arXiv:2112.00390}, 146 | year={2021} 147 | } 148 | ``` 149 | 150 | -------------------------------------------------------------------------------- /image_train_diff_medical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | import json 8 | import os 9 | from pathlib import Path 10 | 11 | import git 12 | from mpi4py import MPI 13 | 14 | from improved_diffusion import dist_util, logger 15 | from datasets.monu import load_data, create_dataset 16 | from improved_diffusion.resample import create_named_schedule_sampler 17 | from improved_diffusion.script_util import ( 18 | model_and_diffusion_defaults, 19 | create_model_and_diffusion, 20 | args_to_dict, 21 | add_dict_to_argparser, 22 | ) 23 | from improved_diffusion.train_util import TrainLoop 24 | from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations 25 | import warnings 26 | warnings.filterwarnings('ignore') 27 | 28 | 29 | def main(): 30 | args = create_argparser().parse_args() 31 | args.use_fp16 = True 32 | args.clip_denoised = False 33 | args.learn_sigma = False 34 | args.sigma_small = False 35 | args.image_size = 256 36 | args.num_res_blocks = 3 37 | args.noise_schedule = "linear" 38 | args.rescale_learned_sigmas = False 39 | args.rescale_timesteps = False 40 | args.use_scale_shift_norm = False 41 | args.deeper_net = True 42 | # args.start_print_iter = 4 43 | # args.save_interval = 4 44 | 45 | exp_name = f"monu_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{MPI.COMM_WORLD.Get_rank()}" 46 | logs_root = Path(__file__).absolute().parent.parent / "logs" 47 | log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}" 48 | os.environ["OPENAI_LOGDIR"] = str(log_path) 49 | set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True) 50 | set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank()) 51 | dist_util.setup_dist() 52 | logger.configure(dir=str(log_path)) 53 | 54 | if args.resume_checkpoint: 55 | resumed_checkpoint_arg = args.resume_checkpoint 56 | args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text())) 57 | args.resume_checkpoint = resumed_checkpoint_arg 58 | 59 | logger.info(args.__dict__) 60 | 61 | (Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4)) 62 | logger.info(f"log folder path: {Path(log_path).resolve()}") 63 | 64 | repo = git.Repo(search_parent_directories=True) 65 | sha = repo.head.object.hexsha 66 | 67 | logger.log(f"git commit hash {sha}") 68 | 69 | logger.log("creating model and diffusion...") 70 | model, diffusion = create_model_and_diffusion( 71 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 72 | ) 73 | model.to(dist_util.dev()) 74 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 75 | 76 | logger.log("creating data loader...") 77 | data = load_data( 78 | data_dir=args.data_dir, 79 | batch_size=args.batch_size, 80 | image_size=args.image_size, 81 | class_cond=args.class_cond, 82 | class_name=args.class_name, 83 | expansion=args.expansion 84 | ) 85 | val_dataset = create_dataset( 86 | mode='val', 87 | image_size=args.image_size 88 | ) 89 | 90 | logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}") 91 | 92 | logger.log("training...") 93 | TrainLoop( 94 | model=model, 95 | diffusion=diffusion, 96 | data=data, 97 | batch_size=args.batch_size, 98 | microbatch=args.microbatch, 99 | lr=args.lr, 100 | ema_rate=args.ema_rate, 101 | log_interval=args.log_interval, 102 | save_interval=args.save_interval, 103 | resume_checkpoint=args.resume_checkpoint, 104 | use_fp16=args.use_fp16, 105 | fp16_scale_growth=args.fp16_scale_growth, 106 | schedule_sampler=schedule_sampler, 107 | weight_decay=args.weight_decay, 108 | lr_anneal_steps=args.lr_anneal_steps, 109 | clip_denoised=args.clip_denoised, 110 | logger=logger, 111 | image_size=args.image_size, 112 | val_dataset=val_dataset, 113 | run_without_test=args.run_without_test, 114 | args=args 115 | # dist_util=dist_util, 116 | ).run_loop(max_iter=300000, start_print_iter=args.start_print_iter) 117 | 118 | 119 | def create_argparser(): 120 | defaults = dict( 121 | data_dir="", 122 | schedule_sampler="uniform", 123 | lr=0.00002, 124 | weight_decay=0.0, 125 | lr_anneal_steps=0, 126 | clip_denoised=False, 127 | batch_size=4, 128 | microbatch=-1, # -1 disables microbatches 129 | ema_rate="0.9999", # comma-separated list of EMA values 130 | save_interval=5000, 131 | start_print_iter=75000, 132 | log_interval=200, 133 | run_without_test=False, 134 | resume_checkpoint="", 135 | use_fp16=False, 136 | fp16_scale_growth=1e-3, 137 | ) 138 | defaults.update(model_and_diffusion_defaults()) 139 | parser = argparse.ArgumentParser() 140 | add_dict_to_argparser(parser, defaults) 141 | return parser 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /image_train_diff_vaih.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | import json 8 | import os 9 | from pathlib import Path 10 | 11 | import git 12 | from mpi4py import MPI 13 | 14 | from improved_diffusion import dist_util, logger 15 | from datasets.vaih import load_data 16 | from improved_diffusion.resample import create_named_schedule_sampler 17 | from improved_diffusion.script_util import ( 18 | model_and_diffusion_defaults, 19 | create_model_and_diffusion, 20 | args_to_dict, 21 | add_dict_to_argparser, 22 | ) 23 | from improved_diffusion.train_util import TrainLoop 24 | from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations 25 | from datasets.vaih import VaihDataset 26 | import warnings 27 | warnings.filterwarnings('ignore') 28 | 29 | 30 | def main(): 31 | args = create_argparser().parse_args() 32 | args.use_fp16 = True 33 | args.clip_denoised = False 34 | args.learn_sigma = False 35 | args.sigma_small = False 36 | args.num_channels = 128 37 | args.image_size = 256 38 | args.num_res_blocks = 3 39 | args.noise_schedule = "linear" 40 | args.rescale_learned_sigmas = False 41 | args.rescale_timesteps = False 42 | args.use_scale_shift_norm = False 43 | args.deeper_net = True 44 | 45 | exp_name = f"vaih_256_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{MPI.COMM_WORLD.Get_rank()}" 46 | 47 | logs_root = Path(__file__).absolute().parent.parent / "logs" 48 | log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}" 49 | os.environ["OPENAI_LOGDIR"] = str(log_path) 50 | set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True) 51 | set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank()) 52 | dist_util.setup_dist() 53 | logger.configure(dir=str(log_path)) 54 | 55 | if args.resume_checkpoint: 56 | resumed_checkpoint_arg = args.resume_checkpoint 57 | args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text())) 58 | args.resume_checkpoint = resumed_checkpoint_arg 59 | 60 | logger.info(args.__dict__) 61 | 62 | (Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4)) 63 | logger.info(f"log folder path: {Path(log_path).resolve()}") 64 | 65 | repo = git.Repo(search_parent_directories=True) 66 | sha = repo.head.object.hexsha 67 | 68 | logger.log(f"git commit hash {sha}") 69 | 70 | logger.log("creating model and diffusion...") 71 | model, diffusion = create_model_and_diffusion( 72 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 73 | ) 74 | model.to(dist_util.dev()) 75 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 76 | 77 | logger.log("creating data loader...") 78 | data = load_data( 79 | data_dir=args.data_dir, 80 | batch_size=args.batch_size, 81 | image_size=args.image_size, 82 | class_cond=args.class_cond 83 | ) 84 | val_dataset = VaihDataset( 85 | mode='val', 86 | image_size=args.image_size, 87 | shard=MPI.COMM_WORLD.Get_rank(), 88 | num_shards=MPI.COMM_WORLD.Get_size(), 89 | ) 90 | 91 | logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}") 92 | 93 | logger.log("training...") 94 | TrainLoop( 95 | model=model, 96 | diffusion=diffusion, 97 | data=data, 98 | batch_size=args.batch_size, 99 | microbatch=args.microbatch, 100 | lr=args.lr, 101 | ema_rate=args.ema_rate, 102 | log_interval=args.log_interval, 103 | save_interval=args.save_interval, 104 | resume_checkpoint=args.resume_checkpoint, 105 | use_fp16=args.use_fp16, 106 | fp16_scale_growth=args.fp16_scale_growth, 107 | schedule_sampler=schedule_sampler, 108 | weight_decay=args.weight_decay, 109 | lr_anneal_steps=args.lr_anneal_steps, 110 | clip_denoised=args.clip_denoised, 111 | logger=logger, 112 | image_size=args.image_size, 113 | val_dataset=val_dataset, 114 | run_without_test=args.run_without_test, 115 | args=args 116 | # dist_util=dist_util, 117 | ).run_loop(max_iter=300000, start_print_iter=args.start_print_iter) 118 | 119 | 120 | def create_argparser(): 121 | defaults = dict( 122 | data_dir="", 123 | schedule_sampler="uniform", 124 | lr=0.00002, 125 | weight_decay=0.0, 126 | lr_anneal_steps=0, 127 | clip_denoised=False, 128 | batch_size=4, 129 | microbatch=-1, # -1 disables microbatches 130 | ema_rate="0.9999", # comma-separated list of EMA values 131 | save_interval=5000, 132 | start_print_iter=75000, 133 | log_interval=200, 134 | run_without_test=False, 135 | resume_checkpoint="", 136 | use_fp16=False, 137 | fp16_scale_growth=1e-3, 138 | ) 139 | defaults.update(model_and_diffusion_defaults()) 140 | parser = argparse.ArgumentParser() 141 | add_dict_to_argparser(parser, defaults) 142 | return parser 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /image_train_diff_city.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | import json 8 | import os 9 | from pathlib import Path 10 | 11 | import git 12 | from mpi4py import MPI 13 | 14 | from improved_diffusion import dist_util, logger 15 | from datasets.city import load_data, create_dataset 16 | from improved_diffusion.resample import create_named_schedule_sampler 17 | from improved_diffusion.script_util import ( 18 | model_and_diffusion_defaults, 19 | create_model_and_diffusion, 20 | args_to_dict, 21 | add_dict_to_argparser, 22 | ) 23 | from improved_diffusion.train_util import TrainLoop 24 | from improved_diffusion.utils import set_random_seed, set_random_seed_for_iterations 25 | import warnings 26 | warnings.filterwarnings('ignore') 27 | 28 | 29 | def main(): 30 | args = create_argparser().parse_args() 31 | args.use_fp16 = True 32 | args.clip_denoised = False 33 | args.learn_sigma = False 34 | args.sigma_small = False 35 | args.num_channels = 128 36 | args.image_size = 128 37 | args.num_res_blocks = 3 38 | args.noise_schedule = "linear" 39 | args.rescale_learned_sigmas = False 40 | args.rescale_timesteps = False 41 | args.use_scale_shift_norm = False 42 | args.deeper_net = True 43 | 44 | exp_name = f"city_{args.rrdb_blocks}_{args.lr}_{args.batch_size}_{args.diffusion_steps}_{str(args.dropout)}_{args.class_name}_{MPI.COMM_WORLD.Get_rank()}" 45 | if args.expansion: 46 | exp_name += "_expansion" 47 | logs_root = Path(__file__).absolute().parent.parent / "logs" 48 | log_path = logs_root / f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}_{exp_name}" 49 | os.environ["OPENAI_LOGDIR"] = str(log_path) 50 | set_random_seed(MPI.COMM_WORLD.Get_rank(), deterministic=True) 51 | set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank()) 52 | dist_util.setup_dist() 53 | logger.configure(dir=str(log_path)) 54 | 55 | if args.resume_checkpoint: 56 | resumed_checkpoint_arg = args.resume_checkpoint 57 | args.__dict__.update(json.loads((Path(args.resume_checkpoint) / 'args.json').read_text())) 58 | args.resume_checkpoint = resumed_checkpoint_arg 59 | 60 | logger.info(args.__dict__) 61 | 62 | (Path(log_path) / 'args.json').write_text(json.dumps(args.__dict__, indent=4)) 63 | logger.info(f"log folder path: {Path(log_path).resolve()}") 64 | 65 | repo = git.Repo(search_parent_directories=True) 66 | sha = repo.head.object.hexsha 67 | 68 | logger.log(f"git commit hash {sha}") 69 | 70 | logger.log("creating model and diffusion...") 71 | model, diffusion = create_model_and_diffusion( 72 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 73 | ) 74 | model.to(dist_util.dev()) 75 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 76 | 77 | logger.log("creating data loader...") 78 | data = load_data( 79 | data_dir=args.data_dir, 80 | batch_size=args.batch_size, 81 | image_size=args.image_size, 82 | class_cond=args.class_cond, 83 | class_name=args.class_name, 84 | expansion=args.expansion 85 | ) 86 | val_dataset = create_dataset( 87 | class_name=args.class_name, 88 | mode='val', 89 | expansion=args.expansion, 90 | ) 91 | 92 | logger.log(f"gpu {MPI.COMM_WORLD.Get_rank()} / {MPI.COMM_WORLD.Get_size()} val length {len(val_dataset)}") 93 | 94 | logger.log("training...") 95 | TrainLoop( 96 | model=model, 97 | diffusion=diffusion, 98 | data=data, 99 | batch_size=args.batch_size, 100 | microbatch=args.microbatch, 101 | lr=args.lr, 102 | ema_rate=args.ema_rate, 103 | log_interval=args.log_interval, 104 | save_interval=args.save_interval, 105 | resume_checkpoint=args.resume_checkpoint, 106 | use_fp16=args.use_fp16, 107 | fp16_scale_growth=args.fp16_scale_growth, 108 | schedule_sampler=schedule_sampler, 109 | weight_decay=args.weight_decay, 110 | lr_anneal_steps=args.lr_anneal_steps, 111 | clip_denoised=args.clip_denoised, 112 | logger=logger, 113 | image_size=args.image_size, 114 | val_dataset=val_dataset, 115 | run_without_test=args.run_without_test, 116 | args=args 117 | # dist_util=dist_util, 118 | ).run_loop(max_iter=300000, start_print_iter=args.start_print_iter) 119 | 120 | 121 | def create_argparser(): 122 | defaults = dict( 123 | data_dir="", 124 | schedule_sampler="uniform", 125 | lr=0.00002, 126 | weight_decay=0.0, 127 | lr_anneal_steps=0, 128 | clip_denoised=False, 129 | batch_size=4, 130 | microbatch=-1, # -1 disables microbatches 131 | ema_rate="0.9999", # comma-separated list of EMA values 132 | save_interval=5000, 133 | start_print_iter=75000, 134 | log_interval=200, 135 | run_without_test=False, 136 | resume_checkpoint="", 137 | use_fp16=False, 138 | fp16_scale_growth=1e-3, 139 | ) 140 | defaults.update(model_and_diffusion_defaults()) 141 | parser = argparse.ArgumentParser() 142 | add_dict_to_argparser(parser, defaults) 143 | return parser 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /improved_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel( 102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 103 | ) 104 | 105 | def _scale_timesteps(self, t): 106 | # Scaling is done by the wrapped model. 107 | return t 108 | 109 | 110 | class _WrappedModel: 111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 112 | self.model = model 113 | self.timestep_map = timestep_map 114 | self.rescale_timesteps = rescale_timesteps 115 | self.original_num_steps = original_num_steps 116 | 117 | def __call__(self, x, ts, **kwargs): 118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 119 | new_ts = map_tensor[ts] 120 | if self.rescale_timesteps: 121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 122 | return self.model(x, new_ts, **kwargs) 123 | -------------------------------------------------------------------------------- /improved_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def swap_ema(target_params, source_params): 69 | """ 70 | Update target parameters to be closer to those of source parameters using 71 | an exponential moving average. 72 | 73 | :param target_params: the target parameter sequence. 74 | :param source_params: the source parameter sequence. 75 | """ 76 | for targ, src in zip(target_params, source_params): 77 | temp = targ.data.clone() 78 | targ.data.copy_(src.data) 79 | src.data.copy_(temp) 80 | 81 | 82 | def zero_module(module): 83 | """ 84 | Zero out the parameters of a module and return it. 85 | """ 86 | for p in module.parameters(): 87 | p.detach().zero_() 88 | return module 89 | 90 | 91 | def scale_module(module, scale): 92 | """ 93 | Scale the parameters of a module and return it. 94 | """ 95 | for p in module.parameters(): 96 | p.detach().mul_(scale) 97 | return module 98 | 99 | 100 | def mean_flat(tensor): 101 | """ 102 | Take the mean over all non-batch dimensions. 103 | """ 104 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 105 | 106 | 107 | def normalization(channels): 108 | """ 109 | Make a standard normalization layer. 110 | 111 | :param channels: number of input channels. 112 | :return: an nn.Module for normalization. 113 | """ 114 | return GroupNorm32(32, channels) 115 | 116 | 117 | def timestep_embedding(timesteps, dim, max_period=10000): 118 | """ 119 | Create sinusoidal timestep embeddings. 120 | 121 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 122 | These may be fractional. 123 | :param dim: the dimension of the output. 124 | :param max_period: controls the minimum frequency of the embeddings. 125 | :return: an [N x dim] Tensor of positional embeddings. 126 | """ 127 | half = dim // 2 128 | freqs = th.exp( 129 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 130 | ).to(device=timesteps.device) 131 | args = timesteps[:, None].float() * freqs[None] 132 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 133 | if dim % 2: 134 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 135 | return embedding 136 | 137 | 138 | def checkpoint(func, inputs, params, flag): 139 | """ 140 | Evaluate a function without caching intermediate activations, allowing for 141 | reduced memory at the expense of extra compute in the backward pass. 142 | 143 | :param func: the function to evaluate. 144 | :param inputs: the argument sequence to pass to `func`. 145 | :param params: a sequence of parameters `func` depends on but does not 146 | explicitly take as arguments. 147 | :param flag: if False, disable gradient checkpointing. 148 | """ 149 | if flag: 150 | args = tuple(inputs) + tuple(params) 151 | return CheckpointFunction.apply(func, len(inputs), *args) 152 | else: 153 | return func(*inputs) 154 | 155 | 156 | class CheckpointFunction(th.autograd.Function): 157 | @staticmethod 158 | def forward(ctx, run_function, length, *args): 159 | ctx.run_function = run_function 160 | ctx.input_tensors = list(args[:length]) 161 | ctx.input_params = list(args[length:]) 162 | with th.no_grad(): 163 | output_tensors = ctx.run_function(*ctx.input_tensors) 164 | return output_tensors 165 | 166 | @staticmethod 167 | def backward(ctx, *output_grads): 168 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 169 | with th.enable_grad(): 170 | # Fixes a bug where the first op in run_function modifies the 171 | # Tensor storage in place, which is not allowed for detach()'d 172 | # Tensors. 173 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 174 | output_tensors = ctx.run_function(*shallow_copies) 175 | input_grads = th.autograd.grad( 176 | output_tensors, 177 | ctx.input_tensors + ctx.input_params, 178 | output_grads, 179 | allow_unused=True, 180 | ) 181 | del ctx.input_tensors 182 | del ctx.input_params 183 | del output_tensors 184 | return (None, None) + input_grads 185 | -------------------------------------------------------------------------------- /datasets/vaih.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from matplotlib import pyplot as plt 8 | from mpi4py import MPI 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from datasets.transforms import \ 12 | Compose, ToPILImage, Resize, RandomHorizontalFlip, ToTensor, Normalize, \ 13 | RandomAffine, RandomVerticalFlip, ColorJitter 14 | 15 | 16 | def load_data( 17 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 18 | ): 19 | """ 20 | For a dataset, create a generator over (images, kwargs) pairs. 21 | 22 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 23 | more keys, each of which map to a batched Tensor of their own. 24 | The kwargs dict can be used for class labels, in which case the key is "y" 25 | and the values are integer tensors of class labels. 26 | 27 | :param data_dir: a dataset directory. 28 | :param batch_size: the batch size of each returned pair. 29 | :param image_size: the size to which images are resized. 30 | :param class_cond: if True, include a "y" key in returned dicts for class 31 | label. If classes are not available and this is true, an 32 | exception will be raised. 33 | :param deterministic: if True, yield results in a deterministic order. 34 | """ 35 | 36 | dataset = VaihDataset( 37 | mode='train', 38 | image_size=image_size, 39 | shard=MPI.COMM_WORLD.Get_rank(), 40 | num_shards=MPI.COMM_WORLD.Get_size(), 41 | ) 42 | if deterministic: 43 | loader = DataLoader( 44 | dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True 45 | ) 46 | else: 47 | loader = DataLoader( 48 | dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True 49 | ) 50 | while True: 51 | yield from loader 52 | 53 | 54 | class VaihDataset(Dataset): 55 | 56 | CLASSES = ('building',) 57 | 58 | PALETTE = [[255, 0, 0]] 59 | 60 | def __init__(self, mode, std=np.array([0.22645572 * 255, 0.15276193 * 255, 0.140702 * 255]), 61 | mean=np.array([0.47341759 * 255, 0.28791303 * 255, 0.2850705 * 255]), no_aug=False, 62 | image_size=256, max_data_size=None, shard=0, num_shards=1, small_image_size=None): 63 | 64 | self.mode = mode 65 | self.mean = torch.from_numpy(mean) 66 | self.std = torch.from_numpy(std) 67 | 68 | if mode == 'train' and not no_aug: 69 | self.transformations = Compose([ToPILImage(), 70 | Resize(size=(image_size, image_size)), 71 | RandomAffine(degrees=[0, 360], scale=(0.75, 1.5)), 72 | ColorJitter(brightness=0.6, 73 | contrast=0.5, 74 | saturation=0.4, 75 | hue=0.025), 76 | RandomVerticalFlip(), 77 | RandomHorizontalFlip(), 78 | ToTensor(), 79 | Normalize(self.mean, self.std)]) 80 | else: 81 | self.transformations = Compose([ToPILImage(), 82 | Resize(size=(image_size, image_size)), 83 | ToTensor(), 84 | Normalize(self.mean, self.std)]) 85 | if mode == 'train': 86 | self.data_length = 100 87 | else: 88 | self.data_length = 68 89 | 90 | if max_data_size is not None: 91 | self.data_length = max_data_size 92 | 93 | if self.mode == 'train': 94 | self.data = h5py.File( 95 | str(Path(__file__).absolute().parent.parent.parent / "data/Vaihingen/full_training_vaih.hdf5"), 'r') 96 | 97 | else: 98 | self.data = h5py.File( 99 | str(Path(__file__).absolute().parent.parent.parent / "data/Vaihingen/full_test_vaih.hdf5"), 'r') 100 | 101 | self.small_image_size = small_image_size 102 | self.mask = self.data['mask_single'] 103 | self.imgs = self.data['imgs'] 104 | self.img_list = list(self.imgs)[shard::num_shards] 105 | self.mask_list = list(self.mask)[shard::num_shards] 106 | 107 | def __len__(self): 108 | return len(self.img_list) 109 | 110 | def __getitem__(self, item): 111 | cimage = self.img_list[item] 112 | img = np.array(self.imgs.get(cimage)) 113 | cmask = self.mask_list[item] 114 | mask = np.array(self.mask.get(cmask)) 115 | img = img.astype(np.uint8) 116 | mask = mask.astype(np.uint8) 117 | img, mask = self.transformations(img, mask) 118 | out_dict = {"conditioned_image": img} 119 | mask = (2 * mask - 1.0).unsqueeze(0) 120 | if self.small_image_size is not None: 121 | out_dict["low_res"] = F.interpolate(mask.unsqueeze(0), self.small_image_size, mode="nearest").squeeze(0) 122 | return mask, out_dict, str(Path(cimage).stem) 123 | 124 | 125 | if __name__ == '__main__': 126 | mean = np.array([0, 0, 0]) 127 | std = np.array([1, 1, 1]) 128 | dataset = VaihDataset('train', mean=mean, std=std, image_size=256) 129 | dataset2 = VaihDataset('train', mean=mean, std=std, image_size=256, no_aug=True) 130 | for i in range(10): 131 | mask, out_dict, _ = dataset[0] 132 | img = out_dict["conditioned_image"] 133 | plt.imshow(img.permute(1,2,0).numpy().astype(np.uint8)) 134 | plt.show() 135 | 136 | plt.imshow(mask.permute(1,2,0).numpy(), cmap='gray') 137 | plt.show() 138 | 139 | mask, out_dict, _ = dataset2[0] 140 | img = out_dict["conditioned_image"] 141 | plt.imshow(img.permute(1,2,0).numpy().astype(np.uint8)) 142 | plt.show() -------------------------------------------------------------------------------- /improved_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /datasets/monu.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import imageio 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import tifffile 8 | import torch 9 | from mpi4py import MPI 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from datasets.transforms import \ 14 | Compose, ToPILImage, ColorJitter, RandomHorizontalFlip, ToTensor, Normalize, RandomVerticalFlip, RandomAffine, \ 15 | Resize, RandomCrop 16 | 17 | 18 | def cv2_loader(path, is_mask): 19 | if is_mask: 20 | # img = cv2.imread(path, 0) 21 | img = imageio.imread(path) 22 | img[img > 0] = 1 23 | else: 24 | # img = cv2.cvtColor(cv2.imread(path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) 25 | # img = imageio.imread(path) 26 | img = tifffile.imread(path) 27 | return img 28 | 29 | 30 | def get_monu_transform(image_size): 31 | 32 | transform_train = Compose([ 33 | ToPILImage(), 34 | Resize((512, 512)), 35 | RandomCrop((image_size, image_size)), 36 | RandomHorizontalFlip(), 37 | RandomVerticalFlip(), 38 | RandomAffine(int(22), scale=(float(0.75), float(1.25))), 39 | ColorJitter(brightness=0.4, 40 | contrast=0.4, 41 | saturation=0.4, 42 | hue=0.1), 43 | ToTensor(), 44 | Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78]) 45 | ]) 46 | transform_test = Compose([ 47 | ToPILImage(), 48 | Resize((512, 512)), 49 | ToTensor(), 50 | Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78]) 51 | ]) 52 | return transform_train, transform_test 53 | 54 | 55 | def create_dataset(mode="train", image_size=256): 56 | datadir = str(Path(__file__).absolute().parent.parent.parent / "data/Medical/MoNuSeg") 57 | 58 | transform_train, transform_test = get_monu_transform(image_size) 59 | if mode == "train": 60 | return MonuDataset(datadir, train=True, transform=transform_train, image_size=image_size) 61 | else: 62 | return MonuDataset(datadir, train=False, transform=transform_test) 63 | 64 | 65 | def load_data( 66 | *, data_dir, batch_size, image_size, class_name, class_cond=False, expansion, deterministic=False 67 | ): 68 | """ 69 | For a dataset, create a generator over (images, kwargs) pairs. 70 | 71 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 72 | more keys, each of which map to a batched Tensor of their own. 73 | The kwargs dict can be used for class labels, in which case the key is "y" 74 | and the values are integer tensors of class labels. 75 | 76 | :param data_dir: a dataset directory. 77 | :param batch_size: the batch size of each returned pair. 78 | :param image_size: the size to which images are resized. 79 | :param class_cond: if True, include a "y" key in returned dicts for class 80 | label. If classes are not available and this is true, an 81 | exception will be raised. 82 | :param deterministic: if True, yield results in a deterministic order. 83 | """ 84 | 85 | dataset = create_dataset(mode="train") 86 | 87 | if deterministic: 88 | loader = DataLoader( 89 | dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True 90 | ) 91 | else: 92 | loader = DataLoader( 93 | dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True 94 | ) 95 | while True: 96 | yield from loader 97 | 98 | 99 | class MonuDataset(torch.utils.data.Dataset): 100 | def __init__(self, root, transform=None, target_transform=None, train=False, loader=cv2_loader, pSize=8, image_size=256): 101 | self.root = root 102 | if train: 103 | self.imgs_root = os.path.join(self.root, 'Training', 'img') 104 | self.masks_root = os.path.join(self.root, 'Training', 'mask') 105 | else: 106 | self.imgs_root = os.path.join(self.root, 'Test', 'img') 107 | self.masks_root = os.path.join(self.root, 'Test', 'mask') 108 | self.image_size = image_size 109 | self.paths = sorted(os.listdir(self.imgs_root)) 110 | self.transform = transform 111 | self.target_transform = target_transform 112 | self.loader = loader 113 | self.train = train 114 | self.pSize = pSize 115 | self.masks = [] 116 | self.imgs = [] 117 | self.mean = torch.from_numpy(np.array([142.07, 98.48, 132.96])) 118 | self.std = torch.from_numpy(np.array([65.78, 57.05, 57.78])) 119 | 120 | shard = MPI.COMM_WORLD.Get_rank() 121 | num_shards = MPI.COMM_WORLD.Get_size() 122 | 123 | for file_path in tqdm(self.paths): 124 | mask_path = file_path.split('.')[0] + '.png' 125 | self.imgs.append(self.loader(os.path.join(self.imgs_root, file_path), is_mask=False)) 126 | self.masks.append(self.loader(os.path.join(self.masks_root, mask_path), is_mask=True)) 127 | 128 | self.imgs = self.imgs[shard::num_shards] 129 | self.masks = self.masks[shard::num_shards] 130 | self.paths = self.paths[shard::num_shards] 131 | 132 | print('num of data:{}'.format(len(self.paths))) 133 | 134 | def __getitem__(self, index): 135 | img = self.imgs[index] 136 | mask = self.masks[index] 137 | 138 | img, mask = self.transform(img, mask) 139 | out_dict = {"conditioned_image": img} 140 | mask = 2 * mask - 1.0 141 | return mask.unsqueeze(0), out_dict, f"{Path(self.paths[index]).stem}_{index}" 142 | 143 | def __len__(self): 144 | return len(self.paths) 145 | 146 | 147 | if __name__ == "__main__": 148 | val_dataset = create_dataset( 149 | mode='val', 150 | image_size=256, 151 | ) 152 | 153 | ds = torch.utils.data.DataLoader(val_dataset, 154 | batch_size=1, 155 | num_workers=0, 156 | shuffle=False, 157 | drop_last=True) 158 | pbar = tqdm(ds) 159 | mean0_list = [] 160 | mean1_list = [] 161 | mean2_list = [] 162 | std0_list = [] 163 | std1_list = [] 164 | std2_list = [] 165 | for i, (mask, out_dict, _) in enumerate(pbar): 166 | img = out_dict["conditioned_image"] 167 | plt.imshow(img.squeeze().permute(1,2,0).numpy().astype(np.uint8)) 168 | plt.show() 169 | 170 | plt.imshow(mask.squeeze().numpy(), cmap='gray') 171 | plt.show() 172 | a = img.mean(dim=(0, 2, 3)) 173 | b = img.std(dim=(0, 2, 3)) 174 | mean0_list.append(a[0].item()) 175 | mean1_list.append(a[1].item()) 176 | mean2_list.append(a[2].item()) 177 | std0_list.append(b[0].item()) 178 | std1_list.append(b[1].item()) 179 | std2_list.append(b[2].item()) 180 | print(np.mean(mean0_list)) 181 | print(np.mean(mean1_list)) 182 | print(np.mean(mean2_list)) 183 | 184 | print(np.mean(std0_list)) 185 | print(np.mean(std1_list)) 186 | print(np.mean(std2_list)) 187 | 188 | # a = img.squeeze().permute(1, 2, 0).cpu().numpy() 189 | # b = mask.squeeze().cpu().numpy() 190 | # a = (a - a.min()) / (a.max() - a.min()) 191 | # cv2.imwrite('kaki.jpg', 255*a) 192 | # cv2.imwrite('kaki_mask.jpg', 255*b) -------------------------------------------------------------------------------- /datasets/city.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import h5py 7 | import numpy as np 8 | import pycocotools.mask as maskUtils 9 | import torch 10 | from PIL import Image 11 | from matplotlib import pyplot as plt 12 | from mpi4py import MPI 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision.transforms.functional import resize 15 | from tqdm import tqdm 16 | 17 | from datasets.transforms import \ 18 | Compose, ToPILImage, RandomHorizontalFlip, ToTensor, Normalize, RandomAffine 19 | 20 | 21 | def create_dataset(mode="train", class_name="train", expansion=False): 22 | shard=MPI.COMM_WORLD.Get_rank() 23 | num_shards = MPI.COMM_WORLD.Get_size() 24 | data_inst_path = str(Path(__file__).absolute().parent.parent.parent / "data/cityscapes_instances/") 25 | 26 | print('loading \"{}\" annotations into memory...'.format(mode)) 27 | data = json.load(open(os.path.join(data_inst_path, mode, 'all_classes_instances.json'), 'r')) 28 | 29 | annotations = data['data'][class_name][shard::num_shards] 30 | 31 | hdf5_obj = h5py.File(os.path.join(data_inst_path, 'all_images.hdf5'), 'r') 32 | images = [hdf5_obj[ann['img']['file_name']] for ann in annotations] 33 | 34 | return CityscapesInstances( 35 | images, 36 | annotations, 37 | mode=mode, 38 | expansion=expansion 39 | ) 40 | 41 | 42 | def load_data( 43 | *, data_dir, batch_size, image_size, class_name, class_cond=False, expansion, deterministic=False 44 | ): 45 | """ 46 | For a dataset, create a generator over (images, kwargs) pairs. 47 | 48 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 49 | more keys, each of which map to a batched Tensor of their own. 50 | The kwargs dict can be used for class labels, in which case the key is "y" 51 | and the values are integer tensors of class labels. 52 | 53 | :param data_dir: a dataset directory. 54 | :param batch_size: the batch size of each returned pair. 55 | :param image_size: the size to which images are resized. 56 | :param class_cond: if True, include a "y" key in returned dicts for class 57 | label. If classes are not available and this is true, an 58 | exception will be raised. 59 | :param deterministic: if True, yield results in a deterministic order. 60 | """ 61 | 62 | dataset = create_dataset(mode="train", class_name=class_name, expansion=expansion) 63 | 64 | if deterministic: 65 | loader = DataLoader( 66 | dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True 67 | ) 68 | else: 69 | loader = DataLoader( 70 | dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True 71 | ) 72 | while True: 73 | yield from loader 74 | 75 | 76 | class CityscapesInstances(Dataset): 77 | CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 78 | 'bicycle') 79 | 80 | def __init__(self, 81 | images, 82 | annotations, 83 | no_aug=False, 84 | mode='train', 85 | loops=100, 86 | expansion=False, 87 | std=np.array([58.395, 57.12, 57.375]), 88 | mean=np.array([123.675, 116.28, 103.53]), 89 | ): 90 | super(CityscapesInstances, self).__init__() 91 | 92 | self.loops = loops 93 | self.mode = mode 94 | self.mean = torch.from_numpy(mean) 95 | self.std = torch.from_numpy(std) 96 | self.expansion = expansion 97 | image_size = 128 98 | 99 | if mode == 'train' and not no_aug: 100 | self.transformations = Compose([ 101 | ToPILImage(), 102 | # Resize((image_size, image_size)), 103 | RandomHorizontalFlip(), 104 | RandomAffine(22, scale=(0.75, 1.25)), 105 | ToTensor(), 106 | Normalize(self.mean, self.std) 107 | # transforms.NormalizeInstance() 108 | ]) 109 | else: 110 | self.transformations = Compose([ 111 | ToPILImage(), 112 | # Resize((image_size, image_size), do_mask=False), 113 | ToTensor(), 114 | Normalize(self.mean, self.std), 115 | # transforms.NormalizeInstance() 116 | ]) 117 | 118 | self.instance_images = [] 119 | self.instance_masks = [] 120 | 121 | self.annotations = annotations 122 | 123 | for item in tqdm(range(len(images))): 124 | ann = self.annotations[item] 125 | mask = self._poly2mask(ann['segmentation'], ann['img']['height'], ann['img']['width']) 126 | bbox = np.maximum(0, np.array(ann['bbox']).astype(np.int32)) 127 | 128 | if self.expansion: 129 | if self.mode == 'train': 130 | bounding_box_expansion = random.randint(10, 20) 131 | else: 132 | bounding_box_expansion = 15 133 | 134 | increase_axis_by = bbox[3] * (bounding_box_expansion / 100) 135 | increase_each_coordinate = increase_axis_by / 2 136 | 137 | x_1 = bbox[1] - increase_each_coordinate 138 | x_2 = bbox[1] + bbox[3] + increase_each_coordinate 139 | 140 | increase_axis_by = bbox[2] * (bounding_box_expansion / 100) 141 | increase_each_coordinate = increase_axis_by / 2 142 | 143 | y_1 = bbox[0] - increase_each_coordinate 144 | y_2 = bbox[0] + bbox[2] + increase_each_coordinate 145 | 146 | # check the axis order 147 | x_2 = round(min(x_2, images[item].shape[0])) 148 | y_2 = round(min(y_2, images[item].shape[1])) 149 | 150 | x_1 = round(max(x_1, 0)) 151 | y_1 = round(max(y_1, 0)) 152 | 153 | instance_image = images[item][x_1:x_2, y_1:y_2] 154 | instance_mask = mask[x_1:x_2, y_1:y_2] 155 | else: 156 | instance_image = images[item][bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 157 | instance_mask = mask[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 158 | 159 | size = [image_size, image_size] 160 | self.instance_images.append(resize(torch.from_numpy(instance_image).permute(2, 0, 1), size, Image.BILINEAR).permute(1, 2, 0).numpy()) 161 | 162 | if mode == 'train' and not no_aug: 163 | self.instance_masks.append(resize(torch.from_numpy(instance_mask).unsqueeze(0), size, Image.NEAREST).squeeze(0).numpy()) 164 | else: 165 | self.instance_masks.append(instance_mask) 166 | 167 | @staticmethod 168 | def _poly2mask(mask_ann, img_h, img_w): 169 | if isinstance(mask_ann, list): 170 | # polygon -- a single object might consist of multiple parts 171 | # we merge all parts into one mask rle code 172 | rles = maskUtils.frPyObjects(mask_ann, img_h, img_w) 173 | rle = maskUtils.merge(rles) 174 | elif isinstance(mask_ann['counts'], list): 175 | # uncompressed RLE 176 | rle = maskUtils.frPyObjects(mask_ann, img_h, img_w) 177 | else: 178 | # rle 179 | rle = mask_ann 180 | mask = maskUtils.decode(rle) 181 | return mask 182 | 183 | def __len__(self): 184 | return len(self.annotations) 185 | 186 | def __getitem__(self, item): 187 | ann = self.annotations[item] 188 | 189 | instance_image, instance_mask = self.transformations(self.instance_images[item], self.instance_masks[item]) 190 | 191 | out_dict = {"conditioned_image": instance_image} 192 | instance_mask = 2 * instance_mask - 1.0 193 | return instance_mask.unsqueeze(0), out_dict, Path(ann["img"]['file_name']).stem 194 | 195 | 196 | def main(): 197 | mean = np.array([0, 0, 0]) 198 | std = np.array([1, 1, 1]) 199 | dataset = create_dataset(class_name="train", mode='train') 200 | for i in range(10): 201 | # mask, out_dict, _ = dataset[i] 202 | # img = out_dict["conditioned_image"] 203 | # plt.imshow(img.permute(1, 2, 0).numpy().astype(np.uint8)) 204 | # plt.show() 205 | # 206 | # plt.imshow(mask.permute(1, 2, 0).numpy(), cmap='gray') 207 | # plt.show() 208 | 209 | masks, out_dict, _ = dataset[i] 210 | imgs = out_dict["conditioned_image"] 211 | for index in range(10): 212 | plt.imshow(imgs[index * 10].permute(1, 2, 0).numpy().astype(np.uint8)) 213 | plt.show() 214 | 215 | for index in range(10): 216 | plt.imshow(masks[index * 10].permute(1, 2, 0).numpy(), cmap='gray') 217 | plt.show() 218 | 219 | pass 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | 225 | -------------------------------------------------------------------------------- /improved_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def model_and_diffusion_defaults(): 12 | """ 13 | Defaults for image training. 14 | """ 15 | return dict( 16 | image_size=64, 17 | num_channels=128, 18 | num_res_blocks=2, 19 | num_heads=4, 20 | num_heads_upsample=-1, 21 | attention_resolutions="16,8", 22 | dropout=0.0, 23 | rrdb_blocks=10, 24 | deeper_net=False, 25 | learn_sigma=False, 26 | sigma_small=False, 27 | class_cond=False, 28 | class_name="train", 29 | expansion=False, 30 | diffusion_steps=100, 31 | noise_schedule="linear", 32 | timestep_respacing="", 33 | use_kl=False, 34 | predict_xstart=False, 35 | rescale_timesteps=True, 36 | rescale_learned_sigmas=True, 37 | use_checkpoint=False, 38 | use_scale_shift_norm=True, 39 | seed=None, 40 | ) 41 | 42 | 43 | def create_model_and_diffusion( 44 | image_size, 45 | class_cond, 46 | learn_sigma, 47 | sigma_small, 48 | num_channels, 49 | num_res_blocks, 50 | num_heads, 51 | num_heads_upsample, 52 | attention_resolutions, 53 | dropout, 54 | rrdb_blocks, 55 | deeper_net, 56 | class_name, 57 | expansion, 58 | diffusion_steps, 59 | noise_schedule, 60 | timestep_respacing, 61 | use_kl, 62 | predict_xstart, 63 | rescale_timesteps, 64 | rescale_learned_sigmas, 65 | use_checkpoint, 66 | use_scale_shift_norm, 67 | seed, 68 | ): 69 | _ = seed # hack to prevent unused variable 70 | _ = expansion 71 | _ = class_name 72 | model = create_model( 73 | image_size, 74 | num_channels, 75 | num_res_blocks, 76 | learn_sigma=learn_sigma, 77 | class_cond=class_cond, 78 | use_checkpoint=use_checkpoint, 79 | attention_resolutions=attention_resolutions, 80 | num_heads=num_heads, 81 | num_heads_upsample=num_heads_upsample, 82 | use_scale_shift_norm=use_scale_shift_norm, 83 | dropout=dropout, 84 | rrdb_blocks=rrdb_blocks, 85 | deeper_net=deeper_net 86 | ) 87 | diffusion = create_gaussian_diffusion( 88 | steps=diffusion_steps, 89 | learn_sigma=learn_sigma, 90 | sigma_small=sigma_small, 91 | noise_schedule=noise_schedule, 92 | use_kl=use_kl, 93 | predict_xstart=predict_xstart, 94 | rescale_timesteps=rescale_timesteps, 95 | rescale_learned_sigmas=rescale_learned_sigmas, 96 | timestep_respacing=timestep_respacing, 97 | ) 98 | return model, diffusion 99 | 100 | 101 | def create_model( 102 | image_size, 103 | num_channels, 104 | num_res_blocks, 105 | learn_sigma, 106 | class_cond, 107 | use_checkpoint, 108 | attention_resolutions, 109 | num_heads, 110 | num_heads_upsample, 111 | use_scale_shift_norm, 112 | dropout, 113 | rrdb_blocks, 114 | deeper_net 115 | ): 116 | if image_size == 256: 117 | if deeper_net: 118 | channel_mult = (1, 1, 1, 2, 2, 4, 4) 119 | else: 120 | channel_mult = (1, 1, 2, 2, 4, 4) 121 | elif image_size == 128: 122 | channel_mult = (1, 1, 2, 2, 4, 4) 123 | elif image_size == 64: 124 | channel_mult = (1, 2, 3, 4) 125 | elif image_size == 32: 126 | channel_mult = (1, 2, 2, 2) 127 | else: 128 | raise ValueError(f"unsupported image size: {image_size}") 129 | 130 | attention_ds = [] 131 | for res in attention_resolutions.split(","): 132 | attention_ds.append(image_size // int(res)) 133 | 134 | return UNetModel( 135 | in_channels=1, 136 | model_channels=num_channels, 137 | out_channels=(1 if not learn_sigma else 2), 138 | num_res_blocks=num_res_blocks, 139 | attention_resolutions=tuple(attention_ds), 140 | dropout=dropout, 141 | channel_mult=channel_mult, 142 | num_classes=(NUM_CLASSES if class_cond else None), 143 | use_checkpoint=use_checkpoint, 144 | num_heads=num_heads, 145 | num_heads_upsample=num_heads_upsample, 146 | use_scale_shift_norm=use_scale_shift_norm, 147 | rrdb_blocks=rrdb_blocks 148 | ) 149 | 150 | 151 | def sr_model_and_diffusion_defaults(): 152 | res = model_and_diffusion_defaults() 153 | res["large_size"] = 256 154 | res["small_size"] = 64 155 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 156 | for k in res.copy().keys(): 157 | if k not in arg_names: 158 | del res[k] 159 | return res 160 | 161 | 162 | def sr_create_model_and_diffusion( 163 | large_size, 164 | small_size, 165 | class_cond, 166 | learn_sigma, 167 | num_channels, 168 | num_res_blocks, 169 | num_heads, 170 | num_heads_upsample, 171 | attention_resolutions, 172 | dropout, 173 | rrdb_blocks, 174 | deeper_net, 175 | diffusion_steps, 176 | noise_schedule, 177 | timestep_respacing, 178 | use_kl, 179 | predict_xstart, 180 | rescale_timesteps, 181 | rescale_learned_sigmas, 182 | use_checkpoint, 183 | use_scale_shift_norm, 184 | ): 185 | model = sr_create_model( 186 | large_size, 187 | small_size, 188 | num_channels, 189 | num_res_blocks, 190 | learn_sigma=learn_sigma, 191 | class_cond=class_cond, 192 | use_checkpoint=use_checkpoint, 193 | attention_resolutions=attention_resolutions, 194 | num_heads=num_heads, 195 | num_heads_upsample=num_heads_upsample, 196 | use_scale_shift_norm=use_scale_shift_norm, 197 | dropout=dropout, 198 | rrdb_blocks=rrdb_blocks, 199 | deeper_net=deeper_net, 200 | ) 201 | diffusion = create_gaussian_diffusion( 202 | steps=diffusion_steps, 203 | learn_sigma=learn_sigma, 204 | noise_schedule=noise_schedule, 205 | use_kl=use_kl, 206 | predict_xstart=predict_xstart, 207 | rescale_timesteps=rescale_timesteps, 208 | rescale_learned_sigmas=rescale_learned_sigmas, 209 | timestep_respacing=timestep_respacing, 210 | ) 211 | return model, diffusion 212 | 213 | 214 | def sr_create_model( 215 | large_size, 216 | small_size, 217 | num_channels, 218 | num_res_blocks, 219 | learn_sigma, 220 | class_cond, 221 | use_checkpoint, 222 | attention_resolutions, 223 | num_heads, 224 | num_heads_upsample, 225 | use_scale_shift_norm, 226 | dropout, 227 | rrdb_blocks, 228 | deeper_net, 229 | ): 230 | _ = small_size # hack to prevent unused variable 231 | 232 | if large_size == 256: 233 | if deeper_net: 234 | channel_mult = (1, 1, 1, 2, 2, 4, 4) 235 | else: 236 | channel_mult = (1, 1, 2, 2, 4, 4) 237 | elif large_size == 64: 238 | channel_mult = (1, 2, 3, 4) 239 | else: 240 | raise ValueError(f"unsupported large size: {large_size}") 241 | 242 | attention_ds = [] 243 | for res in attention_resolutions.split(","): 244 | attention_ds.append(large_size // int(res)) 245 | 246 | return SuperResModel( 247 | in_channels=1, 248 | model_channels=num_channels, 249 | out_channels=(1 if not learn_sigma else 2), 250 | num_res_blocks=num_res_blocks, 251 | attention_resolutions=tuple(attention_ds), 252 | dropout=dropout, 253 | channel_mult=channel_mult, 254 | num_classes=(NUM_CLASSES if class_cond else None), 255 | use_checkpoint=use_checkpoint, 256 | num_heads=num_heads, 257 | num_heads_upsample=num_heads_upsample, 258 | use_scale_shift_norm=use_scale_shift_norm, 259 | rrdb_blocks=rrdb_blocks, 260 | ) 261 | 262 | 263 | def create_gaussian_diffusion( 264 | *, 265 | steps=1000, 266 | learn_sigma=False, 267 | sigma_small=False, 268 | noise_schedule="linear", 269 | use_kl=False, 270 | predict_xstart=False, 271 | rescale_timesteps=False, 272 | rescale_learned_sigmas=False, 273 | timestep_respacing="", 274 | ): 275 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 276 | if use_kl: 277 | loss_type = gd.LossType.RESCALED_KL 278 | elif rescale_learned_sigmas: 279 | loss_type = gd.LossType.RESCALED_MSE 280 | else: 281 | loss_type = gd.LossType.MSE 282 | if not timestep_respacing: 283 | timestep_respacing = [steps] 284 | return SpacedDiffusion( 285 | use_timesteps=space_timesteps(steps, timestep_respacing), 286 | betas=betas, 287 | model_mean_type=( 288 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 289 | ), 290 | model_var_type=( 291 | ( 292 | gd.ModelVarType.FIXED_LARGE 293 | if not sigma_small 294 | else gd.ModelVarType.FIXED_SMALL 295 | ) 296 | if not learn_sigma 297 | else gd.ModelVarType.LEARNED_RANGE 298 | ), 299 | loss_type=loss_type, 300 | rescale_timesteps=rescale_timesteps, 301 | ) 302 | 303 | 304 | def add_dict_to_argparser(parser, default_dict): 305 | for k, v in default_dict.items(): 306 | v_type = type(v) 307 | if v is None: 308 | v_type = str 309 | elif isinstance(v, bool): 310 | v_type = str2bool 311 | parser.add_argument(f"--{k}", default=v, type=v_type) 312 | 313 | 314 | def args_to_dict(args, keys): 315 | return {k: getattr(args, k) for k in keys} 316 | 317 | 318 | def str2bool(v): 319 | """ 320 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 321 | """ 322 | if isinstance(v, bool): 323 | return v 324 | if v.lower() in ("yes", "true", "t", "y", "1"): 325 | return True 326 | elif v.lower() in ("no", "false", "f", "n", "0"): 327 | return False 328 | else: 329 | raise argparse.ArgumentTypeError("boolean value expected") 330 | -------------------------------------------------------------------------------- /improved_diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | import torchvision.utils as tvu 9 | from PIL import Image 10 | from kornia import denormalize 11 | from sklearn.metrics import f1_score, jaccard_score 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | from . import dist_util 16 | from .metrics import FBound_metric, WCov_metric 17 | from datasets.monu import MonuDataset 18 | from .utils import set_random_seed_for_iterations 19 | 20 | cityspallete = [ 21 | 0, 0, 0, 22 | 128, 64, 128, 23 | 244, 35, 232, 24 | 70, 70, 70, 25 | 102, 102, 156, 26 | 190, 153, 153, 27 | 153, 153, 153, 28 | 250, 170, 30, 29 | 220, 220, 0, 30 | 107, 142, 35, 31 | 152, 251, 152, 32 | 0, 130, 180, 33 | 220, 20, 60, 34 | 255, 0, 0, 35 | 0, 0, 142, 36 | 0, 0, 70, 37 | 0, 60, 100, 38 | 0, 80, 100, 39 | 0, 0, 230, 40 | 119, 11, 32, 41 | ] 42 | 43 | 44 | def calculate_metrics(x, gt): 45 | predict = x.detach().cpu().numpy().astype('uint8') 46 | target = gt.detach().cpu().numpy().astype('uint8') 47 | return f1_score(target.flatten(), predict.flatten()), jaccard_score(target.flatten(), predict.flatten()), \ 48 | WCov_metric(predict, target), FBound_metric(predict, target) 49 | 50 | 51 | def sampling_major_vote_func(diffusion_model, ddp_model, output_folder, dataset, logger, clip_denoised, step, n_rounds=3): 52 | ddp_model.eval() 53 | batch_size = 1 54 | major_vote_number = 9 55 | loader = DataLoader(dataset, batch_size=batch_size) 56 | loader_iter = iter(loader) 57 | 58 | f1_score_list = [] 59 | miou_list = [] 60 | fbound_list = [] 61 | wcov_list = [] 62 | 63 | with torch.no_grad(): 64 | for round_index in tqdm( 65 | range(n_rounds), desc="Generating image samples for FID evaluation." 66 | ): 67 | gt_mask, condition_on, name = next(loader_iter) 68 | set_random_seed_for_iterations(step + int(name[0].split("_")[1])) 69 | gt_mask = (gt_mask + 1.0) / 2.0 70 | condition_on = condition_on["conditioned_image"] 71 | former_frame_for_feature_extraction = condition_on.to(dist_util.dev()) 72 | 73 | for i in range(gt_mask.shape[0]): 74 | gt_img = Image.fromarray(gt_mask[i][0].detach().cpu().numpy().astype('uint8')) 75 | gt_img.putpalette(cityspallete) 76 | gt_img.save( 77 | os.path.join(output_folder, f"{name[i]}_gt_palette.png")) 78 | gt_img = Image.fromarray((gt_mask[i][0].detach().cpu().numpy() - 1).astype(np.uint8)) 79 | gt_img.save( 80 | os.path.join(output_folder, f"{name[i]}_gt.png")) 81 | 82 | for i in range(condition_on.shape[0]): 83 | denorm_condition_on = denormalize(condition_on.clone(), mean=dataset.mean, std=dataset.std) 84 | tvu.save_image( 85 | denorm_condition_on[i,] / 255., 86 | os.path.join(output_folder, f"{name[i]}_condition_on.png") 87 | ) 88 | 89 | if isinstance(dataset, MonuDataset): 90 | _, _, W, H = former_frame_for_feature_extraction.shape 91 | kernel_size = dataset.image_size 92 | stride = 256 93 | patches = [] 94 | for y, x in np.ndindex((((W - kernel_size) // stride) + 1, ((H - kernel_size) // stride) + 1)): 95 | y = y * stride 96 | x = x * stride 97 | patches.append(former_frame_for_feature_extraction[0, 98 | :, 99 | y: min(y + kernel_size, W), 100 | x: min(x + kernel_size, H)]) 101 | patches = torch.stack(patches) 102 | 103 | major_vote_list = [] 104 | for i in range(major_vote_number): 105 | x_list = [] 106 | 107 | for index in range(math.ceil(patches.shape[0] / 4)): 108 | model_kwargs = {"conditioned_image": patches[index * 4: min((index + 1) * 4, patches.shape[0])]} 109 | x = diffusion_model.p_sample_loop( 110 | ddp_model, 111 | (model_kwargs["conditioned_image"].shape[0], gt_mask.shape[1], model_kwargs["conditioned_image"].shape[2], model_kwargs["conditioned_image"].shape[3]), 112 | progress=True, 113 | clip_denoised=clip_denoised, 114 | model_kwargs=model_kwargs 115 | ) 116 | 117 | x_list.append(x) 118 | out = torch.cat(x_list) 119 | 120 | output = torch.zeros((former_frame_for_feature_extraction.shape[0], gt_mask.shape[1], former_frame_for_feature_extraction.shape[2], former_frame_for_feature_extraction.shape[3])) 121 | idx_sum = torch.zeros((former_frame_for_feature_extraction.shape[0], gt_mask.shape[1], former_frame_for_feature_extraction.shape[2], former_frame_for_feature_extraction.shape[3])) 122 | for index, val in enumerate(out): 123 | y, x = np.unravel_index(index, (((W - kernel_size) // stride) + 1, ((H - kernel_size) // stride) + 1)) 124 | y = y * stride 125 | x = x * stride 126 | 127 | idx_sum[0, 128 | :, 129 | y: min(y + kernel_size, W), 130 | x: min(x + kernel_size, H)] += 1 131 | 132 | output[0, 133 | :, 134 | y: min(y + kernel_size, W), 135 | x: min(x + kernel_size, H)] += val[:, :min(y + kernel_size, W) - y, :min(x + kernel_size, H) - x].cpu().data.numpy() 136 | 137 | output = output / idx_sum 138 | major_vote_list.append(output) 139 | 140 | x = torch.cat(major_vote_list) 141 | 142 | else: 143 | model_kwargs = { 144 | "conditioned_image": torch.cat([former_frame_for_feature_extraction] * major_vote_number)} 145 | 146 | x = diffusion_model.p_sample_loop( 147 | ddp_model, 148 | (major_vote_number, gt_mask.shape[1], former_frame_for_feature_extraction.shape[2], 149 | former_frame_for_feature_extraction.shape[3]), 150 | progress=True, 151 | clip_denoised=clip_denoised, 152 | model_kwargs=model_kwargs 153 | ) 154 | 155 | x = (x + 1.0) / 2.0 156 | 157 | if x.shape[2] != gt_mask.shape[2] or x.shape[3] != gt_mask.shape[3]: 158 | x = F.interpolate(x, gt_mask.shape[2:], mode='bilinear') 159 | 160 | x = torch.clamp(x, 0.0, 1.0) 161 | 162 | # major vote result 163 | x = x.mean(dim=0, keepdim=True).round() 164 | 165 | for i in range(x.shape[0]): 166 | # save as outer training ids 167 | # current_output = x[i][0] + 1 168 | # current_output[current_output == current_output.max()] = 0 169 | out_img = Image.fromarray(x[i][0].detach().cpu().numpy().astype('uint8')) 170 | out_img.putpalette(cityspallete) 171 | out_img.save( 172 | os.path.join(output_folder, f"{name[i]}_model_output_palette.png")) 173 | out_img = Image.fromarray((x[i][0].detach().cpu().numpy() - 1).astype(np.uint8)) 174 | out_img.save( 175 | os.path.join(output_folder, f"{name[i]}_model_output.png")) 176 | 177 | for index, (gt_im, out_im) in enumerate(zip(gt_mask, x)): 178 | 179 | f1, miou, wcov, fbound = calculate_metrics(out_im[0], gt_im[0]) 180 | f1_score_list.append(f1) 181 | miou_list.append(miou) 182 | wcov_list.append(wcov) 183 | fbound_list.append(fbound) 184 | 185 | logger.info( 186 | f"{name[index]} iou {miou_list[-1]}, f1_Score {f1_score_list[-1]}, WCov {wcov_list[-1]}, boundF {fbound_list[-1]}") 187 | 188 | my_length = len(miou_list) 189 | length_of_data = torch.tensor(len(miou_list), device=dist_util.dev()) 190 | gathered_length_of_data = [torch.tensor(1, device=dist_util.dev()) for _ in range(dist.get_world_size())] 191 | dist.all_gather(gathered_length_of_data, length_of_data) 192 | max_len = torch.max(torch.stack(gathered_length_of_data)) 193 | 194 | iou_tensor = torch.tensor(miou_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev()) 195 | f1_tensor = torch.tensor(f1_score_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev()) 196 | wcov_tensor = torch.tensor(wcov_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev()) 197 | boundf_tensor = torch.tensor(fbound_list + [torch.tensor(-1)] * (max_len - my_length), device=dist_util.dev()) 198 | gathered_miou = [torch.ones_like(iou_tensor) * -1 for _ in range(dist.get_world_size())] 199 | gathered_f1 = [torch.ones_like(f1_tensor) * -1 for _ in range(dist.get_world_size())] 200 | gathered_wcov = [torch.ones_like(wcov_tensor) * -1 for _ in range(dist.get_world_size())] 201 | gathered_boundf = [torch.ones_like(boundf_tensor) * -1 for _ in range(dist.get_world_size())] 202 | 203 | dist.all_gather(gathered_miou, iou_tensor) 204 | dist.all_gather(gathered_f1, f1_tensor) 205 | dist.all_gather(gathered_wcov, wcov_tensor) 206 | dist.all_gather(gathered_boundf, boundf_tensor) 207 | 208 | # if dist.get_rank() == 0: 209 | logger.info("measure total avg") 210 | gathered_miou = torch.cat(gathered_miou) 211 | gathered_miou = gathered_miou[gathered_miou != -1] 212 | logger.info(f"mean iou {gathered_miou.mean()}") 213 | 214 | gathered_f1 = torch.cat(gathered_f1) 215 | gathered_f1 = gathered_f1[gathered_f1 != -1] 216 | logger.info(f"mean f1 {gathered_f1.mean()}") 217 | gathered_wcov = torch.cat(gathered_wcov) 218 | gathered_wcov = gathered_wcov[gathered_wcov != -1] 219 | logger.info(f"mean WCov {gathered_wcov.mean()}") 220 | gathered_boundf = torch.cat(gathered_boundf) 221 | gathered_boundf = gathered_boundf[gathered_boundf != -1] 222 | logger.info(f"mean boundF {gathered_boundf.mean()}") 223 | 224 | dist.barrier() 225 | return gathered_miou.mean().item() 226 | -------------------------------------------------------------------------------- /improved_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(f"{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')} {elem}") 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /improved_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | from pathlib import Path 5 | 6 | import blobfile as bf 7 | import numpy as np 8 | import torch as th 9 | import torch.distributed as dist 10 | from mpi4py import MPI 11 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 12 | from torch.optim import AdamW 13 | from tqdm import tqdm 14 | 15 | from . import dist_util, logger 16 | from .fp16_util import ( 17 | make_master_params, 18 | master_params_to_model_params, 19 | model_grads_to_master_grads, 20 | unflatten_master_params, 21 | zero_grad, 22 | ) 23 | from .nn import update_ema 24 | from .resample import LossAwareSampler, UniformSampler 25 | # For ImageNet experiments, this was a good default value. 26 | # We found that the lg_loss_scale quickly climbed to 27 | # 20-21 within the first ~1K steps of training. 28 | from .sampling_util import sampling_major_vote_func 29 | from .utils import set_random_seed_for_iterations 30 | 31 | INITIAL_LOG_LOSS_SCALE = 20.0 32 | 33 | 34 | class TrainLoop: 35 | def __init__( 36 | self, 37 | *, 38 | model, 39 | diffusion, 40 | data, 41 | batch_size, 42 | microbatch, 43 | lr, 44 | ema_rate, 45 | log_interval, 46 | save_interval, 47 | resume_checkpoint, 48 | logger, 49 | image_size, 50 | val_dataset, 51 | clip_denoised=True, 52 | use_fp16=False, 53 | fp16_scale_growth=1e-3, 54 | schedule_sampler=None, 55 | weight_decay=0.0, 56 | lr_anneal_steps=0, 57 | run_without_test=False, 58 | args=None 59 | ): 60 | self.model = model 61 | self.diffusion = diffusion 62 | self.data = data 63 | self.batch_size = batch_size 64 | self.microbatch = microbatch if microbatch > 0 else batch_size 65 | self.lr = lr 66 | self.args = args 67 | self.ema_rate = ( 68 | [ema_rate] 69 | if isinstance(ema_rate, float) 70 | else [float(x) for x in ema_rate.split(",")] 71 | ) 72 | self.log_interval = log_interval 73 | self.save_interval = save_interval 74 | self.resume_checkpoint = resume_checkpoint 75 | self.use_fp16 = use_fp16 76 | self.fp16_scale_growth = fp16_scale_growth 77 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 78 | self.weight_decay = weight_decay 79 | self.lr_anneal_steps = lr_anneal_steps 80 | 81 | self.step = 1 82 | self.resume_step = 0 83 | self.global_batch = self.batch_size * dist.get_world_size() 84 | 85 | self.model_params = list(self.model.parameters()) 86 | self.master_params = self.model_params 87 | self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE 88 | self.sync_cuda = th.cuda.is_available() 89 | 90 | # if self.resume_checkpoint: 91 | self._load_and_sync_parameters(self.resume_checkpoint) 92 | 93 | if self.use_fp16: 94 | self._setup_fp16() 95 | 96 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 97 | if self.resume_checkpoint: 98 | 99 | self._load_optimizer_state(resume_checkpoint) 100 | # Model was resumed, either due to a restart or a checkpoint 101 | # being specified at the command line. 102 | self.ema_params = [ 103 | self._load_ema_parameters(rate, resume_checkpoint) for rate in self.ema_rate 104 | ] 105 | else: 106 | self.ema_params = [ 107 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 108 | ] 109 | 110 | if th.cuda.is_available(): 111 | self.use_ddp = True 112 | self.ddp_model = DDP( 113 | self.model, 114 | device_ids=[dist_util.dev()], 115 | output_device=dist_util.dev(), 116 | broadcast_buffers=False, 117 | bucket_cap_mb=128, 118 | find_unused_parameters=False, 119 | ) 120 | self.ema_model = copy.deepcopy(self.model).to(th.device("cpu")) 121 | else: 122 | if dist.get_world_size() > 1: 123 | logger.warn( 124 | "Distributed training requires CUDA. " 125 | "Gradients will not be synchronized properly!" 126 | ) 127 | self.use_ddp = False 128 | self.ddp_model = self.model 129 | 130 | self.val_dataset = val_dataset 131 | self.logger = logger 132 | self.ema_val_best_iou = 0 133 | self.val_best_iou = 0 134 | self.clip_denoised = clip_denoised 135 | self.val_current_model_name = "" 136 | self.val_current_model_ema_name = "" 137 | self.current_model_checkpoint_name = "" 138 | self.run_without_test = run_without_test 139 | 140 | def _load_and_sync_parameters(self, logs_path): 141 | # resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 142 | 143 | # model_checkpoint = bf.join( 144 | # bf.dirname(logs_path), f"model.pt" 145 | # ) 146 | logger.log(f"model folder path") 147 | if logs_path: 148 | if Path(logs_path).exists(): 149 | model_path = list(Path(logs_path).glob("model*.pt"))[0] 150 | self.resume_step = parse_resume_step_from_filename(str(model_path)) 151 | self.step = self.resume_step 152 | 153 | logger.log(f"loading model from checkpoint: {model_path} from step {self.step}...") 154 | 155 | self.model.load_state_dict( 156 | dist_util.load_state_dict( 157 | str(model_path), map_location=dist_util.dev() 158 | ) 159 | ) 160 | 161 | dist_util.sync_params(self.model.parameters()) 162 | 163 | def _load_ema_parameters(self, rate, logs_path): 164 | ema_params = copy.deepcopy(self.master_params) 165 | 166 | ema_checkpoint = Path(logs_path) / "ema.pt" 167 | 168 | if ema_checkpoint.exists(): 169 | # if dist.get_rank() == 0: 170 | logger.log(f"loading EMA from checkpoint: {str(ema_checkpoint)}...") 171 | state_dict = dist_util.load_state_dict( 172 | str(ema_checkpoint), map_location=dist_util.dev() 173 | ) 174 | ema_params = self._state_dict_to_master_params(state_dict) 175 | 176 | dist_util.sync_params(ema_params) 177 | return ema_params 178 | 179 | def _load_optimizer_state(self, logs_path): 180 | 181 | opt_checkpoint = Path(logs_path) / "opt.pt" 182 | 183 | if opt_checkpoint.exists(): 184 | logger.log(f"loading optimizer state from checkpoint: {str(opt_checkpoint)}") 185 | state_dict = dist_util.load_state_dict( 186 | str(opt_checkpoint), map_location=dist_util.dev() 187 | ) 188 | self.opt.load_state_dict(state_dict) 189 | 190 | def _setup_fp16(self): 191 | self.master_params = make_master_params(self.model_params) 192 | self.model.convert_to_fp16() 193 | 194 | def run_loop(self, max_iter=250000, start_print_iter=100000, vis_batch_size=8, n_rounds=3): 195 | if dist.get_rank() == 0: 196 | pbar = tqdm() 197 | while ( 198 | self.step < max_iter 199 | ): 200 | self.ddp_model.train() 201 | batch, cond, _ = next(self.data) 202 | self.run_step(batch, cond) 203 | if dist.get_rank() == 0: 204 | pbar.update(1) 205 | if self.step % self.log_interval == 0 and self.step != 0: 206 | logger.log(f"interval") 207 | logger.dumpkvs() 208 | logger.log(f"class {self.args.class_name} lr {self.lr}, expansion {self.args.expansion}, " 209 | f"rrdb blocks {self.args.rrdb_blocks} gpus {MPI.COMM_WORLD.Get_size()}") 210 | 211 | if self.step % self.save_interval == 0: 212 | logger.log(f"save model for checkpoint") 213 | self.save_state_dict() 214 | dist.barrier() 215 | 216 | if self.step % self.save_interval == 0 and self.step >= start_print_iter or self.step == 60000: 217 | if self.run_without_test: 218 | if dist.get_rank() == 0: 219 | self.save_checkpoint(self.ema_rate[0], self.ema_params[0], name=f"model") 220 | else: 221 | self.ddp_model.eval() 222 | 223 | logger.log(f"ema sampling") 224 | output_folder = os.path.join(os.environ["OPENAI_LOGDIR"], f"{self.step}_val_ema_major") 225 | os.mkdir(output_folder) 226 | self.ema_model = self.ema_model.to(dist_util.dev()) 227 | self.ema_model.load_state_dict(self._master_params_to_state_dict(self.ema_params[0])) 228 | self.ema_model.eval() 229 | ema_val_miou = sampling_major_vote_func(self.diffusion, self.ema_model, output_folder=output_folder, 230 | dataset=self.val_dataset, logger=self.logger, 231 | clip_denoised=self.clip_denoised, step=self.step, n_rounds=len(self.val_dataset)) 232 | self.ema_model = self.ema_model.to(th.device("cpu")) # release gpu memory 233 | 234 | if dist.get_rank() == 0: 235 | if self.ema_val_best_iou < ema_val_miou: 236 | logger.log(f"best iou ema val: {ema_val_miou} step {self.step}") 237 | self.ema_val_best_iou = ema_val_miou 238 | 239 | ema_filename = self.save_checkpoint(self.ema_rate[0], self.ema_params[0], name=f"val_{ema_val_miou:.7f}") 240 | 241 | if self.val_current_model_ema_name != "": 242 | ckpt_path = bf.join(get_blob_logdir(), self.val_current_model_ema_name) 243 | if os.path.exists(ckpt_path): 244 | os.remove(ckpt_path) 245 | 246 | self.val_current_model_ema_name = ema_filename 247 | 248 | set_random_seed_for_iterations(MPI.COMM_WORLD.Get_rank() + self.step) 249 | dist.barrier() 250 | self.step += 1 251 | 252 | def run_step(self, batch, cond): 253 | self.forward_backward(batch, cond) 254 | if self.use_fp16: 255 | self.optimize_fp16() 256 | else: 257 | self.optimize_normal() 258 | self.log_step() 259 | 260 | def forward_backward(self, batch, cond): 261 | zero_grad(self.model_params) 262 | for i in range(0, batch.shape[0], self.microbatch): 263 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 264 | micro_cond = { 265 | k: v[i : i + self.microbatch].to(dist_util.dev()) 266 | for k, v in cond.items() 267 | } 268 | last_batch = (i + self.microbatch) >= batch.shape[0] 269 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 270 | 271 | compute_losses = functools.partial( 272 | self.diffusion.training_losses, 273 | self.ddp_model, 274 | micro, 275 | t, 276 | model_kwargs=micro_cond, 277 | ) 278 | 279 | if last_batch or not self.use_ddp: 280 | losses = compute_losses() 281 | else: 282 | with self.ddp_model.no_sync(): 283 | losses = compute_losses() 284 | 285 | if isinstance(self.schedule_sampler, LossAwareSampler): 286 | self.schedule_sampler.update_with_local_losses( 287 | t, losses["loss"].detach() 288 | ) 289 | 290 | loss = (losses["loss"] * weights).mean() 291 | log_loss_dict( 292 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 293 | ) 294 | if self.use_fp16: 295 | loss_scale = 2 ** self.lg_loss_scale 296 | (loss * loss_scale).backward() 297 | else: 298 | loss.backward() 299 | 300 | def optimize_fp16(self): 301 | if any(not th.isfinite(p.grad).all() for p in self.model_params): 302 | self.lg_loss_scale -= 1 303 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 304 | return 305 | 306 | model_grads_to_master_grads(self.model_params, self.master_params) 307 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 308 | self._log_grad_norm() 309 | self._anneal_lr() 310 | self.opt.step() 311 | for rate, params in zip(self.ema_rate, self.ema_params): 312 | update_ema(params, self.master_params, rate=rate) 313 | master_params_to_model_params(self.model_params, self.master_params) 314 | self.lg_loss_scale += self.fp16_scale_growth 315 | 316 | def optimize_normal(self): 317 | self._log_grad_norm() 318 | self._anneal_lr() 319 | self.opt.step() 320 | for rate, params in zip(self.ema_rate, self.ema_params): 321 | update_ema(params, self.master_params, rate=rate) 322 | 323 | def _log_grad_norm(self): 324 | sqsum = 0.0 325 | for p in self.master_params: 326 | sqsum += (p.grad ** 2).sum().item() 327 | logger.logkv_mean("grad_norm", np.sqrt(sqsum)) 328 | 329 | def _anneal_lr(self): 330 | if not self.lr_anneal_steps: 331 | return 332 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 333 | lr = self.lr * (1 - frac_done) 334 | for param_group in self.opt.param_groups: 335 | param_group["lr"] = lr 336 | 337 | def log_step(self): 338 | logger.logkv("step", self.step + self.resume_step) 339 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 340 | if self.use_fp16: 341 | logger.logkv("lg_loss_scale", self.lg_loss_scale) 342 | 343 | def save_checkpoint(self, rate, params, name): 344 | state_dict = self._master_params_to_state_dict(params) 345 | if dist.get_rank() == 0: 346 | logger.log(f"saving model {rate}...") 347 | if not rate: 348 | filename = f"model_{name}_{(self.step+self.resume_step):06d}.pt" 349 | else: 350 | filename = f"ema_{name}_{rate}_{(self.step+self.resume_step):06d}.pt" 351 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 352 | th.save(state_dict, f) 353 | return filename 354 | 355 | def save_state_dict(self): 356 | 357 | if dist.get_rank() == 0: 358 | with bf.BlobFile(bf.join(get_blob_logdir(), f"opt.pt"), "wb",) as f: 359 | th.save(self.opt.state_dict(), f) 360 | 361 | with bf.BlobFile(bf.join(get_blob_logdir(), f"model{self.step}.pt"), "wb") as f: 362 | th.save(self._master_params_to_state_dict(self.master_params), f) 363 | 364 | if self.current_model_checkpoint_name != "": 365 | ckpt_path = bf.join(get_blob_logdir(), self.current_model_checkpoint_name) 366 | if os.path.exists(ckpt_path): 367 | os.remove(ckpt_path) 368 | 369 | self.current_model_checkpoint_name = bf.join(get_blob_logdir(), f"model{self.step}.pt") 370 | 371 | with bf.BlobFile(bf.join(get_blob_logdir(), f"ema.pt"), "wb") as f: 372 | th.save(self._master_params_to_state_dict(self.ema_params[0]), f) 373 | # 374 | # checkpoint = { 375 | # 'step': self.step, 376 | # 'state_dict': self._master_params_to_state_dict(self.master_params), 377 | # 'ema_state_dict': self._master_params_to_state_dict(self.ema_params[0]), 378 | # 'optimizer': self.opt.state_dict() 379 | # } 380 | # 381 | # current_model_checkpoint_name = bf.join(get_blob_logdir(), file_name) 382 | # th.save(checkpoint, current_model_checkpoint_name) 383 | # 384 | # if self.current_model_checkpoint_name != "": 385 | # ckpt_path = bf.join(get_blob_logdir(), self.current_model_checkpoint_name) 386 | # if os.path.exists(ckpt_path): 387 | # os.remove(ckpt_path) 388 | # 389 | # self.current_model_checkpoint_name = current_model_checkpoint_name 390 | 391 | def save(self, name): 392 | 393 | filename = self.save_checkpoint(0, self.master_params, name) 394 | for rate, params in zip(self.ema_rate, self.ema_params): 395 | filename_ema = self.save_checkpoint(rate, params, name) 396 | 397 | # if dist.get_rank() == 0: 398 | # with bf.BlobFile( 399 | # bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 400 | # "wb", 401 | # ) as f: 402 | # th.save(self.opt.state_dict(), f) 403 | 404 | # dist.barrier() 405 | 406 | return filename, filename_ema 407 | 408 | def _master_params_to_state_dict(self, master_params): 409 | if self.use_fp16: 410 | master_params = unflatten_master_params( 411 | list(self.model.parameters()), master_params 412 | ) 413 | state_dict = self.model.state_dict() 414 | for i, (name, _value) in enumerate(self.model.named_parameters()): 415 | assert name in state_dict 416 | state_dict[name] = master_params[i] 417 | return state_dict 418 | 419 | def _state_dict_to_master_params(self, state_dict): 420 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 421 | if self.use_fp16: 422 | return make_master_params(params) 423 | else: 424 | return params 425 | 426 | 427 | def parse_resume_step_from_filename(filename): 428 | """ 429 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 430 | checkpoint's number of steps. 431 | """ 432 | split = filename.split("model") 433 | if len(split) < 2: 434 | return 0 435 | split1 = split[-1].split(".")[0] 436 | try: 437 | return int(split1) 438 | except ValueError: 439 | return 0 440 | 441 | 442 | def get_blob_logdir(): 443 | return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) 444 | 445 | 446 | def find_resume_checkpoint(): 447 | # On your infrastructure, you may want to override this to automatically 448 | # discover the latest checkpoint on your blob storage, etc. 449 | return None 450 | 451 | 452 | def find_ema_checkpoint(main_checkpoint, step, rate): 453 | if main_checkpoint is None: 454 | return None 455 | filename = f"ema_{rate}_{(step):06d}.pt" 456 | path = bf.join(bf.dirname(main_checkpoint), filename) 457 | if bf.exists(path): 458 | return path 459 | return None 460 | 461 | 462 | def log_loss_dict(diffusion, ts, losses): 463 | for key, values in losses.items(): 464 | logger.logkv_mean(key, values.mean().item()) 465 | # Log the quantiles (four quartiles, in particular). 466 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 467 | quartile = int(4 * sub_t / diffusion.num_timesteps) 468 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 469 | -------------------------------------------------------------------------------- /improved_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .RRDB import RRDBNet 11 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 12 | from .nn import ( 13 | SiLU, 14 | conv_nd, 15 | linear, 16 | avg_pool_nd, 17 | zero_module, 18 | normalization, 19 | timestep_embedding, 20 | checkpoint, 21 | ) 22 | 23 | 24 | class TimestepBlock(nn.Module): 25 | """ 26 | Any module where forward() takes timestep embeddings as a second argument. 27 | """ 28 | 29 | @abstractmethod 30 | def forward(self, x, emb): 31 | """ 32 | Apply the module to `x` given `emb` timestep embeddings. 33 | """ 34 | 35 | 36 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 37 | """ 38 | A sequential module that passes timestep embeddings to the children that 39 | support it as an extra input. 40 | """ 41 | 42 | def forward(self, x, emb): 43 | for layer in self: 44 | if isinstance(layer, TimestepBlock): 45 | x = layer(x, emb) 46 | else: 47 | x = layer(x) 48 | return x 49 | 50 | 51 | class Upsample(nn.Module): 52 | """ 53 | An upsampling layer with an optional convolution. 54 | 55 | :param channels: channels in the inputs and outputs. 56 | :param use_conv: a bool determining if a convolution is applied. 57 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 58 | upsampling occurs in the inner-two dimensions. 59 | """ 60 | 61 | def __init__(self, channels, use_conv, dims=2): 62 | super().__init__() 63 | self.channels = channels 64 | self.use_conv = use_conv 65 | self.dims = dims 66 | if use_conv: 67 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 68 | 69 | def forward(self, x): 70 | assert x.shape[1] == self.channels 71 | if self.dims == 3: 72 | x = F.interpolate( 73 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 74 | ) 75 | else: 76 | x = F.interpolate(x, scale_factor=2, mode="nearest") 77 | if self.use_conv: 78 | x = self.conv(x) 79 | return x 80 | 81 | 82 | class Downsample(nn.Module): 83 | """ 84 | A downsampling layer with an optional convolution. 85 | 86 | :param channels: channels in the inputs and outputs. 87 | :param use_conv: a bool determining if a convolution is applied. 88 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 89 | downsampling occurs in the inner-two dimensions. 90 | """ 91 | 92 | def __init__(self, channels, use_conv, dims=2): 93 | super().__init__() 94 | self.channels = channels 95 | self.use_conv = use_conv 96 | self.dims = dims 97 | stride = 2 if dims != 3 else (1, 2, 2) 98 | if use_conv: 99 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 100 | else: 101 | self.op = avg_pool_nd(stride) 102 | 103 | def forward(self, x): 104 | assert x.shape[1] == self.channels 105 | return self.op(x) 106 | 107 | 108 | class ResBlock(TimestepBlock): 109 | """ 110 | A residual block that can optionally change the number of channels. 111 | 112 | :param channels: the number of input channels. 113 | :param emb_channels: the number of timestep embedding channels. 114 | :param dropout: the rate of dropout. 115 | :param out_channels: if specified, the number of out channels. 116 | :param use_conv: if True and out_channels is specified, use a spatial 117 | convolution instead of a smaller 1x1 convolution to change the 118 | channels in the skip connection. 119 | :param dims: determines if the signal is 1D, 2D, or 3D. 120 | :param use_checkpoint: if True, use gradient checkpointing on this module. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | channels, 126 | emb_channels, 127 | dropout, 128 | out_channels=None, 129 | use_conv=False, 130 | use_scale_shift_norm=False, 131 | dims=2, 132 | use_checkpoint=False, 133 | ): 134 | super().__init__() 135 | self.channels = channels 136 | self.emb_channels = emb_channels 137 | self.dropout = dropout 138 | self.out_channels = out_channels or channels 139 | self.use_conv = use_conv 140 | self.use_checkpoint = use_checkpoint 141 | self.use_scale_shift_norm = use_scale_shift_norm 142 | 143 | self.in_layers = nn.Sequential( 144 | normalization(channels), 145 | SiLU(), 146 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 147 | ) 148 | self.emb_layers = nn.Sequential( 149 | SiLU(), 150 | linear( 151 | emb_channels, 152 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 153 | ), 154 | ) 155 | self.out_layers = nn.Sequential( 156 | normalization(self.out_channels), 157 | SiLU(), 158 | nn.Dropout(p=dropout), 159 | zero_module( 160 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 161 | ), 162 | ) 163 | 164 | if self.out_channels == channels: 165 | self.skip_connection = nn.Identity() 166 | elif use_conv: 167 | self.skip_connection = conv_nd( 168 | dims, channels, self.out_channels, 3, padding=1 169 | ) 170 | else: 171 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 172 | 173 | def forward(self, x, emb): 174 | """ 175 | Apply the block to a Tensor, conditioned on a timestep embedding. 176 | 177 | :param x: an [N x C x ...] Tensor of features. 178 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 179 | :return: an [N x C x ...] Tensor of outputs. 180 | """ 181 | return checkpoint( 182 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 183 | ) 184 | 185 | def _forward(self, x, emb): 186 | h = self.in_layers(x) 187 | emb_out = self.emb_layers(emb).type(h.dtype) 188 | while len(emb_out.shape) < len(h.shape): 189 | emb_out = emb_out[..., None] 190 | if self.use_scale_shift_norm: 191 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 192 | scale, shift = th.chunk(emb_out, 2, dim=1) 193 | h = out_norm(h) * (1 + scale) + shift 194 | h = out_rest(h) 195 | else: 196 | h = h + emb_out 197 | h = self.out_layers(h) 198 | return self.skip_connection(x) + h 199 | 200 | 201 | class AttentionBlock(nn.Module): 202 | """ 203 | An attention block that allows spatial positions to attend to each other. 204 | 205 | Originally ported from here, but adapted to the N-d case. 206 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 207 | """ 208 | 209 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 210 | super().__init__() 211 | self.channels = channels 212 | self.num_heads = num_heads 213 | self.use_checkpoint = use_checkpoint 214 | 215 | self.norm = normalization(channels) 216 | self.qkv = conv_nd(1, channels, channels * 3, 1) 217 | self.attention = QKVAttention() 218 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 219 | 220 | def forward(self, x): 221 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 222 | 223 | def _forward(self, x): 224 | b, c, *spatial = x.shape 225 | x = x.reshape(b, c, -1) 226 | qkv = self.qkv(self.norm(x)) 227 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 228 | h = self.attention(qkv) 229 | h = h.reshape(b, -1, h.shape[-1]) 230 | h = self.proj_out(h) 231 | return (x + h).reshape(b, c, *spatial) 232 | 233 | 234 | class QKVAttention(nn.Module): 235 | """ 236 | A module which performs QKV attention. 237 | """ 238 | 239 | def forward(self, qkv): 240 | """ 241 | Apply QKV attention. 242 | 243 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 244 | :return: an [N x C x T] tensor after attention. 245 | """ 246 | ch = qkv.shape[1] // 3 247 | q, k, v = th.split(qkv, ch, dim=1) 248 | scale = 1 / math.sqrt(math.sqrt(ch)) 249 | weight = th.einsum( 250 | "bct,bcs->bts", q * scale, k * scale 251 | ) # More stable with f16 than dividing afterwards 252 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 253 | return th.einsum("bts,bcs->bct", weight, v) 254 | 255 | @staticmethod 256 | def count_flops(model, _x, y): 257 | """ 258 | A counter for the `thop` package to count the operations in an 259 | attention operation. 260 | 261 | Meant to be used like: 262 | 263 | macs, params = thop.profile( 264 | model, 265 | inputs=(inputs, timestamps), 266 | custom_ops={QKVAttention: QKVAttention.count_flops}, 267 | ) 268 | 269 | """ 270 | b, c, *spatial = y[0].shape 271 | num_spatial = int(np.prod(spatial)) 272 | # We perform two matmuls with the same number of ops. 273 | # The first computes the weight matrix, the second computes 274 | # the combination of the value vectors. 275 | matmul_ops = 2 * b * (num_spatial ** 2) * c 276 | model.total_ops += th.DoubleTensor([matmul_ops]) 277 | 278 | 279 | class UNetModel(nn.Module): 280 | """ 281 | The full UNet model with attention and timestep embedding. 282 | 283 | :param in_channels: channels in the input Tensor. 284 | :param model_channels: base channel count for the model. 285 | :param out_channels: channels in the output Tensor. 286 | :param num_res_blocks: number of residual blocks per downsample. 287 | :param attention_resolutions: a collection of downsample rates at which 288 | attention will take place. May be a set, list, or tuple. 289 | For example, if this contains 4, then at 4x downsampling, attention 290 | will be used. 291 | :param dropout: the dropout probability. 292 | :param channel_mult: channel multiplier for each level of the UNet. 293 | :param conv_resample: if True, use learned convolutions for upsampling and 294 | downsampling. 295 | :param dims: determines if the signal is 1D, 2D, or 3D. 296 | :param num_classes: if specified (as an int), then this model will be 297 | class-conditional with `num_classes` classes. 298 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 299 | :param num_heads: the number of attention heads in each attention layer. 300 | """ 301 | 302 | def __init__( 303 | self, 304 | in_channels, 305 | model_channels, 306 | out_channels, 307 | num_res_blocks, 308 | attention_resolutions, 309 | dropout=0, 310 | channel_mult=(1, 2, 4, 8), 311 | conv_resample=True, 312 | dims=2, 313 | num_classes=None, 314 | use_checkpoint=False, 315 | num_heads=1, 316 | num_heads_upsample=-1, 317 | use_scale_shift_norm=False, 318 | rrdb_blocks=3, 319 | ): 320 | super().__init__() 321 | 322 | if num_heads_upsample == -1: 323 | num_heads_upsample = num_heads 324 | 325 | self.in_channels = in_channels 326 | self.model_channels = model_channels 327 | self.out_channels = out_channels 328 | self.num_res_blocks = num_res_blocks 329 | self.attention_resolutions = attention_resolutions 330 | self.dropout = dropout 331 | self.channel_mult = channel_mult 332 | self.conv_resample = conv_resample 333 | self.num_classes = num_classes 334 | self.use_checkpoint = use_checkpoint 335 | self.num_heads = num_heads 336 | self.num_heads_upsample = num_heads_upsample 337 | 338 | time_embed_dim = model_channels * 4 339 | self.time_embed = nn.Sequential( 340 | linear(model_channels, time_embed_dim), 341 | SiLU(), 342 | linear(time_embed_dim, time_embed_dim), 343 | ) 344 | 345 | if self.num_classes is not None: 346 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 347 | self.rrdb = RRDBNet(nb=rrdb_blocks, out_nc=model_channels) 348 | self.input_blocks = nn.ModuleList( 349 | [ 350 | TimestepEmbedSequential( 351 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 352 | ) 353 | ] 354 | ) 355 | input_block_chans = [model_channels] 356 | ch = model_channels 357 | ds = 1 358 | for level, mult in enumerate(channel_mult): 359 | for _ in range(num_res_blocks): 360 | layers = [ 361 | ResBlock( 362 | ch, 363 | time_embed_dim, 364 | dropout, 365 | out_channels=mult * model_channels, 366 | dims=dims, 367 | use_checkpoint=use_checkpoint, 368 | use_scale_shift_norm=use_scale_shift_norm, 369 | ) 370 | ] 371 | ch = mult * model_channels 372 | if ds in attention_resolutions: 373 | layers.append( 374 | AttentionBlock( 375 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 376 | ) 377 | ) 378 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 379 | input_block_chans.append(ch) 380 | if level != len(channel_mult) - 1: 381 | self.input_blocks.append( 382 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 383 | ) 384 | input_block_chans.append(ch) 385 | ds *= 2 386 | 387 | self.middle_block = TimestepEmbedSequential( 388 | ResBlock( 389 | ch, 390 | time_embed_dim, 391 | dropout, 392 | dims=dims, 393 | use_checkpoint=use_checkpoint, 394 | use_scale_shift_norm=use_scale_shift_norm, 395 | ), 396 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 397 | ResBlock( 398 | ch, 399 | time_embed_dim, 400 | dropout, 401 | dims=dims, 402 | use_checkpoint=use_checkpoint, 403 | use_scale_shift_norm=use_scale_shift_norm, 404 | ), 405 | ) 406 | 407 | self.output_blocks = nn.ModuleList([]) 408 | for level, mult in list(enumerate(channel_mult))[::-1]: 409 | for i in range(num_res_blocks + 1): 410 | layers = [ 411 | ResBlock( 412 | ch + input_block_chans.pop(), 413 | time_embed_dim, 414 | dropout, 415 | out_channels=model_channels * mult, 416 | dims=dims, 417 | use_checkpoint=use_checkpoint, 418 | use_scale_shift_norm=use_scale_shift_norm, 419 | ) 420 | ] 421 | ch = model_channels * mult 422 | if ds in attention_resolutions: 423 | layers.append( 424 | AttentionBlock( 425 | ch, 426 | use_checkpoint=use_checkpoint, 427 | num_heads=num_heads_upsample, 428 | ) 429 | ) 430 | if level and i == num_res_blocks: 431 | layers.append(Upsample(ch, conv_resample, dims=dims)) 432 | ds //= 2 433 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 434 | 435 | self.out = nn.Sequential( 436 | normalization(ch), 437 | SiLU(), 438 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 439 | ) 440 | 441 | def convert_to_fp16(self): 442 | """ 443 | Convert the torso of the model to float16. 444 | """ 445 | self.input_blocks.apply(convert_module_to_f16) 446 | self.middle_block.apply(convert_module_to_f16) 447 | self.output_blocks.apply(convert_module_to_f16) 448 | self.rrdb.apply(convert_module_to_f16) 449 | 450 | def convert_to_fp32(self): 451 | """ 452 | Convert the torso of the model to float32. 453 | """ 454 | self.input_blocks.apply(convert_module_to_f32) 455 | self.middle_block.apply(convert_module_to_f32) 456 | self.output_blocks.apply(convert_module_to_f32) 457 | self.rrdb.apply(convert_module_to_f32) 458 | 459 | @property 460 | def inner_dtype(self): 461 | """ 462 | Get the dtype used by the torso of the model. 463 | """ 464 | return next(self.input_blocks.parameters()).dtype 465 | 466 | def forward(self, x, timesteps, y=None, conditioned_image=None): 467 | """ 468 | Apply the model to an input batch. 469 | 470 | :param x: an [N x C x ...] Tensor of inputs. 471 | :param timesteps: a 1-D batch of timesteps. 472 | :param y: an [N] Tensor of labels, if class-conditional. 473 | :return: an [N x C x ...] Tensor of outputs. 474 | """ 475 | assert (y is not None) == ( 476 | self.num_classes is not None 477 | ), "must specify y if and only if the model is class-conditional" 478 | 479 | hs = [] 480 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 481 | 482 | if self.num_classes is not None: 483 | assert y.shape == (x.shape[0],) 484 | emb = emb + self.label_emb(y) 485 | former_frames_features = self.rrdb(conditioned_image.type(self.inner_dtype)) 486 | h = x.type(self.inner_dtype) 487 | for i, module in enumerate(self.input_blocks): 488 | h = module(h, emb) 489 | if i == 0: 490 | h = h + former_frames_features 491 | hs.append(h) 492 | h = self.middle_block(h, emb) 493 | for module in self.output_blocks: 494 | cat_in = th.cat([h, hs.pop()], dim=1) 495 | h = module(cat_in, emb) 496 | h = h.type(x.dtype) 497 | return self.out(h) 498 | 499 | def get_feature_vectors(self, x, timesteps, y=None): 500 | """ 501 | Apply the model and return all of the intermediate tensors. 502 | 503 | :param x: an [N x C x ...] Tensor of inputs. 504 | :param timesteps: a 1-D batch of timesteps. 505 | :param y: an [N] Tensor of labels, if class-conditional. 506 | :return: a dict with the following keys: 507 | - 'down': a list of hidden state tensors from downsampling. 508 | - 'middle': the tensor of the output of the lowest-resolution 509 | block in the model. 510 | - 'up': a list of hidden state tensors from upsampling. 511 | """ 512 | hs = [] 513 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 514 | if self.num_classes is not None: 515 | assert y.shape == (x.shape[0],) 516 | emb = emb + self.label_emb(y) 517 | result = dict(down=[], up=[]) 518 | h = x.type(self.inner_dtype) 519 | for module in self.input_blocks: 520 | h = module(h, emb) 521 | hs.append(h) 522 | result["down"].append(h.type(x.dtype)) 523 | h = self.middle_block(h, emb) 524 | result["middle"] = h.type(x.dtype) 525 | for module in self.output_blocks: 526 | cat_in = th.cat([h, hs.pop()], dim=1) 527 | h = module(cat_in, emb) 528 | result["up"].append(h.type(x.dtype)) 529 | return result 530 | 531 | 532 | class SuperResModel(UNetModel): 533 | """ 534 | A UNetModel that performs super-resolution. 535 | 536 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 537 | """ 538 | 539 | def __init__(self, in_channels, *args, **kwargs): 540 | super().__init__(in_channels * 2, *args, **kwargs) 541 | 542 | def forward(self, x, timesteps, low_res=None, **kwargs): 543 | _, _, new_height, new_width = x.shape 544 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="nearest") 545 | x = th.cat([x, upsampled], dim=1) 546 | return super().forward(x, timesteps, **kwargs) 547 | 548 | def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): 549 | _, new_height, new_width, _ = x.shape 550 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="nearest") 551 | x = th.cat([x, upsampled], dim=1) 552 | return super().get_feature_vectors(x, timesteps, **kwargs) 553 | 554 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import sys 5 | import random 6 | from PIL import Image 7 | 8 | try: 9 | import accimage 10 | except ImportError: 11 | accimage = None 12 | import numpy as np 13 | import numbers 14 | import types 15 | import collections 16 | import warnings 17 | 18 | from torchvision.transforms import functional as F 19 | 20 | if sys.version_info < (3, 3): 21 | Sequence = collections.Sequence 22 | Iterable = collections.Iterable 23 | else: 24 | Sequence = collections.abc.Sequence 25 | Iterable = collections.abc.Iterable 26 | 27 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad", 28 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", 29 | "RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop", 30 | "ColorJitter", "RandomRotation", "RandomAffine", 31 | "RandomPerspective"] 32 | 33 | _pil_interpolation_to_str = { 34 | Image.NEAREST: 'PIL.Image.NEAREST', 35 | Image.BILINEAR: 'PIL.Image.BILINEAR', 36 | Image.BICUBIC: 'PIL.Image.BICUBIC', 37 | Image.LANCZOS: 'PIL.Image.LANCZOS', 38 | Image.HAMMING: 'PIL.Image.HAMMING', 39 | Image.BOX: 'PIL.Image.BOX', 40 | } 41 | 42 | 43 | class Compose(object): 44 | def __init__(self, transforms): 45 | self.transforms = transforms 46 | 47 | def __call__(self, img, mask): 48 | for t in self.transforms: 49 | img, mask = t(img, mask) 50 | return img, mask 51 | 52 | 53 | class ToTensor(object): 54 | def __call__(self, img, mask): 55 | # return F.to_tensor(img), F.to_tensor(mask) 56 | img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() 57 | mask = torch.from_numpy(np.array(mask)).float() 58 | return img, mask 59 | 60 | 61 | class ToPILImage(object): 62 | def __init__(self, mode=None): 63 | self.mode = mode 64 | 65 | def __call__(self, img, mask): 66 | return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode) 67 | 68 | 69 | class Normalize(object): 70 | def __init__(self, mean, std, inplace=False): 71 | self.mean = mean 72 | self.std = std 73 | self.inplace = inplace 74 | 75 | def __call__(self, img, mask): 76 | return F.normalize(img, self.mean, self.std, self.inplace), mask 77 | 78 | 79 | class Resize(object): 80 | def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True): 81 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 82 | self.size = size 83 | self.interpolation = interpolation 84 | self.do_mask = do_mask 85 | 86 | def __call__(self, img, mask): 87 | if self.do_mask: 88 | return F.resize(img, self.size, self.interpolation), F.resize(mask, self.size, Image.NEAREST) 89 | else: 90 | return F.resize(img, self.size, self.interpolation), mask 91 | 92 | 93 | class CenterCrop(object): 94 | def __init__(self, size): 95 | if isinstance(size, numbers.Number): 96 | self.size = (int(size), int(size)) 97 | else: 98 | self.size = size 99 | 100 | def __call__(self, img, mask): 101 | return F.center_crop(img, self.size), F.center_crop(mask, self.size) 102 | 103 | 104 | class Pad(object): 105 | def __init__(self, padding, fill=0, padding_mode='constant'): 106 | assert isinstance(padding, (numbers.Number, tuple)) 107 | assert isinstance(fill, (numbers.Number, str, tuple)) 108 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 109 | if isinstance(padding, Sequence) and len(padding) not in [2, 4]: 110 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 111 | "{} element tuple".format(len(padding))) 112 | 113 | self.padding = padding 114 | self.fill = fill 115 | self.padding_mode = padding_mode 116 | 117 | def __call__(self, img, mask): 118 | return F.pad(img, self.padding, self.fill, self.padding_mode), \ 119 | F.pad(mask, self.padding, self.fill, self.padding_mode) 120 | 121 | 122 | class Lambda(object): 123 | def __init__(self, lambd): 124 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 125 | self.lambd = lambd 126 | 127 | def __call__(self, img, mask): 128 | return self.lambd(img), self.lambd(mask) 129 | 130 | 131 | class Lambda_image(object): 132 | def __init__(self, lambd): 133 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 134 | self.lambd = lambd 135 | 136 | def __call__(self, img, mask): 137 | return self.lambd(img), mask 138 | 139 | 140 | class RandomTransforms(object): 141 | def __init__(self, transforms): 142 | assert isinstance(transforms, (list, tuple)) 143 | self.transforms = transforms 144 | 145 | def __call__(self, *args, **kwargs): 146 | raise NotImplementedError() 147 | 148 | 149 | class RandomApply(RandomTransforms): 150 | def __init__(self, transforms, p=0.5): 151 | super(RandomApply, self).__init__(transforms) 152 | self.p = p 153 | 154 | def __call__(self, img, mask): 155 | if self.p < random.random(): 156 | return img, mask 157 | for t in self.transforms: 158 | img, mask = t(img, mask) 159 | return img, mask 160 | 161 | 162 | class RandomOrder(RandomTransforms): 163 | def __call__(self, img, mask): 164 | order = list(range(len(self.transforms))) 165 | random.shuffle(order) 166 | for i in order: 167 | img, mask = self.transforms[i](img, mask) 168 | return img, mask 169 | 170 | 171 | class RandomChoice(RandomTransforms): 172 | def __call__(self, img, mask): 173 | t = random.choice(self.transforms) 174 | return t(img, mask) 175 | 176 | 177 | class RandomCrop(object): 178 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 179 | if isinstance(size, numbers.Number): 180 | self.size = (int(size), int(size)) 181 | else: 182 | self.size = size 183 | self.padding = padding 184 | self.pad_if_needed = pad_if_needed 185 | self.fill = fill 186 | self.padding_mode = padding_mode 187 | 188 | @staticmethod 189 | def get_params(img, output_size): 190 | w, h = img.size 191 | th, tw = output_size 192 | if w == tw and h == th: 193 | return 0, 0, h, w 194 | 195 | i = random.randint(0, h - th) 196 | j = random.randint(0, w - tw) 197 | return i, j, th, tw 198 | 199 | def __call__(self, img, mask): 200 | if self.padding is not None: 201 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 202 | 203 | # pad the width if needed 204 | if self.pad_if_needed and img.size[0] < self.size[1]: 205 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 206 | # pad the height if needed 207 | if self.pad_if_needed and img.size[1] < self.size[0]: 208 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 209 | 210 | i, j, h, w = self.get_params(img, self.size) 211 | 212 | return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w) 213 | 214 | 215 | class RandomHorizontalFlip(object): 216 | def __init__(self, p=0.5): 217 | self.p = p 218 | 219 | def __call__(self, img, mask): 220 | if random.random() < self.p: 221 | return F.hflip(img), F.hflip(mask) 222 | return img, mask 223 | 224 | 225 | class RandomVerticalFlip(object): 226 | def __init__(self, p=0.5): 227 | self.p = p 228 | 229 | def __call__(self, img, mask): 230 | if random.random() < self.p: 231 | return F.vflip(img), F.vflip(mask) 232 | return img, mask 233 | 234 | 235 | class RandomPerspective(object): 236 | def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): 237 | self.p = p 238 | self.interpolation = interpolation 239 | self.distortion_scale = distortion_scale 240 | 241 | def __call__(self, img, mask): 242 | if not F._is_pil_image(img): 243 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 244 | 245 | if random.random() < self.p: 246 | width, height = img.size 247 | startpoints, endpoints = self.get_params(width, height, self.distortion_scale) 248 | return F.perspective(img, startpoints, endpoints, self.interpolation), \ 249 | F.perspective(mask, startpoints, endpoints, Image.NEAREST) 250 | return img, mask 251 | 252 | @staticmethod 253 | def get_params(width, height, distortion_scale): 254 | half_height = int(height / 2) 255 | half_width = int(width / 2) 256 | topleft = (random.randint(0, int(distortion_scale * half_width)), 257 | random.randint(0, int(distortion_scale * half_height))) 258 | topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 259 | random.randint(0, int(distortion_scale * half_height))) 260 | botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), 261 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 262 | botleft = (random.randint(0, int(distortion_scale * half_width)), 263 | random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) 264 | startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] 265 | endpoints = [topleft, topright, botright, botleft] 266 | return startpoints, endpoints 267 | 268 | 269 | class RandomResizedCrop(object): 270 | def __init__(self, size, mask_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 271 | if isinstance(size, tuple): 272 | self.size = size 273 | self.mask_size = mask_size 274 | else: 275 | self.size = (size, size) 276 | self.mask_size = (mask_size, mask_size) 277 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 278 | warnings.warn("range should be of kind (min, max)") 279 | 280 | self.interpolation = interpolation 281 | self.scale = scale 282 | self.ratio = ratio 283 | 284 | @staticmethod 285 | def get_params(img, scale, ratio): 286 | area = img.size[0] * img.size[1] 287 | 288 | for attempt in range(10): 289 | target_area = random.uniform(*scale) * area 290 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 291 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 292 | 293 | w = int(round(math.sqrt(target_area * aspect_ratio))) 294 | h = int(round(math.sqrt(target_area / aspect_ratio))) 295 | 296 | if w <= img.size[0] and h <= img.size[1]: 297 | i = random.randint(0, img.size[1] - h) 298 | j = random.randint(0, img.size[0] - w) 299 | return i, j, h, w 300 | 301 | # Fallback to central crop 302 | in_ratio = img.size[0] / img.size[1] 303 | if (in_ratio < min(ratio)): 304 | w = img.size[0] 305 | h = w / min(ratio) 306 | elif (in_ratio > max(ratio)): 307 | h = img.size[1] 308 | w = h * max(ratio) 309 | else: # whole image 310 | w = img.size[0] 311 | h = img.size[1] 312 | i = (img.size[1] - h) // 2 313 | j = (img.size[0] - w) // 2 314 | return i, j, h, w 315 | 316 | def __call__(self, img, mask): 317 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 318 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \ 319 | F.resized_crop(mask, i, j, h, w, self.mask_size, Image.NEAREST) 320 | 321 | 322 | class FiveCrop(object): 323 | def __init__(self, size): 324 | self.size = size 325 | if isinstance(size, numbers.Number): 326 | self.size = (int(size), int(size)) 327 | else: 328 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 329 | self.size = size 330 | 331 | def __call__(self, img, mask): 332 | return F.five_crop(img, self.size), F.five_crop(mask, self.size) 333 | 334 | 335 | class TenCrop(object): 336 | def __init__(self, size, vertical_flip=False): 337 | self.size = size 338 | if isinstance(size, numbers.Number): 339 | self.size = (int(size), int(size)) 340 | else: 341 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 342 | self.size = size 343 | self.vertical_flip = vertical_flip 344 | 345 | def __call__(self, img, mask): 346 | return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip) 347 | 348 | 349 | class ColorJitter(object): 350 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 351 | self.brightness = self._check_input(brightness, 'brightness') 352 | self.contrast = self._check_input(contrast, 'contrast') 353 | self.saturation = self._check_input(saturation, 'saturation') 354 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 355 | clip_first_on_zero=False) 356 | 357 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 358 | if isinstance(value, numbers.Number): 359 | if value < 0: 360 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 361 | value = [center - value, center + value] 362 | if clip_first_on_zero: 363 | value[0] = max(value[0], 0) 364 | elif isinstance(value, (tuple, list)) and len(value) == 2: 365 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 366 | raise ValueError("{} values should be between {}".format(name, bound)) 367 | else: 368 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 369 | 370 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 371 | # or (0., 0.) for hue, do nothing 372 | if value[0] == value[1] == center: 373 | value = None 374 | return value 375 | 376 | @staticmethod 377 | def get_params(brightness, contrast, saturation, hue): 378 | transforms = [] 379 | 380 | if brightness is not None: 381 | brightness_factor = random.uniform(brightness[0], brightness[1]) 382 | transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor))) 383 | 384 | if contrast is not None: 385 | contrast_factor = random.uniform(contrast[0], contrast[1]) 386 | transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor))) 387 | 388 | if saturation is not None: 389 | saturation_factor = random.uniform(saturation[0], saturation[1]) 390 | transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor))) 391 | 392 | if hue is not None: 393 | hue_factor = random.uniform(hue[0], hue[1]) 394 | transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor))) 395 | 396 | random.shuffle(transforms) 397 | transform = Compose(transforms) 398 | 399 | return transform 400 | 401 | def __call__(self, img, mask): 402 | transform = self.get_params(self.brightness, self.contrast, 403 | self.saturation, self.hue) 404 | return transform(img, mask) 405 | 406 | 407 | class RandomRotation(object): 408 | def __init__(self, degrees, resample=False, expand=False, center=None): 409 | if isinstance(degrees, numbers.Number): 410 | if degrees < 0: 411 | raise ValueError("If degrees is a single number, it must be positive.") 412 | self.degrees = (-degrees, degrees) 413 | else: 414 | if len(degrees) != 2: 415 | raise ValueError("If degrees is a sequence, it must be of len 2.") 416 | self.degrees = degrees 417 | 418 | self.resample = resample 419 | self.expand = expand 420 | self.center = center 421 | 422 | @staticmethod 423 | def get_params(degrees): 424 | angle = random.uniform(degrees[0], degrees[1]) 425 | 426 | return angle 427 | 428 | def __call__(self, img, mask): 429 | angle = self.get_params(self.degrees) 430 | 431 | return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \ 432 | F.rotate(mask, angle, Image.NEAREST, self.expand, self.center) 433 | 434 | 435 | class RandomAffine(object): 436 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 437 | if isinstance(degrees, numbers.Number): 438 | if degrees < 0: 439 | raise ValueError("If degrees is a single number, it must be positive.") 440 | self.degrees = (-degrees, degrees) 441 | else: 442 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 443 | "degrees should be a list or tuple and it must be of length 2." 444 | self.degrees = degrees 445 | 446 | if translate is not None: 447 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 448 | "translate should be a list or tuple and it must be of length 2." 449 | for t in translate: 450 | if not (0.0 <= t <= 1.0): 451 | raise ValueError("translation values should be between 0 and 1") 452 | self.translate = translate 453 | 454 | if scale is not None: 455 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 456 | "scale should be a list or tuple and it must be of length 2." 457 | for s in scale: 458 | if s <= 0: 459 | raise ValueError("scale values should be positive") 460 | self.scale = scale 461 | 462 | if shear is not None: 463 | if isinstance(shear, numbers.Number): 464 | if shear < 0: 465 | raise ValueError("If shear is a single number, it must be positive.") 466 | self.shear = (-shear, shear) 467 | else: 468 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 469 | "shear should be a list or tuple and it must be of length 2." 470 | self.shear = shear 471 | else: 472 | self.shear = shear 473 | 474 | self.resample = resample 475 | self.fillcolor = fillcolor 476 | 477 | @staticmethod 478 | def get_params(degrees, translate, scale_ranges, shears, img_size): 479 | angle = random.uniform(degrees[0], degrees[1]) 480 | if translate is not None: 481 | max_dx = translate[0] * img_size[0] 482 | max_dy = translate[1] * img_size[1] 483 | translations = (np.round(random.uniform(-max_dx, max_dx)), 484 | np.round(random.uniform(-max_dy, max_dy))) 485 | else: 486 | translations = (0, 0) 487 | 488 | if scale_ranges is not None: 489 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 490 | else: 491 | scale = 1.0 492 | 493 | if shears is not None: 494 | shear = random.uniform(shears[0], shears[1]) 495 | else: 496 | shear = 0.0 497 | 498 | return angle, translations, scale, shear 499 | 500 | def __call__(self, img, mask): 501 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 502 | return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \ 503 | F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor) 504 | 505 | class RandomAffineFromSet(object): 506 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 507 | assert isinstance(degrees, (tuple, list)), \ 508 | "degrees should be a list or tuple." 509 | self.degrees = degrees 510 | 511 | if translate is not None: 512 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 513 | "translate should be a list or tuple and it must be of length 2." 514 | for t in translate: 515 | if not (0.0 <= t <= 1.0): 516 | raise ValueError("translation values should be between 0 and 1") 517 | self.translate = translate 518 | 519 | if scale is not None: 520 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 521 | "scale should be a list or tuple and it must be of length 2." 522 | for s in scale: 523 | if s <= 0: 524 | raise ValueError("scale values should be positive") 525 | self.scale = scale 526 | 527 | if shear is not None: 528 | if isinstance(shear, numbers.Number): 529 | if shear < 0: 530 | raise ValueError("If shear is a single number, it must be positive.") 531 | self.shear = (-shear, shear) 532 | else: 533 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 534 | "shear should be a list or tuple and it must be of length 2." 535 | self.shear = shear 536 | else: 537 | self.shear = shear 538 | 539 | self.resample = resample 540 | self.fillcolor = fillcolor 541 | 542 | @staticmethod 543 | def get_params(degrees, translate, scale_ranges, shears, img_size): 544 | angle = random.choice(degrees) 545 | if translate is not None: 546 | max_dx = translate[0] * img_size[0] 547 | max_dy = translate[1] * img_size[1] 548 | translations = (np.round(random.uniform(-max_dx, max_dx)), 549 | np.round(random.uniform(-max_dy, max_dy))) 550 | else: 551 | translations = (0, 0) 552 | 553 | if scale_ranges is not None: 554 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 555 | else: 556 | scale = 1.0 557 | 558 | if shears is not None: 559 | shear = random.uniform(shears[0], shears[1]) 560 | else: 561 | shear = 0.0 562 | 563 | return angle, translations, scale, shear 564 | 565 | def __call__(self, img, mask): 566 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 567 | return F.affine(img, *ret, resample=Image.BILINEAR, fillcolor=self.fillcolor), \ 568 | F.affine(mask, *ret, resample=Image.NEAREST, fillcolor=self.fillcolor) -------------------------------------------------------------------------------- /improved_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | 22 | The beta schedule library consists of beta schedules which remain similar 23 | in the limit of num_diffusion_timesteps. 24 | Beta schedules may be added, but should not be removed or changed once 25 | they are committed to maintain backwards compatibility. 26 | """ 27 | if schedule_name == "linear": 28 | # Linear schedule from Ho et al, extended to work for any number of 29 | # diffusion steps. 30 | scale = 1000 / num_diffusion_timesteps 31 | beta_start = scale * 0.0001 32 | beta_end = scale * 0.02 33 | return np.linspace( 34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 35 | ) 36 | elif schedule_name == "cosine": 37 | return betas_for_alpha_bar( 38 | num_diffusion_timesteps, 39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 40 | ) 41 | else: 42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 43 | 44 | 45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 46 | """ 47 | Create a beta schedule that discretizes the given alpha_t_bar function, 48 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 49 | 50 | :param num_diffusion_timesteps: the number of betas to produce. 51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 52 | produces the cumulative product of (1-beta) up to that 53 | part of the diffusion process. 54 | :param max_beta: the maximum beta to use; use values lower than 1 to 55 | prevent singularities. 56 | """ 57 | betas = [] 58 | for i in range(num_diffusion_timesteps): 59 | t1 = i / num_diffusion_timesteps 60 | t2 = (i + 1) / num_diffusion_timesteps 61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 62 | return np.array(betas) 63 | 64 | 65 | class ModelMeanType(enum.Enum): 66 | """ 67 | Which type of output the model predicts. 68 | """ 69 | 70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 71 | START_X = enum.auto() # the model predicts x_0 72 | EPSILON = enum.auto() # the model predicts epsilon 73 | 74 | 75 | class ModelVarType(enum.Enum): 76 | """ 77 | What is used as the model's output variance. 78 | 79 | The LEARNED_RANGE option has been added to allow the model to predict 80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 81 | """ 82 | 83 | LEARNED = enum.auto() 84 | FIXED_SMALL = enum.auto() 85 | FIXED_LARGE = enum.auto() 86 | LEARNED_RANGE = enum.auto() 87 | 88 | 89 | class LossType(enum.Enum): 90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 91 | RESCALED_MSE = ( 92 | enum.auto() 93 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 94 | KL = enum.auto() # use the variational lower-bound 95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 96 | 97 | def is_vb(self): 98 | return self == LossType.KL or self == LossType.RESCALED_KL 99 | 100 | 101 | class GaussianDiffusion: 102 | """ 103 | Utilities for training and sampling diffusion models. 104 | 105 | Ported directly from here, and then adapted over time to further experimentation. 106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 107 | 108 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 109 | starting at T and going to 1. 110 | :param model_mean_type: a ModelMeanType determining what the model outputs. 111 | :param model_var_type: a ModelVarType determining how variance is output. 112 | :param loss_type: a LossType determining the loss function to use. 113 | :param rescale_timesteps: if True, pass floating point timesteps into the 114 | model so that they are always scaled like in the 115 | original paper (0 to 1000). 116 | """ 117 | 118 | def __init__( 119 | self, 120 | *, 121 | betas, 122 | model_mean_type, 123 | model_var_type, 124 | loss_type, 125 | rescale_timesteps=False, 126 | ): 127 | self.model_mean_type = model_mean_type 128 | self.model_var_type = model_var_type 129 | self.loss_type = loss_type 130 | self.rescale_timesteps = rescale_timesteps 131 | 132 | # Use float64 for accuracy. 133 | betas = np.array(betas, dtype=np.float64) 134 | self.betas = betas 135 | assert len(betas.shape) == 1, "betas must be 1-D" 136 | assert (betas > 0).all() and (betas <= 1).all() 137 | 138 | self.num_timesteps = int(betas.shape[0]) 139 | 140 | alphas = 1.0 - betas 141 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 142 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 143 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 144 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 145 | 146 | # calculations for diffusion q(x_t | x_{t-1}) and others 147 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 148 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 149 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 150 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 151 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 152 | 153 | # calculations for posterior q(x_{t-1} | x_t, x_0) 154 | self.posterior_variance = ( 155 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 156 | ) 157 | # log calculation clipped because the posterior variance is 0 at the 158 | # beginning of the diffusion chain. 159 | self.posterior_log_variance_clipped = np.log( 160 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 161 | ) 162 | self.posterior_mean_coef1 = ( 163 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 164 | ) 165 | self.posterior_mean_coef2 = ( 166 | (1.0 - self.alphas_cumprod_prev) 167 | * np.sqrt(alphas) 168 | / (1.0 - self.alphas_cumprod) 169 | ) 170 | 171 | def q_mean_variance(self, x_start, t): 172 | """ 173 | Get the distribution q(x_t | x_0). 174 | 175 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 176 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 177 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 178 | """ 179 | mean = ( 180 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 181 | ) 182 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 183 | log_variance = _extract_into_tensor( 184 | self.log_one_minus_alphas_cumprod, t, x_start.shape 185 | ) 186 | return mean, variance, log_variance 187 | 188 | def q_sample(self, x_start, t, noise=None): 189 | """ 190 | Diffuse the data for a given number of diffusion steps. 191 | 192 | In other words, sample from q(x_t | x_0). 193 | 194 | :param x_start: the initial data batch. 195 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 196 | :param noise: if specified, the split-out normal noise. 197 | :return: A noisy version of x_start. 198 | """ 199 | if noise is None: 200 | noise = th.randn_like(x_start) 201 | assert noise.shape == x_start.shape 202 | return ( 203 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 204 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 205 | * noise 206 | ) 207 | 208 | def q_posterior_mean_variance(self, x_start, x_t, t): 209 | """ 210 | Compute the mean and variance of the diffusion posterior: 211 | 212 | q(x_{t-1} | x_t, x_0) 213 | 214 | """ 215 | assert x_start.shape == x_t.shape 216 | posterior_mean = ( 217 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 218 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 219 | ) 220 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 221 | posterior_log_variance_clipped = _extract_into_tensor( 222 | self.posterior_log_variance_clipped, t, x_t.shape 223 | ) 224 | assert ( 225 | posterior_mean.shape[0] 226 | == posterior_variance.shape[0] 227 | == posterior_log_variance_clipped.shape[0] 228 | == x_start.shape[0] 229 | ) 230 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 231 | 232 | def p_mean_variance( 233 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 234 | ): 235 | """ 236 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 237 | the initial x, x_0. 238 | 239 | :param model: the model, which takes a signal and a batch of timesteps 240 | as input. 241 | :param x: the [N x C x ...] tensor at time t. 242 | :param t: a 1-D Tensor of timesteps. 243 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 244 | :param denoised_fn: if not None, a function which applies to the 245 | x_start prediction before it is used to sample. Applies before 246 | clip_denoised. 247 | :param model_kwargs: if not None, a dict of extra keyword arguments to 248 | pass to the model. This can be used for conditioning. 249 | :return: a dict with the following keys: 250 | - 'mean': the model mean output. 251 | - 'variance': the model variance output. 252 | - 'log_variance': the log of 'variance'. 253 | - 'pred_xstart': the prediction for x_0. 254 | """ 255 | if model_kwargs is None: 256 | model_kwargs = {} 257 | 258 | B, C = x.shape[:2] 259 | assert t.shape == (B,) 260 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 261 | 262 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 263 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 264 | model_output, model_var_values = th.split(model_output, C, dim=1) 265 | if self.model_var_type == ModelVarType.LEARNED: 266 | model_log_variance = model_var_values 267 | model_variance = th.exp(model_log_variance) 268 | else: 269 | min_log = _extract_into_tensor( 270 | self.posterior_log_variance_clipped, t, x.shape 271 | ) 272 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 273 | # The model_var_values is [-1, 1] for [min_var, max_var]. 274 | frac = (model_var_values + 1) / 2 275 | model_log_variance = frac * max_log + (1 - frac) * min_log 276 | model_variance = th.exp(model_log_variance) 277 | else: 278 | model_variance, model_log_variance = { 279 | # for fixedlarge, we set the initial (log-)variance like so 280 | # to get a better decoder log likelihood. 281 | ModelVarType.FIXED_LARGE: ( 282 | np.append(self.posterior_variance[1], self.betas[1:]), 283 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 284 | ), 285 | ModelVarType.FIXED_SMALL: ( 286 | self.posterior_variance, 287 | self.posterior_log_variance_clipped, 288 | ), 289 | }[self.model_var_type] 290 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 291 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 292 | 293 | def process_xstart(x): 294 | if denoised_fn is not None: 295 | x = denoised_fn(x) 296 | if clip_denoised: 297 | return x.clamp(-1, 1) 298 | return x 299 | 300 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 301 | pred_xstart = process_xstart( 302 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 303 | ) 304 | model_mean = model_output 305 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 306 | if self.model_mean_type == ModelMeanType.START_X: 307 | pred_xstart = process_xstart(model_output) 308 | else: 309 | pred_xstart = process_xstart( 310 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 311 | ) 312 | model_mean, _, _ = self.q_posterior_mean_variance( 313 | x_start=pred_xstart, x_t=x, t=t 314 | ) 315 | else: 316 | raise NotImplementedError(self.model_mean_type) 317 | 318 | assert ( 319 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 320 | ) 321 | return { 322 | "mean": model_mean, 323 | "variance": model_variance, 324 | "log_variance": model_log_variance, 325 | "pred_xstart": pred_xstart, 326 | } 327 | 328 | def _predict_xstart_from_eps(self, x_t, t, eps): 329 | assert x_t.shape == eps.shape 330 | return ( 331 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 332 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 333 | ) 334 | 335 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 336 | assert x_t.shape == xprev.shape 337 | return ( # (xprev - coef2*x_t) / coef1 338 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 339 | - _extract_into_tensor( 340 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 341 | ) 342 | * x_t 343 | ) 344 | 345 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 346 | return ( 347 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 348 | - pred_xstart 349 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 350 | 351 | def _scale_timesteps(self, t): 352 | if self.rescale_timesteps: 353 | return t.float() * (1000.0 / self.num_timesteps) 354 | return t 355 | 356 | def p_sample( 357 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 358 | ): 359 | """ 360 | Sample x_{t-1} from the model at the given timestep. 361 | 362 | :param model: the model to sample from. 363 | :param x: the current tensor at x_{t-1}. 364 | :param t: the value of t, starting at 0 for the first diffusion step. 365 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 366 | :param denoised_fn: if not None, a function which applies to the 367 | x_start prediction before it is used to sample. 368 | :param model_kwargs: if not None, a dict of extra keyword arguments to 369 | pass to the model. This can be used for conditioning. 370 | :return: a dict containing the following keys: 371 | - 'sample': a random sample from the model. 372 | - 'pred_xstart': a prediction of x_0. 373 | """ 374 | out = self.p_mean_variance( 375 | model, 376 | x, 377 | t, 378 | clip_denoised=clip_denoised, 379 | denoised_fn=denoised_fn, 380 | model_kwargs=model_kwargs, 381 | ) 382 | noise = th.randn_like(x) 383 | nonzero_mask = ( 384 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 385 | ) # no noise when t == 0 386 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 387 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 388 | 389 | def p_sample_loop( 390 | self, 391 | model, 392 | shape, 393 | noise=None, 394 | clip_denoised=True, 395 | denoised_fn=None, 396 | model_kwargs=None, 397 | device=None, 398 | progress=False, 399 | ): 400 | """ 401 | Generate samples from the model. 402 | 403 | :param model: the model module. 404 | :param shape: the shape of the samples, (N, C, H, W). 405 | :param noise: if specified, the noise from the encoder to sample. 406 | Should be of the same shape as `shape`. 407 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 408 | :param denoised_fn: if not None, a function which applies to the 409 | x_start prediction before it is used to sample. 410 | :param model_kwargs: if not None, a dict of extra keyword arguments to 411 | pass to the model. This can be used for conditioning. 412 | :param device: if specified, the device to create the samples on. 413 | If not specified, use a model parameter's device. 414 | :param progress: if True, show a tqdm progress bar. 415 | :return: a non-differentiable batch of samples. 416 | """ 417 | final = None 418 | for sample in self.p_sample_loop_progressive( 419 | model, 420 | shape, 421 | noise=noise, 422 | clip_denoised=clip_denoised, 423 | denoised_fn=denoised_fn, 424 | model_kwargs=model_kwargs, 425 | device=device, 426 | progress=progress, 427 | ): 428 | final = sample 429 | return final["sample"] 430 | 431 | def p_sample_loop_progressive( 432 | self, 433 | model, 434 | shape, 435 | noise=None, 436 | clip_denoised=True, 437 | denoised_fn=None, 438 | model_kwargs=None, 439 | device=None, 440 | progress=False, 441 | ): 442 | """ 443 | Generate samples from the model and yield intermediate samples from 444 | each timestep of diffusion. 445 | 446 | Arguments are the same as p_sample_loop(). 447 | Returns a generator over dicts, where each dict is the return value of 448 | p_sample(). 449 | """ 450 | if device is None: 451 | device = next(model.parameters()).device 452 | assert isinstance(shape, (tuple, list)) 453 | if noise is not None: 454 | img = noise 455 | else: 456 | img = th.randn(*shape).to(device=device) 457 | indices = list(range(self.num_timesteps))[::-1] 458 | 459 | if progress: 460 | # Lazy import so that we don't depend on tqdm. 461 | from tqdm.auto import tqdm 462 | 463 | indices = tqdm(indices) 464 | 465 | for i in indices: 466 | t = th.tensor([i] * shape[0], device=device) 467 | with th.no_grad(): 468 | out = self.p_sample( 469 | model, 470 | img, 471 | t, 472 | clip_denoised=clip_denoised, 473 | denoised_fn=denoised_fn, 474 | model_kwargs=model_kwargs, 475 | ) 476 | yield out 477 | img = out["sample"] 478 | 479 | def ddim_sample( 480 | self, 481 | model, 482 | x, 483 | t, 484 | clip_denoised=True, 485 | denoised_fn=None, 486 | model_kwargs=None, 487 | eta=0.0, 488 | ): 489 | """ 490 | Sample x_{t-1} from the model using DDIM. 491 | 492 | Same usage as p_sample(). 493 | """ 494 | out = self.p_mean_variance( 495 | model, 496 | x, 497 | t, 498 | clip_denoised=clip_denoised, 499 | denoised_fn=denoised_fn, 500 | model_kwargs=model_kwargs, 501 | ) 502 | # Usually our model outputs epsilon, but we re-derive it 503 | # in case we used x_start or x_prev prediction. 504 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 505 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 506 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 507 | sigma = ( 508 | eta 509 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 510 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 511 | ) 512 | # Equation 12. 513 | noise = th.randn_like(x) 514 | mean_pred = ( 515 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 516 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 517 | ) 518 | nonzero_mask = ( 519 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 520 | ) # no noise when t == 0 521 | sample = mean_pred + nonzero_mask * sigma * noise 522 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 523 | 524 | def ddim_reverse_sample( 525 | self, 526 | model, 527 | x, 528 | t, 529 | clip_denoised=True, 530 | denoised_fn=None, 531 | model_kwargs=None, 532 | eta=0.0, 533 | ): 534 | """ 535 | Sample x_{t+1} from the model using DDIM reverse ODE. 536 | """ 537 | assert eta == 0.0, "Reverse ODE only for deterministic path" 538 | out = self.p_mean_variance( 539 | model, 540 | x, 541 | t, 542 | clip_denoised=clip_denoised, 543 | denoised_fn=denoised_fn, 544 | model_kwargs=model_kwargs, 545 | ) 546 | # Usually our model outputs epsilon, but we re-derive it 547 | # in case we used x_start or x_prev prediction. 548 | eps = ( 549 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 550 | - out["pred_xstart"] 551 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 552 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 553 | 554 | # Equation 12. reversed 555 | mean_pred = ( 556 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 557 | + th.sqrt(1 - alpha_bar_next) * eps 558 | ) 559 | 560 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 561 | 562 | def ddim_sample_loop( 563 | self, 564 | model, 565 | shape, 566 | noise=None, 567 | clip_denoised=True, 568 | denoised_fn=None, 569 | model_kwargs=None, 570 | device=None, 571 | progress=False, 572 | eta=0.0, 573 | ): 574 | """ 575 | Generate samples from the model using DDIM. 576 | 577 | Same usage as p_sample_loop(). 578 | """ 579 | final = None 580 | for sample in self.ddim_sample_loop_progressive( 581 | model, 582 | shape, 583 | noise=noise, 584 | clip_denoised=clip_denoised, 585 | denoised_fn=denoised_fn, 586 | model_kwargs=model_kwargs, 587 | device=device, 588 | progress=progress, 589 | eta=eta, 590 | ): 591 | final = sample 592 | return final["sample"] 593 | 594 | def ddim_sample_loop_progressive( 595 | self, 596 | model, 597 | shape, 598 | noise=None, 599 | clip_denoised=True, 600 | denoised_fn=None, 601 | model_kwargs=None, 602 | device=None, 603 | progress=False, 604 | eta=0.0, 605 | ): 606 | """ 607 | Use DDIM to sample from the model and yield intermediate samples from 608 | each timestep of DDIM. 609 | 610 | Same usage as p_sample_loop_progressive(). 611 | """ 612 | if device is None: 613 | device = next(model.parameters()).device 614 | assert isinstance(shape, (tuple, list)) 615 | if noise is not None: 616 | img = noise 617 | else: 618 | img = th.randn(*shape).to(device=device) 619 | indices = list(range(self.num_timesteps))[::-1] 620 | 621 | if progress: 622 | # Lazy import so that we don't depend on tqdm. 623 | from tqdm.auto import tqdm 624 | 625 | indices = tqdm(indices) 626 | 627 | for i in indices: 628 | t = th.tensor([i] * shape[0], device=device) 629 | with th.no_grad(): 630 | out = self.ddim_sample( 631 | model, 632 | img, 633 | t, 634 | clip_denoised=clip_denoised, 635 | denoised_fn=denoised_fn, 636 | model_kwargs=model_kwargs, 637 | eta=eta, 638 | ) 639 | yield out 640 | img = out["sample"] 641 | 642 | def _vb_terms_bpd( 643 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 644 | ): 645 | """ 646 | Get a term for the variational lower-bound. 647 | 648 | The resulting units are bits (rather than nats, as one might expect). 649 | This allows for comparison to other papers. 650 | 651 | :return: a dict with the following keys: 652 | - 'output': a shape [N] tensor of NLLs or KLs. 653 | - 'pred_xstart': the x_0 predictions. 654 | """ 655 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 656 | x_start=x_start, x_t=x_t, t=t 657 | ) 658 | out = self.p_mean_variance( 659 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 660 | ) 661 | kl = normal_kl( 662 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 663 | ) 664 | kl = mean_flat(kl) / np.log(2.0) 665 | 666 | decoder_nll = -discretized_gaussian_log_likelihood( 667 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 668 | ) 669 | assert decoder_nll.shape == x_start.shape 670 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 671 | 672 | # At the first timestep return the decoder NLL, 673 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 674 | output = th.where((t == 0), decoder_nll, kl) 675 | return {"output": output, "pred_xstart": out["pred_xstart"]} 676 | 677 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 678 | """ 679 | Compute training losses for a single timestep. 680 | 681 | :param model: the model to evaluate loss on. 682 | :param x_start: the [N x C x ...] tensor of inputs. 683 | :param t: a batch of timestep indices. 684 | :param model_kwargs: if not None, a dict of extra keyword arguments to 685 | pass to the model. This can be used for conditioning. 686 | :param noise: if specified, the specific Gaussian noise to try to remove. 687 | :return: a dict with the key "loss" containing a tensor of shape [N]. 688 | Some mean or variance settings may also have other keys. 689 | """ 690 | if model_kwargs is None: 691 | model_kwargs = {} 692 | if noise is None: 693 | noise = th.randn_like(x_start) 694 | x_t = self.q_sample(x_start, t, noise=noise) 695 | 696 | terms = {} 697 | 698 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 699 | terms["loss"] = self._vb_terms_bpd( 700 | model=model, 701 | x_start=x_start, 702 | x_t=x_t, 703 | t=t, 704 | clip_denoised=False, 705 | model_kwargs=model_kwargs, 706 | )["output"] 707 | if self.loss_type == LossType.RESCALED_KL: 708 | terms["loss"] *= self.num_timesteps 709 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 710 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 711 | 712 | if self.model_var_type in [ 713 | ModelVarType.LEARNED, 714 | ModelVarType.LEARNED_RANGE, 715 | ]: 716 | B, C = x_t.shape[:2] 717 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 718 | model_output, model_var_values = th.split(model_output, C, dim=1) 719 | # Learn the variance using the variational bound, but don't let 720 | # it affect our mean prediction. 721 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 722 | terms["vb"] = self._vb_terms_bpd( 723 | model=lambda *args, r=frozen_out: r, 724 | x_start=x_start, 725 | x_t=x_t, 726 | t=t, 727 | clip_denoised=False, 728 | )["output"] 729 | if self.loss_type == LossType.RESCALED_MSE: 730 | # Divide by 1000 for equivalence with initial implementation. 731 | # Without a factor of 1/1000, the VB term hurts the MSE term. 732 | terms["vb"] *= self.num_timesteps / 1000.0 733 | 734 | target = { 735 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 736 | x_start=x_start, x_t=x_t, t=t 737 | )[0], 738 | ModelMeanType.START_X: x_start, 739 | ModelMeanType.EPSILON: noise, 740 | }[self.model_mean_type] 741 | assert model_output.shape == target.shape == x_start.shape 742 | terms["mse"] = mean_flat((target - model_output) ** 2) 743 | terms["sum"] = (target - model_output).pow(2).sum(dim=(1, 2, 3)) 744 | if "vb" in terms: 745 | terms["loss"] = terms["mse"] + terms["vb"] 746 | else: 747 | terms["loss"] = terms["sum"] 748 | else: 749 | raise NotImplementedError(self.loss_type) 750 | 751 | return terms 752 | 753 | def _prior_bpd(self, x_start): 754 | """ 755 | Get the prior KL term for the variational lower-bound, measured in 756 | bits-per-dim. 757 | 758 | This term can't be optimized, as it only depends on the encoder. 759 | 760 | :param x_start: the [N x C x ...] tensor of inputs. 761 | :return: a batch of [N] KL values (in bits), one per batch element. 762 | """ 763 | batch_size = x_start.shape[0] 764 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 765 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 766 | kl_prior = normal_kl( 767 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 768 | ) 769 | return mean_flat(kl_prior) / np.log(2.0) 770 | 771 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 772 | """ 773 | Compute the entire variational lower-bound, measured in bits-per-dim, 774 | as well as other related quantities. 775 | 776 | :param model: the model to evaluate loss on. 777 | :param x_start: the [N x C x ...] tensor of inputs. 778 | :param clip_denoised: if True, clip denoised samples. 779 | :param model_kwargs: if not None, a dict of extra keyword arguments to 780 | pass to the model. This can be used for conditioning. 781 | 782 | :return: a dict containing the following keys: 783 | - total_bpd: the total variational lower-bound, per batch element. 784 | - prior_bpd: the prior term in the lower-bound. 785 | - vb: an [N x T] tensor of terms in the lower-bound. 786 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 787 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 788 | """ 789 | device = x_start.device 790 | batch_size = x_start.shape[0] 791 | 792 | vb = [] 793 | xstart_mse = [] 794 | mse = [] 795 | for t in list(range(self.num_timesteps))[::-1]: 796 | t_batch = th.tensor([t] * batch_size, device=device) 797 | noise = th.randn_like(x_start) 798 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 799 | # Calculate VLB term at the current timestep 800 | with th.no_grad(): 801 | out = self._vb_terms_bpd( 802 | model, 803 | x_start=x_start, 804 | x_t=x_t, 805 | t=t_batch, 806 | clip_denoised=clip_denoised, 807 | model_kwargs=model_kwargs, 808 | ) 809 | vb.append(out["output"]) 810 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 811 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 812 | mse.append(mean_flat((eps - noise) ** 2)) 813 | 814 | vb = th.stack(vb, dim=1) 815 | xstart_mse = th.stack(xstart_mse, dim=1) 816 | mse = th.stack(mse, dim=1) 817 | 818 | prior_bpd = self._prior_bpd(x_start) 819 | total_bpd = vb.sum(dim=1) + prior_bpd 820 | return { 821 | "total_bpd": total_bpd, 822 | "prior_bpd": prior_bpd, 823 | "vb": vb, 824 | "xstart_mse": xstart_mse, 825 | "mse": mse, 826 | } 827 | 828 | 829 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 830 | """ 831 | Extract values from a 1-D numpy array for a batch of indices. 832 | 833 | :param arr: the 1-D numpy array. 834 | :param timesteps: a tensor of indices into the array to extract. 835 | :param broadcast_shape: a larger shape of K dimensions with the batch 836 | dimension equal to the length of timesteps. 837 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 838 | """ 839 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 840 | while len(res.shape) < len(broadcast_shape): 841 | res = res[..., None] 842 | return res.expand(broadcast_shape) 843 | --------------------------------------------------------------------------------