├── .DS_Store
├── .gitignore
├── README.html
├── README.md
├── ref
├── .DS_Store
├── bedroom_instructions.txt
├── ref_bedroom
│ ├── .DS_Store
│ ├── bedroom_train_0000125.png
│ ├── bedroom_train_0000126.png
│ ├── bedroom_train_0000192.png
│ ├── bedroom_train_0000274.png
│ ├── bedroom_train_0001285.png
│ ├── bedroom_train_0002219.png
│ ├── bedroom_train_0003606.png
│ └── bedroom_train_0008929.png
├── ref_cat
│ ├── .DS_Store
│ ├── cat_0000024.png
│ ├── cat_0000325.png
│ ├── cat_0000383.png
│ ├── cat_0003829.png
│ └── cat_0007358.png
├── ref_ffhq
│ ├── 69005.png
│ ├── 69019.png
│ ├── 69099.png
│ ├── 69109.png
│ ├── 69708.png
│ ├── 69845.png
│ ├── guide2.jpeg
│ └── style4.png
├── ref_horse
│ ├── horse_0000023.png
│ ├── horse_0000052.png
│ ├── horse_0000338.png
│ ├── horse_0004057.png
│ └── horse_0005581.png
└── ref_style
│ ├── .DS_Store
│ ├── candy.jpeg
│ ├── edtaonisl.jpeg
│ ├── starnight.jpg
│ ├── sunflowers.jpeg
│ └── wave.jpeg
├── requirements.txt
├── scripts
├── clip_finetune_noise_nolabel.py
└── sample.py
├── sdg
├── __init__.py
├── clip_guidance.py
├── distributed.py
├── fp16_util.py
├── gaussian_diffusion.py
├── gpu_affinity.py
├── guidance.py
├── image_datasets.py
├── logger.py
├── logging.py
├── losses.py
├── misc.py
├── nn.py
├── parser.py
├── resample.py
├── respace.py
├── script_util.py
├── train_util.py
└── unet.py
├── setup.py
└── teaser.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | models/
2 | *.pt
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # More Control for Free! Image Synthesis with Semantic Diffusion Guidance
2 |
3 | This is the codebase for [More Control for Free! Image Synthesis with Semantic Diffusion Guidance](http://arxiv.org/abs/2112.05744).
4 |
5 | This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications for semantic guidance.
6 |
7 | 
8 |
9 | ### Installation
10 |
11 | ```bash
12 | git clone https://github.com/xh-liu/SDG_code
13 | cd SDG
14 | pip install -r requirements.txt
15 | pip install -e .
16 | ```
17 |
18 | ### Download pre-trained models
19 |
20 | The pretrained unconditional diffusion models are from [openai/guided-diffusion](https://github.com/openai/guided-diffusion) and [jychoi118/ilvr_adm](https://github.com/jychoi118/ilvr_adm).
21 |
22 | * LSUN bedroom unconditional diffusion: [lsun_bedroom.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_bedroom.pt)
23 | * LSUN cat unconditional diffusion: [lsun_cat.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_cat.pt)
24 | * LSUN horse unconditional diffusion: [lsun_horse.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_horse.pt)
25 | * LSUN horse (no dropout): [lsun_horse_nodropout.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_horse_nodropout.pt)
26 | * FFHQ unconditional diffusion: [ffhq.pt](https://onedrive.live.com/?authkey=%21AOIJGI8FUQXvFf8&id=72419B431C262344%21103807&cid=72419B431C262344)
27 |
28 | We finetune the CLIP image encoders on noisy images for the semantic guidance. We provide the checkpoint as follows:
29 |
30 | * FFHQ semantic guidance: [clip_ffhq.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EQbpgLeWnZhBhNzvFXgn26IBhsveoX3V57ZoQdSsLnwrjA?e=1K0qwv)
31 | * LSUN bedroom semantic guidance: [clipbedroom.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EfVpSVSjAhlEpsBCxSwkBnQByUvgNZqr38bxnG6bDHuOZQ?e=bOgCZT)
32 | * LSUN cat semantic guidance: [clip_cat.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EQdhKa0Jte9FtaB21kRDbT0B7tI3SoZewOack9DNe8s0LQ?e=zILyOa)
33 | * LSUN horse semantic guidance: [clip_horse.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EWqcgeq4kkpCgi3S9WcsOjABScZg-gT-aSnaZyh1uHIxNg?e=qDUJWK)
34 |
35 | ### Sampling with semantic diffusion guidance
36 |
37 | To sample from these models, you can use `scripts/sample.py`.
38 | Here, we provide flags for sampling from all of these models.
39 | We assume that you have downloaded the relevant model checkpoints into a folder called `models/`.
40 |
41 | For LSUN cat, LSUN horse, and LSUN bedroom, the model flags are defined as:
42 |
43 | ```bash
44 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --model_path models/lsun_bedroom.pt"
45 | ```
46 |
47 | For FFHQ dataset, the model flags are defined as:
48 | ```bash
49 | MODEL_FLAGS="--attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --model_path models/ffhq_10m.pt"
50 | ```
51 |
52 | Sampling flags:
53 |
54 | ```bash
55 | SAMPLE_FLAGS="--batch_size 8 --timestep_respacing 100"
56 | ```
57 |
58 | Sampling with image content(semantics) guidance:
59 |
60 | ```bash
61 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 0 --image_weight 100 --image_loss semantic --clip_path models/CLIP_bedroom.pt"
62 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS
63 | ```
64 |
65 | Sampling with image style guidance:
66 | ```bash
67 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 0 --image_weight 100 --image_loss style --clip_path models/CLIP_bedroom.pt"
68 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_style_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS
69 | ```
70 |
71 | Sampling with language guidance:
72 | ```bash
73 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 160 --image_weight 0 --text_instruction_file ref/bedroom_instructions.txt --clip_path models/CLIP_bedroom.pt"
74 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_language_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS
75 | ```
76 |
77 | Sampling with both language and image guidance:
78 | ```bash
79 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 160 --image_weight 100 --image_loss semantic --text_instruction_file ref/bedroom_instructions.txt --clip_path models/CLIP_bedroom.pt"
80 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_language_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS
81 | ```
82 | You may need to adjust the text_weight and image_weight for better visual quality of generated samples.
83 |
84 | ### Citation
85 | If you find our work useful for your research, please cite our papers.
86 | ```
87 | @inproceedings{liu2023more,
88 | title={More control for free! image synthesis with semantic diffusion guidance},
89 | author={Liu, Xihui and Park, Dong Huk and Azadi, Samaneh and Zhang, Gong and Chopikyan, Arman and Hu, Yuxiao and Shi, Humphrey and Rohrbach, Anna and Darrell, Trevor},
90 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
91 | year={2023}
92 | }
93 | ```
94 |
95 |
--------------------------------------------------------------------------------
/ref/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/.DS_Store
--------------------------------------------------------------------------------
/ref/bedroom_instructions.txt:
--------------------------------------------------------------------------------
1 | A photo of a bedroom with a painting on the wall.
2 |
--------------------------------------------------------------------------------
/ref/ref_bedroom/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/.DS_Store
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0000125.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000125.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0000126.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000126.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0000192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000192.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0000274.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000274.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0001285.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0001285.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0002219.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0002219.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0003606.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0003606.png
--------------------------------------------------------------------------------
/ref/ref_bedroom/bedroom_train_0008929.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0008929.png
--------------------------------------------------------------------------------
/ref/ref_cat/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/.DS_Store
--------------------------------------------------------------------------------
/ref/ref_cat/cat_0000024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000024.png
--------------------------------------------------------------------------------
/ref/ref_cat/cat_0000325.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000325.png
--------------------------------------------------------------------------------
/ref/ref_cat/cat_0000383.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000383.png
--------------------------------------------------------------------------------
/ref/ref_cat/cat_0003829.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0003829.png
--------------------------------------------------------------------------------
/ref/ref_cat/cat_0007358.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0007358.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69005.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69019.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69099.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69099.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69109.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69109.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69708.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69708.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/69845.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69845.png
--------------------------------------------------------------------------------
/ref/ref_ffhq/guide2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/guide2.jpeg
--------------------------------------------------------------------------------
/ref/ref_ffhq/style4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/style4.png
--------------------------------------------------------------------------------
/ref/ref_horse/horse_0000023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000023.png
--------------------------------------------------------------------------------
/ref/ref_horse/horse_0000052.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000052.png
--------------------------------------------------------------------------------
/ref/ref_horse/horse_0000338.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000338.png
--------------------------------------------------------------------------------
/ref/ref_horse/horse_0004057.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0004057.png
--------------------------------------------------------------------------------
/ref/ref_horse/horse_0005581.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0005581.png
--------------------------------------------------------------------------------
/ref/ref_style/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/.DS_Store
--------------------------------------------------------------------------------
/ref/ref_style/candy.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/candy.jpeg
--------------------------------------------------------------------------------
/ref/ref_style/edtaonisl.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/edtaonisl.jpeg
--------------------------------------------------------------------------------
/ref/ref_style/starnight.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/starnight.jpg
--------------------------------------------------------------------------------
/ref/ref_style/sunflowers.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/sunflowers.jpeg
--------------------------------------------------------------------------------
/ref/ref_style/wave.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/wave.jpeg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | git+https://github.com/openai/CLIP.git
5 | pynvml
6 | tensorboard
7 |
--------------------------------------------------------------------------------
/scripts/clip_finetune_noise_nolabel.py:
--------------------------------------------------------------------------------
1 | """
2 | Finetune a noised CLIP image encoder on the target dataset without text annotations.
3 | """
4 |
5 | import argparse
6 | import os
7 | import time
8 |
9 | import blobfile as bf
10 | import numpy as np
11 | import torch as th
12 | import torch.distributed as dist
13 | import torch.nn.functional as F
14 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
15 | from torch.optim import AdamW
16 | from torch.utils import tensorboard
17 |
18 | from sgd.parser import create_argparser
19 | from sgd.logging import init_logging, make_logging_dir
20 | from sgd.distributed import init_dist, is_master, get_world_size
21 | from sgd.distributed import master_only_print as print
22 | from sgd.distributed import dist_all_gather_tensor, all_gather_with_gradient
23 | from sgd.gpu_affinity import set_affinity
24 | from sgd.image_datasets import load_data
25 | from sgd.resample import create_named_schedule_sampler
26 | from sgd.script_util import (
27 | add_dict_to_argparser,
28 | args_to_dict,
29 | classifier_and_diffusion_defaults,
30 | create_clip_and_diffusion,
31 | )
32 | from sgd.train_util import parse_resume_step_from_filename, log_loss_dict
33 | from sgd.misc import set_random_seed
34 | from sgd.misc import to_cuda
35 |
36 |
37 | def main():
38 | args = create_argparser().parse_args()
39 |
40 | set_affinity(args.local_rank)
41 | if args.randomized_seed:
42 | args.seed = random.randint(0, 10000)
43 | set_random_seed(args.seed, by_rank=True)
44 | if not args.single_gpu:
45 | init_dist(args.local_rank)
46 | tb_log = None
47 | args.logdir = init_logging(args.exp_name)
48 | if is_master():
49 | tb_log = make_logging_dir(args.logdir)
50 | world_size = get_world_size()
51 |
52 | print("creating model and diffusion...")
53 | model, diffusion = create_clip_and_diffusion(
54 | args, **args_to_dict(args, classifier_and_diffusion_defaults().keys())
55 | )
56 | model.to('cuda')
57 | if args.noised:
58 | schedule_sampler = create_named_schedule_sampler(
59 | args.schedule_sampler, diffusion
60 | )
61 |
62 | resume_step = 0
63 | if args.resume_checkpoint:
64 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
65 | print(f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step")
66 | model.load_state_dict(th.load(args.resume_checkpoint, map_location=lambda storage, loc: storage))
67 | model.to('cuda')
68 |
69 | if args.use_fp16:
70 | if args.fp16_hyperparams == 'pytorch':
71 | scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)
72 | elif args.fp16_hyperparams == 'openai':
73 | scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2**0.001, backoff_factor=0.5, growth_interval=1, enabled=True)
74 |
75 |
76 | use_ddp = th.cuda.is_available() and th.distributed.is_available() and dist.is_initialized()
77 | if use_ddp:
78 | ddp_model = DDP(
79 | model,
80 | device_ids=[args.local_rank],
81 | output_device=args.local_rank,
82 | broadcast_buffers=False,
83 | bucket_cap_mb=128,
84 | find_unused_parameters=False,
85 | )
86 | else:
87 | print("Single GPU Training without DistributedDataParallel. ")
88 | ddp_model = model
89 |
90 | print("creating data loader...")
91 | args.return_text = False
92 | args.return_class = False
93 | args.return_yall = False
94 | train_dataloader = load_data(args)
95 |
96 | print(f"creating optimizer...")
97 | opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=args.weight_decay)
98 | if args.resume_checkpoint:
99 | opt_checkpoint = bf.join(
100 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt"
101 | )
102 | if bf.exists(opt_checkpoint):
103 | print(f"loading optimizer state from checkpoint: {opt_checkpoint}")
104 | opt.load_state_dict(th.load(opt_checkpoint, map_location=lambda storage, loc: storage))
105 | else:
106 | print('Warning: opt checkpoint %s not found' % opt_checkpoint)
107 | sca_checkpoint = bf.join(
108 | bf.dirname(args.resume_checkpoint), f"scaler{resume_step:06}.pt"
109 | )
110 | if bf.exists(sca_checkpoint):
111 | print(f"loading optimizer state from checkpoint: {sca_checkpoint}")
112 | scaler.load_state_dict(th.load(sca_checkpoint, map_location=lambda storage, loc: storage))
113 | else:
114 | print('Warning: opt checkpoint %s not found' % opt_checkpoint)
115 |
116 |
117 |
118 | print("training classifier model...")
119 |
120 | import clip
121 | clip_pretrained, _ = clip.load('RN50x16', jit=False)
122 | clip_pretrained = clip_pretrained.float()
123 | clip_pretrained.eval()
124 | clip_pretrained = clip_pretrained.cuda()
125 |
126 | def forward_backward_log(data_loader, prefix="train", step=0):
127 | batch, batch2 = next(data_loader)
128 |
129 | batch = to_cuda(batch)
130 | batch2 = to_cuda(batch2)
131 | # Noisy images
132 | if args.noised:
133 | t, _ = schedule_sampler.sample(batch.shape[0], 'cuda')
134 | batch = diffusion.q_sample(batch, t)
135 | else:
136 | t = th.zeros(batch.shape[0], dtype=th.long, device='cuda')
137 |
138 | ground_truth = th.arange(batch.shape[0] * world_size, dtype=th.long, device='cuda')
139 | for i, (sub_batch, sub_batch2, sub_t) in enumerate(
140 | split_microbatches(args.microbatch, batch, batch2, t)
141 | ):
142 | with th.cuda.amp.autocast(args.use_fp16):
143 | with th.no_grad():
144 | sub_labels = clip_pretrained.encode_image(sub_batch2)
145 | if args.structure == 'classifier':
146 | image_features = ddp_model(sub_batch, timesteps=sub_t)
147 | text_features = sub_labels
148 | else:
149 | image_features, text_features = ddp_model(sub_batch, sub_labels, sub_t)
150 |
151 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
152 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
153 | losses = {}
154 |
155 | image_features = all_gather_with_gradient(image_features)
156 | text_features = all_gather_with_gradient(text_features)
157 | logits_per_image = 100.0 * image_features @ text_features.t()
158 | logits_per_text = logits_per_image.t()
159 |
160 | loss_i2t = F.cross_entropy(logits_per_image, ground_truth, reduction='none')
161 | loss_t2i = F.cross_entropy(logits_per_text, ground_truth, reduction='none')
162 | loss = loss_i2t + loss_t2i
163 | losses[f"{prefix}_loss_i2t"] = loss_i2t.detach()
164 | losses[f"{prefix}_loss_t2i"] = loss_t2i.detach()
165 |
166 |
167 | losses[f"{prefix}_loss"] = loss.detach()
168 |
169 | log_loss_dict(diffusion, sub_t, losses, tb_log, step)
170 | del losses
171 | loss = loss.mean()
172 | if loss.requires_grad:
173 | if i == 0:
174 | opt.zero_grad()
175 | if args.use_fp16:
176 | scaler.scale(loss).backward()
177 | else:
178 | loss.backward()
179 |
180 | global_batch = args.batch_size * world_size
181 | for step in range(args.iterations - resume_step):
182 | print("***step %d " % (step + resume_step), end='')
183 | num_samples = (step + resume_step + 1) * global_batch
184 | print('samples: %d ' % num_samples, end='')
185 | if args.anneal_lr:
186 | set_annealed_lr(opt, args.lr, args.lr_anneal_steps, step + resume_step)
187 | if is_master():
188 | tb_log.add_scalar('status/step', step + resume_step, step)
189 | tb_log.add_scalar('status/samples', num_samples, step)
190 | tb_log.add_scalar('status/lr', args.lr, step)
191 | forward_backward_log(train_dataloader, step=step)
192 | if args.use_fp16:
193 | scaler.step(opt)
194 | scaler.update()
195 | else:
196 | opt.step()
197 | if not step % args.log_interval and is_master():
198 | tb_log.flush()
199 | if not (step + resume_step) % args.save_interval:
200 | print("saving model...")
201 | save_model(args.logdir, model, opt, scaler, step + resume_step)
202 |
203 | print("saving model...")
204 | if is_master():
205 | save_model(args.logdir, model, opt, step + resume_step)
206 | tb_log.close()
207 |
208 |
209 | def set_annealed_lr(opt, base_lr, anneal_steps, current_steps):
210 | lr_decay_cnt = current_steps // anneal_steps
211 | lr = base_lr * 0.1**lr_decay_cnt
212 | for param_group in opt.param_groups:
213 | param_group["lr"] = lr
214 |
215 |
216 | def save_model(logdir, model, opt, scaler, step):
217 | if is_master():
218 | th.save(model.state_dict(), os.path.join(logdir, f"model{step:06d}.pt"))
219 | th.save(opt.state_dict(), os.path.join(logdir, f"opt{step:06d}.pt"))
220 | th.save(scaler.state_dict(), os.path.join(logdir, f"scaler{step:06d}.pt"))
221 |
222 |
223 | def compute_top_k(logits, labels, k, reduction="mean"):
224 | _, top_ks = th.topk(logits, k, dim=-1)
225 | if reduction == "mean":
226 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
227 | elif reduction == "none":
228 | return (top_ks == labels[:, None]).float().sum(dim=-1)
229 |
230 |
231 | def split_microbatches(microbatch, *args):
232 | bs = len(args[0])
233 | if microbatch == -1 or microbatch >= bs:
234 | yield tuple(args)
235 | else:
236 | for i in range(0, bs, microbatch):
237 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args)
238 |
239 |
240 | if __name__ == "__main__":
241 | main()
242 |
--------------------------------------------------------------------------------
/scripts/sample.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import blobfile as bf
6 | import numpy as np
7 | import torch as th
8 | import torch.distributed as dist
9 | import torchvision
10 | import torch.nn.functional as F
11 |
12 | from sdg.parser import create_argparser
13 | from sdg.logging import init_logging, make_logging_dir
14 | from sdg.distributed import master_only_print as print
15 | from sdg.distributed import is_master, init_dist, get_world_size
16 | from sdg.gpu_affinity import set_affinity
17 | from sdg.logging import init_logging, make_logging_dir
18 | from sdg.script_util import (
19 | model_and_diffusion_defaults,
20 | create_model_and_diffusion,
21 | args_to_dict,
22 | add_dict_to_argparser,
23 | )
24 | from sdg.clip_guidance import CLIP_gd
25 | from sdg.image_datasets import load_ref_data
26 | from sdg.misc import set_random_seed
27 | from sdg.guidance import image_loss, text_loss
28 | from sdg.image_datasets import _list_image_files_recursively
29 | from torchvision import utils
30 | import math
31 | import clip
32 |
33 |
34 | def main():
35 | time0 = time.time()
36 | args = create_argparser().parse_args()
37 | set_affinity(args.local_rank)
38 | if args.randomized_seed:
39 | args.seed = random.randint(0, 10000)
40 | set_random_seed(args.seed, by_rank=True)
41 | if not args.single_gpu:
42 | init_dist(args.local_rank)
43 |
44 | tb_log = None
45 | args.logdir = init_logging(args.exp_name, root_dir='results', timestamp=False)
46 | if is_master():
47 | tb_log = make_logging_dir(args.logdir, no_tb=True)
48 |
49 | print("creating model...")
50 | model, diffusion = create_model_and_diffusion(
51 | **args_to_dict(args, model_and_diffusion_defaults().keys())
52 | )
53 | model.load_state_dict(
54 | th.load(args.model_path, map_location="cpu")
55 | )
56 | model.to('cuda')
57 | model.eval()
58 |
59 | clip_model, preprocess = clip.load('RN50x16', device='cuda')
60 | if args.text_weight == 0:
61 | instructions = [""]
62 | else:
63 | with open(args.text_instruction_file, 'r') as f:
64 | instructions = f.readlines()
65 | instructions = [tmp.replace('\n', '') for tmp in instructions]
66 | # define image list
67 | if args.image_weight == 0:
68 | imgs = [None]
69 | else:
70 | imgs = _list_image_files_recursively(args.data_dir)
71 | imgs = sorted(imgs)
72 | clip_ft = CLIP_gd(args)
73 | clip_ft.load_state_dict(th.load(args.clip_path, map_location='cpu'))
74 | clip_ft.eval()
75 | clip_ft = clip_ft.cuda()
76 |
77 | def cond_fn_sdg(x, t, y, **kwargs):
78 | assert y is not None
79 | with th.no_grad():
80 | if args.text_weight != 0:
81 | text_features = clip_model.encode_text(y)
82 | if args.image_weight != 0:
83 | target_img_noised = diffusion.q_sample(kwargs['ref_img'], t, tscale1000=True)
84 | target_img_features = clip_ft.encode_image_list(target_img_noised, t)
85 | with th.enable_grad():
86 | x_in = x.detach().requires_grad_(True)
87 | image_features = clip_ft.encode_image_list(x_in, t)
88 | if args.text_weight != 0:
89 | loss_text = text_loss(image_features, text_features, args)
90 | else:
91 | loss_text = 0
92 | if args.image_weight != 0:
93 | loss_img = image_loss(image_features, target_img_features, args)
94 | else:
95 | loss_img = 0
96 | total_guidance = loss_text * args.text_weight + loss_img * args.image_weight
97 |
98 | return th.autograd.grad(total_guidance.sum(), x_in)[0]
99 |
100 |
101 | print("creating samples...")
102 | count = 0
103 | for img_cnt in range(len(imgs)):
104 | if imgs[img_cnt] is not None:
105 | print("loading data...")
106 | model_kwargs = load_ref_data(args, imgs[img_cnt])
107 | else:
108 | model_kwargs = {}
109 |
110 | for ins_cnt in range(len(instructions)):
111 | instruction = instructions[ins_cnt]
112 | text = clip.tokenize([instruction for cnt in range(args.batch_size)]).to('cuda')
113 | model_kwargs['y'] = text
114 | model_kwargs = {k: v.to('cuda') for k, v in model_kwargs.items()}
115 | if args.image_weight == 0 and args.text_weight == 0:
116 | cond_fn = None
117 | else:
118 | cond_fn = cond_fn_sdg
119 | with th.cuda.amp.autocast(True):
120 | sample = diffusion.p_sample_loop(
121 | model,
122 | (args.batch_size, 3, args.image_size, args.image_size),
123 | noise=None,
124 | clip_denoised=args.clip_denoised,
125 | model_kwargs=model_kwargs,
126 | cond_fn=cond_fn,
127 | device='cuda',
128 | )
129 |
130 | for i in range(args.batch_size):
131 | if args.text_weight == 0:
132 | out_folder = '%05d_%s' % (img_cnt, os.path.basename(imgs[img_cnt]).split('.')[0])
133 | elif args.image_weight == 0:
134 | out_folder = '%05d_%s' % (ins_cnt, instructions[ins_cnt])
135 | else:
136 | out_folder = '%05d_%05d_%s_%s' % (img_cnt, ins_cnt, os.path.basename(imgs[img_cnt]).split('.')[0], instructions[ins_cnt])
137 |
138 | out_path = os.path.join(args.logdir, out_folder,
139 | f"{str(count * args.batch_size + i).zfill(5)}.png")
140 | os.makedirs(os.path.dirname(out_path), exist_ok=True)
141 | utils.save_image(
142 | sample[i].unsqueeze(0),
143 | out_path,
144 | nrow=1,
145 | normalize=True,
146 | range=(-1, 1),
147 | )
148 |
149 | count += 1
150 | print(f"created {count * args.batch_size} samples")
151 | print(time.time() - time0)
152 |
153 | print("sampling complete")
154 |
155 |
156 | if __name__ == "__main__":
157 | main()
158 |
--------------------------------------------------------------------------------
/sdg/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Codebase for "More Control for Free! Image Synthesis with Semantic Diffusion Guidance".
3 | """
4 |
--------------------------------------------------------------------------------
/sdg/clip_guidance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import transforms
5 | from clip import clip
6 | from .nn import timestep_embedding
7 |
8 |
9 | class CLIP_gd(nn.Module):
10 | def __init__(self, args):
11 | super().__init__()
12 | self.finetune_clip_layer = getattr(args, 'finetune_clip_layer', 'all')
13 | clip_model, preprocess = clip.load('RN50x16', jit=False)
14 | self.preprocess = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
15 | clip_model = clip_model.float()
16 |
17 | # visual
18 | self.visual_frozen = nn.Sequential(
19 | clip_model.visual.conv1,
20 | clip_model.visual.bn1,
21 | clip_model.visual.relu1,
22 | clip_model.visual.conv2,
23 | clip_model.visual.bn2,
24 | clip_model.visual.relu2,
25 | clip_model.visual.conv3,
26 | clip_model.visual.bn3,
27 | clip_model.visual.relu3,
28 | clip_model.visual.avgpool,
29 | clip_model.visual.layer1,
30 | clip_model.visual.layer2,
31 | clip_model.visual.layer3,
32 | )
33 | self.attn_pool = clip_model.visual.attnpool
34 | self.layer4 = clip_model.visual.layer4
35 |
36 | self.attn_resolution = args.image_size // 32
37 | self.image_size = args.image_size
38 | self.after_load()
39 | self.define_finetune()
40 |
41 | def after_load(self):
42 | self.attn_pool.positional_embedding = nn.Parameter(torch.randn(self.attn_resolution ** 2 + 1, 3072) / 3072 ** 0.5)
43 | self.time_embed = nn.Sequential(
44 | nn.Linear(128, 512),
45 | nn.ReLU(),
46 | nn.Linear(512, 512),
47 | )
48 | emb_dim = [48, 48, 96]
49 | tmp1 = [96, 96, 384]
50 | for cnt in range(6):
51 | emb_dim.extend(tmp1)
52 | tmp2 = [192, 192, 768]
53 | for cnt in range(8):
54 | emb_dim.extend(tmp2)
55 | tmp3 = [384, 384, 1536]
56 | for cnt in range(18):
57 | emb_dim.extend(tmp3)
58 | tmp4 = [768, 768, 3072]
59 | for cnt in range(8):
60 | emb_dim.extend(tmp4)
61 | self.emb_layers = nn.Sequential(nn.ReLU(), nn.Linear(512, sum(emb_dim) * 2))
62 | self.split_idx = []
63 | cur_idx = 0
64 | for cnt in range(len(emb_dim)):
65 | self.split_idx.append(cur_idx + emb_dim[cnt])
66 | self.split_idx.append(cur_idx + 2 * emb_dim[cnt])
67 | cur_idx += 2 * emb_dim[cnt]
68 | self.split_idx = self.split_idx[:-1]
69 |
70 | def define_finetune(self):
71 | self.train()
72 |
73 | # freeze visual encoder
74 | for param in self.visual_frozen.parameters():
75 | param.requires_grad = False
76 | for param in self.layer4.parameters():
77 | param.requires_grad = False
78 | self.attn_pool.positional_embedding.requires_grad = True
79 | self.time_embed.requires_grad = True
80 | self.emb_layers.requires_grad = True
81 |
82 | if self.finetune_clip_layer == 'last':
83 | for param in self.layer4.parameters():
84 | param.requires_grad = True
85 | for param in self.attn_pool.parameters():
86 | param.requires_grad = True
87 | elif self.finetune_clip_layer == 'all':
88 | for param in self.parameters():
89 | param.requires_grad = True
90 |
91 |
92 | def train(self, mode=True):
93 | self.visual_frozen.eval()
94 | self.layer4.eval()
95 | self.attn_pool.eval()
96 |
97 | self.time_embed.train(mode)
98 | self.emb_layers.train(mode)
99 |
100 | if self.finetune_clip_layer == 'last':
101 | self.layer4.train(mode)
102 | self.attn_pool.train(mode)
103 | elif self.finetune_clip_layer == 'all':
104 | self.visual_frozen.train(mode)
105 | self.layer4.train(mode)
106 | self.attn_pool.train(mode)
107 |
108 | def encode_image(self, image, t):
109 | image = (image + 1) / 2.0
110 | image = self.preprocess(image)
111 | emb = self.time_embed(timestep_embedding(t, 128))
112 | emb_out = torch.tensor_split(self.emb_layers(emb).unsqueeze(-1).unsqueeze(-1), self.split_idx, dim=1)
113 | x = self.visual_frozen[1](self.visual_frozen[0](image))
114 | x = x * (1 + emb_out[0]) + emb_out[1]
115 | x = self.visual_frozen[4](self.visual_frozen[3](self.visual_frozen[2](x)))
116 | x = x * (1 + emb_out[2]) + emb_out[3]
117 | x = self.visual_frozen[7](self.visual_frozen[6](self.visual_frozen[5](x)))
118 | x = x * (1 + emb_out[4]) + emb_out[5]
119 | x = self.visual_frozen[9](self.visual_frozen[8](x))
120 | # layer1
121 | module_cnt = 10
122 | emb_cnt = 6
123 | for cnt in range(6):
124 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt)
125 |
126 | # layer2
127 | module_cnt = 11
128 | emb_cnt = 42
129 | for cnt in range(8):
130 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt)
131 | # layer3
132 | module_cnt = 12
133 | emb_cnt = 90
134 | for cnt in range(18):
135 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt)
136 |
137 | # layer4
138 | emb_cnt = 198
139 | for cnt in range(8):
140 | x = self.bottleneck_block_forward(self.layer4[cnt], x, emb_out, module_cnt, emb_cnt, cnt)
141 |
142 | x = self.attn_pool(x)
143 |
144 | return x
145 |
146 | def encode_image_list(self, image, t, return_layer=4):
147 | image = (image + 1) / 2.0
148 | image = self.preprocess(image)
149 |
150 | emb = self.time_embed(timestep_embedding(t, 128))
151 | emb_out = torch.tensor_split(self.emb_layers(emb).unsqueeze(-1).unsqueeze(-1), self.split_idx, dim=1)
152 | x = self.visual_frozen[1](self.visual_frozen[0](image))
153 | x = x * (1 + emb_out[0]) + emb_out[1]
154 | x = self.visual_frozen[4](self.visual_frozen[3](self.visual_frozen[2](x)))
155 | x = x * (1 + emb_out[2]) + emb_out[3]
156 | x = self.visual_frozen[7](self.visual_frozen[6](self.visual_frozen[5](x)))
157 | x = x * (1 + emb_out[4]) + emb_out[5]
158 | x1 = self.visual_frozen[9](self.visual_frozen[8](x))
159 | # layer1
160 | module_cnt = 10
161 | emb_cnt = 6
162 |
163 | for cnt in range(6):
164 | if cnt == 0:
165 | x2 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x1, emb_out, module_cnt, emb_cnt, cnt)
166 | else:
167 | x2 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x2, emb_out, module_cnt, emb_cnt, cnt)
168 |
169 | # layer2
170 | module_cnt = 11
171 | emb_cnt = 42
172 | for cnt in range(8):
173 | if cnt == 0:
174 | x3 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x2, emb_out, module_cnt, emb_cnt, cnt)
175 | else:
176 | x3 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x3, emb_out, module_cnt, emb_cnt, cnt)
177 |
178 | # layer3
179 | module_cnt = 12
180 | emb_cnt = 90
181 | for cnt in range(18):
182 | if cnt == 0:
183 | x4 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x3, emb_out, module_cnt, emb_cnt, cnt)
184 | else:
185 | x4 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x4, emb_out, module_cnt, emb_cnt, cnt)
186 |
187 | # layer4
188 | emb_cnt = 198
189 | for cnt in range(8):
190 | if cnt == 0:
191 | x5 = self.bottleneck_block_forward(self.layer4[cnt], x4, emb_out, module_cnt, emb_cnt, cnt)
192 | else:
193 | x5 = self.bottleneck_block_forward(self.layer4[cnt], x5, emb_out, module_cnt, emb_cnt, cnt)
194 |
195 | x6 = self.attn_pool(x5)
196 |
197 | return [x1, x2, x3, x4, x5, x6]
198 |
199 |
200 |
201 | def bottleneck_block_forward(self, net, x, emb_out, module_cnt, emb_cnt, cnt):
202 | identity = x
203 | y = net.bn1(net.conv1(x))
204 | y = y * (1 + emb_out[emb_cnt+cnt*6]) + emb_out[emb_cnt+cnt*6+1]
205 | y = net.relu1(y)
206 | y = net.bn2(net.conv2(y))
207 | y = y * (1 + emb_out[emb_cnt+cnt*6+2]) + emb_out[emb_cnt+cnt*6+3]
208 | y = net.relu2(y)
209 | y = net.avgpool(y)
210 | y = net.bn3(net.conv3(y))
211 | y = y * (1 + emb_out[emb_cnt+cnt*6+4]) + emb_out[emb_cnt+cnt*6+5]
212 | if net.downsample is not None:
213 | identity = net.downsample(x)
214 | y += identity
215 | y = net.relu3(y)
216 | return y
217 |
218 |
219 | def encode_text(self, text):
220 | with torch.no_grad():
221 | x = self.token_embedding_frozen(text) # [batch_size, n_ctx, d_model]
222 | x = x + self.positional_embedding_frozen
223 | x = x.permute(1, 0, 2) # NLD -> LND
224 | x = self.transformer_frozen(x)
225 |
226 | x = self.transformer_last_block(x)
227 | x = x.permute(1, 0, 2) # LND -> NLD
228 | x = self.ln_final(x)
229 |
230 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
231 |
232 | return x
233 |
234 | def unfreeze(self):
235 | self.attn_pool.requires_grad_(True)
236 | self.layer4.requires_grad_(True)
237 |
238 | self.transformer_last_block.requires_grad_(True)
239 | self.ln_final.requires_grad_(True)
240 | self.text_projection.requires_grad_(True)
241 | self.logit_scale.requires_grad_(True)
242 |
243 | def forward(self, image, text_features, timesteps):
244 | # to match the preprocess of clip model
245 |
246 | image_features = self.encode_image(image, timesteps)
247 | #text_features = self.encode_text(text)
248 |
249 | return image_features, text_features
250 |
251 | def training_step(self, batch, batch_idx):
252 | image, text = batch
253 |
254 | bs = image.size(0)
255 |
256 | image_features, text_features = self(image, text)
257 |
258 | # normalized features
259 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
260 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
261 |
262 | # cosine similarity as logits
263 | logit_scale = self.logit_scale.exp()
264 | logits_per_image = logit_scale * image_features @ text_features.t()
265 | logits_per_text = logit_scale * text_features @ image_features.t()
266 |
267 | label = torch.arange(bs).long()
268 | label = label.to(image.device)
269 |
270 | loss_i = F.cross_entropy(logits_per_image, label)
271 | loss_t = F.cross_entropy(logits_per_text, label)
272 |
273 | loss = (loss_i + loss_t) / 2
274 |
275 | return loss
276 |
277 | def configure_optimizers(self):
278 | lr = self.learning_rate
279 | opt = torch.optim.AdamW(list(self.attn_pool.parameters()) +
280 | list(self.layer4.parameters()) +
281 | list(self.transformer_last_block.parameters()) +
282 | list(self.ln_final.parameters()) +
283 | [self.text_projection],
284 | lr=lr)
285 | return opt
286 |
287 |
--------------------------------------------------------------------------------
/sdg/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, check out LICENSE.md
5 | import functools
6 | import ctypes
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | def init_dist(local_rank, backend='nccl', **kwargs):
13 | r"""Initialize distributed training"""
14 | if dist.is_available():
15 | if dist.is_initialized():
16 | return torch.cuda.current_device()
17 | torch.cuda.set_device(local_rank)
18 | dist.init_process_group(backend=backend, init_method='env://', **kwargs)
19 |
20 | # Increase the L2 fetch granularity for faster speed.
21 | _libcudart = ctypes.CDLL('libcudart.so')
22 | # Set device limit on the current device
23 | # cudaLimitMaxL2FetchGranularity = 0x05
24 | pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
25 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
26 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
27 | # assert pValue.contents.value == 128
28 |
29 |
30 | def get_rank():
31 | r"""Get rank of the thread."""
32 | rank = 0
33 | if dist.is_available():
34 | if dist.is_initialized():
35 | rank = dist.get_rank()
36 | return rank
37 |
38 |
39 | def get_world_size():
40 | r"""Get world size. How many GPUs are available in this job."""
41 | world_size = 1
42 | if dist.is_available():
43 | if dist.is_initialized():
44 | world_size = dist.get_world_size()
45 | return world_size
46 |
47 |
48 | def master_only(func):
49 | r"""Apply this function only to the master GPU."""
50 | @functools.wraps(func)
51 | def wrapper(*args, **kwargs):
52 | r"""Simple function wrapper for the master function"""
53 | if get_rank() == 0:
54 | return func(*args, **kwargs)
55 | else:
56 | return None
57 | return wrapper
58 |
59 |
60 | def is_master():
61 | r"""check if current process is the master"""
62 | return get_rank() == 0
63 |
64 |
65 | def is_local_master():
66 | return torch.cuda.current_device() == 0
67 |
68 |
69 | @master_only
70 | def master_only_print(*args, **kwargs):
71 | r"""master-only print"""
72 | print(*args, **kwargs)
73 |
74 | def print_all_rank(*args, **kwargs):
75 | rank = get_rank()
76 | print('[rank %d] %s' % (rank, *args), **kwargs)
77 |
78 |
79 | def dist_reduce_tensor(tensor, rank=0, reduce='mean'):
80 | r""" Reduce to rank 0 """
81 | world_size = get_world_size()
82 | if world_size < 2:
83 | return tensor
84 | with torch.no_grad():
85 | dist.reduce(tensor, dst=rank)
86 | if get_rank() == rank:
87 | if reduce == 'mean':
88 | tensor /= world_size
89 | elif reduce == 'sum':
90 | pass
91 | else:
92 | raise NotImplementedError
93 | return tensor
94 |
95 |
96 | def dist_all_reduce_tensor(tensor, reduce='mean'):
97 | r""" Reduce to all ranks """
98 | world_size = get_world_size()
99 | if world_size < 2:
100 | return tensor
101 | with torch.no_grad():
102 | dist.all_reduce(tensor)
103 | if reduce == 'mean':
104 | tensor /= world_size
105 | elif reduce == 'sum':
106 | pass
107 | else:
108 | raise NotImplementedError
109 | return tensor
110 |
111 |
112 | def dist_all_gather_tensor(tensor):
113 | r""" gather to all ranks """
114 | world_size = get_world_size()
115 | if world_size < 2:
116 | return [tensor]
117 | tensor_list = [
118 | torch.ones_like(tensor) for _ in range(dist.get_world_size())]
119 | with torch.no_grad():
120 | dist.all_gather(tensor_list, tensor)
121 | return tensor_list
122 |
123 | class AllGatherFunction(torch.autograd.Function):
124 | @staticmethod
125 | def forward(ctx, tensor: torch.Tensor, reduce_dtype: torch.dtype = torch.float32):
126 | ctx.reduce_dtype = reduce_dtype
127 | output = list(torch.empty_like(tensor) for _ in range(dist.get_world_size()))
128 | dist.all_gather(output, tensor)
129 | output = torch.cat(output, dim=0)
130 | return output
131 |
132 | @staticmethod
133 | def backward(ctx, grad_output: torch.Tensor):
134 | grad_dtype = grad_output.dtype
135 | input_list = list(grad_output.to(ctx.reduce_dtype).chunk(dist.get_world_size()))
136 | grad_input = torch.empty_like(input_list[dist.get_rank()])
137 | dist.reduce_scatter(grad_input, input_list)
138 | return grad_input.to(grad_dtype)
139 |
140 |
141 | def all_gather_with_gradient(tensor):
142 | world_size = get_world_size()
143 | if world_size < 2:
144 | return tensor
145 | return AllGatherFunction.apply(tensor)
146 |
--------------------------------------------------------------------------------
/sdg/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203 | opt.step()
204 | zero_master_grads(self.master_params)
205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206 | self.lg_loss_scale += self.fp16_scale_growth
207 | return True
208 |
209 | def _optimize_normal(self, opt: th.optim.Optimizer):
210 | grad_norm, param_norm = self._compute_norms()
211 | logger.logkv_mean("grad_norm", grad_norm)
212 | logger.logkv_mean("param_norm", param_norm)
213 | opt.step()
214 | return True
215 |
216 | def _compute_norms(self, grad_scale=1.0):
217 | grad_norm = 0.0
218 | param_norm = 0.0
219 | for p in self.master_params:
220 | with th.no_grad():
221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222 | if p.grad is not None:
223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225 |
226 | def master_params_to_state_dict(self, master_params):
227 | return master_params_to_state_dict(
228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229 | )
230 |
231 | def state_dict_to_master_params(self, state_dict):
232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233 |
234 |
235 | def check_overflow(value):
236 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
237 |
--------------------------------------------------------------------------------
/sdg/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | This code started out as a PyTorch port of Ho et al's diffusion models:
3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4 |
5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6 | """
7 |
8 | import enum
9 | import math
10 |
11 | import numpy as np
12 | import torch as th
13 |
14 | from .nn import mean_flat
15 | from .losses import normal_kl, discretized_gaussian_log_likelihood
16 |
17 |
18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
19 | """
20 | Get a pre-defined beta schedule for the given name.
21 |
22 | The beta schedule library consists of beta schedules which remain similar
23 | in the limit of num_diffusion_timesteps.
24 | Beta schedules may be added, but should not be removed or changed once
25 | they are committed to maintain backwards compatibility.
26 | """
27 | if schedule_name == "linear":
28 | # Linear schedule from Ho et al, extended to work for any number of
29 | # diffusion steps.
30 | scale = 1000 / num_diffusion_timesteps
31 | beta_start = scale * 0.0001
32 | beta_end = scale * 0.02
33 | return np.linspace(
34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
35 | )
36 | elif schedule_name == "cosine":
37 | return betas_for_alpha_bar(
38 | num_diffusion_timesteps,
39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
40 | )
41 | else:
42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
43 |
44 |
45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
46 | """
47 | Create a beta schedule that discretizes the given alpha_t_bar function,
48 | which defines the cumulative product of (1-beta) over time from t = [0,1].
49 |
50 | :param num_diffusion_timesteps: the number of betas to produce.
51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
52 | produces the cumulative product of (1-beta) up to that
53 | part of the diffusion process.
54 | :param max_beta: the maximum beta to use; use values lower than 1 to
55 | prevent singularities.
56 | """
57 | betas = []
58 | for i in range(num_diffusion_timesteps):
59 | t1 = i / num_diffusion_timesteps
60 | t2 = (i + 1) / num_diffusion_timesteps
61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
62 | return np.array(betas)
63 |
64 |
65 | class ModelMeanType(enum.Enum):
66 | """
67 | Which type of output the model predicts.
68 | """
69 |
70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
71 | START_X = enum.auto() # the model predicts x_0
72 | EPSILON = enum.auto() # the model predicts epsilon
73 |
74 |
75 | class ModelVarType(enum.Enum):
76 | """
77 | What is used as the model's output variance.
78 |
79 | The LEARNED_RANGE option has been added to allow the model to predict
80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
81 | """
82 |
83 | LEARNED = enum.auto()
84 | FIXED_SMALL = enum.auto()
85 | FIXED_LARGE = enum.auto()
86 | LEARNED_RANGE = enum.auto()
87 |
88 |
89 | class LossType(enum.Enum):
90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
91 | RESCALED_MSE = (
92 | enum.auto()
93 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
94 | KL = enum.auto() # use the variational lower-bound
95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
96 |
97 | def is_vb(self):
98 | return self == LossType.KL or self == LossType.RESCALED_KL
99 |
100 |
101 | class GaussianDiffusion:
102 | """
103 | Utilities for training and sampling diffusion models.
104 |
105 | Ported directly from here, and then adapted over time to further experimentation.
106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
107 |
108 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
109 | starting at T and going to 1.
110 | :param model_mean_type: a ModelMeanType determining what the model outputs.
111 | :param model_var_type: a ModelVarType determining how variance is output.
112 | :param loss_type: a LossType determining the loss function to use.
113 | :param rescale_timesteps: if True, pass floating point timesteps into the
114 | model so that they are always scaled like in the
115 | original paper (0 to 1000).
116 | """
117 |
118 | def __init__(
119 | self,
120 | *,
121 | betas,
122 | model_mean_type,
123 | model_var_type,
124 | loss_type,
125 | rescale_timesteps=False,
126 | betas1000=None,
127 | ):
128 | self.model_mean_type = model_mean_type
129 | self.model_var_type = model_var_type
130 | self.loss_type = loss_type
131 | self.rescale_timesteps = rescale_timesteps
132 |
133 | # Use float64 for accuracy.
134 | betas = np.array(betas, dtype=np.float64)
135 | self.betas = betas
136 | assert len(betas.shape) == 1, "betas must be 1-D"
137 | assert (betas > 0).all() and (betas <= 1).all()
138 |
139 | self.num_timesteps = int(betas.shape[0])
140 |
141 | alphas = 1.0 - betas
142 |
143 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
144 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
145 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
146 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
147 |
148 | # calculations for diffusion q(x_t | x_{t-1}) and others
149 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
150 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
151 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
152 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
153 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
154 |
155 | if betas1000 is not None:
156 | betas1000 = np.array(betas1000, dtype=np.float64)
157 | alphas1000 = 1.0 - betas1000
158 | self.alphas_cumprod1000 = np.cumprod(alphas1000, axis=0)
159 | self.sqrt_alphas_cumprod1000 = np.sqrt(self.alphas_cumprod1000)
160 | self.sqrt_one_minus_alphas_cumprod1000 = np.sqrt(1.0 - self.alphas_cumprod1000)
161 |
162 | # calculations for posterior q(x_{t-1} | x_t, x_0)
163 | self.posterior_variance = (
164 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
165 | )
166 | # log calculation clipped because the posterior variance is 0 at the
167 | # beginning of the diffusion chain.
168 | self.posterior_log_variance_clipped = np.log(
169 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
170 | )
171 | self.posterior_mean_coef1 = (
172 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
173 | )
174 | self.posterior_mean_coef2 = (
175 | (1.0 - self.alphas_cumprod_prev)
176 | * np.sqrt(alphas)
177 | / (1.0 - self.alphas_cumprod)
178 | )
179 |
180 | def q_mean_variance(self, x_start, t):
181 | """
182 | Get the distribution q(x_t | x_0).
183 |
184 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
185 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
186 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
187 | """
188 | mean = (
189 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
190 | )
191 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
192 | log_variance = _extract_into_tensor(
193 | self.log_one_minus_alphas_cumprod, t, x_start.shape
194 | )
195 | return mean, variance, log_variance
196 |
197 | def q_sample(self, x_start, t, noise=None, tscale1000=False):
198 | """
199 | Diffuse the data for a given number of diffusion steps.
200 |
201 | In other words, sample from q(x_t | x_0).
202 |
203 | :param x_start: the initial data batch.
204 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
205 | :param noise: if specified, the split-out normal noise.
206 | :return: A noisy version of x_start.
207 | """
208 | if noise is None:
209 | noise = th.randn_like(x_start)
210 | assert noise.shape == x_start.shape
211 | if tscale1000:
212 | return (
213 | _extract_into_tensor(self.sqrt_alphas_cumprod1000, t, x_start.shape) * x_start
214 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod1000, t, x_start.shape)
215 | * noise
216 | )
217 | else:
218 | return (
219 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
220 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
221 | * noise
222 | )
223 |
224 | def q_posterior_mean_variance(self, x_start, x_t, t):
225 | """
226 | Compute the mean and variance of the diffusion posterior:
227 |
228 | q(x_{t-1} | x_t, x_0)
229 |
230 | """
231 | assert x_start.shape == x_t.shape
232 | posterior_mean = (
233 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
234 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
235 | )
236 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
237 | posterior_log_variance_clipped = _extract_into_tensor(
238 | self.posterior_log_variance_clipped, t, x_t.shape
239 | )
240 | assert (
241 | posterior_mean.shape[0]
242 | == posterior_variance.shape[0]
243 | == posterior_log_variance_clipped.shape[0]
244 | == x_start.shape[0]
245 | )
246 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
247 |
248 | def p_mean_variance(
249 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
250 | ):
251 | """
252 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
253 | the initial x, x_0.
254 |
255 | :param model: the model, which takes a signal and a batch of timesteps
256 | as input.
257 | :param x: the [N x C x ...] tensor at time t.
258 | :param t: a 1-D Tensor of timesteps.
259 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
260 | :param denoised_fn: if not None, a function which applies to the
261 | x_start prediction before it is used to sample. Applies before
262 | clip_denoised.
263 | :param model_kwargs: if not None, a dict of extra keyword arguments to
264 | pass to the model. This can be used for conditioning.
265 | :return: a dict with the following keys:
266 | - 'mean': the model mean output.
267 | - 'variance': the model variance output.
268 | - 'log_variance': the log of 'variance'.
269 | - 'pred_xstart': the prediction for x_0.
270 | """
271 | if model_kwargs is None:
272 | model_kwargs = {}
273 |
274 | B, C = x.shape[:2]
275 | assert t.shape == (B,)
276 | model_output = model(x, self._scale_timesteps(t), **model_kwargs)
277 |
278 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
279 | assert model_output.shape == (B, C * 2, *x.shape[2:])
280 | model_output, model_var_values = th.split(model_output, C, dim=1)
281 | if self.model_var_type == ModelVarType.LEARNED:
282 | model_log_variance = model_var_values
283 | model_variance = th.exp(model_log_variance)
284 | else:
285 | min_log = _extract_into_tensor(
286 | self.posterior_log_variance_clipped, t, x.shape
287 | )
288 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
289 | # The model_var_values is [-1, 1] for [min_var, max_var].
290 | frac = (model_var_values + 1) / 2
291 | model_log_variance = frac * max_log + (1 - frac) * min_log
292 | model_variance = th.exp(model_log_variance)
293 | else:
294 | model_variance, model_log_variance = {
295 | # for fixedlarge, we set the initial (log-)variance like so
296 | # to get a better decoder log likelihood.
297 | ModelVarType.FIXED_LARGE: (
298 | np.append(self.posterior_variance[1], self.betas[1:]),
299 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
300 | ),
301 | ModelVarType.FIXED_SMALL: (
302 | self.posterior_variance,
303 | self.posterior_log_variance_clipped,
304 | ),
305 | }[self.model_var_type]
306 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
307 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
308 |
309 | def process_xstart(x):
310 | if denoised_fn is not None:
311 | x = denoised_fn(x)
312 | if clip_denoised:
313 | return x.clamp(-1, 1)
314 | return x
315 |
316 | if self.model_mean_type == ModelMeanType.PREVIOUS_X:
317 | pred_xstart = process_xstart(
318 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
319 | )
320 | model_mean = model_output
321 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
322 | if self.model_mean_type == ModelMeanType.START_X:
323 | pred_xstart = process_xstart(model_output)
324 | else:
325 | pred_xstart = process_xstart(
326 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
327 | )
328 | model_mean, _, _ = self.q_posterior_mean_variance(
329 | x_start=pred_xstart, x_t=x, t=t
330 | )
331 | else:
332 | raise NotImplementedError(self.model_mean_type)
333 |
334 | assert (
335 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
336 | )
337 | return {
338 | "mean": model_mean,
339 | "variance": model_variance,
340 | "log_variance": model_log_variance,
341 | "pred_xstart": pred_xstart,
342 | }
343 |
344 | def _predict_xstart_from_eps(self, x_t, t, eps):
345 | assert x_t.shape == eps.shape
346 | return (
347 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
348 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
349 | )
350 |
351 | def _predict_xstart_from_xprev(self, x_t, t, xprev):
352 | assert x_t.shape == xprev.shape
353 | return ( # (xprev - coef2*x_t) / coef1
354 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
355 | - _extract_into_tensor(
356 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
357 | )
358 | * x_t
359 | )
360 |
361 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
362 | return (
363 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
364 | - pred_xstart
365 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
366 |
367 | def _scale_timesteps(self, t):
368 | if self.rescale_timesteps:
369 | return t.float() * (1000.0 / self.num_timesteps)
370 | return t
371 |
372 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
373 | """
374 | Compute the mean for the previous step, given a function cond_fn that
375 | computes the gradient of a conditional log probability with respect to
376 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
377 | condition on y.
378 |
379 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
380 | """
381 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
382 | new_mean = (
383 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
384 | )
385 | return new_mean
386 |
387 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
388 | """
389 | Compute what the p_mean_variance output would have been, should the
390 | model's score function be conditioned by cond_fn.
391 |
392 | See condition_mean() for details on cond_fn.
393 |
394 | Unlike condition_mean(), this instead uses the conditioning strategy
395 | from Song et al (2020).
396 | """
397 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
398 |
399 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
400 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
401 | x, self._scale_timesteps(t), **model_kwargs
402 | )
403 |
404 | out = p_mean_var.copy()
405 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
406 | out["mean"], _, _ = self.q_posterior_mean_variance(
407 | x_start=out["pred_xstart"], x_t=x, t=t
408 | )
409 | return out
410 |
411 | def p_sample(
412 | self,
413 | model,
414 | x,
415 | t,
416 | clip_denoised=True,
417 | denoised_fn=None,
418 | cond_fn=None,
419 | model_kwargs=None,
420 | ):
421 | """
422 | Sample x_{t-1} from the model at the given timestep.
423 |
424 | :param model: the model to sample from.
425 | :param x: the current tensor at x_{t-1}.
426 | :param t: the value of t, starting at 0 for the first diffusion step.
427 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
428 | :param denoised_fn: if not None, a function which applies to the
429 | x_start prediction before it is used to sample.
430 | :param cond_fn: if not None, this is a gradient function that acts
431 | similarly to the model.
432 | :param model_kwargs: if not None, a dict of extra keyword arguments to
433 | pass to the model. This can be used for conditioning.
434 | :return: a dict containing the following keys:
435 | - 'sample': a random sample from the model.
436 | - 'pred_xstart': a prediction of x_0.
437 | """
438 | out = self.p_mean_variance(
439 | model,
440 | x,
441 | t,
442 | clip_denoised=clip_denoised,
443 | denoised_fn=denoised_fn,
444 | model_kwargs=model_kwargs,
445 | )
446 | noise = th.randn_like(x)
447 | nonzero_mask = (
448 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
449 | ) # no noise when t == 0
450 | if cond_fn is not None:
451 | out["mean"] = self.condition_mean(
452 | cond_fn, out, x, t, model_kwargs=model_kwargs
453 | )
454 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
455 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
456 |
457 | def p_sample_loop(
458 | self,
459 | model,
460 | shape,
461 | noise=None,
462 | clip_denoised=True,
463 | denoised_fn=None,
464 | cond_fn=None,
465 | model_kwargs=None,
466 | device=None,
467 | progress=False,
468 | ):
469 | """
470 | Generate samples from the model.
471 |
472 | :param model: the model module.
473 | :param shape: the shape of the samples, (N, C, H, W).
474 | :param noise: if specified, the noise from the encoder to sample.
475 | Should be of the same shape as `shape`.
476 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
477 | :param denoised_fn: if not None, a function which applies to the
478 | x_start prediction before it is used to sample.
479 | :param cond_fn: if not None, this is a gradient function that acts
480 | similarly to the model.
481 | :param model_kwargs: if not None, a dict of extra keyword arguments to
482 | pass to the model. This can be used for conditioning.
483 | :param device: if specified, the device to create the samples on.
484 | If not specified, use a model parameter's device.
485 | :param progress: if True, show a tqdm progress bar.
486 | :return: a non-differentiable batch of samples.
487 | """
488 | final = None
489 | for sample in self.p_sample_loop_progressive(
490 | model,
491 | shape,
492 | noise=noise,
493 | clip_denoised=clip_denoised,
494 | denoised_fn=denoised_fn,
495 | cond_fn=cond_fn,
496 | model_kwargs=model_kwargs,
497 | device=device,
498 | progress=progress,
499 | ):
500 | final = sample
501 | return final["sample"]
502 |
503 | def p_sample_loop_progressive(
504 | self,
505 | model,
506 | shape,
507 | noise=None,
508 | clip_denoised=True,
509 | denoised_fn=None,
510 | cond_fn=None,
511 | model_kwargs=None,
512 | device=None,
513 | progress=False,
514 | ):
515 | """
516 | Generate samples from the model and yield intermediate samples from
517 | each timestep of diffusion.
518 |
519 | Arguments are the same as p_sample_loop().
520 | Returns a generator over dicts, where each dict is the return value of
521 | p_sample().
522 | """
523 | if device is None:
524 | device = next(model.parameters()).device
525 | assert isinstance(shape, (tuple, list))
526 | if noise is not None:
527 | img = noise
528 | else:
529 | img = th.randn(*shape, device=device)
530 | indices = list(range(self.num_timesteps))[::-1]
531 |
532 | if progress:
533 | # Lazy import so that we don't depend on tqdm.
534 | from tqdm.auto import tqdm
535 |
536 | indices = tqdm(indices)
537 |
538 | for i in indices:
539 | t = th.tensor([i] * shape[0], device=device)
540 | with th.no_grad():
541 | out = self.p_sample(
542 | model,
543 | img,
544 | t,
545 | clip_denoised=clip_denoised,
546 | denoised_fn=denoised_fn,
547 | cond_fn=cond_fn,
548 | model_kwargs=model_kwargs,
549 | )
550 | yield out
551 | img = out["sample"]
552 |
553 |
554 | def ddim_sample(
555 | self,
556 | model,
557 | x,
558 | t,
559 | clip_denoised=True,
560 | denoised_fn=None,
561 | cond_fn=None,
562 | model_kwargs=None,
563 | eta=0.0,
564 | ):
565 | """
566 | Sample x_{t-1} from the model using DDIM.
567 |
568 | Same usage as p_sample().
569 | """
570 | out = self.p_mean_variance(
571 | model,
572 | x,
573 | t,
574 | clip_denoised=clip_denoised,
575 | denoised_fn=denoised_fn,
576 | model_kwargs=model_kwargs,
577 | )
578 | if cond_fn is not None:
579 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
580 |
581 | # Usually our model outputs epsilon, but we re-derive it
582 | # in case we used x_start or x_prev prediction.
583 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
584 |
585 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
586 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
587 | sigma = (
588 | eta
589 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
590 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
591 | )
592 | # Equation 12.
593 | noise = th.randn_like(x)
594 | mean_pred = (
595 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
596 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
597 | )
598 | nonzero_mask = (
599 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
600 | ) # no noise when t == 0
601 | sample = mean_pred + nonzero_mask * sigma * noise
602 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
603 |
604 | def ddim_reverse_sample(
605 | self,
606 | model,
607 | x,
608 | t,
609 | clip_denoised=True,
610 | denoised_fn=None,
611 | model_kwargs=None,
612 | eta=0.0,
613 | ):
614 | """
615 | Sample x_{t+1} from the model using DDIM reverse ODE.
616 | """
617 | assert eta == 0.0, "Reverse ODE only for deterministic path"
618 | out = self.p_mean_variance(
619 | model,
620 | x,
621 | t,
622 | clip_denoised=clip_denoised,
623 | denoised_fn=denoised_fn,
624 | model_kwargs=model_kwargs,
625 | )
626 | # Usually our model outputs epsilon, but we re-derive it
627 | # in case we used x_start or x_prev prediction.
628 | eps = (
629 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
630 | - out["pred_xstart"]
631 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
632 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
633 |
634 | # Equation 12. reversed
635 | mean_pred = (
636 | out["pred_xstart"] * th.sqrt(alpha_bar_next)
637 | + th.sqrt(1 - alpha_bar_next) * eps
638 | )
639 |
640 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
641 |
642 | def ddim_sample_loop(
643 | self,
644 | model,
645 | shape,
646 | noise=None,
647 | clip_denoised=True,
648 | denoised_fn=None,
649 | cond_fn=None,
650 | model_kwargs=None,
651 | device=None,
652 | progress=False,
653 | eta=0.0,
654 | ):
655 | """
656 | Generate samples from the model using DDIM.
657 |
658 | Same usage as p_sample_loop().
659 | """
660 | final = None
661 | for sample in self.ddim_sample_loop_progressive(
662 | model,
663 | shape,
664 | noise=noise,
665 | clip_denoised=clip_denoised,
666 | denoised_fn=denoised_fn,
667 | cond_fn=cond_fn,
668 | model_kwargs=model_kwargs,
669 | device=device,
670 | progress=progress,
671 | eta=eta,
672 | ):
673 | final = sample
674 | return final["sample"]
675 |
676 | def ddim_sample_loop_progressive(
677 | self,
678 | model,
679 | shape,
680 | noise=None,
681 | clip_denoised=True,
682 | denoised_fn=None,
683 | cond_fn=None,
684 | model_kwargs=None,
685 | device=None,
686 | progress=False,
687 | eta=0.0,
688 | ):
689 | """
690 | Use DDIM to sample from the model and yield intermediate samples from
691 | each timestep of DDIM.
692 |
693 | Same usage as p_sample_loop_progressive().
694 | """
695 | if device is None:
696 | device = next(model.parameters()).device
697 | assert isinstance(shape, (tuple, list))
698 | if noise is not None:
699 | img = noise
700 | else:
701 | img = th.randn(*shape, device=device)
702 | indices = list(range(self.num_timesteps))[::-1]
703 |
704 | if progress:
705 | # Lazy import so that we don't depend on tqdm.
706 | from tqdm.auto import tqdm
707 |
708 | indices = tqdm(indices)
709 |
710 | for i in indices:
711 | t = th.tensor([i] * shape[0], device=device)
712 | with th.no_grad():
713 | out = self.ddim_sample(
714 | model,
715 | img,
716 | t,
717 | clip_denoised=clip_denoised,
718 | denoised_fn=denoised_fn,
719 | cond_fn=cond_fn,
720 | model_kwargs=model_kwargs,
721 | eta=eta,
722 | )
723 | yield out
724 | img = out["sample"]
725 |
726 | def _vb_terms_bpd(
727 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
728 | ):
729 | """
730 | Get a term for the variational lower-bound.
731 |
732 | The resulting units are bits (rather than nats, as one might expect).
733 | This allows for comparison to other papers.
734 |
735 | :return: a dict with the following keys:
736 | - 'output': a shape [N] tensor of NLLs or KLs.
737 | - 'pred_xstart': the x_0 predictions.
738 | """
739 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
740 | x_start=x_start, x_t=x_t, t=t
741 | )
742 | out = self.p_mean_variance(
743 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
744 | )
745 | kl = normal_kl(
746 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
747 | )
748 | kl = mean_flat(kl) / np.log(2.0)
749 |
750 | decoder_nll = -discretized_gaussian_log_likelihood(
751 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
752 | )
753 | assert decoder_nll.shape == x_start.shape
754 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
755 |
756 | # At the first timestep return the decoder NLL,
757 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
758 | output = th.where((t == 0), decoder_nll, kl)
759 | return {"output": output, "pred_xstart": out["pred_xstart"]}
760 |
761 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
762 | """
763 | Compute training losses for a single timestep.
764 |
765 | :param model: the model to evaluate loss on.
766 | :param x_start: the [N x C x ...] tensor of inputs.
767 | :param t: a batch of timestep indices.
768 | :param model_kwargs: if not None, a dict of extra keyword arguments to
769 | pass to the model. This can be used for conditioning.
770 | :param noise: if specified, the specific Gaussian noise to try to remove.
771 | :return: a dict with the key "loss" containing a tensor of shape [N].
772 | Some mean or variance settings may also have other keys.
773 | """
774 | if model_kwargs is None:
775 | model_kwargs = {}
776 | if noise is None:
777 | noise = th.randn_like(x_start)
778 | x_t = self.q_sample(x_start, t, noise=noise)
779 |
780 | terms = {}
781 |
782 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
783 | terms["loss"] = self._vb_terms_bpd(
784 | model=model,
785 | x_start=x_start,
786 | x_t=x_t,
787 | t=t,
788 | clip_denoised=False,
789 | model_kwargs=model_kwargs,
790 | )["output"]
791 | if self.loss_type == LossType.RESCALED_KL:
792 | terms["loss"] *= self.num_timesteps
793 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
794 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
795 |
796 | if self.model_var_type in [
797 | ModelVarType.LEARNED,
798 | ModelVarType.LEARNED_RANGE,
799 | ]:
800 | B, C = x_t.shape[:2]
801 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
802 | model_output, model_var_values = th.split(model_output, C, dim=1)
803 | # Learn the variance using the variational bound, but don't let
804 | # it affect our mean prediction.
805 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
806 | terms["vb"] = self._vb_terms_bpd(
807 | model=lambda *args, r=frozen_out: r,
808 | x_start=x_start,
809 | x_t=x_t,
810 | t=t,
811 | clip_denoised=False,
812 | )["output"]
813 | if self.loss_type == LossType.RESCALED_MSE:
814 | # Divide by 1000 for equivalence with initial implementation.
815 | # Without a factor of 1/1000, the VB term hurts the MSE term.
816 | terms["vb"] *= self.num_timesteps / 1000.0
817 |
818 | target = {
819 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
820 | x_start=x_start, x_t=x_t, t=t
821 | )[0],
822 | ModelMeanType.START_X: x_start,
823 | ModelMeanType.EPSILON: noise,
824 | }[self.model_mean_type]
825 | assert model_output.shape == target.shape == x_start.shape
826 | terms["mse"] = mean_flat((target - model_output) ** 2)
827 | if "vb" in terms:
828 | terms["loss"] = terms["mse"] + terms["vb"]
829 | else:
830 | terms["loss"] = terms["mse"]
831 | else:
832 | raise NotImplementedError(self.loss_type)
833 |
834 | return terms
835 |
836 | def _prior_bpd(self, x_start):
837 | """
838 | Get the prior KL term for the variational lower-bound, measured in
839 | bits-per-dim.
840 |
841 | This term can't be optimized, as it only depends on the encoder.
842 |
843 | :param x_start: the [N x C x ...] tensor of inputs.
844 | :return: a batch of [N] KL values (in bits), one per batch element.
845 | """
846 | batch_size = x_start.shape[0]
847 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
848 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
849 | kl_prior = normal_kl(
850 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
851 | )
852 | return mean_flat(kl_prior) / np.log(2.0)
853 |
854 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
855 | """
856 | Compute the entire variational lower-bound, measured in bits-per-dim,
857 | as well as other related quantities.
858 |
859 | :param model: the model to evaluate loss on.
860 | :param x_start: the [N x C x ...] tensor of inputs.
861 | :param clip_denoised: if True, clip denoised samples.
862 | :param model_kwargs: if not None, a dict of extra keyword arguments to
863 | pass to the model. This can be used for conditioning.
864 |
865 | :return: a dict containing the following keys:
866 | - total_bpd: the total variational lower-bound, per batch element.
867 | - prior_bpd: the prior term in the lower-bound.
868 | - vb: an [N x T] tensor of terms in the lower-bound.
869 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
870 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
871 | """
872 | device = x_start.device
873 | batch_size = x_start.shape[0]
874 |
875 | vb = []
876 | xstart_mse = []
877 | mse = []
878 | for t in list(range(self.num_timesteps))[::-1]:
879 | t_batch = th.tensor([t] * batch_size, device=device)
880 | noise = th.randn_like(x_start)
881 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
882 | # Calculate VLB term at the current timestep
883 | with th.no_grad():
884 | out = self._vb_terms_bpd(
885 | model,
886 | x_start=x_start,
887 | x_t=x_t,
888 | t=t_batch,
889 | clip_denoised=clip_denoised,
890 | model_kwargs=model_kwargs,
891 | )
892 | vb.append(out["output"])
893 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
894 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
895 | mse.append(mean_flat((eps - noise) ** 2))
896 |
897 | vb = th.stack(vb, dim=1)
898 | xstart_mse = th.stack(xstart_mse, dim=1)
899 | mse = th.stack(mse, dim=1)
900 |
901 | prior_bpd = self._prior_bpd(x_start)
902 | total_bpd = vb.sum(dim=1) + prior_bpd
903 | return {
904 | "total_bpd": total_bpd,
905 | "prior_bpd": prior_bpd,
906 | "vb": vb,
907 | "xstart_mse": xstart_mse,
908 | "mse": mse,
909 | }
910 |
911 |
912 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
913 | """
914 | Extract values from a 1-D numpy array for a batch of indices.
915 |
916 | :param arr: the 1-D numpy array.
917 | :param timesteps: a tensor of indices into the array to extract.
918 | :param broadcast_shape: a larger shape of K dimensions with the batch
919 | dimension equal to the length of timesteps.
920 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
921 | """
922 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
923 | while len(res.shape) < len(broadcast_shape):
924 | res = res[..., None]
925 | return res.expand(broadcast_shape)
926 |
--------------------------------------------------------------------------------
/sdg/gpu_affinity.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, check out LICENSE.md
5 | import math
6 | import os
7 | import pynvml
8 |
9 | pynvml.nvmlInit()
10 |
11 |
12 | def systemGetDriverVersion():
13 | r"""Get Driver Version"""
14 | return pynvml.nvmlSystemGetDriverVersion()
15 |
16 |
17 | def deviceGetCount():
18 | r"""Get number of devices"""
19 | return pynvml.nvmlDeviceGetCount()
20 |
21 |
22 | class device(object):
23 | r"""Device used for nvml."""
24 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
25 |
26 | def __init__(self, device_idx):
27 | super().__init__()
28 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
29 |
30 | def getName(self):
31 | r"""Get obect name"""
32 | return pynvml.nvmlDeviceGetName(self.handle)
33 |
34 | def getCpuAffinity(self):
35 | r"""Get CPU affinity"""
36 | affinity_string = ''
37 | for j in pynvml.nvmlDeviceGetCpuAffinity(
38 | self.handle, device._nvml_affinity_elements):
39 | # assume nvml returns list of 64 bit ints
40 | affinity_string = '{:064b}'.format(j) + affinity_string
41 | affinity_list = [int(x) for x in affinity_string]
42 | affinity_list.reverse() # so core 0 is in 0th element of list
43 |
44 | return [i for i, e in enumerate(affinity_list) if e != 0]
45 |
46 |
47 | def set_affinity(gpu_id=None):
48 | r"""Set GPU affinity
49 |
50 | Args:
51 | gpu_id (int): Which gpu device.
52 | """
53 | if gpu_id is None:
54 | gpu_id = int(os.getenv('LOCAL_RANK', 0))
55 |
56 | dev = device(gpu_id)
57 | os.sched_setaffinity(0, dev.getCpuAffinity())
58 |
59 | # list of ints
60 | # representing the logical cores this process is now affinitied with
61 | return os.sched_getaffinity(0)
62 |
--------------------------------------------------------------------------------
/sdg/guidance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | def gram_matrix(input):
8 | if input.dtype == torch.float16:
9 | input = input.to(torch.float32)
10 | flag = True
11 | a, b, c, d = input.size() # a=batch size(=1)
12 | sqrt_sum = math.sqrt(a * b * c * d) # for numerical stability
13 | features = input.view(a * b, c * d) / sqrt_sum # resise F_XL into \hat F_XL
14 | G = torch.mm(features, features.t()) # compute the gram product
15 | # we 'normalize' the values of the gram matrix
16 | # by dividing by the number of element in each feature maps.
17 | result = G
18 | if flag:
19 | return result.to(torch.float16)
20 | else:
21 | return result
22 |
23 | def image_loss(source, target, args):
24 | if args.image_loss == 'semantic':
25 | source[-1] = source[-1] / source[-1].norm(dim=-1, keepdim=True)
26 | target[-1] = target[-1] / target[-1].norm(dim=-1, keepdim=True)
27 | return (source[-1] * target[-1]).sum(1)
28 | elif args.image_loss == 'style':
29 | weights = [1, 1, 1, 1, 1]
30 | loss = 0
31 | for cnt in range(5):
32 | loss += F.mse_loss(gram_matrix(source[cnt]), gram_matrix(target[cnt]))
33 | return -loss * 1e10 / sum(weights)
34 |
35 | def text_loss(source, target, args):
36 | source_feat = source[-1] / source[-1].norm(dim=-1, keepdim=True)
37 | target = target / target.norm(dim=-1, keepdim=True)
38 | return (source_feat * target).sum(1)
39 |
--------------------------------------------------------------------------------
/sdg/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import os
4 | import pickle
5 |
6 | from PIL import Image
7 | import blobfile as bf
8 | import numpy as np
9 | import torch
10 | import torch.distributed as dist
11 | from torch.utils.data import DataLoader, Dataset
12 | import clip
13 | from sdg.distributed import master_only_print as print
14 | import torchvision.transforms as transforms
15 | import json
16 | import io
17 |
18 |
19 | def load_ref_data(args, ref_img_path=None):
20 | if ref_img_path is None:
21 | ref_img_path = args.ref_img_path
22 | with bf.BlobFile(ref_img_path, "rb") as f:
23 | pil_image = Image.open(f)
24 | pil_image.load()
25 |
26 | pil_image = pil_image.convert("RGB")
27 | arr = center_crop_arr(pil_image, args.image_size)
28 | arr = arr.astype(np.float32) / 127.5 - 1
29 | arr = np.repeat(np.expand_dims(np.transpose(arr, [2, 0, 1]), axis=0), args.batch_size, axis=0)
30 | kwargs = {}
31 | kwargs["ref_img"] = torch.tensor(arr)
32 | return kwargs
33 |
34 | def load_data(args, is_train=True, swap=False):
35 | """
36 | For a dataset, create a generator over (images, kwargs) pairs.
37 |
38 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
39 | more keys, each of which map to a batched Tensor of their own.
40 | The kwargs dict can be used for class labels, in which case the key is "y"
41 | and the values are integer tensors of class labels.
42 |
43 | :param data_dir: a dataset directory.
44 | :param batch_size: the batch size of each returned pair.
45 | :param image_size: the size to which images are resized.
46 | :param class_cond: if True, include a "y" key in returned dicts for class
47 | label. If classes are not available and this is true, an
48 | exception will be raised.
49 | :param deterministic: if True, yield results in a deterministic order.
50 | :param random_crop: if True, randomly crop the images for augmentation.
51 | :param random_flip: if True, randomly flip the images for augmentation.
52 | """
53 | if not args.data_dir:
54 | raise ValueError("unspecified data directory")
55 | data_dir = getattr(args, 'data_dir')
56 | batch_size = getattr(args, 'batch_size')
57 | image_size = getattr(args, 'image_size')
58 | class_cond = getattr(args, 'class_cond', False)
59 | deterministic = getattr(args, 'deterministic', False)
60 | random_crop = getattr(args, 'random_crop', False)
61 | random_flip = getattr(args, 'random_flip', True)
62 | if not is_train:
63 | deterministic = True
64 | random_crop = False
65 | random_flip = False
66 | num_workers = getattr(args, 'num_workers', 4)
67 | all_files = _list_image_files_recursively(data_dir)
68 | classes = None
69 | if class_cond:
70 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
71 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
72 | classes = [sorted_classes[x] for x in class_names]
73 | dataset = ImageDataset(
74 | image_size,
75 | all_files,
76 | classes=classes,
77 | random_crop=random_crop,
78 | random_flip=random_flip,
79 | )
80 | not_distributed = args.single_gpu or args.debug or not dist.is_initialized()
81 | if not_distributed:
82 | sampler = None
83 | else:
84 | if deterministic:
85 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
86 | else:
87 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
88 | if deterministic or not not_distributed:
89 | loader = DataLoader(
90 | dataset, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=num_workers, drop_last=False
91 | )
92 | else:
93 | loader = DataLoader(
94 | dataset, batch_size=batch_size, shuffle=True, sampler=sampler, num_workers=num_workers, drop_last=False
95 | )
96 | while True:
97 | yield from loader
98 |
99 |
100 |
101 | def load_ref_data(args, ref_img_path=None):
102 | if ref_img_path is None:
103 | ref_img_path = args.ref_img_path
104 | with bf.BlobFile(ref_img_path, "rb") as f:
105 | pil_image = Image.open(f)
106 | pil_image.load()
107 |
108 | pil_image = pil_image.convert("RGB")
109 |
110 | arr = center_crop_arr(pil_image, args.image_size)
111 |
112 | arr = arr.astype(np.float32) / 127.5 - 1
113 |
114 | arr = np.repeat(np.expand_dims(np.transpose(arr, [2, 0, 1]), axis=0), args.batch_size, axis=0)
115 |
116 | kwargs = {}
117 |
118 | kwargs["ref_img"] = torch.tensor(arr)
119 |
120 | return kwargs
121 |
122 | def _list_image_files_recursively(data_dir):
123 | results = []
124 | for entry in sorted(bf.listdir(data_dir)):
125 | full_path = bf.join(data_dir, entry)
126 | ext = entry.split(".")[-1]
127 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
128 | results.append(full_path)
129 | elif bf.isdir(full_path):
130 | results.extend(_list_image_files_recursively(full_path))
131 | return results
132 |
133 |
134 | class ImageDataset(Dataset):
135 | def __init__(
136 | self,
137 | resolution,
138 | image_paths,
139 | classes=None,
140 | random_crop=False,
141 | random_flip=True,
142 | ):
143 | super().__init__()
144 | self.resolution = resolution
145 | self.images = image_paths
146 | self.classes = None if classes is None else classes
147 | self.random_crop = random_crop
148 | self.random_flip = random_flip
149 |
150 | def __len__(self):
151 | return len(self.images)
152 |
153 | def __getitem__(self, idx):
154 | path = self.images[idx]
155 | with bf.BlobFile(path, "rb") as f:
156 | pil_image = Image.open(f)
157 | pil_image.load()
158 | pil_image = pil_image.convert("RGB")
159 |
160 | if self.random_crop:
161 | arr = random_crop_arr(pil_image, self.resolution)
162 | else:
163 | arr = center_crop_arr(pil_image, self.resolution)
164 |
165 | if self.random_flip and random.random() < 0.5:
166 | arr = arr[:, ::-1]
167 |
168 | arr = arr.astype(np.float32) / 127.5 - 1
169 |
170 | out_dict = {}
171 | if self.classes is not None:
172 | out_dict["y"] = np.array(self.classes[idx], dtype=np.int64)
173 | return np.transpose(arr, [2, 0, 1]), out_dict
174 |
175 | def center_crop_arr(pil_image, image_size):
176 | # We are not on a new enough PIL to support the `reducing_gap`
177 | # argument, which uses BOX downsampling at powers of two first.
178 | # Thus, we do it by hand to improve downsample quality.
179 | while min(*pil_image.size) >= 2 * image_size:
180 | pil_image = pil_image.resize(
181 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
182 | )
183 |
184 | scale = image_size / min(*pil_image.size)
185 | pil_image = pil_image.resize(
186 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
187 | )
188 |
189 | arr = np.array(pil_image)
190 | crop_y = (arr.shape[0] - image_size) // 2
191 | crop_x = (arr.shape[1] - image_size) // 2
192 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
193 |
194 |
195 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.85, max_crop_frac=0.95):
196 | if min(*pil_image.size) != max(*pil_image.size):
197 | min_crop_frac = 1.0
198 | max_crop_frac = 1.0
199 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
200 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
201 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
202 |
203 | # We are not on a new enough PIL to support the `reducing_gap`
204 | # argument, which uses BOX downsampling at powers of two first.
205 | # Thus, we do it by hand to improve downsample quality.
206 | while min(*pil_image.size) >= 2 * smaller_dim_size:
207 | pil_image = pil_image.resize(
208 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
209 | )
210 |
211 | scale = smaller_dim_size / min(*pil_image.size)
212 | pil_image = pil_image.resize(
213 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
214 | )
215 |
216 | arr = np.array(pil_image)
217 |
218 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
219 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
220 |
221 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
222 |
--------------------------------------------------------------------------------
/sdg/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 | def logkv_mean(key, val):
221 | """
222 | The same as logkv(), but if called many times, values averaged.
223 | """
224 | get_current().logkv_mean(key, val)
225 |
226 |
227 | def logkvs(d):
228 | """
229 | Log a dictionary of key-value pairs
230 | """
231 | for (k, v) in d.items():
232 | logkv(k, v)
233 |
234 |
235 | def dumpkvs():
236 | """
237 | Write all of the diagnostics from the current iteration
238 | """
239 | return get_current().dumpkvs()
240 |
241 |
242 | def getkvs():
243 | return get_current().name2val
244 |
245 |
246 | def log(*args, level=INFO):
247 | """
248 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
249 | """
250 | get_current().log(*args, level=level)
251 |
252 |
253 | def debug(*args):
254 | log(*args, level=DEBUG)
255 |
256 |
257 | def info(*args):
258 | log(*args, level=INFO)
259 |
260 |
261 | def warn(*args):
262 | log(*args, level=WARN)
263 |
264 |
265 | def error(*args):
266 | log(*args, level=ERROR)
267 |
268 |
269 | def set_level(level):
270 | """
271 | Set logging threshold on current logger.
272 | """
273 | get_current().set_level(level)
274 |
275 |
276 | def set_comm(comm):
277 | get_current().set_comm(comm)
278 |
279 |
280 | def get_dir():
281 | """
282 | Get directory that log files are being written to.
283 | will be None if there is no output directory (i.e., if you didn't call start)
284 | """
285 | return get_current().get_dir()
286 |
287 |
288 | record_tabular = logkv
289 | dump_tabular = dumpkvs
290 |
291 |
292 | @contextmanager
293 | def profile_kv(scopename):
294 | logkey = "wait_" + scopename
295 | tstart = time.time()
296 | try:
297 | yield
298 | finally:
299 | get_current().name2val[logkey] += time.time() - tstart
300 |
301 |
302 | def profile(n):
303 | """
304 | Usage:
305 | @profile("my_func")
306 | def my_func(): code
307 | """
308 |
309 | def decorator_with_name(func):
310 | def func_wrapper(*args, **kwargs):
311 | with profile_kv(n):
312 | return func(*args, **kwargs)
313 |
314 | return func_wrapper
315 |
316 | return decorator_with_name
317 |
318 |
319 | # ================================================================
320 | # Backend
321 | # ================================================================
322 |
323 |
324 | def get_current():
325 | if Logger.CURRENT is None:
326 | _configure_default_logger()
327 |
328 | return Logger.CURRENT
329 |
330 |
331 | class Logger(object):
332 | DEFAULT = None # A logger with no output files. (See right below class definition)
333 | # So that you can still log to the terminal without setting up any output files
334 | CURRENT = None # Current logger being used by the free functions above
335 |
336 | def __init__(self, dir, output_formats, comm=None):
337 | self.name2val = defaultdict(float) # values this iteration
338 | self.name2cnt = defaultdict(int)
339 | self.level = INFO
340 | self.dir = dir
341 | self.output_formats = output_formats
342 | self.comm = comm
343 |
344 | # Logging API, forwarded
345 | # ----------------------------------------
346 | def logkv(self, key, val):
347 | self.name2val[key] = val
348 |
349 | def logkv_mean(self, key, val):
350 | oldval, cnt = self.name2val[key], self.name2cnt[key]
351 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
352 | self.name2cnt[key] = cnt + 1
353 |
354 | def dumpkvs(self):
355 | if self.comm is None:
356 | d = self.name2val
357 | else:
358 | d = mpi_weighted_mean(
359 | self.comm,
360 | {
361 | name: (val, self.name2cnt.get(name, 1))
362 | for (name, val) in self.name2val.items()
363 | },
364 | )
365 | if self.comm.rank != 0:
366 | d["dummy"] = 1 # so we don't get a warning about empty dict
367 | out = d.copy() # Return the dict for unit testing purposes
368 | for fmt in self.output_formats:
369 | if isinstance(fmt, KVWriter):
370 | fmt.writekvs(d)
371 | self.name2val.clear()
372 | self.name2cnt.clear()
373 | return out
374 |
375 | def log(self, *args, level=INFO):
376 | if self.level <= level:
377 | self._do_log(args)
378 |
379 | # Configuration
380 | # ----------------------------------------
381 | def set_level(self, level):
382 | self.level = level
383 |
384 | def set_comm(self, comm):
385 | self.comm = comm
386 |
387 | def get_dir(self):
388 | return self.dir
389 |
390 | def close(self):
391 | for fmt in self.output_formats:
392 | fmt.close()
393 |
394 | # Misc
395 | # ----------------------------------------
396 | def _do_log(self, args):
397 | for fmt in self.output_formats:
398 | if isinstance(fmt, SeqWriter):
399 | fmt.writeseq(map(str, args))
400 |
401 |
402 | def get_rank_without_mpi_import():
403 | # check environment variables here instead of importing mpi4py
404 | # to avoid calling MPI_Init() when this module is imported
405 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
406 | if varname in os.environ:
407 | return int(os.environ[varname])
408 | return 0
409 |
410 |
411 | def mpi_weighted_mean(comm, local_name2valcount):
412 | """
413 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
414 | Perform a weighted average over dicts that are each on a different node
415 | Input: local_name2valcount: dict mapping key -> (value, count)
416 | Returns: key -> mean
417 | """
418 | all_name2valcount = comm.gather(local_name2valcount)
419 | if comm.rank == 0:
420 | name2sum = defaultdict(float)
421 | name2count = defaultdict(float)
422 | for n2vc in all_name2valcount:
423 | for (name, (val, count)) in n2vc.items():
424 | try:
425 | val = float(val)
426 | except ValueError:
427 | if comm.rank == 0:
428 | warnings.warn(
429 | "WARNING: tried to compute mean on non-float {}={}".format(
430 | name, val
431 | )
432 | )
433 | else:
434 | name2sum[name] += val * count
435 | name2count[name] += count
436 | return {name: name2sum[name] / name2count[name] for name in name2sum}
437 | else:
438 | return {}
439 |
440 |
441 | def configure(dir='./checkpoints/', exp_name='debug', format_strs=None, comm=None, log_suffix=""):
442 | """
443 | If comm is provided, average all numerical stats across that comm
444 | """
445 | dir = osp.join(
446 | dir,
447 | exp_name, # + '_' + datetime.datetime.now().strftime("%m-%d-%H-%M-%S-%f"),
448 | )
449 | assert isinstance(dir, str)
450 | dir = os.path.expanduser(dir)
451 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
452 |
453 | rank = get_rank_without_mpi_import()
454 | if rank > 0:
455 | log_suffix = log_suffix + "-rank%03i" % rank
456 |
457 | if format_strs is None:
458 | if rank == 0:
459 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
460 | else:
461 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
462 | format_strs = filter(None, format_strs)
463 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
464 |
465 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
466 | if output_formats:
467 | log("Logging to %s" % dir)
468 |
469 |
470 | def _configure_default_logger():
471 | configure()
472 | Logger.DEFAULT = Logger.CURRENT
473 |
474 |
475 | def reset():
476 | if Logger.CURRENT is not Logger.DEFAULT:
477 | Logger.CURRENT.close()
478 | Logger.CURRENT = Logger.DEFAULT
479 | log("Reset logger")
480 |
481 |
482 | @contextmanager
483 | def scoped_configure(dir=None, format_strs=None, comm=None):
484 | prevlogger = Logger.CURRENT
485 | configure(dir=dir, format_strs=format_strs, comm=comm)
486 | try:
487 | yield
488 | finally:
489 | Logger.CURRENT.close()
490 | Logger.CURRENT = prevlogger
491 |
492 |
--------------------------------------------------------------------------------
/sdg/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, check out LICENSE.md
5 | import datetime
6 | import os
7 | import torch
8 |
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from .distributed import master_only, is_master
12 | from .distributed import master_only_print as print
13 | from .distributed import dist_all_reduce_tensor
14 | from .misc import to_cuda
15 | import pdb
16 |
17 |
18 | def get_date_uid():
19 | """Generate a unique id based on date.
20 | Returns:
21 | str: Return uid string, e.g. '20171122171307111552'.
22 | """
23 | return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S"))
24 |
25 |
26 | def init_logging(exp_name, root_dir='logs', timestamp=False):
27 | r"""Create log directory for storing checkpoints and output images.
28 |
29 | Args:
30 | config_path (str): Path to the configuration file.
31 | logdir (str): Log directory name
32 | Returns:
33 | str: Return log dir
34 | """
35 | #config_file = os.path.basename(config_path)
36 | if timestamp:
37 | date_uid = get_date_uid()
38 | exp_name = '_'.join([exp_name, date_uid])
39 | # example: logs/2019_0125_1047_58_spade_cocostuff
40 | #log_file = '_'.join([date_uid, os.path.splitext(config_file)[0]])
41 | #log_file = os.path.splitext(config_file)[0]
42 | logdir = os.path.join(root_dir, exp_name)
43 | return logdir
44 |
45 |
46 | @master_only
47 | def make_logging_dir(logdir, no_tb=False):
48 | r"""Create the logging directory
49 |
50 | Args:
51 | logdir (str): Log directory name
52 | """
53 | print('Make folder {}'.format(logdir))
54 | os.makedirs(logdir, exist_ok=True)
55 | if no_tb:
56 | return None
57 | tensorboard_dir = os.path.join(logdir, 'tensorboard')
58 | os.makedirs(tensorboard_dir, exist_ok=True)
59 | tb_log = SummaryWriter(log_dir=tensorboard_dir)
60 | return tb_log
61 |
62 | def write_tb(tb_log, key, value, step):
63 | if not torch.is_tensor(value):
64 | value = torch.tensor(value)
65 | if not value.is_cuda:
66 | value = to_cuda(value)
67 | value = dist_all_reduce_tensor(value.mean()).item()
68 | if is_master():
69 | tb_log.add_scalar(key, value, step)
70 | print('%s: %f ' % (key, value), end='')
71 |
--------------------------------------------------------------------------------
/sdg/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 |
9 | import torch as th
10 |
11 |
12 | def normal_kl(mean1, logvar1, mean2, logvar2):
13 | """
14 | Compute the KL divergence between two gaussians.
15 |
16 | Shapes are automatically broadcasted, so batches can be compared to
17 | scalars, among other use cases.
18 | """
19 | tensor = None
20 | for obj in (mean1, logvar1, mean2, logvar2):
21 | if isinstance(obj, th.Tensor):
22 | tensor = obj
23 | break
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a Gaussian distribution discretizing to a
53 | given image.
54 |
55 | :param x: the target images. It is assumed that this was uint8 values,
56 | rescaled to the range [-1, 1].
57 | :param means: the Gaussian mean Tensor.
58 | :param log_scales: the Gaussian log stddev Tensor.
59 | :return: a tensor like x of log probabilities (in nats).
60 | """
61 | assert x.shape == means.shape == log_scales.shape
62 | centered_x = x - means
63 | inv_stdv = th.exp(-log_scales)
64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65 | cdf_plus = approx_standard_normal_cdf(plus_in)
66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67 | cdf_min = approx_standard_normal_cdf(min_in)
68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70 | cdf_delta = cdf_plus - cdf_min
71 | log_probs = th.where(
72 | x < -0.999,
73 | log_cdf_plus,
74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75 | )
76 | assert log_probs.shape == x.shape
77 | return log_probs
78 |
--------------------------------------------------------------------------------
/sdg/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, check out LICENSE.md
5 | """Miscellaneous utils."""
6 | import collections
7 | from collections import OrderedDict
8 | import numpy as np
9 | import random
10 |
11 | import torch
12 | import torch.nn.functional as F
13 |
14 | from .distributed import get_rank
15 |
16 | string_classes = (str, bytes)
17 |
18 | def set_random_seed(seed, by_rank=False):
19 | r"""Set random seeds for everything.
20 | Args:
21 | seed (int): Random seed.
22 | by_rank (bool):
23 | """
24 | if by_rank:
25 | seed += get_rank()
26 | print(f"Using random seed {seed}")
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 |
33 |
34 | def split_labels(labels, label_lengths):
35 | r"""Split concatenated labels into their parts.
36 |
37 | Args:
38 | labels (torch.Tensor): Labels obtained through concatenation.
39 | label_lengths (OrderedDict): Containing order of labels & their lengths.
40 |
41 | Returns:
42 |
43 | """
44 | assert isinstance(label_lengths, OrderedDict)
45 | start = 0
46 | outputs = {}
47 | for data_type, length in label_lengths.items():
48 | end = start + length
49 | if labels.dim() == 5:
50 | outputs[data_type] = labels[:, :, start:end]
51 | elif labels.dim() == 4:
52 | outputs[data_type] = labels[:, start:end]
53 | elif labels.dim() == 3:
54 | outputs[data_type] = labels[start:end]
55 | start = end
56 | return outputs
57 |
58 |
59 | def requires_grad(model, require=True):
60 | r""" Set a model to require gradient or not.
61 |
62 | Args:
63 | model (nn.Module): Neural network model.
64 | require (bool): Whether the network requires gradient or not.
65 |
66 | Returns:
67 |
68 | """
69 | for p in model.parameters():
70 | p.requires_grad = require
71 |
72 |
73 | def to_device(data, device):
74 | r"""Move all tensors inside data to device.
75 |
76 | Args:
77 | data (dict, list, or tensor): Input data.
78 | device (str): 'cpu' or 'cuda'.
79 | """
80 | assert device in ['cpu', 'cuda']
81 | if isinstance(data, torch.Tensor):
82 | data = data.to(torch.device(device))
83 | return data
84 | elif isinstance(data, collections.abc.Mapping):
85 | return {key: to_device(data[key], device) for key in data}
86 | elif isinstance(data, collections.abc.Sequence) and \
87 | not isinstance(data, string_classes):
88 | return [to_device(d, device) for d in data]
89 | else:
90 | return data
91 |
92 |
93 | def to_cuda(data):
94 | r"""Move all tensors inside data to gpu.
95 |
96 | Args:
97 | data (dict, list, or tensor): Input data.
98 | """
99 | return to_device(data, 'cuda')
100 |
101 |
102 | def to_cpu(data):
103 | r"""Move all tensors inside data to cpu.
104 |
105 | Args:
106 | data (dict, list, or tensor): Input data.
107 | """
108 | return to_device(data, 'cpu')
109 |
110 |
111 | def to_half(data):
112 | r"""Move all floats to half.
113 |
114 | Args:
115 | data (dict, list or tensor): Input data.
116 | """
117 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data):
118 | data = data.half()
119 | return data
120 | elif isinstance(data, collections.abc.Mapping):
121 | return {key: to_half(data[key]) for key in data}
122 | elif isinstance(data, collections.abc.Sequence) and \
123 | not isinstance(data, string_classes):
124 | return [to_half(d) for d in data]
125 | else:
126 | return data
127 |
128 |
129 | def to_float(data):
130 | r"""Move all halfs to float.
131 |
132 | Args:
133 | data (dict, list or tensor): Input data.
134 | """
135 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data):
136 | data = data.float()
137 | return data
138 | elif isinstance(data, collections.abc.Mapping):
139 | return {key: to_float(data[key]) for key in data}
140 | elif isinstance(data, collections.abc.Sequence) and \
141 | not isinstance(data, string_classes):
142 | return [to_float(d) for d in data]
143 | else:
144 | return data
145 |
146 |
147 | def to_channels_last(data):
148 | r"""Move all data to ``channels_last`` format.
149 |
150 | Args:
151 | data (dict, list or tensor): Input data.
152 | """
153 | if isinstance(data, torch.Tensor):
154 | if data.dim() == 4:
155 | data = data.to(memory_format=torch.channels_last)
156 | return data
157 | elif isinstance(data, collections.abc.Mapping):
158 | return {key: to_channels_last(data[key]) for key in data}
159 | elif isinstance(data, collections.abc.Sequence) and \
160 | not isinstance(data, string_classes):
161 | return [to_channels_last(d) for d in data]
162 | else:
163 | return data
164 |
165 |
166 | def slice_tensor(data, start, end):
167 | r"""Slice all tensors from start to end.
168 | Args:
169 | data (dict, list or tensor): Input data.
170 | """
171 | if isinstance(data, torch.Tensor):
172 | data = data[start:end]
173 | return data
174 | elif isinstance(data, collections.abc.Mapping):
175 | return {key: slice_tensor(data[key], start, end) for key in data}
176 | elif isinstance(data, collections.abc.Sequence) and \
177 | not isinstance(data, string_classes):
178 | return [slice_tensor(d, start, end) for d in data]
179 | else:
180 | return data
181 |
182 |
183 | def get_and_setattr(cfg, name, default):
184 | r"""Get attribute with default choice. If attribute does not exist, set it
185 | using the default value.
186 |
187 | Args:
188 | cfg (obj) : Config options.
189 | name (str) : Attribute name.
190 | default (obj) : Default attribute.
191 |
192 | Returns:
193 | (obj) : Desired attribute.
194 | """
195 | if not hasattr(cfg, name) or name not in cfg.__dict__:
196 | setattr(cfg, name, default)
197 | return getattr(cfg, name)
198 |
199 |
200 | def get_nested_attr(cfg, attr_name, default):
201 | r"""Iteratively try to get the attribute from cfg. If not found, return
202 | default.
203 |
204 | Args:
205 | cfg (obj): Config file.
206 | attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ).
207 | default (obj): Default return value for the attribute.
208 |
209 | Returns:
210 | (obj): Attribute value.
211 | """
212 | names = attr_name.split('.')
213 | atr = cfg
214 | for name in names:
215 | if not hasattr(atr, name):
216 | return default
217 | atr = getattr(atr, name)
218 | return atr
219 |
220 |
221 | def gradient_norm(model):
222 | r"""Return the gradient norm of model.
223 |
224 | Args:
225 | model (PyTorch module): Your network.
226 |
227 | """
228 | total_norm = 0
229 | for p in model.parameters():
230 | if p.grad is not None:
231 | param_norm = p.grad.norm(2)
232 | total_norm += param_norm.item() ** 2
233 | return total_norm ** (1. / 2)
234 |
235 |
236 | def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'):
237 | r"""Randomly shift the input tensor.
238 |
239 | Args:
240 | x (4D tensor): The input batch of images.
241 | offset (int): The maximum offset ratio that is between [0, 1].
242 | The maximum shift is offset * image_size for each direction.
243 | mode (str): The resample mode for 'F.grid_sample'.
244 | padding_mode (str): The padding mode for 'F.grid_sample'.
245 |
246 | Returns:
247 | x (4D tensor) : The randomly shifted image.
248 | """
249 | assert x.dim() == 4, "Input must be a 4D tensor."
250 | batch_size = x.size(0)
251 | theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat(
252 | batch_size, 1, 1)
253 | theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset
254 | grid = F.affine_grid(theta, x.size())
255 | x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode)
256 | return x
257 |
258 |
259 | # def truncated_gaussian(threshold, size, seed=None, device=None):
260 | # r"""Apply the truncated gaussian trick to trade diversity for quality
261 | #
262 | # Args:
263 | # threshold (float): Truncation threshold.
264 | # size (list of integer): Tensor size.
265 | # seed (int): Random seed.
266 | # device:
267 | # """
268 | # state = None if seed is None else np.random.RandomState(seed)
269 | # values = truncnorm.rvs(-threshold, threshold,
270 | # size=size, random_state=state)
271 | # return torch.tensor(values, device=device).float()
272 |
273 |
274 | def apply_imagenet_normalization(input):
275 | r"""Normalize using ImageNet mean and std.
276 |
277 | Args:
278 | input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1].
279 |
280 | Returns:
281 | Normalized inputs using the ImageNet normalization.
282 | """
283 | # normalize the input back to [0, 1]
284 | normalized_input = (input + 1) / 2
285 | # normalize the input using the ImageNet mean and std
286 | mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
287 | std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
288 | output = (normalized_input - mean) / std
289 | return output
290 |
--------------------------------------------------------------------------------
/sdg/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/sdg/parser.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 | import os
5 | import argparse
6 |
7 | from sdg.resample import create_named_schedule_sampler
8 | from sdg.script_util import (
9 | model_and_diffusion_defaults,
10 | args_to_dict,
11 | add_dict_to_argparser,
12 | )
13 |
14 |
15 | def create_argparser():
16 | parser = argparse.ArgumentParser()
17 | # basic
18 | parser.add_argument('--exp_name', required=True)
19 | parser.add_argument('--resume_checkpoint', default=None)
20 | parser.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0))
21 | parser.add_argument('--single_gpu', action='store_true')
22 | parser.add_argument('--seed', type=int, default=2)
23 | parser.add_argument('--randomized_seed', action='store_true')
24 | parser.add_argument('--debug', action='store_true')
25 | parser.add_argument('--logdir', type=str)
26 |
27 | # data
28 | parser.add_argument('--data_dir', type=str, default='')
29 | parser.add_argument('--image_size', type=int, default=256)
30 | parser.add_argument('--deterministic', type=lambda x: (str(x).lower() == 'true'), default=False)
31 | parser.add_argument('--random_crop', type=lambda x: (str(x).lower() == 'true'), default=False)
32 | parser.add_argument('--random_flip', type=lambda x: (str(x).lower() == 'true'), default=True)
33 | parser.add_argument('--num_workers', type=int, default=4)
34 | parser.add_argument('--clip_model', type=str, default='ViT-B-16')
35 |
36 | # model
37 | parser.add_argument('--class_cond', type=lambda x: (str(x).lower() == 'true'), default=False)
38 | parser.add_argument('--text_cond', type=lambda x: (str(x).lower() == 'true'), default=False)
39 | parser.add_argument('--num_channels', type=int, default=128)
40 | parser.add_argument('--num_res_blocks', type=int, default=2)
41 | parser.add_argument('--num_heads', type=int, default=4)
42 | parser.add_argument('--num_heads_upsample', type=int, default=-1)
43 | parser.add_argument('--num_head_channels', type=int, default=-1)
44 | parser.add_argument('--attention_resolutions', type=str, default="16,8")
45 | parser.add_argument('--channel_mult', default="")
46 | parser.add_argument('--dropout', type=float, default=0.0)
47 | parser.add_argument('--use_checkpoint', type=lambda x: (str(x).lower() == 'true'), default=False)
48 | parser.add_argument('--use_scale_shift_norm', type=lambda x: (str(x).lower() == 'true'), default=True)
49 | parser.add_argument('--resblock_updown', type=lambda x: (str(x).lower() == 'true'), default=False)
50 | parser.add_argument('--use_new_attention_order', type=lambda x: (str(x).lower() == 'true'), default=False)
51 |
52 | # diffusion
53 | parser.add_argument('--learn_sigma', type=lambda x: (str(x).lower() == 'true'), default=False)
54 | parser.add_argument('--diffusion_steps', type=int, default=1000)
55 | parser.add_argument('--noise_schedule', type=str, default='linear')
56 | parser.add_argument('--timestep_respacing', default='')
57 | parser.add_argument('--use_kl', type=lambda x: (str(x).lower() == 'true'), default=False)
58 | parser.add_argument('--predict_xstart', type=lambda x: (str(x).lower() == 'true'), default=False)
59 | parser.add_argument('--rescale_timesteps',type=lambda x: (str(x).lower() == 'true'), default=False)
60 | parser.add_argument('--rescale_learned_sigmas', type=lambda x: (str(x).lower() == 'true'), default=False)
61 |
62 | # classifier
63 | parser.add_argument('--classifier_use_fp16', type=lambda x: (str(x).lower() == 'true'), default=False)
64 | parser.add_argument('--classifier_width', type=int, default=128)
65 | parser.add_argument('--classifier_depth', type=int, default=2)
66 | parser.add_argument('--classifier_attention_resolutions', type=str, default="32,16,8")
67 | parser.add_argument('--classifier_use_scale_shift_norm', type=lambda x: (str(x).lower() == 'true'), default=True)
68 | parser.add_argument('--classifier_resblock_updown', type=lambda x: (str(x).lower() == 'true'), default=True)
69 | parser.add_argument('--classifier_pool', type=str, default="attention")
70 | parser.add_argument('--num_classes', type=int, default=1000)
71 |
72 | # sr
73 | parser.add_argument('--large_size', type=int, default=256)
74 | parser.add_argument('--small_size', type=int, default=64)
75 |
76 | # train
77 | parser.add_argument('--batch_size', type=int, default=32)
78 | parser.add_argument('--microbatch', type=int, default=-1)
79 | parser.add_argument('--schedule_sampler', type=str, default='uniform')
80 | parser.add_argument('--lr', type=float, default=1e-4)
81 | parser.add_argument('--weight_decay', type=float, default=0.0)
82 | parser.add_argument('--lr_anneal_steps', type=int, default=0)
83 | parser.add_argument('--use_fp16', type=lambda x: (str(x).lower() == 'true'), default=False)
84 | parser.add_argument('--fp16_scale_growth', type=float, default=1e-3)
85 | parser.add_argument('--fp16_hyperparams', type=str, default='openai')
86 | parser.add_argument('--anneal_lr', type=lambda x: (str(x).lower() == 'true'), default=False)
87 | parser.add_argument('--iterations', type=int, default=500000)
88 |
89 | # save
90 | parser.add_argument('--ema_rate', default='0.9999')
91 | parser.add_argument('--log_interval', type=int, default=10)
92 | parser.add_argument('--save_interval', type=int, default=10000)
93 | parser.add_argument('--eval_interval', type=int, default=10)
94 |
95 | # inference
96 | parser.add_argument('--model_path', type=str, default='')
97 | parser.add_argument('--use_ddim', type=lambda x: (str(x).lower() == 'true'), default=False)
98 | parser.add_argument('--clip_denoised', type=lambda x: (str(x).lower() == 'true'), default=True)
99 | parser.add_argument('--text_weight', type=float, default=1.0)
100 | parser.add_argument('--image_weight', type=float, default=1.0)
101 | parser.add_argument('--image_loss', type=str, default='semantic')
102 | parser.add_argument('--text_instruction_file', type=str, default='ref/ffhq_instructions.txt')
103 | parser.add_argument('--clip_path', type=str, default='')
104 |
105 |
106 | # train classifier/clip
107 | parser.add_argument('--noised', type=lambda x: (str(x).lower() == 'true'), default=True)
108 | parser.add_argument('--finetune_clip_layer', type=str, default='all')
109 |
110 | # superres
111 | parser.add_argument('--base_name', type=str, default='')
112 |
113 |
114 | return parser
115 |
116 |
--------------------------------------------------------------------------------
/sdg/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == 'linear':
18 | return LinearSampler(diffusion)
19 | elif name == 'half':
20 | return HalfSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 |
32 | By default, samplers perform unbiased importance sampling, in which the
33 | objective's mean is unchanged.
34 | However, subclasses may override sample() to change how the resampled
35 | terms are reweighted, allowing for actual changes in the objective.
36 | """
37 |
38 | @abstractmethod
39 | def weights(self):
40 | """
41 | Get a numpy array of weights, one per diffusion step.
42 |
43 | The weights needn't be normalized, but must be positive.
44 | """
45 |
46 | def sample(self, batch_size, device):
47 | """
48 | Importance-sample timesteps for a batch.
49 |
50 | :param batch_size: the number of timesteps.
51 | :param device: the torch device to save to.
52 | :return: a tuple (timesteps, weights):
53 | - timesteps: a tensor of timestep indices.
54 | - weights: a tensor of weights to scale the resulting losses.
55 | """
56 | w = self.weights()
57 | p = w / np.sum(w)
58 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
59 | indices = th.from_numpy(indices_np).long().to(device)
60 | weights_np = 1 / (len(p) * p[indices_np])
61 | weights = th.from_numpy(weights_np).float().to(device)
62 | return indices, weights
63 |
64 |
65 | class UniformSampler(ScheduleSampler):
66 | def __init__(self, diffusion):
67 | self.diffusion = diffusion
68 | self._weights = np.ones([diffusion.num_timesteps])
69 |
70 | def weights(self):
71 | return self._weights
72 |
73 | class LinearSampler(ScheduleSampler):
74 | # higher weight on noisy images and lower weight on clean images
75 | def __init__(self, diffusion):
76 | self.diffusion = diffusion
77 | self._weights = np.linspace(1, 10, diffusion.num_timesteps)
78 |
79 | def weights(self):
80 | return self._weights
81 |
82 | class HalfSampler(ScheduleSampler):
83 | def __init__(self, diffusion):
84 | self.diffusion = diffusion
85 | self._weights = np.ones([diffusion.num_timesteps])
86 | self._weights[diffusion.num_timesteps//2:] = 0
87 |
88 | def weights(self):
89 | return self._weights
90 |
91 |
92 |
93 |
94 | class LossAwareSampler(ScheduleSampler):
95 | def update_with_local_losses(self, local_ts, local_losses):
96 | """
97 | Update the reweighting using losses from a model.
98 |
99 | Call this method from each rank with a batch of timesteps and the
100 | corresponding losses for each of those timesteps.
101 | This method will perform synchronization to make sure all of the ranks
102 | maintain the exact same reweighting.
103 |
104 | :param local_ts: an integer Tensor of timesteps.
105 | :param local_losses: a 1D Tensor of losses.
106 | """
107 | batch_sizes = [
108 | th.tensor([0], dtype=th.int32, device=local_ts.device)
109 | for _ in range(dist.get_world_size())
110 | ]
111 | dist.all_gather(
112 | batch_sizes,
113 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
114 | )
115 |
116 | # Pad all_gather batches to be the maximum batch size.
117 | batch_sizes = [x.item() for x in batch_sizes]
118 | max_bs = max(batch_sizes)
119 |
120 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
121 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
122 | dist.all_gather(timestep_batches, local_ts)
123 | dist.all_gather(loss_batches, local_losses)
124 | timesteps = [
125 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
126 | ]
127 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
128 | self.update_with_all_losses(timesteps, losses)
129 |
130 | @abstractmethod
131 | def update_with_all_losses(self, ts, losses):
132 | """
133 | Update the reweighting using losses from a model.
134 |
135 | Sub-classes should override this method to update the reweighting
136 | using losses from the model.
137 |
138 | This method directly updates the reweighting without synchronizing
139 | between workers. It is called by update_with_local_losses from all
140 | ranks with identical arguments. Thus, it should have deterministic
141 | behavior to maintain state across workers.
142 |
143 | :param ts: a list of int timesteps.
144 | :param losses: a list of float losses, one per timestep.
145 | """
146 |
147 |
148 | class LossSecondMomentResampler(LossAwareSampler):
149 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
150 | self.diffusion = diffusion
151 | self.history_per_term = history_per_term
152 | self.uniform_prob = uniform_prob
153 | self._loss_history = np.zeros(
154 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
155 | )
156 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
157 |
158 | def weights(self):
159 | if not self._warmed_up():
160 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
161 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
162 | weights /= np.sum(weights)
163 | weights *= 1 - self.uniform_prob
164 | weights += self.uniform_prob / len(weights)
165 | return weights
166 |
167 | def update_with_all_losses(self, ts, losses):
168 | for t, loss in zip(ts, losses):
169 | if self._loss_counts[t] == self.history_per_term:
170 | # Shift out the oldest loss term.
171 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
172 | self._loss_history[t, -1] = loss
173 | else:
174 | self._loss_history[t, self._loss_counts[t]] = loss
175 | self._loss_counts[t] += 1
176 |
177 | def _warmed_up(self):
178 | return (self._loss_counts == self.history_per_term).all()
179 |
--------------------------------------------------------------------------------
/sdg/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = (new_ts.float() * (1000.0 / self.original_num_steps)).long()
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/sdg/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel
7 | from .clip_guidance import CLIP_gd
8 |
9 | NUM_CLASSES = 1000
10 |
11 |
12 | def diffusion_defaults():
13 | """
14 | Defaults for image and classifier training.
15 | """
16 | return dict(
17 | learn_sigma=False,
18 | diffusion_steps=1000,
19 | noise_schedule="linear",
20 | timestep_respacing="",
21 | use_kl=False,
22 | predict_xstart=False,
23 | rescale_timesteps=False,
24 | rescale_learned_sigmas=False,
25 | )
26 |
27 |
28 | def classifier_defaults():
29 | """
30 | Defaults for classifier models.
31 | """
32 | return dict(
33 | image_size=64,
34 | classifier_use_fp16=False,
35 | classifier_width=128,
36 | classifier_depth=2,
37 | classifier_attention_resolutions="32,16,8", # 16
38 | classifier_use_scale_shift_norm=True, # False
39 | classifier_resblock_updown=True, # False
40 | classifier_pool="attention",
41 | num_classes=1000,
42 | )
43 |
44 |
45 | def model_and_diffusion_defaults():
46 | """
47 | Defaults for image training.
48 | """
49 | res = dict(
50 | image_size=64,
51 | num_channels=128,
52 | num_res_blocks=2,
53 | num_heads=4,
54 | num_heads_upsample=-1,
55 | num_head_channels=-1,
56 | attention_resolutions="16,8",
57 | channel_mult="",
58 | dropout=0.0,
59 | class_cond=False,
60 | text_cond=False,
61 | use_checkpoint=False,
62 | use_scale_shift_norm=True,
63 | resblock_updown=False,
64 | use_fp16=False,
65 | use_new_attention_order=False,
66 | )
67 | res.update(diffusion_defaults())
68 | return res
69 |
70 |
71 | def classifier_and_diffusion_defaults():
72 | res = classifier_defaults()
73 | res.update(diffusion_defaults())
74 | return res
75 |
76 |
77 | def create_model_and_diffusion(
78 | image_size,
79 | class_cond,
80 | text_cond,
81 | learn_sigma,
82 | num_channels,
83 | num_res_blocks,
84 | channel_mult,
85 | num_heads,
86 | num_head_channels,
87 | num_heads_upsample,
88 | attention_resolutions,
89 | dropout,
90 | diffusion_steps,
91 | noise_schedule,
92 | timestep_respacing,
93 | use_kl,
94 | predict_xstart,
95 | rescale_timesteps,
96 | rescale_learned_sigmas,
97 | use_checkpoint,
98 | use_scale_shift_norm,
99 | resblock_updown,
100 | use_fp16,
101 | use_new_attention_order,
102 | ):
103 | model = create_model(
104 | image_size,
105 | num_channels,
106 | num_res_blocks,
107 | channel_mult=channel_mult,
108 | learn_sigma=learn_sigma,
109 | class_cond=class_cond,
110 | text_cond=text_cond,
111 | use_checkpoint=use_checkpoint,
112 | attention_resolutions=attention_resolutions,
113 | num_heads=num_heads,
114 | num_head_channels=num_head_channels,
115 | num_heads_upsample=num_heads_upsample,
116 | use_scale_shift_norm=use_scale_shift_norm,
117 | dropout=dropout,
118 | resblock_updown=resblock_updown,
119 | use_fp16=use_fp16,
120 | use_new_attention_order=use_new_attention_order,
121 | )
122 | diffusion = create_gaussian_diffusion(
123 | steps=diffusion_steps,
124 | learn_sigma=learn_sigma,
125 | noise_schedule=noise_schedule,
126 | use_kl=use_kl,
127 | predict_xstart=predict_xstart,
128 | rescale_timesteps=rescale_timesteps,
129 | rescale_learned_sigmas=rescale_learned_sigmas,
130 | timestep_respacing=timestep_respacing,
131 | )
132 | return model, diffusion
133 |
134 |
135 | def create_model(
136 | image_size,
137 | num_channels,
138 | num_res_blocks,
139 | channel_mult="",
140 | learn_sigma=False,
141 | class_cond=False,
142 | text_cond=False,
143 | use_checkpoint=False,
144 | attention_resolutions="16",
145 | num_heads=1,
146 | num_head_channels=-1,
147 | num_heads_upsample=-1,
148 | use_scale_shift_norm=False,
149 | dropout=0,
150 | resblock_updown=False,
151 | use_fp16=False,
152 | use_new_attention_order=False,
153 | ):
154 | if channel_mult == "":
155 | if image_size == 512:
156 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
157 | elif image_size == 256:
158 | channel_mult = (1, 1, 2, 2, 4, 4)
159 | elif image_size == 128:
160 | channel_mult = (1, 1, 2, 3, 4)
161 | elif image_size == 64:
162 | channel_mult = (1, 2, 3, 4)
163 | else:
164 | raise ValueError(f"unsupported image size: {image_size}")
165 | else:
166 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
167 |
168 | attention_ds = []
169 | for res in attention_resolutions.split(","):
170 | attention_ds.append(image_size // int(res))
171 |
172 | return UNetModel(
173 | image_size=image_size,
174 | in_channels=3,
175 | model_channels=num_channels,
176 | out_channels=(3 if not learn_sigma else 6),
177 | num_res_blocks=num_res_blocks,
178 | attention_resolutions=tuple(attention_ds),
179 | dropout=dropout,
180 | channel_mult=channel_mult,
181 | num_classes=(NUM_CLASSES if class_cond else None),
182 | text_cond=text_cond,
183 | use_checkpoint=use_checkpoint,
184 | use_fp16=use_fp16,
185 | num_heads=num_heads,
186 | num_head_channels=num_head_channels,
187 | num_heads_upsample=num_heads_upsample,
188 | use_scale_shift_norm=use_scale_shift_norm,
189 | resblock_updown=resblock_updown,
190 | use_new_attention_order=use_new_attention_order,
191 | )
192 |
193 |
194 | def create_classifier_and_diffusion(
195 | image_size,
196 | classifier_use_fp16,
197 | classifier_width,
198 | classifier_depth,
199 | classifier_attention_resolutions,
200 | classifier_use_scale_shift_norm,
201 | classifier_resblock_updown,
202 | classifier_pool,
203 | num_classes,
204 | learn_sigma,
205 | diffusion_steps,
206 | noise_schedule,
207 | timestep_respacing,
208 | use_kl,
209 | predict_xstart,
210 | rescale_timesteps,
211 | rescale_learned_sigmas,
212 | ):
213 | classifier = create_classifier(
214 | image_size,
215 | classifier_use_fp16,
216 | classifier_width,
217 | classifier_depth,
218 | classifier_attention_resolutions,
219 | classifier_use_scale_shift_norm,
220 | classifier_resblock_updown,
221 | classifier_pool,
222 | num_classes,
223 | )
224 | diffusion = create_gaussian_diffusion(
225 | steps=diffusion_steps,
226 | learn_sigma=learn_sigma,
227 | noise_schedule=noise_schedule,
228 | use_kl=use_kl,
229 | predict_xstart=predict_xstart,
230 | rescale_timesteps=rescale_timesteps,
231 | rescale_learned_sigmas=rescale_learned_sigmas,
232 | timestep_respacing=timestep_respacing,
233 | )
234 | return classifier, diffusion
235 |
236 |
237 | def create_clip_and_diffusion(
238 | args,
239 | image_size,
240 | classifier_use_fp16,
241 | classifier_width,
242 | classifier_depth,
243 | classifier_attention_resolutions,
244 | classifier_use_scale_shift_norm,
245 | classifier_resblock_updown,
246 | classifier_pool,
247 | num_classes,
248 | learn_sigma,
249 | diffusion_steps,
250 | noise_schedule,
251 | timestep_respacing,
252 | use_kl,
253 | predict_xstart,
254 | rescale_timesteps,
255 | rescale_learned_sigmas,
256 | ):
257 | clip = create_clip(args)
258 | diffusion = create_gaussian_diffusion(
259 | steps=diffusion_steps,
260 | learn_sigma=learn_sigma,
261 | noise_schedule=noise_schedule,
262 | use_kl=use_kl,
263 | predict_xstart=predict_xstart,
264 | rescale_timesteps=rescale_timesteps,
265 | rescale_learned_sigmas=rescale_learned_sigmas,
266 | timestep_respacing=timestep_respacing,
267 | )
268 | return clip, diffusion
269 |
270 |
271 | def create_classifier(
272 | image_size,
273 | classifier_use_fp16,
274 | classifier_width,
275 | classifier_depth,
276 | classifier_attention_resolutions,
277 | classifier_use_scale_shift_norm,
278 | classifier_resblock_updown,
279 | classifier_pool,
280 | num_classes,
281 | ):
282 | if image_size == 512:
283 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
284 | elif image_size == 256:
285 | channel_mult = (1, 1, 2, 2, 4, 4)
286 | elif image_size == 128:
287 | channel_mult = (1, 1, 2, 3, 4)
288 | elif image_size == 64:
289 | channel_mult = (1, 2, 3, 4)
290 | else:
291 | raise ValueError(f"unsupported image size: {image_size}")
292 |
293 | attention_ds = []
294 | for res in classifier_attention_resolutions.split(","):
295 | attention_ds.append(image_size // int(res))
296 |
297 | return EncoderUNetModel(
298 | image_size=image_size,
299 | in_channels=3,
300 | model_channels=classifier_width,
301 | out_channels=num_classes,
302 | num_res_blocks=classifier_depth,
303 | attention_resolutions=tuple(attention_ds),
304 | channel_mult=channel_mult,
305 | use_fp16=classifier_use_fp16,
306 | num_head_channels=64,
307 | use_scale_shift_norm=classifier_use_scale_shift_norm,
308 | resblock_updown=classifier_resblock_updown,
309 | pool=classifier_pool,
310 | )
311 |
312 | def create_clip(args):
313 | return CLIP_gd(args)
314 |
315 |
316 | def sr_model_and_diffusion_defaults():
317 | res = model_and_diffusion_defaults()
318 | res["large_size"] = 256
319 | res["small_size"] = 64
320 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]
321 | for k in res.copy().keys():
322 | if k not in arg_names:
323 | del res[k]
324 | return res
325 |
326 |
327 | def sr_create_model_and_diffusion(
328 | large_size,
329 | small_size,
330 | class_cond,
331 | text_cond,
332 | learn_sigma,
333 | num_channels,
334 | num_res_blocks,
335 | num_heads,
336 | num_head_channels,
337 | num_heads_upsample,
338 | attention_resolutions,
339 | dropout,
340 | diffusion_steps,
341 | noise_schedule,
342 | timestep_respacing,
343 | use_kl,
344 | predict_xstart,
345 | rescale_timesteps,
346 | rescale_learned_sigmas,
347 | use_checkpoint,
348 | use_scale_shift_norm,
349 | resblock_updown,
350 | use_fp16,
351 | ):
352 | model = sr_create_model(
353 | large_size,
354 | small_size,
355 | num_channels,
356 | num_res_blocks,
357 | learn_sigma=learn_sigma,
358 | class_cond=class_cond,
359 | text_cond=text_cond,
360 | use_checkpoint=use_checkpoint,
361 | attention_resolutions=attention_resolutions,
362 | num_heads=num_heads,
363 | num_head_channels=num_head_channels,
364 | num_heads_upsample=num_heads_upsample,
365 | use_scale_shift_norm=use_scale_shift_norm,
366 | dropout=dropout,
367 | resblock_updown=resblock_updown,
368 | use_fp16=use_fp16,
369 | )
370 | diffusion = create_gaussian_diffusion(
371 | steps=diffusion_steps,
372 | learn_sigma=learn_sigma,
373 | noise_schedule=noise_schedule,
374 | use_kl=use_kl,
375 | predict_xstart=predict_xstart,
376 | rescale_timesteps=rescale_timesteps,
377 | rescale_learned_sigmas=rescale_learned_sigmas,
378 | timestep_respacing=timestep_respacing,
379 | )
380 | return model, diffusion
381 |
382 |
383 | def sr_create_model(
384 | large_size,
385 | small_size,
386 | num_channels,
387 | num_res_blocks,
388 | learn_sigma,
389 | class_cond,
390 | text_cond,
391 | use_checkpoint,
392 | attention_resolutions,
393 | num_heads,
394 | num_head_channels,
395 | num_heads_upsample,
396 | use_scale_shift_norm,
397 | dropout,
398 | resblock_updown,
399 | use_fp16,
400 | ):
401 | _ = small_size # hack to prevent unused variable
402 |
403 | if large_size == 512:
404 | channel_mult = (1, 1, 2, 2, 4, 4)
405 | elif large_size == 256:
406 | channel_mult = (1, 1, 2, 2, 4, 4)
407 | elif large_size == 64:
408 | channel_mult = (1, 2, 3, 4)
409 | else:
410 | raise ValueError(f"unsupported large size: {large_size}")
411 |
412 | attention_ds = []
413 | for res in attention_resolutions.split(","):
414 | attention_ds.append(large_size // int(res))
415 |
416 | return SuperResModel(
417 | image_size=large_size,
418 | in_channels=3,
419 | model_channels=num_channels,
420 | out_channels=(3 if not learn_sigma else 6),
421 | num_res_blocks=num_res_blocks,
422 | attention_resolutions=tuple(attention_ds),
423 | dropout=dropout,
424 | channel_mult=channel_mult,
425 | num_classes=(NUM_CLASSES if class_cond else None),
426 | text_cond=text_cond,
427 | use_checkpoint=use_checkpoint,
428 | num_heads=num_heads,
429 | num_head_channels=num_head_channels,
430 | num_heads_upsample=num_heads_upsample,
431 | use_scale_shift_norm=use_scale_shift_norm,
432 | resblock_updown=resblock_updown,
433 | use_fp16=use_fp16,
434 | )
435 |
436 |
437 | def create_gaussian_diffusion(
438 | *,
439 | steps=1000,
440 | learn_sigma=False,
441 | sigma_small=False,
442 | noise_schedule="linear",
443 | use_kl=False,
444 | predict_xstart=False,
445 | rescale_timesteps=False,
446 | rescale_learned_sigmas=False,
447 | timestep_respacing="",
448 | ):
449 | betas = gd.get_named_beta_schedule(noise_schedule, steps)
450 | betas1000 = gd.get_named_beta_schedule(noise_schedule, 1000)
451 | if use_kl:
452 | loss_type = gd.LossType.RESCALED_KL
453 | elif rescale_learned_sigmas:
454 | loss_type = gd.LossType.RESCALED_MSE
455 | else:
456 | loss_type = gd.LossType.MSE
457 | if not timestep_respacing:
458 | timestep_respacing = [steps]
459 | return SpacedDiffusion(
460 | use_timesteps=space_timesteps(steps, timestep_respacing),
461 | betas=betas,
462 | model_mean_type=(
463 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
464 | ),
465 | model_var_type=(
466 | (
467 | gd.ModelVarType.FIXED_LARGE
468 | if not sigma_small
469 | else gd.ModelVarType.FIXED_SMALL
470 | )
471 | if not learn_sigma
472 | else gd.ModelVarType.LEARNED_RANGE
473 | ),
474 | loss_type=loss_type,
475 | rescale_timesteps=rescale_timesteps,
476 | betas1000=betas1000,
477 | )
478 |
479 |
480 | def add_dict_to_argparser(parser, default_dict):
481 | for k, v in default_dict.items():
482 | v_type = type(v)
483 | if v is None:
484 | v_type = str
485 | elif isinstance(v, bool):
486 | v_type = str2bool
487 | parser.add_argument(f"--{k}", default=v, type=v_type)
488 |
489 |
490 | def args_to_dict(args, keys):
491 | return {k: getattr(args, k) for k in keys}
492 |
493 |
494 | def str2bool(v):
495 | """
496 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
497 | """
498 | if isinstance(v, bool):
499 | return v
500 | if v.lower() in ("yes", "true", "t", "y", "1"):
501 | return True
502 | elif v.lower() in ("no", "false", "f", "n", "0"):
503 | return False
504 | else:
505 | raise argparse.ArgumentTypeError("boolean value expected")
506 |
--------------------------------------------------------------------------------
/sdg/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 | import time
5 | import glob
6 |
7 | import blobfile as bf
8 | import torch as th
9 | import torch.distributed as dist
10 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
11 | from torch.optim import AdamW
12 |
13 | from .distributed import get_world_size, is_master
14 | from .distributed import master_only_print as print
15 | from .fp16_util import MixedPrecisionTrainer
16 | from .nn import update_ema
17 | from .resample import LossAwareSampler, UniformSampler
18 | from .logging import write_tb
19 | from .misc import to_cuda
20 | from . import logger
21 |
22 | # For ImageNet experiments, this was a good default value.
23 | # We found that the lg_loss_scale quickly climbed to
24 | # 20-21 within the first ~1K steps of training.
25 | INITIAL_LOG_LOSS_SCALE = 20.0
26 |
27 |
28 | class TrainLoop:
29 | def __init__(
30 | self,
31 | cfg,
32 | model,
33 | diffusion,
34 | data_train,
35 | tb_log,
36 | schedule_sampler=None,
37 | ):
38 | self.tb_log = tb_log
39 | self.model = model
40 | self.diffusion = diffusion
41 | self.data_train = data_train
42 | self.batch_size = cfg.batch_size
43 | self.microbatch = cfg.microbatch if cfg.microbatch > 0 else cfg.batch_size
44 | self.lr = cfg.lr
45 | self.ema_rate = (
46 | [cfg.ema_rate]
47 | if isinstance(cfg.ema_rate, float)
48 | else [float(x) for x in cfg.ema_rate.split(",")]
49 | )
50 | self.log_interval = cfg.log_interval
51 | self.save_interval = cfg.save_interval
52 | self.resume_checkpoint = cfg.resume_checkpoint
53 | self.use_fp16 = getattr(cfg, 'use_fp16', False)
54 | self.fp16_scale_growth = getattr(cfg, 'fp16_scale_growth', 1e-3)
55 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
56 | self.weight_decay = getattr(cfg, 'weight_decay', 0.0)
57 | self.lr_anneal_steps = getattr(cfg, 'lr_anneal_steps', 0)
58 | self.logdir = getattr(cfg, 'logdir', 'logs/debug')
59 | fp16_hyperparams = getattr(cfg, 'fp16_hyperparams', 'openai')
60 | if self.use_fp16:
61 | if fp16_hyperparams == 'pytorch':
62 | self.scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)
63 | elif fp16_hyperparams == 'openai':
64 | self.scaler = th.cuda.amp.GradScaler(init_scale=2**20 * 1.0, growth_factor=2**0.001, backoff_factor=0.5, growth_interval=1, enabled=True)
65 |
66 | self.step = 0
67 | self.resume_step = 0
68 | self.global_batch = self.batch_size * get_world_size()
69 |
70 | self.sync_cuda = True # th.cuda.is_available()
71 |
72 | self._load_and_sync_parameters()
73 |
74 | self.params = list(self.model.parameters())
75 | self.opt = AdamW(self.params, lr=self.lr, weight_decay=self.weight_decay)
76 | if self.resume_step:
77 | self._load_optimizer_state()
78 | # Model was resumed, either due to a restart or a checkpoint
79 | # being specified at the command line.
80 | self.ema_params = [
81 | self._load_ema_parameters(rate) for rate in self.ema_rate
82 | ]
83 | else:
84 | self.ema_params = [
85 | copy.deepcopy(self.params)
86 | for _ in range(len(self.ema_rate))
87 | ]
88 |
89 | self.use_ddp = th.cuda.is_available() and th.distributed.is_available() and dist.is_initialized()
90 | if self.use_ddp:
91 | self.ddp_model = DDP(
92 | self.model,
93 | device_ids=[cfg.local_rank],
94 | output_device=cfg.local_rank,
95 | bucket_cap_mb=128,
96 | broadcast_buffers=False,
97 | find_unused_parameters=False,
98 | )
99 | else:
100 | print(
101 | "Single GPU Training without DistributedDataParallel. "
102 | )
103 | self.ddp_model = self.model
104 |
105 | def _load_and_sync_parameters(self):
106 | resume_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint
107 |
108 | if resume_checkpoint:
109 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
110 | print(f"loading model from checkpoint: {resume_checkpoint}...")
111 | checkpoint = th.load(resume_checkpoint, map_location=lambda storage, loc: storage)
112 | self.model.load_state_dict(checkpoint)
113 |
114 | def _load_ema_parameters(self, rate):
115 | ema_params = copy.deepcopy(self.params)
116 |
117 | main_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint
118 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
119 | if ema_checkpoint:
120 | print(f"loading EMA from checkpoint: {ema_checkpoint}...")
121 | state_dict = th.load(ema_checkpoint, map_location=lambda storage, loc: storage)
122 | ema_params = to_cuda([state_dict[name] for name, _ in self.model.named_parameters()])
123 |
124 | return ema_params
125 |
126 | def _load_optimizer_state(self):
127 | main_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint
128 | opt_checkpoint = bf.join(
129 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
130 | )
131 | if bf.exists(opt_checkpoint):
132 | print(f"loading optimizer state from checkpoint: {opt_checkpoint}")
133 | checkpoint = th.load(opt_checkpoint, map_location=lambda storage, loc: storage)
134 | self.opt.load_state_dict(checkpoint)
135 |
136 | def _compute_norms(self, grad_scale=1.0):
137 | grad_norm = 0.0
138 | param_norm = 0.0
139 | for p in self.params:
140 | with th.no_grad():
141 | param_norm += th.norm(p, p=2, dtype=th.float32) ** 2
142 | if p.grad is not None:
143 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32) ** 2
144 | return th.sqrt(grad_norm) / grad_scale, th.sqrt(param_norm)
145 |
146 | def run_loop(self):
147 | time0 = time.time()
148 | while (
149 | not self.lr_anneal_steps
150 | or self.step + self.resume_step < self.lr_anneal_steps
151 | ):
152 | print('***step %d ' % (self.step + self.resume_step), end='')
153 | batch, cond = next(self.data_train)
154 | time2 = time.time()
155 | if self.use_fp16:
156 | self.run_step_amp(batch, cond)
157 | else:
158 | self.run_step(batch, cond)
159 | if self.step % self.log_interval == 0 and is_master():
160 | self.tb_log.flush()
161 | if self.step % self.save_interval == 0:
162 | self.save()
163 | # Run for a finite amount of time in integration tests.
164 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
165 | return
166 | num_samples = (self.step + self.resume_step + 1) * self.global_batch
167 | time1 = time.time()
168 | if is_master():
169 | self.tb_log.add_scalar('status/step', self.step+self.resume_step, self.step + self.resume_step)
170 | self.tb_log.add_scalar('status/samples', num_samples, self.step + self.resume_step)
171 | self.tb_log.add_scalar('time/time_per_iter', time1-time0, self.step + self.resume_step)
172 | self.tb_log.add_scalar('time/data_time_per_iter', time2-time0, self.step + self.resume_step)
173 | self.tb_log.add_scalar('time/model_time_per_iter', time1-time2, self.step + self.resume_step)
174 | self.tb_log.add_scalar('status/lr', self.lr, self.step + self.resume_step)
175 | print('lr: %f ' % self.lr, end='')
176 | print('samples: %d ' % num_samples, end='')
177 | print('data time: %f ' % (time2-time0), end='')
178 | print('model time: %f ' % (time1-time2), end='')
179 | print('')
180 | self.step += 1
181 | time0 = time1
182 | # Save the last checkpoint if it wasn't already saved.
183 | if (self.step - 1) % self.save_interval != 0:
184 | self.save()
185 |
186 | def run_step(self, batch, cond):
187 | self.forward_backward(batch, cond)
188 | self.opt.step()
189 | self._update_ema()
190 | self._anneal_lr()
191 |
192 | def forward_backward(self, batch, cond):
193 | self.opt.zero_grad()
194 | for i in range(0, batch.shape[0], self.microbatch):
195 | micro = to_cuda(batch[i : i + self.microbatch])
196 | micro_cond = {
197 | k: to_cuda(v[i : i + self.microbatch])
198 | for k, v in cond.items()
199 | }
200 | last_batch = (i + self.microbatch) >= batch.shape[0]
201 | t, weights = self.schedule_sampler.sample(micro.shape[0], 'cuda')
202 |
203 | compute_losses = functools.partial(
204 | self.diffusion.training_losses,
205 | self.ddp_model,
206 | micro,
207 | t,
208 | model_kwargs=micro_cond,
209 | )
210 |
211 | if last_batch or not self.use_ddp:
212 | losses = compute_losses()
213 | else:
214 | with self.ddp_model.no_sync():
215 | losses = compute_losses()
216 |
217 | if isinstance(self.schedule_sampler, LossAwareSampler):
218 | self.schedule_sampler.update_with_local_losses(
219 | t, losses["loss"].detach()
220 | )
221 |
222 | loss = (losses["loss"] * weights).mean()
223 | if self.step % 10 == 0:
224 | log_loss_dict(
225 | self.diffusion, t, {k: v * weights for k, v in losses.items()},
226 | self.tb_log, self.step + self.resume_step
227 | )
228 | loss.backward()
229 |
230 | def run_step_amp(self, batch, cond):
231 | self.forward_backward_amp(batch, cond)
232 | self.scaler.step(self.opt)
233 | self.scaler.update()
234 | self._update_ema()
235 | self._anneal_lr()
236 |
237 | def forward_backward_amp(self, batch, cond):
238 | self.opt.zero_grad()
239 | for i in range(0, batch.shape[0], self.microbatch):
240 | micro = to_cuda(batch[i : i + self.microbatch])
241 | micro_cond = {
242 | k: to_cuda(v[i : i + self.microbatch])
243 | for k, v in cond.items()
244 | }
245 | last_batch = (i + self.microbatch) >= batch.shape[0]
246 | t, weights = self.schedule_sampler.sample(micro.shape[0], 'cuda')
247 |
248 | with th.cuda.amp.autocast(True):
249 | compute_losses = functools.partial(
250 | self.diffusion.training_losses,
251 | self.ddp_model,
252 | micro,
253 | t,
254 | model_kwargs=micro_cond,
255 | )
256 |
257 | if last_batch or not self.use_ddp:
258 | losses = compute_losses()
259 | else:
260 | with self.ddp_model.no_sync():
261 | losses = compute_losses()
262 |
263 | if isinstance(self.schedule_sampler, LossAwareSampler):
264 | self.schedule_sampler.update_with_local_losses(
265 | t, losses["loss"].detach()
266 | )
267 |
268 | loss = (losses["loss"] * weights).mean()
269 | if self.step % 10 == 0:
270 | log_loss_dict(
271 | self.diffusion, t, {k: v * weights for k, v in losses.items()},
272 | self.tb_log, self.step + self.resume_step
273 | )
274 | self.scaler.scale(loss).backward()
275 |
276 |
277 |
278 | def _update_ema(self):
279 | for rate, params in zip(self.ema_rate, self.ema_params):
280 | update_ema(params, self.params, rate=rate)
281 |
282 | def _anneal_lr(self):
283 | if not self.lr_anneal_steps:
284 | return
285 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
286 | lr = self.lr * (1 - frac_done)
287 | for param_group in self.opt.param_groups:
288 | param_group["lr"] = lr
289 |
290 | def log_step(self):
291 | logger.logkv("step", self.step + self.resume_step)
292 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
293 |
294 | def save(self):
295 | def save_checkpoint(rate):
296 | print(f"saving model {rate}...")
297 | if is_master():
298 | if not rate:
299 | filename = f"model{(self.step+self.resume_step):06d}.pt"
300 | else:
301 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
302 | with bf.BlobFile(bf.join(self.logdir, filename), "wb") as f:
303 | th.save(self.model.state_dict(), f)
304 |
305 | save_checkpoint(0)
306 | for rate, params in zip(self.ema_rate, self.ema_params):
307 | save_checkpoint(rate)
308 |
309 | if is_master():
310 | with bf.BlobFile(
311 | bf.join(self.logdir, f"opt{(self.step+self.resume_step):06d}.pt"),
312 | "wb",
313 | ) as f:
314 | th.save(self.opt.state_dict(), f)
315 |
316 | if is_master() and self.use_fp16:
317 | with bf.BlobFile(
318 | bf.join(self.logdir, f"scaler{(self.step+self.resume_step):06d}.pt"),
319 | "wb",
320 | ) as f:
321 | th.save(self.scaler.state_dict(), f)
322 |
323 |
324 | def parse_resume_step_from_filename(filename):
325 | """
326 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
327 | checkpoint's number of steps.
328 | """
329 | split = filename.split("model")
330 | if len(split) < 2:
331 | return 0
332 | split1 = split[-1].split(".")[0]
333 | try:
334 | return int(split1)
335 | except ValueError:
336 | return 0
337 |
338 |
339 | def get_blob_logdir():
340 | # You can change this to be a separate path to save checkpoints to
341 | # a blobstore or some external drive.
342 | return logger.get_dir()
343 |
344 |
345 | def find_resume_checkpoint(logdir):
346 | # On your infrastructure, you may want to override this to automatically
347 | # discover the latest checkpoint on your blob storage, etc.
348 | models = sorted(glob.glob(os.path.join(logdir, 'model*.pt')))
349 | if len(models) >= 1:
350 | return models[-1]
351 | else:
352 | return None
353 |
354 |
355 | def find_ema_checkpoint(main_checkpoint, step, rate):
356 | if main_checkpoint is None:
357 | return None
358 | filename = f"ema_{rate}_{(step):06d}.pt"
359 | path = bf.join(bf.dirname(main_checkpoint), filename)
360 | if bf.exists(path):
361 | return path
362 | return None
363 |
364 |
365 | def log_loss_dict(diffusion, ts, losses, tb_log, step, prefix='loss'):
366 | for key, values in losses.items():
367 | write_tb(tb_log, f"{prefix}/{key}", values, step)
368 | quartile_list = [[] for cnt in range(4)]
369 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
370 | quartile = int(4 * sub_t / diffusion.num_timesteps)
371 | quartile_list[quartile].append(sub_loss)
372 | for cnt in range(4):
373 | if len(quartile_list[cnt]) != 0:
374 | write_tb(tb_log, f"{prefix}/{key}_q{cnt}", sum(quartile_list[cnt])/len(quartile_list[cnt]), step)
375 | else:
376 | write_tb(tb_log, f"{prefix}/{key}_q{cnt}", 0.0, step)
377 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="sdg",
5 | py_modules=["sdg"],
6 | install_requires=["blobfile>=1.0.5", "tqdm"],
7 | )
8 |
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/teaser.png
--------------------------------------------------------------------------------