├── .gitignore ├── P2_weighting ├── LICENSE ├── README.md ├── datasets │ ├── README.md │ └── lsun_bedroom.py ├── evaluations │ ├── README.md │ ├── evaluator.py │ └── requirements.txt ├── guided_diffusion │ ├── __init__.py │ ├── dist_util.py │ ├── fp16_util.py │ ├── gaussian_diffusion.py │ ├── image_datasets.py │ ├── logger.py │ ├── losses.py │ ├── nn.py │ ├── resample.py │ ├── respace.py │ ├── script_util.py │ ├── train_util.py │ └── unet.py ├── scripts │ ├── classifier_sample.py │ ├── classifier_train.py │ ├── image_nll.py │ ├── image_sample.py │ ├── image_train.py │ ├── super_res_sample.py │ └── super_res_train.py └── setup.py ├── README.md ├── diffae ├── __init__.py ├── align.py ├── choices.py ├── cog.yaml ├── config.py ├── config_base.py ├── dataset.py ├── dataset_util.py ├── diffusion │ ├── __init__.py │ ├── base.py │ ├── diffusion.py │ └── resample.py ├── dist_utils.py ├── evals │ └── church256_autoenc.txt ├── experiment.py ├── experiment_classifier.py ├── lmdb_writer.py ├── metrics.py ├── model │ ├── __init__.py │ ├── blocks.py │ ├── latentnet.py │ ├── nn.py │ ├── unet.py │ └── unet_autoenc.py ├── predict.py ├── renderer.py ├── run_afhq256-dog.py ├── run_church256.py ├── ssim.py ├── templates.py ├── templates_cls.py └── templates_latent.py ├── environment.yaml ├── eval_diffaeB.py ├── gen_style_domA.py ├── imgs └── teaser.jpg ├── imgs_input_domA ├── img1.png ├── img2.png ├── img3.png ├── img4.png ├── img5.png └── img6.png ├── imgs_style_domB ├── img1.png ├── img2.png ├── img3.png ├── img4.png └── img5.png ├── scripts ├── eval.sh ├── prepare_train.sh └── train.sh ├── train_diffaeB.py └── utils ├── args.py ├── map_net.py ├── tester.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | diffae/checkpoints 6 | exp 7 | imgs_style_domA 8 | P2_weighting/models -------------------------------------------------------------------------------- /P2_weighting/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /P2_weighting/README.md: -------------------------------------------------------------------------------- 1 | # P2 weighting (CVPR 2022) 2 | 3 | This is the codebase for [Perception Prioritized Training of Diffusion Models](https://arxiv.org/abs/2204.00227). 4 | 5 | This repository is heavily based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion). 6 | 7 | P2 modifies the weighting scheme of the training objective function to improve sample quality. It encourages the diffusion model to focus on recovering signals from highly corrupted data, where the model learns global and perceptually rich concepts. Below figure shows the weighting schemes in terms of SNR. 8 | 9 | ![snr_weight](https://user-images.githubusercontent.com/36615789/161203299-8b02d76b-9c51-4529-8329-3ac08e9f3bc8.png) 10 | 11 | ## Pre-trained models 12 | 13 | All models are trained at 256x256 resolution. 14 | 15 | Here are the models trained on FFHQ, CelebA-HQ, CUB, AFHQ-Dogs, Flowers, and MetFaces: [drive](https://1drv.ms/u/s!AkQjJhxDm0Fyhqp_4gkYjwVRBe8V_w?e=Et3ITH) 16 | 17 | ## Requirements 18 | We tested on PyTorch 1.7.1, single RTX8000 GPU. 19 | 20 | ## Sampling from pre-trained models 21 | 22 | First, set PYTHONPATH variable to point to the root of the repository. Do the same when training new models. 23 | 24 | ``` 25 | export PYTHONPATH=$PYTHONPATH:$(pwd) 26 | ``` 27 | 28 | Put model checkpoints into a folder `models/`. 29 | 30 | Samples will be saved in `samples/`. 31 | 32 | ``` 33 | python scripts/image_sample.py --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 250 --model_path models/ffhq_p2.pt --sample_dir samples 34 | ``` 35 | 36 | To sample for 250 timesteps without DDIM, replace `--timestep_respacing ddim25` to `--timestep_respacing 250`, and replace `--use_ddim True` with `--use_ddim False`. 37 | 38 | ## Training your models 39 | 40 | `--p2_gamma` and `--p2_k` are two hyperparameters of P2 weighting. We used `--p2_gamma 0.5 --p2_k 1` and `--p2_gamma 1 --p2_k 1` in the paper. 41 | 42 | Logs and models will be saved in `logs/`. You should modify `--data_dir`. 43 | 44 | We used lightweight version (93M parameter) of [ADM](https://arxiv.org/abs/2105.05233) (over 500M) as default model configuration. You may modify the model. 45 | 46 | ``` 47 | python scripts/image_train.py --data_dir data/DATASET_NAME --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --lr 2e-5 --batch_size 8 --rescale_learned_sigmas True --p2_gamma 1 --p2_k 1 --log_dir logs 48 | ``` 49 | 50 | 51 | -------------------------------------------------------------------------------- /P2_weighting/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase. 4 | 5 | ## Class-conditional ImageNet 6 | 7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 8 | 9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 10 | 11 | ``` 12 | for file in *.tar; do tar xf "$file"; rm "$file"; done 13 | ``` 14 | 15 | This will extract and remove each tar file in turn. 16 | 17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 18 | 19 | ## LSUN bedroom 20 | 21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 22 | 23 | ``` 24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 25 | ``` 26 | 27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 28 | -------------------------------------------------------------------------------- /P2_weighting/datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /P2_weighting/evaluations/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Download batches 6 | 7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 8 | 9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 10 | 11 | Here are links to download all of the sample and reference batches: 12 | 13 | * LSUN 14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 25 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 26 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 27 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 28 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 29 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 30 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 31 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 32 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 33 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 34 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 35 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 36 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 37 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 38 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 39 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 40 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 41 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 42 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 43 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 44 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 45 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 46 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 47 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 48 | 49 | # Run evaluations 50 | 51 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 52 | 53 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 54 | 55 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 56 | 57 | ``` 58 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 59 | ... 60 | computing reference batch activations... 61 | computing/reading reference batch statistics... 62 | computing sample batch activations... 63 | computing/reading sample batch statistics... 64 | Computing evaluations... 65 | Inception Score: 215.8370361328125 66 | FID: 3.9425574129223264 67 | sFID: 6.140433703346162 68 | Precision: 0.8265 69 | Recall: 0.5309 70 | ``` 71 | -------------------------------------------------------------------------------- /P2_weighting/evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /P2_weighting/guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | batch_size, 30 | microbatch, 31 | lr, 32 | ema_rate, 33 | log_interval, 34 | save_interval, 35 | resume_checkpoint, 36 | use_fp16=False, 37 | fp16_scale_growth=1e-3, 38 | schedule_sampler=None, 39 | weight_decay=0.0, 40 | lr_anneal_steps=0, 41 | ): 42 | self.model = model 43 | self.diffusion = diffusion 44 | self.data = data 45 | self.batch_size = batch_size 46 | self.microbatch = microbatch if microbatch > 0 else batch_size 47 | self.lr = lr 48 | self.ema_rate = ( 49 | [ema_rate] 50 | if isinstance(ema_rate, float) 51 | else [float(x) for x in ema_rate.split(",")] 52 | ) 53 | self.log_interval = log_interval 54 | self.save_interval = save_interval 55 | self.resume_checkpoint = resume_checkpoint 56 | self.use_fp16 = use_fp16 57 | self.fp16_scale_growth = fp16_scale_growth 58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 59 | self.weight_decay = weight_decay 60 | self.lr_anneal_steps = lr_anneal_steps 61 | 62 | self.step = 0 63 | self.resume_step = 0 64 | self.global_batch = self.batch_size * dist.get_world_size() 65 | 66 | self.sync_cuda = th.cuda.is_available() 67 | 68 | self._load_and_sync_parameters() 69 | self.mp_trainer = MixedPrecisionTrainer( 70 | model=self.model, 71 | use_fp16=self.use_fp16, 72 | fp16_scale_growth=fp16_scale_growth, 73 | ) 74 | 75 | self.opt = AdamW( 76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 77 | ) 78 | if self.resume_step: 79 | self._load_optimizer_state() 80 | # Model was resumed, either due to a restart or a checkpoint 81 | # being specified at the command line. 82 | self.ema_params = [ 83 | self._load_ema_parameters(rate) for rate in self.ema_rate 84 | ] 85 | else: 86 | self.ema_params = [ 87 | copy.deepcopy(self.mp_trainer.master_params) 88 | for _ in range(len(self.ema_rate)) 89 | ] 90 | 91 | if th.cuda.is_available(): 92 | self.use_ddp = True 93 | self.ddp_model = DDP( 94 | self.model, 95 | device_ids=[dist_util.dev()], 96 | output_device=dist_util.dev(), 97 | broadcast_buffers=False, 98 | bucket_cap_mb=128, 99 | find_unused_parameters=False, 100 | ) 101 | else: 102 | if dist.get_world_size() > 1: 103 | logger.warn( 104 | "Distributed training requires CUDA. " 105 | "Gradients will not be synchronized properly!" 106 | ) 107 | self.use_ddp = False 108 | self.ddp_model = self.model 109 | 110 | def _load_and_sync_parameters(self): 111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 112 | 113 | if resume_checkpoint: 114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 115 | if dist.get_rank() == 0: 116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 117 | self.model.load_state_dict( 118 | dist_util.load_state_dict( 119 | resume_checkpoint, map_location=dist_util.dev() 120 | ) 121 | ) 122 | 123 | dist_util.sync_params(self.model.parameters()) 124 | 125 | def _load_ema_parameters(self, rate): 126 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 127 | 128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 130 | if ema_checkpoint: 131 | if dist.get_rank() == 0: 132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 133 | state_dict = dist_util.load_state_dict( 134 | ema_checkpoint, map_location=dist_util.dev() 135 | ) 136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 137 | 138 | dist_util.sync_params(ema_params) 139 | return ema_params 140 | 141 | def _load_optimizer_state(self): 142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 143 | opt_checkpoint = bf.join( 144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 145 | ) 146 | if bf.exists(opt_checkpoint): 147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 148 | state_dict = dist_util.load_state_dict( 149 | opt_checkpoint, map_location=dist_util.dev() 150 | ) 151 | self.opt.load_state_dict(state_dict) 152 | 153 | def run_loop(self): 154 | while ( 155 | not self.lr_anneal_steps 156 | or self.step + self.resume_step < self.lr_anneal_steps 157 | ): 158 | batch, cond = next(self.data) 159 | self.run_step(batch, cond) 160 | if self.step % self.log_interval == 0: 161 | logger.dumpkvs() 162 | if self.step % self.save_interval == 0: 163 | self.save() 164 | # Run for a finite amount of time in integration tests. 165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 166 | return 167 | self.step += 1 168 | # Save the last checkpoint if it wasn't already saved. 169 | if (self.step - 1) % self.save_interval != 0: 170 | self.save() 171 | 172 | def run_step(self, batch, cond): 173 | self.forward_backward(batch, cond) 174 | took_step = self.mp_trainer.optimize(self.opt) 175 | if took_step: 176 | self._update_ema() 177 | self._anneal_lr() 178 | self.log_step() 179 | 180 | def forward_backward(self, batch, cond): 181 | self.mp_trainer.zero_grad() 182 | for i in range(0, batch.shape[0], self.microbatch): 183 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 184 | micro_cond = { 185 | k: v[i : i + self.microbatch].to(dist_util.dev()) 186 | for k, v in cond.items() 187 | } 188 | last_batch = (i + self.microbatch) >= batch.shape[0] 189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 190 | 191 | compute_losses = functools.partial( 192 | self.diffusion.training_losses, 193 | self.ddp_model, 194 | micro, 195 | t, 196 | model_kwargs=micro_cond, 197 | ) 198 | 199 | if last_batch or not self.use_ddp: 200 | losses = compute_losses() 201 | else: 202 | with self.ddp_model.no_sync(): 203 | losses = compute_losses() 204 | 205 | if isinstance(self.schedule_sampler, LossAwareSampler): 206 | self.schedule_sampler.update_with_local_losses( 207 | t, losses["loss"].detach() 208 | ) 209 | 210 | loss = (losses["loss"] * weights).mean() 211 | log_loss_dict( 212 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 213 | ) 214 | self.mp_trainer.backward(loss) 215 | 216 | def _update_ema(self): 217 | for rate, params in zip(self.ema_rate, self.ema_params): 218 | update_ema(params, self.mp_trainer.master_params, rate=rate) 219 | 220 | def _anneal_lr(self): 221 | if not self.lr_anneal_steps: 222 | return 223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 224 | lr = self.lr * (1 - frac_done) 225 | for param_group in self.opt.param_groups: 226 | param_group["lr"] = lr 227 | 228 | def log_step(self): 229 | logger.logkv("step", self.step + self.resume_step) 230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 231 | 232 | def save(self): 233 | def save_checkpoint(rate, params): 234 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 235 | if dist.get_rank() == 0: 236 | logger.log(f"saving model {rate}...") 237 | if not rate: 238 | filename = f"model{(self.step+self.resume_step):06d}.pt" 239 | else: 240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 242 | th.save(state_dict, f) 243 | 244 | save_checkpoint(0, self.mp_trainer.master_params) 245 | for rate, params in zip(self.ema_rate, self.ema_params): 246 | save_checkpoint(rate, params) 247 | 248 | if dist.get_rank() == 0: 249 | with bf.BlobFile( 250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 251 | "wb", 252 | ) as f: 253 | th.save(self.opt.state_dict(), f) 254 | 255 | dist.barrier() 256 | 257 | 258 | def parse_resume_step_from_filename(filename): 259 | """ 260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 261 | checkpoint's number of steps. 262 | """ 263 | split = filename.split("model") 264 | if len(split) < 2: 265 | return 0 266 | split1 = split[-1].split(".")[0] 267 | try: 268 | return int(split1) 269 | except ValueError: 270 | return 0 271 | 272 | 273 | def get_blob_logdir(): 274 | # You can change this to be a separate path to save checkpoints to 275 | # a blobstore or some external drive. 276 | return logger.get_dir() 277 | 278 | 279 | def find_resume_checkpoint(): 280 | # On your infrastructure, you may want to override this to automatically 281 | # discover the latest checkpoint on your blob storage, etc. 282 | return None 283 | 284 | 285 | def find_ema_checkpoint(main_checkpoint, step, rate): 286 | if main_checkpoint is None: 287 | return None 288 | filename = f"ema_{rate}_{(step):06d}.pt" 289 | path = bf.join(bf.dirname(main_checkpoint), filename) 290 | if bf.exists(path): 291 | return path 292 | return None 293 | 294 | 295 | def log_loss_dict(diffusion, ts, losses): 296 | for key, values in losses.items(): 297 | logger.logkv_mean(key, values.mean().item()) 298 | # Log the quantiles (four quartiles, in particular). 299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 300 | quartile = int(4 * sub_t / diffusion.num_timesteps) 301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 302 | -------------------------------------------------------------------------------- /P2_weighting/scripts/classifier_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Like image_sample.py, but use a noisy image classifier to guide the sampling 3 | process towards more realistic images. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | NUM_CLASSES, 17 | model_and_diffusion_defaults, 18 | classifier_defaults, 19 | create_model_and_diffusion, 20 | create_classifier, 21 | add_dict_to_argparser, 22 | args_to_dict, 23 | ) 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | logger.log("creating model and diffusion...") 33 | model, diffusion = create_model_and_diffusion( 34 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 35 | ) 36 | model.load_state_dict( 37 | dist_util.load_state_dict(args.model_path, map_location="cpu") 38 | ) 39 | model.to(dist_util.dev()) 40 | if args.use_fp16: 41 | model.convert_to_fp16() 42 | model.eval() 43 | 44 | logger.log("loading classifier...") 45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys())) 46 | classifier.load_state_dict( 47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu") 48 | ) 49 | classifier.to(dist_util.dev()) 50 | if args.classifier_use_fp16: 51 | classifier.convert_to_fp16() 52 | classifier.eval() 53 | 54 | def cond_fn(x, t, y=None): 55 | assert y is not None 56 | with th.enable_grad(): 57 | x_in = x.detach().requires_grad_(True) 58 | logits = classifier(x_in, t) 59 | log_probs = F.log_softmax(logits, dim=-1) 60 | selected = log_probs[range(len(logits)), y.view(-1)] 61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 62 | 63 | def model_fn(x, t, y=None): 64 | assert y is not None 65 | return model(x, t, y if args.class_cond else None) 66 | 67 | logger.log("sampling...") 68 | all_images = [] 69 | all_labels = [] 70 | while len(all_images) * args.batch_size < args.num_samples: 71 | model_kwargs = {} 72 | classes = th.randint( 73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 74 | ) 75 | model_kwargs["y"] = classes 76 | sample_fn = ( 77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 78 | ) 79 | sample = sample_fn( 80 | model_fn, 81 | (args.batch_size, 3, args.image_size, args.image_size), 82 | clip_denoised=args.clip_denoised, 83 | model_kwargs=model_kwargs, 84 | cond_fn=cond_fn, 85 | device=dist_util.dev(), 86 | ) 87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 88 | sample = sample.permute(0, 2, 3, 1) 89 | sample = sample.contiguous() 90 | 91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] 95 | dist.all_gather(gathered_labels, classes) 96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 97 | logger.log(f"created {len(all_images) * args.batch_size} samples") 98 | 99 | arr = np.concatenate(all_images, axis=0) 100 | arr = arr[: args.num_samples] 101 | label_arr = np.concatenate(all_labels, axis=0) 102 | label_arr = label_arr[: args.num_samples] 103 | if dist.get_rank() == 0: 104 | shape_str = "x".join([str(x) for x in arr.shape]) 105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 106 | logger.log(f"saving to {out_path}") 107 | np.savez(out_path, arr, label_arr) 108 | 109 | dist.barrier() 110 | logger.log("sampling complete") 111 | 112 | 113 | def create_argparser(): 114 | defaults = dict( 115 | clip_denoised=True, 116 | num_samples=10000, 117 | batch_size=16, 118 | use_ddim=False, 119 | model_path="", 120 | classifier_path="", 121 | classifier_scale=1.0, 122 | ) 123 | defaults.update(model_and_diffusion_defaults()) 124 | defaults.update(classifier_defaults()) 125 | parser = argparse.ArgumentParser() 126 | add_dict_to_argparser(parser, defaults) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /P2_weighting/scripts/classifier_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a noised image classifier on ImageNet. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import blobfile as bf 9 | import torch as th 10 | import torch.distributed as dist 11 | import torch.nn.functional as F 12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 13 | from torch.optim import AdamW 14 | 15 | from guided_diffusion import dist_util, logger 16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer 17 | from guided_diffusion.image_datasets import load_data 18 | from guided_diffusion.resample import create_named_schedule_sampler 19 | from guided_diffusion.script_util import ( 20 | add_dict_to_argparser, 21 | args_to_dict, 22 | classifier_and_diffusion_defaults, 23 | create_classifier_and_diffusion, 24 | ) 25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | 31 | dist_util.setup_dist() 32 | logger.configure() 33 | 34 | logger.log("creating model and diffusion...") 35 | model, diffusion = create_classifier_and_diffusion( 36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys()) 37 | ) 38 | model.to(dist_util.dev()) 39 | if args.noised: 40 | schedule_sampler = create_named_schedule_sampler( 41 | args.schedule_sampler, diffusion 42 | ) 43 | 44 | resume_step = 0 45 | if args.resume_checkpoint: 46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint) 47 | if dist.get_rank() == 0: 48 | logger.log( 49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" 50 | ) 51 | model.load_state_dict( 52 | dist_util.load_state_dict( 53 | args.resume_checkpoint, map_location=dist_util.dev() 54 | ) 55 | ) 56 | 57 | # Needed for creating correct EMAs and fp16 parameters. 58 | dist_util.sync_params(model.parameters()) 59 | 60 | mp_trainer = MixedPrecisionTrainer( 61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0 62 | ) 63 | 64 | model = DDP( 65 | model, 66 | device_ids=[dist_util.dev()], 67 | output_device=dist_util.dev(), 68 | broadcast_buffers=False, 69 | bucket_cap_mb=128, 70 | find_unused_parameters=False, 71 | ) 72 | 73 | logger.log("creating data loader...") 74 | data = load_data( 75 | data_dir=args.data_dir, 76 | batch_size=args.batch_size, 77 | image_size=args.image_size, 78 | class_cond=True, 79 | random_crop=True, 80 | ) 81 | if args.val_data_dir: 82 | val_data = load_data( 83 | data_dir=args.val_data_dir, 84 | batch_size=args.batch_size, 85 | image_size=args.image_size, 86 | class_cond=True, 87 | ) 88 | else: 89 | val_data = None 90 | 91 | logger.log(f"creating optimizer...") 92 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) 93 | if args.resume_checkpoint: 94 | opt_checkpoint = bf.join( 95 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt" 96 | ) 97 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 98 | opt.load_state_dict( 99 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) 100 | ) 101 | 102 | logger.log("training classifier model...") 103 | 104 | def forward_backward_log(data_loader, prefix="train"): 105 | batch, extra = next(data_loader) 106 | labels = extra["y"].to(dist_util.dev()) 107 | 108 | batch = batch.to(dist_util.dev()) 109 | # Noisy images 110 | if args.noised: 111 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) 112 | batch = diffusion.q_sample(batch, t) 113 | else: 114 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) 115 | 116 | for i, (sub_batch, sub_labels, sub_t) in enumerate( 117 | split_microbatches(args.microbatch, batch, labels, t) 118 | ): 119 | logits = model(sub_batch, timesteps=sub_t) 120 | loss = F.cross_entropy(logits, sub_labels, reduction="none") 121 | 122 | losses = {} 123 | losses[f"{prefix}_loss"] = loss.detach() 124 | losses[f"{prefix}_acc@1"] = compute_top_k( 125 | logits, sub_labels, k=1, reduction="none" 126 | ) 127 | losses[f"{prefix}_acc@5"] = compute_top_k( 128 | logits, sub_labels, k=5, reduction="none" 129 | ) 130 | log_loss_dict(diffusion, sub_t, losses) 131 | del losses 132 | loss = loss.mean() 133 | if loss.requires_grad: 134 | if i == 0: 135 | mp_trainer.zero_grad() 136 | mp_trainer.backward(loss * len(sub_batch) / len(batch)) 137 | 138 | for step in range(args.iterations - resume_step): 139 | logger.logkv("step", step + resume_step) 140 | logger.logkv( 141 | "samples", 142 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(), 143 | ) 144 | if args.anneal_lr: 145 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) 146 | forward_backward_log(data) 147 | mp_trainer.optimize(opt) 148 | if val_data is not None and not step % args.eval_interval: 149 | with th.no_grad(): 150 | with model.no_sync(): 151 | model.eval() 152 | forward_backward_log(val_data, prefix="val") 153 | model.train() 154 | if not step % args.log_interval: 155 | logger.dumpkvs() 156 | if ( 157 | step 158 | and dist.get_rank() == 0 159 | and not (step + resume_step) % args.save_interval 160 | ): 161 | logger.log("saving model...") 162 | save_model(mp_trainer, opt, step + resume_step) 163 | 164 | if dist.get_rank() == 0: 165 | logger.log("saving model...") 166 | save_model(mp_trainer, opt, step + resume_step) 167 | dist.barrier() 168 | 169 | 170 | def set_annealed_lr(opt, base_lr, frac_done): 171 | lr = base_lr * (1 - frac_done) 172 | for param_group in opt.param_groups: 173 | param_group["lr"] = lr 174 | 175 | 176 | def save_model(mp_trainer, opt, step): 177 | if dist.get_rank() == 0: 178 | th.save( 179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params), 180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"), 181 | ) 182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt")) 183 | 184 | 185 | def compute_top_k(logits, labels, k, reduction="mean"): 186 | _, top_ks = th.topk(logits, k, dim=-1) 187 | if reduction == "mean": 188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 189 | elif reduction == "none": 190 | return (top_ks == labels[:, None]).float().sum(dim=-1) 191 | 192 | 193 | def split_microbatches(microbatch, *args): 194 | bs = len(args[0]) 195 | if microbatch == -1 or microbatch >= bs: 196 | yield tuple(args) 197 | else: 198 | for i in range(0, bs, microbatch): 199 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args) 200 | 201 | 202 | def create_argparser(): 203 | defaults = dict( 204 | data_dir="", 205 | val_data_dir="", 206 | noised=True, 207 | iterations=150000, 208 | lr=3e-4, 209 | weight_decay=0.0, 210 | anneal_lr=False, 211 | batch_size=4, 212 | microbatch=-1, 213 | schedule_sampler="uniform", 214 | resume_checkpoint="", 215 | log_interval=10, 216 | eval_interval=5, 217 | save_interval=10000, 218 | ) 219 | defaults.update(classifier_and_diffusion_defaults()) 220 | parser = argparse.ArgumentParser() 221 | add_dict_to_argparser(parser, defaults) 222 | return parser 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /P2_weighting/scripts/image_nll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate the bits/dimension for an image model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | from guided_diffusion import dist_util, logger 12 | from guided_diffusion.image_datasets import load_data 13 | from guided_diffusion.script_util import ( 14 | model_and_diffusion_defaults, 15 | create_model_and_diffusion, 16 | add_dict_to_argparser, 17 | args_to_dict, 18 | ) 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.load_state_dict( 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") 33 | ) 34 | model.to(dist_util.dev()) 35 | model.eval() 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | deterministic=True, 44 | ) 45 | 46 | logger.log("evaluating...") 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) 48 | 49 | 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): 51 | all_bpd = [] 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} 53 | num_complete = 0 54 | while num_complete < num_samples: 55 | batch, model_kwargs = next(data) 56 | batch = batch.to(dist_util.dev()) 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 58 | minibatch_metrics = diffusion.calc_bpd_loop( 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs 60 | ) 61 | 62 | for key, term_list in all_metrics.items(): 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() 64 | dist.all_reduce(terms) 65 | term_list.append(terms.detach().cpu().numpy()) 66 | 67 | total_bpd = minibatch_metrics["total_bpd"] 68 | total_bpd = total_bpd.mean() / dist.get_world_size() 69 | dist.all_reduce(total_bpd) 70 | all_bpd.append(total_bpd.item()) 71 | num_complete += dist.get_world_size() * batch.shape[0] 72 | 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") 74 | 75 | if dist.get_rank() == 0: 76 | for name, terms in all_metrics.items(): 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") 78 | logger.log(f"saving {name} terms to {out_path}") 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) 80 | 81 | dist.barrier() 82 | logger.log("evaluation complete") 83 | 84 | 85 | def create_argparser(): 86 | defaults = dict( 87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /P2_weighting/scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from guided_diffusion import dist_util, logger 14 | from guided_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | from torchvision import utils 22 | 23 | 24 | def main(): 25 | args = create_argparser().parse_args() 26 | 27 | dist_util.setup_dist() 28 | logger.configure(dir=args.sample_dir) 29 | 30 | logger.log("creating model and diffusion...") 31 | model, diffusion = create_model_and_diffusion( 32 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 33 | ) 34 | model.load_state_dict( 35 | dist_util.load_state_dict(args.model_path, map_location="cpu") 36 | ) 37 | model.to(dist_util.dev()) 38 | if args.use_fp16: 39 | model.convert_to_fp16() 40 | model.eval() 41 | 42 | logger.log("sampling...") 43 | all_images = [] 44 | all_labels = [] 45 | count = 0 46 | while count * args.batch_size < args.num_samples: 47 | model_kwargs = {} 48 | if args.class_cond: 49 | classes = th.randint( 50 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 51 | ) 52 | model_kwargs["y"] = classes 53 | sample_fn = ( 54 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 55 | ) 56 | sample = sample_fn( 57 | model, 58 | (args.batch_size, 3, args.image_size, args.image_size), 59 | clip_denoised=args.clip_denoised, 60 | model_kwargs=model_kwargs, 61 | ) 62 | # saving png 63 | for i in range(args.batch_size): 64 | out_path = os.path.join(logger.get_dir(), 65 | f"{str(count * args.batch_size + i).zfill(5)}.png") 66 | utils.save_image( 67 | sample[i].unsqueeze(0), 68 | out_path, 69 | nrow=1, 70 | normalize=True, 71 | range=(-1, 1), 72 | ) 73 | # saving npz 74 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 75 | sample = sample.permute(0, 2, 3, 1) 76 | sample = sample.contiguous() 77 | 78 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 79 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 80 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 81 | if args.class_cond: 82 | gathered_labels = [ 83 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 84 | ] 85 | dist.all_gather(gathered_labels, classes) 86 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 87 | logger.log(f"created {len(all_images) * args.batch_size} samples") 88 | 89 | arr = np.concatenate(all_images, axis=0) 90 | arr = arr[: args.num_samples] 91 | if args.class_cond: 92 | label_arr = np.concatenate(all_labels, axis=0) 93 | label_arr = label_arr[: args.num_samples] 94 | if dist.get_rank() == 0: 95 | shape_str = "x".join([str(x) for x in arr.shape]) 96 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 97 | logger.log(f"saving to {out_path}") 98 | if args.class_cond: 99 | np.savez(out_path, arr, label_arr) 100 | else: 101 | np.savez(out_path, arr) 102 | 103 | dist.barrier() 104 | logger.log("sampling complete") 105 | 106 | 107 | def create_argparser(): 108 | defaults = dict( 109 | clip_denoised=True, 110 | num_samples=10000, 111 | batch_size=16, 112 | use_ddim=False, 113 | model_path="", 114 | sample_dir="", 115 | ) 116 | defaults.update(model_and_diffusion_defaults()) 117 | parser = argparse.ArgumentParser() 118 | add_dict_to_argparser(parser, defaults) 119 | return parser 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /P2_weighting/scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from guided_diffusion import dist_util, logger 8 | from guided_diffusion.image_datasets import load_data 9 | from guided_diffusion.resample import create_named_schedule_sampler 10 | from guided_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from guided_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure(dir=args.log_dir) 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | data_dir=args.data_dir, 35 | batch_size=args.batch_size, 36 | image_size=args.image_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | data_dir="", 63 | log_dir="", 64 | schedule_sampler="uniform", 65 | lr=1e-4, 66 | weight_decay=0.0, 67 | lr_anneal_steps=0, 68 | batch_size=1, 69 | microbatch=-1, # -1 disables microbatches 70 | ema_rate="0.9999", # comma-separated list of EMA values 71 | log_interval=10, 72 | save_interval=10000, 73 | resume_checkpoint="", 74 | use_fp16=False, 75 | fp16_scale_growth=1e-3, 76 | ) 77 | defaults.update(model_and_diffusion_defaults()) 78 | parser = argparse.ArgumentParser() 79 | add_dict_to_argparser(parser, defaults) 80 | return parser 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /P2_weighting/scripts/super_res_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of samples from a super resolution model, given a batch 3 | of samples from a regular model from image_sample.py. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import blobfile as bf 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | sr_model_and_diffusion_defaults, 17 | sr_create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model...") 30 | model, diffusion = sr_create_model_and_diffusion( 31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("loading data...") 42 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) 43 | 44 | logger.log("creating samples...") 45 | all_images = [] 46 | while len(all_images) * args.batch_size < args.num_samples: 47 | model_kwargs = next(data) 48 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 49 | sample = diffusion.p_sample_loop( 50 | model, 51 | (args.batch_size, 3, args.large_size, args.large_size), 52 | clip_denoised=args.clip_denoised, 53 | model_kwargs=model_kwargs, 54 | ) 55 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 56 | sample = sample.permute(0, 2, 3, 1) 57 | sample = sample.contiguous() 58 | 59 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 60 | dist.all_gather(all_samples, sample) # gather not supported with NCCL 61 | for sample in all_samples: 62 | all_images.append(sample.cpu().numpy()) 63 | logger.log(f"created {len(all_images) * args.batch_size} samples") 64 | 65 | arr = np.concatenate(all_images, axis=0) 66 | arr = arr[: args.num_samples] 67 | if dist.get_rank() == 0: 68 | shape_str = "x".join([str(x) for x in arr.shape]) 69 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 70 | logger.log(f"saving to {out_path}") 71 | np.savez(out_path, arr) 72 | 73 | dist.barrier() 74 | logger.log("sampling complete") 75 | 76 | 77 | def load_data_for_worker(base_samples, batch_size, class_cond): 78 | with bf.BlobFile(base_samples, "rb") as f: 79 | obj = np.load(f) 80 | image_arr = obj["arr_0"] 81 | if class_cond: 82 | label_arr = obj["arr_1"] 83 | rank = dist.get_rank() 84 | num_ranks = dist.get_world_size() 85 | buffer = [] 86 | label_buffer = [] 87 | while True: 88 | for i in range(rank, len(image_arr), num_ranks): 89 | buffer.append(image_arr[i]) 90 | if class_cond: 91 | label_buffer.append(label_arr[i]) 92 | if len(buffer) == batch_size: 93 | batch = th.from_numpy(np.stack(buffer)).float() 94 | batch = batch / 127.5 - 1.0 95 | batch = batch.permute(0, 3, 1, 2) 96 | res = dict(low_res=batch) 97 | if class_cond: 98 | res["y"] = th.from_numpy(np.stack(label_buffer)) 99 | yield res 100 | buffer, label_buffer = [], [] 101 | 102 | 103 | def create_argparser(): 104 | defaults = dict( 105 | clip_denoised=True, 106 | num_samples=10000, 107 | batch_size=16, 108 | use_ddim=False, 109 | base_samples="", 110 | model_path="", 111 | ) 112 | defaults.update(sr_model_and_diffusion_defaults()) 113 | parser = argparse.ArgumentParser() 114 | add_dict_to_argparser(parser, defaults) 115 | return parser 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /P2_weighting/scripts/super_res_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch.nn.functional as F 8 | 9 | from guided_diffusion import dist_util, logger 10 | from guided_diffusion.image_datasets import load_data 11 | from guided_diffusion.resample import create_named_schedule_sampler 12 | from guided_diffusion.script_util import ( 13 | sr_model_and_diffusion_defaults, 14 | sr_create_model_and_diffusion, 15 | args_to_dict, 16 | add_dict_to_argparser, 17 | ) 18 | from guided_diffusion.train_util import TrainLoop 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model...") 28 | model, diffusion = sr_create_model_and_diffusion( 29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_superres_data( 36 | args.data_dir, 37 | args.batch_size, 38 | large_size=args.large_size, 39 | small_size=args.small_size, 40 | class_cond=args.class_cond, 41 | ) 42 | 43 | logger.log("training...") 44 | TrainLoop( 45 | model=model, 46 | diffusion=diffusion, 47 | data=data, 48 | batch_size=args.batch_size, 49 | microbatch=args.microbatch, 50 | lr=args.lr, 51 | ema_rate=args.ema_rate, 52 | log_interval=args.log_interval, 53 | save_interval=args.save_interval, 54 | resume_checkpoint=args.resume_checkpoint, 55 | use_fp16=args.use_fp16, 56 | fp16_scale_growth=args.fp16_scale_growth, 57 | schedule_sampler=schedule_sampler, 58 | weight_decay=args.weight_decay, 59 | lr_anneal_steps=args.lr_anneal_steps, 60 | ).run_loop() 61 | 62 | 63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): 64 | data = load_data( 65 | data_dir=data_dir, 66 | batch_size=batch_size, 67 | image_size=large_size, 68 | class_cond=class_cond, 69 | ) 70 | for large_batch, model_kwargs in data: 71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") 72 | yield large_batch, model_kwargs 73 | 74 | 75 | def create_argparser(): 76 | defaults = dict( 77 | data_dir="", 78 | schedule_sampler="uniform", 79 | lr=1e-4, 80 | weight_decay=0.0, 81 | lr_anneal_steps=0, 82 | batch_size=1, 83 | microbatch=-1, 84 | ema_rate="0.9999", 85 | log_interval=10, 86 | save_interval=10000, 87 | resume_checkpoint="", 88 | use_fp16=False, 89 | fp16_scale_growth=1e-3, 90 | ) 91 | defaults.update(sr_model_and_diffusion_defaults()) 92 | parser = argparse.ArgumentParser() 93 | add_dict_to_argparser(parser, defaults) 94 | return parser 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /P2_weighting/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="guided-diffusion", 5 | py_modules=["guided_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2024] One-Shot Structure-Aware Stylized Image Synthesis 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2402.17275-red)](https://arxiv.org/abs/2402.17275) 4 | 5 | > **One-Shot Structure-Aware Stylized Image Synthesis**
6 | > Hansam Cho, Jonghyun Lee, Seunggyu Chang, Yonghyun Jeong
7 | > 8 | >**Abstract**:
9 | While GAN-based models have been successful in image stylization tasks, they often struggle with structure preservation while stylizing a wide range of input images. Recently, diffusion models have been adopted for image stylization but still lack the capability to maintain the original quality of input images. Building on this, we propose OSASIS: a novel one-shot stylization method that is robust in structure preservation. We show that OSASIS is able to effectively disentangle the semantics from the structure of an image, allowing it to control the level of content and style implemented to a given input. We apply OSASIS to various experimental settings, including stylization with out-of-domain reference images and stylization with text-driven manipulation. Results show that OSASIS outperforms other stylization methods, especially for input images that were rarely encountered during training, providing a promising solution to stylization via diffusion models. 10 | 11 | ## Description 12 | Official implementation of One-Shot Structure-Aware Stylized Image Synthesis 13 | 14 | ![image](imgs/teaser.jpg) 15 | 16 | ## Setup 17 | ``` 18 | conda env create -f environment.yaml 19 | conda activate osasis 20 | ``` 21 | 22 | ## Prepare Training 23 | 1. Download DDPM with [P2-weighting](https://github.com/jychoi118/P2-weighting) trained on FFHQ ([ffhq_p2.pt](https://drive.google.com/file/d/14ACthNkradJWBAL0th6z0UguzR6QjFKH/view?usp=drive_link)) and put model checkpoint in `P2_weighting/models` 24 | ``` 25 | OSASIS 26 | |--P2_weighting 27 | | |--models 28 | | | |--ffhq_p2.pt 29 | ``` 30 | 31 | 2. Generate style images in domain A (photorealistic domain) 32 | 33 | ``` 34 | DEVICE=0 35 | 36 | SAMPLE_FLAGS="--attention_resolutions 16 --class_cond False --class_cond False --diffusion_steps 1000 --dropout 0.0 \ 37 | --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 \ 38 | --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 50" 39 | 40 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 41 | python gen_style_domA.py ${SAMPLE_FLAGS} \ 42 | --model_path P2_weighting/models/ffhq_p2.pt \ 43 | --input_dir imgs_style_domB \ 44 | --sample_dir imgs_style_domA \ 45 | --img_name img1.png \ 46 | --n 1 \ 47 | --t_start_ratio 0.5 \ 48 | --seed 1 \ 49 | 50 | ``` 51 | `input_dir`: directory of style images in domain B (stylized domain)
52 | `sample_dir`: saving directory of style images in domain A (photorealistic domain)
53 | `img_name`: name of style image
54 | `n`: number of sampling images to generate style image in domain A
55 | `t_srtart_ratio`: noising level of image ($t_0$) 56 | 57 | 58 | 59 | ## Training 60 | 1. Download [DiffAE](https://github.com/phizaz/diffae) trained on FFHQ( [ffhq256_autoenc](https://drive.google.com/drive/folders/1-5zfxT6Gl-GjxM7z9ZO2AHlB70tfmF6V), [ffhq256_autoenc_latent](https://drive.google.com/drive/folders/1-H8WzKc65dEONN-DQ87TnXc23nTXDTYb) ) and put model checkpoints in `diffae/checkpoints` 61 | ``` 62 | OSASIS 63 | |--diffae 64 | | |--checkpoints 65 | | | |--ffhq256_autoenc 66 | | | | |--last.ckpt 67 | | | | |--latent.pkl 68 | | | |--ffhq256_autoenc_latent 69 | | | | |--last.ckpt 70 | ``` 71 | 72 | 73 | 2. Train the model using the following scripts, which necessitate 34GB of VRAM for a batch size of 8. The process takes in approximately 30 minutes on a single A100 GPU. 74 | 75 | ``` 76 | DEVICE=0 77 | 78 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 79 | python train_diffaeB.py \ 80 | --style_domA_dir imgs_style_domA \ 81 | --style_domB_dir imgs_style_domB \ 82 | --ref_img img1.png \ 83 | --work_dir exp/img1 \ 84 | --n_iter 200 \ 85 | --ckpt_freq 200 \ 86 | --batch_size 8 \ 87 | --map_net \ 88 | --map_time \ 89 | --lambda_map 0.1 \ 90 | --train 91 | ``` 92 | `style_domA_dir`: directory of style images in domain A (photorealistic domain)
93 | `style_domB_dir`: directory of style images in domain B (stylized domain)
94 | `ref_img`: name of style image
95 | `work_dir`: working directory
96 | `n_iter`: number of iteration 97 | 98 | 99 | ## Testing 100 | Generate stylized image with following scripts: 101 | ``` 102 | DEVICE=0 103 | 104 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 105 | python eval_diffaeB.py \ 106 | --style_domB_dir imgs_style_domB \ 107 | --infer_dir imgs_input_domA \ 108 | --ref_img img1.png \ 109 | --work_dir exp/img1 \ 110 | --map_net \ 111 | --map_time \ 112 | --lambda_map 0.1 113 | ``` 114 | `style_domB_dir`: directory of style images in domain B (stylized domain)
115 | `infer_dir`: directory of input images in domain A (photorealistic domain)
116 | `ref_img`: name of style image
117 | `work_dir`: working directory 118 | 119 | ## Using Pretrained Models 120 | Download pretrained weights in this [link](https://drive.google.com/drive/folders/1N0q9RBYIwc110njCsHX7uiY3BtwY5PqY?usp=sharing) and put checkpoint as shown in below 121 | ``` 122 | OSASIS 123 | |--exp 124 | | |--img1 125 | | | |--ckpt 126 | | | | |--iter_200.pt 127 | | |--img2 128 | | | |--ckpt 129 | | | | |--iter_200.pt 130 | ``` 131 | 132 | ## Acknowledgements 133 | This repository is built upon [P2-weighting](https://github.com/jychoi118/P2-weighting), [DiffAE](https://github.com/phizaz/diffae), and [MindTheGap](https://github.com/ZPdesu/MindTheGap) 134 | 135 | ## Citation 136 | ```bibtex 137 | @article{cho2024one, 138 | title={One-Shot Structure-Aware Stylized Image Synthesis}, 139 | author={Cho, Hansam and Lee, Jonghyun and Chang, Seunggyu and Jeong, Yonghyun}, 140 | journal={arXiv preprint arXiv:2402.17275}, 141 | year={2024} 142 | } 143 | ``` -------------------------------------------------------------------------------- /diffae/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | project_dir = os.getcwd() 5 | sys.path.append(os.path.join(project_dir, 'diffae')) -------------------------------------------------------------------------------- /diffae/align.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import os 3 | import os.path as osp 4 | import sys 5 | from multiprocessing import Pool 6 | 7 | import dlib 8 | import numpy as np 9 | import PIL.Image 10 | import requests 11 | import scipy.ndimage 12 | from tqdm import tqdm 13 | from argparse import ArgumentParser 14 | 15 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 16 | 17 | 18 | def image_align(src_file, 19 | dst_file, 20 | face_landmarks, 21 | output_size=1024, 22 | transform_size=4096, 23 | enable_padding=True): 24 | # Align function from FFHQ dataset pre-processing step 25 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 26 | 27 | lm = np.array(face_landmarks) 28 | lm_chin = lm[0:17] # left-right 29 | lm_eyebrow_left = lm[17:22] # left-right 30 | lm_eyebrow_right = lm[22:27] # left-right 31 | lm_nose = lm[27:31] # top-down 32 | lm_nostrils = lm[31:36] # top-down 33 | lm_eye_left = lm[36:42] # left-clockwise 34 | lm_eye_right = lm[42:48] # left-clockwise 35 | lm_mouth_outer = lm[48:60] # left-clockwise 36 | lm_mouth_inner = lm[60:68] # left-clockwise 37 | 38 | # Calculate auxiliary vectors. 39 | eye_left = np.mean(lm_eye_left, axis=0) 40 | eye_right = np.mean(lm_eye_right, axis=0) 41 | eye_avg = (eye_left + eye_right) * 0.5 42 | eye_to_eye = eye_right - eye_left 43 | mouth_left = lm_mouth_outer[0] 44 | mouth_right = lm_mouth_outer[6] 45 | mouth_avg = (mouth_left + mouth_right) * 0.5 46 | eye_to_mouth = mouth_avg - eye_avg 47 | 48 | # Choose oriented crop rectangle. 49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 50 | x /= np.hypot(*x) 51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 52 | y = np.flipud(x) * [-1, 1] 53 | c = eye_avg + eye_to_mouth * 0.1 54 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 55 | qsize = np.hypot(*x) * 2 56 | 57 | # Load in-the-wild image. 58 | if not os.path.isfile(src_file): 59 | print( 60 | '\nCannot find source image. Please run "--wilds" before "--align".' 61 | ) 62 | return 63 | img = PIL.Image.open(src_file) 64 | img = img.convert('RGB') 65 | 66 | # Shrink. 67 | shrink = int(np.floor(qsize / output_size * 0.5)) 68 | if shrink > 1: 69 | rsize = (int(np.rint(float(img.size[0]) / shrink)), 70 | int(np.rint(float(img.size[1]) / shrink))) 71 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 72 | quad /= shrink 73 | qsize /= shrink 74 | 75 | # Crop. 76 | border = max(int(np.rint(qsize * 0.1)), 3) 77 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 78 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 79 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), 80 | min(crop[2] + border, 81 | img.size[0]), min(crop[3] + border, img.size[1])) 82 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 83 | img = img.crop(crop) 84 | quad -= crop[0:2] 85 | 86 | # Pad. 87 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 88 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 89 | pad = (max(-pad[0] + border, 90 | 0), max(-pad[1] + border, 91 | 0), max(pad[2] - img.size[0] + border, 92 | 0), max(pad[3] - img.size[1] + border, 0)) 93 | if enable_padding and max(pad) > border - 4: 94 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 95 | img = np.pad(np.float32(img), 96 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 97 | h, w, _ = img.shape 98 | y, x, _ = np.ogrid[:h, :w, :1] 99 | mask = np.maximum( 100 | 1.0 - 101 | np.minimum(np.float32(x) / pad[0], 102 | np.float32(w - 1 - x) / pad[2]), 1.0 - 103 | np.minimum(np.float32(y) / pad[1], 104 | np.float32(h - 1 - y) / pad[3])) 105 | blur = qsize * 0.02 106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - 107 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 108 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 110 | 'RGB') 111 | quad += pad[:2] 112 | 113 | # Transform. 114 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, 115 | (quad + 0.5).flatten(), PIL.Image.BILINEAR) 116 | if output_size < transform_size: 117 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 118 | 119 | # Save aligned image. 120 | img.save(dst_file, 'PNG') 121 | 122 | 123 | class LandmarksDetector: 124 | def __init__(self, predictor_model_path): 125 | """ 126 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file 127 | """ 128 | self.detector = dlib.get_frontal_face_detector( 129 | ) # cnn_face_detection_model_v1 also can be used 130 | self.shape_predictor = dlib.shape_predictor(predictor_model_path) 131 | 132 | def get_landmarks(self, image): 133 | img = dlib.load_rgb_image(image) 134 | dets = self.detector(img, 1) 135 | 136 | for detection in dets: 137 | face_landmarks = [ 138 | (item.x, item.y) 139 | for item in self.shape_predictor(img, detection).parts() 140 | ] 141 | yield face_landmarks 142 | 143 | 144 | def unpack_bz2(src_path): 145 | dst_path = src_path[:-4] 146 | if os.path.exists(dst_path): 147 | print('cached') 148 | return dst_path 149 | data = bz2.BZ2File(src_path).read() 150 | with open(dst_path, 'wb') as fp: 151 | fp.write(data) 152 | return dst_path 153 | 154 | 155 | def work_landmark(raw_img_path, img_name, face_landmarks): 156 | face_img_name = '%s.png' % (os.path.splitext(img_name)[0], ) 157 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 158 | if os.path.exists(aligned_face_path): 159 | return 160 | image_align(raw_img_path, 161 | aligned_face_path, 162 | face_landmarks, 163 | output_size=256) 164 | 165 | 166 | def get_file(src, tgt): 167 | if os.path.exists(tgt): 168 | print('cached') 169 | return tgt 170 | tgt_dir = os.path.dirname(tgt) 171 | if not os.path.exists(tgt_dir): 172 | os.makedirs(tgt_dir) 173 | file = requests.get(src) 174 | open(tgt, 'wb').write(file.content) 175 | return tgt 176 | 177 | 178 | if __name__ == "__main__": 179 | """ 180 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step 181 | python align_images.py /raw_images /aligned_images 182 | """ 183 | parser = ArgumentParser() 184 | parser.add_argument("-i", 185 | "--input_imgs_path", 186 | type=str, 187 | default="imgs", 188 | help="input images directory path") 189 | parser.add_argument("-o", 190 | "--output_imgs_path", 191 | type=str, 192 | default="imgs_align", 193 | help="output images directory path") 194 | 195 | args = parser.parse_args() 196 | 197 | # takes very long time ... 198 | landmarks_model_path = unpack_bz2( 199 | get_file( 200 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', 201 | 'temp/shape_predictor_68_face_landmarks.dat.bz2')) 202 | 203 | # RAW_IMAGES_DIR = sys.argv[1] 204 | # ALIGNED_IMAGES_DIR = sys.argv[2] 205 | RAW_IMAGES_DIR = args.input_imgs_path 206 | ALIGNED_IMAGES_DIR = args.output_imgs_path 207 | 208 | if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR) 209 | 210 | files = os.listdir(RAW_IMAGES_DIR) 211 | print(f'total img files {len(files)}') 212 | with tqdm(total=len(files)) as progress: 213 | 214 | def cb(*args): 215 | # print('update') 216 | progress.update() 217 | 218 | def err_cb(e): 219 | print('error:', e) 220 | 221 | with Pool(8) as pool: 222 | res = [] 223 | landmarks_detector = LandmarksDetector(landmarks_model_path) 224 | for img_name in files: 225 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) 226 | # print('img_name:', img_name) 227 | for i, face_landmarks in enumerate( 228 | landmarks_detector.get_landmarks(raw_img_path), 229 | start=1): 230 | # assert i == 1, f'{i}' 231 | # print(i, face_landmarks) 232 | # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) 233 | # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 234 | # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256) 235 | 236 | work_landmark(raw_img_path, img_name, face_landmarks) 237 | progress.update() 238 | 239 | # job = pool.apply_async( 240 | # work_landmark, 241 | # (raw_img_path, img_name, face_landmarks), 242 | # callback=cb, 243 | # error_callback=err_cb, 244 | # ) 245 | # res.append(job) 246 | 247 | # pool.close() 248 | # pool.join() 249 | print(f"output aligned images at: {ALIGNED_IMAGES_DIR}") 250 | -------------------------------------------------------------------------------- /diffae/choices.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from torch import nn 3 | 4 | 5 | class TrainMode(Enum): 6 | # manipulate mode = training the classifier 7 | manipulate = 'manipulate' 8 | # default trainin mode! 9 | diffusion = 'diffusion' 10 | # default latent training mode! 11 | # fitting the a DDPM to a given latent 12 | latent_diffusion = 'latentdiffusion' 13 | 14 | def is_manipulate(self): 15 | return self in [ 16 | TrainMode.manipulate, 17 | ] 18 | 19 | def is_diffusion(self): 20 | return self in [ 21 | TrainMode.diffusion, 22 | TrainMode.latent_diffusion, 23 | ] 24 | 25 | def is_autoenc(self): 26 | # the network possibly does autoencoding 27 | return self in [ 28 | TrainMode.diffusion, 29 | ] 30 | 31 | def is_latent_diffusion(self): 32 | return self in [ 33 | TrainMode.latent_diffusion, 34 | ] 35 | 36 | def use_latent_net(self): 37 | return self.is_latent_diffusion() 38 | 39 | def require_dataset_infer(self): 40 | """ 41 | whether training in this mode requires the latent variables to be available? 42 | """ 43 | # this will precalculate all the latents before hand 44 | # and the dataset will be all the predicted latents 45 | return self in [ 46 | TrainMode.latent_diffusion, 47 | TrainMode.manipulate, 48 | ] 49 | 50 | 51 | class ManipulateMode(Enum): 52 | """ 53 | how to train the classifier to manipulate 54 | """ 55 | # train on whole celeba attr dataset 56 | celebahq_all = 'celebahq_all' 57 | # celeba with D2C's crop 58 | d2c_fewshot = 'd2cfewshot' 59 | d2c_fewshot_allneg = 'd2cfewshotallneg' 60 | 61 | def is_celeba_attr(self): 62 | return self in [ 63 | ManipulateMode.d2c_fewshot, 64 | ManipulateMode.d2c_fewshot_allneg, 65 | ManipulateMode.celebahq_all, 66 | ] 67 | 68 | def is_single_class(self): 69 | return self in [ 70 | ManipulateMode.d2c_fewshot, 71 | ManipulateMode.d2c_fewshot_allneg, 72 | ] 73 | 74 | def is_fewshot(self): 75 | return self in [ 76 | ManipulateMode.d2c_fewshot, 77 | ManipulateMode.d2c_fewshot_allneg, 78 | ] 79 | 80 | def is_fewshot_allneg(self): 81 | return self in [ 82 | ManipulateMode.d2c_fewshot_allneg, 83 | ] 84 | 85 | 86 | class ModelType(Enum): 87 | """ 88 | Kinds of the backbone models 89 | """ 90 | 91 | # unconditional ddpm 92 | ddpm = 'ddpm' 93 | # autoencoding ddpm cannot do unconditional generation 94 | autoencoder = 'autoencoder' 95 | 96 | def has_autoenc(self): 97 | return self in [ 98 | ModelType.autoencoder, 99 | ] 100 | 101 | def can_sample(self): 102 | return self in [ModelType.ddpm] 103 | 104 | 105 | class ModelName(Enum): 106 | """ 107 | List of all supported model classes 108 | """ 109 | 110 | beatgans_ddpm = 'beatgans_ddpm' 111 | beatgans_autoenc = 'beatgans_autoenc' 112 | 113 | 114 | class ModelMeanType(Enum): 115 | """ 116 | Which type of output the model predicts. 117 | """ 118 | 119 | eps = 'eps' # the model predicts epsilon 120 | 121 | 122 | class ModelVarType(Enum): 123 | """ 124 | What is used as the model's output variance. 125 | 126 | The LEARNED_RANGE option has been added to allow the model to predict 127 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 128 | """ 129 | 130 | # posterior beta_t 131 | fixed_small = 'fixed_small' 132 | # beta_t 133 | fixed_large = 'fixed_large' 134 | 135 | 136 | class LossType(Enum): 137 | mse = 'mse' # use raw MSE loss (and KL when learning variances) 138 | l1 = 'l1' 139 | 140 | 141 | class GenerativeType(Enum): 142 | """ 143 | How's a sample generated 144 | """ 145 | 146 | ddpm = 'ddpm' 147 | ddim = 'ddim' 148 | 149 | 150 | class OptimizerType(Enum): 151 | adam = 'adam' 152 | adamw = 'adamw' 153 | 154 | 155 | class Activation(Enum): 156 | none = 'none' 157 | relu = 'relu' 158 | lrelu = 'lrelu' 159 | silu = 'silu' 160 | tanh = 'tanh' 161 | 162 | def get_act(self): 163 | if self == Activation.none: 164 | return nn.Identity() 165 | elif self == Activation.relu: 166 | return nn.ReLU() 167 | elif self == Activation.lrelu: 168 | return nn.LeakyReLU(negative_slope=0.2) 169 | elif self == Activation.silu: 170 | return nn.SiLU() 171 | elif self == Activation.tanh: 172 | return nn.Tanh() 173 | else: 174 | raise NotImplementedError() 175 | 176 | 177 | class ManipulateLossType(Enum): 178 | bce = 'bce' 179 | mse = 'mse' -------------------------------------------------------------------------------- /diffae/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "10.2" 3 | gpu: true 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "numpy==1.21.5" 10 | - "cmake==3.23.3" 11 | - "ipython==7.21.0" 12 | - "opencv-python==4.5.4.58" 13 | - "pandas==1.1.5" 14 | - "lmdb==1.2.1" 15 | - "lpips==0.1.4" 16 | - "pytorch-fid==0.2.0" 17 | - "ftfy==6.1.1" 18 | - "scipy==1.5.4" 19 | - "torch==1.9.1" 20 | - "torchvision==0.10.1" 21 | - "tqdm==4.62.3" 22 | - "regex==2022.7.25" 23 | - "Pillow==9.2.0" 24 | - "pytorch_lightning==1.7.0" 25 | 26 | run: 27 | - pip install dlib 28 | 29 | predict: "predict.py:Predictor" 30 | -------------------------------------------------------------------------------- /diffae/config_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from copy import deepcopy 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class BaseConfig: 9 | def clone(self): 10 | return deepcopy(self) 11 | 12 | def inherit(self, another): 13 | """inherit common keys from a given config""" 14 | common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) 15 | for k in common_keys: 16 | setattr(self, k, getattr(another, k)) 17 | 18 | def propagate(self): 19 | """push down the configuration to all members""" 20 | for k, v in self.__dict__.items(): 21 | if isinstance(v, BaseConfig): 22 | v.inherit(self) 23 | v.propagate() 24 | 25 | def save(self, save_path): 26 | """save config to json file""" 27 | dirname = os.path.dirname(save_path) 28 | if not os.path.exists(dirname): 29 | os.makedirs(dirname) 30 | conf = self.as_dict_jsonable() 31 | with open(save_path, 'w') as f: 32 | json.dump(conf, f) 33 | 34 | def load(self, load_path): 35 | """load json config""" 36 | with open(load_path) as f: 37 | conf = json.load(f) 38 | self.from_dict(conf) 39 | 40 | def from_dict(self, dict, strict=False): 41 | for k, v in dict.items(): 42 | if not hasattr(self, k): 43 | if strict: 44 | raise ValueError(f"loading extra '{k}'") 45 | else: 46 | print(f"loading extra '{k}'") 47 | continue 48 | if isinstance(self.__dict__[k], BaseConfig): 49 | self.__dict__[k].from_dict(v) 50 | else: 51 | self.__dict__[k] = v 52 | 53 | def as_dict_jsonable(self): 54 | conf = {} 55 | for k, v in self.__dict__.items(): 56 | if isinstance(v, BaseConfig): 57 | conf[k] = v.as_dict_jsonable() 58 | else: 59 | if jsonable(v): 60 | conf[k] = v 61 | else: 62 | # ignore not jsonable 63 | pass 64 | return conf 65 | 66 | 67 | def jsonable(x): 68 | try: 69 | json.dumps(x) 70 | return True 71 | except TypeError: 72 | return False 73 | -------------------------------------------------------------------------------- /diffae/dataset_util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from dist_utils import * 4 | 5 | 6 | def use_cached_dataset_path(source_path, cache_path): 7 | if get_rank() == 0: 8 | if not os.path.exists(cache_path): 9 | # shutil.rmtree(cache_path) 10 | print(f'copying the data: {source_path} to {cache_path}') 11 | shutil.copytree(source_path, cache_path) 12 | barrier() 13 | return cache_path -------------------------------------------------------------------------------- /diffae/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig 4 | 5 | Sampler = Union[SpacedDiffusionBeatGans] 6 | SamplerConfig = Union[SpacedDiffusionBeatGansConfig] 7 | -------------------------------------------------------------------------------- /diffae/diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from dataclasses import dataclass 3 | 4 | 5 | def space_timesteps(num_timesteps, section_counts): 6 | """ 7 | Create a list of timesteps to use from an original diffusion process, 8 | given the number of timesteps we want to take from equally-sized portions 9 | of the original process. 10 | 11 | For example, if there's 300 timesteps and the section counts are [10,15,20] 12 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 13 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 14 | 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | 18 | :param num_timesteps: the number of diffusion steps in the original 19 | process to divide up. 20 | :param section_counts: either a list of numbers, or a string containing 21 | comma-separated numbers, indicating the step count 22 | per section. As a special case, use "ddimN" where N 23 | is a number of steps to use the striding from the 24 | DDIM paper. 25 | :return: a set of diffusion steps from the original process to use. 26 | """ 27 | if isinstance(section_counts, str): 28 | if section_counts.startswith("ddim"): 29 | desired_count = int(section_counts[len("ddim"):]) 30 | for i in range(1, num_timesteps): 31 | if len(range(0, num_timesteps, i)) == desired_count: 32 | return set(range(0, num_timesteps, i)) 33 | raise ValueError( 34 | f"cannot create exactly {num_timesteps} steps with an integer stride" 35 | ) 36 | section_counts = [int(x) for x in section_counts.split(",")] 37 | size_per = num_timesteps // len(section_counts) 38 | extra = num_timesteps % len(section_counts) 39 | start_idx = 0 40 | all_steps = [] 41 | for i, section_count in enumerate(section_counts): 42 | size = size_per + (1 if i < extra else 0) 43 | if size < section_count: 44 | raise ValueError( 45 | f"cannot divide section of {size} steps into {section_count}") 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | @dataclass 61 | class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): 62 | use_timesteps: Tuple[int] = None 63 | 64 | def make_sampler(self): 65 | return SpacedDiffusionBeatGans(self) 66 | 67 | 68 | class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans): 69 | """ 70 | A diffusion process which can skip steps in a base diffusion process. 71 | 72 | :param use_timesteps: a collection (sequence or set) of timesteps from the 73 | original diffusion process to retain. 74 | :param kwargs: the kwargs to create the base diffusion process. 75 | """ 76 | def __init__(self, conf: SpacedDiffusionBeatGansConfig): 77 | self.conf = conf 78 | self.use_timesteps = set(conf.use_timesteps) 79 | # how the new t's mapped to the old t's 80 | self.timestep_map = [] 81 | self.original_num_steps = len(conf.betas) 82 | 83 | base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 87 | if i in self.use_timesteps: 88 | # getting the new betas of the new timesteps 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | self.timestep_map.append(i) 92 | conf.betas = np.array(new_betas) 93 | super().__init__(conf) 94 | 95 | def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs 96 | return super().p_mean_variance(self._wrap_model(model), *args, 97 | **kwargs) 98 | 99 | def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs 100 | return super().training_losses(self._wrap_model(model), *args, 101 | **kwargs) 102 | 103 | def condition_mean(self, cond_fn, *args, **kwargs): 104 | return super().condition_mean(self._wrap_model(cond_fn), *args, 105 | **kwargs) 106 | 107 | def condition_score(self, cond_fn, *args, **kwargs): 108 | return super().condition_score(self._wrap_model(cond_fn), *args, 109 | **kwargs) 110 | 111 | def _wrap_model(self, model: Model): 112 | if isinstance(model, _WrappedModel): 113 | return model 114 | return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, 115 | self.original_num_steps) 116 | 117 | def _scale_timesteps(self, t): 118 | # Scaling is done by the wrapped model. 119 | return t 120 | 121 | 122 | class _WrappedModel: 123 | """ 124 | converting the supplied t's to the old t's scales. 125 | """ 126 | def __init__(self, model, timestep_map, rescale_timesteps, 127 | original_num_steps): 128 | self.model = model 129 | self.timestep_map = timestep_map 130 | self.rescale_timesteps = rescale_timesteps 131 | self.original_num_steps = original_num_steps 132 | 133 | def forward(self, x, t, t_cond=None, **kwargs): 134 | """ 135 | Args: 136 | t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's 137 | t_cond: the same as t but can be of different values 138 | """ 139 | map_tensor = th.tensor(self.timestep_map, 140 | device=t.device, 141 | dtype=t.dtype) 142 | 143 | def do(t): 144 | new_ts = map_tensor[t] 145 | if self.rescale_timesteps: 146 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 147 | return new_ts 148 | 149 | if t_cond is not None: 150 | # support t_cond 151 | t_cond = do(t_cond) 152 | 153 | return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) 154 | 155 | def __getattr__(self, name): 156 | # allow for calling the model's methods 157 | if hasattr(self.model, name): 158 | func = getattr(self.model, name) 159 | return func 160 | raise AttributeError(name) 161 | -------------------------------------------------------------------------------- /diffae/diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | else: 18 | raise NotImplementedError(f"unknown schedule sampler: {name}") 19 | 20 | 21 | class ScheduleSampler(ABC): 22 | """ 23 | A distribution over timesteps in the diffusion process, intended to reduce 24 | variance of the objective. 25 | 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | @abstractmethod 32 | def weights(self): 33 | """ 34 | Get a numpy array of weights, one per diffusion step. 35 | 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | 43 | :param batch_size: the number of timesteps. 44 | :param device: the torch device to save to. 45 | :return: a tuple (timesteps, weights): 46 | - timesteps: a tensor of timestep indices. 47 | - weights: a tensor of weights to scale the resulting losses. 48 | """ 49 | w = self.weights() 50 | p = w / np.sum(w) 51 | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) 52 | indices = th.from_numpy(indices_np).long().to(device) 53 | weights_np = 1 / (len(p) * p[indices_np]) 54 | weights = th.from_numpy(weights_np).float().to(device) 55 | return indices, weights 56 | 57 | 58 | class UniformSampler(ScheduleSampler): 59 | def __init__(self, num_timesteps): 60 | self._weights = np.ones([num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | -------------------------------------------------------------------------------- /diffae/dist_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from torch import distributed 3 | 4 | 5 | def barrier(): 6 | if distributed.is_initialized(): 7 | distributed.barrier() 8 | else: 9 | pass 10 | 11 | 12 | def broadcast(data, src): 13 | if distributed.is_initialized(): 14 | distributed.broadcast(data, src) 15 | else: 16 | pass 17 | 18 | 19 | def all_gather(data: List, src): 20 | if distributed.is_initialized(): 21 | distributed.all_gather(data, src) 22 | else: 23 | data[0] = src 24 | 25 | 26 | def get_rank(): 27 | if distributed.is_initialized(): 28 | return distributed.get_rank() 29 | else: 30 | return 0 31 | 32 | 33 | def get_world_size(): 34 | if distributed.is_initialized(): 35 | return distributed.get_world_size() 36 | else: 37 | return 1 38 | 39 | 40 | def chunk_size(size, rank, world_size): 41 | extra = rank < size % world_size 42 | return size // world_size + extra -------------------------------------------------------------------------------- /diffae/evals/church256_autoenc.txt: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /diffae/lmdb_writer.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | 6 | import torch 7 | 8 | from contextlib import contextmanager 9 | from torch.utils.data import Dataset 10 | from multiprocessing import Process, Queue 11 | import os 12 | import shutil 13 | 14 | 15 | def convert(x, format, quality=100): 16 | # to prevent locking! 17 | torch.set_num_threads(1) 18 | 19 | buffer = BytesIO() 20 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0) 21 | x = x.to(torch.uint8) 22 | x = x.numpy() 23 | img = Image.fromarray(x) 24 | img.save(buffer, format=format, quality=quality) 25 | val = buffer.getvalue() 26 | return val 27 | 28 | 29 | @contextmanager 30 | def nullcontext(): 31 | yield 32 | 33 | 34 | class _WriterWroker(Process): 35 | def __init__(self, path, format, quality, zfill, q): 36 | super().__init__() 37 | if os.path.exists(path): 38 | shutil.rmtree(path) 39 | 40 | self.path = path 41 | self.format = format 42 | self.quality = quality 43 | self.zfill = zfill 44 | self.q = q 45 | self.i = 0 46 | 47 | def run(self): 48 | if not os.path.exists(self.path): 49 | os.makedirs(self.path) 50 | 51 | with lmdb.open(self.path, map_size=1024**4, readahead=False) as env: 52 | while True: 53 | job = self.q.get() 54 | if job is None: 55 | break 56 | with env.begin(write=True) as txn: 57 | for x in job: 58 | key = f"{str(self.i).zfill(self.zfill)}".encode( 59 | "utf-8") 60 | x = convert(x, self.format, self.quality) 61 | txn.put(key, x) 62 | self.i += 1 63 | 64 | with env.begin(write=True) as txn: 65 | txn.put("length".encode("utf-8"), str(self.i).encode("utf-8")) 66 | 67 | 68 | class LMDBImageWriter: 69 | def __init__(self, path, format='webp', quality=100, zfill=7) -> None: 70 | self.path = path 71 | self.format = format 72 | self.quality = quality 73 | self.zfill = zfill 74 | self.queue = None 75 | self.worker = None 76 | 77 | def __enter__(self): 78 | self.queue = Queue(maxsize=3) 79 | self.worker = _WriterWroker(self.path, self.format, self.quality, 80 | self.zfill, self.queue) 81 | self.worker.start() 82 | 83 | def put_images(self, tensor): 84 | """ 85 | Args: 86 | tensor: (n, c, h, w) [0-1] tensor 87 | """ 88 | self.queue.put(tensor.cpu()) 89 | # with self.env.begin(write=True) as txn: 90 | # for x in tensor: 91 | # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8") 92 | # x = convert(x, self.format, self.quality) 93 | # txn.put(key, x) 94 | # self.i += 1 95 | 96 | def __exit__(self, *args, **kwargs): 97 | self.queue.put(None) 98 | self.queue.close() 99 | self.worker.join() 100 | 101 | 102 | class LMDBImageReader(Dataset): 103 | def __init__(self, path, zfill: int = 7): 104 | self.zfill = zfill 105 | self.env = lmdb.open( 106 | path, 107 | max_readers=32, 108 | readonly=True, 109 | lock=False, 110 | readahead=False, 111 | meminit=False, 112 | ) 113 | 114 | if not self.env: 115 | raise IOError('Cannot open lmdb dataset', path) 116 | 117 | with self.env.begin(write=False) as txn: 118 | self.length = int( 119 | txn.get('length'.encode('utf-8')).decode('utf-8')) 120 | 121 | def __len__(self): 122 | return self.length 123 | 124 | def __getitem__(self, index): 125 | with self.env.begin(write=False) as txn: 126 | key = f'{str(index).zfill(self.zfill)}'.encode('utf-8') 127 | img_bytes = txn.get(key) 128 | 129 | buffer = BytesIO(img_bytes) 130 | img = Image.open(buffer) 131 | return img 132 | -------------------------------------------------------------------------------- /diffae/model/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from .unet import BeatGANsUNetModel, BeatGANsUNetConfig 3 | from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel 4 | 5 | Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel] 6 | ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig] 7 | -------------------------------------------------------------------------------- /diffae/model/latentnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import NamedTuple, Tuple 5 | 6 | import torch 7 | from choices import * 8 | from config_base import BaseConfig 9 | from torch import nn 10 | from torch.nn import init 11 | 12 | from .blocks import * 13 | from .nn import timestep_embedding 14 | from .unet import * 15 | 16 | 17 | class LatentNetType(Enum): 18 | none = 'none' 19 | # injecting inputs into the hidden layers 20 | skip = 'skip' 21 | 22 | 23 | class LatentNetReturn(NamedTuple): 24 | pred: torch.Tensor = None 25 | 26 | 27 | @dataclass 28 | class MLPSkipNetConfig(BaseConfig): 29 | """ 30 | default MLP for the latent DPM in the paper! 31 | """ 32 | num_channels: int 33 | skip_layers: Tuple[int] 34 | num_hid_channels: int 35 | num_layers: int 36 | num_time_emb_channels: int = 64 37 | activation: Activation = Activation.silu 38 | use_norm: bool = True 39 | condition_bias: float = 1 40 | dropout: float = 0 41 | last_act: Activation = Activation.none 42 | num_time_layers: int = 2 43 | time_last_act: bool = False 44 | 45 | def make_model(self): 46 | return MLPSkipNet(self) 47 | 48 | 49 | class MLPSkipNet(nn.Module): 50 | """ 51 | concat x to hidden layers 52 | 53 | default MLP for the latent DPM in the paper! 54 | """ 55 | def __init__(self, conf: MLPSkipNetConfig): 56 | super().__init__() 57 | self.conf = conf 58 | 59 | layers = [] 60 | for i in range(conf.num_time_layers): 61 | if i == 0: 62 | a = conf.num_time_emb_channels 63 | b = conf.num_channels 64 | else: 65 | a = conf.num_channels 66 | b = conf.num_channels 67 | layers.append(nn.Linear(a, b)) 68 | if i < conf.num_time_layers - 1 or conf.time_last_act: 69 | layers.append(conf.activation.get_act()) 70 | self.time_embed = nn.Sequential(*layers) 71 | 72 | self.layers = nn.ModuleList([]) 73 | for i in range(conf.num_layers): 74 | if i == 0: 75 | act = conf.activation 76 | norm = conf.use_norm 77 | cond = True 78 | a, b = conf.num_channels, conf.num_hid_channels 79 | dropout = conf.dropout 80 | elif i == conf.num_layers - 1: 81 | act = Activation.none 82 | norm = False 83 | cond = False 84 | a, b = conf.num_hid_channels, conf.num_channels 85 | dropout = 0 86 | else: 87 | act = conf.activation 88 | norm = conf.use_norm 89 | cond = True 90 | a, b = conf.num_hid_channels, conf.num_hid_channels 91 | dropout = conf.dropout 92 | 93 | if i in conf.skip_layers: 94 | a += conf.num_channels 95 | 96 | self.layers.append( 97 | MLPLNAct( 98 | a, 99 | b, 100 | norm=norm, 101 | activation=act, 102 | cond_channels=conf.num_channels, 103 | use_cond=cond, 104 | condition_bias=conf.condition_bias, 105 | dropout=dropout, 106 | )) 107 | self.last_act = conf.last_act.get_act() 108 | 109 | def forward(self, x, t, **kwargs): 110 | t = timestep_embedding(t, self.conf.num_time_emb_channels) 111 | cond = self.time_embed(t) 112 | h = x 113 | for i in range(len(self.layers)): 114 | if i in self.conf.skip_layers: 115 | # injecting input into the hidden layers 116 | h = torch.cat([h, x], dim=1) 117 | h = self.layers[i].forward(x=h, cond=cond) 118 | h = self.last_act(h) 119 | return LatentNetReturn(h) 120 | 121 | 122 | class MLPLNAct(nn.Module): 123 | def __init__( 124 | self, 125 | in_channels: int, 126 | out_channels: int, 127 | norm: bool, 128 | use_cond: bool, 129 | activation: Activation, 130 | cond_channels: int, 131 | condition_bias: float = 0, 132 | dropout: float = 0, 133 | ): 134 | super().__init__() 135 | self.activation = activation 136 | self.condition_bias = condition_bias 137 | self.use_cond = use_cond 138 | 139 | self.linear = nn.Linear(in_channels, out_channels) 140 | self.act = activation.get_act() 141 | if self.use_cond: 142 | self.linear_emb = nn.Linear(cond_channels, out_channels) 143 | self.cond_layers = nn.Sequential(self.act, self.linear_emb) 144 | if norm: 145 | self.norm = nn.LayerNorm(out_channels) 146 | else: 147 | self.norm = nn.Identity() 148 | 149 | if dropout > 0: 150 | self.dropout = nn.Dropout(p=dropout) 151 | else: 152 | self.dropout = nn.Identity() 153 | 154 | self.init_weights() 155 | 156 | def init_weights(self): 157 | for module in self.modules(): 158 | if isinstance(module, nn.Linear): 159 | if self.activation == Activation.relu: 160 | init.kaiming_normal_(module.weight, 161 | a=0, 162 | nonlinearity='relu') 163 | elif self.activation == Activation.lrelu: 164 | init.kaiming_normal_(module.weight, 165 | a=0.2, 166 | nonlinearity='leaky_relu') 167 | elif self.activation == Activation.silu: 168 | init.kaiming_normal_(module.weight, 169 | a=0, 170 | nonlinearity='relu') 171 | else: 172 | # leave it as default 173 | pass 174 | 175 | def forward(self, x, cond=None): 176 | x = self.linear(x) 177 | if self.use_cond: 178 | # (n, c) or (n, c * 2) 179 | cond = self.cond_layers(cond) 180 | cond = (cond, None) 181 | 182 | # scale shift first 183 | x = x * (self.condition_bias + cond[0]) 184 | if cond[1] is not None: 185 | x = x + cond[1] 186 | # then norm 187 | x = self.norm(x) 188 | else: 189 | # no condition 190 | x = self.norm(x) 191 | x = self.act(x) 192 | x = self.dropout(x) 193 | return x -------------------------------------------------------------------------------- /diffae/model/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | from enum import Enum 6 | import math 7 | from typing import Optional 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | import torch.nn.functional as F 14 | 15 | 16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 17 | class SiLU(nn.Module): 18 | # @th.jit.script 19 | def forward(self, x): 20 | return x * th.sigmoid(x) 21 | 22 | 23 | class GroupNorm32(nn.GroupNorm): 24 | def forward(self, x): 25 | return super().forward(x.float()).type(x.dtype) 26 | 27 | 28 | def conv_nd(dims, *args, **kwargs): 29 | """ 30 | Create a 1D, 2D, or 3D convolution module. 31 | """ 32 | if dims == 1: 33 | return nn.Conv1d(*args, **kwargs) 34 | elif dims == 2: 35 | return nn.Conv2d(*args, **kwargs) 36 | elif dims == 3: 37 | return nn.Conv3d(*args, **kwargs) 38 | raise ValueError(f"unsupported dimensions: {dims}") 39 | 40 | 41 | def linear(*args, **kwargs): 42 | """ 43 | Create a linear module. 44 | """ 45 | return nn.Linear(*args, **kwargs) 46 | 47 | 48 | def avg_pool_nd(dims, *args, **kwargs): 49 | """ 50 | Create a 1D, 2D, or 3D average pooling module. 51 | """ 52 | if dims == 1: 53 | return nn.AvgPool1d(*args, **kwargs) 54 | elif dims == 2: 55 | return nn.AvgPool2d(*args, **kwargs) 56 | elif dims == 3: 57 | return nn.AvgPool3d(*args, **kwargs) 58 | raise ValueError(f"unsupported dimensions: {dims}") 59 | 60 | 61 | def update_ema(target_params, source_params, rate=0.99): 62 | """ 63 | Update target parameters to be closer to those of source parameters using 64 | an exponential moving average. 65 | 66 | :param target_params: the target parameter sequence. 67 | :param source_params: the source parameter sequence. 68 | :param rate: the EMA rate (closer to 1 means slower). 69 | """ 70 | for targ, src in zip(target_params, source_params): 71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 72 | 73 | 74 | def zero_module(module): 75 | """ 76 | Zero out the parameters of a module and return it. 77 | """ 78 | for p in module.parameters(): 79 | p.detach().zero_() 80 | return module 81 | 82 | 83 | def scale_module(module, scale): 84 | """ 85 | Scale the parameters of a module and return it. 86 | """ 87 | for p in module.parameters(): 88 | p.detach().mul_(scale) 89 | return module 90 | 91 | 92 | def mean_flat(tensor): 93 | """ 94 | Take the mean over all non-batch dimensions. 95 | """ 96 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 97 | 98 | 99 | def normalization(channels): 100 | """ 101 | Make a standard normalization layer. 102 | 103 | :param channels: number of input channels. 104 | :return: an nn.Module for normalization. 105 | """ 106 | return GroupNorm32(min(32, channels), channels) 107 | 108 | 109 | def timestep_embedding(timesteps, dim, max_period=10000): 110 | """ 111 | Create sinusoidal timestep embeddings. 112 | 113 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 114 | These may be fractional. 115 | :param dim: the dimension of the output. 116 | :param max_period: controls the minimum frequency of the embeddings. 117 | :return: an [N x dim] Tensor of positional embeddings. 118 | """ 119 | half = dim // 2 120 | freqs = th.exp(-math.log(max_period) * 121 | th.arange(start=0, end=half, dtype=th.float32) / 122 | half).to(device=timesteps.device) 123 | args = timesteps[:, None].float() * freqs[None] 124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 125 | if dim % 2: 126 | embedding = th.cat( 127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1) 128 | return embedding 129 | 130 | 131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False): 132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 133 | if flag: 134 | return torch.utils.checkpoint.checkpoint( 135 | func, *args, preserve_rng_state=preserve_rng_state) 136 | else: 137 | return func(*args) 138 | -------------------------------------------------------------------------------- /diffae/model/unet_autoenc.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.functional import silu 6 | 7 | from .latentnet import * 8 | from .unet import * 9 | from choices import * 10 | 11 | 12 | @dataclass 13 | class BeatGANsAutoencConfig(BeatGANsUNetConfig): 14 | # number of style channels 15 | enc_out_channels: int = 512 16 | enc_attn_resolutions: Tuple[int] = None 17 | enc_pool: str = 'depthconv' 18 | enc_num_res_block: int = 2 19 | enc_channel_mult: Tuple[int] = None 20 | enc_grad_checkpoint: bool = False 21 | latent_net_conf: MLPSkipNetConfig = None 22 | 23 | def make_model(self): 24 | return BeatGANsAutoencModel(self) 25 | 26 | 27 | class BeatGANsAutoencModel(BeatGANsUNetModel): 28 | def __init__(self, conf: BeatGANsAutoencConfig): 29 | super().__init__(conf) 30 | self.conf = conf 31 | 32 | # having only time, cond 33 | self.time_embed = TimeStyleSeperateEmbed( 34 | time_channels=conf.model_channels, 35 | time_out_channels=conf.embed_channels, 36 | ) 37 | 38 | self.encoder = BeatGANsEncoderConfig( 39 | image_size=conf.image_size, 40 | in_channels=conf.in_channels, 41 | model_channels=conf.model_channels, 42 | out_hid_channels=conf.enc_out_channels, 43 | out_channels=conf.enc_out_channels, 44 | num_res_blocks=conf.enc_num_res_block, 45 | attention_resolutions=(conf.enc_attn_resolutions 46 | or conf.attention_resolutions), 47 | dropout=conf.dropout, 48 | channel_mult=conf.enc_channel_mult or conf.channel_mult, 49 | use_time_condition=False, 50 | conv_resample=conf.conv_resample, 51 | dims=conf.dims, 52 | use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, 53 | num_heads=conf.num_heads, 54 | num_head_channels=conf.num_head_channels, 55 | resblock_updown=conf.resblock_updown, 56 | use_new_attention_order=conf.use_new_attention_order, 57 | pool=conf.enc_pool, 58 | ).make_model() 59 | 60 | if conf.latent_net_conf is not None: 61 | self.latent_net = conf.latent_net_conf.make_model() 62 | 63 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 64 | """ 65 | Reparameterization trick to sample from N(mu, var) from 66 | N(0,1). 67 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 68 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 69 | :return: (Tensor) [B x D] 70 | """ 71 | assert self.conf.is_stochastic 72 | std = torch.exp(0.5 * logvar) 73 | eps = torch.randn_like(std) 74 | return eps * std + mu 75 | 76 | def sample_z(self, n: int, device): 77 | assert self.conf.is_stochastic 78 | return torch.randn(n, self.conf.enc_out_channels, device=device) 79 | 80 | def noise_to_cond(self, noise: Tensor): 81 | raise NotImplementedError() 82 | assert self.conf.noise_net_conf is not None 83 | return self.noise_net.forward(noise) 84 | 85 | def encode(self, x): 86 | cond = self.encoder.forward(x) 87 | return cond 88 | # return {'cond': cond} 89 | 90 | @property 91 | def stylespace_sizes(self): 92 | modules = list(self.input_blocks.modules()) + list( 93 | self.middle_block.modules()) + list(self.output_blocks.modules()) 94 | sizes = [] 95 | for module in modules: 96 | if isinstance(module, ResBlock): 97 | linear = module.cond_emb_layers[-1] 98 | sizes.append(linear.weight.shape[0]) 99 | return sizes 100 | 101 | def encode_stylespace(self, x, return_vector: bool = True): 102 | """ 103 | encode to style space 104 | """ 105 | modules = list(self.input_blocks.modules()) + list( 106 | self.middle_block.modules()) + list(self.output_blocks.modules()) 107 | # (n, c) 108 | cond = self.encoder.forward(x) 109 | S = [] 110 | for module in modules: 111 | if isinstance(module, ResBlock): 112 | # (n, c') 113 | s = module.cond_emb_layers.forward(cond) 114 | S.append(s) 115 | 116 | if return_vector: 117 | # (n, sum_c) 118 | return torch.cat(S, dim=1) 119 | else: 120 | return S 121 | 122 | def forward(self, 123 | x, 124 | t, 125 | y=None, 126 | x_start=None, 127 | cond=None, 128 | style=None, 129 | noise=None, 130 | t_cond=None, 131 | mix=False, 132 | ref_cond_scale=None, 133 | **kwargs): 134 | """ 135 | Apply the model to an input batch. 136 | 137 | Args: 138 | x_start: the original image to encode 139 | cond: output of the encoder 140 | noise: random noise (to predict the cond) 141 | """ 142 | 143 | if mix: 144 | assert type(cond) is dict 145 | 146 | if t_cond is None: 147 | t_cond = t 148 | 149 | if noise is not None: 150 | # if the noise is given, we predict the cond from noise 151 | cond = self.noise_to_cond(noise) 152 | 153 | if cond is None: 154 | if x is not None: 155 | assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' 156 | 157 | tmp = self.encode(x_start) 158 | cond = tmp['cond'] 159 | 160 | if t is not None: 161 | _t_emb = timestep_embedding(t, self.conf.model_channels) 162 | _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) 163 | else: 164 | # this happens when training only autoenc 165 | _t_emb = None 166 | _t_cond_emb = None 167 | 168 | if self.conf.resnet_two_cond: 169 | if not mix: 170 | res = self.time_embed.forward( 171 | time_emb=_t_emb, 172 | cond=cond, 173 | time_cond_emb=_t_cond_emb, 174 | ) 175 | else: 176 | res = self.time_embed.forward( 177 | time_emb=_t_emb, 178 | cond=None, 179 | time_cond_emb=_t_cond_emb, 180 | ) 181 | else: 182 | raise NotImplementedError() 183 | 184 | if self.conf.resnet_two_cond: 185 | # two cond: first = time emb, second = cond_emb 186 | emb = res.time_emb 187 | cond_emb = res.emb 188 | else: 189 | # one cond = combined of both time and cond 190 | emb = res.emb 191 | cond_emb = None 192 | 193 | # override the style if given 194 | style = style or res.style 195 | 196 | assert (y is not None) == ( 197 | self.conf.num_classes is not None 198 | ), "must specify y if and only if the model is class-conditional" 199 | 200 | if self.conf.num_classes is not None: 201 | raise NotImplementedError() 202 | # assert y.shape == (x.shape[0], ) 203 | # emb = emb + self.label_emb(y) 204 | 205 | # where in the model to supply time conditions 206 | enc_time_emb = emb 207 | mid_time_emb = emb 208 | dec_time_emb = emb 209 | # where in the model to supply style conditions 210 | enc_cond_emb = cond_emb 211 | mid_cond_emb = cond_emb 212 | dec_cond_emb = cond_emb 213 | 214 | # hs = [] 215 | hs = [[] for _ in range(len(self.conf.channel_mult))] 216 | 217 | if x is not None: 218 | h = x.type(self.dtype) 219 | 220 | # input blocks 221 | k = 0 222 | for i in range(len(self.input_num_blocks)): 223 | for j in range(self.input_num_blocks[i]): 224 | if not mix: 225 | h = self.input_blocks[k](h, 226 | emb=enc_time_emb, 227 | cond=enc_cond_emb) 228 | else: 229 | if i in ref_cond_scale: 230 | h = self.input_blocks[k](h, 231 | emb=enc_time_emb, 232 | cond=cond['ref']) 233 | else: 234 | h = self.input_blocks[k](h, 235 | emb=enc_time_emb, 236 | cond=cond['input']) 237 | 238 | # print(i, j, h.shape) 239 | hs[i].append(h) 240 | k += 1 241 | assert k == len(self.input_blocks) 242 | 243 | # middle blocks 244 | if not mix: 245 | h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) 246 | else: 247 | h = self.middle_block(h, emb=mid_time_emb, cond=cond['input']) 248 | else: 249 | # no lateral connections 250 | # happens when training only the autonecoder 251 | h = None 252 | hs = [[] for _ in range(len(self.conf.channel_mult))] 253 | 254 | # output blocks 255 | k = 0 256 | n = len(self.output_num_blocks) 257 | if mix: 258 | ref_cond_scale_out = [n-(scale+1) for scale in ref_cond_scale] 259 | for i in range(len(self.output_num_blocks)): 260 | for j in range(self.output_num_blocks[i]): 261 | # take the lateral connection from the same layer (in reserve) 262 | # until there is no more, use None 263 | try: 264 | lateral = hs[-i - 1].pop() 265 | # print(i, j, lateral.shape) 266 | except IndexError: 267 | lateral = None 268 | # print(i, j, lateral) 269 | 270 | if not mix: 271 | h = self.output_blocks[k](h, 272 | emb=dec_time_emb, 273 | cond=dec_cond_emb, 274 | lateral=lateral) 275 | else: 276 | if i in ref_cond_scale_out: 277 | # if k == (len(self.output_num_blocks)*self.output_num_blocks[0]-1): 278 | h = self.output_blocks[k](h, 279 | emb=dec_time_emb, 280 | cond=cond['ref'], 281 | lateral=lateral) 282 | else: 283 | h = self.output_blocks[k](h, 284 | emb=dec_time_emb, 285 | cond=cond['input'], 286 | lateral=lateral) 287 | k += 1 288 | 289 | pred = self.out(h) 290 | return AutoencReturn(pred=pred, cond=cond) 291 | 292 | 293 | class AutoencReturn(NamedTuple): 294 | pred: Tensor 295 | cond: Tensor = None 296 | 297 | 298 | class EmbedReturn(NamedTuple): 299 | # style and time 300 | emb: Tensor = None 301 | # time only 302 | time_emb: Tensor = None 303 | # style only (but could depend on time) 304 | style: Tensor = None 305 | 306 | 307 | class TimeStyleSeperateEmbed(nn.Module): 308 | # embed only style 309 | def __init__(self, time_channels, time_out_channels): 310 | super().__init__() 311 | self.time_embed = nn.Sequential( 312 | linear(time_channels, time_out_channels), 313 | nn.SiLU(), 314 | linear(time_out_channels, time_out_channels), 315 | ) 316 | self.style = nn.Identity() 317 | 318 | def forward(self, time_emb=None, cond=None, **kwargs): 319 | if time_emb is None: 320 | # happens with autoenc training mode 321 | time_emb = None 322 | else: 323 | time_emb = self.time_embed(time_emb) 324 | style = self.style(cond) 325 | return EmbedReturn(emb=style, time_emb=time_emb, style=style) 326 | -------------------------------------------------------------------------------- /diffae/predict.py: -------------------------------------------------------------------------------- 1 | # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls 2 | # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 3 | # bunzip2 shape_predictor_68_face_landmarks.dat.bz2 4 | 5 | import os 6 | import torch 7 | from torchvision.utils import save_image 8 | import tempfile 9 | from templates import * 10 | from templates_cls import * 11 | from experiment_classifier import ClsModel 12 | from align import LandmarksDetector, image_align 13 | from cog import BasePredictor, Path, Input, BaseModel 14 | 15 | 16 | class ModelOutput(BaseModel): 17 | image: Path 18 | 19 | 20 | class Predictor(BasePredictor): 21 | def setup(self): 22 | self.aligned_dir = "aligned" 23 | os.makedirs(self.aligned_dir, exist_ok=True) 24 | self.device = "cuda:0" 25 | 26 | # Model Initialization 27 | model_config = ffhq256_autoenc() 28 | self.model = LitModel(model_config) 29 | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu") 30 | self.model.load_state_dict(state["state_dict"], strict=False) 31 | self.model.ema_model.eval() 32 | self.model.ema_model.to(self.device) 33 | 34 | # Classifier Initialization 35 | classifier_config = ffhq256_autoenc_cls() 36 | classifier_config.pretrain = None # a bit faster 37 | self.classifier = ClsModel(classifier_config) 38 | state_class = torch.load( 39 | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu" 40 | ) 41 | print("latent step:", state_class["global_step"]) 42 | self.classifier.load_state_dict(state_class["state_dict"], strict=False) 43 | self.classifier.to(self.device) 44 | 45 | self.landmarks_detector = LandmarksDetector( 46 | "shape_predictor_68_face_landmarks.dat" 47 | ) 48 | 49 | def predict( 50 | self, 51 | image: Path = Input( 52 | description="Input image for face manipulation. Image will be aligned and cropped, " 53 | "output aligned and manipulated images.", 54 | ), 55 | target_class: str = Input( 56 | default="Bangs", 57 | choices=[ 58 | "5_o_Clock_Shadow", 59 | "Arched_Eyebrows", 60 | "Attractive", 61 | "Bags_Under_Eyes", 62 | "Bald", 63 | "Bangs", 64 | "Big_Lips", 65 | "Big_Nose", 66 | "Black_Hair", 67 | "Blond_Hair", 68 | "Blurry", 69 | "Brown_Hair", 70 | "Bushy_Eyebrows", 71 | "Chubby", 72 | "Double_Chin", 73 | "Eyeglasses", 74 | "Goatee", 75 | "Gray_Hair", 76 | "Heavy_Makeup", 77 | "High_Cheekbones", 78 | "Male", 79 | "Mouth_Slightly_Open", 80 | "Mustache", 81 | "Narrow_Eyes", 82 | "Beard", 83 | "Oval_Face", 84 | "Pale_Skin", 85 | "Pointy_Nose", 86 | "Receding_Hairline", 87 | "Rosy_Cheeks", 88 | "Sideburns", 89 | "Smiling", 90 | "Straight_Hair", 91 | "Wavy_Hair", 92 | "Wearing_Earrings", 93 | "Wearing_Hat", 94 | "Wearing_Lipstick", 95 | "Wearing_Necklace", 96 | "Wearing_Necktie", 97 | "Young", 98 | ], 99 | description="Choose manipulation direction.", 100 | ), 101 | manipulation_amplitude: float = Input( 102 | default=0.3, 103 | ge=-0.5, 104 | le=0.5, 105 | description="When set too strong it would result in artifact as it could dominate the original image information.", 106 | ), 107 | T_step: int = Input( 108 | default=100, 109 | choices=[50, 100, 125, 200, 250, 500], 110 | description="Number of step for generation.", 111 | ), 112 | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]), 113 | ) -> List[ModelOutput]: 114 | 115 | img_size = 256 116 | print("Aligning image...") 117 | for i, face_landmarks in enumerate( 118 | self.landmarks_detector.get_landmarks(str(image)), start=1 119 | ): 120 | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks) 121 | 122 | data = ImageDataset( 123 | self.aligned_dir, 124 | image_size=img_size, 125 | exts=["jpg", "jpeg", "JPG", "png"], 126 | do_augment=False, 127 | ) 128 | 129 | print("Encoding and Manipulating the aligned image...") 130 | cls_manipulation_amplitude = manipulation_amplitude 131 | interpreted_target_class = target_class 132 | if ( 133 | target_class not in CelebAttrDataset.id_to_cls 134 | and f"No_{target_class}" in CelebAttrDataset.id_to_cls 135 | ): 136 | cls_manipulation_amplitude = -manipulation_amplitude 137 | interpreted_target_class = f"No_{target_class}" 138 | 139 | batch = data[0]["img"][None] 140 | 141 | semantic_latent = self.model.encode(batch.to(self.device)) 142 | stochastic_latent = self.model.encode_stochastic( 143 | batch.to(self.device), semantic_latent, T=T_inv 144 | ) 145 | 146 | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class] 147 | class_direction = self.classifier.classifier.weight[cls_id] 148 | normalized_class_direction = F.normalize(class_direction[None, :], dim=1) 149 | 150 | normalized_semantic_latent = self.classifier.normalize(semantic_latent) 151 | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512) 152 | normalized_manipulated_semantic_latent = ( 153 | normalized_semantic_latent 154 | + normalized_manipulation_amp * normalized_class_direction 155 | ) 156 | 157 | manipulated_semantic_latent = self.classifier.denormalize( 158 | normalized_manipulated_semantic_latent 159 | ) 160 | 161 | # Render Manipulated image 162 | manipulated_img = self.model.render( 163 | stochastic_latent, manipulated_semantic_latent, T=T_step 164 | )[0] 165 | original_img = data[0]["img"] 166 | 167 | model_output = [] 168 | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png" 169 | save_image(convert2rgb(original_img), str(out_path)) 170 | model_output.append(ModelOutput(image=out_path)) 171 | 172 | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png" 173 | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path)) 174 | model_output.append(ModelOutput(image=out_path)) 175 | return model_output 176 | 177 | 178 | def convert2rgb(img, adjust_scale=True): 179 | convert_img = torch.tensor(img) 180 | if adjust_scale: 181 | convert_img = (convert_img + 1) / 2 182 | return convert_img.cpu() 183 | -------------------------------------------------------------------------------- /diffae/renderer.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | from torch.cuda import amp 4 | 5 | 6 | def render_uncondition(conf: TrainConfig, 7 | model: BeatGANsAutoencModel, 8 | x_T, 9 | sampler: Sampler, 10 | latent_sampler: Sampler, 11 | conds_mean=None, 12 | conds_std=None, 13 | clip_latent_noise: bool = False): 14 | device = x_T.device 15 | if conf.train_mode == TrainMode.diffusion: 16 | assert conf.model_type.can_sample() 17 | return sampler.sample(model=model, noise=x_T) 18 | elif conf.train_mode.is_latent_diffusion(): 19 | model: BeatGANsAutoencModel 20 | if conf.train_mode == TrainMode.latent_diffusion: 21 | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device) 22 | else: 23 | raise NotImplementedError() 24 | 25 | if clip_latent_noise: 26 | latent_noise = latent_noise.clip(-1, 1) 27 | 28 | cond = latent_sampler.sample( 29 | model=model.latent_net, 30 | noise=latent_noise, 31 | clip_denoised=conf.latent_clip_sample, 32 | ) 33 | 34 | if conf.latent_znormalize: 35 | cond = cond * conds_std.to(device) + conds_mean.to(device) 36 | 37 | # the diffusion on the model 38 | return sampler.sample(model=model, noise=x_T, cond=cond) 39 | else: 40 | raise NotImplementedError() 41 | 42 | 43 | def render_condition( 44 | conf: TrainConfig, 45 | model: BeatGANsAutoencModel, 46 | x_T, 47 | sampler: Sampler, 48 | x_start=None, 49 | cond=None, 50 | ): 51 | if conf.train_mode == TrainMode.diffusion: 52 | assert conf.model_type.has_autoenc() 53 | # returns {'cond', 'cond2'} 54 | if cond is None: 55 | cond = model.encode(x_start) 56 | return sampler.sample(model=model, 57 | noise=x_T, 58 | model_kwargs={'cond': cond}) 59 | else: 60 | raise NotImplementedError() 61 | -------------------------------------------------------------------------------- /diffae/run_afhq256-dog.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # gpus = [0, 1, 2, 3, 4, 5, 6, 7] 6 | conf = ffhq256_autoenc() 7 | 8 | conf.data_name = 'afhq256-dog' 9 | conf.name = 'afhq256-dog_autoenc_old' 10 | conf.total_samples = 90_000_000 11 | conf.sample_every_samples = 500_000 12 | conf.eval_ema_every_samples = 10_000_000 13 | conf.eval_every_samples = 10_000_000 14 | conf.eval_num_images = 1000 15 | 16 | # train(conf, gpus=gpus) 17 | 18 | # infer the latents for training the latent DPM 19 | # NOTE: not gpu heavy, but more gpus can be of use! 20 | # gpus = [3] 21 | # conf.eval_programs = ['infer'] 22 | # train(conf, gpus=gpus, mode='eval') 23 | 24 | # # train the latent DPM 25 | # # NOTE: only need a single gpu 26 | gpus = [3] 27 | conf = ffhq256_autoenc_latent() 28 | conf.data_name = 'afhq256-dog' 29 | conf.name = 'afhq256-dog_autoenc_latent_old' 30 | conf.pretrain = PretrainConfig( 31 | name='90M', 32 | path=f'checkpoints/afhq256-dog_autoenc_old/last.ckpt', 33 | ) 34 | conf.latent_infer_path = f'checkpoints/afhq256-dog_autoenc_old/latent.pkl' 35 | conf.total_samples = 40_000_000 36 | conf.sample_every_samples = 5_000_000 37 | train(conf, gpus=gpus) -------------------------------------------------------------------------------- /diffae/run_church256.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # gpus = [0, 1, 2, 3, 4, 5, 6, 7] 6 | conf = ffhq256_autoenc() 7 | 8 | conf.data_name = 'church256' 9 | conf.name = 'church256_autoenc' 10 | conf.total_samples = 90_000_000 11 | conf.sample_every_samples = 500_000 12 | conf.eval_ema_every_samples = 10_000_000 13 | conf.eval_every_samples = 10_000_000 14 | 15 | # train(conf, gpus=gpus) 16 | 17 | # infer the latents for training the latent DPM 18 | # NOTE: not gpu heavy, but more gpus can be of use! 19 | # gpus = [2] 20 | # conf.eval_programs = ['infer'] 21 | # train(conf, gpus=gpus, mode='eval') 22 | 23 | # # train the latent DPM 24 | # # NOTE: only need a single gpu 25 | gpus = [2] 26 | conf = ffhq256_autoenc_latent() 27 | conf.data_name = 'church256' 28 | conf.name = 'church256_autoenc_latent' 29 | conf.pretrain = PretrainConfig( 30 | name='90M', 31 | path=f'checkpoints/church256_autoenc/last.ckpt', 32 | ) 33 | conf.latent_infer_path = f'checkpoints/church256_autoenc/latent.pkl' 34 | conf.total_samples = 100_000_000 35 | conf.sample_every_samples = 5_000_000 36 | train(conf, gpus=gpus) -------------------------------------------------------------------------------- /diffae/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([ 10 | exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) 11 | for x in range(window_size) 12 | ]) 13 | return gauss / gauss.sum() 14 | 15 | 16 | def create_window(window_size, channel): 17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 18 | _2D_window = _1D_window.mm( 19 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable( 21 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()) 22 | return window 23 | 24 | 25 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 26 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 27 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 28 | 29 | mu1_sq = mu1.pow(2) 30 | mu2_sq = mu2.pow(2) 31 | mu1_mu2 = mu1 * mu2 32 | 33 | sigma1_sq = F.conv2d( 34 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 35 | sigma2_sq = F.conv2d( 36 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 37 | sigma12 = F.conv2d( 38 | img1 * img2, window, padding=window_size // 2, 39 | groups=channel) - mu1_mu2 40 | 41 | C1 = 0.01**2 42 | C2 = 0.03**2 43 | 44 | ssim_map = ((2 * mu1_mu2 + C1) * 45 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 46 | (sigma1_sq + sigma2_sq + C2)) 47 | 48 | if size_average: 49 | return ssim_map.mean() 50 | else: 51 | return ssim_map.mean(1).mean(1).mean(1) 52 | 53 | 54 | class SSIM(torch.nn.Module): 55 | def __init__(self, window_size=11, size_average=True): 56 | super(SSIM, self).__init__() 57 | self.window_size = window_size 58 | self.size_average = size_average 59 | self.channel = 1 60 | self.window = create_window(window_size, self.channel) 61 | 62 | def forward(self, img1, img2): 63 | (_, channel, _, _) = img1.size() 64 | 65 | if channel == self.channel and self.window.data.type( 66 | ) == img1.data.type(): 67 | window = self.window 68 | else: 69 | window = create_window(self.window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | self.window = window 76 | self.channel = channel 77 | 78 | return _ssim(img1, img2, window, self.window_size, channel, 79 | self.size_average) 80 | 81 | 82 | def ssim(img1, img2, window_size=11, size_average=True): 83 | (_, channel, _, _) = img1.size() 84 | window = create_window(window_size, channel) 85 | 86 | if img1.is_cuda: 87 | window = window.cuda(img1.get_device()) 88 | window = window.type_as(img1) 89 | 90 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /diffae/templates.py: -------------------------------------------------------------------------------- 1 | from experiment import * 2 | 3 | 4 | def ddpm(): 5 | """ 6 | base configuration for all DDIM-based models. 7 | """ 8 | conf = TrainConfig() 9 | conf.batch_size = 32 10 | conf.beatgans_gen_type = GenerativeType.ddim 11 | conf.beta_scheduler = 'linear' 12 | conf.data_name = 'ffhq' 13 | conf.diffusion_type = 'beatgans' 14 | conf.eval_ema_every_samples = 200_000 15 | conf.eval_every_samples = 200_000 16 | conf.fp16 = True 17 | conf.lr = 1e-4 18 | conf.model_name = ModelName.beatgans_ddpm 19 | conf.net_attn = (16, ) 20 | conf.net_beatgans_attn_head = 1 21 | conf.net_beatgans_embed_channels = 512 22 | conf.net_ch_mult = (1, 2, 4, 8) 23 | conf.net_ch = 64 24 | conf.sample_size = 32 25 | conf.T_eval = 20 26 | conf.T = 1000 27 | conf.make_model_conf() 28 | return conf 29 | 30 | 31 | def autoenc_base(): 32 | """ 33 | base configuration for all Diff-AE models. 34 | """ 35 | conf = TrainConfig() 36 | conf.batch_size = 32 37 | conf.beatgans_gen_type = GenerativeType.ddim 38 | conf.beta_scheduler = 'linear' 39 | conf.data_name = 'ffhq' 40 | conf.diffusion_type = 'beatgans' 41 | conf.eval_ema_every_samples = 200_000 42 | conf.eval_every_samples = 200_000 43 | conf.fp16 = True 44 | conf.lr = 1e-4 45 | conf.model_name = ModelName.beatgans_autoenc 46 | conf.net_attn = (16, ) 47 | conf.net_beatgans_attn_head = 1 48 | conf.net_beatgans_embed_channels = 512 49 | conf.net_beatgans_resnet_two_cond = True 50 | conf.net_ch_mult = (1, 2, 4, 8) 51 | conf.net_ch = 64 52 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8) 53 | conf.net_enc_pool = 'adaptivenonzero' 54 | conf.sample_size = 32 55 | conf.T_eval = 20 56 | conf.T = 1000 57 | conf.make_model_conf() 58 | return conf 59 | 60 | 61 | def ffhq64_ddpm(): 62 | conf = ddpm() 63 | conf.data_name = 'ffhqlmdb256' 64 | conf.warmup = 0 65 | conf.total_samples = 72_000_000 66 | conf.scale_up_gpus(4) 67 | return conf 68 | 69 | 70 | def ffhq64_autoenc(): 71 | conf = autoenc_base() 72 | conf.data_name = 'ffhqlmdb256' 73 | conf.warmup = 0 74 | conf.total_samples = 72_000_000 75 | conf.net_ch_mult = (1, 2, 4, 8) 76 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8) 77 | conf.eval_every_samples = 1_000_000 78 | conf.eval_ema_every_samples = 1_000_000 79 | conf.scale_up_gpus(4) 80 | conf.make_model_conf() 81 | return conf 82 | 83 | 84 | def celeba64d2c_ddpm(): 85 | conf = ffhq128_ddpm() 86 | conf.data_name = 'celebalmdb' 87 | conf.eval_every_samples = 10_000_000 88 | conf.eval_ema_every_samples = 10_000_000 89 | conf.total_samples = 72_000_000 90 | conf.name = 'celeba64d2c_ddpm' 91 | return conf 92 | 93 | 94 | def celeba64d2c_autoenc(): 95 | conf = ffhq64_autoenc() 96 | conf.data_name = 'celebalmdb' 97 | conf.eval_every_samples = 10_000_000 98 | conf.eval_ema_every_samples = 10_000_000 99 | conf.total_samples = 72_000_000 100 | conf.name = 'celeba64d2c_autoenc' 101 | return conf 102 | 103 | 104 | def ffhq128_ddpm(): 105 | conf = ddpm() 106 | conf.data_name = 'ffhqlmdb256' 107 | conf.warmup = 0 108 | conf.total_samples = 48_000_000 109 | conf.img_size = 128 110 | conf.net_ch = 128 111 | # channels: 112 | # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4 113 | # sizes: 114 | # 128 => 128 => 64 => 32 => 16 => 8 115 | conf.net_ch_mult = (1, 1, 2, 3, 4) 116 | conf.eval_every_samples = 1_000_000 117 | conf.eval_ema_every_samples = 1_000_000 118 | conf.scale_up_gpus(4) 119 | conf.eval_ema_every_samples = 10_000_000 120 | conf.eval_every_samples = 10_000_000 121 | conf.make_model_conf() 122 | return conf 123 | 124 | 125 | def ffhq128_autoenc_base(): 126 | conf = autoenc_base() 127 | conf.data_name = 'ffhqlmdb256' 128 | conf.scale_up_gpus(4) 129 | conf.img_size = 128 130 | conf.net_ch = 128 131 | # final resolution = 8x8 132 | conf.net_ch_mult = (1, 1, 2, 3, 4) 133 | # final resolution = 4x4 134 | conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) 135 | conf.eval_ema_every_samples = 10_000_000 136 | conf.eval_every_samples = 10_000_000 137 | conf.make_model_conf() 138 | return conf 139 | 140 | 141 | def ffhq256_autoenc(): 142 | conf = ffhq128_autoenc_base() 143 | conf.img_size = 256 144 | conf.net_ch = 128 145 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4) 146 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) 147 | conf.eval_every_samples = 10_000_000 148 | conf.eval_ema_every_samples = 10_000_000 149 | conf.total_samples = 200_000_000 150 | conf.batch_size = 64 151 | conf.make_model_conf() 152 | conf.name = 'ffhq256_autoenc' 153 | return conf 154 | 155 | 156 | def ffhq256_autoenc_eco(): 157 | conf = ffhq128_autoenc_base() 158 | conf.img_size = 256 159 | conf.net_ch = 128 160 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4) 161 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) 162 | conf.eval_every_samples = 10_000_000 163 | conf.eval_ema_every_samples = 10_000_000 164 | conf.total_samples = 200_000_000 165 | conf.batch_size = 64 166 | conf.make_model_conf() 167 | conf.name = 'ffhq256_autoenc_eco' 168 | return conf 169 | 170 | 171 | def ffhq128_ddpm_72M(): 172 | conf = ffhq128_ddpm() 173 | conf.total_samples = 72_000_000 174 | conf.name = 'ffhq128_ddpm_72M' 175 | return conf 176 | 177 | 178 | def ffhq128_autoenc_72M(): 179 | conf = ffhq128_autoenc_base() 180 | conf.total_samples = 72_000_000 181 | conf.name = 'ffhq128_autoenc_72M' 182 | return conf 183 | 184 | 185 | def ffhq128_ddpm_130M(): 186 | conf = ffhq128_ddpm() 187 | conf.total_samples = 130_000_000 188 | conf.eval_ema_every_samples = 10_000_000 189 | conf.eval_every_samples = 10_000_000 190 | conf.name = 'ffhq128_ddpm_130M' 191 | return conf 192 | 193 | 194 | def ffhq128_autoenc_130M(): 195 | conf = ffhq128_autoenc_base() 196 | conf.total_samples = 130_000_000 197 | conf.eval_ema_every_samples = 10_000_000 198 | conf.eval_every_samples = 10_000_000 199 | conf.name = 'ffhq128_autoenc_130M' 200 | return conf 201 | 202 | 203 | def horse128_ddpm(): 204 | conf = ffhq128_ddpm() 205 | conf.data_name = 'horse256' 206 | conf.total_samples = 130_000_000 207 | conf.eval_ema_every_samples = 10_000_000 208 | conf.eval_every_samples = 10_000_000 209 | conf.name = 'horse128_ddpm' 210 | return conf 211 | 212 | 213 | def horse128_autoenc(): 214 | conf = ffhq128_autoenc_base() 215 | conf.data_name = 'horse256' 216 | conf.total_samples = 130_000_000 217 | conf.eval_ema_every_samples = 10_000_000 218 | conf.eval_every_samples = 10_000_000 219 | conf.name = 'horse128_autoenc' 220 | return conf 221 | 222 | 223 | def bedroom128_ddpm(): 224 | conf = ffhq128_ddpm() 225 | conf.data_name = 'bedroom256' 226 | conf.eval_ema_every_samples = 10_000_000 227 | conf.eval_every_samples = 10_000_000 228 | conf.total_samples = 120_000_000 229 | conf.name = 'bedroom128_ddpm' 230 | return conf 231 | 232 | 233 | def bedroom128_autoenc(): 234 | conf = ffhq128_autoenc_base() 235 | conf.data_name = 'bedroom256' 236 | conf.eval_ema_every_samples = 10_000_000 237 | conf.eval_every_samples = 10_000_000 238 | conf.total_samples = 120_000_000 239 | conf.name = 'bedroom128_autoenc' 240 | return conf 241 | 242 | 243 | def pretrain_celeba64d2c_72M(): 244 | conf = celeba64d2c_autoenc() 245 | conf.pretrain = PretrainConfig( 246 | name='72M', 247 | path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt', 248 | ) 249 | conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl' 250 | return conf 251 | 252 | 253 | def pretrain_ffhq128_autoenc72M(): 254 | conf = ffhq128_autoenc_base() 255 | conf.postfix = '' 256 | conf.pretrain = PretrainConfig( 257 | name='72M', 258 | path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt', 259 | ) 260 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl' 261 | return conf 262 | 263 | 264 | def pretrain_ffhq128_autoenc130M(): 265 | conf = ffhq128_autoenc_base() 266 | conf.pretrain = PretrainConfig( 267 | name='130M', 268 | path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', 269 | ) 270 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' 271 | return conf 272 | 273 | 274 | def pretrain_ffhq256_autoenc(): 275 | conf = ffhq256_autoenc() 276 | conf.pretrain = PretrainConfig( 277 | name='90M', 278 | path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', 279 | ) 280 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' 281 | return conf 282 | 283 | 284 | def pretrain_horse128(): 285 | conf = horse128_autoenc() 286 | conf.pretrain = PretrainConfig( 287 | name='82M', 288 | path=f'checkpoints/{horse128_autoenc().name}/last.ckpt', 289 | ) 290 | conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl' 291 | return conf 292 | 293 | 294 | def pretrain_bedroom128(): 295 | conf = bedroom128_autoenc() 296 | conf.pretrain = PretrainConfig( 297 | name='120M', 298 | path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt', 299 | ) 300 | conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl' 301 | return conf 302 | -------------------------------------------------------------------------------- /diffae/templates_cls.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | 3 | 4 | def ffhq128_autoenc_cls(): 5 | conf = ffhq128_autoenc_130M() 6 | conf.train_mode = TrainMode.manipulate 7 | conf.manipulate_mode = ManipulateMode.celebahq_all 8 | conf.manipulate_znormalize = True 9 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' 10 | conf.batch_size = 32 11 | conf.lr = 1e-3 12 | conf.total_samples = 300_000 13 | # use the pretraining trick instead of contiuning trick 14 | conf.pretrain = PretrainConfig( 15 | '130M', 16 | f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', 17 | ) 18 | conf.name = 'ffhq128_autoenc_cls' 19 | return conf 20 | 21 | 22 | def ffhq256_autoenc_cls(): 23 | '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels''' 24 | conf = ffhq256_autoenc() 25 | conf.train_mode = TrainMode.manipulate 26 | conf.manipulate_mode = ManipulateMode.celebahq_all 27 | conf.manipulate_znormalize = True 28 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ 29 | conf.batch_size = 32 30 | conf.lr = 1e-3 31 | conf.total_samples = 300_000 32 | # use the pretraining trick instead of contiuning trick 33 | conf.pretrain = PretrainConfig( 34 | '130M', 35 | f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', 36 | ) 37 | conf.name = 'ffhq256_autoenc_cls' 38 | return conf 39 | -------------------------------------------------------------------------------- /diffae/templates_latent.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | 3 | 4 | def latent_diffusion_config(conf: TrainConfig): 5 | conf.batch_size = 128 6 | conf.train_mode = TrainMode.latent_diffusion 7 | conf.latent_gen_type = GenerativeType.ddim 8 | conf.latent_loss_type = LossType.mse 9 | conf.latent_model_mean_type = ModelMeanType.eps 10 | conf.latent_model_var_type = ModelVarType.fixed_large 11 | conf.latent_rescale_timesteps = False 12 | conf.latent_clip_sample = False 13 | conf.latent_T_eval = 20 14 | conf.latent_znormalize = True 15 | conf.total_samples = 96_000_000 16 | conf.sample_every_samples = 400_000 17 | conf.eval_every_samples = 20_000_000 18 | conf.eval_ema_every_samples = 20_000_000 19 | conf.save_every_samples = 2_000_000 20 | return conf 21 | 22 | 23 | def latent_diffusion128_config(conf: TrainConfig): 24 | conf = latent_diffusion_config(conf) 25 | conf.batch_size_eval = 32 26 | return conf 27 | 28 | 29 | def latent_mlp_2048_norm_10layers(conf: TrainConfig): 30 | conf.net_latent_net_type = LatentNetType.skip 31 | conf.net_latent_layers = 10 32 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) 33 | conf.net_latent_activation = Activation.silu 34 | conf.net_latent_num_hid_channels = 2048 35 | conf.net_latent_use_norm = True 36 | conf.net_latent_condition_bias = 1 37 | return conf 38 | 39 | 40 | def latent_mlp_2048_norm_20layers(conf: TrainConfig): 41 | conf = latent_mlp_2048_norm_10layers(conf) 42 | conf.net_latent_layers = 20 43 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) 44 | return conf 45 | 46 | 47 | def latent_256_batch_size(conf: TrainConfig): 48 | conf.batch_size = 256 49 | conf.eval_ema_every_samples = 100_000_000 50 | conf.eval_every_samples = 100_000_000 51 | conf.sample_every_samples = 1_000_000 52 | conf.save_every_samples = 2_000_000 53 | conf.total_samples = 301_000_000 54 | return conf 55 | 56 | 57 | def latent_512_batch_size(conf: TrainConfig): 58 | conf.batch_size = 512 59 | conf.eval_ema_every_samples = 100_000_000 60 | conf.eval_every_samples = 100_000_000 61 | conf.sample_every_samples = 1_000_000 62 | conf.save_every_samples = 5_000_000 63 | conf.total_samples = 501_000_000 64 | return conf 65 | 66 | 67 | def latent_2048_batch_size(conf: TrainConfig): 68 | conf.batch_size = 2048 69 | conf.eval_ema_every_samples = 200_000_000 70 | conf.eval_every_samples = 200_000_000 71 | conf.sample_every_samples = 4_000_000 72 | conf.save_every_samples = 20_000_000 73 | conf.total_samples = 1_501_000_000 74 | return conf 75 | 76 | 77 | def adamw_weight_decay(conf: TrainConfig): 78 | conf.optimizer = OptimizerType.adamw 79 | conf.weight_decay = 0.01 80 | return conf 81 | 82 | 83 | def ffhq128_autoenc_latent(): 84 | conf = pretrain_ffhq128_autoenc130M() 85 | conf = latent_diffusion128_config(conf) 86 | conf = latent_mlp_2048_norm_10layers(conf) 87 | conf = latent_256_batch_size(conf) 88 | conf = adamw_weight_decay(conf) 89 | conf.total_samples = 101_000_000 90 | conf.latent_loss_type = LossType.l1 91 | conf.latent_beta_scheduler = 'const0.008' 92 | conf.name = 'ffhq128_autoenc_latent' 93 | return conf 94 | 95 | 96 | def ffhq256_autoenc_latent(): 97 | conf = pretrain_ffhq256_autoenc() 98 | conf = latent_diffusion128_config(conf) 99 | conf = latent_mlp_2048_norm_10layers(conf) 100 | conf = latent_256_batch_size(conf) 101 | conf = adamw_weight_decay(conf) 102 | conf.total_samples = 101_000_000 103 | conf.latent_loss_type = LossType.l1 104 | conf.latent_beta_scheduler = 'const0.008' 105 | conf.eval_ema_every_samples = 200_000_000 106 | conf.eval_every_samples = 200_000_000 107 | conf.sample_every_samples = 4_000_000 108 | conf.name = 'ffhq256_autoenc_latent' 109 | return conf 110 | 111 | 112 | def horse128_autoenc_latent(): 113 | conf = pretrain_horse128() 114 | conf = latent_diffusion128_config(conf) 115 | conf = latent_2048_batch_size(conf) 116 | conf = latent_mlp_2048_norm_20layers(conf) 117 | conf.total_samples = 2_001_000_000 118 | conf.latent_beta_scheduler = 'const0.008' 119 | conf.latent_loss_type = LossType.l1 120 | conf.name = 'horse128_autoenc_latent' 121 | return conf 122 | 123 | 124 | def bedroom128_autoenc_latent(): 125 | conf = pretrain_bedroom128() 126 | conf = latent_diffusion128_config(conf) 127 | conf = latent_2048_batch_size(conf) 128 | conf = latent_mlp_2048_norm_20layers(conf) 129 | conf.total_samples = 2_001_000_000 130 | conf.latent_beta_scheduler = 'const0.008' 131 | conf.latent_loss_type = LossType.l1 132 | conf.name = 'bedroom128_autoenc_latent' 133 | return conf 134 | 135 | 136 | def celeba64d2c_autoenc_latent(): 137 | conf = pretrain_celeba64d2c_72M() 138 | conf = latent_diffusion_config(conf) 139 | conf = latent_512_batch_size(conf) 140 | conf = latent_mlp_2048_norm_10layers(conf) 141 | conf = adamw_weight_decay(conf) 142 | # just for the name 143 | conf.continue_from = PretrainConfig('200M', 144 | f'log-latent/{conf.name}/last.ckpt') 145 | conf.postfix = '_300M' 146 | conf.total_samples = 301_000_000 147 | conf.latent_beta_scheduler = 'const0.008' 148 | conf.latent_loss_type = LossType.l1 149 | conf.name = 'celeba64d2c_autoenc_latent' 150 | return conf 151 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: osasis 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.8.10 7 | - pip=21.2.4 8 | - mpi4py==3.1.5 9 | - pip: 10 | - blobfile==2.1.1 11 | - lpips==0.1.4 12 | - torch==1.13.0 13 | - torchvision==0.14.0 14 | - pytorch-lightning==1.8.6 15 | - tqdm==4.64.1 16 | - git+https://github.com/openai/CLIP.git 17 | - pandas==1.4.4 18 | - lmdb==1.4.0 19 | - pytorch-fid==0.2.1 -------------------------------------------------------------------------------- /eval_diffaeB.py: -------------------------------------------------------------------------------- 1 | from utils.args import make_args 2 | from utils.tester import DiffFSTester 3 | 4 | 5 | def main(args): 6 | tester = DiffFSTester(args) 7 | tester.infer_image_all() 8 | 9 | 10 | if __name__=='__main__': 11 | args = make_args() 12 | main(args) -------------------------------------------------------------------------------- /gen_style_domA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import glob 9 | import random 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | from PIL import Image 15 | import lpips 16 | 17 | import torch as th 18 | import torchvision.transforms as transforms 19 | from torchvision import utils 20 | 21 | from P2_weighting.guided_diffusion import dist_util, logger 22 | from P2_weighting.guided_diffusion.script_util import ( 23 | model_and_diffusion_defaults, 24 | create_model_and_diffusion, 25 | add_dict_to_argparser, 26 | args_to_dict, 27 | ) 28 | 29 | """ 30 | Using P2 / rtaio_0.5 / eta 1.0 / respacing 50 31 | """ 32 | 33 | 34 | def main(): 35 | args = create_argparser().parse_args() 36 | 37 | # set seed 38 | random.seed(args.seed) 39 | np.random.seed(args.seed) 40 | th.manual_seed(args.seed) 41 | 42 | dist_util.setup_dist() 43 | logger.configure(dir=args.sample_dir) 44 | 45 | logger.log("creating model and diffusion...") 46 | model, diffusion = create_model_and_diffusion( 47 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 48 | ) 49 | model.load_state_dict( 50 | dist_util.load_state_dict(args.model_path, map_location="cpu") 51 | ) 52 | model.to(dist_util.dev()) 53 | if args.use_fp16: 54 | model.convert_to_fp16() 55 | model.eval() 56 | 57 | logger.log("sampling...") 58 | 59 | device = dist_util.dev() 60 | transform_256 = transforms.Compose([ 61 | transforms.Resize((256,256)), 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 64 | ]) 65 | 66 | # imgs_ref_domainB = glob.glob(f'{args.input_dir}/*') 67 | imgs_ref_domainB = [f'{args.input_dir}/art_{args.seed}.png'] 68 | percept_loss = lpips.LPIPS(net='vgg').to(device) 69 | l1 = th.nn.L1Loss(reduction='none') 70 | 71 | # for img_file in tqdm(imgs_ref_domainB): 72 | # # file_name = img_file.split('/')[-1] 73 | # file_name = Path(img_file).name 74 | 75 | img_file = os.path.join(args.input_dir, args.img_name) 76 | img_ref_B = Image.open(img_file).convert('RGB') 77 | img_ref_B = transform_256(img_ref_B) 78 | img_ref_B = img_ref_B.unsqueeze(0).to(device) 79 | img_ref_B_all = img_ref_B.repeat(args.n,1,1,1) 80 | 81 | t_start = int(diffusion.num_timesteps*args.t_start_ratio) 82 | t_start = th.tensor([t_start], device=device) 83 | 84 | # forward DDPM 85 | xt = diffusion.q_sample(img_ref_B_all.clone(), t_start) 86 | 87 | # reverse DDPM 88 | indices = list(range(t_start))[::-1] 89 | for i in indices: 90 | t = th.tensor([i] * img_ref_B_all.shape[0], device=device) 91 | with th.no_grad(): 92 | out = diffusion.p_sample( 93 | model, 94 | xt, 95 | t, 96 | clip_denoised=True, 97 | denoised_fn=None, 98 | cond_fn=None, 99 | model_kwargs=None, 100 | ) 101 | xt = out["sample"] 102 | 103 | # # reverse DDIM 104 | # indices = list(range(t_start))[::-1] 105 | # for i in indices: 106 | # t = th.tensor([i] * img_ref_B_all.shape[0], device=device) 107 | # with th.no_grad(): 108 | # out = diffusion.ddim_sample( 109 | # model, 110 | # xt, 111 | # t, 112 | # clip_denoised=True, 113 | # denoised_fn=None, 114 | # cond_fn=None, 115 | # model_kwargs=None, 116 | # ) 117 | # xt = out["sample"] 118 | 119 | # compute loss 120 | # from torchvision.utils import save_image 121 | # save_image(xt/2+0.5, os.path.join('step1_tmp', file_name)) 122 | l1_loss = l1(xt, img_ref_B.repeat(int(args.n),1,1,1)) 123 | l1_loss = l1_loss.mean(dim=(1,2,3)) 124 | lpips_loss = percept_loss(xt, img_ref_B.repeat(int(args.n),1,1,1)).squeeze() 125 | loss = 10*l1_loss + lpips_loss 126 | 127 | # pick best image 128 | img_idx = th.argmin(loss) 129 | img_ref_A = xt[img_idx] 130 | 131 | os.makedirs(args.sample_dir, exist_ok=True) 132 | utils.save_image(img_ref_A/2+0.5, os.path.join(args.sample_dir, args.img_name)) 133 | 134 | def create_argparser(): 135 | defaults = dict( 136 | model_path="", 137 | input_dir="", 138 | sample_dir="", 139 | img_name="", 140 | n=1, 141 | t_start_ratio=0.5, 142 | eta=1.0, 143 | seed=1 144 | ) 145 | defaults.update(model_and_diffusion_defaults()) 146 | parser = argparse.ArgumentParser() 147 | add_dict_to_argparser(parser, defaults) 148 | return parser 149 | 150 | 151 | if __name__ == "__main__": 152 | main() -------------------------------------------------------------------------------- /imgs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs/teaser.jpg -------------------------------------------------------------------------------- /imgs_input_domA/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img1.png -------------------------------------------------------------------------------- /imgs_input_domA/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img2.png -------------------------------------------------------------------------------- /imgs_input_domA/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img3.png -------------------------------------------------------------------------------- /imgs_input_domA/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img4.png -------------------------------------------------------------------------------- /imgs_input_domA/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img5.png -------------------------------------------------------------------------------- /imgs_input_domA/img6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img6.png -------------------------------------------------------------------------------- /imgs_style_domB/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img1.png -------------------------------------------------------------------------------- /imgs_style_domB/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img2.png -------------------------------------------------------------------------------- /imgs_style_domB/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img3.png -------------------------------------------------------------------------------- /imgs_style_domB/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img4.png -------------------------------------------------------------------------------- /imgs_style_domB/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img5.png -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | # (step 3) Evaluate 2 | 3 | DEVICE=0 4 | 5 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 6 | python eval_diffaeB.py \ 7 | --style_domB_dir imgs_style_domB \ 8 | --infer_dir imgs_input_domA \ 9 | --ref_img img1.png \ 10 | --work_dir exp/img1 \ 11 | --map_net \ 12 | --map_time \ 13 | --lambda_map 0.1 -------------------------------------------------------------------------------- /scripts/prepare_train.sh: -------------------------------------------------------------------------------- 1 | # (step 1) download p2 weighting ffhq_p2.pt and run code 2 | 3 | DEVICE=0 4 | 5 | SAMPLE_FLAGS="--attention_resolutions 16 --class_cond False --class_cond False --diffusion_steps 1000 --dropout 0.0 \ 6 | --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 \ 7 | --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 50" 8 | 9 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 10 | python gen_style_domA.py ${SAMPLE_FLAGS} \ 11 | --model_path P2_weighting/models/ffhq_p2.pt \ 12 | --input_dir imgs_style_domB \ 13 | --sample_dir imgs_style_domA \ 14 | --img_name img1.png \ 15 | --n 1 \ 16 | --t_start_ratio 0.5 \ 17 | --seed 1 \ 18 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | # download pretrained DiffAE 2 | # batch size 8 -> 34GB / single A100 about 30min 3 | # (step 2) 4 | 5 | DEVICE=0 6 | 7 | CUDA_VISIBLE_DEVICES=${DEVICE} \ 8 | python train_diffaeB.py \ 9 | --style_domA_dir imgs_style_domA \ 10 | --style_domB_dir imgs_style_domB \ 11 | --ref_img img1.png \ 12 | --work_dir exp/img1 \ 13 | --n_iter 200 \ 14 | --ckpt_freq 200 \ 15 | --batch_size 8 \ 16 | --map_net \ 17 | --map_time \ 18 | --lambda_map 0.1 \ 19 | --train 20 | -------------------------------------------------------------------------------- /train_diffaeB.py: -------------------------------------------------------------------------------- 1 | from utils.args import make_args 2 | from utils.trainer import DiffFSTrainer 3 | 4 | 5 | def main(args): 6 | trainer = DiffFSTrainer(args) 7 | trainer.train() 8 | 9 | 10 | if __name__ == '__main__': 11 | args = make_args() 12 | main(args) -------------------------------------------------------------------------------- /utils/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | 6 | def make_args(): 7 | parser = argparse.ArgumentParser() 8 | # directory 9 | parser.add_argument('--diffae_ckpt', default='diffae/checkpoints', type=str, help='diffusion autoencoder checkpoints') 10 | parser.add_argument('--style_domA_dir', default='imgs_style_domA', type=str, help='style images (domain A)') 11 | parser.add_argument('--style_domB_dir', default='imgs_style_domB', type=str, help='style images (domain B)') 12 | parser.add_argument('--infer_dir', default='imgs_input_domA', type=str, help='input images (domain A)') 13 | # data 14 | parser.add_argument('--ref_img', default='digital_painting_jing.png', type=str, help='reference image file name') 15 | parser.add_argument('--work_dir', default='exp0', type=str, help='experiment working directory') 16 | # train 17 | parser.add_argument('--seed', default=0, type=int, help='random seed') 18 | parser.add_argument('--n_iter', default=200, type=int, help='training iteration') 19 | parser.add_argument('--batch_size', default=8, type=int, help='batch size') 20 | parser.add_argument('--lr', default=5e-6, type=float, help='learning rate') 21 | # reverse process 22 | parser.add_argument('--T_train_for', default=50, type=int, help='forward timesteps during training') 23 | parser.add_argument('--T_train_back', default=20, type=int, help='backward timesteps during training') 24 | parser.add_argument('--T_infer_for', default=100, type=int, help='forward timesteps during inference') 25 | parser.add_argument('--T_infer_back', default=50, type=int, help='backward timesteps during inference') 26 | parser.add_argument('--T_latent', default=200, type=int, help='latent ddim teimsteps') 27 | parser.add_argument('--t0_ratio', default=0.5, type=float, help='return step ratio') 28 | # loss coefficients 29 | parser.add_argument('--cross_dom', default=1, type=float, help='cross domain loss') 30 | parser.add_argument('--in_dom', default=0.5, type=float, help='in domain loss') 31 | parser.add_argument('--recon_clip', default=30, type=float, help='clip reconstruction loss') 32 | parser.add_argument('--recon_l1', default=10, type=float, help='l1 reconstruction') 33 | parser.add_argument('--recon_lpips', default=10, type=float, help='lpips reconstruction') 34 | # utils 35 | parser.add_argument('--print_freq', default=10, type=int) 36 | parser.add_argument('--train_img_freq', default=50, type=int) 37 | parser.add_argument('--ckpt_freq', default=200, type=int) 38 | parser.add_argument('--train', action='store_true', help='train flag') 39 | # mapping net 40 | parser.add_argument('--map_net', action='store_true', help='using mappingnet') 41 | parser.add_argument('--map_time', action='store_true', help='using mappingnet time') 42 | parser.add_argument('--lambda_map', default=0.1, type=float, help='weigth of mapping net output') 43 | 44 | args = parser.parse_args() 45 | 46 | args.device = 'cuda' 47 | args.ref_img_name = Path(args.ref_img).stem 48 | 49 | if args.train: 50 | os.makedirs(args.work_dir, exist_ok=True) 51 | with open(os.path.join(args.work_dir, 'args.txt'), 'w') as f: 52 | json.dump(args.__dict__, f, indent=2) 53 | 54 | return args 55 | 56 | -------------------------------------------------------------------------------- /utils/map_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffae.model.nn import timestep_embedding 4 | 5 | class MappingNet(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.layers = nn.Sequential( 10 | nn.Conv2d(3, 128, kernel_size=1), 11 | 12 | nn.GroupNorm(32, 128), 13 | nn.SiLU(), 14 | nn.Conv2d(128, 256, kernel_size=1), 15 | 16 | nn.GroupNorm(32, 256), 17 | nn.SiLU(), 18 | nn.Conv2d(256, 128, kernel_size=1), 19 | 20 | nn.GroupNorm(32, 128), 21 | nn.SiLU(), 22 | nn.Conv2d(128, 3, kernel_size=1) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.layers(x) 27 | return x 28 | 29 | 30 | class MappingNetTime(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | 34 | self.in_layer = nn.Sequential( 35 | nn.Conv2d(3, 128, kernel_size=1), 36 | ) 37 | 38 | self.block1 = MappingNetTimeBlock( 39 | in_ch = 128, 40 | out_ch = 256, 41 | t_emb_dim = 512 42 | ) 43 | 44 | self.block2 = MappingNetTimeBlock( 45 | in_ch = 256, 46 | out_ch = 128, 47 | t_emb_dim = 512 48 | ) 49 | 50 | self.out_layer = nn.Sequential( 51 | nn.Conv2d(128, 3, kernel_size=1) 52 | ) 53 | 54 | self.time_embed = nn.Sequential( 55 | nn.Linear(128, 512), 56 | nn.SiLU(), 57 | nn.Linear(512, 512), 58 | ) 59 | 60 | def forward(self, x, t): 61 | t_emb = timestep_embedding(t, 128) 62 | t_emb = self.time_embed(t_emb) 63 | 64 | x = self.in_layer(x) 65 | x = self.block1(x, t_emb) 66 | x = self.block2(x, t_emb) 67 | x = self.out_layer(x) 68 | 69 | return x 70 | 71 | 72 | class MappingNetTimeBlock(nn.Module): 73 | def __init__(self, in_ch, out_ch, t_emb_dim): 74 | super().__init__() 75 | 76 | self.pre_layer = nn.Sequential( 77 | nn.GroupNorm(32, in_ch), 78 | nn.SiLU(), 79 | nn.Conv2d(in_ch, out_ch, kernel_size=1), 80 | nn.GroupNorm(32, out_ch), 81 | ) 82 | 83 | self.post_layer = nn.Sequential( 84 | nn.SiLU(), 85 | nn.Conv2d(out_ch, out_ch, kernel_size=1) 86 | ) 87 | 88 | self.emb_layer = nn.Sequential( 89 | nn.SiLU(), 90 | nn.Linear(t_emb_dim, out_ch*2) 91 | ) 92 | 93 | def forward(self, x, t_emb, scale_bias:float=1): 94 | 95 | t_emb = self.emb_layer(t_emb) 96 | 97 | # match shape 98 | while len(t_emb.shape) < len(x.shape): 99 | t_emb = t_emb[..., None] 100 | 101 | scale, shift = torch.chunk(t_emb, 2, dim=1) 102 | 103 | x = self.pre_layer(x) 104 | x = x * (scale_bias + scale) 105 | x = x + shift 106 | x = self.post_layer(x) 107 | 108 | return x -------------------------------------------------------------------------------- /utils/tester.py: -------------------------------------------------------------------------------- 1 | ''' 2 | image is load base on the dataloader 3 | ''' 4 | 5 | import os 6 | import glob 7 | import copy 8 | import random 9 | from pathlib import Path 10 | import numpy as np 11 | from PIL import Image 12 | 13 | import torch 14 | import torchvision.transforms as transforms 15 | from torchvision.utils import save_image 16 | 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from diffae.templates_latent import ffhq256_autoenc_latent 20 | from diffae.experiment import LitModel 21 | from utils.map_net import MappingNet, MappingNetTime 22 | 23 | 24 | class TestDataset(Dataset): 25 | def __init__(self, img_dir): 26 | self.imgs = glob.glob(os.path.join(img_dir, '*')) 27 | self.transform_img = transforms.Compose([ 28 | transforms.Resize((256,256)), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 31 | ]) 32 | 33 | def __len__(self): 34 | return len(self.imgs) 35 | 36 | def __getitem__(self, idx): 37 | img_cont_A = Image.open(self.imgs[idx]) 38 | img_cont_A = self.transform_img(img_cont_A) 39 | return img_cont_A, Path(self.imgs[idx]).name 40 | 41 | 42 | class DiffFSTester: 43 | def __init__(self, args): 44 | self.args = args 45 | 46 | # set seed 47 | random.seed(self.args.seed) 48 | np.random.seed(self.args.seed) 49 | torch.manual_seed(self.args.seed) 50 | 51 | # load diffae 52 | conf = ffhq256_autoenc_latent() 53 | conf.pretrain.path = os.path.join(self.args.diffae_ckpt, 'ffhq256_autoenc/last.ckpt') 54 | conf.latent_infer_path = os.path.join(self.args.diffae_ckpt, 'ffhq256_autoenc/latent.pkl') 55 | 56 | model_diffae = LitModel(conf) 57 | state = torch.load(os.path.join(self.args.diffae_ckpt, f'{conf.name}/last.ckpt'), map_location='cpu') 58 | model_diffae.load_state_dict(state['state_dict'], strict=False) 59 | 60 | # make diffae for domainA (photo) / freeze 61 | self.diffae_A = copy.deepcopy(model_diffae.ema_model) 62 | self.diffae_A = self.diffae_A.to(self.args.device) 63 | self.diffae_A.eval() 64 | self.diffae_A.requires_grad_(False) 65 | 66 | # make diffae for domainB (style) / train 67 | self.diffae_B = copy.deepcopy(model_diffae.ema_model) 68 | self.diffae_B = self.diffae_B.to(self.args.device) 69 | self.diffae_B.eval() 70 | self.diffae_B.requires_grad_(False) 71 | 72 | # mapping net 73 | if self.args.map_net: 74 | if self.args.map_time: 75 | self.model_map = MappingNetTime().to(args.device) 76 | else: 77 | self.model_map = MappingNet().to(args.device) 78 | 79 | self.infer_samp_for = conf._make_diffusion_conf(self.args.T_infer_for).make_sampler() 80 | self.infer_samp_back = conf._make_diffusion_conf(self.args.T_infer_back).make_sampler() 81 | 82 | self.transform_img = transforms.Compose([ 83 | transforms.Resize((256,256)), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) 86 | ]) 87 | 88 | def infer_image(self, img_dir, z_style_B): 89 | 90 | test_dataset = TestDataset(img_dir) 91 | test_loader = DataLoader(test_dataset, batch_size=64) 92 | 93 | for img_cont_A, file_name in test_loader: 94 | 95 | img_cont_A = img_cont_A.to(self.args.device) 96 | z_cont_A = self.diffae_A.encode(img_cont_A) 97 | z_cont_A = z_cont_A.detach().clone() 98 | xt_cont_A = img_cont_A.clone() 99 | 100 | with torch.no_grad(): 101 | # forward ddim (content A) 102 | forwad_indices = list(range(self.args.T_infer_for))\ 103 | [:int(self.args.T_infer_for*self.args.t0_ratio)] 104 | for j in forwad_indices: 105 | t = torch.tensor([j]*len(img_cont_A), device=self.args.device) 106 | out = self.infer_samp_for.ddim_reverse_sample(self.diffae_A, 107 | xt_cont_A, 108 | t, 109 | model_kwargs={'cond': z_cont_A}) 110 | xt_cont_A = out['sample'] 111 | 112 | xt_cont_B = xt_cont_A.detach().clone() 113 | 114 | # reverse ddim (mix) 115 | reverse_indices = list(range(self.args.T_infer_back))[::-1]\ 116 | [int(self.args.T_infer_back*(1-self.args.t0_ratio)):] 117 | for j in reverse_indices: 118 | t = torch.tensor([j]*len(img_cont_A), device=self.args.device) 119 | 120 | if self.args.map_net: 121 | if self.args.map_time: 122 | map_cont = self.model_map(img_cont_A, t) 123 | else: 124 | map_cont = self.model_map(img_cont_A) 125 | xt_cont_B = xt_cont_B + self.args.lambda_map*map_cont 126 | 127 | out = self.infer_samp_back.ddim_sample(self.diffae_B, 128 | xt_cont_B, 129 | t, 130 | model_kwargs={'cond': {'ref':z_style_B, 'input':z_cont_A}, 131 | 'ref_cond_scale': [0, 1, 2, 3], 132 | 'mix': True}) 133 | xt_cont_B = out['sample'].detach().clone() 134 | 135 | save_dir = os.path.join(self.args.work_dir, 'imgs_test') 136 | os.makedirs(save_dir, exist_ok=True) 137 | # save_image(xt_cont_B/2+0.5, os.path.join(save_dir, Path(input_path).name)) 138 | 139 | for i, img in enumerate(xt_cont_B): 140 | save_image(img/2+0.5, os.path.join(save_dir, file_name[i])) 141 | 142 | 143 | def infer_image_all(self): 144 | ckpt_path = os.path.join(self.args.work_dir, 'ckpt', f'iter_{self.args.n_iter}.pt') 145 | ckpt = torch.load(ckpt_path, map_location='cpu') 146 | self.diffae_B.load_state_dict(ckpt['diffae_B']) 147 | self.diffae_B = self.diffae_B.to(self.args.device) 148 | 149 | if self.args.map_net: 150 | self.model_map.load_state_dict(ckpt['model_map']) 151 | self.model_map = self.model_map.to(self.args.device) 152 | 153 | # style image 154 | img_style_B = Image.open(os.path.join(self.args.style_domB_dir, self.args.ref_img)).convert('RGB') 155 | img_style_B = self.transform_img(img_style_B).unsqueeze(0).to(self.args.device) 156 | z_style_B = self.diffae_A.encode(img_style_B) 157 | z_style_B = z_style_B.detach().clone() 158 | 159 | self.infer_image(self.args.infer_dir, z_style_B) --------------------------------------------------------------------------------