├── figs └── teaser.png ├── house_diffusion ├── __init__.py ├── losses.py ├── dist_util.py ├── script_util.py ├── respace.py ├── nn.py ├── resample.py ├── fp16_util.py ├── train_util.py ├── transformer.py ├── logger.py ├── rplanhg_datasets.py └── gaussian_diffusion.py ├── setup.py ├── .gitignore ├── LICENSE ├── requirements.txt ├── scripts ├── script.sh ├── image_train.py └── image_sample.py ├── prepare.sh ├── README.md └── LICENSE_GPL /figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aminshabani/house_diffusion/HEAD/figs/teaser.png -------------------------------------------------------------------------------- /house_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "HouseDiffusion" based on the implementation from "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="house-diffusion", 5 | py_modules=["house_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | classify_image_graph_def.pb 4 | datasets 5 | scripts/rplan 6 | scripts/houses.npz 7 | scripts/outputs 8 | *.npz 9 | *.gif 10 | *.pt 11 | .env 12 | house_diffusion.egg-info 13 | scripts/ckpts 14 | scripts/ckpts_backup 15 | slurm-* 16 | embedder/ckpts 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The code and the model weights in this repository are not allowed for commercial usage. For research purposes, the terms follow the GPL v3, as in the separate file "LICENSE_GPL". 2 | -- Authors of the paper "HouseDiffusion: Vector Floorplan Generation via a Diffusion Model with Discrete and Continuous Denoising". -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blobfile==2.0.1 2 | cairosvg==2.6.0 3 | drawSvg==1.9.0 4 | imageio==2.19.2 5 | matplotlib==3.5.1 6 | mpi4py==3.1.4 7 | networkx==2.8.2 8 | numpy==1.21.5 9 | opencv_python==4.6.0.66 10 | Pillow==9.4.0 11 | pytorch_fid==0.3.0 12 | setuptools==57.5.0 13 | Shapely==1.8.4 14 | tensorflow==2.11.0 15 | torch==2.0.0.dev20221212 16 | tqdm==4.61.2 17 | webcolors==1.12 18 | -------------------------------------------------------------------------------- /scripts/script.sh: -------------------------------------------------------------------------------- 1 | MODEL_FLAGS="--dataset rplan --batch_size 512 --set_name train --target_set 8" 2 | TRAIN_FLAGS="--lr 1e-3 --save_interval 5000 --weight_decay 0.05 --log_interval 500" 3 | SAMPLE_FLAGS="--batch_size 64 --num_samples 64" 4 | 5 | CUDA_VISIBLE_DEVICES='0' python image_train.py $MODEL_FLAGS $TRAIN_FLAGS 6 | #CUDA_VISIBLE_DEVICES='1' python image_sample.py $MODEL_FLAGS --model_path ckpts/exp/model250000.pt $SAMPLE_FLAGS 7 | -------------------------------------------------------------------------------- /prepare.sh: -------------------------------------------------------------------------------- 1 | module load python 2 | module load scipy-stack 3 | module load mpi4py/3.1.3 4 | virtualenv --no-download .env 5 | source .env/bin/activate 6 | pip install --no-index --upgrade pip 7 | pip install --no-index torch torchvision 8 | pip install --no-index matplotlib Pillow scikit-learn scipy opencv_python imageio tensorboard 9 | pip install --no-index tqdm seaborn msgpack PyYAML ConfigArgParse urllib3 10 | pip install scikit-spatial 11 | pip install -e . 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HouseDiffusion 2 | **[HouseDiffusion: Vector Floorplan Generation via a Diffusion Model with Discrete and Continuous Denoising](https://arxiv.org/abs/2211.13287)** 3 | 4 | ## Installation 5 | **1. Clone our repo and install the requirements:** 6 | 7 | Our implementation is based on the public implementation of [guided-diffusion](https://github.com/openai/guided-diffusion). For installation instructions, please refer to their repository. Keep in mind that our current version has not been cleaned and some features from the original repository may not function correctly. 8 | 9 | ``` 10 | git clone https://github.com/aminshabani/house_diffusion.git 11 | cd house_diffusion 12 | pip install -r requirements.txt 13 | pip install -e . 14 | ``` 15 | **2. Download the dataset and create the datasets directory** 16 | 17 | - You can download the datasets from [RPLAN's website](http://staff.ustc.edu.cn/~fuxm/projects/DeepLayout/index.html) or by filling [this](https://docs.google.com/forms/d/e/1FAIpQLSfwteilXzURRKDI5QopWCyOGkeb_CFFbRwtQ0SOPhEg0KGSfw/viewform) form. 18 | - We also use data preprocessing from House-GAN++ which you can find in [this](https://github.com/sepidsh/Housegan-data-reader) link. 19 | Put all of the processed files from the downloaded dataset in a `datasets` folder in the current directory: 20 | 21 | ``` 22 | house_diffusion 23 | ├── datasets 24 | │ ├── rplan 25 | | | └── 0.json 26 | | | └── 1.json 27 | | | └── ... 28 | | └── ... 29 | └── guided_diffusion 30 | └── scripts 31 | └── ... 32 | ``` 33 | - We have provided a temporary model that you can download from [Google Drive](https://drive.google.com/file/d/16zKmtxwY5lF6JE-CJGkRf3-OFoD1TrdR/view?usp=share_link). 34 | 35 | ## Running the code 36 | 37 | **1. Training** 38 | 39 | You can run a single experiment using the following command: 40 | ``` 41 | python image_train.py --dataset rplan --batch_size 32 --set_name train --target_set 8 42 | ``` 43 | **2. Sampling** 44 | To sample floorplans, you can run the following command from inside of the `scripts` directory. To provide different visualizations, please see the `save_samples` function from `scripts/image_sample.py` 45 | 46 | ``` 47 | python image_sample.py --dataset rplan --batch_size 32 --set_name eval --target_set 8 --model_path ckpts/exp/model250000.pt --num_samples 64 48 | ``` 49 | You can also run the corresponding code from `scripts/script.sh`. 50 | 51 | 52 | ## Citation 53 | 54 | ``` 55 | @article{shabani2022housediffusion, 56 | title={HouseDiffusion: Vector Floorplan Generation via a Diffusion Model with Discrete and Continuous Denoising}, 57 | author={Shabani, Mohammad Amin and Hosseini, Sepidehsadat and Furukawa, Yasutaka}, 58 | journal={arXiv preprint arXiv:2211.13287}, 59 | year={2022} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /house_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /house_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 4 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | ## temporary removed to manually set the CUDA_VISIBLE_DEVICES 28 | #os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 29 | 30 | comm = MPI.COMM_WORLD 31 | backend = "gloo" if not th.cuda.is_available() else "nccl" 32 | 33 | if backend == "gloo": 34 | hostname = "localhost" 35 | else: 36 | hostname = socket.gethostbyname(socket.getfqdn()) 37 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 38 | os.environ["RANK"] = str(comm.rank) 39 | os.environ["WORLD_SIZE"] = str(comm.size) 40 | 41 | port = comm.bcast(_find_free_port(), root=0) 42 | os.environ["MASTER_PORT"] = str(port) 43 | dist.init_process_group(backend=backend, init_method="env://") 44 | 45 | 46 | def dev(): 47 | """ 48 | Get the device to use for torch.distributed. 49 | """ 50 | if th.cuda.is_available(): 51 | return th.device(f"cuda") 52 | return th.device("cpu") 53 | 54 | 55 | def load_state_dict(path, **kwargs): 56 | """ 57 | Load a PyTorch file without redundant fetches across MPI ranks. 58 | """ 59 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 60 | if MPI.COMM_WORLD.Get_rank() == 0: 61 | with bf.BlobFile(path, "rb") as f: 62 | data = f.read() 63 | num_chunks = len(data) // chunk_size 64 | if len(data) % chunk_size: 65 | num_chunks += 1 66 | MPI.COMM_WORLD.bcast(num_chunks) 67 | for i in range(0, len(data), chunk_size): 68 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 69 | else: 70 | num_chunks = MPI.COMM_WORLD.bcast(None) 71 | data = bytes() 72 | for _ in range(num_chunks): 73 | data += MPI.COMM_WORLD.bcast(None) 74 | 75 | return th.load(io.BytesIO(data), **kwargs) 76 | 77 | 78 | def sync_params(params): 79 | """ 80 | Synchronize a sequence of Tensors across ranks from rank 0. 81 | """ 82 | for p in params: 83 | with th.no_grad(): 84 | dist.broadcast(p, 0) 85 | 86 | 87 | def _find_free_port(): 88 | try: 89 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 90 | s.bind(("", 0)) 91 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 92 | return s.getsockname()[1] 93 | finally: 94 | s.close() 95 | -------------------------------------------------------------------------------- /scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from house_diffusion import dist_util, logger 8 | from house_diffusion.rplanhg_datasets import load_rplanhg_data 9 | from house_diffusion.resample import create_named_schedule_sampler 10 | from house_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | update_arg_parser, 16 | ) 17 | from house_diffusion.train_util import TrainLoop 18 | 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | update_arg_parser(args) 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | if args.dataset=='rplan': 36 | data = load_rplanhg_data( 37 | batch_size=args.batch_size, 38 | analog_bit=args.analog_bit, 39 | target_set=args.target_set, 40 | set_name=args.set_name, 41 | ) 42 | else: 43 | print('dataset not exist!') 44 | assert False 45 | 46 | logger.log("training...") 47 | TrainLoop( 48 | model=model, 49 | diffusion=diffusion, 50 | data=data, 51 | batch_size=args.batch_size, 52 | microbatch=args.microbatch, 53 | lr=args.lr, 54 | ema_rate=args.ema_rate, 55 | log_interval=args.log_interval, 56 | save_interval=args.save_interval, 57 | resume_checkpoint=args.resume_checkpoint, 58 | use_fp16=args.use_fp16, 59 | fp16_scale_growth=args.fp16_scale_growth, 60 | schedule_sampler=schedule_sampler, 61 | weight_decay=args.weight_decay, 62 | lr_anneal_steps=args.lr_anneal_steps, 63 | analog_bit=args.analog_bit, 64 | ).run_loop() 65 | 66 | 67 | def create_argparser(): 68 | defaults = dict( 69 | dataset = '', 70 | schedule_sampler= "uniform", #"loss-second-moment", "uniform", 71 | lr=1e-4, 72 | weight_decay=0.0, 73 | lr_anneal_steps=0, 74 | batch_size=1, 75 | microbatch=-1, # -1 disables microbatches 76 | ema_rate="0.9999", # comma-separated list of EMA values 77 | log_interval=10, 78 | save_interval=10000, 79 | resume_checkpoint="", 80 | use_fp16=False, 81 | fp16_scale_growth=1e-3, 82 | ) 83 | parser = argparse.ArgumentParser() 84 | defaults.update(model_and_diffusion_defaults()) 85 | add_dict_to_argparser(parser, defaults) 86 | return parser 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /house_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .transformer import TransformerModel 7 | 8 | def diffusion_defaults(): 9 | """ 10 | Defaults for image and classifier training. 11 | """ 12 | return dict( 13 | analog_bit=False, 14 | learn_sigma=False, 15 | diffusion_steps=1000, 16 | noise_schedule="cosine", 17 | timestep_respacing="", 18 | use_kl=False, 19 | predict_xstart=False, 20 | rescale_timesteps=False, 21 | rescale_learned_sigmas=False, 22 | target_set=-1, 23 | set_name='', 24 | ) 25 | 26 | def update_arg_parser(args): 27 | args.num_channels = 512 28 | num_coords = 16 if args.analog_bit else 2 29 | if args.dataset=='rplan': 30 | args.input_channels = num_coords + (2*8 if not args.analog_bit else 0) # . , . , . , . , ' 31 | args.condition_channels = 89 32 | args.out_channels = num_coords * 1 33 | args.use_unet = False 34 | 35 | elif args.dataset=='st3d': 36 | args.input_channels = num_coords + (2*8 if not args.analog_bit else 0) # . , . , . , . , ' 37 | args.condition_channels = 89 38 | args.out_channels = num_coords * 1 39 | args.use_unet = False 40 | 41 | elif args.dataset=='zind': 42 | args.input_channels = num_coords + 2 * 8 43 | args.condition_channels = 89 44 | args.out_channels = num_coords * 1 45 | args.use_unet = False 46 | 47 | elif args.dataset=='layout': 48 | args.use_unet = True 49 | pass #TODO NEED TO COMPLETE 50 | 51 | elif args.dataset=='outdoor': 52 | args.use_unet = True 53 | pass #TODO NEED TO COMPLETE 54 | else: 55 | assert False, "DATASET NOT FOUND" 56 | 57 | def model_and_diffusion_defaults(): 58 | """ 59 | Defaults for image training. 60 | """ 61 | res = dict( 62 | dataset='', 63 | use_checkpoint=False, 64 | input_channels=0, 65 | condition_channels=0, 66 | out_channels=0, 67 | use_unet=False, 68 | num_channels=128 69 | ) 70 | res.update(diffusion_defaults()) 71 | return res 72 | 73 | def create_model_and_diffusion( 74 | input_channels, 75 | condition_channels, 76 | num_channels, 77 | out_channels, 78 | dataset, 79 | use_checkpoint, 80 | use_unet, 81 | learn_sigma, 82 | diffusion_steps, 83 | noise_schedule, 84 | timestep_respacing, 85 | use_kl, 86 | predict_xstart, 87 | rescale_timesteps, 88 | rescale_learned_sigmas, 89 | analog_bit, 90 | target_set, 91 | set_name, 92 | ): 93 | model = TransformerModel(input_channels, condition_channels, num_channels, out_channels, dataset, use_checkpoint, use_unet, analog_bit) 94 | 95 | diffusion = create_gaussian_diffusion( 96 | steps=diffusion_steps, 97 | learn_sigma=learn_sigma, 98 | noise_schedule=noise_schedule, 99 | use_kl=use_kl, 100 | predict_xstart=predict_xstart, 101 | rescale_timesteps=rescale_timesteps, 102 | rescale_learned_sigmas=rescale_learned_sigmas, 103 | timestep_respacing=timestep_respacing, 104 | ) 105 | return model, diffusion 106 | 107 | def create_gaussian_diffusion( 108 | *, 109 | steps=1000, 110 | learn_sigma=False, 111 | sigma_small=False, 112 | noise_schedule="linear", 113 | use_kl=False, 114 | predict_xstart=False, 115 | rescale_timesteps=False, 116 | rescale_learned_sigmas=False, 117 | timestep_respacing="", 118 | ): 119 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 120 | if use_kl: 121 | loss_type = gd.LossType.RESCALED_KL 122 | elif rescale_learned_sigmas: 123 | loss_type = gd.LossType.RESCALED_MSE 124 | else: 125 | loss_type = gd.LossType.MSE 126 | if not timestep_respacing: 127 | timestep_respacing = [steps] 128 | return SpacedDiffusion( 129 | use_timesteps=space_timesteps(steps, timestep_respacing), 130 | betas=betas, 131 | model_mean_type=( 132 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 133 | ), 134 | model_var_type=( 135 | ( 136 | gd.ModelVarType.FIXED_LARGE 137 | if not sigma_small 138 | else gd.ModelVarType.FIXED_SMALL 139 | ) 140 | if not learn_sigma 141 | else gd.ModelVarType.LEARNED_RANGE 142 | ), 143 | loss_type=loss_type, 144 | rescale_timesteps=rescale_timesteps, 145 | ) 146 | 147 | 148 | def add_dict_to_argparser(parser, default_dict): 149 | for k, v in default_dict.items(): 150 | v_type = type(v) 151 | if v is None: 152 | v_type = str 153 | elif isinstance(v, bool): 154 | v_type = str2bool 155 | parser.add_argument(f"--{k}", default=v, type=v_type) 156 | 157 | 158 | def args_to_dict(args, keys): 159 | return {k: getattr(args, k) for k in keys} 160 | 161 | 162 | def str2bool(v): 163 | """ 164 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 165 | """ 166 | if isinstance(v, bool): 167 | return v 168 | if v.lower() in ("yes", "true", "t", "y", "1"): 169 | return True 170 | elif v.lower() in ("no", "false", "f", "n", "0"): 171 | return False 172 | else: 173 | raise argparse.ArgumentTypeError("boolean value expected") 174 | -------------------------------------------------------------------------------- /house_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /house_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor, padding_mask): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | tensor = tensor * padding_mask.unsqueeze(1) 91 | tensor = tensor.mean(dim=list(range(1, len(tensor.shape))))/th.sum(padding_mask, dim=1) 92 | return tensor 93 | 94 | 95 | def normalization(channels): 96 | """ 97 | Make a standard normalization layer. 98 | 99 | :param channels: number of input channels. 100 | :return: an nn.Module for normalization. 101 | """ 102 | return GroupNorm32(32, channels) 103 | 104 | 105 | def timestep_embedding(timesteps, dim, max_period=10000): 106 | """ 107 | Create sinusoidal timestep embeddings. 108 | 109 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 110 | These may be fractional. 111 | :param dim: the dimension of the output. 112 | :param max_period: controls the minimum frequency of the embeddings. 113 | :return: an [N x dim] Tensor of positional embeddings. 114 | """ 115 | half = dim // 2 116 | freqs = th.exp( 117 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 118 | ).to(device=timesteps.device) 119 | args = timesteps[:, None].float() * freqs[None] 120 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 121 | if dim % 2: 122 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 123 | return embedding 124 | 125 | 126 | def checkpoint(func, inputs, params, flag): 127 | """ 128 | Evaluate a function without caching intermediate activations, allowing for 129 | reduced memory at the expense of extra compute in the backward pass. 130 | 131 | :param func: the function to evaluate. 132 | :param inputs: the argument sequence to pass to `func`. 133 | :param params: a sequence of parameters `func` depends on but does not 134 | explicitly take as arguments. 135 | :param flag: if False, disable gradient checkpointing. 136 | """ 137 | if flag: 138 | args = tuple(inputs) + tuple(params) 139 | return CheckpointFunction.apply(func, len(inputs), *args) 140 | else: 141 | return func(*inputs) 142 | 143 | 144 | class CheckpointFunction(th.autograd.Function): 145 | @staticmethod 146 | def forward(ctx, run_function, length, *args): 147 | ctx.run_function = run_function 148 | ctx.input_tensors = list(args[:length]) 149 | ctx.input_params = list(args[length:]) 150 | with th.no_grad(): 151 | output_tensors = ctx.run_function(*ctx.input_tensors) 152 | return output_tensors 153 | 154 | @staticmethod 155 | def backward(ctx, *output_grads): 156 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 157 | with th.enable_grad(): 158 | # Fixes a bug where the first op in run_function modifies the 159 | # Tensor storage in place, which is not allowed for detach()'d 160 | # Tensors. 161 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 162 | output_tensors = ctx.run_function(*shallow_copies) 163 | input_grads = th.autograd.grad( 164 | output_tensors, 165 | ctx.input_tensors + ctx.input_params, 166 | output_grads, 167 | allow_unused=True, 168 | ) 169 | del ctx.input_tensors 170 | del ctx.input_params 171 | del output_tensors 172 | return (None, None) + input_grads 173 | -------------------------------------------------------------------------------- /house_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /house_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /house_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | analog_bit=None, 42 | ): 43 | self.analog_bit=analog_bit 44 | self.model = model 45 | self.diffusion = diffusion 46 | self.data = data 47 | self.batch_size = batch_size 48 | self.microbatch = microbatch if microbatch > 0 else batch_size 49 | self.lr = lr 50 | self.ema_rate = ( 51 | [ema_rate] 52 | if isinstance(ema_rate, float) 53 | else [float(x) for x in ema_rate.split(",")] 54 | ) 55 | self.log_interval = log_interval 56 | self.save_interval = save_interval 57 | self.resume_checkpoint = resume_checkpoint 58 | self.use_fp16 = use_fp16 59 | self.fp16_scale_growth = fp16_scale_growth 60 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 61 | self.weight_decay = weight_decay 62 | self.lr_anneal_steps = lr_anneal_steps 63 | 64 | self.step = 0 65 | self.resume_step = 0 66 | self.global_batch = self.batch_size * dist.get_world_size() 67 | 68 | self.sync_cuda = th.cuda.is_available() 69 | 70 | self._load_and_sync_parameters() 71 | self.mp_trainer = MixedPrecisionTrainer( 72 | model=self.model, 73 | use_fp16=self.use_fp16, 74 | fp16_scale_growth=fp16_scale_growth, 75 | ) 76 | 77 | self.opt = AdamW( 78 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 79 | ) 80 | if self.resume_step: 81 | self._load_optimizer_state() 82 | # Model was resumed, either due to a restart or a checkpoint 83 | # being specified at the command line. 84 | self.ema_params = [ 85 | self._load_ema_parameters(rate) for rate in self.ema_rate 86 | ] 87 | else: 88 | self.ema_params = [ 89 | copy.deepcopy(self.mp_trainer.master_params) 90 | for _ in range(len(self.ema_rate)) 91 | ] 92 | 93 | if th.cuda.is_available(): 94 | self.use_ddp = True 95 | self.ddp_model = DDP( 96 | self.model, 97 | device_ids=[dist_util.dev()], 98 | output_device=dist_util.dev(), 99 | broadcast_buffers=False, 100 | bucket_cap_mb=128, 101 | find_unused_parameters=False, 102 | ) 103 | else: 104 | if dist.get_world_size() > 1: 105 | logger.warn( 106 | "Distributed training requires CUDA. " 107 | "Gradients will not be synchronized properly!" 108 | ) 109 | self.use_ddp = False 110 | self.ddp_model = self.model 111 | 112 | def _load_and_sync_parameters(self): 113 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 114 | 115 | if resume_checkpoint: 116 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 117 | # if dist.get_rank() == 0: 118 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 119 | self.model.load_state_dict( 120 | dist_util.load_state_dict( 121 | resume_checkpoint, map_location=dist_util.dev() 122 | ) 123 | ) 124 | 125 | dist_util.sync_params(self.model.parameters()) 126 | 127 | def _load_ema_parameters(self, rate): 128 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 129 | 130 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 131 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 132 | if ema_checkpoint: 133 | if dist.get_rank() == 0: 134 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 135 | state_dict = dist_util.load_state_dict( 136 | ema_checkpoint, map_location=dist_util.dev() 137 | ) 138 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 139 | 140 | dist_util.sync_params(ema_params) 141 | return ema_params 142 | 143 | def _load_optimizer_state(self): 144 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 145 | opt_checkpoint = bf.join( 146 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 147 | ) 148 | if bf.exists(opt_checkpoint): 149 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 150 | state_dict = dist_util.load_state_dict( 151 | opt_checkpoint, map_location=dist_util.dev() 152 | ) 153 | self.opt.load_state_dict(state_dict) 154 | 155 | def run_loop(self): 156 | while ( 157 | not self.lr_anneal_steps 158 | or self.step + self.resume_step < self.lr_anneal_steps 159 | ): 160 | batch, cond = next(self.data) 161 | self.run_step(batch, cond) 162 | if self.step % 100000 == 0: 163 | lr = self.lr * (0.1**(self.step//100000)) 164 | logger.log(f"Step {self.step}: Updating learning rate to {lr}") 165 | for param_group in self.opt.param_groups: 166 | param_group["lr"] = lr 167 | if self.step % self.log_interval == 0: 168 | logger.dumpkvs() 169 | if self.step % self.save_interval == 0: 170 | self.save() 171 | # Run for a finite amount of time in integration tests. 172 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 173 | return 174 | self.step += 1 175 | # Save the last checkpoint if it wasn't already saved. 176 | if (self.step - 1) % self.save_interval != 0: 177 | self.save() 178 | 179 | def run_step(self, batch, cond): 180 | self.forward_backward(batch, cond) 181 | took_step = self.mp_trainer.optimize(self.opt) 182 | if took_step: 183 | self._update_ema() 184 | self._anneal_lr() 185 | self.log_step() 186 | 187 | def forward_backward(self, batch, cond): 188 | self.mp_trainer.zero_grad() 189 | for i in range(0, batch.shape[0], self.microbatch): 190 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 191 | micro_cond = { 192 | k: v[i : i + self.microbatch].to(dist_util.dev()) 193 | for k, v in cond.items() 194 | } 195 | model_kwargs = micro_cond 196 | 197 | last_batch = (i + self.microbatch) >= batch.shape[0] 198 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 199 | 200 | compute_losses = functools.partial( 201 | self.diffusion.training_losses, 202 | self.ddp_model, 203 | micro, 204 | t, 205 | model_kwargs=model_kwargs, 206 | analog_bit=self.analog_bit, 207 | ) 208 | 209 | if last_batch or not self.use_ddp: 210 | losses = compute_losses() 211 | else: 212 | with self.ddp_model.no_sync(): 213 | losses = compute_losses() 214 | 215 | if isinstance(self.schedule_sampler, LossAwareSampler): 216 | self.schedule_sampler.update_with_local_losses( 217 | t, losses["loss"].detach() 218 | ) 219 | 220 | loss = (losses["loss"] * weights).mean() 221 | log_loss_dict( 222 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 223 | ) 224 | self.mp_trainer.backward(loss) 225 | 226 | def _update_ema(self): 227 | for rate, params in zip(self.ema_rate, self.ema_params): 228 | update_ema(params, self.mp_trainer.master_params, rate=rate) 229 | 230 | def _anneal_lr(self): 231 | if not self.lr_anneal_steps: 232 | return 233 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 234 | lr = self.lr * (1 - frac_done) 235 | for param_group in self.opt.param_groups: 236 | param_group["lr"] = lr 237 | 238 | def log_step(self): 239 | logger.logkv("step", self.step + self.resume_step) 240 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 241 | 242 | def save(self): 243 | def save_checkpoint(rate, params): 244 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 245 | if dist.get_rank() == 0: 246 | logger.log(f"saving model {rate}...") 247 | if not rate: 248 | filename = f"model{(self.step+self.resume_step):06d}.pt" 249 | else: 250 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 251 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 252 | th.save(state_dict, f) 253 | 254 | save_checkpoint(0, self.mp_trainer.master_params) 255 | for rate, params in zip(self.ema_rate, self.ema_params): 256 | save_checkpoint(rate, params) 257 | 258 | if dist.get_rank() == 0: 259 | with bf.BlobFile( 260 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 261 | "wb", 262 | ) as f: 263 | th.save(self.opt.state_dict(), f) 264 | 265 | dist.barrier() 266 | 267 | 268 | def parse_resume_step_from_filename(filename): 269 | """ 270 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 271 | checkpoint's number of steps. 272 | """ 273 | split = filename.split("model") 274 | if len(split) < 2: 275 | return 0 276 | split1 = split[-1].split(".")[0] 277 | try: 278 | return int(split1) 279 | except ValueError: 280 | return 0 281 | 282 | 283 | def get_blob_logdir(): 284 | # You can change this to be a separate path to save checkpoints to 285 | # a blobstore or some external drive. 286 | return logger.get_dir() 287 | 288 | 289 | def find_resume_checkpoint(): 290 | # On your infrastructure, you may want to override this to automatically 291 | # discover the latest checkpoint on your blob storage, etc. 292 | return None 293 | 294 | 295 | def find_ema_checkpoint(main_checkpoint, step, rate): 296 | if main_checkpoint is None: 297 | return None 298 | filename = f"ema_{rate}_{(step):06d}.pt" 299 | path = bf.join(bf.dirname(main_checkpoint), filename) 300 | if bf.exists(path): 301 | return path 302 | return None 303 | 304 | 305 | def log_loss_dict(diffusion, ts, losses): 306 | for key, values in losses.items(): 307 | logger.logkv_mean(key, values.mean().item()) 308 | # Log the quantiles (four quartiles, in particular). 309 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 310 | quartile = int(4 * sub_t / diffusion.num_timesteps) 311 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 312 | -------------------------------------------------------------------------------- /house_diffusion/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .nn import timestep_embedding 7 | 8 | def dec2bin(xinp, bits): 9 | mask = 2 ** th.arange(bits - 1, -1, -1).to(xinp.device, xinp.dtype) 10 | return xinp.unsqueeze(-1).bitwise_and(mask).ne(0).float() 11 | 12 | class PositionalEncoding(nn.Module): 13 | 14 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 15 | super().__init__() 16 | self.dropout = nn.Dropout(p=dropout) 17 | 18 | position = th.arange(max_len).unsqueeze(1) 19 | div_term = th.exp(th.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 20 | pe = th.zeros(1, max_len, d_model) 21 | pe[0, :, 0::2] = th.sin(position * div_term) 22 | pe[0, :, 1::2] = th.cos(position * div_term) 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, x): 26 | """ 27 | Args: 28 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 29 | """ 30 | x = x + self.pe[0:1, :x.size(1)] 31 | return self.dropout(x) 32 | 33 | class FeedForward(nn.Module): 34 | def __init__(self, d_model, d_ff, dropout, activation): 35 | super().__init__() 36 | # We set d_ff as a default to 2048 37 | self.linear_1 = nn.Linear(d_model, d_ff) 38 | self.dropout = nn.Dropout(dropout) 39 | self.linear_2 = nn.Linear(d_ff, d_model) 40 | self.activation = activation 41 | def forward(self, x): 42 | x = self.dropout(self.activation(self.linear_1(x))) 43 | x = self.linear_2(x) 44 | return x 45 | 46 | def attention(q, k, v, d_k, mask=None, dropout=None): 47 | scores = th.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 48 | if mask is not None: 49 | mask = mask.unsqueeze(1) 50 | scores = scores.masked_fill(mask == 1, -1e9) 51 | scores = F.softmax(scores, dim=-1) 52 | if dropout is not None: 53 | scores = dropout(scores) 54 | output = th.matmul(scores, v) 55 | return output 56 | 57 | class MultiHeadAttention(nn.Module): 58 | def __init__(self, heads, d_model, dropout = 0.1): 59 | super().__init__() 60 | self.d_model = d_model 61 | self.d_k = d_model // heads 62 | self.h = heads 63 | self.q_linear = nn.Linear(d_model, d_model) 64 | self.v_linear = nn.Linear(d_model, d_model) 65 | self.k_linear = nn.Linear(d_model, d_model) 66 | self.dropout = nn.Dropout(dropout) 67 | self.out = nn.Linear(d_model, d_model) 68 | 69 | def forward(self, q, k, v, mask=None): 70 | bs = q.size(0) 71 | # perform linear operation and split into h heads 72 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 73 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 74 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 75 | # transpose to get dimensions bs * h * sl * d_model 76 | k = k.transpose(1,2) 77 | q = q.transpose(1,2) 78 | v = v.transpose(1,2)# calculate attention using function we will define next 79 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 80 | # concatenate heads and put through final linear layer 81 | concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model) 82 | output = self.out(concat) 83 | return output 84 | 85 | class EncoderLayer(nn.Module): 86 | def __init__(self, d_model, heads, dropout, activation): 87 | super().__init__() 88 | self.norm_1 = nn.InstanceNorm1d(d_model) 89 | self.norm_2 = nn.InstanceNorm1d(d_model) 90 | self.self_attn = MultiHeadAttention(heads, d_model) 91 | self.door_attn = MultiHeadAttention(heads, d_model) 92 | self.gen_attn = MultiHeadAttention(heads, d_model) 93 | self.ff = FeedForward(d_model, d_model*2, dropout, activation) 94 | self.dropout = nn.Dropout(dropout) 95 | 96 | def forward(self, x, door_mask, self_mask, gen_mask): 97 | assert (gen_mask.max()==1 and gen_mask.min()==0), f"{gen_mask.max()}, {gen_mask.min()}" 98 | x2 = self.norm_1(x) 99 | x = x + self.dropout(self.door_attn(x2,x2,x2,door_mask)) \ 100 | + self.dropout(self.self_attn(x2, x2, x2, self_mask)) \ 101 | + self.dropout(self.gen_attn(x2, x2, x2, gen_mask)) 102 | x2 = self.norm_2(x) 103 | x = x + self.dropout(self.ff(x2)) 104 | return x 105 | 106 | class TransformerModel(nn.Module): 107 | """ 108 | The full Transformer model with timestep embedding. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | in_channels, 114 | condition_channels, 115 | model_channels, 116 | out_channels, 117 | dataset, 118 | use_checkpoint, 119 | use_unet, 120 | analog_bit, 121 | ): 122 | super().__init__() 123 | self.in_channels = in_channels 124 | self.condition_channels = condition_channels 125 | self.model_channels = model_channels 126 | self.out_channels = out_channels 127 | self.time_channels = model_channels 128 | self.use_checkpoint = use_checkpoint 129 | self.analog_bit = analog_bit 130 | self.use_unet = use_unet 131 | self.num_layers = 4 132 | 133 | # self.pos_encoder = PositionalEncoding(model_channels, 0.001) 134 | # self.activation = nn.SiLU() 135 | self.activation = nn.ReLU() 136 | 137 | self.time_embed = nn.Sequential( 138 | nn.Linear(self.model_channels, self.model_channels), 139 | nn.SiLU(), 140 | nn.Linear(self.model_channels, self.time_channels), 141 | ) 142 | self.input_emb = nn.Linear(self.in_channels, self.model_channels) 143 | self.condition_emb = nn.Linear(self.condition_channels, self.model_channels) 144 | 145 | if use_unet: 146 | self.unet = UNet(self.model_channels, 1) 147 | 148 | self.transformer_layers = nn.ModuleList([EncoderLayer(self.model_channels, 4, 0.1, self.activation) for x in range(self.num_layers)]) 149 | # self.transformer_layers = nn.ModuleList([nn.TransformerEncoderLayer(self.model_channels, 4, self.model_channels*2, 0.1, self.activation, batch_first=True) for x in range(self.num_layers)]) 150 | 151 | self.output_linear1 = nn.Linear(self.model_channels, self.model_channels) 152 | self.output_linear2 = nn.Linear(self.model_channels, self.model_channels//2) 153 | self.output_linear3 = nn.Linear(self.model_channels//2, self.out_channels) 154 | 155 | if not self.analog_bit: 156 | self.output_linear_bin1 = nn.Linear(162+self.model_channels, self.model_channels) 157 | self.output_linear_bin2 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) 158 | self.output_linear_bin3 = EncoderLayer(self.model_channels, 1, 0.1, self.activation) 159 | self.output_linear_bin4 = nn.Linear(self.model_channels, 16) 160 | 161 | print(f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") 162 | 163 | def expand_points(self, points, connections): 164 | def average_points(point1, point2): 165 | points_new = (point1+point2)/2 166 | return points_new 167 | p1 = points 168 | p1 = p1.view([p1.shape[0], p1.shape[1], 2, -1]) 169 | p5 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] 170 | p5 = p5.view([p5.shape[0], p5.shape[1], 2, -1]) 171 | p3 = average_points(p1, p5) 172 | p2 = average_points(p1, p3) 173 | p4 = average_points(p3, p5) 174 | p1_5 = average_points(p1, p2) 175 | p2_5 = average_points(p2, p3) 176 | p3_5 = average_points(p3, p4) 177 | p4_5 = average_points(p4, p5) 178 | points_new = th.cat((p1.view_as(points), p1_5.view_as(points), p2.view_as(points), 179 | p2_5.view_as(points), p3.view_as(points), p3_5.view_as(points), p4.view_as(points), p4_5.view_as(points), p5.view_as(points)), 2) 180 | return points_new.detach() 181 | 182 | def create_image(self, points, connections, room_indices, img_size=256, res=200): 183 | img = th.zeros((points.shape[0], 1, img_size, img_size), device=points.device) 184 | points = (points+1)*(img_size//2) 185 | points[points>=img_size] = img_size-1 186 | points[points<0] = 0 187 | p1 = points 188 | p2 = points[th.arange(points.shape[0])[:, None], connections[:,:,1].long()] 189 | 190 | slope = (p2[:,:,1]-p1[:,:,1])/((p2[:,:,0]-p1[:,:,0])) 191 | slope[slope.isnan()] = 0 192 | slope[slope.isinf()] = 1 193 | 194 | m = th.linspace(0, 1, res, device=points.device) 195 | new_shape = [p2.shape[0], res, p2.shape[1], p2.shape[2]] 196 | 197 | new_p2 = p2.unsqueeze(1).expand(new_shape) 198 | new_p1 = p1.unsqueeze(1).expand(new_shape) 199 | new_room_indices = room_indices.unsqueeze(1).expand([p2.shape[0], res, p2.shape[1], 1]) 200 | 201 | inc = new_p2 - new_p1 202 | 203 | xs = m.view(1,-1,1) * inc[:,:,:,0] 204 | xs = xs + new_p1[:,:,:,0] 205 | xs = xs.long() 206 | 207 | x_inc = th.where(inc[:,:,:,0]==0, inc[:,:,:,1], inc[:,:,:,0]) 208 | x_inc = m.view(1,-1,1) * x_inc 209 | ys = x_inc * slope.unsqueeze(1) + new_p1[:,:,:,1] 210 | ys = ys.long() 211 | 212 | img[th.arange(xs.shape[0])[:, None], :, xs.view(img.shape[0], -1), ys.view(img.shape[0], -1)] = new_room_indices.reshape(img.shape[0], -1, 1).float() 213 | return img.detach() 214 | 215 | def forward(self, x, timesteps, xtalpha, epsalpha, is_syn=False, **kwargs): 216 | """ 217 | Apply the model to an input batch. 218 | 219 | :param x: an [N x S x C] Tensor of inputs. 220 | :param timesteps: a 1-D batch of timesteps. 221 | :param y: an [N] Tensor of labels, if class-conditional. 222 | :return: an [N x S x C] Tensor of outputs. 223 | """ 224 | # prefix = 'syn_' if is_syn else '' 225 | prefix = 'syn_' if is_syn else '' 226 | x = x.permute([0, 2, 1]).float() # -> convert [N x C x S] to [N x S x C] 227 | 228 | if not self.analog_bit: 229 | x = self.expand_points(x, kwargs[f'{prefix}connections']) 230 | 231 | # Different input embeddings (Input, Time, Conditions) 232 | time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 233 | time_emb = time_emb.unsqueeze(1) 234 | input_emb = self.input_emb(x) 235 | if self.condition_channels>0: 236 | cond = None 237 | for key in [f'{prefix}room_types', f'{prefix}corner_indices', f'{prefix}room_indices']: 238 | if cond is None: 239 | cond = kwargs[key] 240 | else: 241 | cond = th.cat((cond, kwargs[key]), 2) 242 | cond_emb = self.condition_emb(cond.float()) 243 | 244 | # PositionalEncoding and DM model 245 | out = input_emb + cond_emb + time_emb.repeat((1, input_emb.shape[1], 1)) 246 | for layer in self.transformer_layers: 247 | out = layer(out, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) 248 | 249 | out_dec = self.output_linear1(out) 250 | out_dec = self.activation(out_dec) 251 | out_dec = self.output_linear2(out_dec) 252 | out_dec = self.output_linear3(out_dec) 253 | 254 | if not self.analog_bit: 255 | out_bin_start = x*xtalpha.repeat([1,1,9]) - out_dec.repeat([1,1,9]) * epsalpha.repeat([1,1,9]) 256 | out_bin = (out_bin_start/2 + 0.5) # -> [0,1] 257 | out_bin = out_bin * 256 #-> [0, 256] 258 | out_bin = dec2bin(out_bin.round().int(), 8) 259 | out_bin_inp = out_bin.reshape([x.shape[0], x.shape[1], 16*9]) 260 | out_bin_inp[out_bin_inp==0] = -1 261 | 262 | out_bin = th.cat((out_bin_start, out_bin_inp, cond_emb), 2) 263 | out_bin = self.activation(self.output_linear_bin1(out_bin)) 264 | out_bin = self.output_linear_bin2(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) 265 | out_bin = self.output_linear_bin3(out_bin, kwargs[f'{prefix}door_mask'], kwargs[f'{prefix}self_mask'], kwargs[f'{prefix}gen_mask']) 266 | out_bin = self.output_linear_bin4(out_bin) 267 | 268 | out_bin = out_bin.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] 269 | 270 | out_dec = out_dec.permute([0, 2, 1]) # -> convert back [N x S x C] to [N x C x S] 271 | 272 | if not self.analog_bit: 273 | return out_dec, out_bin 274 | else: 275 | return out_dec, None 276 | -------------------------------------------------------------------------------- /house_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | # tempfile.gettempdir(), 451 | 'ckpts', 452 | datetime.datetime.now().strftime("openai_%Y_%m_%d_%H_%M_%S_%f"), 453 | ) 454 | assert isinstance(dir, str) 455 | dir = os.path.expanduser(dir) 456 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 457 | 458 | rank = get_rank_without_mpi_import() 459 | if rank > 0: 460 | log_suffix = log_suffix + "-rank%03i" % rank 461 | 462 | if format_strs is None: 463 | if rank == 0: 464 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 465 | else: 466 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 467 | format_strs = filter(None, format_strs) 468 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 469 | 470 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 471 | if output_formats: 472 | log("Logging to %s" % dir) 473 | 474 | 475 | def _configure_default_logger(): 476 | configure() 477 | Logger.DEFAULT = Logger.CURRENT 478 | 479 | 480 | def reset(): 481 | if Logger.CURRENT is not Logger.DEFAULT: 482 | Logger.CURRENT.close() 483 | Logger.CURRENT = Logger.DEFAULT 484 | log("Reset logger") 485 | 486 | 487 | @contextmanager 488 | def scoped_configure(dir=None, format_strs=None, comm=None): 489 | prevlogger = Logger.CURRENT 490 | configure(dir=dir, format_strs=format_strs, comm=comm) 491 | try: 492 | yield 493 | finally: 494 | Logger.CURRENT.close() 495 | Logger.CURRENT = prevlogger 496 | 497 | -------------------------------------------------------------------------------- /scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | 12 | import io 13 | import PIL.Image as Image 14 | import drawSvg as drawsvg 15 | import cairosvg 16 | import imageio 17 | from tqdm import tqdm 18 | import matplotlib.pyplot as plt 19 | from pytorch_fid.fid_score import calculate_fid_given_paths 20 | from house_diffusion.rplanhg_datasets import load_rplanhg_data 21 | from house_diffusion import dist_util, logger 22 | from house_diffusion.script_util import ( 23 | model_and_diffusion_defaults, 24 | create_model_and_diffusion, 25 | add_dict_to_argparser, 26 | args_to_dict, 27 | update_arg_parser, 28 | ) 29 | import webcolors 30 | import networkx as nx 31 | from collections import defaultdict 32 | from shapely.geometry import Polygon 33 | from shapely.geometry.base import geom_factory 34 | from shapely.geos import lgeos 35 | 36 | # import random 37 | # th.manual_seed(0) 38 | # random.seed(0) 39 | # np.random.seed(0) 40 | 41 | bin_to_int = lambda x: int("".join([str(int(i.cpu().data)) for i in x]), 2) 42 | def bin_to_int_sample(sample, resolution=256): 43 | sample_new = th.zeros([sample.shape[0], sample.shape[1], sample.shape[2], 2]) 44 | sample[sample<0] = 0 45 | sample[sample>0] = 1 46 | for i in range(sample.shape[0]): 47 | for j in range(sample.shape[1]): 48 | for k in range(sample.shape[2]): 49 | sample_new[i, j, k, 0] = bin_to_int(sample[i, j, k, :8]) 50 | sample_new[i, j, k, 1] = bin_to_int(sample[i, j, k, 8:]) 51 | sample = sample_new 52 | sample = sample/(resolution/2) - 1 53 | return sample 54 | 55 | def get_graph(indx, g_true, ID_COLOR, draw_graph, save_svg): 56 | # build true graph 57 | G_true = nx.Graph() 58 | colors_H = [] 59 | node_size = [] 60 | edge_color = [] 61 | linewidths = [] 62 | edgecolors = [] 63 | # add nodes 64 | for k, label in enumerate(g_true[0]): 65 | _type = label 66 | if _type >= 0 and _type not in [11, 12]: 67 | G_true.add_nodes_from([(k, {'label':k})]) 68 | colors_H.append(ID_COLOR[_type]) 69 | node_size.append(1000) 70 | edgecolors.append('blue') 71 | linewidths.append(0.0) 72 | # add outside node 73 | G_true.add_nodes_from([(-1, {'label':-1})]) 74 | colors_H.append("white") 75 | node_size.append(750) 76 | edgecolors.append('black') 77 | linewidths.append(3.0) 78 | # add edges 79 | for k, m, l in g_true[1]: 80 | k = int(k) 81 | l = int(l) 82 | _type_k = g_true[0][k] 83 | _type_l = g_true[0][l] 84 | if m > 0 and (_type_k not in [11, 12] and _type_l not in [11, 12]): 85 | G_true.add_edges_from([(k, l)]) 86 | edge_color.append('#D3A2C7') 87 | elif m > 0 and (_type_k==11 or _type_l==11): 88 | if _type_k==11: 89 | G_true.add_edges_from([(l, -1)]) 90 | else: 91 | G_true.add_edges_from([(k, -1)]) 92 | edge_color.append('#727171') 93 | if draw_graph: 94 | plt.figure() 95 | pos = nx.nx_agraph.graphviz_layout(G_true, prog='neato') 96 | nx.draw(G_true, pos, node_size=node_size, linewidths=linewidths, node_color=colors_H, font_size=14, font_color='white',\ 97 | font_weight='bold', edgecolors=edgecolors, width=4.0, with_labels=False) 98 | if save_svg: 99 | plt.savefig(f'outputs/graphs_gt/{indx}.svg') 100 | else: 101 | plt.savefig(f'outputs/graphs_gt/{indx}.jpg') 102 | plt.close('all') 103 | return G_true 104 | 105 | def estimate_graph(indx, polys, nodes, G_gt, ID_COLOR, draw_graph, save_svg): 106 | nodes = np.array(nodes) 107 | G_gt = G_gt[1-th.where((G_gt == th.tensor([0,0,0], device='cuda')).all(dim=1))[0]] 108 | G_gt = get_graph(indx, [nodes, G_gt], ID_COLOR, draw_graph, save_svg) 109 | G_estimated = nx.Graph() 110 | colors_H = [] 111 | node_size = [] 112 | edge_color = [] 113 | linewidths = [] 114 | edgecolors = [] 115 | edge_labels = {} 116 | # add nodes 117 | for k, label in enumerate(nodes): 118 | _type = label 119 | if _type >= 0 and _type not in [11, 12]: 120 | G_estimated.add_nodes_from([(k, {'label':k})]) 121 | colors_H.append(ID_COLOR[_type]) 122 | node_size.append(1000) 123 | linewidths.append(0.0) 124 | # add outside node 125 | G_estimated.add_nodes_from([(-1, {'label':-1})]) 126 | colors_H.append("white") 127 | node_size.append(750) 128 | edgecolors.append('black') 129 | linewidths.append(3.0) 130 | # add node-to-door connections 131 | doors_inds = np.where((nodes == 11) | (nodes == 12))[0] 132 | rooms_inds = np.where((nodes != 11) & (nodes != 12))[0] 133 | doors_rooms_map = defaultdict(list) 134 | for k in doors_inds: 135 | for l in rooms_inds: 136 | if k > l: 137 | p1, p2 = polys[k], polys[l] 138 | p1, p2 = Polygon(p1), Polygon(p2) 139 | if not p1.is_valid: 140 | p1 = geom_factory(lgeos.GEOSMakeValid(p1._geom)) 141 | if not p2.is_valid: 142 | p2 = geom_factory(lgeos.GEOSMakeValid(p2._geom)) 143 | iou = p1.intersection(p2).area/ p1.union(p2).area 144 | if iou > 0 and iou < 0.2: 145 | doors_rooms_map[k].append((l, iou)) 146 | # draw connections 147 | for k in doors_rooms_map.keys(): 148 | _conn = doors_rooms_map[k] 149 | _conn = sorted(_conn, key=lambda tup: tup[1], reverse=True) 150 | _conn_top2 = _conn[:2] 151 | if nodes[k] != 11: 152 | if len(_conn_top2) > 1: 153 | l1, l2 = _conn_top2[0][0], _conn_top2[1][0] 154 | edge_labels[(l1, l2)] = k 155 | G_estimated.add_edges_from([(l1, l2)]) 156 | else: 157 | if len(_conn) > 0: 158 | l1 = _conn[0][0] 159 | edge_labels[(-1, l1)] = k 160 | G_estimated.add_edges_from([(-1, l1)]) 161 | # add missed edges 162 | G_estimated_complete = G_estimated.copy() 163 | for k, l in G_gt.edges(): 164 | if not G_estimated.has_edge(k, l): 165 | G_estimated_complete.add_edges_from([(k, l)]) 166 | # add edges colors 167 | colors = [] 168 | mistakes = 0 169 | for k, l in G_estimated_complete.edges(): 170 | if G_gt.has_edge(k, l) and not G_estimated.has_edge(k, l): 171 | colors.append('yellow') 172 | mistakes += 1 173 | elif G_estimated.has_edge(k, l) and not G_gt.has_edge(k, l): 174 | colors.append('red') 175 | mistakes += 1 176 | elif G_estimated.has_edge(k, l) and G_gt.has_edge(k, l): 177 | colors.append('green') 178 | else: 179 | print('ERR') 180 | if draw_graph: 181 | plt.figure() 182 | pos = nx.nx_agraph.graphviz_layout(G_estimated_complete, prog='neato') 183 | weights = [4 for u, v in G_estimated_complete.edges()] 184 | nx.draw(G_estimated_complete, pos, edge_color=colors, linewidths=linewidths, edgecolors=edgecolors, node_size=node_size, node_color=colors_H, font_size=14, font_weight='bold', font_color='white', width=weights, with_labels=False) 185 | if save_svg: 186 | plt.savefig(f'outputs/graphs_pred/{indx}.svg') 187 | else: 188 | plt.savefig(f'outputs/graphs_pred/{indx}.jpg') 189 | plt.close('all') 190 | return mistakes 191 | 192 | def save_samples( 193 | sample, ext, model_kwargs, 194 | tmp_count, num_room_types, 195 | save_gif=False, save_edges=False, 196 | door_indices = [11, 12, 13], ID_COLOR=None, 197 | is_syn=False, draw_graph=False, save_svg=False): 198 | prefix = 'syn_' if is_syn else '' 199 | graph_errors = [] 200 | if not save_gif: 201 | sample = sample[-1:] 202 | for i in tqdm(range(sample.shape[1])): 203 | resolution = 256 204 | images = [] 205 | images2 = [] 206 | images3 = [] 207 | for k in range(sample.shape[0]): 208 | draw = drawsvg.Drawing(resolution, resolution, displayInline=False) 209 | draw.append(drawsvg.Rectangle(0,0,resolution,resolution, fill='black')) 210 | draw2 = drawsvg.Drawing(resolution, resolution, displayInline=False) 211 | draw2.append(drawsvg.Rectangle(0,0,resolution,resolution, fill='black')) 212 | draw3 = drawsvg.Drawing(resolution, resolution, displayInline=False) 213 | draw3.append(drawsvg.Rectangle(0,0,resolution,resolution, fill='black')) 214 | draw_color = drawsvg.Drawing(resolution, resolution, displayInline=False) 215 | draw_color.append(drawsvg.Rectangle(0,0,resolution,resolution, fill='white')) 216 | polys = [] 217 | types = [] 218 | for j, point in (enumerate(sample[k][i])): 219 | if model_kwargs[f'{prefix}src_key_padding_mask'][i][j]==1: 220 | continue 221 | point = point.cpu().data.numpy() 222 | if j==0: 223 | poly = [] 224 | if j>0 and (model_kwargs[f'{prefix}room_indices'][i, j]!=model_kwargs[f'{prefix}room_indices'][i, j-1]).any(): 225 | polys.append(poly) 226 | types.append(c) 227 | poly = [] 228 | pred_center = False 229 | if pred_center: 230 | point = point/2 + 1 231 | point = point * resolution//2 232 | else: 233 | point = point/2 + 0.5 234 | point = point * resolution 235 | poly.append((point[0], point[1])) 236 | c = np.argmax(model_kwargs[f'{prefix}room_types'][i][j-1].cpu().numpy()) 237 | polys.append(poly) 238 | types.append(c) 239 | for poly, c in zip(polys, types): 240 | if c in door_indices or c==0: 241 | continue 242 | room_type = c 243 | c = webcolors.hex_to_rgb(ID_COLOR[c]) 244 | draw_color.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='black', stroke_width=1)) 245 | draw.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill='black', fill_opacity=0.0, stroke=webcolors.rgb_to_hex([int(x/2) for x in c]), stroke_width=0.5*(resolution/256))) 246 | draw2.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill=ID_COLOR[room_type], fill_opacity=1.0, stroke=webcolors.rgb_to_hex([int(x/2) for x in c]), stroke_width=0.5*(resolution/256))) 247 | for corner in poly: 248 | draw.append(drawsvg.Circle(corner[0], corner[1], 2*(resolution/256), fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='gray', stroke_width=0.25)) 249 | draw3.append(drawsvg.Circle(corner[0], corner[1], 2*(resolution/256), fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='gray', stroke_width=0.25)) 250 | for poly, c in zip(polys, types): 251 | if c not in door_indices: 252 | continue 253 | room_type = c 254 | c = webcolors.hex_to_rgb(ID_COLOR[c]) 255 | draw_color.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='black', stroke_width=1)) 256 | draw.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill='black', fill_opacity=0.0, stroke=webcolors.rgb_to_hex([int(x/2) for x in c]), stroke_width=0.5*(resolution/256))) 257 | draw2.append(drawsvg.Lines(*np.array(poly).flatten().tolist(), close=True, fill=ID_COLOR[room_type], fill_opacity=1.0, stroke=webcolors.rgb_to_hex([int(x/2) for x in c]), stroke_width=0.5*(resolution/256))) 258 | for corner in poly: 259 | draw.append(drawsvg.Circle(corner[0], corner[1], 2*(resolution/256), fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='gray', stroke_width=0.25)) 260 | draw3.append(drawsvg.Circle(corner[0], corner[1], 2*(resolution/256), fill=ID_COLOR[room_type], fill_opacity=1.0, stroke='gray', stroke_width=0.25)) 261 | images.append(Image.open(io.BytesIO(cairosvg.svg2png(draw.asSvg())))) 262 | images2.append(Image.open(io.BytesIO(cairosvg.svg2png(draw2.asSvg())))) 263 | images3.append(Image.open(io.BytesIO(cairosvg.svg2png(draw3.asSvg())))) 264 | if k==sample.shape[0]-1 or True: 265 | if save_edges: 266 | draw.saveSvg(f'outputs/{ext}/{tmp_count+i}_{k}_{ext}.svg') 267 | if save_svg: 268 | draw_color.saveSvg(f'outputs/{ext}/{tmp_count+i}c_{k}_{ext}.svg') 269 | else: 270 | Image.open(io.BytesIO(cairosvg.svg2png(draw_color.asSvg()))).save(f'outputs/{ext}/{tmp_count+i}c_{ext}.png') 271 | if k==sample.shape[0]-1: 272 | if 'graph' in model_kwargs: 273 | graph_errors.append(estimate_graph(tmp_count+i, polys, types, model_kwargs[f'{prefix}graph'][i], ID_COLOR=ID_COLOR, draw_graph=draw_graph, save_svg=save_svg)) 274 | else: 275 | graph_errors.append(0) 276 | if save_gif: 277 | imageio.mimwrite(f'outputs/gif/{tmp_count+i}.gif', images, fps=10, loop=1) 278 | imageio.mimwrite(f'outputs/gif/{tmp_count+i}_v2.gif', images2, fps=10, loop=1) 279 | imageio.mimwrite(f'outputs/gif/{tmp_count+i}_v3.gif', images3, fps=10, loop=1) 280 | return graph_errors 281 | 282 | def main(): 283 | args = create_argparser().parse_args() 284 | update_arg_parser(args) 285 | 286 | dist_util.setup_dist() 287 | logger.configure() 288 | 289 | logger.log("creating model and diffusion...") 290 | model, diffusion = create_model_and_diffusion( 291 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 292 | ) 293 | model.load_state_dict( 294 | dist_util.load_state_dict(args.model_path, map_location="cpu") 295 | ) 296 | model.to(dist_util.dev()) 297 | model.eval() 298 | 299 | errors = [] 300 | for _ in range(5): 301 | logger.log("sampling...") 302 | tmp_count = 0 303 | os.makedirs('outputs/pred', exist_ok=True) 304 | os.makedirs('outputs/gt', exist_ok=True) 305 | os.makedirs('outputs/gif', exist_ok=True) 306 | os.makedirs('outputs/graphs_gt', exist_ok=True) 307 | os.makedirs('outputs/graphs_pred', exist_ok=True) 308 | 309 | if args.dataset=='rplan': 310 | ID_COLOR = {1: '#EE4D4D', 2: '#C67C7B', 3: '#FFD274', 4: '#BEBEBE', 5: '#BFE3E8', 311 | 6: '#7BA779', 7: '#E87A90', 8: '#FF8C69', 10: '#1F849B', 11: '#727171', 312 | 13: '#785A67', 12: '#D3A2C7'} 313 | num_room_types = 14 314 | data = load_rplanhg_data( 315 | batch_size=args.batch_size, 316 | analog_bit=args.analog_bit, 317 | set_name=args.set_name, 318 | target_set=args.target_set, 319 | ) 320 | else: 321 | print("dataset does not exist!") 322 | assert False 323 | graph_errors = [] 324 | while tmp_count < args.num_samples: 325 | model_kwargs = {} 326 | sample_fn = ( 327 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 328 | ) 329 | data_sample, model_kwargs = next(data) 330 | for key in model_kwargs: 331 | model_kwargs[key] = model_kwargs[key].cuda() 332 | 333 | sample = sample_fn( 334 | model, 335 | data_sample.shape, 336 | clip_denoised=args.clip_denoised, 337 | model_kwargs=model_kwargs, 338 | analog_bit=args.analog_bit, 339 | ) 340 | sample_gt = data_sample.cuda().unsqueeze(0) 341 | sample = sample.permute([0, 1, 3, 2]) 342 | sample_gt = sample_gt.permute([0, 1, 3, 2]) 343 | if args.analog_bit: 344 | sample_gt = bin_to_int_sample(sample_gt) 345 | sample = bin_to_int_sample(sample) 346 | 347 | graph_error = save_samples(sample_gt, 'gt', model_kwargs, tmp_count, num_room_types, ID_COLOR=ID_COLOR, draw_graph=args.draw_graph, save_svg=args.save_svg) 348 | graph_error = save_samples(sample, 'pred', model_kwargs, tmp_count, num_room_types, ID_COLOR=ID_COLOR, is_syn=True, draw_graph=args.draw_graph, save_svg=args.save_svg) 349 | graph_errors.extend(graph_error) 350 | tmp_count+=sample_gt.shape[1] 351 | logger.log("sampling complete") 352 | fid_score = calculate_fid_given_paths(['outputs/gt', 'outputs/pred'], 64, 'cuda', 2048) 353 | print(f'FID: {fid_score}') 354 | print(f'Compatibility: {np.mean(graph_errors)}') 355 | errors.append([fid_score, np.mean(graph_errors)]) 356 | errors = np.array(errors) 357 | print(f'Diversity mean: {errors[:, 0].mean()} \t Diversity std: {errors[:, 0].std()}') 358 | print(f'Compatibility mean: {errors[:, 1].mean()} \t Compatibility std: {errors[:, 1].std()}') 359 | 360 | def create_argparser(): 361 | defaults = dict( 362 | dataset='', 363 | clip_denoised=True, 364 | num_samples=10000, 365 | batch_size=16, 366 | use_ddim=False, 367 | model_path="", 368 | draw_graph=True, 369 | save_svg=True, 370 | ) 371 | defaults.update(model_and_diffusion_defaults()) 372 | parser = argparse.ArgumentParser() 373 | add_dict_to_argparser(parser, defaults) 374 | return parser 375 | 376 | 377 | if __name__ == "__main__": 378 | main() 379 | -------------------------------------------------------------------------------- /house_diffusion/rplanhg_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch as th 4 | 5 | from PIL import Image, ImageDraw 6 | import blobfile as bf 7 | from mpi4py import MPI 8 | import numpy as np 9 | from torch.utils.data import DataLoader, Dataset 10 | from glob import glob 11 | import json 12 | import os 13 | import cv2 as cv 14 | from tqdm import tqdm 15 | from shapely import geometry as gm 16 | from shapely.ops import unary_union 17 | from collections import defaultdict 18 | import copy 19 | 20 | def load_rplanhg_data( 21 | batch_size, 22 | analog_bit, 23 | target_set = 8, 24 | set_name = 'train', 25 | ): 26 | """ 27 | For a dataset, create a generator over (shapes, kwargs) pairs. 28 | """ 29 | print(f"loading {set_name} of target set {target_set}") 30 | deterministic = False if set_name=='train' else True 31 | dataset = RPlanhgDataset(set_name, analog_bit, target_set) 32 | if deterministic: 33 | loader = DataLoader( 34 | dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False 35 | ) 36 | else: 37 | loader = DataLoader( 38 | dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False 39 | ) 40 | while True: 41 | yield from loader 42 | 43 | def make_non_manhattan(poly, polygon, house_poly): 44 | dist = abs(poly[2]-poly[0]) 45 | direction = np.argmin(dist) 46 | center = poly.mean(0) 47 | min = poly.min(0) 48 | max = poly.max(0) 49 | 50 | tmp = np.random.randint(3, 7) 51 | new_min_y = center[1]-(max[1]-min[1])/tmp 52 | new_max_y = center[1]+(max[1]-min[1])/tmp 53 | if center[0]<128: 54 | new_min_x = min[0]-(max[0]-min[0])/np.random.randint(2,5) 55 | new_max_x = center[0] 56 | poly1=[[min[0], min[1]], [new_min_x, new_min_y], [new_min_x, new_max_y], [min[0], max[1]], [max[0], max[1]], [max[0], min[1]]] 57 | else: 58 | new_min_x = center[0] 59 | new_max_x = max[0]+(max[0]-min[0])/np.random.randint(2,5) 60 | poly1=[[min[0], min[1]], [min[0], max[1]], [max[0], max[1]], [new_max_x, new_max_y], [new_max_x, new_min_y], [max[0], min[1]]] 61 | 62 | new_min_x = center[0]-(max[0]-min[0])/tmp 63 | new_max_x = center[0]+(max[0]-min[0])/tmp 64 | if center[1]<128: 65 | new_min_y = min[1]-(max[1]-min[1])/np.random.randint(2,5) 66 | new_max_y = center[1] 67 | poly2=[[min[0], min[1]], [min[0], max[1]], [max[0], max[1]], [max[0], min[1]], [new_max_x, new_min_y], [new_min_x, new_min_y]] 68 | else: 69 | new_min_y = center[1] 70 | new_max_y = max[1]+(max[1]-min[1])/np.random.randint(2,5) 71 | poly2=[[min[0], min[1]], [min[0], max[1]], [new_min_x, new_max_y], [new_max_x, new_max_y], [max[0], max[1]], [max[0], min[1]]] 72 | p1 = gm.Polygon(poly1) 73 | iou1 = house_poly.intersection(p1).area/ p1.area 74 | p2 = gm.Polygon(poly2) 75 | iou2 = house_poly.intersection(p2).area/ p2.area 76 | if iou1>0.9 and iou2>0.9: 77 | return poly 78 | if iou110: 176 | continue 177 | if len(room[0])!=4: 178 | continue 179 | if np.random.randint(2): 180 | continue 181 | poly = gm.Polygon(room[0]) 182 | house_polygon = unary_union([gm.Polygon(room[0]) for room in h]) 183 | room[0] = make_non_manhattan(room[0], poly, house_polygon) 184 | 185 | for h, graph in tqdm(zip(self.org_houses, self.org_graphs), desc='processing dataset'): 186 | house = [] 187 | corner_bounds = [] 188 | num_points = 0 189 | for i, room in enumerate(h): 190 | if room[1]>10: 191 | room[1] = {15:11, 17:12, 16:13}[room[1]] 192 | room[0] = np.reshape(room[0], [len(room[0]), 2])/256. - 0.5 # [[x0,y0],[x1,y1],...,[x15,y15]] and map to 0-1 - > -0.5, 0.5 193 | room[0] = room[0] * 2 # map to [-1, 1] 194 | if self.set_name=='train': 195 | cnumber_dist[room[1]].append(len(room[0])) 196 | # Adding conditions 197 | num_room_corners = len(room[0]) 198 | rtype = np.repeat(np.array([get_one_hot(room[1], 25)]), num_room_corners, 0) 199 | room_index = np.repeat(np.array([get_one_hot(len(house)+1, 32)]), num_room_corners, 0) 200 | corner_index = np.array([get_one_hot(x, 32) for x in range(num_room_corners)]) 201 | # Src_key_padding_mask 202 | padding_mask = np.repeat(1, num_room_corners) 203 | padding_mask = np.expand_dims(padding_mask, 1) 204 | # Generating corner bounds for attention masks 205 | connections = np.array([[i,(i+1)%num_room_corners] for i in range(num_room_corners)]) 206 | connections += num_points 207 | corner_bounds.append([num_points, num_points+num_room_corners]) 208 | num_points += num_room_corners 209 | room = np.concatenate((room[0], rtype, corner_index, room_index, padding_mask, connections), 1) 210 | house.append(room) 211 | 212 | house_layouts = np.concatenate(house, 0) 213 | if len(house_layouts)>max_num_points: 214 | continue 215 | padding = np.zeros((max_num_points-len(house_layouts), 94)) 216 | gen_mask = np.ones((max_num_points, max_num_points)) 217 | gen_mask[:len(house_layouts), :len(house_layouts)] = 0 218 | house_layouts = np.concatenate((house_layouts, padding), 0) 219 | 220 | door_mask = np.ones((max_num_points, max_num_points)) 221 | self_mask = np.ones((max_num_points, max_num_points)) 222 | for i in range(len(corner_bounds)): 223 | for j in range(len(corner_bounds)): 224 | if i==j: 225 | self_mask[corner_bounds[i][0]:corner_bounds[i][1],corner_bounds[j][0]:corner_bounds[j][1]] = 0 226 | elif any(np.equal([i, 1, j], graph).all(1)) or any(np.equal([j, 1, i], graph).all(1)): 227 | door_mask[corner_bounds[i][0]:corner_bounds[i][1],corner_bounds[j][0]:corner_bounds[j][1]] = 0 228 | houses.append(house_layouts) 229 | door_masks.append(door_mask) 230 | self_masks.append(self_mask) 231 | gen_masks.append(gen_mask) 232 | graphs.append(graph) 233 | self.max_num_points = max_num_points 234 | self.houses = houses 235 | self.door_masks = door_masks 236 | self.self_masks = self_masks 237 | self.gen_masks = gen_masks 238 | self.num_coords = 2 239 | self.graphs = graphs 240 | 241 | np.savez_compressed(f'processed_rplan/rplan_{set_name}_{target_set}', graphs=self.graphs, houses=self.houses, 242 | door_masks=self.door_masks, self_masks=self.self_masks, gen_masks=self.gen_masks) 243 | if self.set_name=='train': 244 | np.savez_compressed(f'processed_rplan/rplan_{set_name}_{target_set}_cndist', cnumber_dist=cnumber_dist) 245 | 246 | if set_name=='eval': 247 | houses = [] 248 | graphs = [] 249 | door_masks = [] 250 | self_masks = [] 251 | gen_masks = [] 252 | len_house_layouts = 0 253 | for h, graph in tqdm(zip(self.org_houses, self.org_graphs), desc='processing dataset'): 254 | house = [] 255 | corner_bounds = [] 256 | num_points = 0 257 | num_room_corners_total = [cnumber_dist[room[1]][random.randint(0, len(cnumber_dist[room[1]])-1)] for room in h] 258 | while np.sum(num_room_corners_total)>=max_num_points: 259 | num_room_corners_total = [cnumber_dist[room[1]][random.randint(0, len(cnumber_dist[room[1]])-1)] for room in h] 260 | for i, room in enumerate(h): 261 | # Adding conditions 262 | num_room_corners = num_room_corners_total[i] 263 | rtype = np.repeat(np.array([get_one_hot(room[1], 25)]), num_room_corners, 0) 264 | room_index = np.repeat(np.array([get_one_hot(len(house)+1, 32)]), num_room_corners, 0) 265 | corner_index = np.array([get_one_hot(x, 32) for x in range(num_room_corners)]) 266 | # Src_key_padding_mask 267 | padding_mask = np.repeat(1, num_room_corners) 268 | padding_mask = np.expand_dims(padding_mask, 1) 269 | # Generating corner bounds for attention masks 270 | connections = np.array([[i,(i+1)%num_room_corners] for i in range(num_room_corners)]) 271 | connections += num_points 272 | corner_bounds.append([num_points, num_points+num_room_corners]) 273 | num_points += num_room_corners 274 | room = np.concatenate((np.zeros([num_room_corners, 2]), rtype, corner_index, room_index, padding_mask, connections), 1) 275 | house.append(room) 276 | 277 | house_layouts = np.concatenate(house, 0) 278 | if np.sum([len(room[0]) for room in h])>max_num_points: 279 | continue 280 | padding = np.zeros((max_num_points-len(house_layouts), 94)) 281 | gen_mask = np.ones((max_num_points, max_num_points)) 282 | gen_mask[:len(house_layouts), :len(house_layouts)] = 0 283 | house_layouts = np.concatenate((house_layouts, padding), 0) 284 | 285 | door_mask = np.ones((max_num_points, max_num_points)) 286 | self_mask = np.ones((max_num_points, max_num_points)) 287 | for i, room in enumerate(h): 288 | if room[1]==1: 289 | living_room_index = i 290 | break 291 | for i in range(len(corner_bounds)): 292 | is_connected = False 293 | for j in range(len(corner_bounds)): 294 | if i==j: 295 | self_mask[corner_bounds[i][0]:corner_bounds[i][1],corner_bounds[j][0]:corner_bounds[j][1]] = 0 296 | elif any(np.equal([i, 1, j], graph).all(1)) or any(np.equal([j, 1, i], graph).all(1)): 297 | door_mask[corner_bounds[i][0]:corner_bounds[i][1],corner_bounds[j][0]:corner_bounds[j][1]] = 0 298 | is_connected = True 299 | if not is_connected: 300 | door_mask[corner_bounds[i][0]:corner_bounds[i][1],corner_bounds[living_room_index][0]:corner_bounds[living_room_index][1]] = 0 301 | 302 | houses.append(house_layouts) 303 | door_masks.append(door_mask) 304 | self_masks.append(self_mask) 305 | gen_masks.append(gen_mask) 306 | graphs.append(graph) 307 | self.syn_houses = houses 308 | self.syn_door_masks = door_masks 309 | self.syn_self_masks = self_masks 310 | self.syn_gen_masks = gen_masks 311 | self.syn_graphs = graphs 312 | np.savez_compressed(f'processed_rplan/rplan_{set_name}_{target_set}_syn', graphs=self.syn_graphs, houses=self.syn_houses, 313 | door_masks=self.syn_door_masks, self_masks=self.syn_self_masks, gen_masks=self.syn_gen_masks) 314 | 315 | def __len__(self): 316 | return len(self.houses) 317 | 318 | def __getitem__(self, idx): 319 | # idx = int(idx//20) 320 | arr = self.houses[idx][:, :self.num_coords] 321 | graph = np.concatenate((self.graphs[idx], np.zeros([200-len(self.graphs[idx]), 3])), 0) 322 | 323 | cond = { 324 | 'door_mask': self.door_masks[idx], 325 | 'self_mask': self.self_masks[idx], 326 | 'gen_mask': self.gen_masks[idx], 327 | 'room_types': self.houses[idx][:, self.num_coords:self.num_coords+25], 328 | 'corner_indices': self.houses[idx][:, self.num_coords+25:self.num_coords+57], 329 | 'room_indices': self.houses[idx][:, self.num_coords+57:self.num_coords+89], 330 | 'src_key_padding_mask': 1-self.houses[idx][:, self.num_coords+89], 331 | 'connections': self.houses[idx][:, self.num_coords+90:self.num_coords+92], 332 | 'graph': graph, 333 | } 334 | if self.set_name == 'eval': 335 | syn_graph = np.concatenate((self.syn_graphs[idx], np.zeros([200-len(self.syn_graphs[idx]), 3])), 0) 336 | assert (graph == syn_graph).all(), idx 337 | cond.update({ 338 | 'syn_door_mask': self.syn_door_masks[idx], 339 | 'syn_self_mask': self.syn_self_masks[idx], 340 | 'syn_gen_mask': self.syn_gen_masks[idx], 341 | 'syn_room_types': self.syn_houses[idx][:, self.num_coords:self.num_coords+25], 342 | 'syn_corner_indices': self.syn_houses[idx][:, self.num_coords+25:self.num_coords+57], 343 | 'syn_room_indices': self.syn_houses[idx][:, self.num_coords+57:self.num_coords+89], 344 | 'syn_src_key_padding_mask': 1-self.syn_houses[idx][:, self.num_coords+89], 345 | 'syn_connections': self.syn_houses[idx][:, self.num_coords+90:self.num_coords+92], 346 | 'syn_graph': syn_graph, 347 | }) 348 | if self.set_name == 'train': 349 | #### Random Rotate 350 | rotation = random.randint(0,3) 351 | if rotation == 1: 352 | arr[:, [0, 1]] = arr[:, [1, 0]] 353 | arr[:, 0] = -arr[:, 0] 354 | elif rotation == 2: 355 | arr[:, [0, 1]] = -arr[:, [1, 0]] 356 | elif rotation == 3: 357 | arr[:, [0, 1]] = arr[:, [1, 0]] 358 | arr[:, 1] = -arr[:, 1] 359 | 360 | ## To generate any rotation uncomment this 361 | 362 | # if self.non_manhattan: 363 | # theta = random.random()*np.pi/2 364 | # rot_mat = np.array([[np.cos(theta), -np.sin(theta), 0], 365 | # [np.sin(theta), np.cos(theta), 0]]) 366 | # arr = np.matmul(arr,rot_mat)[:,:2] 367 | 368 | # Random Scale 369 | # arr = arr * np.random.normal(1., .5) 370 | 371 | # Random Shift 372 | # arr[:, 0] = arr[:, 0] + np.random.normal(0., .1) 373 | # arr[:, 1] = arr[:, 1] + np.random.normal(0., .1) 374 | 375 | if not self.analog_bit: 376 | arr = np.transpose(arr, [1, 0]) 377 | return arr.astype(float), cond 378 | else: 379 | ONE_HOT_RES = 256 380 | arr_onehot = np.zeros((ONE_HOT_RES*2, arr.shape[1])) - 1 381 | xs = ((arr[:, 0]+1)*(ONE_HOT_RES/2)).astype(int) 382 | ys = ((arr[:, 1]+1)*(ONE_HOT_RES/2)).astype(int) 383 | xs = np.array([get_bin(x, 8) for x in xs]) 384 | ys = np.array([get_bin(x, 8) for x in ys]) 385 | arr_onehot = np.concatenate([xs, ys], 1) 386 | arr_onehot = np.transpose(arr_onehot, [1, 0]) 387 | arr_onehot[arr_onehot==0] = -1 388 | return arr_onehot.astype(float), cond 389 | 390 | def make_sequence(self, edges): 391 | polys = [] 392 | v_curr = tuple(edges[0][:2]) 393 | e_ind_curr = 0 394 | e_visited = [0] 395 | seq_tracker = [v_curr] 396 | find_next = False 397 | while len(e_visited) < len(edges): 398 | if find_next == False: 399 | if v_curr == tuple(edges[e_ind_curr][2:]): 400 | v_curr = tuple(edges[e_ind_curr][:2]) 401 | else: 402 | v_curr = tuple(edges[e_ind_curr][2:]) 403 | find_next = not find_next 404 | else: 405 | # look for next edge 406 | for k, e in enumerate(edges): 407 | if k not in e_visited: 408 | if (v_curr == tuple(e[:2])): 409 | v_curr = tuple(e[2:]) 410 | e_ind_curr = k 411 | e_visited.append(k) 412 | break 413 | elif (v_curr == tuple(e[2:])): 414 | v_curr = tuple(e[:2]) 415 | e_ind_curr = k 416 | e_visited.append(k) 417 | break 418 | 419 | # extract next sequence 420 | if v_curr == seq_tracker[-1]: 421 | polys.append(seq_tracker) 422 | for k, e in enumerate(edges): 423 | if k not in e_visited: 424 | v_curr = tuple(edges[0][:2]) 425 | seq_tracker = [v_curr] 426 | find_next = False 427 | e_ind_curr = k 428 | e_visited.append(k) 429 | break 430 | else: 431 | seq_tracker.append(v_curr) 432 | polys.append(seq_tracker) 433 | 434 | return polys 435 | 436 | def build_graph(self, rms_type, fp_eds, eds_to_rms, out_size=64): 437 | # create edges 438 | triples = [] 439 | nodes = rms_type 440 | # encode connections 441 | for k in range(len(nodes)): 442 | for l in range(len(nodes)): 443 | if l > k: 444 | is_adjacent = any([True for e_map in eds_to_rms if (l in e_map) and (k in e_map)]) 445 | if is_adjacent: 446 | if 'train' in self.set_name: 447 | triples.append([k, 1, l]) 448 | else: 449 | triples.append([k, 1, l]) 450 | else: 451 | if 'train' in self.set_name: 452 | triples.append([k, -1, l]) 453 | else: 454 | triples.append([k, -1, l]) 455 | # get rooms masks 456 | eds_to_rms_tmp = [] 457 | for l in range(len(eds_to_rms)): 458 | eds_to_rms_tmp.append([eds_to_rms[l][0]]) 459 | rms_masks = [] 460 | im_size = 256 461 | fp_mk = np.zeros((out_size, out_size)) 462 | for k in range(len(nodes)): 463 | # add rooms and doors 464 | eds = [] 465 | for l, e_map in enumerate(eds_to_rms_tmp): 466 | if (k in e_map): 467 | eds.append(l) 468 | # draw rooms 469 | rm_im = Image.new('L', (im_size, im_size)) 470 | dr = ImageDraw.Draw(rm_im) 471 | for eds_poly in [eds]: 472 | poly = self.make_sequence(np.array([fp_eds[l][:4] for l in eds_poly]))[0] 473 | poly = [(im_size*x, im_size*y) for x, y in poly] 474 | if len(poly) >= 2: 475 | dr.polygon(poly, fill='white') 476 | else: 477 | print("Empty room") 478 | exit(0) 479 | rm_im = rm_im.resize((out_size, out_size)) 480 | rm_arr = np.array(rm_im) 481 | inds = np.where(rm_arr>0) 482 | rm_arr[inds] = 1.0 483 | rms_masks.append(rm_arr) 484 | if rms_type[k] != 15 and rms_type[k] != 17: 485 | fp_mk[inds] = k+1 486 | # trick to remove overlap 487 | for k in range(len(nodes)): 488 | if rms_type[k] != 15 and rms_type[k] != 17: 489 | rm_arr = np.zeros((out_size, out_size)) 490 | inds = np.where(fp_mk==k+1) 491 | rm_arr[inds] = 1.0 492 | rms_masks[k] = rm_arr 493 | # convert to array 494 | nodes = np.array(nodes) 495 | triples = np.array(triples) 496 | rms_masks = np.array(rms_masks) 497 | return nodes, triples, rms_masks 498 | 499 | def is_adjacent(box_a, box_b, threshold=0.03): 500 | x0, y0, x1, y1 = box_a 501 | x2, y2, x3, y3 = box_b 502 | h1, h2 = x1-x0, x3-x2 503 | w1, w2 = y1-y0, y3-y2 504 | xc1, xc2 = (x0+x1)/2.0, (x2+x3)/2.0 505 | yc1, yc2 = (y0+y1)/2.0, (y2+y3)/2.0 506 | delta_x = np.abs(xc2-xc1) - (h1 + h2)/2.0 507 | delta_y = np.abs(yc2-yc1) - (w1 + w2)/2.0 508 | delta = max(delta_x, delta_y) 509 | return delta < threshold 510 | 511 | def reader(filename): 512 | with open(filename) as f: 513 | info =json.load(f) 514 | rms_bbs=np.asarray(info['boxes']) 515 | fp_eds=info['edges'] 516 | rms_type=info['room_type'] 517 | eds_to_rms=info['ed_rm'] 518 | s_r=0 519 | for rmk in range(len(rms_type)): 520 | if(rms_type[rmk]!=17): 521 | s_r=s_r+1 522 | rms_bbs = np.array(rms_bbs)/256.0 523 | fp_eds = np.array(fp_eds)/256.0 524 | fp_eds = fp_eds[:, :4] 525 | tl = np.min(rms_bbs[:, :2], 0) 526 | br = np.max(rms_bbs[:, 2:], 0) 527 | shift = (tl+br)/2.0 - 0.5 528 | rms_bbs[:, :2] -= shift 529 | rms_bbs[:, 2:] -= shift 530 | fp_eds[:, :2] -= shift 531 | fp_eds[:, 2:] -= shift 532 | tl -= shift 533 | br -= shift 534 | return rms_type,fp_eds,rms_bbs,eds_to_rms 535 | 536 | if __name__ == '__main__': 537 | dataset = RPlanhgDataset('eval', False, 8) 538 | -------------------------------------------------------------------------------- /LICENSE_GPL: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /house_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | from tqdm.auto import tqdm 17 | 18 | 19 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 20 | """ 21 | Get a pre-defined beta schedule for the given name. 22 | 23 | The beta schedule library consists of beta schedules which remain similar 24 | in the limit of num_diffusion_timesteps. 25 | Beta schedules may be added, but should not be removed or changed once 26 | they are committed to maintain backwards compatibility. 27 | """ 28 | if schedule_name == "linear": 29 | # Linear schedule from Ho et al, extended to work for any number of 30 | # diffusion steps. 31 | scale = 1000 / num_diffusion_timesteps 32 | beta_start = scale * 0.0001 33 | beta_end = scale * 0.02 34 | return np.linspace( 35 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 36 | ) 37 | elif schedule_name == "cosine": 38 | print("COSINE") 39 | return betas_for_alpha_bar( 40 | num_diffusion_timesteps, 41 | # lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 42 | lambda t: math.cos((t) / 1.000 * math.pi / 2) ** 2, 43 | ) 44 | else: 45 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 46 | 47 | 48 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 49 | """ 50 | Create a beta schedule that discretizes the given alpha_t_bar function, 51 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 52 | 53 | :param num_diffusion_timesteps: the number of betas to produce. 54 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 55 | produces the cumulative product of (1-beta) up to that 56 | part of the diffusion process. 57 | :param max_beta: the maximum beta to use; use values lower than 1 to 58 | prevent singularities. 59 | """ 60 | betas = [] 61 | for i in range(num_diffusion_timesteps): 62 | t1 = i / num_diffusion_timesteps 63 | t2 = (i + 1) / num_diffusion_timesteps 64 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 65 | return np.array(betas) 66 | 67 | 68 | class ModelMeanType(enum.Enum): 69 | """ 70 | Which type of output the model predicts. 71 | """ 72 | 73 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 74 | START_X = enum.auto() # the model predicts x_0 75 | EPSILON = enum.auto() # the model predicts epsilon 76 | 77 | 78 | class ModelVarType(enum.Enum): 79 | """ 80 | What is used as the model's output variance. 81 | 82 | The LEARNED_RANGE option has been added to allow the model to predict 83 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 84 | """ 85 | 86 | LEARNED = enum.auto() 87 | FIXED_SMALL = enum.auto() 88 | FIXED_LARGE = enum.auto() 89 | LEARNED_RANGE = enum.auto() 90 | 91 | 92 | class LossType(enum.Enum): 93 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 94 | RESCALED_MSE = ( 95 | enum.auto() 96 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 97 | KL = enum.auto() # use the variational lower-bound 98 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 99 | 100 | def is_vb(self): 101 | return self == LossType.KL or self == LossType.RESCALED_KL 102 | 103 | 104 | class GaussianDiffusion: 105 | """ 106 | Utilities for training and sampling diffusion models. 107 | 108 | Ported directly from here, and then adapted over time to further experimentation. 109 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 110 | 111 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 112 | starting at T and going to 1. 113 | :param model_mean_type: a ModelMeanType determining what the model outputs. 114 | :param model_var_type: a ModelVarType determining how variance is output. 115 | :param loss_type: a LossType determining the loss function to use. 116 | :param rescale_timesteps: if True, pass floating point timesteps into the 117 | model so that they are always scaled like in the 118 | original paper (0 to 1000). 119 | """ 120 | 121 | def __init__( 122 | self, 123 | *, 124 | betas, 125 | model_mean_type, 126 | model_var_type, 127 | loss_type, 128 | rescale_timesteps=False, 129 | ): 130 | self.model_mean_type = model_mean_type 131 | self.model_var_type = model_var_type 132 | self.loss_type = loss_type 133 | self.rescale_timesteps = rescale_timesteps 134 | 135 | # Use float64 for accuracy. 136 | betas = np.array(betas, dtype=np.float64) 137 | self.betas = betas 138 | assert len(betas.shape) == 1, "betas must be 1-D" 139 | assert (betas > 0).all() and (betas <= 1).all() 140 | 141 | self.num_timesteps = int(betas.shape[0]) 142 | 143 | alphas = 1.0 - betas 144 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 145 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 146 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 147 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 148 | 149 | # calculations for diffusion q(x_t | x_{t-1}) and others 150 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 151 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 152 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 153 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 154 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 155 | 156 | # calculations for posterior q(x_{t-1} | x_t, x_0) 157 | self.posterior_variance = ( 158 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 159 | ) 160 | # log calculation clipped because the posterior variance is 0 at the 161 | # beginning of the diffusion chain. 162 | self.posterior_log_variance_clipped = np.log( 163 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 164 | ) 165 | self.posterior_mean_coef1 = ( 166 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 167 | ) 168 | self.posterior_mean_coef2 = ( 169 | (1.0 - self.alphas_cumprod_prev) 170 | * np.sqrt(alphas) 171 | / (1.0 - self.alphas_cumprod) 172 | ) 173 | 174 | def q_mean_variance(self, x_start, t): 175 | """ 176 | Get the distribution q(x_t | x_0). 177 | 178 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 179 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 180 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 181 | """ 182 | mean = ( 183 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 184 | ) 185 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 186 | log_variance = _extract_into_tensor( 187 | self.log_one_minus_alphas_cumprod, t, x_start.shape 188 | ) 189 | return mean, variance, log_variance 190 | 191 | def q_sample(self, x_start, t, noise=None): 192 | """ 193 | Diffuse the data for a given number of diffusion steps. 194 | 195 | In other words, sample from q(x_t | x_0). 196 | 197 | :param x_start: the initial data batch. 198 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 199 | :param noise: if specified, the split-out normal noise. 200 | :return: A noisy version of x_start. 201 | """ 202 | if noise is None: 203 | noise = th.randn_like(x_start) 204 | assert noise.shape == x_start.shape 205 | return ( 206 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 207 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 208 | * noise 209 | ) 210 | 211 | def q_posterior_mean_variance(self, x_start, x_t, t): 212 | """ 213 | Compute the mean and variance of the diffusion posterior: 214 | 215 | q(x_{t-1} | x_t, x_0) 216 | 217 | """ 218 | assert x_start.shape == x_t.shape 219 | posterior_mean = ( 220 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 221 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 222 | ) 223 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 224 | posterior_log_variance_clipped = _extract_into_tensor( 225 | self.posterior_log_variance_clipped, t, x_t.shape 226 | ) 227 | assert ( 228 | posterior_mean.shape[0] 229 | == posterior_variance.shape[0] 230 | == posterior_log_variance_clipped.shape[0] 231 | == x_start.shape[0] 232 | ) 233 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 234 | 235 | def p_mean_variance( 236 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, analog_bit=None 237 | ): 238 | """ 239 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 240 | the initial x, x_0. 241 | 242 | :param model: the model, which takes a signal and a batch of timesteps 243 | as input. 244 | :param x: the [N x C x ...] tensor at time t. 245 | :param t: a 1-D Tensor of timesteps. 246 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 247 | :param denoised_fn: if not None, a function which applies to the 248 | x_start prediction before it is used to sample. Applies before 249 | clip_denoised. 250 | :param model_kwargs: if not None, a dict of extra keyword arguments to 251 | pass to the model. This can be used for conditioning. 252 | :return: a dict with the following keys: 253 | - 'mean': the model mean output. 254 | - 'variance': the model variance output. 255 | - 'log_variance': the log of 'variance'. 256 | - 'pred_xstart': the prediction for x_0. 257 | """ 258 | if model_kwargs is None: 259 | model_kwargs = {} 260 | 261 | B, C = x.shape[:2] 262 | assert t.shape == (B,) 263 | xtalpha = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape).permute([0,2,1]) 264 | epsalpha = _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape).permute([0,2,1]) 265 | model_output_dec, model_output_bin = model(x, self._scale_timesteps(t), xtalpha=xtalpha, epsalpha=epsalpha, is_syn=True, **model_kwargs) 266 | model_output = model_output_dec 267 | 268 | if analog_bit: 269 | predict_descrete = 0 270 | else: 271 | predict_descrete = 32 272 | 273 | if t[0] < predict_descrete: 274 | def bin2dec(b, bits): 275 | mask = 2 ** th.arange(bits - 1, -1, -1).to(b.device, b.dtype) 276 | return th.sum(mask * b, -1) 277 | model_output_bin[model_output_bin>0] = 1 278 | model_output_bin[model_output_bin<=0] = 0 279 | model_output_bin = bin2dec(model_output_bin.round().int().permute([0,2,1]).reshape(model_output_bin.shape[0], 280 | model_output_bin.shape[2], 2, 8), 8).permute([0,2,1]) 281 | 282 | model_output_bin = ((model_output_bin/256) - 0.5) * 2 283 | model_output = model_output_bin 284 | 285 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 286 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 287 | model_output, model_var_values = th.split(model_output, C, dim=1) 288 | if self.model_var_type == ModelVarType.LEARNED: 289 | model_log_variance = model_var_values 290 | model_variance = th.exp(model_log_variance) 291 | else: 292 | min_log = _extract_into_tensor( 293 | self.posterior_log_variance_clipped, t, x.shape 294 | ) 295 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 296 | # The model_var_values is [-1, 1] for [min_var, max_var]. 297 | frac = (model_var_values + 1) / 2 298 | model_log_variance = frac * max_log + (1 - frac) * min_log 299 | model_variance = th.exp(model_log_variance) 300 | else: 301 | model_variance, model_log_variance = { 302 | # for fixedlarge, we set the initial (log-)variance like so 303 | # to get a better decoder log likelihood. 304 | ModelVarType.FIXED_LARGE: ( 305 | np.append(self.posterior_variance[1], self.betas[1:]), 306 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 307 | ), 308 | ModelVarType.FIXED_SMALL: ( 309 | self.posterior_variance, 310 | self.posterior_log_variance_clipped, 311 | ), 312 | }[self.model_var_type] 313 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 314 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 315 | 316 | def process_xstart(x): 317 | if denoised_fn is not None: 318 | x = denoised_fn(x) 319 | if clip_denoised: 320 | return x.clamp(-1, 1) 321 | return x 322 | 323 | if t[0] >= predict_descrete: 324 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 325 | pred_xstart = process_xstart( 326 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 327 | ) 328 | model_mean = model_output 329 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 330 | if self.model_mean_type == ModelMeanType.START_X: 331 | pred_xstart = process_xstart(model_output) 332 | else: 333 | pred_xstart = process_xstart( 334 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 335 | ) 336 | model_mean, _, _ = self.q_posterior_mean_variance( 337 | x_start=pred_xstart, x_t=x, t=t 338 | ) 339 | else: 340 | raise NotImplementedError(self.model_mean_type) 341 | else: 342 | pred_xstart = process_xstart(model_output) 343 | model_mean, _, _ = self.q_posterior_mean_variance( 344 | x_start=pred_xstart, x_t=x, t=t 345 | ) 346 | 347 | assert ( 348 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 349 | ) 350 | return { 351 | "mean": model_mean, 352 | "variance": model_variance, 353 | "log_variance": model_log_variance, 354 | "pred_xstart": pred_xstart, 355 | } 356 | 357 | def _predict_xstart_from_eps(self, x_t, t, eps): 358 | assert x_t.shape == eps.shape 359 | return ( 360 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 361 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 362 | ) 363 | 364 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 365 | assert x_t.shape == xprev.shape 366 | return ( # (xprev - coef2*x_t) / coef1 367 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 368 | - _extract_into_tensor( 369 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 370 | ) 371 | * x_t 372 | ) 373 | 374 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 375 | return ( 376 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 377 | - pred_xstart 378 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 379 | 380 | def _scale_timesteps(self, t): 381 | if self.rescale_timesteps: 382 | return t.float() * (1000.0 / self.num_timesteps) 383 | return t 384 | 385 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 386 | """ 387 | Compute the mean for the previous step, given a function cond_fn that 388 | computes the gradient of a conditional log probability with respect to 389 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 390 | condition on y. 391 | 392 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 393 | """ 394 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) 395 | new_mean = ( 396 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 397 | ) 398 | return new_mean 399 | 400 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 401 | """ 402 | Compute what the p_mean_variance output would have been, should the 403 | model's score function be conditioned by cond_fn. 404 | 405 | See condition_mean() for details on cond_fn. 406 | 407 | Unlike condition_mean(), this instead uses the conditioning strategy 408 | from Song et al (2020). 409 | """ 410 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 411 | 412 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 413 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn( 414 | x, self._scale_timesteps(t), **model_kwargs 415 | ) 416 | 417 | out = p_mean_var.copy() 418 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 419 | out["mean"], _, _ = self.q_posterior_mean_variance( 420 | x_start=out["pred_xstart"], x_t=x, t=t 421 | ) 422 | return out 423 | 424 | def p_sample( 425 | self, 426 | model, 427 | x, 428 | t, 429 | clip_denoised=True, 430 | denoised_fn=None, 431 | cond_fn=None, 432 | model_kwargs=None, 433 | analog_bit=None, 434 | ): 435 | """ 436 | Sample x_{t-1} from the model at the given timestep. 437 | 438 | :param model: the model to sample from. 439 | :param x: the current tensor at x_{t-1}. 440 | :param t: the value of t, starting at 0 for the first diffusion step. 441 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 442 | :param denoised_fn: if not None, a function which applies to the 443 | x_start prediction before it is used to sample. 444 | :param cond_fn: if not None, this is a gradient function that acts 445 | similarly to the model. 446 | :param model_kwargs: if not None, a dict of extra keyword arguments to 447 | pass to the model. This can be used for conditioning. 448 | :return: a dict containing the following keys: 449 | - 'sample': a random sample from the model. 450 | - 'pred_xstart': a prediction of x_0. 451 | """ 452 | out = self.p_mean_variance( 453 | model, 454 | x, 455 | t, 456 | clip_denoised=clip_denoised, 457 | denoised_fn=denoised_fn, 458 | model_kwargs=model_kwargs, 459 | analog_bit=analog_bit, 460 | ) 461 | noise = th.randn_like(x) 462 | nonzero_mask = ( 463 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 464 | ) # no noise when t == 0 465 | if cond_fn is not None: 466 | out["mean"] = self.condition_mean( 467 | cond_fn, out, x, t, model_kwargs=model_kwargs 468 | ) 469 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 470 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 471 | 472 | def p_sample_loop( 473 | self, 474 | model, 475 | shape, 476 | noise=None, 477 | clip_denoised=True, 478 | denoised_fn=None, 479 | cond_fn=None, 480 | model_kwargs=None, 481 | device=None, 482 | progress=False, 483 | analog_bit=None, 484 | ): 485 | """ 486 | Generate samples from the model. 487 | 488 | :param model: the model module. 489 | :param shape: the shape of the samples, (N, C, H, W). 490 | :param noise: if specified, the noise from the encoder to sample. 491 | Should be of the same shape as `shape`. 492 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 493 | :param denoised_fn: if not None, a function which applies to the 494 | x_start prediction before it is used to sample. 495 | :param cond_fn: if not None, this is a gradient function that acts 496 | similarly to the model. 497 | :param model_kwargs: if not None, a dict of extra keyword arguments to 498 | pass to the model. This can be used for conditioning. 499 | :param device: if specified, the device to create the samples on. 500 | If not specified, use a model parameter's device. 501 | :param progress: if True, show a tqdm progress bar. 502 | :return: a non-differentiable batch of samples. 503 | """ 504 | myfinal = [] 505 | final = None 506 | for i, sample in tqdm(enumerate(self.p_sample_loop_progressive( 507 | model, 508 | shape, 509 | noise=noise, 510 | clip_denoised=clip_denoised, 511 | denoised_fn=denoised_fn, 512 | cond_fn=cond_fn, 513 | model_kwargs=model_kwargs, 514 | device=device, 515 | progress=progress, 516 | analog_bit=analog_bit, 517 | ))): 518 | if i>970: 519 | myfinal.append(sample['sample']) 520 | final = sample 521 | return th.stack(myfinal) 522 | # return final["sample"] 523 | 524 | def p_sample_loop_progressive( 525 | self, 526 | model, 527 | shape, 528 | noise=None, 529 | clip_denoised=True, 530 | denoised_fn=None, 531 | cond_fn=None, 532 | model_kwargs=None, 533 | device=None, 534 | progress=False, 535 | analog_bit=None, 536 | ): 537 | """ 538 | Generate samples from the model and yield intermediate samples from 539 | each timestep of diffusion. 540 | 541 | Arguments are the same as p_sample_loop(). 542 | Returns a generator over dicts, where each dict is the return value of 543 | p_sample(). 544 | """ 545 | if device is None: 546 | device = next(model.parameters()).device 547 | assert isinstance(shape, (tuple, list)) 548 | if noise is not None: 549 | img = noise 550 | else: 551 | img = th.randn(*shape, device=device) 552 | indices = list(range(self.num_timesteps))[::-1] 553 | 554 | if progress: 555 | # Lazy import so that we don't depend on tqdm. 556 | 557 | indices = tqdm(indices) 558 | 559 | for i in indices: 560 | t = th.tensor([i] * shape[0], device=device) 561 | with th.no_grad(): 562 | out = self.p_sample( 563 | model, 564 | img, 565 | t, 566 | clip_denoised=clip_denoised, 567 | denoised_fn=denoised_fn, 568 | cond_fn=cond_fn, 569 | model_kwargs=model_kwargs, 570 | analog_bit=analog_bit, 571 | ) 572 | yield out 573 | img = out["sample"] 574 | 575 | def ddim_sample( 576 | self, 577 | model, 578 | x, 579 | t, 580 | clip_denoised=True, 581 | denoised_fn=None, 582 | cond_fn=None, 583 | model_kwargs=None, 584 | eta=0.0, 585 | ): 586 | """ 587 | Sample x_{t-1} from the model using DDIM. 588 | 589 | Same usage as p_sample(). 590 | """ 591 | out = self.p_mean_variance( 592 | model, 593 | x, 594 | t, 595 | clip_denoised=clip_denoised, 596 | denoised_fn=denoised_fn, 597 | model_kwargs=model_kwargs, 598 | ) 599 | if cond_fn is not None: 600 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 601 | 602 | # Usually our model outputs epsilon, but we re-derive it 603 | # in case we used x_start or x_prev prediction. 604 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 605 | 606 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 607 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 608 | sigma = ( 609 | eta 610 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 611 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 612 | ) 613 | # Equation 12. 614 | noise = th.randn_like(x) 615 | mean_pred = ( 616 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 617 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 618 | ) 619 | nonzero_mask = ( 620 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 621 | ) # no noise when t == 0 622 | sample = mean_pred + nonzero_mask * sigma * noise 623 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 624 | 625 | def ddim_reverse_sample( 626 | self, 627 | model, 628 | x, 629 | t, 630 | clip_denoised=True, 631 | denoised_fn=None, 632 | model_kwargs=None, 633 | eta=0.0, 634 | ): 635 | """ 636 | Sample x_{t+1} from the model using DDIM reverse ODE. 637 | """ 638 | assert eta == 0.0, "Reverse ODE only for deterministic path" 639 | out = self.p_mean_variance( 640 | model, 641 | x, 642 | t, 643 | clip_denoised=clip_denoised, 644 | denoised_fn=denoised_fn, 645 | model_kwargs=model_kwargs, 646 | ) 647 | # Usually our model outputs epsilon, but we re-derive it 648 | # in case we used x_start or x_prev prediction. 649 | eps = ( 650 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 651 | - out["pred_xstart"] 652 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 653 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 654 | 655 | # Equation 12. reversed 656 | mean_pred = ( 657 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 658 | + th.sqrt(1 - alpha_bar_next) * eps 659 | ) 660 | 661 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 662 | 663 | def ddim_sample_loop( 664 | self, 665 | model, 666 | shape, 667 | noise=None, 668 | clip_denoised=True, 669 | denoised_fn=None, 670 | cond_fn=None, 671 | model_kwargs=None, 672 | device=None, 673 | progress=False, 674 | eta=0.0, 675 | ): 676 | """ 677 | Generate samples from the model using DDIM. 678 | 679 | Same usage as p_sample_loop(). 680 | """ 681 | myfinal = [] 682 | final = None 683 | for i, sample in tqdm(enumerate(self.ddim_sample_loop_progressive( 684 | model, 685 | shape, 686 | noise=noise, 687 | clip_denoised=clip_denoised, 688 | denoised_fn=denoised_fn, 689 | cond_fn=cond_fn, 690 | model_kwargs=model_kwargs, 691 | device=device, 692 | progress=progress, 693 | eta=eta, 694 | ))): 695 | if i>990: 696 | myfinal.append(sample['sample']) 697 | final = sample 698 | return th.stack(myfinal) 699 | # return final["sample"] 700 | 701 | def ddim_sample_loop_progressive( 702 | self, 703 | model, 704 | shape, 705 | noise=None, 706 | clip_denoised=True, 707 | denoised_fn=None, 708 | cond_fn=None, 709 | model_kwargs=None, 710 | device=None, 711 | progress=False, 712 | eta=0.0, 713 | ): 714 | """ 715 | Use DDIM to sample from the model and yield intermediate samples from 716 | each timestep of DDIM. 717 | 718 | Same usage as p_sample_loop_progressive(). 719 | """ 720 | if device is None: 721 | device = next(model.parameters()).device 722 | assert isinstance(shape, (tuple, list)) 723 | if noise is not None: 724 | img = noise 725 | else: 726 | img = th.randn(*shape, device=device) 727 | indices = list(range(self.num_timesteps))[::-1] 728 | 729 | if progress: 730 | # Lazy import so that we don't depend on tqdm. 731 | from tqdm.auto import tqdm 732 | 733 | indices = tqdm(indices) 734 | 735 | for i in indices: 736 | t = th.tensor([i] * shape[0], device=device) 737 | with th.no_grad(): 738 | out = self.ddim_sample( 739 | model, 740 | img, 741 | t, 742 | clip_denoised=clip_denoised, 743 | denoised_fn=denoised_fn, 744 | cond_fn=cond_fn, 745 | model_kwargs=model_kwargs, 746 | eta=eta, 747 | ) 748 | yield out 749 | img = out["sample"] 750 | 751 | def _vb_terms_bpd( 752 | self, model, x_start, x_t, t, padding_mask, clip_denoised=True, model_kwargs=None, 753 | ): 754 | """ 755 | Get a term for the variational lower-bound. 756 | 757 | The resulting units are bits (rather than nats, as one might expect). 758 | This allows for comparison to other papers. 759 | 760 | :return: a dict with the following keys: 761 | - 'output': a shape [N] tensor of NLLs or KLs. 762 | - 'pred_xstart': the x_0 predictions. 763 | """ 764 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 765 | x_start=x_start, x_t=x_t, t=t 766 | ) 767 | out = self.p_mean_variance( 768 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 769 | ) 770 | kl = normal_kl( 771 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 772 | ) 773 | kl = mean_flat(kl, padding_mask) / np.log(2.0) 774 | 775 | decoder_nll = -discretized_gaussian_log_likelihood( 776 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 777 | ) 778 | assert decoder_nll.shape == x_start.shape 779 | decoder_nll = mean_flat(decoder_nll, padding_mask) / np.log(2.0) 780 | 781 | # At the first timestep return the decoder NLL, 782 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 783 | # output = th.where((t == 0), decoder_nll, kl) 784 | output = kl 785 | return {"output": output, "pred_xstart": out["pred_xstart"]} 786 | 787 | def training_losses(self, model, x_start, t, model_kwargs, analog_bit, noise=None): 788 | """ 789 | Compute training losses for a single timestep. 790 | 791 | :param model: the model to evaluate loss on. 792 | :param x_start: the [N x C x ...] tensor of inputs. 793 | :param t: a batch of timestep indices. 794 | :param model_kwargs: if not None, a dict of extra keyword arguments to 795 | pass to the model. This can be used for conditioning. 796 | :param noise: if specified, the specific Gaussian noise to try to remove. 797 | :return: a dict with the key "loss" containing a tensor of shape [N]. 798 | Some mean or variance settings may also have other keys. 799 | """ 800 | if model_kwargs is None: 801 | model_kwargs = {} 802 | if noise is None: 803 | noise = th.randn_like(x_start) 804 | x_t = self.q_sample(x_start, t, noise=noise) 805 | 806 | terms = {} 807 | tmp_mask = (1 - model_kwargs['src_key_padding_mask']) 808 | 809 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 810 | terms["loss"] = self._vb_terms_bpd( 811 | model=model, 812 | x_start=x_start, 813 | x_t=x_t, 814 | padding_mask = tmp_mask, 815 | t=t, 816 | clip_denoised=False, 817 | model_kwargs=model_kwargs, 818 | )["output"] 819 | if self.loss_type == LossType.RESCALED_KL: 820 | terms["loss"] *= self.num_timesteps 821 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 822 | xtalpha = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape).permute([0,2,1]) 823 | epsalpha = _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape).permute([0,2,1]) 824 | model_output_dec, model_output_bin = model(x_t, self._scale_timesteps(t), xtalpha=xtalpha, epsalpha=epsalpha, **model_kwargs) 825 | # model_output_dec = model(x_t, self._scale_timesteps(t), **model_kwargs) 826 | 827 | if self.model_var_type in [ 828 | ModelVarType.LEARNED, 829 | ModelVarType.LEARNED_RANGE, 830 | ]: 831 | B, C = x_t.shape[:2] 832 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 833 | model_output, model_var_values = th.split(model_output, C, dim=1) 834 | # Learn the variance using the variational bound, but don't let 835 | # it affect our mean prediction. 836 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 837 | terms["vb"] = self._vb_terms_bpd( 838 | model=lambda *args, r=frozen_out: r, 839 | x_start=x_start, 840 | x_t=x_t, 841 | padding_mask = tmp_mask, 842 | t=t, 843 | clip_denoised=False, 844 | )["output"] 845 | if self.loss_type == LossType.RESCALED_MSE: 846 | # Divide by 1000 for equivalence with initial implementation. 847 | # Without a factor of 1/1000, the VB term hurts the MSE term. 848 | terms["vb"] *= self.num_timesteps / 1000.0 849 | 850 | target = { 851 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 852 | x_start=x_start, x_t=x_t, t=t 853 | )[0], 854 | ModelMeanType.START_X: x_start, 855 | ModelMeanType.EPSILON: noise, 856 | }[self.model_mean_type] 857 | 858 | if not analog_bit: 859 | def dec2bin(xinp, bits): 860 | mask = 2 ** th.arange(bits - 1, -1, -1).to(xinp.device, xinp.dtype) 861 | return xinp.unsqueeze(-1).bitwise_and(mask).ne(0).float() 862 | bin_target = x_start.detach() 863 | bin_target = (bin_target/2 + 0.5) # -> [0,1] 864 | bin_target = bin_target * 256 #-> [0, 256] 865 | bin_target = dec2bin(bin_target.permute([0,2,1]).round().int(), 8) 866 | bin_target = bin_target.reshape([target.shape[0], target.shape[2], 16]).permute([0,2,1]) 867 | t_weights = (t<10).cuda().unsqueeze(1).unsqueeze(2) 868 | t_weights = t_weights * (t_weights.shape[0]/max(1, t_weights.sum())) 869 | bin_target[bin_target==0] = -1 870 | assert model_output_bin.shape == bin_target.shape 871 | 872 | assert model_output_dec.shape == target.shape == x_start.shape 873 | 874 | if not analog_bit: 875 | terms["mse_bin"] = mean_flat(((bin_target - model_output_bin) ** 2) * t_weights, tmp_mask) 876 | terms["mse_dec"] = mean_flat(((target - model_output_dec) ** 2), tmp_mask) 877 | 878 | if "vb" in terms: 879 | terms["loss"] = terms["mse"] + terms["vb"] 880 | else: 881 | if not analog_bit: 882 | terms["loss"] = terms["mse_dec"] + terms["mse_bin"] 883 | else: 884 | terms["loss"] = terms["mse_dec"] 885 | else: 886 | raise NotImplementedError(self.loss_type) 887 | 888 | return terms 889 | 890 | def _prior_bpd(self, x_start): 891 | """ 892 | Get the prior KL term for the variational lower-bound, measured in 893 | bits-per-dim. 894 | 895 | This term can't be optimized, as it only depends on the encoder. 896 | 897 | :param x_start: the [N x C x ...] tensor of inputs. 898 | :return: a batch of [N] KL values (in bits), one per batch element. 899 | """ 900 | batch_size = x_start.shape[0] 901 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 902 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 903 | kl_prior = normal_kl( 904 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 905 | ) 906 | return mean_flat(kl_prior) / np.log(2.0) 907 | 908 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 909 | """ 910 | Compute the entire variational lower-bound, measured in bits-per-dim, 911 | as well as other related quantities. 912 | 913 | :param model: the model to evaluate loss on. 914 | :param x_start: the [N x C x ...] tensor of inputs. 915 | :param clip_denoised: if True, clip denoised samples. 916 | :param model_kwargs: if not None, a dict of extra keyword arguments to 917 | pass to the model. This can be used for conditioning. 918 | 919 | :return: a dict containing the following keys: 920 | - total_bpd: the total variational lower-bound, per batch element. 921 | - prior_bpd: the prior term in the lower-bound. 922 | - vb: an [N x T] tensor of terms in the lower-bound. 923 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 924 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 925 | """ 926 | device = x_start.device 927 | batch_size = x_start.shape[0] 928 | 929 | vb = [] 930 | xstart_mse = [] 931 | mse = [] 932 | for t in list(range(self.num_timesteps))[::-1]: 933 | t_batch = th.tensor([t] * batch_size, device=device) 934 | noise = th.randn_like(x_start) 935 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 936 | # Calculate VLB term at the current timestep 937 | with th.no_grad(): 938 | out = self._vb_terms_bpd( 939 | model, 940 | x_start=x_start, 941 | x_t=x_t, 942 | t=t_batch, 943 | clip_denoised=clip_denoised, 944 | model_kwargs=model_kwargs, 945 | ) 946 | vb.append(out["output"]) 947 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 948 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 949 | mse.append(mean_flat((eps - noise) ** 2)) 950 | 951 | vb = th.stack(vb, dim=1) 952 | xstart_mse = th.stack(xstart_mse, dim=1) 953 | mse = th.stack(mse, dim=1) 954 | 955 | prior_bpd = self._prior_bpd(x_start) 956 | total_bpd = vb.sum(dim=1) + prior_bpd 957 | return { 958 | "total_bpd": total_bpd, 959 | "prior_bpd": prior_bpd, 960 | "vb": vb, 961 | "xstart_mse": xstart_mse, 962 | "mse": mse, 963 | } 964 | 965 | 966 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 967 | """ 968 | Extract values from a 1-D numpy array for a batch of indices. 969 | 970 | :param arr: the 1-D numpy array. 971 | :param timesteps: a tensor of indices into the array to extract. 972 | :param broadcast_shape: a larger shape of K dimensions with the batch 973 | dimension equal to the length of timesteps. 974 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 975 | """ 976 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 977 | while len(res.shape) < len(broadcast_shape): 978 | res = res[..., None] 979 | return res.expand(broadcast_shape) 980 | --------------------------------------------------------------------------------