├── .gitignore
├── P2_weighting
├── LICENSE
├── README.md
├── datasets
│ ├── README.md
│ └── lsun_bedroom.py
├── evaluations
│ ├── README.md
│ ├── evaluator.py
│ └── requirements.txt
├── guided_diffusion
│ ├── __init__.py
│ ├── dist_util.py
│ ├── fp16_util.py
│ ├── gaussian_diffusion.py
│ ├── image_datasets.py
│ ├── logger.py
│ ├── losses.py
│ ├── nn.py
│ ├── resample.py
│ ├── respace.py
│ ├── script_util.py
│ ├── train_util.py
│ └── unet.py
├── scripts
│ ├── classifier_sample.py
│ ├── classifier_train.py
│ ├── image_nll.py
│ ├── image_sample.py
│ ├── image_train.py
│ ├── super_res_sample.py
│ └── super_res_train.py
└── setup.py
├── README.md
├── diffae
├── __init__.py
├── align.py
├── choices.py
├── cog.yaml
├── config.py
├── config_base.py
├── dataset.py
├── dataset_util.py
├── diffusion
│ ├── __init__.py
│ ├── base.py
│ ├── diffusion.py
│ └── resample.py
├── dist_utils.py
├── evals
│ └── church256_autoenc.txt
├── experiment.py
├── experiment_classifier.py
├── lmdb_writer.py
├── metrics.py
├── model
│ ├── __init__.py
│ ├── blocks.py
│ ├── latentnet.py
│ ├── nn.py
│ ├── unet.py
│ └── unet_autoenc.py
├── predict.py
├── renderer.py
├── run_afhq256-dog.py
├── run_church256.py
├── ssim.py
├── templates.py
├── templates_cls.py
└── templates_latent.py
├── environment.yaml
├── eval_diffaeB.py
├── gen_style_domA.py
├── imgs
└── teaser.jpg
├── imgs_input_domA
├── img1.png
├── img2.png
├── img3.png
├── img4.png
├── img5.png
└── img6.png
├── imgs_style_domB
├── img1.png
├── img2.png
├── img3.png
├── img4.png
└── img5.png
├── scripts
├── eval.sh
├── prepare_train.sh
└── train.sh
├── train_diffaeB.py
└── utils
├── args.py
├── map_net.py
├── tester.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | diffae/checkpoints
6 | exp
7 | imgs_style_domA
8 | P2_weighting/models
--------------------------------------------------------------------------------
/P2_weighting/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/P2_weighting/README.md:
--------------------------------------------------------------------------------
1 | # P2 weighting (CVPR 2022)
2 |
3 | This is the codebase for [Perception Prioritized Training of Diffusion Models](https://arxiv.org/abs/2204.00227).
4 |
5 | This repository is heavily based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion).
6 |
7 | P2 modifies the weighting scheme of the training objective function to improve sample quality. It encourages the diffusion model to focus on recovering signals from highly corrupted data, where the model learns global and perceptually rich concepts. Below figure shows the weighting schemes in terms of SNR.
8 |
9 | 
10 |
11 | ## Pre-trained models
12 |
13 | All models are trained at 256x256 resolution.
14 |
15 | Here are the models trained on FFHQ, CelebA-HQ, CUB, AFHQ-Dogs, Flowers, and MetFaces: [drive](https://1drv.ms/u/s!AkQjJhxDm0Fyhqp_4gkYjwVRBe8V_w?e=Et3ITH)
16 |
17 | ## Requirements
18 | We tested on PyTorch 1.7.1, single RTX8000 GPU.
19 |
20 | ## Sampling from pre-trained models
21 |
22 | First, set PYTHONPATH variable to point to the root of the repository. Do the same when training new models.
23 |
24 | ```
25 | export PYTHONPATH=$PYTHONPATH:$(pwd)
26 | ```
27 |
28 | Put model checkpoints into a folder `models/`.
29 |
30 | Samples will be saved in `samples/`.
31 |
32 | ```
33 | python scripts/image_sample.py --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 250 --model_path models/ffhq_p2.pt --sample_dir samples
34 | ```
35 |
36 | To sample for 250 timesteps without DDIM, replace `--timestep_respacing ddim25` to `--timestep_respacing 250`, and replace `--use_ddim True` with `--use_ddim False`.
37 |
38 | ## Training your models
39 |
40 | `--p2_gamma` and `--p2_k` are two hyperparameters of P2 weighting. We used `--p2_gamma 0.5 --p2_k 1` and `--p2_gamma 1 --p2_k 1` in the paper.
41 |
42 | Logs and models will be saved in `logs/`. You should modify `--data_dir`.
43 |
44 | We used lightweight version (93M parameter) of [ADM](https://arxiv.org/abs/2105.05233) (over 500M) as default model configuration. You may modify the model.
45 |
46 | ```
47 | python scripts/image_train.py --data_dir data/DATASET_NAME --attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --lr 2e-5 --batch_size 8 --rescale_learned_sigmas True --p2_gamma 1 --p2_k 1 --log_dir logs
48 | ```
49 |
50 |
51 |
--------------------------------------------------------------------------------
/P2_weighting/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Downloading datasets
2 |
3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase.
4 |
5 | ## Class-conditional ImageNet
6 |
7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class.
8 |
9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script:
10 |
11 | ```
12 | for file in *.tar; do tar xf "$file"; rm "$file"; done
13 | ```
14 |
15 | This will extract and remove each tar file in turn.
16 |
17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels.
18 |
19 | ## LSUN bedroom
20 |
21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so:
22 |
23 | ```
24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir
25 | ```
26 |
27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument.
28 |
--------------------------------------------------------------------------------
/P2_weighting/datasets/lsun_bedroom.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert an LSUN lmdb database into a directory of images.
3 | """
4 |
5 | import argparse
6 | import io
7 | import os
8 |
9 | from PIL import Image
10 | import lmdb
11 | import numpy as np
12 |
13 |
14 | def read_images(lmdb_path, image_size):
15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True)
16 | with env.begin(write=False) as transaction:
17 | cursor = transaction.cursor()
18 | for _, webp_data in cursor:
19 | img = Image.open(io.BytesIO(webp_data))
20 | width, height = img.size
21 | scale = image_size / min(width, height)
22 | img = img.resize(
23 | (int(round(scale * width)), int(round(scale * height))),
24 | resample=Image.BOX,
25 | )
26 | arr = np.array(img)
27 | h, w, _ = arr.shape
28 | h_off = (h - image_size) // 2
29 | w_off = (w - image_size) // 2
30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size]
31 | yield arr
32 |
33 |
34 | def dump_images(out_dir, images, prefix):
35 | if not os.path.exists(out_dir):
36 | os.mkdir(out_dir)
37 | for i, img in enumerate(images):
38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png"))
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--image-size", help="new image size", type=int, default=256)
44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom")
45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database")
46 | parser.add_argument("out_dir", help="path to output directory")
47 | args = parser.parse_args()
48 |
49 | images = read_images(args.lmdb_path, args.image_size)
50 | dump_images(args.out_dir, images, args.prefix)
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/P2_weighting/evaluations/README.md:
--------------------------------------------------------------------------------
1 | # Evaluations
2 |
3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4 |
5 | # Download batches
6 |
7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
8 |
9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
10 |
11 | Here are links to download all of the sample and reference batches:
12 |
13 | * LSUN
14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
25 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
26 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
27 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
28 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
29 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
30 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
31 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
32 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
33 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
34 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
35 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
36 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
37 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
38 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
39 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
40 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
41 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
42 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
43 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
44 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
45 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
46 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
47 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
48 |
49 | # Run evaluations
50 |
51 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
52 |
53 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
54 |
55 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
56 |
57 | ```
58 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
59 | ...
60 | computing reference batch activations...
61 | computing/reading reference batch statistics...
62 | computing sample batch activations...
63 | computing/reading sample batch statistics...
64 | Computing evaluations...
65 | Inception Score: 215.8370361328125
66 | FID: 3.9425574129223264
67 | sFID: 6.140433703346162
68 | Precision: 0.8265
69 | Recall: 0.5309
70 | ```
71 |
--------------------------------------------------------------------------------
/P2_weighting/evaluations/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu>=2.0
2 | scipy
3 | requests
4 | tqdm
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "Improved Denoising Diffusion Probabilistic Models".
3 | """
4 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | # Change this to reflect your cluster layout.
15 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
16 | GPUS_PER_NODE = 8
17 |
18 | SETUP_RETRY_COUNT = 3
19 |
20 |
21 | def setup_dist():
22 | """
23 | Setup a distributed process group.
24 | """
25 | if dist.is_initialized():
26 | return
27 | # os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
28 |
29 | comm = MPI.COMM_WORLD
30 | backend = "gloo" if not th.cuda.is_available() else "nccl"
31 |
32 | if backend == "gloo":
33 | hostname = "localhost"
34 | else:
35 | hostname = socket.gethostbyname(socket.getfqdn())
36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
37 | os.environ["RANK"] = str(comm.rank)
38 | os.environ["WORLD_SIZE"] = str(comm.size)
39 |
40 | port = comm.bcast(_find_free_port(), root=0)
41 | os.environ["MASTER_PORT"] = str(port)
42 | dist.init_process_group(backend=backend, init_method="env://")
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file without redundant fetches across MPI ranks.
57 | """
58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
59 | if MPI.COMM_WORLD.Get_rank() == 0:
60 | with bf.BlobFile(path, "rb") as f:
61 | data = f.read()
62 | num_chunks = len(data) // chunk_size
63 | if len(data) % chunk_size:
64 | num_chunks += 1
65 | MPI.COMM_WORLD.bcast(num_chunks)
66 | for i in range(0, len(data), chunk_size):
67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
68 | else:
69 | num_chunks = MPI.COMM_WORLD.bcast(None)
70 | data = bytes()
71 | for _ in range(num_chunks):
72 | data += MPI.COMM_WORLD.bcast(None)
73 |
74 | return th.load(io.BytesIO(data), **kwargs)
75 |
76 |
77 | def sync_params(params):
78 | """
79 | Synchronize a sequence of Tensors across ranks from rank 0.
80 | """
81 | for p in params:
82 | with th.no_grad():
83 | dist.broadcast(p, 0)
84 |
85 |
86 | def _find_free_port():
87 | try:
88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
89 | s.bind(("", 0))
90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
91 | return s.getsockname()[1]
92 | finally:
93 | s.close()
94 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | from mpi4py import MPI
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | def load_data(
12 | *,
13 | data_dir,
14 | batch_size,
15 | image_size,
16 | class_cond=False,
17 | deterministic=False,
18 | random_crop=False,
19 | random_flip=True,
20 | ):
21 | """
22 | For a dataset, create a generator over (images, kwargs) pairs.
23 |
24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
25 | more keys, each of which map to a batched Tensor of their own.
26 | The kwargs dict can be used for class labels, in which case the key is "y"
27 | and the values are integer tensors of class labels.
28 |
29 | :param data_dir: a dataset directory.
30 | :param batch_size: the batch size of each returned pair.
31 | :param image_size: the size to which images are resized.
32 | :param class_cond: if True, include a "y" key in returned dicts for class
33 | label. If classes are not available and this is true, an
34 | exception will be raised.
35 | :param deterministic: if True, yield results in a deterministic order.
36 | :param random_crop: if True, randomly crop the images for augmentation.
37 | :param random_flip: if True, randomly flip the images for augmentation.
38 | """
39 | if not data_dir:
40 | raise ValueError("unspecified data directory")
41 | all_files = _list_image_files_recursively(data_dir)
42 | classes = None
43 | if class_cond:
44 | # Assume classes are the first part of the filename,
45 | # before an underscore.
46 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
48 | classes = [sorted_classes[x] for x in class_names]
49 | dataset = ImageDataset(
50 | image_size,
51 | all_files,
52 | classes=classes,
53 | shard=MPI.COMM_WORLD.Get_rank(),
54 | num_shards=MPI.COMM_WORLD.Get_size(),
55 | random_crop=random_crop,
56 | random_flip=random_flip,
57 | )
58 | if deterministic:
59 | loader = DataLoader(
60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
61 | )
62 | else:
63 | loader = DataLoader(
64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
65 | )
66 | while True:
67 | yield from loader
68 |
69 |
70 | def _list_image_files_recursively(data_dir):
71 | results = []
72 | for entry in sorted(bf.listdir(data_dir)):
73 | full_path = bf.join(data_dir, entry)
74 | ext = entry.split(".")[-1]
75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
76 | results.append(full_path)
77 | elif bf.isdir(full_path):
78 | results.extend(_list_image_files_recursively(full_path))
79 | return results
80 |
81 |
82 | class ImageDataset(Dataset):
83 | def __init__(
84 | self,
85 | resolution,
86 | image_paths,
87 | classes=None,
88 | shard=0,
89 | num_shards=1,
90 | random_crop=False,
91 | random_flip=True,
92 | ):
93 | super().__init__()
94 | self.resolution = resolution
95 | self.local_images = image_paths[shard:][::num_shards]
96 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
97 | self.random_crop = random_crop
98 | self.random_flip = random_flip
99 |
100 | def __len__(self):
101 | return len(self.local_images)
102 |
103 | def __getitem__(self, idx):
104 | path = self.local_images[idx]
105 | with bf.BlobFile(path, "rb") as f:
106 | pil_image = Image.open(f)
107 | pil_image.load()
108 | pil_image = pil_image.convert("RGB")
109 |
110 | if self.random_crop:
111 | arr = random_crop_arr(pil_image, self.resolution)
112 | else:
113 | arr = center_crop_arr(pil_image, self.resolution)
114 |
115 | if self.random_flip and random.random() < 0.5:
116 | arr = arr[:, ::-1]
117 |
118 | arr = arr.astype(np.float32) / 127.5 - 1
119 |
120 | out_dict = {}
121 | if self.local_classes is not None:
122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123 | return np.transpose(arr, [2, 0, 1]), out_dict
124 |
125 |
126 | def center_crop_arr(pil_image, image_size):
127 | # We are not on a new enough PIL to support the `reducing_gap`
128 | # argument, which uses BOX downsampling at powers of two first.
129 | # Thus, we do it by hand to improve downsample quality.
130 | while min(*pil_image.size) >= 2 * image_size:
131 | pil_image = pil_image.resize(
132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
133 | )
134 |
135 | scale = image_size / min(*pil_image.size)
136 | pil_image = pil_image.resize(
137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
138 | )
139 |
140 | arr = np.array(pil_image)
141 | crop_y = (arr.shape[0] - image_size) // 2
142 | crop_x = (arr.shape[1] - image_size) // 2
143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
144 |
145 |
146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
150 |
151 | # We are not on a new enough PIL to support the `reducing_gap`
152 | # argument, which uses BOX downsampling at powers of two first.
153 | # Thus, we do it by hand to improve downsample quality.
154 | while min(*pil_image.size) >= 2 * smaller_dim_size:
155 | pil_image = pil_image.resize(
156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
157 | )
158 |
159 | scale = smaller_dim_size / min(*pil_image.size)
160 | pil_image = pil_image.resize(
161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
162 | )
163 |
164 | arr = np.array(pil_image)
165 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
166 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
168 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/P2_weighting/guided_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 |
16 | # For ImageNet experiments, this was a good default value.
17 | # We found that the lg_loss_scale quickly climbed to
18 | # 20-21 within the first ~1K steps of training.
19 | INITIAL_LOG_LOSS_SCALE = 20.0
20 |
21 |
22 | class TrainLoop:
23 | def __init__(
24 | self,
25 | *,
26 | model,
27 | diffusion,
28 | data,
29 | batch_size,
30 | microbatch,
31 | lr,
32 | ema_rate,
33 | log_interval,
34 | save_interval,
35 | resume_checkpoint,
36 | use_fp16=False,
37 | fp16_scale_growth=1e-3,
38 | schedule_sampler=None,
39 | weight_decay=0.0,
40 | lr_anneal_steps=0,
41 | ):
42 | self.model = model
43 | self.diffusion = diffusion
44 | self.data = data
45 | self.batch_size = batch_size
46 | self.microbatch = microbatch if microbatch > 0 else batch_size
47 | self.lr = lr
48 | self.ema_rate = (
49 | [ema_rate]
50 | if isinstance(ema_rate, float)
51 | else [float(x) for x in ema_rate.split(",")]
52 | )
53 | self.log_interval = log_interval
54 | self.save_interval = save_interval
55 | self.resume_checkpoint = resume_checkpoint
56 | self.use_fp16 = use_fp16
57 | self.fp16_scale_growth = fp16_scale_growth
58 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
59 | self.weight_decay = weight_decay
60 | self.lr_anneal_steps = lr_anneal_steps
61 |
62 | self.step = 0
63 | self.resume_step = 0
64 | self.global_batch = self.batch_size * dist.get_world_size()
65 |
66 | self.sync_cuda = th.cuda.is_available()
67 |
68 | self._load_and_sync_parameters()
69 | self.mp_trainer = MixedPrecisionTrainer(
70 | model=self.model,
71 | use_fp16=self.use_fp16,
72 | fp16_scale_growth=fp16_scale_growth,
73 | )
74 |
75 | self.opt = AdamW(
76 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
77 | )
78 | if self.resume_step:
79 | self._load_optimizer_state()
80 | # Model was resumed, either due to a restart or a checkpoint
81 | # being specified at the command line.
82 | self.ema_params = [
83 | self._load_ema_parameters(rate) for rate in self.ema_rate
84 | ]
85 | else:
86 | self.ema_params = [
87 | copy.deepcopy(self.mp_trainer.master_params)
88 | for _ in range(len(self.ema_rate))
89 | ]
90 |
91 | if th.cuda.is_available():
92 | self.use_ddp = True
93 | self.ddp_model = DDP(
94 | self.model,
95 | device_ids=[dist_util.dev()],
96 | output_device=dist_util.dev(),
97 | broadcast_buffers=False,
98 | bucket_cap_mb=128,
99 | find_unused_parameters=False,
100 | )
101 | else:
102 | if dist.get_world_size() > 1:
103 | logger.warn(
104 | "Distributed training requires CUDA. "
105 | "Gradients will not be synchronized properly!"
106 | )
107 | self.use_ddp = False
108 | self.ddp_model = self.model
109 |
110 | def _load_and_sync_parameters(self):
111 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
112 |
113 | if resume_checkpoint:
114 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
115 | if dist.get_rank() == 0:
116 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
117 | self.model.load_state_dict(
118 | dist_util.load_state_dict(
119 | resume_checkpoint, map_location=dist_util.dev()
120 | )
121 | )
122 |
123 | dist_util.sync_params(self.model.parameters())
124 |
125 | def _load_ema_parameters(self, rate):
126 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
127 |
128 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
129 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
130 | if ema_checkpoint:
131 | if dist.get_rank() == 0:
132 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
133 | state_dict = dist_util.load_state_dict(
134 | ema_checkpoint, map_location=dist_util.dev()
135 | )
136 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
137 |
138 | dist_util.sync_params(ema_params)
139 | return ema_params
140 |
141 | def _load_optimizer_state(self):
142 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
143 | opt_checkpoint = bf.join(
144 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
145 | )
146 | if bf.exists(opt_checkpoint):
147 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
148 | state_dict = dist_util.load_state_dict(
149 | opt_checkpoint, map_location=dist_util.dev()
150 | )
151 | self.opt.load_state_dict(state_dict)
152 |
153 | def run_loop(self):
154 | while (
155 | not self.lr_anneal_steps
156 | or self.step + self.resume_step < self.lr_anneal_steps
157 | ):
158 | batch, cond = next(self.data)
159 | self.run_step(batch, cond)
160 | if self.step % self.log_interval == 0:
161 | logger.dumpkvs()
162 | if self.step % self.save_interval == 0:
163 | self.save()
164 | # Run for a finite amount of time in integration tests.
165 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
166 | return
167 | self.step += 1
168 | # Save the last checkpoint if it wasn't already saved.
169 | if (self.step - 1) % self.save_interval != 0:
170 | self.save()
171 |
172 | def run_step(self, batch, cond):
173 | self.forward_backward(batch, cond)
174 | took_step = self.mp_trainer.optimize(self.opt)
175 | if took_step:
176 | self._update_ema()
177 | self._anneal_lr()
178 | self.log_step()
179 |
180 | def forward_backward(self, batch, cond):
181 | self.mp_trainer.zero_grad()
182 | for i in range(0, batch.shape[0], self.microbatch):
183 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
184 | micro_cond = {
185 | k: v[i : i + self.microbatch].to(dist_util.dev())
186 | for k, v in cond.items()
187 | }
188 | last_batch = (i + self.microbatch) >= batch.shape[0]
189 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
190 |
191 | compute_losses = functools.partial(
192 | self.diffusion.training_losses,
193 | self.ddp_model,
194 | micro,
195 | t,
196 | model_kwargs=micro_cond,
197 | )
198 |
199 | if last_batch or not self.use_ddp:
200 | losses = compute_losses()
201 | else:
202 | with self.ddp_model.no_sync():
203 | losses = compute_losses()
204 |
205 | if isinstance(self.schedule_sampler, LossAwareSampler):
206 | self.schedule_sampler.update_with_local_losses(
207 | t, losses["loss"].detach()
208 | )
209 |
210 | loss = (losses["loss"] * weights).mean()
211 | log_loss_dict(
212 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
213 | )
214 | self.mp_trainer.backward(loss)
215 |
216 | def _update_ema(self):
217 | for rate, params in zip(self.ema_rate, self.ema_params):
218 | update_ema(params, self.mp_trainer.master_params, rate=rate)
219 |
220 | def _anneal_lr(self):
221 | if not self.lr_anneal_steps:
222 | return
223 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
224 | lr = self.lr * (1 - frac_done)
225 | for param_group in self.opt.param_groups:
226 | param_group["lr"] = lr
227 |
228 | def log_step(self):
229 | logger.logkv("step", self.step + self.resume_step)
230 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
231 |
232 | def save(self):
233 | def save_checkpoint(rate, params):
234 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
235 | if dist.get_rank() == 0:
236 | logger.log(f"saving model {rate}...")
237 | if not rate:
238 | filename = f"model{(self.step+self.resume_step):06d}.pt"
239 | else:
240 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
241 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
242 | th.save(state_dict, f)
243 |
244 | save_checkpoint(0, self.mp_trainer.master_params)
245 | for rate, params in zip(self.ema_rate, self.ema_params):
246 | save_checkpoint(rate, params)
247 |
248 | if dist.get_rank() == 0:
249 | with bf.BlobFile(
250 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
251 | "wb",
252 | ) as f:
253 | th.save(self.opt.state_dict(), f)
254 |
255 | dist.barrier()
256 |
257 |
258 | def parse_resume_step_from_filename(filename):
259 | """
260 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
261 | checkpoint's number of steps.
262 | """
263 | split = filename.split("model")
264 | if len(split) < 2:
265 | return 0
266 | split1 = split[-1].split(".")[0]
267 | try:
268 | return int(split1)
269 | except ValueError:
270 | return 0
271 |
272 |
273 | def get_blob_logdir():
274 | # You can change this to be a separate path to save checkpoints to
275 | # a blobstore or some external drive.
276 | return logger.get_dir()
277 |
278 |
279 | def find_resume_checkpoint():
280 | # On your infrastructure, you may want to override this to automatically
281 | # discover the latest checkpoint on your blob storage, etc.
282 | return None
283 |
284 |
285 | def find_ema_checkpoint(main_checkpoint, step, rate):
286 | if main_checkpoint is None:
287 | return None
288 | filename = f"ema_{rate}_{(step):06d}.pt"
289 | path = bf.join(bf.dirname(main_checkpoint), filename)
290 | if bf.exists(path):
291 | return path
292 | return None
293 |
294 |
295 | def log_loss_dict(diffusion, ts, losses):
296 | for key, values in losses.items():
297 | logger.logkv_mean(key, values.mean().item())
298 | # Log the quantiles (four quartiles, in particular).
299 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
300 | quartile = int(4 * sub_t / diffusion.num_timesteps)
301 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
302 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/classifier_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Like image_sample.py, but use a noisy image classifier to guide the sampling
3 | process towards more realistic images.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 | import torch.nn.functional as F
13 |
14 | from guided_diffusion import dist_util, logger
15 | from guided_diffusion.script_util import (
16 | NUM_CLASSES,
17 | model_and_diffusion_defaults,
18 | classifier_defaults,
19 | create_model_and_diffusion,
20 | create_classifier,
21 | add_dict_to_argparser,
22 | args_to_dict,
23 | )
24 |
25 |
26 | def main():
27 | args = create_argparser().parse_args()
28 |
29 | dist_util.setup_dist()
30 | logger.configure()
31 |
32 | logger.log("creating model and diffusion...")
33 | model, diffusion = create_model_and_diffusion(
34 | **args_to_dict(args, model_and_diffusion_defaults().keys())
35 | )
36 | model.load_state_dict(
37 | dist_util.load_state_dict(args.model_path, map_location="cpu")
38 | )
39 | model.to(dist_util.dev())
40 | if args.use_fp16:
41 | model.convert_to_fp16()
42 | model.eval()
43 |
44 | logger.log("loading classifier...")
45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
46 | classifier.load_state_dict(
47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu")
48 | )
49 | classifier.to(dist_util.dev())
50 | if args.classifier_use_fp16:
51 | classifier.convert_to_fp16()
52 | classifier.eval()
53 |
54 | def cond_fn(x, t, y=None):
55 | assert y is not None
56 | with th.enable_grad():
57 | x_in = x.detach().requires_grad_(True)
58 | logits = classifier(x_in, t)
59 | log_probs = F.log_softmax(logits, dim=-1)
60 | selected = log_probs[range(len(logits)), y.view(-1)]
61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
62 |
63 | def model_fn(x, t, y=None):
64 | assert y is not None
65 | return model(x, t, y if args.class_cond else None)
66 |
67 | logger.log("sampling...")
68 | all_images = []
69 | all_labels = []
70 | while len(all_images) * args.batch_size < args.num_samples:
71 | model_kwargs = {}
72 | classes = th.randint(
73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
74 | )
75 | model_kwargs["y"] = classes
76 | sample_fn = (
77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
78 | )
79 | sample = sample_fn(
80 | model_fn,
81 | (args.batch_size, 3, args.image_size, args.image_size),
82 | clip_denoised=args.clip_denoised,
83 | model_kwargs=model_kwargs,
84 | cond_fn=cond_fn,
85 | device=dist_util.dev(),
86 | )
87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
88 | sample = sample.permute(0, 2, 3, 1)
89 | sample = sample.contiguous()
90 |
91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
95 | dist.all_gather(gathered_labels, classes)
96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
97 | logger.log(f"created {len(all_images) * args.batch_size} samples")
98 |
99 | arr = np.concatenate(all_images, axis=0)
100 | arr = arr[: args.num_samples]
101 | label_arr = np.concatenate(all_labels, axis=0)
102 | label_arr = label_arr[: args.num_samples]
103 | if dist.get_rank() == 0:
104 | shape_str = "x".join([str(x) for x in arr.shape])
105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
106 | logger.log(f"saving to {out_path}")
107 | np.savez(out_path, arr, label_arr)
108 |
109 | dist.barrier()
110 | logger.log("sampling complete")
111 |
112 |
113 | def create_argparser():
114 | defaults = dict(
115 | clip_denoised=True,
116 | num_samples=10000,
117 | batch_size=16,
118 | use_ddim=False,
119 | model_path="",
120 | classifier_path="",
121 | classifier_scale=1.0,
122 | )
123 | defaults.update(model_and_diffusion_defaults())
124 | defaults.update(classifier_defaults())
125 | parser = argparse.ArgumentParser()
126 | add_dict_to_argparser(parser, defaults)
127 | return parser
128 |
129 |
130 | if __name__ == "__main__":
131 | main()
132 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/classifier_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a noised image classifier on ImageNet.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import blobfile as bf
9 | import torch as th
10 | import torch.distributed as dist
11 | import torch.nn.functional as F
12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
13 | from torch.optim import AdamW
14 |
15 | from guided_diffusion import dist_util, logger
16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer
17 | from guided_diffusion.image_datasets import load_data
18 | from guided_diffusion.resample import create_named_schedule_sampler
19 | from guided_diffusion.script_util import (
20 | add_dict_to_argparser,
21 | args_to_dict,
22 | classifier_and_diffusion_defaults,
23 | create_classifier_and_diffusion,
24 | )
25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict
26 |
27 |
28 | def main():
29 | args = create_argparser().parse_args()
30 |
31 | dist_util.setup_dist()
32 | logger.configure()
33 |
34 | logger.log("creating model and diffusion...")
35 | model, diffusion = create_classifier_and_diffusion(
36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys())
37 | )
38 | model.to(dist_util.dev())
39 | if args.noised:
40 | schedule_sampler = create_named_schedule_sampler(
41 | args.schedule_sampler, diffusion
42 | )
43 |
44 | resume_step = 0
45 | if args.resume_checkpoint:
46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
47 | if dist.get_rank() == 0:
48 | logger.log(
49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
50 | )
51 | model.load_state_dict(
52 | dist_util.load_state_dict(
53 | args.resume_checkpoint, map_location=dist_util.dev()
54 | )
55 | )
56 |
57 | # Needed for creating correct EMAs and fp16 parameters.
58 | dist_util.sync_params(model.parameters())
59 |
60 | mp_trainer = MixedPrecisionTrainer(
61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0
62 | )
63 |
64 | model = DDP(
65 | model,
66 | device_ids=[dist_util.dev()],
67 | output_device=dist_util.dev(),
68 | broadcast_buffers=False,
69 | bucket_cap_mb=128,
70 | find_unused_parameters=False,
71 | )
72 |
73 | logger.log("creating data loader...")
74 | data = load_data(
75 | data_dir=args.data_dir,
76 | batch_size=args.batch_size,
77 | image_size=args.image_size,
78 | class_cond=True,
79 | random_crop=True,
80 | )
81 | if args.val_data_dir:
82 | val_data = load_data(
83 | data_dir=args.val_data_dir,
84 | batch_size=args.batch_size,
85 | image_size=args.image_size,
86 | class_cond=True,
87 | )
88 | else:
89 | val_data = None
90 |
91 | logger.log(f"creating optimizer...")
92 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay)
93 | if args.resume_checkpoint:
94 | opt_checkpoint = bf.join(
95 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
96 | )
97 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
98 | opt.load_state_dict(
99 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
100 | )
101 |
102 | logger.log("training classifier model...")
103 |
104 | def forward_backward_log(data_loader, prefix="train"):
105 | batch, extra = next(data_loader)
106 | labels = extra["y"].to(dist_util.dev())
107 |
108 | batch = batch.to(dist_util.dev())
109 | # Noisy images
110 | if args.noised:
111 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
112 | batch = diffusion.q_sample(batch, t)
113 | else:
114 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())
115 |
116 | for i, (sub_batch, sub_labels, sub_t) in enumerate(
117 | split_microbatches(args.microbatch, batch, labels, t)
118 | ):
119 | logits = model(sub_batch, timesteps=sub_t)
120 | loss = F.cross_entropy(logits, sub_labels, reduction="none")
121 |
122 | losses = {}
123 | losses[f"{prefix}_loss"] = loss.detach()
124 | losses[f"{prefix}_acc@1"] = compute_top_k(
125 | logits, sub_labels, k=1, reduction="none"
126 | )
127 | losses[f"{prefix}_acc@5"] = compute_top_k(
128 | logits, sub_labels, k=5, reduction="none"
129 | )
130 | log_loss_dict(diffusion, sub_t, losses)
131 | del losses
132 | loss = loss.mean()
133 | if loss.requires_grad:
134 | if i == 0:
135 | mp_trainer.zero_grad()
136 | mp_trainer.backward(loss * len(sub_batch) / len(batch))
137 |
138 | for step in range(args.iterations - resume_step):
139 | logger.logkv("step", step + resume_step)
140 | logger.logkv(
141 | "samples",
142 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
143 | )
144 | if args.anneal_lr:
145 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations)
146 | forward_backward_log(data)
147 | mp_trainer.optimize(opt)
148 | if val_data is not None and not step % args.eval_interval:
149 | with th.no_grad():
150 | with model.no_sync():
151 | model.eval()
152 | forward_backward_log(val_data, prefix="val")
153 | model.train()
154 | if not step % args.log_interval:
155 | logger.dumpkvs()
156 | if (
157 | step
158 | and dist.get_rank() == 0
159 | and not (step + resume_step) % args.save_interval
160 | ):
161 | logger.log("saving model...")
162 | save_model(mp_trainer, opt, step + resume_step)
163 |
164 | if dist.get_rank() == 0:
165 | logger.log("saving model...")
166 | save_model(mp_trainer, opt, step + resume_step)
167 | dist.barrier()
168 |
169 |
170 | def set_annealed_lr(opt, base_lr, frac_done):
171 | lr = base_lr * (1 - frac_done)
172 | for param_group in opt.param_groups:
173 | param_group["lr"] = lr
174 |
175 |
176 | def save_model(mp_trainer, opt, step):
177 | if dist.get_rank() == 0:
178 | th.save(
179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params),
180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"),
181 | )
182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt"))
183 |
184 |
185 | def compute_top_k(logits, labels, k, reduction="mean"):
186 | _, top_ks = th.topk(logits, k, dim=-1)
187 | if reduction == "mean":
188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
189 | elif reduction == "none":
190 | return (top_ks == labels[:, None]).float().sum(dim=-1)
191 |
192 |
193 | def split_microbatches(microbatch, *args):
194 | bs = len(args[0])
195 | if microbatch == -1 or microbatch >= bs:
196 | yield tuple(args)
197 | else:
198 | for i in range(0, bs, microbatch):
199 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args)
200 |
201 |
202 | def create_argparser():
203 | defaults = dict(
204 | data_dir="",
205 | val_data_dir="",
206 | noised=True,
207 | iterations=150000,
208 | lr=3e-4,
209 | weight_decay=0.0,
210 | anneal_lr=False,
211 | batch_size=4,
212 | microbatch=-1,
213 | schedule_sampler="uniform",
214 | resume_checkpoint="",
215 | log_interval=10,
216 | eval_interval=5,
217 | save_interval=10000,
218 | )
219 | defaults.update(classifier_and_diffusion_defaults())
220 | parser = argparse.ArgumentParser()
221 | add_dict_to_argparser(parser, defaults)
222 | return parser
223 |
224 |
225 | if __name__ == "__main__":
226 | main()
227 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/image_nll.py:
--------------------------------------------------------------------------------
1 | """
2 | Approximate the bits/dimension for an image model.
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | import numpy as np
9 | import torch.distributed as dist
10 |
11 | from guided_diffusion import dist_util, logger
12 | from guided_diffusion.image_datasets import load_data
13 | from guided_diffusion.script_util import (
14 | model_and_diffusion_defaults,
15 | create_model_and_diffusion,
16 | add_dict_to_argparser,
17 | args_to_dict,
18 | )
19 |
20 |
21 | def main():
22 | args = create_argparser().parse_args()
23 |
24 | dist_util.setup_dist()
25 | logger.configure()
26 |
27 | logger.log("creating model and diffusion...")
28 | model, diffusion = create_model_and_diffusion(
29 | **args_to_dict(args, model_and_diffusion_defaults().keys())
30 | )
31 | model.load_state_dict(
32 | dist_util.load_state_dict(args.model_path, map_location="cpu")
33 | )
34 | model.to(dist_util.dev())
35 | model.eval()
36 |
37 | logger.log("creating data loader...")
38 | data = load_data(
39 | data_dir=args.data_dir,
40 | batch_size=args.batch_size,
41 | image_size=args.image_size,
42 | class_cond=args.class_cond,
43 | deterministic=True,
44 | )
45 |
46 | logger.log("evaluating...")
47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised)
48 |
49 |
50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised):
51 | all_bpd = []
52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []}
53 | num_complete = 0
54 | while num_complete < num_samples:
55 | batch, model_kwargs = next(data)
56 | batch = batch.to(dist_util.dev())
57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
58 | minibatch_metrics = diffusion.calc_bpd_loop(
59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs
60 | )
61 |
62 | for key, term_list in all_metrics.items():
63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size()
64 | dist.all_reduce(terms)
65 | term_list.append(terms.detach().cpu().numpy())
66 |
67 | total_bpd = minibatch_metrics["total_bpd"]
68 | total_bpd = total_bpd.mean() / dist.get_world_size()
69 | dist.all_reduce(total_bpd)
70 | all_bpd.append(total_bpd.item())
71 | num_complete += dist.get_world_size() * batch.shape[0]
72 |
73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}")
74 |
75 | if dist.get_rank() == 0:
76 | for name, terms in all_metrics.items():
77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz")
78 | logger.log(f"saving {name} terms to {out_path}")
79 | np.savez(out_path, np.mean(np.stack(terms), axis=0))
80 |
81 | dist.barrier()
82 | logger.log("evaluation complete")
83 |
84 |
85 | def create_argparser():
86 | defaults = dict(
87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path=""
88 | )
89 | defaults.update(model_and_diffusion_defaults())
90 | parser = argparse.ArgumentParser()
91 | add_dict_to_argparser(parser, defaults)
92 | return parser
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/image_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 |
13 | from guided_diffusion import dist_util, logger
14 | from guided_diffusion.script_util import (
15 | NUM_CLASSES,
16 | model_and_diffusion_defaults,
17 | create_model_and_diffusion,
18 | add_dict_to_argparser,
19 | args_to_dict,
20 | )
21 | from torchvision import utils
22 |
23 |
24 | def main():
25 | args = create_argparser().parse_args()
26 |
27 | dist_util.setup_dist()
28 | logger.configure(dir=args.sample_dir)
29 |
30 | logger.log("creating model and diffusion...")
31 | model, diffusion = create_model_and_diffusion(
32 | **args_to_dict(args, model_and_diffusion_defaults().keys())
33 | )
34 | model.load_state_dict(
35 | dist_util.load_state_dict(args.model_path, map_location="cpu")
36 | )
37 | model.to(dist_util.dev())
38 | if args.use_fp16:
39 | model.convert_to_fp16()
40 | model.eval()
41 |
42 | logger.log("sampling...")
43 | all_images = []
44 | all_labels = []
45 | count = 0
46 | while count * args.batch_size < args.num_samples:
47 | model_kwargs = {}
48 | if args.class_cond:
49 | classes = th.randint(
50 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
51 | )
52 | model_kwargs["y"] = classes
53 | sample_fn = (
54 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
55 | )
56 | sample = sample_fn(
57 | model,
58 | (args.batch_size, 3, args.image_size, args.image_size),
59 | clip_denoised=args.clip_denoised,
60 | model_kwargs=model_kwargs,
61 | )
62 | # saving png
63 | for i in range(args.batch_size):
64 | out_path = os.path.join(logger.get_dir(),
65 | f"{str(count * args.batch_size + i).zfill(5)}.png")
66 | utils.save_image(
67 | sample[i].unsqueeze(0),
68 | out_path,
69 | nrow=1,
70 | normalize=True,
71 | range=(-1, 1),
72 | )
73 | # saving npz
74 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
75 | sample = sample.permute(0, 2, 3, 1)
76 | sample = sample.contiguous()
77 |
78 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
79 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
80 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
81 | if args.class_cond:
82 | gathered_labels = [
83 | th.zeros_like(classes) for _ in range(dist.get_world_size())
84 | ]
85 | dist.all_gather(gathered_labels, classes)
86 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
87 | logger.log(f"created {len(all_images) * args.batch_size} samples")
88 |
89 | arr = np.concatenate(all_images, axis=0)
90 | arr = arr[: args.num_samples]
91 | if args.class_cond:
92 | label_arr = np.concatenate(all_labels, axis=0)
93 | label_arr = label_arr[: args.num_samples]
94 | if dist.get_rank() == 0:
95 | shape_str = "x".join([str(x) for x in arr.shape])
96 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
97 | logger.log(f"saving to {out_path}")
98 | if args.class_cond:
99 | np.savez(out_path, arr, label_arr)
100 | else:
101 | np.savez(out_path, arr)
102 |
103 | dist.barrier()
104 | logger.log("sampling complete")
105 |
106 |
107 | def create_argparser():
108 | defaults = dict(
109 | clip_denoised=True,
110 | num_samples=10000,
111 | batch_size=16,
112 | use_ddim=False,
113 | model_path="",
114 | sample_dir="",
115 | )
116 | defaults.update(model_and_diffusion_defaults())
117 | parser = argparse.ArgumentParser()
118 | add_dict_to_argparser(parser, defaults)
119 | return parser
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/image_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 |
7 | from guided_diffusion import dist_util, logger
8 | from guided_diffusion.image_datasets import load_data
9 | from guided_diffusion.resample import create_named_schedule_sampler
10 | from guided_diffusion.script_util import (
11 | model_and_diffusion_defaults,
12 | create_model_and_diffusion,
13 | args_to_dict,
14 | add_dict_to_argparser,
15 | )
16 | from guided_diffusion.train_util import TrainLoop
17 |
18 |
19 | def main():
20 | args = create_argparser().parse_args()
21 |
22 | dist_util.setup_dist()
23 | logger.configure(dir=args.log_dir)
24 |
25 | logger.log("creating model and diffusion...")
26 | model, diffusion = create_model_and_diffusion(
27 | **args_to_dict(args, model_and_diffusion_defaults().keys())
28 | )
29 | model.to(dist_util.dev())
30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
31 |
32 | logger.log("creating data loader...")
33 | data = load_data(
34 | data_dir=args.data_dir,
35 | batch_size=args.batch_size,
36 | image_size=args.image_size,
37 | class_cond=args.class_cond,
38 | )
39 |
40 | logger.log("training...")
41 | TrainLoop(
42 | model=model,
43 | diffusion=diffusion,
44 | data=data,
45 | batch_size=args.batch_size,
46 | microbatch=args.microbatch,
47 | lr=args.lr,
48 | ema_rate=args.ema_rate,
49 | log_interval=args.log_interval,
50 | save_interval=args.save_interval,
51 | resume_checkpoint=args.resume_checkpoint,
52 | use_fp16=args.use_fp16,
53 | fp16_scale_growth=args.fp16_scale_growth,
54 | schedule_sampler=schedule_sampler,
55 | weight_decay=args.weight_decay,
56 | lr_anneal_steps=args.lr_anneal_steps,
57 | ).run_loop()
58 |
59 |
60 | def create_argparser():
61 | defaults = dict(
62 | data_dir="",
63 | log_dir="",
64 | schedule_sampler="uniform",
65 | lr=1e-4,
66 | weight_decay=0.0,
67 | lr_anneal_steps=0,
68 | batch_size=1,
69 | microbatch=-1, # -1 disables microbatches
70 | ema_rate="0.9999", # comma-separated list of EMA values
71 | log_interval=10,
72 | save_interval=10000,
73 | resume_checkpoint="",
74 | use_fp16=False,
75 | fp16_scale_growth=1e-3,
76 | )
77 | defaults.update(model_and_diffusion_defaults())
78 | parser = argparse.ArgumentParser()
79 | add_dict_to_argparser(parser, defaults)
80 | return parser
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/super_res_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of samples from a super resolution model, given a batch
3 | of samples from a regular model from image_sample.py.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import blobfile as bf
10 | import numpy as np
11 | import torch as th
12 | import torch.distributed as dist
13 |
14 | from guided_diffusion import dist_util, logger
15 | from guided_diffusion.script_util import (
16 | sr_model_and_diffusion_defaults,
17 | sr_create_model_and_diffusion,
18 | args_to_dict,
19 | add_dict_to_argparser,
20 | )
21 |
22 |
23 | def main():
24 | args = create_argparser().parse_args()
25 |
26 | dist_util.setup_dist()
27 | logger.configure()
28 |
29 | logger.log("creating model...")
30 | model, diffusion = sr_create_model_and_diffusion(
31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
32 | )
33 | model.load_state_dict(
34 | dist_util.load_state_dict(args.model_path, map_location="cpu")
35 | )
36 | model.to(dist_util.dev())
37 | if args.use_fp16:
38 | model.convert_to_fp16()
39 | model.eval()
40 |
41 | logger.log("loading data...")
42 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond)
43 |
44 | logger.log("creating samples...")
45 | all_images = []
46 | while len(all_images) * args.batch_size < args.num_samples:
47 | model_kwargs = next(data)
48 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()}
49 | sample = diffusion.p_sample_loop(
50 | model,
51 | (args.batch_size, 3, args.large_size, args.large_size),
52 | clip_denoised=args.clip_denoised,
53 | model_kwargs=model_kwargs,
54 | )
55 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
56 | sample = sample.permute(0, 2, 3, 1)
57 | sample = sample.contiguous()
58 |
59 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
60 | dist.all_gather(all_samples, sample) # gather not supported with NCCL
61 | for sample in all_samples:
62 | all_images.append(sample.cpu().numpy())
63 | logger.log(f"created {len(all_images) * args.batch_size} samples")
64 |
65 | arr = np.concatenate(all_images, axis=0)
66 | arr = arr[: args.num_samples]
67 | if dist.get_rank() == 0:
68 | shape_str = "x".join([str(x) for x in arr.shape])
69 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
70 | logger.log(f"saving to {out_path}")
71 | np.savez(out_path, arr)
72 |
73 | dist.barrier()
74 | logger.log("sampling complete")
75 |
76 |
77 | def load_data_for_worker(base_samples, batch_size, class_cond):
78 | with bf.BlobFile(base_samples, "rb") as f:
79 | obj = np.load(f)
80 | image_arr = obj["arr_0"]
81 | if class_cond:
82 | label_arr = obj["arr_1"]
83 | rank = dist.get_rank()
84 | num_ranks = dist.get_world_size()
85 | buffer = []
86 | label_buffer = []
87 | while True:
88 | for i in range(rank, len(image_arr), num_ranks):
89 | buffer.append(image_arr[i])
90 | if class_cond:
91 | label_buffer.append(label_arr[i])
92 | if len(buffer) == batch_size:
93 | batch = th.from_numpy(np.stack(buffer)).float()
94 | batch = batch / 127.5 - 1.0
95 | batch = batch.permute(0, 3, 1, 2)
96 | res = dict(low_res=batch)
97 | if class_cond:
98 | res["y"] = th.from_numpy(np.stack(label_buffer))
99 | yield res
100 | buffer, label_buffer = [], []
101 |
102 |
103 | def create_argparser():
104 | defaults = dict(
105 | clip_denoised=True,
106 | num_samples=10000,
107 | batch_size=16,
108 | use_ddim=False,
109 | base_samples="",
110 | model_path="",
111 | )
112 | defaults.update(sr_model_and_diffusion_defaults())
113 | parser = argparse.ArgumentParser()
114 | add_dict_to_argparser(parser, defaults)
115 | return parser
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/P2_weighting/scripts/super_res_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a super-resolution model.
3 | """
4 |
5 | import argparse
6 |
7 | import torch.nn.functional as F
8 |
9 | from guided_diffusion import dist_util, logger
10 | from guided_diffusion.image_datasets import load_data
11 | from guided_diffusion.resample import create_named_schedule_sampler
12 | from guided_diffusion.script_util import (
13 | sr_model_and_diffusion_defaults,
14 | sr_create_model_and_diffusion,
15 | args_to_dict,
16 | add_dict_to_argparser,
17 | )
18 | from guided_diffusion.train_util import TrainLoop
19 |
20 |
21 | def main():
22 | args = create_argparser().parse_args()
23 |
24 | dist_util.setup_dist()
25 | logger.configure()
26 |
27 | logger.log("creating model...")
28 | model, diffusion = sr_create_model_and_diffusion(
29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys())
30 | )
31 | model.to(dist_util.dev())
32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
33 |
34 | logger.log("creating data loader...")
35 | data = load_superres_data(
36 | args.data_dir,
37 | args.batch_size,
38 | large_size=args.large_size,
39 | small_size=args.small_size,
40 | class_cond=args.class_cond,
41 | )
42 |
43 | logger.log("training...")
44 | TrainLoop(
45 | model=model,
46 | diffusion=diffusion,
47 | data=data,
48 | batch_size=args.batch_size,
49 | microbatch=args.microbatch,
50 | lr=args.lr,
51 | ema_rate=args.ema_rate,
52 | log_interval=args.log_interval,
53 | save_interval=args.save_interval,
54 | resume_checkpoint=args.resume_checkpoint,
55 | use_fp16=args.use_fp16,
56 | fp16_scale_growth=args.fp16_scale_growth,
57 | schedule_sampler=schedule_sampler,
58 | weight_decay=args.weight_decay,
59 | lr_anneal_steps=args.lr_anneal_steps,
60 | ).run_loop()
61 |
62 |
63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False):
64 | data = load_data(
65 | data_dir=data_dir,
66 | batch_size=batch_size,
67 | image_size=large_size,
68 | class_cond=class_cond,
69 | )
70 | for large_batch, model_kwargs in data:
71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area")
72 | yield large_batch, model_kwargs
73 |
74 |
75 | def create_argparser():
76 | defaults = dict(
77 | data_dir="",
78 | schedule_sampler="uniform",
79 | lr=1e-4,
80 | weight_decay=0.0,
81 | lr_anneal_steps=0,
82 | batch_size=1,
83 | microbatch=-1,
84 | ema_rate="0.9999",
85 | log_interval=10,
86 | save_interval=10000,
87 | resume_checkpoint="",
88 | use_fp16=False,
89 | fp16_scale_growth=1e-3,
90 | )
91 | defaults.update(sr_model_and_diffusion_defaults())
92 | parser = argparse.ArgumentParser()
93 | add_dict_to_argparser(parser, defaults)
94 | return parser
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
--------------------------------------------------------------------------------
/P2_weighting/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="guided-diffusion",
5 | py_modules=["guided_diffusion"],
6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7 | )
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2024] One-Shot Structure-Aware Stylized Image Synthesis
2 |
3 | [](https://arxiv.org/abs/2402.17275)
4 |
5 | > **One-Shot Structure-Aware Stylized Image Synthesis**
6 | > Hansam Cho, Jonghyun Lee, Seunggyu Chang, Yonghyun Jeong
7 | >
8 | >**Abstract**:
9 | While GAN-based models have been successful in image stylization tasks, they often struggle with structure preservation while stylizing a wide range of input images. Recently, diffusion models have been adopted for image stylization but still lack the capability to maintain the original quality of input images. Building on this, we propose OSASIS: a novel one-shot stylization method that is robust in structure preservation. We show that OSASIS is able to effectively disentangle the semantics from the structure of an image, allowing it to control the level of content and style implemented to a given input. We apply OSASIS to various experimental settings, including stylization with out-of-domain reference images and stylization with text-driven manipulation. Results show that OSASIS outperforms other stylization methods, especially for input images that were rarely encountered during training, providing a promising solution to stylization via diffusion models.
10 |
11 | ## Description
12 | Official implementation of One-Shot Structure-Aware Stylized Image Synthesis
13 |
14 | 
15 |
16 | ## Setup
17 | ```
18 | conda env create -f environment.yaml
19 | conda activate osasis
20 | ```
21 |
22 | ## Prepare Training
23 | 1. Download DDPM with [P2-weighting](https://github.com/jychoi118/P2-weighting) trained on FFHQ ([ffhq_p2.pt](https://drive.google.com/file/d/14ACthNkradJWBAL0th6z0UguzR6QjFKH/view?usp=drive_link)) and put model checkpoint in `P2_weighting/models`
24 | ```
25 | OSASIS
26 | |--P2_weighting
27 | | |--models
28 | | | |--ffhq_p2.pt
29 | ```
30 |
31 | 2. Generate style images in domain A (photorealistic domain)
32 |
33 | ```
34 | DEVICE=0
35 |
36 | SAMPLE_FLAGS="--attention_resolutions 16 --class_cond False --class_cond False --diffusion_steps 1000 --dropout 0.0 \
37 | --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 \
38 | --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 50"
39 |
40 | CUDA_VISIBLE_DEVICES=${DEVICE} \
41 | python gen_style_domA.py ${SAMPLE_FLAGS} \
42 | --model_path P2_weighting/models/ffhq_p2.pt \
43 | --input_dir imgs_style_domB \
44 | --sample_dir imgs_style_domA \
45 | --img_name img1.png \
46 | --n 1 \
47 | --t_start_ratio 0.5 \
48 | --seed 1 \
49 |
50 | ```
51 | `input_dir`: directory of style images in domain B (stylized domain)
52 | `sample_dir`: saving directory of style images in domain A (photorealistic domain)
53 | `img_name`: name of style image
54 | `n`: number of sampling images to generate style image in domain A
55 | `t_srtart_ratio`: noising level of image ($t_0$)
56 |
57 |
58 |
59 | ## Training
60 | 1. Download [DiffAE](https://github.com/phizaz/diffae) trained on FFHQ( [ffhq256_autoenc](https://drive.google.com/drive/folders/1-5zfxT6Gl-GjxM7z9ZO2AHlB70tfmF6V), [ffhq256_autoenc_latent](https://drive.google.com/drive/folders/1-H8WzKc65dEONN-DQ87TnXc23nTXDTYb) ) and put model checkpoints in `diffae/checkpoints`
61 | ```
62 | OSASIS
63 | |--diffae
64 | | |--checkpoints
65 | | | |--ffhq256_autoenc
66 | | | | |--last.ckpt
67 | | | | |--latent.pkl
68 | | | |--ffhq256_autoenc_latent
69 | | | | |--last.ckpt
70 | ```
71 |
72 |
73 | 2. Train the model using the following scripts, which necessitate 34GB of VRAM for a batch size of 8. The process takes in approximately 30 minutes on a single A100 GPU.
74 |
75 | ```
76 | DEVICE=0
77 |
78 | CUDA_VISIBLE_DEVICES=${DEVICE} \
79 | python train_diffaeB.py \
80 | --style_domA_dir imgs_style_domA \
81 | --style_domB_dir imgs_style_domB \
82 | --ref_img img1.png \
83 | --work_dir exp/img1 \
84 | --n_iter 200 \
85 | --ckpt_freq 200 \
86 | --batch_size 8 \
87 | --map_net \
88 | --map_time \
89 | --lambda_map 0.1 \
90 | --train
91 | ```
92 | `style_domA_dir`: directory of style images in domain A (photorealistic domain)
93 | `style_domB_dir`: directory of style images in domain B (stylized domain)
94 | `ref_img`: name of style image
95 | `work_dir`: working directory
96 | `n_iter`: number of iteration
97 |
98 |
99 | ## Testing
100 | Generate stylized image with following scripts:
101 | ```
102 | DEVICE=0
103 |
104 | CUDA_VISIBLE_DEVICES=${DEVICE} \
105 | python eval_diffaeB.py \
106 | --style_domB_dir imgs_style_domB \
107 | --infer_dir imgs_input_domA \
108 | --ref_img img1.png \
109 | --work_dir exp/img1 \
110 | --map_net \
111 | --map_time \
112 | --lambda_map 0.1
113 | ```
114 | `style_domB_dir`: directory of style images in domain B (stylized domain)
115 | `infer_dir`: directory of input images in domain A (photorealistic domain)
116 | `ref_img`: name of style image
117 | `work_dir`: working directory
118 |
119 | ## Using Pretrained Models
120 | Download pretrained weights in this [link](https://drive.google.com/drive/folders/1N0q9RBYIwc110njCsHX7uiY3BtwY5PqY?usp=sharing) and put checkpoint as shown in below
121 | ```
122 | OSASIS
123 | |--exp
124 | | |--img1
125 | | | |--ckpt
126 | | | | |--iter_200.pt
127 | | |--img2
128 | | | |--ckpt
129 | | | | |--iter_200.pt
130 | ```
131 |
132 | ## Acknowledgements
133 | This repository is built upon [P2-weighting](https://github.com/jychoi118/P2-weighting), [DiffAE](https://github.com/phizaz/diffae), and [MindTheGap](https://github.com/ZPdesu/MindTheGap)
134 |
135 | ## Citation
136 | ```bibtex
137 | @article{cho2024one,
138 | title={One-Shot Structure-Aware Stylized Image Synthesis},
139 | author={Cho, Hansam and Lee, Jonghyun and Chang, Seunggyu and Jeong, Yonghyun},
140 | journal={arXiv preprint arXiv:2402.17275},
141 | year={2024}
142 | }
143 | ```
--------------------------------------------------------------------------------
/diffae/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | project_dir = os.getcwd()
5 | sys.path.append(os.path.join(project_dir, 'diffae'))
--------------------------------------------------------------------------------
/diffae/align.py:
--------------------------------------------------------------------------------
1 | import bz2
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 |
7 | import dlib
8 | import numpy as np
9 | import PIL.Image
10 | import requests
11 | import scipy.ndimage
12 | from tqdm import tqdm
13 | from argparse import ArgumentParser
14 |
15 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
16 |
17 |
18 | def image_align(src_file,
19 | dst_file,
20 | face_landmarks,
21 | output_size=1024,
22 | transform_size=4096,
23 | enable_padding=True):
24 | # Align function from FFHQ dataset pre-processing step
25 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
26 |
27 | lm = np.array(face_landmarks)
28 | lm_chin = lm[0:17] # left-right
29 | lm_eyebrow_left = lm[17:22] # left-right
30 | lm_eyebrow_right = lm[22:27] # left-right
31 | lm_nose = lm[27:31] # top-down
32 | lm_nostrils = lm[31:36] # top-down
33 | lm_eye_left = lm[36:42] # left-clockwise
34 | lm_eye_right = lm[42:48] # left-clockwise
35 | lm_mouth_outer = lm[48:60] # left-clockwise
36 | lm_mouth_inner = lm[60:68] # left-clockwise
37 |
38 | # Calculate auxiliary vectors.
39 | eye_left = np.mean(lm_eye_left, axis=0)
40 | eye_right = np.mean(lm_eye_right, axis=0)
41 | eye_avg = (eye_left + eye_right) * 0.5
42 | eye_to_eye = eye_right - eye_left
43 | mouth_left = lm_mouth_outer[0]
44 | mouth_right = lm_mouth_outer[6]
45 | mouth_avg = (mouth_left + mouth_right) * 0.5
46 | eye_to_mouth = mouth_avg - eye_avg
47 |
48 | # Choose oriented crop rectangle.
49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
50 | x /= np.hypot(*x)
51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
52 | y = np.flipud(x) * [-1, 1]
53 | c = eye_avg + eye_to_mouth * 0.1
54 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
55 | qsize = np.hypot(*x) * 2
56 |
57 | # Load in-the-wild image.
58 | if not os.path.isfile(src_file):
59 | print(
60 | '\nCannot find source image. Please run "--wilds" before "--align".'
61 | )
62 | return
63 | img = PIL.Image.open(src_file)
64 | img = img.convert('RGB')
65 |
66 | # Shrink.
67 | shrink = int(np.floor(qsize / output_size * 0.5))
68 | if shrink > 1:
69 | rsize = (int(np.rint(float(img.size[0]) / shrink)),
70 | int(np.rint(float(img.size[1]) / shrink)))
71 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
72 | quad /= shrink
73 | qsize /= shrink
74 |
75 | # Crop.
76 | border = max(int(np.rint(qsize * 0.1)), 3)
77 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
78 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
79 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
80 | min(crop[2] + border,
81 | img.size[0]), min(crop[3] + border, img.size[1]))
82 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
83 | img = img.crop(crop)
84 | quad -= crop[0:2]
85 |
86 | # Pad.
87 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
88 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
89 | pad = (max(-pad[0] + border,
90 | 0), max(-pad[1] + border,
91 | 0), max(pad[2] - img.size[0] + border,
92 | 0), max(pad[3] - img.size[1] + border, 0))
93 | if enable_padding and max(pad) > border - 4:
94 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
95 | img = np.pad(np.float32(img),
96 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
97 | h, w, _ = img.shape
98 | y, x, _ = np.ogrid[:h, :w, :1]
99 | mask = np.maximum(
100 | 1.0 -
101 | np.minimum(np.float32(x) / pad[0],
102 | np.float32(w - 1 - x) / pad[2]), 1.0 -
103 | np.minimum(np.float32(y) / pad[1],
104 | np.float32(h - 1 - y) / pad[3]))
105 | blur = qsize * 0.02
106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
107 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)),
110 | 'RGB')
111 | quad += pad[:2]
112 |
113 | # Transform.
114 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
115 | (quad + 0.5).flatten(), PIL.Image.BILINEAR)
116 | if output_size < transform_size:
117 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
118 |
119 | # Save aligned image.
120 | img.save(dst_file, 'PNG')
121 |
122 |
123 | class LandmarksDetector:
124 | def __init__(self, predictor_model_path):
125 | """
126 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
127 | """
128 | self.detector = dlib.get_frontal_face_detector(
129 | ) # cnn_face_detection_model_v1 also can be used
130 | self.shape_predictor = dlib.shape_predictor(predictor_model_path)
131 |
132 | def get_landmarks(self, image):
133 | img = dlib.load_rgb_image(image)
134 | dets = self.detector(img, 1)
135 |
136 | for detection in dets:
137 | face_landmarks = [
138 | (item.x, item.y)
139 | for item in self.shape_predictor(img, detection).parts()
140 | ]
141 | yield face_landmarks
142 |
143 |
144 | def unpack_bz2(src_path):
145 | dst_path = src_path[:-4]
146 | if os.path.exists(dst_path):
147 | print('cached')
148 | return dst_path
149 | data = bz2.BZ2File(src_path).read()
150 | with open(dst_path, 'wb') as fp:
151 | fp.write(data)
152 | return dst_path
153 |
154 |
155 | def work_landmark(raw_img_path, img_name, face_landmarks):
156 | face_img_name = '%s.png' % (os.path.splitext(img_name)[0], )
157 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
158 | if os.path.exists(aligned_face_path):
159 | return
160 | image_align(raw_img_path,
161 | aligned_face_path,
162 | face_landmarks,
163 | output_size=256)
164 |
165 |
166 | def get_file(src, tgt):
167 | if os.path.exists(tgt):
168 | print('cached')
169 | return tgt
170 | tgt_dir = os.path.dirname(tgt)
171 | if not os.path.exists(tgt_dir):
172 | os.makedirs(tgt_dir)
173 | file = requests.get(src)
174 | open(tgt, 'wb').write(file.content)
175 | return tgt
176 |
177 |
178 | if __name__ == "__main__":
179 | """
180 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
181 | python align_images.py /raw_images /aligned_images
182 | """
183 | parser = ArgumentParser()
184 | parser.add_argument("-i",
185 | "--input_imgs_path",
186 | type=str,
187 | default="imgs",
188 | help="input images directory path")
189 | parser.add_argument("-o",
190 | "--output_imgs_path",
191 | type=str,
192 | default="imgs_align",
193 | help="output images directory path")
194 |
195 | args = parser.parse_args()
196 |
197 | # takes very long time ...
198 | landmarks_model_path = unpack_bz2(
199 | get_file(
200 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2',
201 | 'temp/shape_predictor_68_face_landmarks.dat.bz2'))
202 |
203 | # RAW_IMAGES_DIR = sys.argv[1]
204 | # ALIGNED_IMAGES_DIR = sys.argv[2]
205 | RAW_IMAGES_DIR = args.input_imgs_path
206 | ALIGNED_IMAGES_DIR = args.output_imgs_path
207 |
208 | if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR)
209 |
210 | files = os.listdir(RAW_IMAGES_DIR)
211 | print(f'total img files {len(files)}')
212 | with tqdm(total=len(files)) as progress:
213 |
214 | def cb(*args):
215 | # print('update')
216 | progress.update()
217 |
218 | def err_cb(e):
219 | print('error:', e)
220 |
221 | with Pool(8) as pool:
222 | res = []
223 | landmarks_detector = LandmarksDetector(landmarks_model_path)
224 | for img_name in files:
225 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
226 | # print('img_name:', img_name)
227 | for i, face_landmarks in enumerate(
228 | landmarks_detector.get_landmarks(raw_img_path),
229 | start=1):
230 | # assert i == 1, f'{i}'
231 | # print(i, face_landmarks)
232 | # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
233 | # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
234 | # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256)
235 |
236 | work_landmark(raw_img_path, img_name, face_landmarks)
237 | progress.update()
238 |
239 | # job = pool.apply_async(
240 | # work_landmark,
241 | # (raw_img_path, img_name, face_landmarks),
242 | # callback=cb,
243 | # error_callback=err_cb,
244 | # )
245 | # res.append(job)
246 |
247 | # pool.close()
248 | # pool.join()
249 | print(f"output aligned images at: {ALIGNED_IMAGES_DIR}")
250 |
--------------------------------------------------------------------------------
/diffae/choices.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from torch import nn
3 |
4 |
5 | class TrainMode(Enum):
6 | # manipulate mode = training the classifier
7 | manipulate = 'manipulate'
8 | # default trainin mode!
9 | diffusion = 'diffusion'
10 | # default latent training mode!
11 | # fitting the a DDPM to a given latent
12 | latent_diffusion = 'latentdiffusion'
13 |
14 | def is_manipulate(self):
15 | return self in [
16 | TrainMode.manipulate,
17 | ]
18 |
19 | def is_diffusion(self):
20 | return self in [
21 | TrainMode.diffusion,
22 | TrainMode.latent_diffusion,
23 | ]
24 |
25 | def is_autoenc(self):
26 | # the network possibly does autoencoding
27 | return self in [
28 | TrainMode.diffusion,
29 | ]
30 |
31 | def is_latent_diffusion(self):
32 | return self in [
33 | TrainMode.latent_diffusion,
34 | ]
35 |
36 | def use_latent_net(self):
37 | return self.is_latent_diffusion()
38 |
39 | def require_dataset_infer(self):
40 | """
41 | whether training in this mode requires the latent variables to be available?
42 | """
43 | # this will precalculate all the latents before hand
44 | # and the dataset will be all the predicted latents
45 | return self in [
46 | TrainMode.latent_diffusion,
47 | TrainMode.manipulate,
48 | ]
49 |
50 |
51 | class ManipulateMode(Enum):
52 | """
53 | how to train the classifier to manipulate
54 | """
55 | # train on whole celeba attr dataset
56 | celebahq_all = 'celebahq_all'
57 | # celeba with D2C's crop
58 | d2c_fewshot = 'd2cfewshot'
59 | d2c_fewshot_allneg = 'd2cfewshotallneg'
60 |
61 | def is_celeba_attr(self):
62 | return self in [
63 | ManipulateMode.d2c_fewshot,
64 | ManipulateMode.d2c_fewshot_allneg,
65 | ManipulateMode.celebahq_all,
66 | ]
67 |
68 | def is_single_class(self):
69 | return self in [
70 | ManipulateMode.d2c_fewshot,
71 | ManipulateMode.d2c_fewshot_allneg,
72 | ]
73 |
74 | def is_fewshot(self):
75 | return self in [
76 | ManipulateMode.d2c_fewshot,
77 | ManipulateMode.d2c_fewshot_allneg,
78 | ]
79 |
80 | def is_fewshot_allneg(self):
81 | return self in [
82 | ManipulateMode.d2c_fewshot_allneg,
83 | ]
84 |
85 |
86 | class ModelType(Enum):
87 | """
88 | Kinds of the backbone models
89 | """
90 |
91 | # unconditional ddpm
92 | ddpm = 'ddpm'
93 | # autoencoding ddpm cannot do unconditional generation
94 | autoencoder = 'autoencoder'
95 |
96 | def has_autoenc(self):
97 | return self in [
98 | ModelType.autoencoder,
99 | ]
100 |
101 | def can_sample(self):
102 | return self in [ModelType.ddpm]
103 |
104 |
105 | class ModelName(Enum):
106 | """
107 | List of all supported model classes
108 | """
109 |
110 | beatgans_ddpm = 'beatgans_ddpm'
111 | beatgans_autoenc = 'beatgans_autoenc'
112 |
113 |
114 | class ModelMeanType(Enum):
115 | """
116 | Which type of output the model predicts.
117 | """
118 |
119 | eps = 'eps' # the model predicts epsilon
120 |
121 |
122 | class ModelVarType(Enum):
123 | """
124 | What is used as the model's output variance.
125 |
126 | The LEARNED_RANGE option has been added to allow the model to predict
127 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128 | """
129 |
130 | # posterior beta_t
131 | fixed_small = 'fixed_small'
132 | # beta_t
133 | fixed_large = 'fixed_large'
134 |
135 |
136 | class LossType(Enum):
137 | mse = 'mse' # use raw MSE loss (and KL when learning variances)
138 | l1 = 'l1'
139 |
140 |
141 | class GenerativeType(Enum):
142 | """
143 | How's a sample generated
144 | """
145 |
146 | ddpm = 'ddpm'
147 | ddim = 'ddim'
148 |
149 |
150 | class OptimizerType(Enum):
151 | adam = 'adam'
152 | adamw = 'adamw'
153 |
154 |
155 | class Activation(Enum):
156 | none = 'none'
157 | relu = 'relu'
158 | lrelu = 'lrelu'
159 | silu = 'silu'
160 | tanh = 'tanh'
161 |
162 | def get_act(self):
163 | if self == Activation.none:
164 | return nn.Identity()
165 | elif self == Activation.relu:
166 | return nn.ReLU()
167 | elif self == Activation.lrelu:
168 | return nn.LeakyReLU(negative_slope=0.2)
169 | elif self == Activation.silu:
170 | return nn.SiLU()
171 | elif self == Activation.tanh:
172 | return nn.Tanh()
173 | else:
174 | raise NotImplementedError()
175 |
176 |
177 | class ManipulateLossType(Enum):
178 | bce = 'bce'
179 | mse = 'mse'
--------------------------------------------------------------------------------
/diffae/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | cuda: "10.2"
3 | gpu: true
4 | python_version: "3.8"
5 | system_packages:
6 | - "libgl1-mesa-glx"
7 | - "libglib2.0-0"
8 | python_packages:
9 | - "numpy==1.21.5"
10 | - "cmake==3.23.3"
11 | - "ipython==7.21.0"
12 | - "opencv-python==4.5.4.58"
13 | - "pandas==1.1.5"
14 | - "lmdb==1.2.1"
15 | - "lpips==0.1.4"
16 | - "pytorch-fid==0.2.0"
17 | - "ftfy==6.1.1"
18 | - "scipy==1.5.4"
19 | - "torch==1.9.1"
20 | - "torchvision==0.10.1"
21 | - "tqdm==4.62.3"
22 | - "regex==2022.7.25"
23 | - "Pillow==9.2.0"
24 | - "pytorch_lightning==1.7.0"
25 |
26 | run:
27 | - pip install dlib
28 |
29 | predict: "predict.py:Predictor"
30 |
--------------------------------------------------------------------------------
/diffae/config_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from copy import deepcopy
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class BaseConfig:
9 | def clone(self):
10 | return deepcopy(self)
11 |
12 | def inherit(self, another):
13 | """inherit common keys from a given config"""
14 | common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15 | for k in common_keys:
16 | setattr(self, k, getattr(another, k))
17 |
18 | def propagate(self):
19 | """push down the configuration to all members"""
20 | for k, v in self.__dict__.items():
21 | if isinstance(v, BaseConfig):
22 | v.inherit(self)
23 | v.propagate()
24 |
25 | def save(self, save_path):
26 | """save config to json file"""
27 | dirname = os.path.dirname(save_path)
28 | if not os.path.exists(dirname):
29 | os.makedirs(dirname)
30 | conf = self.as_dict_jsonable()
31 | with open(save_path, 'w') as f:
32 | json.dump(conf, f)
33 |
34 | def load(self, load_path):
35 | """load json config"""
36 | with open(load_path) as f:
37 | conf = json.load(f)
38 | self.from_dict(conf)
39 |
40 | def from_dict(self, dict, strict=False):
41 | for k, v in dict.items():
42 | if not hasattr(self, k):
43 | if strict:
44 | raise ValueError(f"loading extra '{k}'")
45 | else:
46 | print(f"loading extra '{k}'")
47 | continue
48 | if isinstance(self.__dict__[k], BaseConfig):
49 | self.__dict__[k].from_dict(v)
50 | else:
51 | self.__dict__[k] = v
52 |
53 | def as_dict_jsonable(self):
54 | conf = {}
55 | for k, v in self.__dict__.items():
56 | if isinstance(v, BaseConfig):
57 | conf[k] = v.as_dict_jsonable()
58 | else:
59 | if jsonable(v):
60 | conf[k] = v
61 | else:
62 | # ignore not jsonable
63 | pass
64 | return conf
65 |
66 |
67 | def jsonable(x):
68 | try:
69 | json.dumps(x)
70 | return True
71 | except TypeError:
72 | return False
73 |
--------------------------------------------------------------------------------
/diffae/dataset_util.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from dist_utils import *
4 |
5 |
6 | def use_cached_dataset_path(source_path, cache_path):
7 | if get_rank() == 0:
8 | if not os.path.exists(cache_path):
9 | # shutil.rmtree(cache_path)
10 | print(f'copying the data: {source_path} to {cache_path}')
11 | shutil.copytree(source_path, cache_path)
12 | barrier()
13 | return cache_path
--------------------------------------------------------------------------------
/diffae/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4 |
5 | Sampler = Union[SpacedDiffusionBeatGans]
6 | SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
7 |
--------------------------------------------------------------------------------
/diffae/diffusion/diffusion.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from dataclasses import dataclass
3 |
4 |
5 | def space_timesteps(num_timesteps, section_counts):
6 | """
7 | Create a list of timesteps to use from an original diffusion process,
8 | given the number of timesteps we want to take from equally-sized portions
9 | of the original process.
10 |
11 | For example, if there's 300 timesteps and the section counts are [10,15,20]
12 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
13 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
14 |
15 | If the stride is a string starting with "ddim", then the fixed striding
16 | from the DDIM paper is used, and only one section is allowed.
17 |
18 | :param num_timesteps: the number of diffusion steps in the original
19 | process to divide up.
20 | :param section_counts: either a list of numbers, or a string containing
21 | comma-separated numbers, indicating the step count
22 | per section. As a special case, use "ddimN" where N
23 | is a number of steps to use the striding from the
24 | DDIM paper.
25 | :return: a set of diffusion steps from the original process to use.
26 | """
27 | if isinstance(section_counts, str):
28 | if section_counts.startswith("ddim"):
29 | desired_count = int(section_counts[len("ddim"):])
30 | for i in range(1, num_timesteps):
31 | if len(range(0, num_timesteps, i)) == desired_count:
32 | return set(range(0, num_timesteps, i))
33 | raise ValueError(
34 | f"cannot create exactly {num_timesteps} steps with an integer stride"
35 | )
36 | section_counts = [int(x) for x in section_counts.split(",")]
37 | size_per = num_timesteps // len(section_counts)
38 | extra = num_timesteps % len(section_counts)
39 | start_idx = 0
40 | all_steps = []
41 | for i, section_count in enumerate(section_counts):
42 | size = size_per + (1 if i < extra else 0)
43 | if size < section_count:
44 | raise ValueError(
45 | f"cannot divide section of {size} steps into {section_count}")
46 | if section_count <= 1:
47 | frac_stride = 1
48 | else:
49 | frac_stride = (size - 1) / (section_count - 1)
50 | cur_idx = 0.0
51 | taken_steps = []
52 | for _ in range(section_count):
53 | taken_steps.append(start_idx + round(cur_idx))
54 | cur_idx += frac_stride
55 | all_steps += taken_steps
56 | start_idx += size
57 | return set(all_steps)
58 |
59 |
60 | @dataclass
61 | class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62 | use_timesteps: Tuple[int] = None
63 |
64 | def make_sampler(self):
65 | return SpacedDiffusionBeatGans(self)
66 |
67 |
68 | class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69 | """
70 | A diffusion process which can skip steps in a base diffusion process.
71 |
72 | :param use_timesteps: a collection (sequence or set) of timesteps from the
73 | original diffusion process to retain.
74 | :param kwargs: the kwargs to create the base diffusion process.
75 | """
76 | def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77 | self.conf = conf
78 | self.use_timesteps = set(conf.use_timesteps)
79 | # how the new t's mapped to the old t's
80 | self.timestep_map = []
81 | self.original_num_steps = len(conf.betas)
82 |
83 | base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84 | last_alpha_cumprod = 1.0
85 | new_betas = []
86 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87 | if i in self.use_timesteps:
88 | # getting the new betas of the new timesteps
89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90 | last_alpha_cumprod = alpha_cumprod
91 | self.timestep_map.append(i)
92 | conf.betas = np.array(new_betas)
93 | super().__init__(conf)
94 |
95 | def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96 | return super().p_mean_variance(self._wrap_model(model), *args,
97 | **kwargs)
98 |
99 | def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100 | return super().training_losses(self._wrap_model(model), *args,
101 | **kwargs)
102 |
103 | def condition_mean(self, cond_fn, *args, **kwargs):
104 | return super().condition_mean(self._wrap_model(cond_fn), *args,
105 | **kwargs)
106 |
107 | def condition_score(self, cond_fn, *args, **kwargs):
108 | return super().condition_score(self._wrap_model(cond_fn), *args,
109 | **kwargs)
110 |
111 | def _wrap_model(self, model: Model):
112 | if isinstance(model, _WrappedModel):
113 | return model
114 | return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115 | self.original_num_steps)
116 |
117 | def _scale_timesteps(self, t):
118 | # Scaling is done by the wrapped model.
119 | return t
120 |
121 |
122 | class _WrappedModel:
123 | """
124 | converting the supplied t's to the old t's scales.
125 | """
126 | def __init__(self, model, timestep_map, rescale_timesteps,
127 | original_num_steps):
128 | self.model = model
129 | self.timestep_map = timestep_map
130 | self.rescale_timesteps = rescale_timesteps
131 | self.original_num_steps = original_num_steps
132 |
133 | def forward(self, x, t, t_cond=None, **kwargs):
134 | """
135 | Args:
136 | t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137 | t_cond: the same as t but can be of different values
138 | """
139 | map_tensor = th.tensor(self.timestep_map,
140 | device=t.device,
141 | dtype=t.dtype)
142 |
143 | def do(t):
144 | new_ts = map_tensor[t]
145 | if self.rescale_timesteps:
146 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147 | return new_ts
148 |
149 | if t_cond is not None:
150 | # support t_cond
151 | t_cond = do(t_cond)
152 |
153 | return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
154 |
155 | def __getattr__(self, name):
156 | # allow for calling the model's methods
157 | if hasattr(self.model, name):
158 | func = getattr(self.model, name)
159 | return func
160 | raise AttributeError(name)
161 |
--------------------------------------------------------------------------------
/diffae/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | else:
18 | raise NotImplementedError(f"unknown schedule sampler: {name}")
19 |
20 |
21 | class ScheduleSampler(ABC):
22 | """
23 | A distribution over timesteps in the diffusion process, intended to reduce
24 | variance of the objective.
25 |
26 | By default, samplers perform unbiased importance sampling, in which the
27 | objective's mean is unchanged.
28 | However, subclasses may override sample() to change how the resampled
29 | terms are reweighted, allowing for actual changes in the objective.
30 | """
31 | @abstractmethod
32 | def weights(self):
33 | """
34 | Get a numpy array of weights, one per diffusion step.
35 |
36 | The weights needn't be normalized, but must be positive.
37 | """
38 |
39 | def sample(self, batch_size, device):
40 | """
41 | Importance-sample timesteps for a batch.
42 |
43 | :param batch_size: the number of timesteps.
44 | :param device: the torch device to save to.
45 | :return: a tuple (timesteps, weights):
46 | - timesteps: a tensor of timestep indices.
47 | - weights: a tensor of weights to scale the resulting losses.
48 | """
49 | w = self.weights()
50 | p = w / np.sum(w)
51 | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52 | indices = th.from_numpy(indices_np).long().to(device)
53 | weights_np = 1 / (len(p) * p[indices_np])
54 | weights = th.from_numpy(weights_np).float().to(device)
55 | return indices, weights
56 |
57 |
58 | class UniformSampler(ScheduleSampler):
59 | def __init__(self, num_timesteps):
60 | self._weights = np.ones([num_timesteps])
61 |
62 | def weights(self):
63 | return self._weights
64 |
--------------------------------------------------------------------------------
/diffae/dist_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from torch import distributed
3 |
4 |
5 | def barrier():
6 | if distributed.is_initialized():
7 | distributed.barrier()
8 | else:
9 | pass
10 |
11 |
12 | def broadcast(data, src):
13 | if distributed.is_initialized():
14 | distributed.broadcast(data, src)
15 | else:
16 | pass
17 |
18 |
19 | def all_gather(data: List, src):
20 | if distributed.is_initialized():
21 | distributed.all_gather(data, src)
22 | else:
23 | data[0] = src
24 |
25 |
26 | def get_rank():
27 | if distributed.is_initialized():
28 | return distributed.get_rank()
29 | else:
30 | return 0
31 |
32 |
33 | def get_world_size():
34 | if distributed.is_initialized():
35 | return distributed.get_world_size()
36 | else:
37 | return 1
38 |
39 |
40 | def chunk_size(size, rank, world_size):
41 | extra = rank < size % world_size
42 | return size // world_size + extra
--------------------------------------------------------------------------------
/diffae/evals/church256_autoenc.txt:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/diffae/lmdb_writer.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 |
6 | import torch
7 |
8 | from contextlib import contextmanager
9 | from torch.utils.data import Dataset
10 | from multiprocessing import Process, Queue
11 | import os
12 | import shutil
13 |
14 |
15 | def convert(x, format, quality=100):
16 | # to prevent locking!
17 | torch.set_num_threads(1)
18 |
19 | buffer = BytesIO()
20 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
21 | x = x.to(torch.uint8)
22 | x = x.numpy()
23 | img = Image.fromarray(x)
24 | img.save(buffer, format=format, quality=quality)
25 | val = buffer.getvalue()
26 | return val
27 |
28 |
29 | @contextmanager
30 | def nullcontext():
31 | yield
32 |
33 |
34 | class _WriterWroker(Process):
35 | def __init__(self, path, format, quality, zfill, q):
36 | super().__init__()
37 | if os.path.exists(path):
38 | shutil.rmtree(path)
39 |
40 | self.path = path
41 | self.format = format
42 | self.quality = quality
43 | self.zfill = zfill
44 | self.q = q
45 | self.i = 0
46 |
47 | def run(self):
48 | if not os.path.exists(self.path):
49 | os.makedirs(self.path)
50 |
51 | with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
52 | while True:
53 | job = self.q.get()
54 | if job is None:
55 | break
56 | with env.begin(write=True) as txn:
57 | for x in job:
58 | key = f"{str(self.i).zfill(self.zfill)}".encode(
59 | "utf-8")
60 | x = convert(x, self.format, self.quality)
61 | txn.put(key, x)
62 | self.i += 1
63 |
64 | with env.begin(write=True) as txn:
65 | txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
66 |
67 |
68 | class LMDBImageWriter:
69 | def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
70 | self.path = path
71 | self.format = format
72 | self.quality = quality
73 | self.zfill = zfill
74 | self.queue = None
75 | self.worker = None
76 |
77 | def __enter__(self):
78 | self.queue = Queue(maxsize=3)
79 | self.worker = _WriterWroker(self.path, self.format, self.quality,
80 | self.zfill, self.queue)
81 | self.worker.start()
82 |
83 | def put_images(self, tensor):
84 | """
85 | Args:
86 | tensor: (n, c, h, w) [0-1] tensor
87 | """
88 | self.queue.put(tensor.cpu())
89 | # with self.env.begin(write=True) as txn:
90 | # for x in tensor:
91 | # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
92 | # x = convert(x, self.format, self.quality)
93 | # txn.put(key, x)
94 | # self.i += 1
95 |
96 | def __exit__(self, *args, **kwargs):
97 | self.queue.put(None)
98 | self.queue.close()
99 | self.worker.join()
100 |
101 |
102 | class LMDBImageReader(Dataset):
103 | def __init__(self, path, zfill: int = 7):
104 | self.zfill = zfill
105 | self.env = lmdb.open(
106 | path,
107 | max_readers=32,
108 | readonly=True,
109 | lock=False,
110 | readahead=False,
111 | meminit=False,
112 | )
113 |
114 | if not self.env:
115 | raise IOError('Cannot open lmdb dataset', path)
116 |
117 | with self.env.begin(write=False) as txn:
118 | self.length = int(
119 | txn.get('length'.encode('utf-8')).decode('utf-8'))
120 |
121 | def __len__(self):
122 | return self.length
123 |
124 | def __getitem__(self, index):
125 | with self.env.begin(write=False) as txn:
126 | key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
127 | img_bytes = txn.get(key)
128 |
129 | buffer = BytesIO(img_bytes)
130 | img = Image.open(buffer)
131 | return img
132 |
--------------------------------------------------------------------------------
/diffae/model/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3 | from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4 |
5 | Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6 | ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
7 |
--------------------------------------------------------------------------------
/diffae/model/latentnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from enum import Enum
4 | from typing import NamedTuple, Tuple
5 |
6 | import torch
7 | from choices import *
8 | from config_base import BaseConfig
9 | from torch import nn
10 | from torch.nn import init
11 |
12 | from .blocks import *
13 | from .nn import timestep_embedding
14 | from .unet import *
15 |
16 |
17 | class LatentNetType(Enum):
18 | none = 'none'
19 | # injecting inputs into the hidden layers
20 | skip = 'skip'
21 |
22 |
23 | class LatentNetReturn(NamedTuple):
24 | pred: torch.Tensor = None
25 |
26 |
27 | @dataclass
28 | class MLPSkipNetConfig(BaseConfig):
29 | """
30 | default MLP for the latent DPM in the paper!
31 | """
32 | num_channels: int
33 | skip_layers: Tuple[int]
34 | num_hid_channels: int
35 | num_layers: int
36 | num_time_emb_channels: int = 64
37 | activation: Activation = Activation.silu
38 | use_norm: bool = True
39 | condition_bias: float = 1
40 | dropout: float = 0
41 | last_act: Activation = Activation.none
42 | num_time_layers: int = 2
43 | time_last_act: bool = False
44 |
45 | def make_model(self):
46 | return MLPSkipNet(self)
47 |
48 |
49 | class MLPSkipNet(nn.Module):
50 | """
51 | concat x to hidden layers
52 |
53 | default MLP for the latent DPM in the paper!
54 | """
55 | def __init__(self, conf: MLPSkipNetConfig):
56 | super().__init__()
57 | self.conf = conf
58 |
59 | layers = []
60 | for i in range(conf.num_time_layers):
61 | if i == 0:
62 | a = conf.num_time_emb_channels
63 | b = conf.num_channels
64 | else:
65 | a = conf.num_channels
66 | b = conf.num_channels
67 | layers.append(nn.Linear(a, b))
68 | if i < conf.num_time_layers - 1 or conf.time_last_act:
69 | layers.append(conf.activation.get_act())
70 | self.time_embed = nn.Sequential(*layers)
71 |
72 | self.layers = nn.ModuleList([])
73 | for i in range(conf.num_layers):
74 | if i == 0:
75 | act = conf.activation
76 | norm = conf.use_norm
77 | cond = True
78 | a, b = conf.num_channels, conf.num_hid_channels
79 | dropout = conf.dropout
80 | elif i == conf.num_layers - 1:
81 | act = Activation.none
82 | norm = False
83 | cond = False
84 | a, b = conf.num_hid_channels, conf.num_channels
85 | dropout = 0
86 | else:
87 | act = conf.activation
88 | norm = conf.use_norm
89 | cond = True
90 | a, b = conf.num_hid_channels, conf.num_hid_channels
91 | dropout = conf.dropout
92 |
93 | if i in conf.skip_layers:
94 | a += conf.num_channels
95 |
96 | self.layers.append(
97 | MLPLNAct(
98 | a,
99 | b,
100 | norm=norm,
101 | activation=act,
102 | cond_channels=conf.num_channels,
103 | use_cond=cond,
104 | condition_bias=conf.condition_bias,
105 | dropout=dropout,
106 | ))
107 | self.last_act = conf.last_act.get_act()
108 |
109 | def forward(self, x, t, **kwargs):
110 | t = timestep_embedding(t, self.conf.num_time_emb_channels)
111 | cond = self.time_embed(t)
112 | h = x
113 | for i in range(len(self.layers)):
114 | if i in self.conf.skip_layers:
115 | # injecting input into the hidden layers
116 | h = torch.cat([h, x], dim=1)
117 | h = self.layers[i].forward(x=h, cond=cond)
118 | h = self.last_act(h)
119 | return LatentNetReturn(h)
120 |
121 |
122 | class MLPLNAct(nn.Module):
123 | def __init__(
124 | self,
125 | in_channels: int,
126 | out_channels: int,
127 | norm: bool,
128 | use_cond: bool,
129 | activation: Activation,
130 | cond_channels: int,
131 | condition_bias: float = 0,
132 | dropout: float = 0,
133 | ):
134 | super().__init__()
135 | self.activation = activation
136 | self.condition_bias = condition_bias
137 | self.use_cond = use_cond
138 |
139 | self.linear = nn.Linear(in_channels, out_channels)
140 | self.act = activation.get_act()
141 | if self.use_cond:
142 | self.linear_emb = nn.Linear(cond_channels, out_channels)
143 | self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144 | if norm:
145 | self.norm = nn.LayerNorm(out_channels)
146 | else:
147 | self.norm = nn.Identity()
148 |
149 | if dropout > 0:
150 | self.dropout = nn.Dropout(p=dropout)
151 | else:
152 | self.dropout = nn.Identity()
153 |
154 | self.init_weights()
155 |
156 | def init_weights(self):
157 | for module in self.modules():
158 | if isinstance(module, nn.Linear):
159 | if self.activation == Activation.relu:
160 | init.kaiming_normal_(module.weight,
161 | a=0,
162 | nonlinearity='relu')
163 | elif self.activation == Activation.lrelu:
164 | init.kaiming_normal_(module.weight,
165 | a=0.2,
166 | nonlinearity='leaky_relu')
167 | elif self.activation == Activation.silu:
168 | init.kaiming_normal_(module.weight,
169 | a=0,
170 | nonlinearity='relu')
171 | else:
172 | # leave it as default
173 | pass
174 |
175 | def forward(self, x, cond=None):
176 | x = self.linear(x)
177 | if self.use_cond:
178 | # (n, c) or (n, c * 2)
179 | cond = self.cond_layers(cond)
180 | cond = (cond, None)
181 |
182 | # scale shift first
183 | x = x * (self.condition_bias + cond[0])
184 | if cond[1] is not None:
185 | x = x + cond[1]
186 | # then norm
187 | x = self.norm(x)
188 | else:
189 | # no condition
190 | x = self.norm(x)
191 | x = self.act(x)
192 | x = self.dropout(x)
193 | return x
--------------------------------------------------------------------------------
/diffae/model/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | from enum import Enum
6 | import math
7 | from typing import Optional
8 |
9 | import torch as th
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | import torch.nn.functional as F
14 |
15 |
16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17 | class SiLU(nn.Module):
18 | # @th.jit.script
19 | def forward(self, x):
20 | return x * th.sigmoid(x)
21 |
22 |
23 | class GroupNorm32(nn.GroupNorm):
24 | def forward(self, x):
25 | return super().forward(x.float()).type(x.dtype)
26 |
27 |
28 | def conv_nd(dims, *args, **kwargs):
29 | """
30 | Create a 1D, 2D, or 3D convolution module.
31 | """
32 | if dims == 1:
33 | return nn.Conv1d(*args, **kwargs)
34 | elif dims == 2:
35 | return nn.Conv2d(*args, **kwargs)
36 | elif dims == 3:
37 | return nn.Conv3d(*args, **kwargs)
38 | raise ValueError(f"unsupported dimensions: {dims}")
39 |
40 |
41 | def linear(*args, **kwargs):
42 | """
43 | Create a linear module.
44 | """
45 | return nn.Linear(*args, **kwargs)
46 |
47 |
48 | def avg_pool_nd(dims, *args, **kwargs):
49 | """
50 | Create a 1D, 2D, or 3D average pooling module.
51 | """
52 | if dims == 1:
53 | return nn.AvgPool1d(*args, **kwargs)
54 | elif dims == 2:
55 | return nn.AvgPool2d(*args, **kwargs)
56 | elif dims == 3:
57 | return nn.AvgPool3d(*args, **kwargs)
58 | raise ValueError(f"unsupported dimensions: {dims}")
59 |
60 |
61 | def update_ema(target_params, source_params, rate=0.99):
62 | """
63 | Update target parameters to be closer to those of source parameters using
64 | an exponential moving average.
65 |
66 | :param target_params: the target parameter sequence.
67 | :param source_params: the source parameter sequence.
68 | :param rate: the EMA rate (closer to 1 means slower).
69 | """
70 | for targ, src in zip(target_params, source_params):
71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72 |
73 |
74 | def zero_module(module):
75 | """
76 | Zero out the parameters of a module and return it.
77 | """
78 | for p in module.parameters():
79 | p.detach().zero_()
80 | return module
81 |
82 |
83 | def scale_module(module, scale):
84 | """
85 | Scale the parameters of a module and return it.
86 | """
87 | for p in module.parameters():
88 | p.detach().mul_(scale)
89 | return module
90 |
91 |
92 | def mean_flat(tensor):
93 | """
94 | Take the mean over all non-batch dimensions.
95 | """
96 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
97 |
98 |
99 | def normalization(channels):
100 | """
101 | Make a standard normalization layer.
102 |
103 | :param channels: number of input channels.
104 | :return: an nn.Module for normalization.
105 | """
106 | return GroupNorm32(min(32, channels), channels)
107 |
108 |
109 | def timestep_embedding(timesteps, dim, max_period=10000):
110 | """
111 | Create sinusoidal timestep embeddings.
112 |
113 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
114 | These may be fractional.
115 | :param dim: the dimension of the output.
116 | :param max_period: controls the minimum frequency of the embeddings.
117 | :return: an [N x dim] Tensor of positional embeddings.
118 | """
119 | half = dim // 2
120 | freqs = th.exp(-math.log(max_period) *
121 | th.arange(start=0, end=half, dtype=th.float32) /
122 | half).to(device=timesteps.device)
123 | args = timesteps[:, None].float() * freqs[None]
124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125 | if dim % 2:
126 | embedding = th.cat(
127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128 | return embedding
129 |
130 |
131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133 | if flag:
134 | return torch.utils.checkpoint.checkpoint(
135 | func, *args, preserve_rng_state=preserve_rng_state)
136 | else:
137 | return func(*args)
138 |
--------------------------------------------------------------------------------
/diffae/model/unet_autoenc.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn.functional import silu
6 |
7 | from .latentnet import *
8 | from .unet import *
9 | from choices import *
10 |
11 |
12 | @dataclass
13 | class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14 | # number of style channels
15 | enc_out_channels: int = 512
16 | enc_attn_resolutions: Tuple[int] = None
17 | enc_pool: str = 'depthconv'
18 | enc_num_res_block: int = 2
19 | enc_channel_mult: Tuple[int] = None
20 | enc_grad_checkpoint: bool = False
21 | latent_net_conf: MLPSkipNetConfig = None
22 |
23 | def make_model(self):
24 | return BeatGANsAutoencModel(self)
25 |
26 |
27 | class BeatGANsAutoencModel(BeatGANsUNetModel):
28 | def __init__(self, conf: BeatGANsAutoencConfig):
29 | super().__init__(conf)
30 | self.conf = conf
31 |
32 | # having only time, cond
33 | self.time_embed = TimeStyleSeperateEmbed(
34 | time_channels=conf.model_channels,
35 | time_out_channels=conf.embed_channels,
36 | )
37 |
38 | self.encoder = BeatGANsEncoderConfig(
39 | image_size=conf.image_size,
40 | in_channels=conf.in_channels,
41 | model_channels=conf.model_channels,
42 | out_hid_channels=conf.enc_out_channels,
43 | out_channels=conf.enc_out_channels,
44 | num_res_blocks=conf.enc_num_res_block,
45 | attention_resolutions=(conf.enc_attn_resolutions
46 | or conf.attention_resolutions),
47 | dropout=conf.dropout,
48 | channel_mult=conf.enc_channel_mult or conf.channel_mult,
49 | use_time_condition=False,
50 | conv_resample=conf.conv_resample,
51 | dims=conf.dims,
52 | use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53 | num_heads=conf.num_heads,
54 | num_head_channels=conf.num_head_channels,
55 | resblock_updown=conf.resblock_updown,
56 | use_new_attention_order=conf.use_new_attention_order,
57 | pool=conf.enc_pool,
58 | ).make_model()
59 |
60 | if conf.latent_net_conf is not None:
61 | self.latent_net = conf.latent_net_conf.make_model()
62 |
63 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64 | """
65 | Reparameterization trick to sample from N(mu, var) from
66 | N(0,1).
67 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69 | :return: (Tensor) [B x D]
70 | """
71 | assert self.conf.is_stochastic
72 | std = torch.exp(0.5 * logvar)
73 | eps = torch.randn_like(std)
74 | return eps * std + mu
75 |
76 | def sample_z(self, n: int, device):
77 | assert self.conf.is_stochastic
78 | return torch.randn(n, self.conf.enc_out_channels, device=device)
79 |
80 | def noise_to_cond(self, noise: Tensor):
81 | raise NotImplementedError()
82 | assert self.conf.noise_net_conf is not None
83 | return self.noise_net.forward(noise)
84 |
85 | def encode(self, x):
86 | cond = self.encoder.forward(x)
87 | return cond
88 | # return {'cond': cond}
89 |
90 | @property
91 | def stylespace_sizes(self):
92 | modules = list(self.input_blocks.modules()) + list(
93 | self.middle_block.modules()) + list(self.output_blocks.modules())
94 | sizes = []
95 | for module in modules:
96 | if isinstance(module, ResBlock):
97 | linear = module.cond_emb_layers[-1]
98 | sizes.append(linear.weight.shape[0])
99 | return sizes
100 |
101 | def encode_stylespace(self, x, return_vector: bool = True):
102 | """
103 | encode to style space
104 | """
105 | modules = list(self.input_blocks.modules()) + list(
106 | self.middle_block.modules()) + list(self.output_blocks.modules())
107 | # (n, c)
108 | cond = self.encoder.forward(x)
109 | S = []
110 | for module in modules:
111 | if isinstance(module, ResBlock):
112 | # (n, c')
113 | s = module.cond_emb_layers.forward(cond)
114 | S.append(s)
115 |
116 | if return_vector:
117 | # (n, sum_c)
118 | return torch.cat(S, dim=1)
119 | else:
120 | return S
121 |
122 | def forward(self,
123 | x,
124 | t,
125 | y=None,
126 | x_start=None,
127 | cond=None,
128 | style=None,
129 | noise=None,
130 | t_cond=None,
131 | mix=False,
132 | ref_cond_scale=None,
133 | **kwargs):
134 | """
135 | Apply the model to an input batch.
136 |
137 | Args:
138 | x_start: the original image to encode
139 | cond: output of the encoder
140 | noise: random noise (to predict the cond)
141 | """
142 |
143 | if mix:
144 | assert type(cond) is dict
145 |
146 | if t_cond is None:
147 | t_cond = t
148 |
149 | if noise is not None:
150 | # if the noise is given, we predict the cond from noise
151 | cond = self.noise_to_cond(noise)
152 |
153 | if cond is None:
154 | if x is not None:
155 | assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
156 |
157 | tmp = self.encode(x_start)
158 | cond = tmp['cond']
159 |
160 | if t is not None:
161 | _t_emb = timestep_embedding(t, self.conf.model_channels)
162 | _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
163 | else:
164 | # this happens when training only autoenc
165 | _t_emb = None
166 | _t_cond_emb = None
167 |
168 | if self.conf.resnet_two_cond:
169 | if not mix:
170 | res = self.time_embed.forward(
171 | time_emb=_t_emb,
172 | cond=cond,
173 | time_cond_emb=_t_cond_emb,
174 | )
175 | else:
176 | res = self.time_embed.forward(
177 | time_emb=_t_emb,
178 | cond=None,
179 | time_cond_emb=_t_cond_emb,
180 | )
181 | else:
182 | raise NotImplementedError()
183 |
184 | if self.conf.resnet_two_cond:
185 | # two cond: first = time emb, second = cond_emb
186 | emb = res.time_emb
187 | cond_emb = res.emb
188 | else:
189 | # one cond = combined of both time and cond
190 | emb = res.emb
191 | cond_emb = None
192 |
193 | # override the style if given
194 | style = style or res.style
195 |
196 | assert (y is not None) == (
197 | self.conf.num_classes is not None
198 | ), "must specify y if and only if the model is class-conditional"
199 |
200 | if self.conf.num_classes is not None:
201 | raise NotImplementedError()
202 | # assert y.shape == (x.shape[0], )
203 | # emb = emb + self.label_emb(y)
204 |
205 | # where in the model to supply time conditions
206 | enc_time_emb = emb
207 | mid_time_emb = emb
208 | dec_time_emb = emb
209 | # where in the model to supply style conditions
210 | enc_cond_emb = cond_emb
211 | mid_cond_emb = cond_emb
212 | dec_cond_emb = cond_emb
213 |
214 | # hs = []
215 | hs = [[] for _ in range(len(self.conf.channel_mult))]
216 |
217 | if x is not None:
218 | h = x.type(self.dtype)
219 |
220 | # input blocks
221 | k = 0
222 | for i in range(len(self.input_num_blocks)):
223 | for j in range(self.input_num_blocks[i]):
224 | if not mix:
225 | h = self.input_blocks[k](h,
226 | emb=enc_time_emb,
227 | cond=enc_cond_emb)
228 | else:
229 | if i in ref_cond_scale:
230 | h = self.input_blocks[k](h,
231 | emb=enc_time_emb,
232 | cond=cond['ref'])
233 | else:
234 | h = self.input_blocks[k](h,
235 | emb=enc_time_emb,
236 | cond=cond['input'])
237 |
238 | # print(i, j, h.shape)
239 | hs[i].append(h)
240 | k += 1
241 | assert k == len(self.input_blocks)
242 |
243 | # middle blocks
244 | if not mix:
245 | h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
246 | else:
247 | h = self.middle_block(h, emb=mid_time_emb, cond=cond['input'])
248 | else:
249 | # no lateral connections
250 | # happens when training only the autonecoder
251 | h = None
252 | hs = [[] for _ in range(len(self.conf.channel_mult))]
253 |
254 | # output blocks
255 | k = 0
256 | n = len(self.output_num_blocks)
257 | if mix:
258 | ref_cond_scale_out = [n-(scale+1) for scale in ref_cond_scale]
259 | for i in range(len(self.output_num_blocks)):
260 | for j in range(self.output_num_blocks[i]):
261 | # take the lateral connection from the same layer (in reserve)
262 | # until there is no more, use None
263 | try:
264 | lateral = hs[-i - 1].pop()
265 | # print(i, j, lateral.shape)
266 | except IndexError:
267 | lateral = None
268 | # print(i, j, lateral)
269 |
270 | if not mix:
271 | h = self.output_blocks[k](h,
272 | emb=dec_time_emb,
273 | cond=dec_cond_emb,
274 | lateral=lateral)
275 | else:
276 | if i in ref_cond_scale_out:
277 | # if k == (len(self.output_num_blocks)*self.output_num_blocks[0]-1):
278 | h = self.output_blocks[k](h,
279 | emb=dec_time_emb,
280 | cond=cond['ref'],
281 | lateral=lateral)
282 | else:
283 | h = self.output_blocks[k](h,
284 | emb=dec_time_emb,
285 | cond=cond['input'],
286 | lateral=lateral)
287 | k += 1
288 |
289 | pred = self.out(h)
290 | return AutoencReturn(pred=pred, cond=cond)
291 |
292 |
293 | class AutoencReturn(NamedTuple):
294 | pred: Tensor
295 | cond: Tensor = None
296 |
297 |
298 | class EmbedReturn(NamedTuple):
299 | # style and time
300 | emb: Tensor = None
301 | # time only
302 | time_emb: Tensor = None
303 | # style only (but could depend on time)
304 | style: Tensor = None
305 |
306 |
307 | class TimeStyleSeperateEmbed(nn.Module):
308 | # embed only style
309 | def __init__(self, time_channels, time_out_channels):
310 | super().__init__()
311 | self.time_embed = nn.Sequential(
312 | linear(time_channels, time_out_channels),
313 | nn.SiLU(),
314 | linear(time_out_channels, time_out_channels),
315 | )
316 | self.style = nn.Identity()
317 |
318 | def forward(self, time_emb=None, cond=None, **kwargs):
319 | if time_emb is None:
320 | # happens with autoenc training mode
321 | time_emb = None
322 | else:
323 | time_emb = self.time_embed(time_emb)
324 | style = self.style(cond)
325 | return EmbedReturn(emb=style, time_emb=time_emb, style=style)
326 |
--------------------------------------------------------------------------------
/diffae/predict.py:
--------------------------------------------------------------------------------
1 | # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
2 | # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
3 | # bunzip2 shape_predictor_68_face_landmarks.dat.bz2
4 |
5 | import os
6 | import torch
7 | from torchvision.utils import save_image
8 | import tempfile
9 | from templates import *
10 | from templates_cls import *
11 | from experiment_classifier import ClsModel
12 | from align import LandmarksDetector, image_align
13 | from cog import BasePredictor, Path, Input, BaseModel
14 |
15 |
16 | class ModelOutput(BaseModel):
17 | image: Path
18 |
19 |
20 | class Predictor(BasePredictor):
21 | def setup(self):
22 | self.aligned_dir = "aligned"
23 | os.makedirs(self.aligned_dir, exist_ok=True)
24 | self.device = "cuda:0"
25 |
26 | # Model Initialization
27 | model_config = ffhq256_autoenc()
28 | self.model = LitModel(model_config)
29 | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
30 | self.model.load_state_dict(state["state_dict"], strict=False)
31 | self.model.ema_model.eval()
32 | self.model.ema_model.to(self.device)
33 |
34 | # Classifier Initialization
35 | classifier_config = ffhq256_autoenc_cls()
36 | classifier_config.pretrain = None # a bit faster
37 | self.classifier = ClsModel(classifier_config)
38 | state_class = torch.load(
39 | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
40 | )
41 | print("latent step:", state_class["global_step"])
42 | self.classifier.load_state_dict(state_class["state_dict"], strict=False)
43 | self.classifier.to(self.device)
44 |
45 | self.landmarks_detector = LandmarksDetector(
46 | "shape_predictor_68_face_landmarks.dat"
47 | )
48 |
49 | def predict(
50 | self,
51 | image: Path = Input(
52 | description="Input image for face manipulation. Image will be aligned and cropped, "
53 | "output aligned and manipulated images.",
54 | ),
55 | target_class: str = Input(
56 | default="Bangs",
57 | choices=[
58 | "5_o_Clock_Shadow",
59 | "Arched_Eyebrows",
60 | "Attractive",
61 | "Bags_Under_Eyes",
62 | "Bald",
63 | "Bangs",
64 | "Big_Lips",
65 | "Big_Nose",
66 | "Black_Hair",
67 | "Blond_Hair",
68 | "Blurry",
69 | "Brown_Hair",
70 | "Bushy_Eyebrows",
71 | "Chubby",
72 | "Double_Chin",
73 | "Eyeglasses",
74 | "Goatee",
75 | "Gray_Hair",
76 | "Heavy_Makeup",
77 | "High_Cheekbones",
78 | "Male",
79 | "Mouth_Slightly_Open",
80 | "Mustache",
81 | "Narrow_Eyes",
82 | "Beard",
83 | "Oval_Face",
84 | "Pale_Skin",
85 | "Pointy_Nose",
86 | "Receding_Hairline",
87 | "Rosy_Cheeks",
88 | "Sideburns",
89 | "Smiling",
90 | "Straight_Hair",
91 | "Wavy_Hair",
92 | "Wearing_Earrings",
93 | "Wearing_Hat",
94 | "Wearing_Lipstick",
95 | "Wearing_Necklace",
96 | "Wearing_Necktie",
97 | "Young",
98 | ],
99 | description="Choose manipulation direction.",
100 | ),
101 | manipulation_amplitude: float = Input(
102 | default=0.3,
103 | ge=-0.5,
104 | le=0.5,
105 | description="When set too strong it would result in artifact as it could dominate the original image information.",
106 | ),
107 | T_step: int = Input(
108 | default=100,
109 | choices=[50, 100, 125, 200, 250, 500],
110 | description="Number of step for generation.",
111 | ),
112 | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
113 | ) -> List[ModelOutput]:
114 |
115 | img_size = 256
116 | print("Aligning image...")
117 | for i, face_landmarks in enumerate(
118 | self.landmarks_detector.get_landmarks(str(image)), start=1
119 | ):
120 | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
121 |
122 | data = ImageDataset(
123 | self.aligned_dir,
124 | image_size=img_size,
125 | exts=["jpg", "jpeg", "JPG", "png"],
126 | do_augment=False,
127 | )
128 |
129 | print("Encoding and Manipulating the aligned image...")
130 | cls_manipulation_amplitude = manipulation_amplitude
131 | interpreted_target_class = target_class
132 | if (
133 | target_class not in CelebAttrDataset.id_to_cls
134 | and f"No_{target_class}" in CelebAttrDataset.id_to_cls
135 | ):
136 | cls_manipulation_amplitude = -manipulation_amplitude
137 | interpreted_target_class = f"No_{target_class}"
138 |
139 | batch = data[0]["img"][None]
140 |
141 | semantic_latent = self.model.encode(batch.to(self.device))
142 | stochastic_latent = self.model.encode_stochastic(
143 | batch.to(self.device), semantic_latent, T=T_inv
144 | )
145 |
146 | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
147 | class_direction = self.classifier.classifier.weight[cls_id]
148 | normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
149 |
150 | normalized_semantic_latent = self.classifier.normalize(semantic_latent)
151 | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
152 | normalized_manipulated_semantic_latent = (
153 | normalized_semantic_latent
154 | + normalized_manipulation_amp * normalized_class_direction
155 | )
156 |
157 | manipulated_semantic_latent = self.classifier.denormalize(
158 | normalized_manipulated_semantic_latent
159 | )
160 |
161 | # Render Manipulated image
162 | manipulated_img = self.model.render(
163 | stochastic_latent, manipulated_semantic_latent, T=T_step
164 | )[0]
165 | original_img = data[0]["img"]
166 |
167 | model_output = []
168 | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
169 | save_image(convert2rgb(original_img), str(out_path))
170 | model_output.append(ModelOutput(image=out_path))
171 |
172 | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
173 | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
174 | model_output.append(ModelOutput(image=out_path))
175 | return model_output
176 |
177 |
178 | def convert2rgb(img, adjust_scale=True):
179 | convert_img = torch.tensor(img)
180 | if adjust_scale:
181 | convert_img = (convert_img + 1) / 2
182 | return convert_img.cpu()
183 |
--------------------------------------------------------------------------------
/diffae/renderer.py:
--------------------------------------------------------------------------------
1 | from config import *
2 |
3 | from torch.cuda import amp
4 |
5 |
6 | def render_uncondition(conf: TrainConfig,
7 | model: BeatGANsAutoencModel,
8 | x_T,
9 | sampler: Sampler,
10 | latent_sampler: Sampler,
11 | conds_mean=None,
12 | conds_std=None,
13 | clip_latent_noise: bool = False):
14 | device = x_T.device
15 | if conf.train_mode == TrainMode.diffusion:
16 | assert conf.model_type.can_sample()
17 | return sampler.sample(model=model, noise=x_T)
18 | elif conf.train_mode.is_latent_diffusion():
19 | model: BeatGANsAutoencModel
20 | if conf.train_mode == TrainMode.latent_diffusion:
21 | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device)
22 | else:
23 | raise NotImplementedError()
24 |
25 | if clip_latent_noise:
26 | latent_noise = latent_noise.clip(-1, 1)
27 |
28 | cond = latent_sampler.sample(
29 | model=model.latent_net,
30 | noise=latent_noise,
31 | clip_denoised=conf.latent_clip_sample,
32 | )
33 |
34 | if conf.latent_znormalize:
35 | cond = cond * conds_std.to(device) + conds_mean.to(device)
36 |
37 | # the diffusion on the model
38 | return sampler.sample(model=model, noise=x_T, cond=cond)
39 | else:
40 | raise NotImplementedError()
41 |
42 |
43 | def render_condition(
44 | conf: TrainConfig,
45 | model: BeatGANsAutoencModel,
46 | x_T,
47 | sampler: Sampler,
48 | x_start=None,
49 | cond=None,
50 | ):
51 | if conf.train_mode == TrainMode.diffusion:
52 | assert conf.model_type.has_autoenc()
53 | # returns {'cond', 'cond2'}
54 | if cond is None:
55 | cond = model.encode(x_start)
56 | return sampler.sample(model=model,
57 | noise=x_T,
58 | model_kwargs={'cond': cond})
59 | else:
60 | raise NotImplementedError()
61 |
--------------------------------------------------------------------------------
/diffae/run_afhq256-dog.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # gpus = [0, 1, 2, 3, 4, 5, 6, 7]
6 | conf = ffhq256_autoenc()
7 |
8 | conf.data_name = 'afhq256-dog'
9 | conf.name = 'afhq256-dog_autoenc_old'
10 | conf.total_samples = 90_000_000
11 | conf.sample_every_samples = 500_000
12 | conf.eval_ema_every_samples = 10_000_000
13 | conf.eval_every_samples = 10_000_000
14 | conf.eval_num_images = 1000
15 |
16 | # train(conf, gpus=gpus)
17 |
18 | # infer the latents for training the latent DPM
19 | # NOTE: not gpu heavy, but more gpus can be of use!
20 | # gpus = [3]
21 | # conf.eval_programs = ['infer']
22 | # train(conf, gpus=gpus, mode='eval')
23 |
24 | # # train the latent DPM
25 | # # NOTE: only need a single gpu
26 | gpus = [3]
27 | conf = ffhq256_autoenc_latent()
28 | conf.data_name = 'afhq256-dog'
29 | conf.name = 'afhq256-dog_autoenc_latent_old'
30 | conf.pretrain = PretrainConfig(
31 | name='90M',
32 | path=f'checkpoints/afhq256-dog_autoenc_old/last.ckpt',
33 | )
34 | conf.latent_infer_path = f'checkpoints/afhq256-dog_autoenc_old/latent.pkl'
35 | conf.total_samples = 40_000_000
36 | conf.sample_every_samples = 5_000_000
37 | train(conf, gpus=gpus)
--------------------------------------------------------------------------------
/diffae/run_church256.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # gpus = [0, 1, 2, 3, 4, 5, 6, 7]
6 | conf = ffhq256_autoenc()
7 |
8 | conf.data_name = 'church256'
9 | conf.name = 'church256_autoenc'
10 | conf.total_samples = 90_000_000
11 | conf.sample_every_samples = 500_000
12 | conf.eval_ema_every_samples = 10_000_000
13 | conf.eval_every_samples = 10_000_000
14 |
15 | # train(conf, gpus=gpus)
16 |
17 | # infer the latents for training the latent DPM
18 | # NOTE: not gpu heavy, but more gpus can be of use!
19 | # gpus = [2]
20 | # conf.eval_programs = ['infer']
21 | # train(conf, gpus=gpus, mode='eval')
22 |
23 | # # train the latent DPM
24 | # # NOTE: only need a single gpu
25 | gpus = [2]
26 | conf = ffhq256_autoenc_latent()
27 | conf.data_name = 'church256'
28 | conf.name = 'church256_autoenc_latent'
29 | conf.pretrain = PretrainConfig(
30 | name='90M',
31 | path=f'checkpoints/church256_autoenc/last.ckpt',
32 | )
33 | conf.latent_infer_path = f'checkpoints/church256_autoenc/latent.pkl'
34 | conf.total_samples = 100_000_000
35 | conf.sample_every_samples = 5_000_000
36 | train(conf, gpus=gpus)
--------------------------------------------------------------------------------
/diffae/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([
10 | exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
11 | for x in range(window_size)
12 | ])
13 | return gauss / gauss.sum()
14 |
15 |
16 | def create_window(window_size, channel):
17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
18 | _2D_window = _1D_window.mm(
19 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
20 | window = Variable(
21 | _2D_window.expand(channel, 1, window_size, window_size).contiguous())
22 | return window
23 |
24 |
25 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
26 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
27 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
28 |
29 | mu1_sq = mu1.pow(2)
30 | mu2_sq = mu2.pow(2)
31 | mu1_mu2 = mu1 * mu2
32 |
33 | sigma1_sq = F.conv2d(
34 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
35 | sigma2_sq = F.conv2d(
36 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
37 | sigma12 = F.conv2d(
38 | img1 * img2, window, padding=window_size // 2,
39 | groups=channel) - mu1_mu2
40 |
41 | C1 = 0.01**2
42 | C2 = 0.03**2
43 |
44 | ssim_map = ((2 * mu1_mu2 + C1) *
45 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
46 | (sigma1_sq + sigma2_sq + C2))
47 |
48 | if size_average:
49 | return ssim_map.mean()
50 | else:
51 | return ssim_map.mean(1).mean(1).mean(1)
52 |
53 |
54 | class SSIM(torch.nn.Module):
55 | def __init__(self, window_size=11, size_average=True):
56 | super(SSIM, self).__init__()
57 | self.window_size = window_size
58 | self.size_average = size_average
59 | self.channel = 1
60 | self.window = create_window(window_size, self.channel)
61 |
62 | def forward(self, img1, img2):
63 | (_, channel, _, _) = img1.size()
64 |
65 | if channel == self.channel and self.window.data.type(
66 | ) == img1.data.type():
67 | window = self.window
68 | else:
69 | window = create_window(self.window_size, channel)
70 |
71 | if img1.is_cuda:
72 | window = window.cuda(img1.get_device())
73 | window = window.type_as(img1)
74 |
75 | self.window = window
76 | self.channel = channel
77 |
78 | return _ssim(img1, img2, window, self.window_size, channel,
79 | self.size_average)
80 |
81 |
82 | def ssim(img1, img2, window_size=11, size_average=True):
83 | (_, channel, _, _) = img1.size()
84 | window = create_window(window_size, channel)
85 |
86 | if img1.is_cuda:
87 | window = window.cuda(img1.get_device())
88 | window = window.type_as(img1)
89 |
90 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/diffae/templates.py:
--------------------------------------------------------------------------------
1 | from experiment import *
2 |
3 |
4 | def ddpm():
5 | """
6 | base configuration for all DDIM-based models.
7 | """
8 | conf = TrainConfig()
9 | conf.batch_size = 32
10 | conf.beatgans_gen_type = GenerativeType.ddim
11 | conf.beta_scheduler = 'linear'
12 | conf.data_name = 'ffhq'
13 | conf.diffusion_type = 'beatgans'
14 | conf.eval_ema_every_samples = 200_000
15 | conf.eval_every_samples = 200_000
16 | conf.fp16 = True
17 | conf.lr = 1e-4
18 | conf.model_name = ModelName.beatgans_ddpm
19 | conf.net_attn = (16, )
20 | conf.net_beatgans_attn_head = 1
21 | conf.net_beatgans_embed_channels = 512
22 | conf.net_ch_mult = (1, 2, 4, 8)
23 | conf.net_ch = 64
24 | conf.sample_size = 32
25 | conf.T_eval = 20
26 | conf.T = 1000
27 | conf.make_model_conf()
28 | return conf
29 |
30 |
31 | def autoenc_base():
32 | """
33 | base configuration for all Diff-AE models.
34 | """
35 | conf = TrainConfig()
36 | conf.batch_size = 32
37 | conf.beatgans_gen_type = GenerativeType.ddim
38 | conf.beta_scheduler = 'linear'
39 | conf.data_name = 'ffhq'
40 | conf.diffusion_type = 'beatgans'
41 | conf.eval_ema_every_samples = 200_000
42 | conf.eval_every_samples = 200_000
43 | conf.fp16 = True
44 | conf.lr = 1e-4
45 | conf.model_name = ModelName.beatgans_autoenc
46 | conf.net_attn = (16, )
47 | conf.net_beatgans_attn_head = 1
48 | conf.net_beatgans_embed_channels = 512
49 | conf.net_beatgans_resnet_two_cond = True
50 | conf.net_ch_mult = (1, 2, 4, 8)
51 | conf.net_ch = 64
52 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53 | conf.net_enc_pool = 'adaptivenonzero'
54 | conf.sample_size = 32
55 | conf.T_eval = 20
56 | conf.T = 1000
57 | conf.make_model_conf()
58 | return conf
59 |
60 |
61 | def ffhq64_ddpm():
62 | conf = ddpm()
63 | conf.data_name = 'ffhqlmdb256'
64 | conf.warmup = 0
65 | conf.total_samples = 72_000_000
66 | conf.scale_up_gpus(4)
67 | return conf
68 |
69 |
70 | def ffhq64_autoenc():
71 | conf = autoenc_base()
72 | conf.data_name = 'ffhqlmdb256'
73 | conf.warmup = 0
74 | conf.total_samples = 72_000_000
75 | conf.net_ch_mult = (1, 2, 4, 8)
76 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
77 | conf.eval_every_samples = 1_000_000
78 | conf.eval_ema_every_samples = 1_000_000
79 | conf.scale_up_gpus(4)
80 | conf.make_model_conf()
81 | return conf
82 |
83 |
84 | def celeba64d2c_ddpm():
85 | conf = ffhq128_ddpm()
86 | conf.data_name = 'celebalmdb'
87 | conf.eval_every_samples = 10_000_000
88 | conf.eval_ema_every_samples = 10_000_000
89 | conf.total_samples = 72_000_000
90 | conf.name = 'celeba64d2c_ddpm'
91 | return conf
92 |
93 |
94 | def celeba64d2c_autoenc():
95 | conf = ffhq64_autoenc()
96 | conf.data_name = 'celebalmdb'
97 | conf.eval_every_samples = 10_000_000
98 | conf.eval_ema_every_samples = 10_000_000
99 | conf.total_samples = 72_000_000
100 | conf.name = 'celeba64d2c_autoenc'
101 | return conf
102 |
103 |
104 | def ffhq128_ddpm():
105 | conf = ddpm()
106 | conf.data_name = 'ffhqlmdb256'
107 | conf.warmup = 0
108 | conf.total_samples = 48_000_000
109 | conf.img_size = 128
110 | conf.net_ch = 128
111 | # channels:
112 | # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
113 | # sizes:
114 | # 128 => 128 => 64 => 32 => 16 => 8
115 | conf.net_ch_mult = (1, 1, 2, 3, 4)
116 | conf.eval_every_samples = 1_000_000
117 | conf.eval_ema_every_samples = 1_000_000
118 | conf.scale_up_gpus(4)
119 | conf.eval_ema_every_samples = 10_000_000
120 | conf.eval_every_samples = 10_000_000
121 | conf.make_model_conf()
122 | return conf
123 |
124 |
125 | def ffhq128_autoenc_base():
126 | conf = autoenc_base()
127 | conf.data_name = 'ffhqlmdb256'
128 | conf.scale_up_gpus(4)
129 | conf.img_size = 128
130 | conf.net_ch = 128
131 | # final resolution = 8x8
132 | conf.net_ch_mult = (1, 1, 2, 3, 4)
133 | # final resolution = 4x4
134 | conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
135 | conf.eval_ema_every_samples = 10_000_000
136 | conf.eval_every_samples = 10_000_000
137 | conf.make_model_conf()
138 | return conf
139 |
140 |
141 | def ffhq256_autoenc():
142 | conf = ffhq128_autoenc_base()
143 | conf.img_size = 256
144 | conf.net_ch = 128
145 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
146 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
147 | conf.eval_every_samples = 10_000_000
148 | conf.eval_ema_every_samples = 10_000_000
149 | conf.total_samples = 200_000_000
150 | conf.batch_size = 64
151 | conf.make_model_conf()
152 | conf.name = 'ffhq256_autoenc'
153 | return conf
154 |
155 |
156 | def ffhq256_autoenc_eco():
157 | conf = ffhq128_autoenc_base()
158 | conf.img_size = 256
159 | conf.net_ch = 128
160 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
161 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
162 | conf.eval_every_samples = 10_000_000
163 | conf.eval_ema_every_samples = 10_000_000
164 | conf.total_samples = 200_000_000
165 | conf.batch_size = 64
166 | conf.make_model_conf()
167 | conf.name = 'ffhq256_autoenc_eco'
168 | return conf
169 |
170 |
171 | def ffhq128_ddpm_72M():
172 | conf = ffhq128_ddpm()
173 | conf.total_samples = 72_000_000
174 | conf.name = 'ffhq128_ddpm_72M'
175 | return conf
176 |
177 |
178 | def ffhq128_autoenc_72M():
179 | conf = ffhq128_autoenc_base()
180 | conf.total_samples = 72_000_000
181 | conf.name = 'ffhq128_autoenc_72M'
182 | return conf
183 |
184 |
185 | def ffhq128_ddpm_130M():
186 | conf = ffhq128_ddpm()
187 | conf.total_samples = 130_000_000
188 | conf.eval_ema_every_samples = 10_000_000
189 | conf.eval_every_samples = 10_000_000
190 | conf.name = 'ffhq128_ddpm_130M'
191 | return conf
192 |
193 |
194 | def ffhq128_autoenc_130M():
195 | conf = ffhq128_autoenc_base()
196 | conf.total_samples = 130_000_000
197 | conf.eval_ema_every_samples = 10_000_000
198 | conf.eval_every_samples = 10_000_000
199 | conf.name = 'ffhq128_autoenc_130M'
200 | return conf
201 |
202 |
203 | def horse128_ddpm():
204 | conf = ffhq128_ddpm()
205 | conf.data_name = 'horse256'
206 | conf.total_samples = 130_000_000
207 | conf.eval_ema_every_samples = 10_000_000
208 | conf.eval_every_samples = 10_000_000
209 | conf.name = 'horse128_ddpm'
210 | return conf
211 |
212 |
213 | def horse128_autoenc():
214 | conf = ffhq128_autoenc_base()
215 | conf.data_name = 'horse256'
216 | conf.total_samples = 130_000_000
217 | conf.eval_ema_every_samples = 10_000_000
218 | conf.eval_every_samples = 10_000_000
219 | conf.name = 'horse128_autoenc'
220 | return conf
221 |
222 |
223 | def bedroom128_ddpm():
224 | conf = ffhq128_ddpm()
225 | conf.data_name = 'bedroom256'
226 | conf.eval_ema_every_samples = 10_000_000
227 | conf.eval_every_samples = 10_000_000
228 | conf.total_samples = 120_000_000
229 | conf.name = 'bedroom128_ddpm'
230 | return conf
231 |
232 |
233 | def bedroom128_autoenc():
234 | conf = ffhq128_autoenc_base()
235 | conf.data_name = 'bedroom256'
236 | conf.eval_ema_every_samples = 10_000_000
237 | conf.eval_every_samples = 10_000_000
238 | conf.total_samples = 120_000_000
239 | conf.name = 'bedroom128_autoenc'
240 | return conf
241 |
242 |
243 | def pretrain_celeba64d2c_72M():
244 | conf = celeba64d2c_autoenc()
245 | conf.pretrain = PretrainConfig(
246 | name='72M',
247 | path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
248 | )
249 | conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
250 | return conf
251 |
252 |
253 | def pretrain_ffhq128_autoenc72M():
254 | conf = ffhq128_autoenc_base()
255 | conf.postfix = ''
256 | conf.pretrain = PretrainConfig(
257 | name='72M',
258 | path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
259 | )
260 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
261 | return conf
262 |
263 |
264 | def pretrain_ffhq128_autoenc130M():
265 | conf = ffhq128_autoenc_base()
266 | conf.pretrain = PretrainConfig(
267 | name='130M',
268 | path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
269 | )
270 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
271 | return conf
272 |
273 |
274 | def pretrain_ffhq256_autoenc():
275 | conf = ffhq256_autoenc()
276 | conf.pretrain = PretrainConfig(
277 | name='90M',
278 | path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
279 | )
280 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
281 | return conf
282 |
283 |
284 | def pretrain_horse128():
285 | conf = horse128_autoenc()
286 | conf.pretrain = PretrainConfig(
287 | name='82M',
288 | path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
289 | )
290 | conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
291 | return conf
292 |
293 |
294 | def pretrain_bedroom128():
295 | conf = bedroom128_autoenc()
296 | conf.pretrain = PretrainConfig(
297 | name='120M',
298 | path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
299 | )
300 | conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
301 | return conf
302 |
--------------------------------------------------------------------------------
/diffae/templates_cls.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 |
3 |
4 | def ffhq128_autoenc_cls():
5 | conf = ffhq128_autoenc_130M()
6 | conf.train_mode = TrainMode.manipulate
7 | conf.manipulate_mode = ManipulateMode.celebahq_all
8 | conf.manipulate_znormalize = True
9 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
10 | conf.batch_size = 32
11 | conf.lr = 1e-3
12 | conf.total_samples = 300_000
13 | # use the pretraining trick instead of contiuning trick
14 | conf.pretrain = PretrainConfig(
15 | '130M',
16 | f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
17 | )
18 | conf.name = 'ffhq128_autoenc_cls'
19 | return conf
20 |
21 |
22 | def ffhq256_autoenc_cls():
23 | '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels'''
24 | conf = ffhq256_autoenc()
25 | conf.train_mode = TrainMode.manipulate
26 | conf.manipulate_mode = ManipulateMode.celebahq_all
27 | conf.manipulate_znormalize = True
28 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ
29 | conf.batch_size = 32
30 | conf.lr = 1e-3
31 | conf.total_samples = 300_000
32 | # use the pretraining trick instead of contiuning trick
33 | conf.pretrain = PretrainConfig(
34 | '130M',
35 | f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
36 | )
37 | conf.name = 'ffhq256_autoenc_cls'
38 | return conf
39 |
--------------------------------------------------------------------------------
/diffae/templates_latent.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 |
3 |
4 | def latent_diffusion_config(conf: TrainConfig):
5 | conf.batch_size = 128
6 | conf.train_mode = TrainMode.latent_diffusion
7 | conf.latent_gen_type = GenerativeType.ddim
8 | conf.latent_loss_type = LossType.mse
9 | conf.latent_model_mean_type = ModelMeanType.eps
10 | conf.latent_model_var_type = ModelVarType.fixed_large
11 | conf.latent_rescale_timesteps = False
12 | conf.latent_clip_sample = False
13 | conf.latent_T_eval = 20
14 | conf.latent_znormalize = True
15 | conf.total_samples = 96_000_000
16 | conf.sample_every_samples = 400_000
17 | conf.eval_every_samples = 20_000_000
18 | conf.eval_ema_every_samples = 20_000_000
19 | conf.save_every_samples = 2_000_000
20 | return conf
21 |
22 |
23 | def latent_diffusion128_config(conf: TrainConfig):
24 | conf = latent_diffusion_config(conf)
25 | conf.batch_size_eval = 32
26 | return conf
27 |
28 |
29 | def latent_mlp_2048_norm_10layers(conf: TrainConfig):
30 | conf.net_latent_net_type = LatentNetType.skip
31 | conf.net_latent_layers = 10
32 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
33 | conf.net_latent_activation = Activation.silu
34 | conf.net_latent_num_hid_channels = 2048
35 | conf.net_latent_use_norm = True
36 | conf.net_latent_condition_bias = 1
37 | return conf
38 |
39 |
40 | def latent_mlp_2048_norm_20layers(conf: TrainConfig):
41 | conf = latent_mlp_2048_norm_10layers(conf)
42 | conf.net_latent_layers = 20
43 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
44 | return conf
45 |
46 |
47 | def latent_256_batch_size(conf: TrainConfig):
48 | conf.batch_size = 256
49 | conf.eval_ema_every_samples = 100_000_000
50 | conf.eval_every_samples = 100_000_000
51 | conf.sample_every_samples = 1_000_000
52 | conf.save_every_samples = 2_000_000
53 | conf.total_samples = 301_000_000
54 | return conf
55 |
56 |
57 | def latent_512_batch_size(conf: TrainConfig):
58 | conf.batch_size = 512
59 | conf.eval_ema_every_samples = 100_000_000
60 | conf.eval_every_samples = 100_000_000
61 | conf.sample_every_samples = 1_000_000
62 | conf.save_every_samples = 5_000_000
63 | conf.total_samples = 501_000_000
64 | return conf
65 |
66 |
67 | def latent_2048_batch_size(conf: TrainConfig):
68 | conf.batch_size = 2048
69 | conf.eval_ema_every_samples = 200_000_000
70 | conf.eval_every_samples = 200_000_000
71 | conf.sample_every_samples = 4_000_000
72 | conf.save_every_samples = 20_000_000
73 | conf.total_samples = 1_501_000_000
74 | return conf
75 |
76 |
77 | def adamw_weight_decay(conf: TrainConfig):
78 | conf.optimizer = OptimizerType.adamw
79 | conf.weight_decay = 0.01
80 | return conf
81 |
82 |
83 | def ffhq128_autoenc_latent():
84 | conf = pretrain_ffhq128_autoenc130M()
85 | conf = latent_diffusion128_config(conf)
86 | conf = latent_mlp_2048_norm_10layers(conf)
87 | conf = latent_256_batch_size(conf)
88 | conf = adamw_weight_decay(conf)
89 | conf.total_samples = 101_000_000
90 | conf.latent_loss_type = LossType.l1
91 | conf.latent_beta_scheduler = 'const0.008'
92 | conf.name = 'ffhq128_autoenc_latent'
93 | return conf
94 |
95 |
96 | def ffhq256_autoenc_latent():
97 | conf = pretrain_ffhq256_autoenc()
98 | conf = latent_diffusion128_config(conf)
99 | conf = latent_mlp_2048_norm_10layers(conf)
100 | conf = latent_256_batch_size(conf)
101 | conf = adamw_weight_decay(conf)
102 | conf.total_samples = 101_000_000
103 | conf.latent_loss_type = LossType.l1
104 | conf.latent_beta_scheduler = 'const0.008'
105 | conf.eval_ema_every_samples = 200_000_000
106 | conf.eval_every_samples = 200_000_000
107 | conf.sample_every_samples = 4_000_000
108 | conf.name = 'ffhq256_autoenc_latent'
109 | return conf
110 |
111 |
112 | def horse128_autoenc_latent():
113 | conf = pretrain_horse128()
114 | conf = latent_diffusion128_config(conf)
115 | conf = latent_2048_batch_size(conf)
116 | conf = latent_mlp_2048_norm_20layers(conf)
117 | conf.total_samples = 2_001_000_000
118 | conf.latent_beta_scheduler = 'const0.008'
119 | conf.latent_loss_type = LossType.l1
120 | conf.name = 'horse128_autoenc_latent'
121 | return conf
122 |
123 |
124 | def bedroom128_autoenc_latent():
125 | conf = pretrain_bedroom128()
126 | conf = latent_diffusion128_config(conf)
127 | conf = latent_2048_batch_size(conf)
128 | conf = latent_mlp_2048_norm_20layers(conf)
129 | conf.total_samples = 2_001_000_000
130 | conf.latent_beta_scheduler = 'const0.008'
131 | conf.latent_loss_type = LossType.l1
132 | conf.name = 'bedroom128_autoenc_latent'
133 | return conf
134 |
135 |
136 | def celeba64d2c_autoenc_latent():
137 | conf = pretrain_celeba64d2c_72M()
138 | conf = latent_diffusion_config(conf)
139 | conf = latent_512_batch_size(conf)
140 | conf = latent_mlp_2048_norm_10layers(conf)
141 | conf = adamw_weight_decay(conf)
142 | # just for the name
143 | conf.continue_from = PretrainConfig('200M',
144 | f'log-latent/{conf.name}/last.ckpt')
145 | conf.postfix = '_300M'
146 | conf.total_samples = 301_000_000
147 | conf.latent_beta_scheduler = 'const0.008'
148 | conf.latent_loss_type = LossType.l1
149 | conf.name = 'celeba64d2c_autoenc_latent'
150 | return conf
151 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: osasis
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.8.10
7 | - pip=21.2.4
8 | - mpi4py==3.1.5
9 | - pip:
10 | - blobfile==2.1.1
11 | - lpips==0.1.4
12 | - torch==1.13.0
13 | - torchvision==0.14.0
14 | - pytorch-lightning==1.8.6
15 | - tqdm==4.64.1
16 | - git+https://github.com/openai/CLIP.git
17 | - pandas==1.4.4
18 | - lmdb==1.4.0
19 | - pytorch-fid==0.2.1
--------------------------------------------------------------------------------
/eval_diffaeB.py:
--------------------------------------------------------------------------------
1 | from utils.args import make_args
2 | from utils.tester import DiffFSTester
3 |
4 |
5 | def main(args):
6 | tester = DiffFSTester(args)
7 | tester.infer_image_all()
8 |
9 |
10 | if __name__=='__main__':
11 | args = make_args()
12 | main(args)
--------------------------------------------------------------------------------
/gen_style_domA.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 | import glob
9 | import random
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | from tqdm import tqdm
14 | from PIL import Image
15 | import lpips
16 |
17 | import torch as th
18 | import torchvision.transforms as transforms
19 | from torchvision import utils
20 |
21 | from P2_weighting.guided_diffusion import dist_util, logger
22 | from P2_weighting.guided_diffusion.script_util import (
23 | model_and_diffusion_defaults,
24 | create_model_and_diffusion,
25 | add_dict_to_argparser,
26 | args_to_dict,
27 | )
28 |
29 | """
30 | Using P2 / rtaio_0.5 / eta 1.0 / respacing 50
31 | """
32 |
33 |
34 | def main():
35 | args = create_argparser().parse_args()
36 |
37 | # set seed
38 | random.seed(args.seed)
39 | np.random.seed(args.seed)
40 | th.manual_seed(args.seed)
41 |
42 | dist_util.setup_dist()
43 | logger.configure(dir=args.sample_dir)
44 |
45 | logger.log("creating model and diffusion...")
46 | model, diffusion = create_model_and_diffusion(
47 | **args_to_dict(args, model_and_diffusion_defaults().keys())
48 | )
49 | model.load_state_dict(
50 | dist_util.load_state_dict(args.model_path, map_location="cpu")
51 | )
52 | model.to(dist_util.dev())
53 | if args.use_fp16:
54 | model.convert_to_fp16()
55 | model.eval()
56 |
57 | logger.log("sampling...")
58 |
59 | device = dist_util.dev()
60 | transform_256 = transforms.Compose([
61 | transforms.Resize((256,256)),
62 | transforms.ToTensor(),
63 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
64 | ])
65 |
66 | # imgs_ref_domainB = glob.glob(f'{args.input_dir}/*')
67 | imgs_ref_domainB = [f'{args.input_dir}/art_{args.seed}.png']
68 | percept_loss = lpips.LPIPS(net='vgg').to(device)
69 | l1 = th.nn.L1Loss(reduction='none')
70 |
71 | # for img_file in tqdm(imgs_ref_domainB):
72 | # # file_name = img_file.split('/')[-1]
73 | # file_name = Path(img_file).name
74 |
75 | img_file = os.path.join(args.input_dir, args.img_name)
76 | img_ref_B = Image.open(img_file).convert('RGB')
77 | img_ref_B = transform_256(img_ref_B)
78 | img_ref_B = img_ref_B.unsqueeze(0).to(device)
79 | img_ref_B_all = img_ref_B.repeat(args.n,1,1,1)
80 |
81 | t_start = int(diffusion.num_timesteps*args.t_start_ratio)
82 | t_start = th.tensor([t_start], device=device)
83 |
84 | # forward DDPM
85 | xt = diffusion.q_sample(img_ref_B_all.clone(), t_start)
86 |
87 | # reverse DDPM
88 | indices = list(range(t_start))[::-1]
89 | for i in indices:
90 | t = th.tensor([i] * img_ref_B_all.shape[0], device=device)
91 | with th.no_grad():
92 | out = diffusion.p_sample(
93 | model,
94 | xt,
95 | t,
96 | clip_denoised=True,
97 | denoised_fn=None,
98 | cond_fn=None,
99 | model_kwargs=None,
100 | )
101 | xt = out["sample"]
102 |
103 | # # reverse DDIM
104 | # indices = list(range(t_start))[::-1]
105 | # for i in indices:
106 | # t = th.tensor([i] * img_ref_B_all.shape[0], device=device)
107 | # with th.no_grad():
108 | # out = diffusion.ddim_sample(
109 | # model,
110 | # xt,
111 | # t,
112 | # clip_denoised=True,
113 | # denoised_fn=None,
114 | # cond_fn=None,
115 | # model_kwargs=None,
116 | # )
117 | # xt = out["sample"]
118 |
119 | # compute loss
120 | # from torchvision.utils import save_image
121 | # save_image(xt/2+0.5, os.path.join('step1_tmp', file_name))
122 | l1_loss = l1(xt, img_ref_B.repeat(int(args.n),1,1,1))
123 | l1_loss = l1_loss.mean(dim=(1,2,3))
124 | lpips_loss = percept_loss(xt, img_ref_B.repeat(int(args.n),1,1,1)).squeeze()
125 | loss = 10*l1_loss + lpips_loss
126 |
127 | # pick best image
128 | img_idx = th.argmin(loss)
129 | img_ref_A = xt[img_idx]
130 |
131 | os.makedirs(args.sample_dir, exist_ok=True)
132 | utils.save_image(img_ref_A/2+0.5, os.path.join(args.sample_dir, args.img_name))
133 |
134 | def create_argparser():
135 | defaults = dict(
136 | model_path="",
137 | input_dir="",
138 | sample_dir="",
139 | img_name="",
140 | n=1,
141 | t_start_ratio=0.5,
142 | eta=1.0,
143 | seed=1
144 | )
145 | defaults.update(model_and_diffusion_defaults())
146 | parser = argparse.ArgumentParser()
147 | add_dict_to_argparser(parser, defaults)
148 | return parser
149 |
150 |
151 | if __name__ == "__main__":
152 | main()
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs/teaser.jpg
--------------------------------------------------------------------------------
/imgs_input_domA/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img1.png
--------------------------------------------------------------------------------
/imgs_input_domA/img2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img2.png
--------------------------------------------------------------------------------
/imgs_input_domA/img3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img3.png
--------------------------------------------------------------------------------
/imgs_input_domA/img4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img4.png
--------------------------------------------------------------------------------
/imgs_input_domA/img5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img5.png
--------------------------------------------------------------------------------
/imgs_input_domA/img6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_input_domA/img6.png
--------------------------------------------------------------------------------
/imgs_style_domB/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img1.png
--------------------------------------------------------------------------------
/imgs_style_domB/img2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img2.png
--------------------------------------------------------------------------------
/imgs_style_domB/img3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img3.png
--------------------------------------------------------------------------------
/imgs_style_domB/img4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img4.png
--------------------------------------------------------------------------------
/imgs_style_domB/img5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hansam95/OSASIS/122fc2940b28cd8c1effd463df0b25c83acfe15b/imgs_style_domB/img5.png
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | # (step 3) Evaluate
2 |
3 | DEVICE=0
4 |
5 | CUDA_VISIBLE_DEVICES=${DEVICE} \
6 | python eval_diffaeB.py \
7 | --style_domB_dir imgs_style_domB \
8 | --infer_dir imgs_input_domA \
9 | --ref_img img1.png \
10 | --work_dir exp/img1 \
11 | --map_net \
12 | --map_time \
13 | --lambda_map 0.1
--------------------------------------------------------------------------------
/scripts/prepare_train.sh:
--------------------------------------------------------------------------------
1 | # (step 1) download p2 weighting ffhq_p2.pt and run code
2 |
3 | DEVICE=0
4 |
5 | SAMPLE_FLAGS="--attention_resolutions 16 --class_cond False --class_cond False --diffusion_steps 1000 --dropout 0.0 \
6 | --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_res_blocks 1 --num_head_channels 64 \
7 | --resblock_updown True --use_fp16 False --use_scale_shift_norm True --timestep_respacing 50"
8 |
9 | CUDA_VISIBLE_DEVICES=${DEVICE} \
10 | python gen_style_domA.py ${SAMPLE_FLAGS} \
11 | --model_path P2_weighting/models/ffhq_p2.pt \
12 | --input_dir imgs_style_domB \
13 | --sample_dir imgs_style_domA \
14 | --img_name img1.png \
15 | --n 1 \
16 | --t_start_ratio 0.5 \
17 | --seed 1 \
18 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | # download pretrained DiffAE
2 | # batch size 8 -> 34GB / single A100 about 30min
3 | # (step 2)
4 |
5 | DEVICE=0
6 |
7 | CUDA_VISIBLE_DEVICES=${DEVICE} \
8 | python train_diffaeB.py \
9 | --style_domA_dir imgs_style_domA \
10 | --style_domB_dir imgs_style_domB \
11 | --ref_img img1.png \
12 | --work_dir exp/img1 \
13 | --n_iter 200 \
14 | --ckpt_freq 200 \
15 | --batch_size 8 \
16 | --map_net \
17 | --map_time \
18 | --lambda_map 0.1 \
19 | --train
20 |
--------------------------------------------------------------------------------
/train_diffaeB.py:
--------------------------------------------------------------------------------
1 | from utils.args import make_args
2 | from utils.trainer import DiffFSTrainer
3 |
4 |
5 | def main(args):
6 | trainer = DiffFSTrainer(args)
7 | trainer.train()
8 |
9 |
10 | if __name__ == '__main__':
11 | args = make_args()
12 | main(args)
--------------------------------------------------------------------------------
/utils/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from pathlib import Path
5 |
6 | def make_args():
7 | parser = argparse.ArgumentParser()
8 | # directory
9 | parser.add_argument('--diffae_ckpt', default='diffae/checkpoints', type=str, help='diffusion autoencoder checkpoints')
10 | parser.add_argument('--style_domA_dir', default='imgs_style_domA', type=str, help='style images (domain A)')
11 | parser.add_argument('--style_domB_dir', default='imgs_style_domB', type=str, help='style images (domain B)')
12 | parser.add_argument('--infer_dir', default='imgs_input_domA', type=str, help='input images (domain A)')
13 | # data
14 | parser.add_argument('--ref_img', default='digital_painting_jing.png', type=str, help='reference image file name')
15 | parser.add_argument('--work_dir', default='exp0', type=str, help='experiment working directory')
16 | # train
17 | parser.add_argument('--seed', default=0, type=int, help='random seed')
18 | parser.add_argument('--n_iter', default=200, type=int, help='training iteration')
19 | parser.add_argument('--batch_size', default=8, type=int, help='batch size')
20 | parser.add_argument('--lr', default=5e-6, type=float, help='learning rate')
21 | # reverse process
22 | parser.add_argument('--T_train_for', default=50, type=int, help='forward timesteps during training')
23 | parser.add_argument('--T_train_back', default=20, type=int, help='backward timesteps during training')
24 | parser.add_argument('--T_infer_for', default=100, type=int, help='forward timesteps during inference')
25 | parser.add_argument('--T_infer_back', default=50, type=int, help='backward timesteps during inference')
26 | parser.add_argument('--T_latent', default=200, type=int, help='latent ddim teimsteps')
27 | parser.add_argument('--t0_ratio', default=0.5, type=float, help='return step ratio')
28 | # loss coefficients
29 | parser.add_argument('--cross_dom', default=1, type=float, help='cross domain loss')
30 | parser.add_argument('--in_dom', default=0.5, type=float, help='in domain loss')
31 | parser.add_argument('--recon_clip', default=30, type=float, help='clip reconstruction loss')
32 | parser.add_argument('--recon_l1', default=10, type=float, help='l1 reconstruction')
33 | parser.add_argument('--recon_lpips', default=10, type=float, help='lpips reconstruction')
34 | # utils
35 | parser.add_argument('--print_freq', default=10, type=int)
36 | parser.add_argument('--train_img_freq', default=50, type=int)
37 | parser.add_argument('--ckpt_freq', default=200, type=int)
38 | parser.add_argument('--train', action='store_true', help='train flag')
39 | # mapping net
40 | parser.add_argument('--map_net', action='store_true', help='using mappingnet')
41 | parser.add_argument('--map_time', action='store_true', help='using mappingnet time')
42 | parser.add_argument('--lambda_map', default=0.1, type=float, help='weigth of mapping net output')
43 |
44 | args = parser.parse_args()
45 |
46 | args.device = 'cuda'
47 | args.ref_img_name = Path(args.ref_img).stem
48 |
49 | if args.train:
50 | os.makedirs(args.work_dir, exist_ok=True)
51 | with open(os.path.join(args.work_dir, 'args.txt'), 'w') as f:
52 | json.dump(args.__dict__, f, indent=2)
53 |
54 | return args
55 |
56 |
--------------------------------------------------------------------------------
/utils/map_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from diffae.model.nn import timestep_embedding
4 |
5 | class MappingNet(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | self.layers = nn.Sequential(
10 | nn.Conv2d(3, 128, kernel_size=1),
11 |
12 | nn.GroupNorm(32, 128),
13 | nn.SiLU(),
14 | nn.Conv2d(128, 256, kernel_size=1),
15 |
16 | nn.GroupNorm(32, 256),
17 | nn.SiLU(),
18 | nn.Conv2d(256, 128, kernel_size=1),
19 |
20 | nn.GroupNorm(32, 128),
21 | nn.SiLU(),
22 | nn.Conv2d(128, 3, kernel_size=1)
23 | )
24 |
25 | def forward(self, x):
26 | x = self.layers(x)
27 | return x
28 |
29 |
30 | class MappingNetTime(nn.Module):
31 | def __init__(self):
32 | super().__init__()
33 |
34 | self.in_layer = nn.Sequential(
35 | nn.Conv2d(3, 128, kernel_size=1),
36 | )
37 |
38 | self.block1 = MappingNetTimeBlock(
39 | in_ch = 128,
40 | out_ch = 256,
41 | t_emb_dim = 512
42 | )
43 |
44 | self.block2 = MappingNetTimeBlock(
45 | in_ch = 256,
46 | out_ch = 128,
47 | t_emb_dim = 512
48 | )
49 |
50 | self.out_layer = nn.Sequential(
51 | nn.Conv2d(128, 3, kernel_size=1)
52 | )
53 |
54 | self.time_embed = nn.Sequential(
55 | nn.Linear(128, 512),
56 | nn.SiLU(),
57 | nn.Linear(512, 512),
58 | )
59 |
60 | def forward(self, x, t):
61 | t_emb = timestep_embedding(t, 128)
62 | t_emb = self.time_embed(t_emb)
63 |
64 | x = self.in_layer(x)
65 | x = self.block1(x, t_emb)
66 | x = self.block2(x, t_emb)
67 | x = self.out_layer(x)
68 |
69 | return x
70 |
71 |
72 | class MappingNetTimeBlock(nn.Module):
73 | def __init__(self, in_ch, out_ch, t_emb_dim):
74 | super().__init__()
75 |
76 | self.pre_layer = nn.Sequential(
77 | nn.GroupNorm(32, in_ch),
78 | nn.SiLU(),
79 | nn.Conv2d(in_ch, out_ch, kernel_size=1),
80 | nn.GroupNorm(32, out_ch),
81 | )
82 |
83 | self.post_layer = nn.Sequential(
84 | nn.SiLU(),
85 | nn.Conv2d(out_ch, out_ch, kernel_size=1)
86 | )
87 |
88 | self.emb_layer = nn.Sequential(
89 | nn.SiLU(),
90 | nn.Linear(t_emb_dim, out_ch*2)
91 | )
92 |
93 | def forward(self, x, t_emb, scale_bias:float=1):
94 |
95 | t_emb = self.emb_layer(t_emb)
96 |
97 | # match shape
98 | while len(t_emb.shape) < len(x.shape):
99 | t_emb = t_emb[..., None]
100 |
101 | scale, shift = torch.chunk(t_emb, 2, dim=1)
102 |
103 | x = self.pre_layer(x)
104 | x = x * (scale_bias + scale)
105 | x = x + shift
106 | x = self.post_layer(x)
107 |
108 | return x
--------------------------------------------------------------------------------
/utils/tester.py:
--------------------------------------------------------------------------------
1 | '''
2 | image is load base on the dataloader
3 | '''
4 |
5 | import os
6 | import glob
7 | import copy
8 | import random
9 | from pathlib import Path
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import torch
14 | import torchvision.transforms as transforms
15 | from torchvision.utils import save_image
16 |
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | from diffae.templates_latent import ffhq256_autoenc_latent
20 | from diffae.experiment import LitModel
21 | from utils.map_net import MappingNet, MappingNetTime
22 |
23 |
24 | class TestDataset(Dataset):
25 | def __init__(self, img_dir):
26 | self.imgs = glob.glob(os.path.join(img_dir, '*'))
27 | self.transform_img = transforms.Compose([
28 | transforms.Resize((256,256)),
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
31 | ])
32 |
33 | def __len__(self):
34 | return len(self.imgs)
35 |
36 | def __getitem__(self, idx):
37 | img_cont_A = Image.open(self.imgs[idx])
38 | img_cont_A = self.transform_img(img_cont_A)
39 | return img_cont_A, Path(self.imgs[idx]).name
40 |
41 |
42 | class DiffFSTester:
43 | def __init__(self, args):
44 | self.args = args
45 |
46 | # set seed
47 | random.seed(self.args.seed)
48 | np.random.seed(self.args.seed)
49 | torch.manual_seed(self.args.seed)
50 |
51 | # load diffae
52 | conf = ffhq256_autoenc_latent()
53 | conf.pretrain.path = os.path.join(self.args.diffae_ckpt, 'ffhq256_autoenc/last.ckpt')
54 | conf.latent_infer_path = os.path.join(self.args.diffae_ckpt, 'ffhq256_autoenc/latent.pkl')
55 |
56 | model_diffae = LitModel(conf)
57 | state = torch.load(os.path.join(self.args.diffae_ckpt, f'{conf.name}/last.ckpt'), map_location='cpu')
58 | model_diffae.load_state_dict(state['state_dict'], strict=False)
59 |
60 | # make diffae for domainA (photo) / freeze
61 | self.diffae_A = copy.deepcopy(model_diffae.ema_model)
62 | self.diffae_A = self.diffae_A.to(self.args.device)
63 | self.diffae_A.eval()
64 | self.diffae_A.requires_grad_(False)
65 |
66 | # make diffae for domainB (style) / train
67 | self.diffae_B = copy.deepcopy(model_diffae.ema_model)
68 | self.diffae_B = self.diffae_B.to(self.args.device)
69 | self.diffae_B.eval()
70 | self.diffae_B.requires_grad_(False)
71 |
72 | # mapping net
73 | if self.args.map_net:
74 | if self.args.map_time:
75 | self.model_map = MappingNetTime().to(args.device)
76 | else:
77 | self.model_map = MappingNet().to(args.device)
78 |
79 | self.infer_samp_for = conf._make_diffusion_conf(self.args.T_infer_for).make_sampler()
80 | self.infer_samp_back = conf._make_diffusion_conf(self.args.T_infer_back).make_sampler()
81 |
82 | self.transform_img = transforms.Compose([
83 | transforms.Resize((256,256)),
84 | transforms.ToTensor(),
85 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
86 | ])
87 |
88 | def infer_image(self, img_dir, z_style_B):
89 |
90 | test_dataset = TestDataset(img_dir)
91 | test_loader = DataLoader(test_dataset, batch_size=64)
92 |
93 | for img_cont_A, file_name in test_loader:
94 |
95 | img_cont_A = img_cont_A.to(self.args.device)
96 | z_cont_A = self.diffae_A.encode(img_cont_A)
97 | z_cont_A = z_cont_A.detach().clone()
98 | xt_cont_A = img_cont_A.clone()
99 |
100 | with torch.no_grad():
101 | # forward ddim (content A)
102 | forwad_indices = list(range(self.args.T_infer_for))\
103 | [:int(self.args.T_infer_for*self.args.t0_ratio)]
104 | for j in forwad_indices:
105 | t = torch.tensor([j]*len(img_cont_A), device=self.args.device)
106 | out = self.infer_samp_for.ddim_reverse_sample(self.diffae_A,
107 | xt_cont_A,
108 | t,
109 | model_kwargs={'cond': z_cont_A})
110 | xt_cont_A = out['sample']
111 |
112 | xt_cont_B = xt_cont_A.detach().clone()
113 |
114 | # reverse ddim (mix)
115 | reverse_indices = list(range(self.args.T_infer_back))[::-1]\
116 | [int(self.args.T_infer_back*(1-self.args.t0_ratio)):]
117 | for j in reverse_indices:
118 | t = torch.tensor([j]*len(img_cont_A), device=self.args.device)
119 |
120 | if self.args.map_net:
121 | if self.args.map_time:
122 | map_cont = self.model_map(img_cont_A, t)
123 | else:
124 | map_cont = self.model_map(img_cont_A)
125 | xt_cont_B = xt_cont_B + self.args.lambda_map*map_cont
126 |
127 | out = self.infer_samp_back.ddim_sample(self.diffae_B,
128 | xt_cont_B,
129 | t,
130 | model_kwargs={'cond': {'ref':z_style_B, 'input':z_cont_A},
131 | 'ref_cond_scale': [0, 1, 2, 3],
132 | 'mix': True})
133 | xt_cont_B = out['sample'].detach().clone()
134 |
135 | save_dir = os.path.join(self.args.work_dir, 'imgs_test')
136 | os.makedirs(save_dir, exist_ok=True)
137 | # save_image(xt_cont_B/2+0.5, os.path.join(save_dir, Path(input_path).name))
138 |
139 | for i, img in enumerate(xt_cont_B):
140 | save_image(img/2+0.5, os.path.join(save_dir, file_name[i]))
141 |
142 |
143 | def infer_image_all(self):
144 | ckpt_path = os.path.join(self.args.work_dir, 'ckpt', f'iter_{self.args.n_iter}.pt')
145 | ckpt = torch.load(ckpt_path, map_location='cpu')
146 | self.diffae_B.load_state_dict(ckpt['diffae_B'])
147 | self.diffae_B = self.diffae_B.to(self.args.device)
148 |
149 | if self.args.map_net:
150 | self.model_map.load_state_dict(ckpt['model_map'])
151 | self.model_map = self.model_map.to(self.args.device)
152 |
153 | # style image
154 | img_style_B = Image.open(os.path.join(self.args.style_domB_dir, self.args.ref_img)).convert('RGB')
155 | img_style_B = self.transform_img(img_style_B).unsqueeze(0).to(self.args.device)
156 | z_style_B = self.diffae_A.encode(img_style_B)
157 | z_style_B = z_style_B.detach().clone()
158 |
159 | self.infer_image(self.args.infer_dir, z_style_B)
--------------------------------------------------------------------------------