├── .DS_Store ├── .gitignore ├── README.html ├── README.md ├── ref ├── .DS_Store ├── bedroom_instructions.txt ├── ref_bedroom │ ├── .DS_Store │ ├── bedroom_train_0000125.png │ ├── bedroom_train_0000126.png │ ├── bedroom_train_0000192.png │ ├── bedroom_train_0000274.png │ ├── bedroom_train_0001285.png │ ├── bedroom_train_0002219.png │ ├── bedroom_train_0003606.png │ └── bedroom_train_0008929.png ├── ref_cat │ ├── .DS_Store │ ├── cat_0000024.png │ ├── cat_0000325.png │ ├── cat_0000383.png │ ├── cat_0003829.png │ └── cat_0007358.png ├── ref_ffhq │ ├── 69005.png │ ├── 69019.png │ ├── 69099.png │ ├── 69109.png │ ├── 69708.png │ ├── 69845.png │ ├── guide2.jpeg │ └── style4.png ├── ref_horse │ ├── horse_0000023.png │ ├── horse_0000052.png │ ├── horse_0000338.png │ ├── horse_0004057.png │ └── horse_0005581.png └── ref_style │ ├── .DS_Store │ ├── candy.jpeg │ ├── edtaonisl.jpeg │ ├── starnight.jpg │ ├── sunflowers.jpeg │ └── wave.jpeg ├── requirements.txt ├── scripts ├── clip_finetune_noise_nolabel.py └── sample.py ├── sdg ├── __init__.py ├── clip_guidance.py ├── distributed.py ├── fp16_util.py ├── gaussian_diffusion.py ├── gpu_affinity.py ├── guidance.py ├── image_datasets.py ├── logger.py ├── logging.py ├── losses.py ├── misc.py ├── nn.py ├── parser.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── setup.py └── teaser.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | *.pt 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # More Control for Free! Image Synthesis with Semantic Diffusion Guidance 2 | 3 | This is the codebase for [More Control for Free! Image Synthesis with Semantic Diffusion Guidance](http://arxiv.org/abs/2112.05744). 4 | 5 | This repository is based on [openai/guided-diffusion](https://github.com/openai/guided-diffusion), with modifications for semantic guidance. 6 | 7 | ![results](teaser.png) 8 | 9 | ### Installation 10 | 11 | ```bash 12 | git clone https://github.com/xh-liu/SDG_code 13 | cd SDG 14 | pip install -r requirements.txt 15 | pip install -e . 16 | ``` 17 | 18 | ### Download pre-trained models 19 | 20 | The pretrained unconditional diffusion models are from [openai/guided-diffusion](https://github.com/openai/guided-diffusion) and [jychoi118/ilvr_adm](https://github.com/jychoi118/ilvr_adm). 21 | 22 | * LSUN bedroom unconditional diffusion: [lsun_bedroom.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_bedroom.pt) 23 | * LSUN cat unconditional diffusion: [lsun_cat.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_cat.pt) 24 | * LSUN horse unconditional diffusion: [lsun_horse.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_horse.pt) 25 | * LSUN horse (no dropout): [lsun_horse_nodropout.pt](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/lsun_horse_nodropout.pt) 26 | * FFHQ unconditional diffusion: [ffhq.pt](https://onedrive.live.com/?authkey=%21AOIJGI8FUQXvFf8&id=72419B431C262344%21103807&cid=72419B431C262344) 27 | 28 | We finetune the CLIP image encoders on noisy images for the semantic guidance. We provide the checkpoint as follows: 29 | 30 | * FFHQ semantic guidance: [clip_ffhq.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EQbpgLeWnZhBhNzvFXgn26IBhsveoX3V57ZoQdSsLnwrjA?e=1K0qwv) 31 | * LSUN bedroom semantic guidance: [clipbedroom.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EfVpSVSjAhlEpsBCxSwkBnQByUvgNZqr38bxnG6bDHuOZQ?e=bOgCZT) 32 | * LSUN cat semantic guidance: [clip_cat.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EQdhKa0Jte9FtaB21kRDbT0B7tI3SoZewOack9DNe8s0LQ?e=zILyOa) 33 | * LSUN horse semantic guidance: [clip_horse.pt](https://hkuhk-my.sharepoint.com/:u:/g/personal/xihuiliu_hku_hk/EWqcgeq4kkpCgi3S9WcsOjABScZg-gT-aSnaZyh1uHIxNg?e=qDUJWK) 34 | 35 | ### Sampling with semantic diffusion guidance 36 | 37 | To sample from these models, you can use `scripts/sample.py`. 38 | Here, we provide flags for sampling from all of these models. 39 | We assume that you have downloaded the relevant model checkpoints into a folder called `models/`. 40 | 41 | For LSUN cat, LSUN horse, and LSUN bedroom, the model flags are defined as: 42 | 43 | ```bash 44 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --model_path models/lsun_bedroom.pt" 45 | ``` 46 | 47 | For FFHQ dataset, the model flags are defined as: 48 | ```bash 49 | MODEL_FLAGS="--attention_resolutions 16 --class_cond False --diffusion_steps 1000 --dropout 0.0 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 1 --resblock_updown True --use_fp16 False --use_scale_shift_norm True --model_path models/ffhq_10m.pt" 50 | ``` 51 | 52 | Sampling flags: 53 | 54 | ```bash 55 | SAMPLE_FLAGS="--batch_size 8 --timestep_respacing 100" 56 | ``` 57 | 58 | Sampling with image content(semantics) guidance: 59 | 60 | ```bash 61 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 0 --image_weight 100 --image_loss semantic --clip_path models/CLIP_bedroom.pt" 62 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS 63 | ``` 64 | 65 | Sampling with image style guidance: 66 | ```bash 67 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 0 --image_weight 100 --image_loss style --clip_path models/CLIP_bedroom.pt" 68 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_style_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS 69 | ``` 70 | 71 | Sampling with language guidance: 72 | ```bash 73 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 160 --image_weight 0 --text_instruction_file ref/bedroom_instructions.txt --clip_path models/CLIP_bedroom.pt" 74 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_language_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS 75 | ``` 76 | 77 | Sampling with both language and image guidance: 78 | ```bash 79 | GUIDANCE_FLAGS="--data_dir ref/ref_bedroom --text_weight 160 --image_weight 100 --image_loss semantic --text_instruction_file ref/bedroom_instructions.txt --clip_path models/CLIP_bedroom.pt" 80 | CUDA_VISIBLE_DEVICES=0 python -u scripts/sample.py --exp_name bedroom_image_language_guidance --single_gpu $MODEL_FLAGS $SAMPLE_FLAGS $GUIDANCE_FLAGS 81 | ``` 82 | You may need to adjust the text_weight and image_weight for better visual quality of generated samples. 83 | 84 | ### Citation 85 | If you find our work useful for your research, please cite our papers. 86 | ``` 87 | @inproceedings{liu2023more, 88 | title={More control for free! image synthesis with semantic diffusion guidance}, 89 | author={Liu, Xihui and Park, Dong Huk and Azadi, Samaneh and Zhang, Gong and Chopikyan, Arman and Hu, Yuxiao and Shi, Humphrey and Rohrbach, Anna and Darrell, Trevor}, 90 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 91 | year={2023} 92 | } 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /ref/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/.DS_Store -------------------------------------------------------------------------------- /ref/bedroom_instructions.txt: -------------------------------------------------------------------------------- 1 | A photo of a bedroom with a painting on the wall. 2 | -------------------------------------------------------------------------------- /ref/ref_bedroom/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/.DS_Store -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0000125.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000125.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0000126.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000126.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0000192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000192.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0000274.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0000274.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0001285.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0001285.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0002219.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0002219.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0003606.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0003606.png -------------------------------------------------------------------------------- /ref/ref_bedroom/bedroom_train_0008929.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_bedroom/bedroom_train_0008929.png -------------------------------------------------------------------------------- /ref/ref_cat/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/.DS_Store -------------------------------------------------------------------------------- /ref/ref_cat/cat_0000024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000024.png -------------------------------------------------------------------------------- /ref/ref_cat/cat_0000325.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000325.png -------------------------------------------------------------------------------- /ref/ref_cat/cat_0000383.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0000383.png -------------------------------------------------------------------------------- /ref/ref_cat/cat_0003829.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0003829.png -------------------------------------------------------------------------------- /ref/ref_cat/cat_0007358.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_cat/cat_0007358.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69005.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69019.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69099.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69109.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69109.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69708.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69708.png -------------------------------------------------------------------------------- /ref/ref_ffhq/69845.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/69845.png -------------------------------------------------------------------------------- /ref/ref_ffhq/guide2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/guide2.jpeg -------------------------------------------------------------------------------- /ref/ref_ffhq/style4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_ffhq/style4.png -------------------------------------------------------------------------------- /ref/ref_horse/horse_0000023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000023.png -------------------------------------------------------------------------------- /ref/ref_horse/horse_0000052.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000052.png -------------------------------------------------------------------------------- /ref/ref_horse/horse_0000338.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0000338.png -------------------------------------------------------------------------------- /ref/ref_horse/horse_0004057.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0004057.png -------------------------------------------------------------------------------- /ref/ref_horse/horse_0005581.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_horse/horse_0005581.png -------------------------------------------------------------------------------- /ref/ref_style/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/.DS_Store -------------------------------------------------------------------------------- /ref/ref_style/candy.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/candy.jpeg -------------------------------------------------------------------------------- /ref/ref_style/edtaonisl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/edtaonisl.jpeg -------------------------------------------------------------------------------- /ref/ref_style/starnight.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/starnight.jpg -------------------------------------------------------------------------------- /ref/ref_style/sunflowers.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/sunflowers.jpeg -------------------------------------------------------------------------------- /ref/ref_style/wave.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/ref/ref_style/wave.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | git+https://github.com/openai/CLIP.git 5 | pynvml 6 | tensorboard 7 | -------------------------------------------------------------------------------- /scripts/clip_finetune_noise_nolabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Finetune a noised CLIP image encoder on the target dataset without text annotations. 3 | """ 4 | 5 | import argparse 6 | import os 7 | import time 8 | 9 | import blobfile as bf 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | import torch.nn.functional as F 14 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 15 | from torch.optim import AdamW 16 | from torch.utils import tensorboard 17 | 18 | from sgd.parser import create_argparser 19 | from sgd.logging import init_logging, make_logging_dir 20 | from sgd.distributed import init_dist, is_master, get_world_size 21 | from sgd.distributed import master_only_print as print 22 | from sgd.distributed import dist_all_gather_tensor, all_gather_with_gradient 23 | from sgd.gpu_affinity import set_affinity 24 | from sgd.image_datasets import load_data 25 | from sgd.resample import create_named_schedule_sampler 26 | from sgd.script_util import ( 27 | add_dict_to_argparser, 28 | args_to_dict, 29 | classifier_and_diffusion_defaults, 30 | create_clip_and_diffusion, 31 | ) 32 | from sgd.train_util import parse_resume_step_from_filename, log_loss_dict 33 | from sgd.misc import set_random_seed 34 | from sgd.misc import to_cuda 35 | 36 | 37 | def main(): 38 | args = create_argparser().parse_args() 39 | 40 | set_affinity(args.local_rank) 41 | if args.randomized_seed: 42 | args.seed = random.randint(0, 10000) 43 | set_random_seed(args.seed, by_rank=True) 44 | if not args.single_gpu: 45 | init_dist(args.local_rank) 46 | tb_log = None 47 | args.logdir = init_logging(args.exp_name) 48 | if is_master(): 49 | tb_log = make_logging_dir(args.logdir) 50 | world_size = get_world_size() 51 | 52 | print("creating model and diffusion...") 53 | model, diffusion = create_clip_and_diffusion( 54 | args, **args_to_dict(args, classifier_and_diffusion_defaults().keys()) 55 | ) 56 | model.to('cuda') 57 | if args.noised: 58 | schedule_sampler = create_named_schedule_sampler( 59 | args.schedule_sampler, diffusion 60 | ) 61 | 62 | resume_step = 0 63 | if args.resume_checkpoint: 64 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint) 65 | print(f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step") 66 | model.load_state_dict(th.load(args.resume_checkpoint, map_location=lambda storage, loc: storage)) 67 | model.to('cuda') 68 | 69 | if args.use_fp16: 70 | if args.fp16_hyperparams == 'pytorch': 71 | scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True) 72 | elif args.fp16_hyperparams == 'openai': 73 | scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2**0.001, backoff_factor=0.5, growth_interval=1, enabled=True) 74 | 75 | 76 | use_ddp = th.cuda.is_available() and th.distributed.is_available() and dist.is_initialized() 77 | if use_ddp: 78 | ddp_model = DDP( 79 | model, 80 | device_ids=[args.local_rank], 81 | output_device=args.local_rank, 82 | broadcast_buffers=False, 83 | bucket_cap_mb=128, 84 | find_unused_parameters=False, 85 | ) 86 | else: 87 | print("Single GPU Training without DistributedDataParallel. ") 88 | ddp_model = model 89 | 90 | print("creating data loader...") 91 | args.return_text = False 92 | args.return_class = False 93 | args.return_yall = False 94 | train_dataloader = load_data(args) 95 | 96 | print(f"creating optimizer...") 97 | opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=args.lr, weight_decay=args.weight_decay) 98 | if args.resume_checkpoint: 99 | opt_checkpoint = bf.join( 100 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt" 101 | ) 102 | if bf.exists(opt_checkpoint): 103 | print(f"loading optimizer state from checkpoint: {opt_checkpoint}") 104 | opt.load_state_dict(th.load(opt_checkpoint, map_location=lambda storage, loc: storage)) 105 | else: 106 | print('Warning: opt checkpoint %s not found' % opt_checkpoint) 107 | sca_checkpoint = bf.join( 108 | bf.dirname(args.resume_checkpoint), f"scaler{resume_step:06}.pt" 109 | ) 110 | if bf.exists(sca_checkpoint): 111 | print(f"loading optimizer state from checkpoint: {sca_checkpoint}") 112 | scaler.load_state_dict(th.load(sca_checkpoint, map_location=lambda storage, loc: storage)) 113 | else: 114 | print('Warning: opt checkpoint %s not found' % opt_checkpoint) 115 | 116 | 117 | 118 | print("training classifier model...") 119 | 120 | import clip 121 | clip_pretrained, _ = clip.load('RN50x16', jit=False) 122 | clip_pretrained = clip_pretrained.float() 123 | clip_pretrained.eval() 124 | clip_pretrained = clip_pretrained.cuda() 125 | 126 | def forward_backward_log(data_loader, prefix="train", step=0): 127 | batch, batch2 = next(data_loader) 128 | 129 | batch = to_cuda(batch) 130 | batch2 = to_cuda(batch2) 131 | # Noisy images 132 | if args.noised: 133 | t, _ = schedule_sampler.sample(batch.shape[0], 'cuda') 134 | batch = diffusion.q_sample(batch, t) 135 | else: 136 | t = th.zeros(batch.shape[0], dtype=th.long, device='cuda') 137 | 138 | ground_truth = th.arange(batch.shape[0] * world_size, dtype=th.long, device='cuda') 139 | for i, (sub_batch, sub_batch2, sub_t) in enumerate( 140 | split_microbatches(args.microbatch, batch, batch2, t) 141 | ): 142 | with th.cuda.amp.autocast(args.use_fp16): 143 | with th.no_grad(): 144 | sub_labels = clip_pretrained.encode_image(sub_batch2) 145 | if args.structure == 'classifier': 146 | image_features = ddp_model(sub_batch, timesteps=sub_t) 147 | text_features = sub_labels 148 | else: 149 | image_features, text_features = ddp_model(sub_batch, sub_labels, sub_t) 150 | 151 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 152 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 153 | losses = {} 154 | 155 | image_features = all_gather_with_gradient(image_features) 156 | text_features = all_gather_with_gradient(text_features) 157 | logits_per_image = 100.0 * image_features @ text_features.t() 158 | logits_per_text = logits_per_image.t() 159 | 160 | loss_i2t = F.cross_entropy(logits_per_image, ground_truth, reduction='none') 161 | loss_t2i = F.cross_entropy(logits_per_text, ground_truth, reduction='none') 162 | loss = loss_i2t + loss_t2i 163 | losses[f"{prefix}_loss_i2t"] = loss_i2t.detach() 164 | losses[f"{prefix}_loss_t2i"] = loss_t2i.detach() 165 | 166 | 167 | losses[f"{prefix}_loss"] = loss.detach() 168 | 169 | log_loss_dict(diffusion, sub_t, losses, tb_log, step) 170 | del losses 171 | loss = loss.mean() 172 | if loss.requires_grad: 173 | if i == 0: 174 | opt.zero_grad() 175 | if args.use_fp16: 176 | scaler.scale(loss).backward() 177 | else: 178 | loss.backward() 179 | 180 | global_batch = args.batch_size * world_size 181 | for step in range(args.iterations - resume_step): 182 | print("***step %d " % (step + resume_step), end='') 183 | num_samples = (step + resume_step + 1) * global_batch 184 | print('samples: %d ' % num_samples, end='') 185 | if args.anneal_lr: 186 | set_annealed_lr(opt, args.lr, args.lr_anneal_steps, step + resume_step) 187 | if is_master(): 188 | tb_log.add_scalar('status/step', step + resume_step, step) 189 | tb_log.add_scalar('status/samples', num_samples, step) 190 | tb_log.add_scalar('status/lr', args.lr, step) 191 | forward_backward_log(train_dataloader, step=step) 192 | if args.use_fp16: 193 | scaler.step(opt) 194 | scaler.update() 195 | else: 196 | opt.step() 197 | if not step % args.log_interval and is_master(): 198 | tb_log.flush() 199 | if not (step + resume_step) % args.save_interval: 200 | print("saving model...") 201 | save_model(args.logdir, model, opt, scaler, step + resume_step) 202 | 203 | print("saving model...") 204 | if is_master(): 205 | save_model(args.logdir, model, opt, step + resume_step) 206 | tb_log.close() 207 | 208 | 209 | def set_annealed_lr(opt, base_lr, anneal_steps, current_steps): 210 | lr_decay_cnt = current_steps // anneal_steps 211 | lr = base_lr * 0.1**lr_decay_cnt 212 | for param_group in opt.param_groups: 213 | param_group["lr"] = lr 214 | 215 | 216 | def save_model(logdir, model, opt, scaler, step): 217 | if is_master(): 218 | th.save(model.state_dict(), os.path.join(logdir, f"model{step:06d}.pt")) 219 | th.save(opt.state_dict(), os.path.join(logdir, f"opt{step:06d}.pt")) 220 | th.save(scaler.state_dict(), os.path.join(logdir, f"scaler{step:06d}.pt")) 221 | 222 | 223 | def compute_top_k(logits, labels, k, reduction="mean"): 224 | _, top_ks = th.topk(logits, k, dim=-1) 225 | if reduction == "mean": 226 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 227 | elif reduction == "none": 228 | return (top_ks == labels[:, None]).float().sum(dim=-1) 229 | 230 | 231 | def split_microbatches(microbatch, *args): 232 | bs = len(args[0]) 233 | if microbatch == -1 or microbatch >= bs: 234 | yield tuple(args) 235 | else: 236 | for i in range(0, bs, microbatch): 237 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args) 238 | 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /scripts/sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import blobfile as bf 6 | import numpy as np 7 | import torch as th 8 | import torch.distributed as dist 9 | import torchvision 10 | import torch.nn.functional as F 11 | 12 | from sdg.parser import create_argparser 13 | from sdg.logging import init_logging, make_logging_dir 14 | from sdg.distributed import master_only_print as print 15 | from sdg.distributed import is_master, init_dist, get_world_size 16 | from sdg.gpu_affinity import set_affinity 17 | from sdg.logging import init_logging, make_logging_dir 18 | from sdg.script_util import ( 19 | model_and_diffusion_defaults, 20 | create_model_and_diffusion, 21 | args_to_dict, 22 | add_dict_to_argparser, 23 | ) 24 | from sdg.clip_guidance import CLIP_gd 25 | from sdg.image_datasets import load_ref_data 26 | from sdg.misc import set_random_seed 27 | from sdg.guidance import image_loss, text_loss 28 | from sdg.image_datasets import _list_image_files_recursively 29 | from torchvision import utils 30 | import math 31 | import clip 32 | 33 | 34 | def main(): 35 | time0 = time.time() 36 | args = create_argparser().parse_args() 37 | set_affinity(args.local_rank) 38 | if args.randomized_seed: 39 | args.seed = random.randint(0, 10000) 40 | set_random_seed(args.seed, by_rank=True) 41 | if not args.single_gpu: 42 | init_dist(args.local_rank) 43 | 44 | tb_log = None 45 | args.logdir = init_logging(args.exp_name, root_dir='results', timestamp=False) 46 | if is_master(): 47 | tb_log = make_logging_dir(args.logdir, no_tb=True) 48 | 49 | print("creating model...") 50 | model, diffusion = create_model_and_diffusion( 51 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 52 | ) 53 | model.load_state_dict( 54 | th.load(args.model_path, map_location="cpu") 55 | ) 56 | model.to('cuda') 57 | model.eval() 58 | 59 | clip_model, preprocess = clip.load('RN50x16', device='cuda') 60 | if args.text_weight == 0: 61 | instructions = [""] 62 | else: 63 | with open(args.text_instruction_file, 'r') as f: 64 | instructions = f.readlines() 65 | instructions = [tmp.replace('\n', '') for tmp in instructions] 66 | # define image list 67 | if args.image_weight == 0: 68 | imgs = [None] 69 | else: 70 | imgs = _list_image_files_recursively(args.data_dir) 71 | imgs = sorted(imgs) 72 | clip_ft = CLIP_gd(args) 73 | clip_ft.load_state_dict(th.load(args.clip_path, map_location='cpu')) 74 | clip_ft.eval() 75 | clip_ft = clip_ft.cuda() 76 | 77 | def cond_fn_sdg(x, t, y, **kwargs): 78 | assert y is not None 79 | with th.no_grad(): 80 | if args.text_weight != 0: 81 | text_features = clip_model.encode_text(y) 82 | if args.image_weight != 0: 83 | target_img_noised = diffusion.q_sample(kwargs['ref_img'], t, tscale1000=True) 84 | target_img_features = clip_ft.encode_image_list(target_img_noised, t) 85 | with th.enable_grad(): 86 | x_in = x.detach().requires_grad_(True) 87 | image_features = clip_ft.encode_image_list(x_in, t) 88 | if args.text_weight != 0: 89 | loss_text = text_loss(image_features, text_features, args) 90 | else: 91 | loss_text = 0 92 | if args.image_weight != 0: 93 | loss_img = image_loss(image_features, target_img_features, args) 94 | else: 95 | loss_img = 0 96 | total_guidance = loss_text * args.text_weight + loss_img * args.image_weight 97 | 98 | return th.autograd.grad(total_guidance.sum(), x_in)[0] 99 | 100 | 101 | print("creating samples...") 102 | count = 0 103 | for img_cnt in range(len(imgs)): 104 | if imgs[img_cnt] is not None: 105 | print("loading data...") 106 | model_kwargs = load_ref_data(args, imgs[img_cnt]) 107 | else: 108 | model_kwargs = {} 109 | 110 | for ins_cnt in range(len(instructions)): 111 | instruction = instructions[ins_cnt] 112 | text = clip.tokenize([instruction for cnt in range(args.batch_size)]).to('cuda') 113 | model_kwargs['y'] = text 114 | model_kwargs = {k: v.to('cuda') for k, v in model_kwargs.items()} 115 | if args.image_weight == 0 and args.text_weight == 0: 116 | cond_fn = None 117 | else: 118 | cond_fn = cond_fn_sdg 119 | with th.cuda.amp.autocast(True): 120 | sample = diffusion.p_sample_loop( 121 | model, 122 | (args.batch_size, 3, args.image_size, args.image_size), 123 | noise=None, 124 | clip_denoised=args.clip_denoised, 125 | model_kwargs=model_kwargs, 126 | cond_fn=cond_fn, 127 | device='cuda', 128 | ) 129 | 130 | for i in range(args.batch_size): 131 | if args.text_weight == 0: 132 | out_folder = '%05d_%s' % (img_cnt, os.path.basename(imgs[img_cnt]).split('.')[0]) 133 | elif args.image_weight == 0: 134 | out_folder = '%05d_%s' % (ins_cnt, instructions[ins_cnt]) 135 | else: 136 | out_folder = '%05d_%05d_%s_%s' % (img_cnt, ins_cnt, os.path.basename(imgs[img_cnt]).split('.')[0], instructions[ins_cnt]) 137 | 138 | out_path = os.path.join(args.logdir, out_folder, 139 | f"{str(count * args.batch_size + i).zfill(5)}.png") 140 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 141 | utils.save_image( 142 | sample[i].unsqueeze(0), 143 | out_path, 144 | nrow=1, 145 | normalize=True, 146 | range=(-1, 1), 147 | ) 148 | 149 | count += 1 150 | print(f"created {count * args.batch_size} samples") 151 | print(time.time() - time0) 152 | 153 | print("sampling complete") 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /sdg/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "More Control for Free! Image Synthesis with Semantic Diffusion Guidance". 3 | """ 4 | -------------------------------------------------------------------------------- /sdg/clip_guidance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | from clip import clip 6 | from .nn import timestep_embedding 7 | 8 | 9 | class CLIP_gd(nn.Module): 10 | def __init__(self, args): 11 | super().__init__() 12 | self.finetune_clip_layer = getattr(args, 'finetune_clip_layer', 'all') 13 | clip_model, preprocess = clip.load('RN50x16', jit=False) 14 | self.preprocess = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 15 | clip_model = clip_model.float() 16 | 17 | # visual 18 | self.visual_frozen = nn.Sequential( 19 | clip_model.visual.conv1, 20 | clip_model.visual.bn1, 21 | clip_model.visual.relu1, 22 | clip_model.visual.conv2, 23 | clip_model.visual.bn2, 24 | clip_model.visual.relu2, 25 | clip_model.visual.conv3, 26 | clip_model.visual.bn3, 27 | clip_model.visual.relu3, 28 | clip_model.visual.avgpool, 29 | clip_model.visual.layer1, 30 | clip_model.visual.layer2, 31 | clip_model.visual.layer3, 32 | ) 33 | self.attn_pool = clip_model.visual.attnpool 34 | self.layer4 = clip_model.visual.layer4 35 | 36 | self.attn_resolution = args.image_size // 32 37 | self.image_size = args.image_size 38 | self.after_load() 39 | self.define_finetune() 40 | 41 | def after_load(self): 42 | self.attn_pool.positional_embedding = nn.Parameter(torch.randn(self.attn_resolution ** 2 + 1, 3072) / 3072 ** 0.5) 43 | self.time_embed = nn.Sequential( 44 | nn.Linear(128, 512), 45 | nn.ReLU(), 46 | nn.Linear(512, 512), 47 | ) 48 | emb_dim = [48, 48, 96] 49 | tmp1 = [96, 96, 384] 50 | for cnt in range(6): 51 | emb_dim.extend(tmp1) 52 | tmp2 = [192, 192, 768] 53 | for cnt in range(8): 54 | emb_dim.extend(tmp2) 55 | tmp3 = [384, 384, 1536] 56 | for cnt in range(18): 57 | emb_dim.extend(tmp3) 58 | tmp4 = [768, 768, 3072] 59 | for cnt in range(8): 60 | emb_dim.extend(tmp4) 61 | self.emb_layers = nn.Sequential(nn.ReLU(), nn.Linear(512, sum(emb_dim) * 2)) 62 | self.split_idx = [] 63 | cur_idx = 0 64 | for cnt in range(len(emb_dim)): 65 | self.split_idx.append(cur_idx + emb_dim[cnt]) 66 | self.split_idx.append(cur_idx + 2 * emb_dim[cnt]) 67 | cur_idx += 2 * emb_dim[cnt] 68 | self.split_idx = self.split_idx[:-1] 69 | 70 | def define_finetune(self): 71 | self.train() 72 | 73 | # freeze visual encoder 74 | for param in self.visual_frozen.parameters(): 75 | param.requires_grad = False 76 | for param in self.layer4.parameters(): 77 | param.requires_grad = False 78 | self.attn_pool.positional_embedding.requires_grad = True 79 | self.time_embed.requires_grad = True 80 | self.emb_layers.requires_grad = True 81 | 82 | if self.finetune_clip_layer == 'last': 83 | for param in self.layer4.parameters(): 84 | param.requires_grad = True 85 | for param in self.attn_pool.parameters(): 86 | param.requires_grad = True 87 | elif self.finetune_clip_layer == 'all': 88 | for param in self.parameters(): 89 | param.requires_grad = True 90 | 91 | 92 | def train(self, mode=True): 93 | self.visual_frozen.eval() 94 | self.layer4.eval() 95 | self.attn_pool.eval() 96 | 97 | self.time_embed.train(mode) 98 | self.emb_layers.train(mode) 99 | 100 | if self.finetune_clip_layer == 'last': 101 | self.layer4.train(mode) 102 | self.attn_pool.train(mode) 103 | elif self.finetune_clip_layer == 'all': 104 | self.visual_frozen.train(mode) 105 | self.layer4.train(mode) 106 | self.attn_pool.train(mode) 107 | 108 | def encode_image(self, image, t): 109 | image = (image + 1) / 2.0 110 | image = self.preprocess(image) 111 | emb = self.time_embed(timestep_embedding(t, 128)) 112 | emb_out = torch.tensor_split(self.emb_layers(emb).unsqueeze(-1).unsqueeze(-1), self.split_idx, dim=1) 113 | x = self.visual_frozen[1](self.visual_frozen[0](image)) 114 | x = x * (1 + emb_out[0]) + emb_out[1] 115 | x = self.visual_frozen[4](self.visual_frozen[3](self.visual_frozen[2](x))) 116 | x = x * (1 + emb_out[2]) + emb_out[3] 117 | x = self.visual_frozen[7](self.visual_frozen[6](self.visual_frozen[5](x))) 118 | x = x * (1 + emb_out[4]) + emb_out[5] 119 | x = self.visual_frozen[9](self.visual_frozen[8](x)) 120 | # layer1 121 | module_cnt = 10 122 | emb_cnt = 6 123 | for cnt in range(6): 124 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt) 125 | 126 | # layer2 127 | module_cnt = 11 128 | emb_cnt = 42 129 | for cnt in range(8): 130 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt) 131 | # layer3 132 | module_cnt = 12 133 | emb_cnt = 90 134 | for cnt in range(18): 135 | x = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x, emb_out, module_cnt, emb_cnt, cnt) 136 | 137 | # layer4 138 | emb_cnt = 198 139 | for cnt in range(8): 140 | x = self.bottleneck_block_forward(self.layer4[cnt], x, emb_out, module_cnt, emb_cnt, cnt) 141 | 142 | x = self.attn_pool(x) 143 | 144 | return x 145 | 146 | def encode_image_list(self, image, t, return_layer=4): 147 | image = (image + 1) / 2.0 148 | image = self.preprocess(image) 149 | 150 | emb = self.time_embed(timestep_embedding(t, 128)) 151 | emb_out = torch.tensor_split(self.emb_layers(emb).unsqueeze(-1).unsqueeze(-1), self.split_idx, dim=1) 152 | x = self.visual_frozen[1](self.visual_frozen[0](image)) 153 | x = x * (1 + emb_out[0]) + emb_out[1] 154 | x = self.visual_frozen[4](self.visual_frozen[3](self.visual_frozen[2](x))) 155 | x = x * (1 + emb_out[2]) + emb_out[3] 156 | x = self.visual_frozen[7](self.visual_frozen[6](self.visual_frozen[5](x))) 157 | x = x * (1 + emb_out[4]) + emb_out[5] 158 | x1 = self.visual_frozen[9](self.visual_frozen[8](x)) 159 | # layer1 160 | module_cnt = 10 161 | emb_cnt = 6 162 | 163 | for cnt in range(6): 164 | if cnt == 0: 165 | x2 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x1, emb_out, module_cnt, emb_cnt, cnt) 166 | else: 167 | x2 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x2, emb_out, module_cnt, emb_cnt, cnt) 168 | 169 | # layer2 170 | module_cnt = 11 171 | emb_cnt = 42 172 | for cnt in range(8): 173 | if cnt == 0: 174 | x3 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x2, emb_out, module_cnt, emb_cnt, cnt) 175 | else: 176 | x3 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x3, emb_out, module_cnt, emb_cnt, cnt) 177 | 178 | # layer3 179 | module_cnt = 12 180 | emb_cnt = 90 181 | for cnt in range(18): 182 | if cnt == 0: 183 | x4 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x3, emb_out, module_cnt, emb_cnt, cnt) 184 | else: 185 | x4 = self.bottleneck_block_forward(self.visual_frozen[module_cnt][cnt], x4, emb_out, module_cnt, emb_cnt, cnt) 186 | 187 | # layer4 188 | emb_cnt = 198 189 | for cnt in range(8): 190 | if cnt == 0: 191 | x5 = self.bottleneck_block_forward(self.layer4[cnt], x4, emb_out, module_cnt, emb_cnt, cnt) 192 | else: 193 | x5 = self.bottleneck_block_forward(self.layer4[cnt], x5, emb_out, module_cnt, emb_cnt, cnt) 194 | 195 | x6 = self.attn_pool(x5) 196 | 197 | return [x1, x2, x3, x4, x5, x6] 198 | 199 | 200 | 201 | def bottleneck_block_forward(self, net, x, emb_out, module_cnt, emb_cnt, cnt): 202 | identity = x 203 | y = net.bn1(net.conv1(x)) 204 | y = y * (1 + emb_out[emb_cnt+cnt*6]) + emb_out[emb_cnt+cnt*6+1] 205 | y = net.relu1(y) 206 | y = net.bn2(net.conv2(y)) 207 | y = y * (1 + emb_out[emb_cnt+cnt*6+2]) + emb_out[emb_cnt+cnt*6+3] 208 | y = net.relu2(y) 209 | y = net.avgpool(y) 210 | y = net.bn3(net.conv3(y)) 211 | y = y * (1 + emb_out[emb_cnt+cnt*6+4]) + emb_out[emb_cnt+cnt*6+5] 212 | if net.downsample is not None: 213 | identity = net.downsample(x) 214 | y += identity 215 | y = net.relu3(y) 216 | return y 217 | 218 | 219 | def encode_text(self, text): 220 | with torch.no_grad(): 221 | x = self.token_embedding_frozen(text) # [batch_size, n_ctx, d_model] 222 | x = x + self.positional_embedding_frozen 223 | x = x.permute(1, 0, 2) # NLD -> LND 224 | x = self.transformer_frozen(x) 225 | 226 | x = self.transformer_last_block(x) 227 | x = x.permute(1, 0, 2) # LND -> NLD 228 | x = self.ln_final(x) 229 | 230 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 231 | 232 | return x 233 | 234 | def unfreeze(self): 235 | self.attn_pool.requires_grad_(True) 236 | self.layer4.requires_grad_(True) 237 | 238 | self.transformer_last_block.requires_grad_(True) 239 | self.ln_final.requires_grad_(True) 240 | self.text_projection.requires_grad_(True) 241 | self.logit_scale.requires_grad_(True) 242 | 243 | def forward(self, image, text_features, timesteps): 244 | # to match the preprocess of clip model 245 | 246 | image_features = self.encode_image(image, timesteps) 247 | #text_features = self.encode_text(text) 248 | 249 | return image_features, text_features 250 | 251 | def training_step(self, batch, batch_idx): 252 | image, text = batch 253 | 254 | bs = image.size(0) 255 | 256 | image_features, text_features = self(image, text) 257 | 258 | # normalized features 259 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 260 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 261 | 262 | # cosine similarity as logits 263 | logit_scale = self.logit_scale.exp() 264 | logits_per_image = logit_scale * image_features @ text_features.t() 265 | logits_per_text = logit_scale * text_features @ image_features.t() 266 | 267 | label = torch.arange(bs).long() 268 | label = label.to(image.device) 269 | 270 | loss_i = F.cross_entropy(logits_per_image, label) 271 | loss_t = F.cross_entropy(logits_per_text, label) 272 | 273 | loss = (loss_i + loss_t) / 2 274 | 275 | return loss 276 | 277 | def configure_optimizers(self): 278 | lr = self.learning_rate 279 | opt = torch.optim.AdamW(list(self.attn_pool.parameters()) + 280 | list(self.layer4.parameters()) + 281 | list(self.transformer_last_block.parameters()) + 282 | list(self.ln_final.parameters()) + 283 | [self.text_projection], 284 | lr=lr) 285 | return opt 286 | 287 | -------------------------------------------------------------------------------- /sdg/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | import functools 6 | import ctypes 7 | 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def init_dist(local_rank, backend='nccl', **kwargs): 13 | r"""Initialize distributed training""" 14 | if dist.is_available(): 15 | if dist.is_initialized(): 16 | return torch.cuda.current_device() 17 | torch.cuda.set_device(local_rank) 18 | dist.init_process_group(backend=backend, init_method='env://', **kwargs) 19 | 20 | # Increase the L2 fetch granularity for faster speed. 21 | _libcudart = ctypes.CDLL('libcudart.so') 22 | # Set device limit on the current device 23 | # cudaLimitMaxL2FetchGranularity = 0x05 24 | pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) 25 | _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) 26 | _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) 27 | # assert pValue.contents.value == 128 28 | 29 | 30 | def get_rank(): 31 | r"""Get rank of the thread.""" 32 | rank = 0 33 | if dist.is_available(): 34 | if dist.is_initialized(): 35 | rank = dist.get_rank() 36 | return rank 37 | 38 | 39 | def get_world_size(): 40 | r"""Get world size. How many GPUs are available in this job.""" 41 | world_size = 1 42 | if dist.is_available(): 43 | if dist.is_initialized(): 44 | world_size = dist.get_world_size() 45 | return world_size 46 | 47 | 48 | def master_only(func): 49 | r"""Apply this function only to the master GPU.""" 50 | @functools.wraps(func) 51 | def wrapper(*args, **kwargs): 52 | r"""Simple function wrapper for the master function""" 53 | if get_rank() == 0: 54 | return func(*args, **kwargs) 55 | else: 56 | return None 57 | return wrapper 58 | 59 | 60 | def is_master(): 61 | r"""check if current process is the master""" 62 | return get_rank() == 0 63 | 64 | 65 | def is_local_master(): 66 | return torch.cuda.current_device() == 0 67 | 68 | 69 | @master_only 70 | def master_only_print(*args, **kwargs): 71 | r"""master-only print""" 72 | print(*args, **kwargs) 73 | 74 | def print_all_rank(*args, **kwargs): 75 | rank = get_rank() 76 | print('[rank %d] %s' % (rank, *args), **kwargs) 77 | 78 | 79 | def dist_reduce_tensor(tensor, rank=0, reduce='mean'): 80 | r""" Reduce to rank 0 """ 81 | world_size = get_world_size() 82 | if world_size < 2: 83 | return tensor 84 | with torch.no_grad(): 85 | dist.reduce(tensor, dst=rank) 86 | if get_rank() == rank: 87 | if reduce == 'mean': 88 | tensor /= world_size 89 | elif reduce == 'sum': 90 | pass 91 | else: 92 | raise NotImplementedError 93 | return tensor 94 | 95 | 96 | def dist_all_reduce_tensor(tensor, reduce='mean'): 97 | r""" Reduce to all ranks """ 98 | world_size = get_world_size() 99 | if world_size < 2: 100 | return tensor 101 | with torch.no_grad(): 102 | dist.all_reduce(tensor) 103 | if reduce == 'mean': 104 | tensor /= world_size 105 | elif reduce == 'sum': 106 | pass 107 | else: 108 | raise NotImplementedError 109 | return tensor 110 | 111 | 112 | def dist_all_gather_tensor(tensor): 113 | r""" gather to all ranks """ 114 | world_size = get_world_size() 115 | if world_size < 2: 116 | return [tensor] 117 | tensor_list = [ 118 | torch.ones_like(tensor) for _ in range(dist.get_world_size())] 119 | with torch.no_grad(): 120 | dist.all_gather(tensor_list, tensor) 121 | return tensor_list 122 | 123 | class AllGatherFunction(torch.autograd.Function): 124 | @staticmethod 125 | def forward(ctx, tensor: torch.Tensor, reduce_dtype: torch.dtype = torch.float32): 126 | ctx.reduce_dtype = reduce_dtype 127 | output = list(torch.empty_like(tensor) for _ in range(dist.get_world_size())) 128 | dist.all_gather(output, tensor) 129 | output = torch.cat(output, dim=0) 130 | return output 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output: torch.Tensor): 134 | grad_dtype = grad_output.dtype 135 | input_list = list(grad_output.to(ctx.reduce_dtype).chunk(dist.get_world_size())) 136 | grad_input = torch.empty_like(input_list[dist.get_rank()]) 137 | dist.reduce_scatter(grad_input, input_list) 138 | return grad_input.to(grad_dtype) 139 | 140 | 141 | def all_gather_with_gradient(tensor): 142 | world_size = get_world_size() 143 | if world_size < 2: 144 | return tensor 145 | return AllGatherFunction.apply(tensor) 146 | -------------------------------------------------------------------------------- /sdg/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /sdg/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | from .nn import mean_flat 15 | from .losses import normal_kl, discretized_gaussian_log_likelihood 16 | 17 | 18 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 19 | """ 20 | Get a pre-defined beta schedule for the given name. 21 | 22 | The beta schedule library consists of beta schedules which remain similar 23 | in the limit of num_diffusion_timesteps. 24 | Beta schedules may be added, but should not be removed or changed once 25 | they are committed to maintain backwards compatibility. 26 | """ 27 | if schedule_name == "linear": 28 | # Linear schedule from Ho et al, extended to work for any number of 29 | # diffusion steps. 30 | scale = 1000 / num_diffusion_timesteps 31 | beta_start = scale * 0.0001 32 | beta_end = scale * 0.02 33 | return np.linspace( 34 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 35 | ) 36 | elif schedule_name == "cosine": 37 | return betas_for_alpha_bar( 38 | num_diffusion_timesteps, 39 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 40 | ) 41 | else: 42 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 43 | 44 | 45 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 46 | """ 47 | Create a beta schedule that discretizes the given alpha_t_bar function, 48 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 49 | 50 | :param num_diffusion_timesteps: the number of betas to produce. 51 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 52 | produces the cumulative product of (1-beta) up to that 53 | part of the diffusion process. 54 | :param max_beta: the maximum beta to use; use values lower than 1 to 55 | prevent singularities. 56 | """ 57 | betas = [] 58 | for i in range(num_diffusion_timesteps): 59 | t1 = i / num_diffusion_timesteps 60 | t2 = (i + 1) / num_diffusion_timesteps 61 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 62 | return np.array(betas) 63 | 64 | 65 | class ModelMeanType(enum.Enum): 66 | """ 67 | Which type of output the model predicts. 68 | """ 69 | 70 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 71 | START_X = enum.auto() # the model predicts x_0 72 | EPSILON = enum.auto() # the model predicts epsilon 73 | 74 | 75 | class ModelVarType(enum.Enum): 76 | """ 77 | What is used as the model's output variance. 78 | 79 | The LEARNED_RANGE option has been added to allow the model to predict 80 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 81 | """ 82 | 83 | LEARNED = enum.auto() 84 | FIXED_SMALL = enum.auto() 85 | FIXED_LARGE = enum.auto() 86 | LEARNED_RANGE = enum.auto() 87 | 88 | 89 | class LossType(enum.Enum): 90 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 91 | RESCALED_MSE = ( 92 | enum.auto() 93 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 94 | KL = enum.auto() # use the variational lower-bound 95 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 96 | 97 | def is_vb(self): 98 | return self == LossType.KL or self == LossType.RESCALED_KL 99 | 100 | 101 | class GaussianDiffusion: 102 | """ 103 | Utilities for training and sampling diffusion models. 104 | 105 | Ported directly from here, and then adapted over time to further experimentation. 106 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 107 | 108 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 109 | starting at T and going to 1. 110 | :param model_mean_type: a ModelMeanType determining what the model outputs. 111 | :param model_var_type: a ModelVarType determining how variance is output. 112 | :param loss_type: a LossType determining the loss function to use. 113 | :param rescale_timesteps: if True, pass floating point timesteps into the 114 | model so that they are always scaled like in the 115 | original paper (0 to 1000). 116 | """ 117 | 118 | def __init__( 119 | self, 120 | *, 121 | betas, 122 | model_mean_type, 123 | model_var_type, 124 | loss_type, 125 | rescale_timesteps=False, 126 | betas1000=None, 127 | ): 128 | self.model_mean_type = model_mean_type 129 | self.model_var_type = model_var_type 130 | self.loss_type = loss_type 131 | self.rescale_timesteps = rescale_timesteps 132 | 133 | # Use float64 for accuracy. 134 | betas = np.array(betas, dtype=np.float64) 135 | self.betas = betas 136 | assert len(betas.shape) == 1, "betas must be 1-D" 137 | assert (betas > 0).all() and (betas <= 1).all() 138 | 139 | self.num_timesteps = int(betas.shape[0]) 140 | 141 | alphas = 1.0 - betas 142 | 143 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 144 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 145 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 146 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 147 | 148 | # calculations for diffusion q(x_t | x_{t-1}) and others 149 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 150 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 151 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 152 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 153 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 154 | 155 | if betas1000 is not None: 156 | betas1000 = np.array(betas1000, dtype=np.float64) 157 | alphas1000 = 1.0 - betas1000 158 | self.alphas_cumprod1000 = np.cumprod(alphas1000, axis=0) 159 | self.sqrt_alphas_cumprod1000 = np.sqrt(self.alphas_cumprod1000) 160 | self.sqrt_one_minus_alphas_cumprod1000 = np.sqrt(1.0 - self.alphas_cumprod1000) 161 | 162 | # calculations for posterior q(x_{t-1} | x_t, x_0) 163 | self.posterior_variance = ( 164 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 165 | ) 166 | # log calculation clipped because the posterior variance is 0 at the 167 | # beginning of the diffusion chain. 168 | self.posterior_log_variance_clipped = np.log( 169 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 170 | ) 171 | self.posterior_mean_coef1 = ( 172 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 173 | ) 174 | self.posterior_mean_coef2 = ( 175 | (1.0 - self.alphas_cumprod_prev) 176 | * np.sqrt(alphas) 177 | / (1.0 - self.alphas_cumprod) 178 | ) 179 | 180 | def q_mean_variance(self, x_start, t): 181 | """ 182 | Get the distribution q(x_t | x_0). 183 | 184 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 185 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 186 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 187 | """ 188 | mean = ( 189 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 190 | ) 191 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 192 | log_variance = _extract_into_tensor( 193 | self.log_one_minus_alphas_cumprod, t, x_start.shape 194 | ) 195 | return mean, variance, log_variance 196 | 197 | def q_sample(self, x_start, t, noise=None, tscale1000=False): 198 | """ 199 | Diffuse the data for a given number of diffusion steps. 200 | 201 | In other words, sample from q(x_t | x_0). 202 | 203 | :param x_start: the initial data batch. 204 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 205 | :param noise: if specified, the split-out normal noise. 206 | :return: A noisy version of x_start. 207 | """ 208 | if noise is None: 209 | noise = th.randn_like(x_start) 210 | assert noise.shape == x_start.shape 211 | if tscale1000: 212 | return ( 213 | _extract_into_tensor(self.sqrt_alphas_cumprod1000, t, x_start.shape) * x_start 214 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod1000, t, x_start.shape) 215 | * noise 216 | ) 217 | else: 218 | return ( 219 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 220 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 221 | * noise 222 | ) 223 | 224 | def q_posterior_mean_variance(self, x_start, x_t, t): 225 | """ 226 | Compute the mean and variance of the diffusion posterior: 227 | 228 | q(x_{t-1} | x_t, x_0) 229 | 230 | """ 231 | assert x_start.shape == x_t.shape 232 | posterior_mean = ( 233 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 234 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 235 | ) 236 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 237 | posterior_log_variance_clipped = _extract_into_tensor( 238 | self.posterior_log_variance_clipped, t, x_t.shape 239 | ) 240 | assert ( 241 | posterior_mean.shape[0] 242 | == posterior_variance.shape[0] 243 | == posterior_log_variance_clipped.shape[0] 244 | == x_start.shape[0] 245 | ) 246 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 247 | 248 | def p_mean_variance( 249 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 250 | ): 251 | """ 252 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 253 | the initial x, x_0. 254 | 255 | :param model: the model, which takes a signal and a batch of timesteps 256 | as input. 257 | :param x: the [N x C x ...] tensor at time t. 258 | :param t: a 1-D Tensor of timesteps. 259 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 260 | :param denoised_fn: if not None, a function which applies to the 261 | x_start prediction before it is used to sample. Applies before 262 | clip_denoised. 263 | :param model_kwargs: if not None, a dict of extra keyword arguments to 264 | pass to the model. This can be used for conditioning. 265 | :return: a dict with the following keys: 266 | - 'mean': the model mean output. 267 | - 'variance': the model variance output. 268 | - 'log_variance': the log of 'variance'. 269 | - 'pred_xstart': the prediction for x_0. 270 | """ 271 | if model_kwargs is None: 272 | model_kwargs = {} 273 | 274 | B, C = x.shape[:2] 275 | assert t.shape == (B,) 276 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 277 | 278 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 279 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 280 | model_output, model_var_values = th.split(model_output, C, dim=1) 281 | if self.model_var_type == ModelVarType.LEARNED: 282 | model_log_variance = model_var_values 283 | model_variance = th.exp(model_log_variance) 284 | else: 285 | min_log = _extract_into_tensor( 286 | self.posterior_log_variance_clipped, t, x.shape 287 | ) 288 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 289 | # The model_var_values is [-1, 1] for [min_var, max_var]. 290 | frac = (model_var_values + 1) / 2 291 | model_log_variance = frac * max_log + (1 - frac) * min_log 292 | model_variance = th.exp(model_log_variance) 293 | else: 294 | model_variance, model_log_variance = { 295 | # for fixedlarge, we set the initial (log-)variance like so 296 | # to get a better decoder log likelihood. 297 | ModelVarType.FIXED_LARGE: ( 298 | np.append(self.posterior_variance[1], self.betas[1:]), 299 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 300 | ), 301 | ModelVarType.FIXED_SMALL: ( 302 | self.posterior_variance, 303 | self.posterior_log_variance_clipped, 304 | ), 305 | }[self.model_var_type] 306 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 307 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 308 | 309 | def process_xstart(x): 310 | if denoised_fn is not None: 311 | x = denoised_fn(x) 312 | if clip_denoised: 313 | return x.clamp(-1, 1) 314 | return x 315 | 316 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 317 | pred_xstart = process_xstart( 318 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 319 | ) 320 | model_mean = model_output 321 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 322 | if self.model_mean_type == ModelMeanType.START_X: 323 | pred_xstart = process_xstart(model_output) 324 | else: 325 | pred_xstart = process_xstart( 326 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 327 | ) 328 | model_mean, _, _ = self.q_posterior_mean_variance( 329 | x_start=pred_xstart, x_t=x, t=t 330 | ) 331 | else: 332 | raise NotImplementedError(self.model_mean_type) 333 | 334 | assert ( 335 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 336 | ) 337 | return { 338 | "mean": model_mean, 339 | "variance": model_variance, 340 | "log_variance": model_log_variance, 341 | "pred_xstart": pred_xstart, 342 | } 343 | 344 | def _predict_xstart_from_eps(self, x_t, t, eps): 345 | assert x_t.shape == eps.shape 346 | return ( 347 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 348 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 349 | ) 350 | 351 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 352 | assert x_t.shape == xprev.shape 353 | return ( # (xprev - coef2*x_t) / coef1 354 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 355 | - _extract_into_tensor( 356 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 357 | ) 358 | * x_t 359 | ) 360 | 361 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 362 | return ( 363 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 364 | - pred_xstart 365 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 366 | 367 | def _scale_timesteps(self, t): 368 | if self.rescale_timesteps: 369 | return t.float() * (1000.0 / self.num_timesteps) 370 | return t 371 | 372 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 373 | """ 374 | Compute the mean for the previous step, given a function cond_fn that 375 | computes the gradient of a conditional log probability with respect to 376 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 377 | condition on y. 378 | 379 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 380 | """ 381 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) 382 | new_mean = ( 383 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 384 | ) 385 | return new_mean 386 | 387 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 388 | """ 389 | Compute what the p_mean_variance output would have been, should the 390 | model's score function be conditioned by cond_fn. 391 | 392 | See condition_mean() for details on cond_fn. 393 | 394 | Unlike condition_mean(), this instead uses the conditioning strategy 395 | from Song et al (2020). 396 | """ 397 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 398 | 399 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 400 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn( 401 | x, self._scale_timesteps(t), **model_kwargs 402 | ) 403 | 404 | out = p_mean_var.copy() 405 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 406 | out["mean"], _, _ = self.q_posterior_mean_variance( 407 | x_start=out["pred_xstart"], x_t=x, t=t 408 | ) 409 | return out 410 | 411 | def p_sample( 412 | self, 413 | model, 414 | x, 415 | t, 416 | clip_denoised=True, 417 | denoised_fn=None, 418 | cond_fn=None, 419 | model_kwargs=None, 420 | ): 421 | """ 422 | Sample x_{t-1} from the model at the given timestep. 423 | 424 | :param model: the model to sample from. 425 | :param x: the current tensor at x_{t-1}. 426 | :param t: the value of t, starting at 0 for the first diffusion step. 427 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 428 | :param denoised_fn: if not None, a function which applies to the 429 | x_start prediction before it is used to sample. 430 | :param cond_fn: if not None, this is a gradient function that acts 431 | similarly to the model. 432 | :param model_kwargs: if not None, a dict of extra keyword arguments to 433 | pass to the model. This can be used for conditioning. 434 | :return: a dict containing the following keys: 435 | - 'sample': a random sample from the model. 436 | - 'pred_xstart': a prediction of x_0. 437 | """ 438 | out = self.p_mean_variance( 439 | model, 440 | x, 441 | t, 442 | clip_denoised=clip_denoised, 443 | denoised_fn=denoised_fn, 444 | model_kwargs=model_kwargs, 445 | ) 446 | noise = th.randn_like(x) 447 | nonzero_mask = ( 448 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 449 | ) # no noise when t == 0 450 | if cond_fn is not None: 451 | out["mean"] = self.condition_mean( 452 | cond_fn, out, x, t, model_kwargs=model_kwargs 453 | ) 454 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 455 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 456 | 457 | def p_sample_loop( 458 | self, 459 | model, 460 | shape, 461 | noise=None, 462 | clip_denoised=True, 463 | denoised_fn=None, 464 | cond_fn=None, 465 | model_kwargs=None, 466 | device=None, 467 | progress=False, 468 | ): 469 | """ 470 | Generate samples from the model. 471 | 472 | :param model: the model module. 473 | :param shape: the shape of the samples, (N, C, H, W). 474 | :param noise: if specified, the noise from the encoder to sample. 475 | Should be of the same shape as `shape`. 476 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 477 | :param denoised_fn: if not None, a function which applies to the 478 | x_start prediction before it is used to sample. 479 | :param cond_fn: if not None, this is a gradient function that acts 480 | similarly to the model. 481 | :param model_kwargs: if not None, a dict of extra keyword arguments to 482 | pass to the model. This can be used for conditioning. 483 | :param device: if specified, the device to create the samples on. 484 | If not specified, use a model parameter's device. 485 | :param progress: if True, show a tqdm progress bar. 486 | :return: a non-differentiable batch of samples. 487 | """ 488 | final = None 489 | for sample in self.p_sample_loop_progressive( 490 | model, 491 | shape, 492 | noise=noise, 493 | clip_denoised=clip_denoised, 494 | denoised_fn=denoised_fn, 495 | cond_fn=cond_fn, 496 | model_kwargs=model_kwargs, 497 | device=device, 498 | progress=progress, 499 | ): 500 | final = sample 501 | return final["sample"] 502 | 503 | def p_sample_loop_progressive( 504 | self, 505 | model, 506 | shape, 507 | noise=None, 508 | clip_denoised=True, 509 | denoised_fn=None, 510 | cond_fn=None, 511 | model_kwargs=None, 512 | device=None, 513 | progress=False, 514 | ): 515 | """ 516 | Generate samples from the model and yield intermediate samples from 517 | each timestep of diffusion. 518 | 519 | Arguments are the same as p_sample_loop(). 520 | Returns a generator over dicts, where each dict is the return value of 521 | p_sample(). 522 | """ 523 | if device is None: 524 | device = next(model.parameters()).device 525 | assert isinstance(shape, (tuple, list)) 526 | if noise is not None: 527 | img = noise 528 | else: 529 | img = th.randn(*shape, device=device) 530 | indices = list(range(self.num_timesteps))[::-1] 531 | 532 | if progress: 533 | # Lazy import so that we don't depend on tqdm. 534 | from tqdm.auto import tqdm 535 | 536 | indices = tqdm(indices) 537 | 538 | for i in indices: 539 | t = th.tensor([i] * shape[0], device=device) 540 | with th.no_grad(): 541 | out = self.p_sample( 542 | model, 543 | img, 544 | t, 545 | clip_denoised=clip_denoised, 546 | denoised_fn=denoised_fn, 547 | cond_fn=cond_fn, 548 | model_kwargs=model_kwargs, 549 | ) 550 | yield out 551 | img = out["sample"] 552 | 553 | 554 | def ddim_sample( 555 | self, 556 | model, 557 | x, 558 | t, 559 | clip_denoised=True, 560 | denoised_fn=None, 561 | cond_fn=None, 562 | model_kwargs=None, 563 | eta=0.0, 564 | ): 565 | """ 566 | Sample x_{t-1} from the model using DDIM. 567 | 568 | Same usage as p_sample(). 569 | """ 570 | out = self.p_mean_variance( 571 | model, 572 | x, 573 | t, 574 | clip_denoised=clip_denoised, 575 | denoised_fn=denoised_fn, 576 | model_kwargs=model_kwargs, 577 | ) 578 | if cond_fn is not None: 579 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 580 | 581 | # Usually our model outputs epsilon, but we re-derive it 582 | # in case we used x_start or x_prev prediction. 583 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 584 | 585 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 586 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 587 | sigma = ( 588 | eta 589 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 590 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 591 | ) 592 | # Equation 12. 593 | noise = th.randn_like(x) 594 | mean_pred = ( 595 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 596 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 597 | ) 598 | nonzero_mask = ( 599 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 600 | ) # no noise when t == 0 601 | sample = mean_pred + nonzero_mask * sigma * noise 602 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 603 | 604 | def ddim_reverse_sample( 605 | self, 606 | model, 607 | x, 608 | t, 609 | clip_denoised=True, 610 | denoised_fn=None, 611 | model_kwargs=None, 612 | eta=0.0, 613 | ): 614 | """ 615 | Sample x_{t+1} from the model using DDIM reverse ODE. 616 | """ 617 | assert eta == 0.0, "Reverse ODE only for deterministic path" 618 | out = self.p_mean_variance( 619 | model, 620 | x, 621 | t, 622 | clip_denoised=clip_denoised, 623 | denoised_fn=denoised_fn, 624 | model_kwargs=model_kwargs, 625 | ) 626 | # Usually our model outputs epsilon, but we re-derive it 627 | # in case we used x_start or x_prev prediction. 628 | eps = ( 629 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 630 | - out["pred_xstart"] 631 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 632 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 633 | 634 | # Equation 12. reversed 635 | mean_pred = ( 636 | out["pred_xstart"] * th.sqrt(alpha_bar_next) 637 | + th.sqrt(1 - alpha_bar_next) * eps 638 | ) 639 | 640 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 641 | 642 | def ddim_sample_loop( 643 | self, 644 | model, 645 | shape, 646 | noise=None, 647 | clip_denoised=True, 648 | denoised_fn=None, 649 | cond_fn=None, 650 | model_kwargs=None, 651 | device=None, 652 | progress=False, 653 | eta=0.0, 654 | ): 655 | """ 656 | Generate samples from the model using DDIM. 657 | 658 | Same usage as p_sample_loop(). 659 | """ 660 | final = None 661 | for sample in self.ddim_sample_loop_progressive( 662 | model, 663 | shape, 664 | noise=noise, 665 | clip_denoised=clip_denoised, 666 | denoised_fn=denoised_fn, 667 | cond_fn=cond_fn, 668 | model_kwargs=model_kwargs, 669 | device=device, 670 | progress=progress, 671 | eta=eta, 672 | ): 673 | final = sample 674 | return final["sample"] 675 | 676 | def ddim_sample_loop_progressive( 677 | self, 678 | model, 679 | shape, 680 | noise=None, 681 | clip_denoised=True, 682 | denoised_fn=None, 683 | cond_fn=None, 684 | model_kwargs=None, 685 | device=None, 686 | progress=False, 687 | eta=0.0, 688 | ): 689 | """ 690 | Use DDIM to sample from the model and yield intermediate samples from 691 | each timestep of DDIM. 692 | 693 | Same usage as p_sample_loop_progressive(). 694 | """ 695 | if device is None: 696 | device = next(model.parameters()).device 697 | assert isinstance(shape, (tuple, list)) 698 | if noise is not None: 699 | img = noise 700 | else: 701 | img = th.randn(*shape, device=device) 702 | indices = list(range(self.num_timesteps))[::-1] 703 | 704 | if progress: 705 | # Lazy import so that we don't depend on tqdm. 706 | from tqdm.auto import tqdm 707 | 708 | indices = tqdm(indices) 709 | 710 | for i in indices: 711 | t = th.tensor([i] * shape[0], device=device) 712 | with th.no_grad(): 713 | out = self.ddim_sample( 714 | model, 715 | img, 716 | t, 717 | clip_denoised=clip_denoised, 718 | denoised_fn=denoised_fn, 719 | cond_fn=cond_fn, 720 | model_kwargs=model_kwargs, 721 | eta=eta, 722 | ) 723 | yield out 724 | img = out["sample"] 725 | 726 | def _vb_terms_bpd( 727 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 728 | ): 729 | """ 730 | Get a term for the variational lower-bound. 731 | 732 | The resulting units are bits (rather than nats, as one might expect). 733 | This allows for comparison to other papers. 734 | 735 | :return: a dict with the following keys: 736 | - 'output': a shape [N] tensor of NLLs or KLs. 737 | - 'pred_xstart': the x_0 predictions. 738 | """ 739 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 740 | x_start=x_start, x_t=x_t, t=t 741 | ) 742 | out = self.p_mean_variance( 743 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 744 | ) 745 | kl = normal_kl( 746 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 747 | ) 748 | kl = mean_flat(kl) / np.log(2.0) 749 | 750 | decoder_nll = -discretized_gaussian_log_likelihood( 751 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 752 | ) 753 | assert decoder_nll.shape == x_start.shape 754 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 755 | 756 | # At the first timestep return the decoder NLL, 757 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 758 | output = th.where((t == 0), decoder_nll, kl) 759 | return {"output": output, "pred_xstart": out["pred_xstart"]} 760 | 761 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 762 | """ 763 | Compute training losses for a single timestep. 764 | 765 | :param model: the model to evaluate loss on. 766 | :param x_start: the [N x C x ...] tensor of inputs. 767 | :param t: a batch of timestep indices. 768 | :param model_kwargs: if not None, a dict of extra keyword arguments to 769 | pass to the model. This can be used for conditioning. 770 | :param noise: if specified, the specific Gaussian noise to try to remove. 771 | :return: a dict with the key "loss" containing a tensor of shape [N]. 772 | Some mean or variance settings may also have other keys. 773 | """ 774 | if model_kwargs is None: 775 | model_kwargs = {} 776 | if noise is None: 777 | noise = th.randn_like(x_start) 778 | x_t = self.q_sample(x_start, t, noise=noise) 779 | 780 | terms = {} 781 | 782 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 783 | terms["loss"] = self._vb_terms_bpd( 784 | model=model, 785 | x_start=x_start, 786 | x_t=x_t, 787 | t=t, 788 | clip_denoised=False, 789 | model_kwargs=model_kwargs, 790 | )["output"] 791 | if self.loss_type == LossType.RESCALED_KL: 792 | terms["loss"] *= self.num_timesteps 793 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 794 | model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) 795 | 796 | if self.model_var_type in [ 797 | ModelVarType.LEARNED, 798 | ModelVarType.LEARNED_RANGE, 799 | ]: 800 | B, C = x_t.shape[:2] 801 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 802 | model_output, model_var_values = th.split(model_output, C, dim=1) 803 | # Learn the variance using the variational bound, but don't let 804 | # it affect our mean prediction. 805 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 806 | terms["vb"] = self._vb_terms_bpd( 807 | model=lambda *args, r=frozen_out: r, 808 | x_start=x_start, 809 | x_t=x_t, 810 | t=t, 811 | clip_denoised=False, 812 | )["output"] 813 | if self.loss_type == LossType.RESCALED_MSE: 814 | # Divide by 1000 for equivalence with initial implementation. 815 | # Without a factor of 1/1000, the VB term hurts the MSE term. 816 | terms["vb"] *= self.num_timesteps / 1000.0 817 | 818 | target = { 819 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 820 | x_start=x_start, x_t=x_t, t=t 821 | )[0], 822 | ModelMeanType.START_X: x_start, 823 | ModelMeanType.EPSILON: noise, 824 | }[self.model_mean_type] 825 | assert model_output.shape == target.shape == x_start.shape 826 | terms["mse"] = mean_flat((target - model_output) ** 2) 827 | if "vb" in terms: 828 | terms["loss"] = terms["mse"] + terms["vb"] 829 | else: 830 | terms["loss"] = terms["mse"] 831 | else: 832 | raise NotImplementedError(self.loss_type) 833 | 834 | return terms 835 | 836 | def _prior_bpd(self, x_start): 837 | """ 838 | Get the prior KL term for the variational lower-bound, measured in 839 | bits-per-dim. 840 | 841 | This term can't be optimized, as it only depends on the encoder. 842 | 843 | :param x_start: the [N x C x ...] tensor of inputs. 844 | :return: a batch of [N] KL values (in bits), one per batch element. 845 | """ 846 | batch_size = x_start.shape[0] 847 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 848 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 849 | kl_prior = normal_kl( 850 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 851 | ) 852 | return mean_flat(kl_prior) / np.log(2.0) 853 | 854 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 855 | """ 856 | Compute the entire variational lower-bound, measured in bits-per-dim, 857 | as well as other related quantities. 858 | 859 | :param model: the model to evaluate loss on. 860 | :param x_start: the [N x C x ...] tensor of inputs. 861 | :param clip_denoised: if True, clip denoised samples. 862 | :param model_kwargs: if not None, a dict of extra keyword arguments to 863 | pass to the model. This can be used for conditioning. 864 | 865 | :return: a dict containing the following keys: 866 | - total_bpd: the total variational lower-bound, per batch element. 867 | - prior_bpd: the prior term in the lower-bound. 868 | - vb: an [N x T] tensor of terms in the lower-bound. 869 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 870 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 871 | """ 872 | device = x_start.device 873 | batch_size = x_start.shape[0] 874 | 875 | vb = [] 876 | xstart_mse = [] 877 | mse = [] 878 | for t in list(range(self.num_timesteps))[::-1]: 879 | t_batch = th.tensor([t] * batch_size, device=device) 880 | noise = th.randn_like(x_start) 881 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 882 | # Calculate VLB term at the current timestep 883 | with th.no_grad(): 884 | out = self._vb_terms_bpd( 885 | model, 886 | x_start=x_start, 887 | x_t=x_t, 888 | t=t_batch, 889 | clip_denoised=clip_denoised, 890 | model_kwargs=model_kwargs, 891 | ) 892 | vb.append(out["output"]) 893 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 894 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 895 | mse.append(mean_flat((eps - noise) ** 2)) 896 | 897 | vb = th.stack(vb, dim=1) 898 | xstart_mse = th.stack(xstart_mse, dim=1) 899 | mse = th.stack(mse, dim=1) 900 | 901 | prior_bpd = self._prior_bpd(x_start) 902 | total_bpd = vb.sum(dim=1) + prior_bpd 903 | return { 904 | "total_bpd": total_bpd, 905 | "prior_bpd": prior_bpd, 906 | "vb": vb, 907 | "xstart_mse": xstart_mse, 908 | "mse": mse, 909 | } 910 | 911 | 912 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 913 | """ 914 | Extract values from a 1-D numpy array for a batch of indices. 915 | 916 | :param arr: the 1-D numpy array. 917 | :param timesteps: a tensor of indices into the array to extract. 918 | :param broadcast_shape: a larger shape of K dimensions with the batch 919 | dimension equal to the length of timesteps. 920 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 921 | """ 922 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 923 | while len(res.shape) < len(broadcast_shape): 924 | res = res[..., None] 925 | return res.expand(broadcast_shape) 926 | -------------------------------------------------------------------------------- /sdg/gpu_affinity.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | import math 6 | import os 7 | import pynvml 8 | 9 | pynvml.nvmlInit() 10 | 11 | 12 | def systemGetDriverVersion(): 13 | r"""Get Driver Version""" 14 | return pynvml.nvmlSystemGetDriverVersion() 15 | 16 | 17 | def deviceGetCount(): 18 | r"""Get number of devices""" 19 | return pynvml.nvmlDeviceGetCount() 20 | 21 | 22 | class device(object): 23 | r"""Device used for nvml.""" 24 | _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) 25 | 26 | def __init__(self, device_idx): 27 | super().__init__() 28 | self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) 29 | 30 | def getName(self): 31 | r"""Get obect name""" 32 | return pynvml.nvmlDeviceGetName(self.handle) 33 | 34 | def getCpuAffinity(self): 35 | r"""Get CPU affinity""" 36 | affinity_string = '' 37 | for j in pynvml.nvmlDeviceGetCpuAffinity( 38 | self.handle, device._nvml_affinity_elements): 39 | # assume nvml returns list of 64 bit ints 40 | affinity_string = '{:064b}'.format(j) + affinity_string 41 | affinity_list = [int(x) for x in affinity_string] 42 | affinity_list.reverse() # so core 0 is in 0th element of list 43 | 44 | return [i for i, e in enumerate(affinity_list) if e != 0] 45 | 46 | 47 | def set_affinity(gpu_id=None): 48 | r"""Set GPU affinity 49 | 50 | Args: 51 | gpu_id (int): Which gpu device. 52 | """ 53 | if gpu_id is None: 54 | gpu_id = int(os.getenv('LOCAL_RANK', 0)) 55 | 56 | dev = device(gpu_id) 57 | os.sched_setaffinity(0, dev.getCpuAffinity()) 58 | 59 | # list of ints 60 | # representing the logical cores this process is now affinitied with 61 | return os.sched_getaffinity(0) 62 | -------------------------------------------------------------------------------- /sdg/guidance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def gram_matrix(input): 8 | if input.dtype == torch.float16: 9 | input = input.to(torch.float32) 10 | flag = True 11 | a, b, c, d = input.size() # a=batch size(=1) 12 | sqrt_sum = math.sqrt(a * b * c * d) # for numerical stability 13 | features = input.view(a * b, c * d) / sqrt_sum # resise F_XL into \hat F_XL 14 | G = torch.mm(features, features.t()) # compute the gram product 15 | # we 'normalize' the values of the gram matrix 16 | # by dividing by the number of element in each feature maps. 17 | result = G 18 | if flag: 19 | return result.to(torch.float16) 20 | else: 21 | return result 22 | 23 | def image_loss(source, target, args): 24 | if args.image_loss == 'semantic': 25 | source[-1] = source[-1] / source[-1].norm(dim=-1, keepdim=True) 26 | target[-1] = target[-1] / target[-1].norm(dim=-1, keepdim=True) 27 | return (source[-1] * target[-1]).sum(1) 28 | elif args.image_loss == 'style': 29 | weights = [1, 1, 1, 1, 1] 30 | loss = 0 31 | for cnt in range(5): 32 | loss += F.mse_loss(gram_matrix(source[cnt]), gram_matrix(target[cnt])) 33 | return -loss * 1e10 / sum(weights) 34 | 35 | def text_loss(source, target, args): 36 | source_feat = source[-1] / source[-1].norm(dim=-1, keepdim=True) 37 | target = target / target.norm(dim=-1, keepdim=True) 38 | return (source_feat * target).sum(1) 39 | -------------------------------------------------------------------------------- /sdg/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import os 4 | import pickle 5 | 6 | from PIL import Image 7 | import blobfile as bf 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | from torch.utils.data import DataLoader, Dataset 12 | import clip 13 | from sdg.distributed import master_only_print as print 14 | import torchvision.transforms as transforms 15 | import json 16 | import io 17 | 18 | 19 | def load_ref_data(args, ref_img_path=None): 20 | if ref_img_path is None: 21 | ref_img_path = args.ref_img_path 22 | with bf.BlobFile(ref_img_path, "rb") as f: 23 | pil_image = Image.open(f) 24 | pil_image.load() 25 | 26 | pil_image = pil_image.convert("RGB") 27 | arr = center_crop_arr(pil_image, args.image_size) 28 | arr = arr.astype(np.float32) / 127.5 - 1 29 | arr = np.repeat(np.expand_dims(np.transpose(arr, [2, 0, 1]), axis=0), args.batch_size, axis=0) 30 | kwargs = {} 31 | kwargs["ref_img"] = torch.tensor(arr) 32 | return kwargs 33 | 34 | def load_data(args, is_train=True, swap=False): 35 | """ 36 | For a dataset, create a generator over (images, kwargs) pairs. 37 | 38 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 39 | more keys, each of which map to a batched Tensor of their own. 40 | The kwargs dict can be used for class labels, in which case the key is "y" 41 | and the values are integer tensors of class labels. 42 | 43 | :param data_dir: a dataset directory. 44 | :param batch_size: the batch size of each returned pair. 45 | :param image_size: the size to which images are resized. 46 | :param class_cond: if True, include a "y" key in returned dicts for class 47 | label. If classes are not available and this is true, an 48 | exception will be raised. 49 | :param deterministic: if True, yield results in a deterministic order. 50 | :param random_crop: if True, randomly crop the images for augmentation. 51 | :param random_flip: if True, randomly flip the images for augmentation. 52 | """ 53 | if not args.data_dir: 54 | raise ValueError("unspecified data directory") 55 | data_dir = getattr(args, 'data_dir') 56 | batch_size = getattr(args, 'batch_size') 57 | image_size = getattr(args, 'image_size') 58 | class_cond = getattr(args, 'class_cond', False) 59 | deterministic = getattr(args, 'deterministic', False) 60 | random_crop = getattr(args, 'random_crop', False) 61 | random_flip = getattr(args, 'random_flip', True) 62 | if not is_train: 63 | deterministic = True 64 | random_crop = False 65 | random_flip = False 66 | num_workers = getattr(args, 'num_workers', 4) 67 | all_files = _list_image_files_recursively(data_dir) 68 | classes = None 69 | if class_cond: 70 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 71 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 72 | classes = [sorted_classes[x] for x in class_names] 73 | dataset = ImageDataset( 74 | image_size, 75 | all_files, 76 | classes=classes, 77 | random_crop=random_crop, 78 | random_flip=random_flip, 79 | ) 80 | not_distributed = args.single_gpu or args.debug or not dist.is_initialized() 81 | if not_distributed: 82 | sampler = None 83 | else: 84 | if deterministic: 85 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) 86 | else: 87 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) 88 | if deterministic or not not_distributed: 89 | loader = DataLoader( 90 | dataset, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=num_workers, drop_last=False 91 | ) 92 | else: 93 | loader = DataLoader( 94 | dataset, batch_size=batch_size, shuffle=True, sampler=sampler, num_workers=num_workers, drop_last=False 95 | ) 96 | while True: 97 | yield from loader 98 | 99 | 100 | 101 | def load_ref_data(args, ref_img_path=None): 102 | if ref_img_path is None: 103 | ref_img_path = args.ref_img_path 104 | with bf.BlobFile(ref_img_path, "rb") as f: 105 | pil_image = Image.open(f) 106 | pil_image.load() 107 | 108 | pil_image = pil_image.convert("RGB") 109 | 110 | arr = center_crop_arr(pil_image, args.image_size) 111 | 112 | arr = arr.astype(np.float32) / 127.5 - 1 113 | 114 | arr = np.repeat(np.expand_dims(np.transpose(arr, [2, 0, 1]), axis=0), args.batch_size, axis=0) 115 | 116 | kwargs = {} 117 | 118 | kwargs["ref_img"] = torch.tensor(arr) 119 | 120 | return kwargs 121 | 122 | def _list_image_files_recursively(data_dir): 123 | results = [] 124 | for entry in sorted(bf.listdir(data_dir)): 125 | full_path = bf.join(data_dir, entry) 126 | ext = entry.split(".")[-1] 127 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 128 | results.append(full_path) 129 | elif bf.isdir(full_path): 130 | results.extend(_list_image_files_recursively(full_path)) 131 | return results 132 | 133 | 134 | class ImageDataset(Dataset): 135 | def __init__( 136 | self, 137 | resolution, 138 | image_paths, 139 | classes=None, 140 | random_crop=False, 141 | random_flip=True, 142 | ): 143 | super().__init__() 144 | self.resolution = resolution 145 | self.images = image_paths 146 | self.classes = None if classes is None else classes 147 | self.random_crop = random_crop 148 | self.random_flip = random_flip 149 | 150 | def __len__(self): 151 | return len(self.images) 152 | 153 | def __getitem__(self, idx): 154 | path = self.images[idx] 155 | with bf.BlobFile(path, "rb") as f: 156 | pil_image = Image.open(f) 157 | pil_image.load() 158 | pil_image = pil_image.convert("RGB") 159 | 160 | if self.random_crop: 161 | arr = random_crop_arr(pil_image, self.resolution) 162 | else: 163 | arr = center_crop_arr(pil_image, self.resolution) 164 | 165 | if self.random_flip and random.random() < 0.5: 166 | arr = arr[:, ::-1] 167 | 168 | arr = arr.astype(np.float32) / 127.5 - 1 169 | 170 | out_dict = {} 171 | if self.classes is not None: 172 | out_dict["y"] = np.array(self.classes[idx], dtype=np.int64) 173 | return np.transpose(arr, [2, 0, 1]), out_dict 174 | 175 | def center_crop_arr(pil_image, image_size): 176 | # We are not on a new enough PIL to support the `reducing_gap` 177 | # argument, which uses BOX downsampling at powers of two first. 178 | # Thus, we do it by hand to improve downsample quality. 179 | while min(*pil_image.size) >= 2 * image_size: 180 | pil_image = pil_image.resize( 181 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 182 | ) 183 | 184 | scale = image_size / min(*pil_image.size) 185 | pil_image = pil_image.resize( 186 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 187 | ) 188 | 189 | arr = np.array(pil_image) 190 | crop_y = (arr.shape[0] - image_size) // 2 191 | crop_x = (arr.shape[1] - image_size) // 2 192 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 193 | 194 | 195 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.85, max_crop_frac=0.95): 196 | if min(*pil_image.size) != max(*pil_image.size): 197 | min_crop_frac = 1.0 198 | max_crop_frac = 1.0 199 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 200 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 201 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 202 | 203 | # We are not on a new enough PIL to support the `reducing_gap` 204 | # argument, which uses BOX downsampling at powers of two first. 205 | # Thus, we do it by hand to improve downsample quality. 206 | while min(*pil_image.size) >= 2 * smaller_dim_size: 207 | pil_image = pil_image.resize( 208 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 209 | ) 210 | 211 | scale = smaller_dim_size / min(*pil_image.size) 212 | pil_image = pil_image.resize( 213 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 214 | ) 215 | 216 | arr = np.array(pil_image) 217 | 218 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 219 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 220 | 221 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 222 | -------------------------------------------------------------------------------- /sdg/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | def logkv_mean(key, val): 221 | """ 222 | The same as logkv(), but if called many times, values averaged. 223 | """ 224 | get_current().logkv_mean(key, val) 225 | 226 | 227 | def logkvs(d): 228 | """ 229 | Log a dictionary of key-value pairs 230 | """ 231 | for (k, v) in d.items(): 232 | logkv(k, v) 233 | 234 | 235 | def dumpkvs(): 236 | """ 237 | Write all of the diagnostics from the current iteration 238 | """ 239 | return get_current().dumpkvs() 240 | 241 | 242 | def getkvs(): 243 | return get_current().name2val 244 | 245 | 246 | def log(*args, level=INFO): 247 | """ 248 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 249 | """ 250 | get_current().log(*args, level=level) 251 | 252 | 253 | def debug(*args): 254 | log(*args, level=DEBUG) 255 | 256 | 257 | def info(*args): 258 | log(*args, level=INFO) 259 | 260 | 261 | def warn(*args): 262 | log(*args, level=WARN) 263 | 264 | 265 | def error(*args): 266 | log(*args, level=ERROR) 267 | 268 | 269 | def set_level(level): 270 | """ 271 | Set logging threshold on current logger. 272 | """ 273 | get_current().set_level(level) 274 | 275 | 276 | def set_comm(comm): 277 | get_current().set_comm(comm) 278 | 279 | 280 | def get_dir(): 281 | """ 282 | Get directory that log files are being written to. 283 | will be None if there is no output directory (i.e., if you didn't call start) 284 | """ 285 | return get_current().get_dir() 286 | 287 | 288 | record_tabular = logkv 289 | dump_tabular = dumpkvs 290 | 291 | 292 | @contextmanager 293 | def profile_kv(scopename): 294 | logkey = "wait_" + scopename 295 | tstart = time.time() 296 | try: 297 | yield 298 | finally: 299 | get_current().name2val[logkey] += time.time() - tstart 300 | 301 | 302 | def profile(n): 303 | """ 304 | Usage: 305 | @profile("my_func") 306 | def my_func(): code 307 | """ 308 | 309 | def decorator_with_name(func): 310 | def func_wrapper(*args, **kwargs): 311 | with profile_kv(n): 312 | return func(*args, **kwargs) 313 | 314 | return func_wrapper 315 | 316 | return decorator_with_name 317 | 318 | 319 | # ================================================================ 320 | # Backend 321 | # ================================================================ 322 | 323 | 324 | def get_current(): 325 | if Logger.CURRENT is None: 326 | _configure_default_logger() 327 | 328 | return Logger.CURRENT 329 | 330 | 331 | class Logger(object): 332 | DEFAULT = None # A logger with no output files. (See right below class definition) 333 | # So that you can still log to the terminal without setting up any output files 334 | CURRENT = None # Current logger being used by the free functions above 335 | 336 | def __init__(self, dir, output_formats, comm=None): 337 | self.name2val = defaultdict(float) # values this iteration 338 | self.name2cnt = defaultdict(int) 339 | self.level = INFO 340 | self.dir = dir 341 | self.output_formats = output_formats 342 | self.comm = comm 343 | 344 | # Logging API, forwarded 345 | # ---------------------------------------- 346 | def logkv(self, key, val): 347 | self.name2val[key] = val 348 | 349 | def logkv_mean(self, key, val): 350 | oldval, cnt = self.name2val[key], self.name2cnt[key] 351 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 352 | self.name2cnt[key] = cnt + 1 353 | 354 | def dumpkvs(self): 355 | if self.comm is None: 356 | d = self.name2val 357 | else: 358 | d = mpi_weighted_mean( 359 | self.comm, 360 | { 361 | name: (val, self.name2cnt.get(name, 1)) 362 | for (name, val) in self.name2val.items() 363 | }, 364 | ) 365 | if self.comm.rank != 0: 366 | d["dummy"] = 1 # so we don't get a warning about empty dict 367 | out = d.copy() # Return the dict for unit testing purposes 368 | for fmt in self.output_formats: 369 | if isinstance(fmt, KVWriter): 370 | fmt.writekvs(d) 371 | self.name2val.clear() 372 | self.name2cnt.clear() 373 | return out 374 | 375 | def log(self, *args, level=INFO): 376 | if self.level <= level: 377 | self._do_log(args) 378 | 379 | # Configuration 380 | # ---------------------------------------- 381 | def set_level(self, level): 382 | self.level = level 383 | 384 | def set_comm(self, comm): 385 | self.comm = comm 386 | 387 | def get_dir(self): 388 | return self.dir 389 | 390 | def close(self): 391 | for fmt in self.output_formats: 392 | fmt.close() 393 | 394 | # Misc 395 | # ---------------------------------------- 396 | def _do_log(self, args): 397 | for fmt in self.output_formats: 398 | if isinstance(fmt, SeqWriter): 399 | fmt.writeseq(map(str, args)) 400 | 401 | 402 | def get_rank_without_mpi_import(): 403 | # check environment variables here instead of importing mpi4py 404 | # to avoid calling MPI_Init() when this module is imported 405 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 406 | if varname in os.environ: 407 | return int(os.environ[varname]) 408 | return 0 409 | 410 | 411 | def mpi_weighted_mean(comm, local_name2valcount): 412 | """ 413 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 414 | Perform a weighted average over dicts that are each on a different node 415 | Input: local_name2valcount: dict mapping key -> (value, count) 416 | Returns: key -> mean 417 | """ 418 | all_name2valcount = comm.gather(local_name2valcount) 419 | if comm.rank == 0: 420 | name2sum = defaultdict(float) 421 | name2count = defaultdict(float) 422 | for n2vc in all_name2valcount: 423 | for (name, (val, count)) in n2vc.items(): 424 | try: 425 | val = float(val) 426 | except ValueError: 427 | if comm.rank == 0: 428 | warnings.warn( 429 | "WARNING: tried to compute mean on non-float {}={}".format( 430 | name, val 431 | ) 432 | ) 433 | else: 434 | name2sum[name] += val * count 435 | name2count[name] += count 436 | return {name: name2sum[name] / name2count[name] for name in name2sum} 437 | else: 438 | return {} 439 | 440 | 441 | def configure(dir='./checkpoints/', exp_name='debug', format_strs=None, comm=None, log_suffix=""): 442 | """ 443 | If comm is provided, average all numerical stats across that comm 444 | """ 445 | dir = osp.join( 446 | dir, 447 | exp_name, # + '_' + datetime.datetime.now().strftime("%m-%d-%H-%M-%S-%f"), 448 | ) 449 | assert isinstance(dir, str) 450 | dir = os.path.expanduser(dir) 451 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 452 | 453 | rank = get_rank_without_mpi_import() 454 | if rank > 0: 455 | log_suffix = log_suffix + "-rank%03i" % rank 456 | 457 | if format_strs is None: 458 | if rank == 0: 459 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 460 | else: 461 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 462 | format_strs = filter(None, format_strs) 463 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 464 | 465 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 466 | if output_formats: 467 | log("Logging to %s" % dir) 468 | 469 | 470 | def _configure_default_logger(): 471 | configure() 472 | Logger.DEFAULT = Logger.CURRENT 473 | 474 | 475 | def reset(): 476 | if Logger.CURRENT is not Logger.DEFAULT: 477 | Logger.CURRENT.close() 478 | Logger.CURRENT = Logger.DEFAULT 479 | log("Reset logger") 480 | 481 | 482 | @contextmanager 483 | def scoped_configure(dir=None, format_strs=None, comm=None): 484 | prevlogger = Logger.CURRENT 485 | configure(dir=dir, format_strs=format_strs, comm=comm) 486 | try: 487 | yield 488 | finally: 489 | Logger.CURRENT.close() 490 | Logger.CURRENT = prevlogger 491 | 492 | -------------------------------------------------------------------------------- /sdg/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | import datetime 6 | import os 7 | import torch 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from .distributed import master_only, is_master 12 | from .distributed import master_only_print as print 13 | from .distributed import dist_all_reduce_tensor 14 | from .misc import to_cuda 15 | import pdb 16 | 17 | 18 | def get_date_uid(): 19 | """Generate a unique id based on date. 20 | Returns: 21 | str: Return uid string, e.g. '20171122171307111552'. 22 | """ 23 | return str(datetime.datetime.now().strftime("%Y_%m%d_%H%M_%S")) 24 | 25 | 26 | def init_logging(exp_name, root_dir='logs', timestamp=False): 27 | r"""Create log directory for storing checkpoints and output images. 28 | 29 | Args: 30 | config_path (str): Path to the configuration file. 31 | logdir (str): Log directory name 32 | Returns: 33 | str: Return log dir 34 | """ 35 | #config_file = os.path.basename(config_path) 36 | if timestamp: 37 | date_uid = get_date_uid() 38 | exp_name = '_'.join([exp_name, date_uid]) 39 | # example: logs/2019_0125_1047_58_spade_cocostuff 40 | #log_file = '_'.join([date_uid, os.path.splitext(config_file)[0]]) 41 | #log_file = os.path.splitext(config_file)[0] 42 | logdir = os.path.join(root_dir, exp_name) 43 | return logdir 44 | 45 | 46 | @master_only 47 | def make_logging_dir(logdir, no_tb=False): 48 | r"""Create the logging directory 49 | 50 | Args: 51 | logdir (str): Log directory name 52 | """ 53 | print('Make folder {}'.format(logdir)) 54 | os.makedirs(logdir, exist_ok=True) 55 | if no_tb: 56 | return None 57 | tensorboard_dir = os.path.join(logdir, 'tensorboard') 58 | os.makedirs(tensorboard_dir, exist_ok=True) 59 | tb_log = SummaryWriter(log_dir=tensorboard_dir) 60 | return tb_log 61 | 62 | def write_tb(tb_log, key, value, step): 63 | if not torch.is_tensor(value): 64 | value = torch.tensor(value) 65 | if not value.is_cuda: 66 | value = to_cuda(value) 67 | value = dist_all_reduce_tensor(value.mean()).item() 68 | if is_master(): 69 | tb_log.add_scalar(key, value, step) 70 | print('%s: %f ' % (key, value), end='') 71 | -------------------------------------------------------------------------------- /sdg/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /sdg/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, check out LICENSE.md 5 | """Miscellaneous utils.""" 6 | import collections 7 | from collections import OrderedDict 8 | import numpy as np 9 | import random 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from .distributed import get_rank 15 | 16 | string_classes = (str, bytes) 17 | 18 | def set_random_seed(seed, by_rank=False): 19 | r"""Set random seeds for everything. 20 | Args: 21 | seed (int): Random seed. 22 | by_rank (bool): 23 | """ 24 | if by_rank: 25 | seed += get_rank() 26 | print(f"Using random seed {seed}") 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | 33 | 34 | def split_labels(labels, label_lengths): 35 | r"""Split concatenated labels into their parts. 36 | 37 | Args: 38 | labels (torch.Tensor): Labels obtained through concatenation. 39 | label_lengths (OrderedDict): Containing order of labels & their lengths. 40 | 41 | Returns: 42 | 43 | """ 44 | assert isinstance(label_lengths, OrderedDict) 45 | start = 0 46 | outputs = {} 47 | for data_type, length in label_lengths.items(): 48 | end = start + length 49 | if labels.dim() == 5: 50 | outputs[data_type] = labels[:, :, start:end] 51 | elif labels.dim() == 4: 52 | outputs[data_type] = labels[:, start:end] 53 | elif labels.dim() == 3: 54 | outputs[data_type] = labels[start:end] 55 | start = end 56 | return outputs 57 | 58 | 59 | def requires_grad(model, require=True): 60 | r""" Set a model to require gradient or not. 61 | 62 | Args: 63 | model (nn.Module): Neural network model. 64 | require (bool): Whether the network requires gradient or not. 65 | 66 | Returns: 67 | 68 | """ 69 | for p in model.parameters(): 70 | p.requires_grad = require 71 | 72 | 73 | def to_device(data, device): 74 | r"""Move all tensors inside data to device. 75 | 76 | Args: 77 | data (dict, list, or tensor): Input data. 78 | device (str): 'cpu' or 'cuda'. 79 | """ 80 | assert device in ['cpu', 'cuda'] 81 | if isinstance(data, torch.Tensor): 82 | data = data.to(torch.device(device)) 83 | return data 84 | elif isinstance(data, collections.abc.Mapping): 85 | return {key: to_device(data[key], device) for key in data} 86 | elif isinstance(data, collections.abc.Sequence) and \ 87 | not isinstance(data, string_classes): 88 | return [to_device(d, device) for d in data] 89 | else: 90 | return data 91 | 92 | 93 | def to_cuda(data): 94 | r"""Move all tensors inside data to gpu. 95 | 96 | Args: 97 | data (dict, list, or tensor): Input data. 98 | """ 99 | return to_device(data, 'cuda') 100 | 101 | 102 | def to_cpu(data): 103 | r"""Move all tensors inside data to cpu. 104 | 105 | Args: 106 | data (dict, list, or tensor): Input data. 107 | """ 108 | return to_device(data, 'cpu') 109 | 110 | 111 | def to_half(data): 112 | r"""Move all floats to half. 113 | 114 | Args: 115 | data (dict, list or tensor): Input data. 116 | """ 117 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data): 118 | data = data.half() 119 | return data 120 | elif isinstance(data, collections.abc.Mapping): 121 | return {key: to_half(data[key]) for key in data} 122 | elif isinstance(data, collections.abc.Sequence) and \ 123 | not isinstance(data, string_classes): 124 | return [to_half(d) for d in data] 125 | else: 126 | return data 127 | 128 | 129 | def to_float(data): 130 | r"""Move all halfs to float. 131 | 132 | Args: 133 | data (dict, list or tensor): Input data. 134 | """ 135 | if isinstance(data, torch.Tensor) and torch.is_floating_point(data): 136 | data = data.float() 137 | return data 138 | elif isinstance(data, collections.abc.Mapping): 139 | return {key: to_float(data[key]) for key in data} 140 | elif isinstance(data, collections.abc.Sequence) and \ 141 | not isinstance(data, string_classes): 142 | return [to_float(d) for d in data] 143 | else: 144 | return data 145 | 146 | 147 | def to_channels_last(data): 148 | r"""Move all data to ``channels_last`` format. 149 | 150 | Args: 151 | data (dict, list or tensor): Input data. 152 | """ 153 | if isinstance(data, torch.Tensor): 154 | if data.dim() == 4: 155 | data = data.to(memory_format=torch.channels_last) 156 | return data 157 | elif isinstance(data, collections.abc.Mapping): 158 | return {key: to_channels_last(data[key]) for key in data} 159 | elif isinstance(data, collections.abc.Sequence) and \ 160 | not isinstance(data, string_classes): 161 | return [to_channels_last(d) for d in data] 162 | else: 163 | return data 164 | 165 | 166 | def slice_tensor(data, start, end): 167 | r"""Slice all tensors from start to end. 168 | Args: 169 | data (dict, list or tensor): Input data. 170 | """ 171 | if isinstance(data, torch.Tensor): 172 | data = data[start:end] 173 | return data 174 | elif isinstance(data, collections.abc.Mapping): 175 | return {key: slice_tensor(data[key], start, end) for key in data} 176 | elif isinstance(data, collections.abc.Sequence) and \ 177 | not isinstance(data, string_classes): 178 | return [slice_tensor(d, start, end) for d in data] 179 | else: 180 | return data 181 | 182 | 183 | def get_and_setattr(cfg, name, default): 184 | r"""Get attribute with default choice. If attribute does not exist, set it 185 | using the default value. 186 | 187 | Args: 188 | cfg (obj) : Config options. 189 | name (str) : Attribute name. 190 | default (obj) : Default attribute. 191 | 192 | Returns: 193 | (obj) : Desired attribute. 194 | """ 195 | if not hasattr(cfg, name) or name not in cfg.__dict__: 196 | setattr(cfg, name, default) 197 | return getattr(cfg, name) 198 | 199 | 200 | def get_nested_attr(cfg, attr_name, default): 201 | r"""Iteratively try to get the attribute from cfg. If not found, return 202 | default. 203 | 204 | Args: 205 | cfg (obj): Config file. 206 | attr_name (str): Attribute name (e.g. XXX.YYY.ZZZ). 207 | default (obj): Default return value for the attribute. 208 | 209 | Returns: 210 | (obj): Attribute value. 211 | """ 212 | names = attr_name.split('.') 213 | atr = cfg 214 | for name in names: 215 | if not hasattr(atr, name): 216 | return default 217 | atr = getattr(atr, name) 218 | return atr 219 | 220 | 221 | def gradient_norm(model): 222 | r"""Return the gradient norm of model. 223 | 224 | Args: 225 | model (PyTorch module): Your network. 226 | 227 | """ 228 | total_norm = 0 229 | for p in model.parameters(): 230 | if p.grad is not None: 231 | param_norm = p.grad.norm(2) 232 | total_norm += param_norm.item() ** 2 233 | return total_norm ** (1. / 2) 234 | 235 | 236 | def random_shift(x, offset=0.05, mode='bilinear', padding_mode='reflection'): 237 | r"""Randomly shift the input tensor. 238 | 239 | Args: 240 | x (4D tensor): The input batch of images. 241 | offset (int): The maximum offset ratio that is between [0, 1]. 242 | The maximum shift is offset * image_size for each direction. 243 | mode (str): The resample mode for 'F.grid_sample'. 244 | padding_mode (str): The padding mode for 'F.grid_sample'. 245 | 246 | Returns: 247 | x (4D tensor) : The randomly shifted image. 248 | """ 249 | assert x.dim() == 4, "Input must be a 4D tensor." 250 | batch_size = x.size(0) 251 | theta = torch.eye(2, 3, device=x.device).unsqueeze(0).repeat( 252 | batch_size, 1, 1) 253 | theta[:, :, 2] = 2 * offset * torch.rand(batch_size, 2) - offset 254 | grid = F.affine_grid(theta, x.size()) 255 | x = F.grid_sample(x, grid, mode=mode, padding_mode=padding_mode) 256 | return x 257 | 258 | 259 | # def truncated_gaussian(threshold, size, seed=None, device=None): 260 | # r"""Apply the truncated gaussian trick to trade diversity for quality 261 | # 262 | # Args: 263 | # threshold (float): Truncation threshold. 264 | # size (list of integer): Tensor size. 265 | # seed (int): Random seed. 266 | # device: 267 | # """ 268 | # state = None if seed is None else np.random.RandomState(seed) 269 | # values = truncnorm.rvs(-threshold, threshold, 270 | # size=size, random_state=state) 271 | # return torch.tensor(values, device=device).float() 272 | 273 | 274 | def apply_imagenet_normalization(input): 275 | r"""Normalize using ImageNet mean and std. 276 | 277 | Args: 278 | input (4D tensor NxCxHxW): The input images, assuming to be [-1, 1]. 279 | 280 | Returns: 281 | Normalized inputs using the ImageNet normalization. 282 | """ 283 | # normalize the input back to [0, 1] 284 | normalized_input = (input + 1) / 2 285 | # normalize the input using the ImageNet mean and std 286 | mean = normalized_input.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 287 | std = normalized_input.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 288 | output = (normalized_input - mean) / std 289 | return output 290 | -------------------------------------------------------------------------------- /sdg/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /sdg/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | import os 5 | import argparse 6 | 7 | from sdg.resample import create_named_schedule_sampler 8 | from sdg.script_util import ( 9 | model_and_diffusion_defaults, 10 | args_to_dict, 11 | add_dict_to_argparser, 12 | ) 13 | 14 | 15 | def create_argparser(): 16 | parser = argparse.ArgumentParser() 17 | # basic 18 | parser.add_argument('--exp_name', required=True) 19 | parser.add_argument('--resume_checkpoint', default=None) 20 | parser.add_argument('--local_rank', type=int, default=os.getenv('LOCAL_RANK', 0)) 21 | parser.add_argument('--single_gpu', action='store_true') 22 | parser.add_argument('--seed', type=int, default=2) 23 | parser.add_argument('--randomized_seed', action='store_true') 24 | parser.add_argument('--debug', action='store_true') 25 | parser.add_argument('--logdir', type=str) 26 | 27 | # data 28 | parser.add_argument('--data_dir', type=str, default='') 29 | parser.add_argument('--image_size', type=int, default=256) 30 | parser.add_argument('--deterministic', type=lambda x: (str(x).lower() == 'true'), default=False) 31 | parser.add_argument('--random_crop', type=lambda x: (str(x).lower() == 'true'), default=False) 32 | parser.add_argument('--random_flip', type=lambda x: (str(x).lower() == 'true'), default=True) 33 | parser.add_argument('--num_workers', type=int, default=4) 34 | parser.add_argument('--clip_model', type=str, default='ViT-B-16') 35 | 36 | # model 37 | parser.add_argument('--class_cond', type=lambda x: (str(x).lower() == 'true'), default=False) 38 | parser.add_argument('--text_cond', type=lambda x: (str(x).lower() == 'true'), default=False) 39 | parser.add_argument('--num_channels', type=int, default=128) 40 | parser.add_argument('--num_res_blocks', type=int, default=2) 41 | parser.add_argument('--num_heads', type=int, default=4) 42 | parser.add_argument('--num_heads_upsample', type=int, default=-1) 43 | parser.add_argument('--num_head_channels', type=int, default=-1) 44 | parser.add_argument('--attention_resolutions', type=str, default="16,8") 45 | parser.add_argument('--channel_mult', default="") 46 | parser.add_argument('--dropout', type=float, default=0.0) 47 | parser.add_argument('--use_checkpoint', type=lambda x: (str(x).lower() == 'true'), default=False) 48 | parser.add_argument('--use_scale_shift_norm', type=lambda x: (str(x).lower() == 'true'), default=True) 49 | parser.add_argument('--resblock_updown', type=lambda x: (str(x).lower() == 'true'), default=False) 50 | parser.add_argument('--use_new_attention_order', type=lambda x: (str(x).lower() == 'true'), default=False) 51 | 52 | # diffusion 53 | parser.add_argument('--learn_sigma', type=lambda x: (str(x).lower() == 'true'), default=False) 54 | parser.add_argument('--diffusion_steps', type=int, default=1000) 55 | parser.add_argument('--noise_schedule', type=str, default='linear') 56 | parser.add_argument('--timestep_respacing', default='') 57 | parser.add_argument('--use_kl', type=lambda x: (str(x).lower() == 'true'), default=False) 58 | parser.add_argument('--predict_xstart', type=lambda x: (str(x).lower() == 'true'), default=False) 59 | parser.add_argument('--rescale_timesteps',type=lambda x: (str(x).lower() == 'true'), default=False) 60 | parser.add_argument('--rescale_learned_sigmas', type=lambda x: (str(x).lower() == 'true'), default=False) 61 | 62 | # classifier 63 | parser.add_argument('--classifier_use_fp16', type=lambda x: (str(x).lower() == 'true'), default=False) 64 | parser.add_argument('--classifier_width', type=int, default=128) 65 | parser.add_argument('--classifier_depth', type=int, default=2) 66 | parser.add_argument('--classifier_attention_resolutions', type=str, default="32,16,8") 67 | parser.add_argument('--classifier_use_scale_shift_norm', type=lambda x: (str(x).lower() == 'true'), default=True) 68 | parser.add_argument('--classifier_resblock_updown', type=lambda x: (str(x).lower() == 'true'), default=True) 69 | parser.add_argument('--classifier_pool', type=str, default="attention") 70 | parser.add_argument('--num_classes', type=int, default=1000) 71 | 72 | # sr 73 | parser.add_argument('--large_size', type=int, default=256) 74 | parser.add_argument('--small_size', type=int, default=64) 75 | 76 | # train 77 | parser.add_argument('--batch_size', type=int, default=32) 78 | parser.add_argument('--microbatch', type=int, default=-1) 79 | parser.add_argument('--schedule_sampler', type=str, default='uniform') 80 | parser.add_argument('--lr', type=float, default=1e-4) 81 | parser.add_argument('--weight_decay', type=float, default=0.0) 82 | parser.add_argument('--lr_anneal_steps', type=int, default=0) 83 | parser.add_argument('--use_fp16', type=lambda x: (str(x).lower() == 'true'), default=False) 84 | parser.add_argument('--fp16_scale_growth', type=float, default=1e-3) 85 | parser.add_argument('--fp16_hyperparams', type=str, default='openai') 86 | parser.add_argument('--anneal_lr', type=lambda x: (str(x).lower() == 'true'), default=False) 87 | parser.add_argument('--iterations', type=int, default=500000) 88 | 89 | # save 90 | parser.add_argument('--ema_rate', default='0.9999') 91 | parser.add_argument('--log_interval', type=int, default=10) 92 | parser.add_argument('--save_interval', type=int, default=10000) 93 | parser.add_argument('--eval_interval', type=int, default=10) 94 | 95 | # inference 96 | parser.add_argument('--model_path', type=str, default='') 97 | parser.add_argument('--use_ddim', type=lambda x: (str(x).lower() == 'true'), default=False) 98 | parser.add_argument('--clip_denoised', type=lambda x: (str(x).lower() == 'true'), default=True) 99 | parser.add_argument('--text_weight', type=float, default=1.0) 100 | parser.add_argument('--image_weight', type=float, default=1.0) 101 | parser.add_argument('--image_loss', type=str, default='semantic') 102 | parser.add_argument('--text_instruction_file', type=str, default='ref/ffhq_instructions.txt') 103 | parser.add_argument('--clip_path', type=str, default='') 104 | 105 | 106 | # train classifier/clip 107 | parser.add_argument('--noised', type=lambda x: (str(x).lower() == 'true'), default=True) 108 | parser.add_argument('--finetune_clip_layer', type=str, default='all') 109 | 110 | # superres 111 | parser.add_argument('--base_name', type=str, default='') 112 | 113 | 114 | return parser 115 | 116 | -------------------------------------------------------------------------------- /sdg/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == 'linear': 18 | return LinearSampler(diffusion) 19 | elif name == 'half': 20 | return HalfSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | 32 | By default, samplers perform unbiased importance sampling, in which the 33 | objective's mean is unchanged. 34 | However, subclasses may override sample() to change how the resampled 35 | terms are reweighted, allowing for actual changes in the objective. 36 | """ 37 | 38 | @abstractmethod 39 | def weights(self): 40 | """ 41 | Get a numpy array of weights, one per diffusion step. 42 | 43 | The weights needn't be normalized, but must be positive. 44 | """ 45 | 46 | def sample(self, batch_size, device): 47 | """ 48 | Importance-sample timesteps for a batch. 49 | 50 | :param batch_size: the number of timesteps. 51 | :param device: the torch device to save to. 52 | :return: a tuple (timesteps, weights): 53 | - timesteps: a tensor of timestep indices. 54 | - weights: a tensor of weights to scale the resulting losses. 55 | """ 56 | w = self.weights() 57 | p = w / np.sum(w) 58 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 59 | indices = th.from_numpy(indices_np).long().to(device) 60 | weights_np = 1 / (len(p) * p[indices_np]) 61 | weights = th.from_numpy(weights_np).float().to(device) 62 | return indices, weights 63 | 64 | 65 | class UniformSampler(ScheduleSampler): 66 | def __init__(self, diffusion): 67 | self.diffusion = diffusion 68 | self._weights = np.ones([diffusion.num_timesteps]) 69 | 70 | def weights(self): 71 | return self._weights 72 | 73 | class LinearSampler(ScheduleSampler): 74 | # higher weight on noisy images and lower weight on clean images 75 | def __init__(self, diffusion): 76 | self.diffusion = diffusion 77 | self._weights = np.linspace(1, 10, diffusion.num_timesteps) 78 | 79 | def weights(self): 80 | return self._weights 81 | 82 | class HalfSampler(ScheduleSampler): 83 | def __init__(self, diffusion): 84 | self.diffusion = diffusion 85 | self._weights = np.ones([diffusion.num_timesteps]) 86 | self._weights[diffusion.num_timesteps//2:] = 0 87 | 88 | def weights(self): 89 | return self._weights 90 | 91 | 92 | 93 | 94 | class LossAwareSampler(ScheduleSampler): 95 | def update_with_local_losses(self, local_ts, local_losses): 96 | """ 97 | Update the reweighting using losses from a model. 98 | 99 | Call this method from each rank with a batch of timesteps and the 100 | corresponding losses for each of those timesteps. 101 | This method will perform synchronization to make sure all of the ranks 102 | maintain the exact same reweighting. 103 | 104 | :param local_ts: an integer Tensor of timesteps. 105 | :param local_losses: a 1D Tensor of losses. 106 | """ 107 | batch_sizes = [ 108 | th.tensor([0], dtype=th.int32, device=local_ts.device) 109 | for _ in range(dist.get_world_size()) 110 | ] 111 | dist.all_gather( 112 | batch_sizes, 113 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 114 | ) 115 | 116 | # Pad all_gather batches to be the maximum batch size. 117 | batch_sizes = [x.item() for x in batch_sizes] 118 | max_bs = max(batch_sizes) 119 | 120 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 121 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 122 | dist.all_gather(timestep_batches, local_ts) 123 | dist.all_gather(loss_batches, local_losses) 124 | timesteps = [ 125 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 126 | ] 127 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 128 | self.update_with_all_losses(timesteps, losses) 129 | 130 | @abstractmethod 131 | def update_with_all_losses(self, ts, losses): 132 | """ 133 | Update the reweighting using losses from a model. 134 | 135 | Sub-classes should override this method to update the reweighting 136 | using losses from the model. 137 | 138 | This method directly updates the reweighting without synchronizing 139 | between workers. It is called by update_with_local_losses from all 140 | ranks with identical arguments. Thus, it should have deterministic 141 | behavior to maintain state across workers. 142 | 143 | :param ts: a list of int timesteps. 144 | :param losses: a list of float losses, one per timestep. 145 | """ 146 | 147 | 148 | class LossSecondMomentResampler(LossAwareSampler): 149 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 150 | self.diffusion = diffusion 151 | self.history_per_term = history_per_term 152 | self.uniform_prob = uniform_prob 153 | self._loss_history = np.zeros( 154 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 155 | ) 156 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 157 | 158 | def weights(self): 159 | if not self._warmed_up(): 160 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 161 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 162 | weights /= np.sum(weights) 163 | weights *= 1 - self.uniform_prob 164 | weights += self.uniform_prob / len(weights) 165 | return weights 166 | 167 | def update_with_all_losses(self, ts, losses): 168 | for t, loss in zip(ts, losses): 169 | if self._loss_counts[t] == self.history_per_term: 170 | # Shift out the oldest loss term. 171 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 172 | self._loss_history[t, -1] = loss 173 | else: 174 | self._loss_history[t, self._loss_counts[t]] = loss 175 | self._loss_counts[t] += 1 176 | 177 | def _warmed_up(self): 178 | return (self._loss_counts == self.history_per_term).all() 179 | -------------------------------------------------------------------------------- /sdg/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = (new_ts.float() * (1000.0 / self.original_num_steps)).long() 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /sdg/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | from .clip_guidance import CLIP_gd 8 | 9 | NUM_CLASSES = 1000 10 | 11 | 12 | def diffusion_defaults(): 13 | """ 14 | Defaults for image and classifier training. 15 | """ 16 | return dict( 17 | learn_sigma=False, 18 | diffusion_steps=1000, 19 | noise_schedule="linear", 20 | timestep_respacing="", 21 | use_kl=False, 22 | predict_xstart=False, 23 | rescale_timesteps=False, 24 | rescale_learned_sigmas=False, 25 | ) 26 | 27 | 28 | def classifier_defaults(): 29 | """ 30 | Defaults for classifier models. 31 | """ 32 | return dict( 33 | image_size=64, 34 | classifier_use_fp16=False, 35 | classifier_width=128, 36 | classifier_depth=2, 37 | classifier_attention_resolutions="32,16,8", # 16 38 | classifier_use_scale_shift_norm=True, # False 39 | classifier_resblock_updown=True, # False 40 | classifier_pool="attention", 41 | num_classes=1000, 42 | ) 43 | 44 | 45 | def model_and_diffusion_defaults(): 46 | """ 47 | Defaults for image training. 48 | """ 49 | res = dict( 50 | image_size=64, 51 | num_channels=128, 52 | num_res_blocks=2, 53 | num_heads=4, 54 | num_heads_upsample=-1, 55 | num_head_channels=-1, 56 | attention_resolutions="16,8", 57 | channel_mult="", 58 | dropout=0.0, 59 | class_cond=False, 60 | text_cond=False, 61 | use_checkpoint=False, 62 | use_scale_shift_norm=True, 63 | resblock_updown=False, 64 | use_fp16=False, 65 | use_new_attention_order=False, 66 | ) 67 | res.update(diffusion_defaults()) 68 | return res 69 | 70 | 71 | def classifier_and_diffusion_defaults(): 72 | res = classifier_defaults() 73 | res.update(diffusion_defaults()) 74 | return res 75 | 76 | 77 | def create_model_and_diffusion( 78 | image_size, 79 | class_cond, 80 | text_cond, 81 | learn_sigma, 82 | num_channels, 83 | num_res_blocks, 84 | channel_mult, 85 | num_heads, 86 | num_head_channels, 87 | num_heads_upsample, 88 | attention_resolutions, 89 | dropout, 90 | diffusion_steps, 91 | noise_schedule, 92 | timestep_respacing, 93 | use_kl, 94 | predict_xstart, 95 | rescale_timesteps, 96 | rescale_learned_sigmas, 97 | use_checkpoint, 98 | use_scale_shift_norm, 99 | resblock_updown, 100 | use_fp16, 101 | use_new_attention_order, 102 | ): 103 | model = create_model( 104 | image_size, 105 | num_channels, 106 | num_res_blocks, 107 | channel_mult=channel_mult, 108 | learn_sigma=learn_sigma, 109 | class_cond=class_cond, 110 | text_cond=text_cond, 111 | use_checkpoint=use_checkpoint, 112 | attention_resolutions=attention_resolutions, 113 | num_heads=num_heads, 114 | num_head_channels=num_head_channels, 115 | num_heads_upsample=num_heads_upsample, 116 | use_scale_shift_norm=use_scale_shift_norm, 117 | dropout=dropout, 118 | resblock_updown=resblock_updown, 119 | use_fp16=use_fp16, 120 | use_new_attention_order=use_new_attention_order, 121 | ) 122 | diffusion = create_gaussian_diffusion( 123 | steps=diffusion_steps, 124 | learn_sigma=learn_sigma, 125 | noise_schedule=noise_schedule, 126 | use_kl=use_kl, 127 | predict_xstart=predict_xstart, 128 | rescale_timesteps=rescale_timesteps, 129 | rescale_learned_sigmas=rescale_learned_sigmas, 130 | timestep_respacing=timestep_respacing, 131 | ) 132 | return model, diffusion 133 | 134 | 135 | def create_model( 136 | image_size, 137 | num_channels, 138 | num_res_blocks, 139 | channel_mult="", 140 | learn_sigma=False, 141 | class_cond=False, 142 | text_cond=False, 143 | use_checkpoint=False, 144 | attention_resolutions="16", 145 | num_heads=1, 146 | num_head_channels=-1, 147 | num_heads_upsample=-1, 148 | use_scale_shift_norm=False, 149 | dropout=0, 150 | resblock_updown=False, 151 | use_fp16=False, 152 | use_new_attention_order=False, 153 | ): 154 | if channel_mult == "": 155 | if image_size == 512: 156 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 157 | elif image_size == 256: 158 | channel_mult = (1, 1, 2, 2, 4, 4) 159 | elif image_size == 128: 160 | channel_mult = (1, 1, 2, 3, 4) 161 | elif image_size == 64: 162 | channel_mult = (1, 2, 3, 4) 163 | else: 164 | raise ValueError(f"unsupported image size: {image_size}") 165 | else: 166 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 167 | 168 | attention_ds = [] 169 | for res in attention_resolutions.split(","): 170 | attention_ds.append(image_size // int(res)) 171 | 172 | return UNetModel( 173 | image_size=image_size, 174 | in_channels=3, 175 | model_channels=num_channels, 176 | out_channels=(3 if not learn_sigma else 6), 177 | num_res_blocks=num_res_blocks, 178 | attention_resolutions=tuple(attention_ds), 179 | dropout=dropout, 180 | channel_mult=channel_mult, 181 | num_classes=(NUM_CLASSES if class_cond else None), 182 | text_cond=text_cond, 183 | use_checkpoint=use_checkpoint, 184 | use_fp16=use_fp16, 185 | num_heads=num_heads, 186 | num_head_channels=num_head_channels, 187 | num_heads_upsample=num_heads_upsample, 188 | use_scale_shift_norm=use_scale_shift_norm, 189 | resblock_updown=resblock_updown, 190 | use_new_attention_order=use_new_attention_order, 191 | ) 192 | 193 | 194 | def create_classifier_and_diffusion( 195 | image_size, 196 | classifier_use_fp16, 197 | classifier_width, 198 | classifier_depth, 199 | classifier_attention_resolutions, 200 | classifier_use_scale_shift_norm, 201 | classifier_resblock_updown, 202 | classifier_pool, 203 | num_classes, 204 | learn_sigma, 205 | diffusion_steps, 206 | noise_schedule, 207 | timestep_respacing, 208 | use_kl, 209 | predict_xstart, 210 | rescale_timesteps, 211 | rescale_learned_sigmas, 212 | ): 213 | classifier = create_classifier( 214 | image_size, 215 | classifier_use_fp16, 216 | classifier_width, 217 | classifier_depth, 218 | classifier_attention_resolutions, 219 | classifier_use_scale_shift_norm, 220 | classifier_resblock_updown, 221 | classifier_pool, 222 | num_classes, 223 | ) 224 | diffusion = create_gaussian_diffusion( 225 | steps=diffusion_steps, 226 | learn_sigma=learn_sigma, 227 | noise_schedule=noise_schedule, 228 | use_kl=use_kl, 229 | predict_xstart=predict_xstart, 230 | rescale_timesteps=rescale_timesteps, 231 | rescale_learned_sigmas=rescale_learned_sigmas, 232 | timestep_respacing=timestep_respacing, 233 | ) 234 | return classifier, diffusion 235 | 236 | 237 | def create_clip_and_diffusion( 238 | args, 239 | image_size, 240 | classifier_use_fp16, 241 | classifier_width, 242 | classifier_depth, 243 | classifier_attention_resolutions, 244 | classifier_use_scale_shift_norm, 245 | classifier_resblock_updown, 246 | classifier_pool, 247 | num_classes, 248 | learn_sigma, 249 | diffusion_steps, 250 | noise_schedule, 251 | timestep_respacing, 252 | use_kl, 253 | predict_xstart, 254 | rescale_timesteps, 255 | rescale_learned_sigmas, 256 | ): 257 | clip = create_clip(args) 258 | diffusion = create_gaussian_diffusion( 259 | steps=diffusion_steps, 260 | learn_sigma=learn_sigma, 261 | noise_schedule=noise_schedule, 262 | use_kl=use_kl, 263 | predict_xstart=predict_xstart, 264 | rescale_timesteps=rescale_timesteps, 265 | rescale_learned_sigmas=rescale_learned_sigmas, 266 | timestep_respacing=timestep_respacing, 267 | ) 268 | return clip, diffusion 269 | 270 | 271 | def create_classifier( 272 | image_size, 273 | classifier_use_fp16, 274 | classifier_width, 275 | classifier_depth, 276 | classifier_attention_resolutions, 277 | classifier_use_scale_shift_norm, 278 | classifier_resblock_updown, 279 | classifier_pool, 280 | num_classes, 281 | ): 282 | if image_size == 512: 283 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 284 | elif image_size == 256: 285 | channel_mult = (1, 1, 2, 2, 4, 4) 286 | elif image_size == 128: 287 | channel_mult = (1, 1, 2, 3, 4) 288 | elif image_size == 64: 289 | channel_mult = (1, 2, 3, 4) 290 | else: 291 | raise ValueError(f"unsupported image size: {image_size}") 292 | 293 | attention_ds = [] 294 | for res in classifier_attention_resolutions.split(","): 295 | attention_ds.append(image_size // int(res)) 296 | 297 | return EncoderUNetModel( 298 | image_size=image_size, 299 | in_channels=3, 300 | model_channels=classifier_width, 301 | out_channels=num_classes, 302 | num_res_blocks=classifier_depth, 303 | attention_resolutions=tuple(attention_ds), 304 | channel_mult=channel_mult, 305 | use_fp16=classifier_use_fp16, 306 | num_head_channels=64, 307 | use_scale_shift_norm=classifier_use_scale_shift_norm, 308 | resblock_updown=classifier_resblock_updown, 309 | pool=classifier_pool, 310 | ) 311 | 312 | def create_clip(args): 313 | return CLIP_gd(args) 314 | 315 | 316 | def sr_model_and_diffusion_defaults(): 317 | res = model_and_diffusion_defaults() 318 | res["large_size"] = 256 319 | res["small_size"] = 64 320 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 321 | for k in res.copy().keys(): 322 | if k not in arg_names: 323 | del res[k] 324 | return res 325 | 326 | 327 | def sr_create_model_and_diffusion( 328 | large_size, 329 | small_size, 330 | class_cond, 331 | text_cond, 332 | learn_sigma, 333 | num_channels, 334 | num_res_blocks, 335 | num_heads, 336 | num_head_channels, 337 | num_heads_upsample, 338 | attention_resolutions, 339 | dropout, 340 | diffusion_steps, 341 | noise_schedule, 342 | timestep_respacing, 343 | use_kl, 344 | predict_xstart, 345 | rescale_timesteps, 346 | rescale_learned_sigmas, 347 | use_checkpoint, 348 | use_scale_shift_norm, 349 | resblock_updown, 350 | use_fp16, 351 | ): 352 | model = sr_create_model( 353 | large_size, 354 | small_size, 355 | num_channels, 356 | num_res_blocks, 357 | learn_sigma=learn_sigma, 358 | class_cond=class_cond, 359 | text_cond=text_cond, 360 | use_checkpoint=use_checkpoint, 361 | attention_resolutions=attention_resolutions, 362 | num_heads=num_heads, 363 | num_head_channels=num_head_channels, 364 | num_heads_upsample=num_heads_upsample, 365 | use_scale_shift_norm=use_scale_shift_norm, 366 | dropout=dropout, 367 | resblock_updown=resblock_updown, 368 | use_fp16=use_fp16, 369 | ) 370 | diffusion = create_gaussian_diffusion( 371 | steps=diffusion_steps, 372 | learn_sigma=learn_sigma, 373 | noise_schedule=noise_schedule, 374 | use_kl=use_kl, 375 | predict_xstart=predict_xstart, 376 | rescale_timesteps=rescale_timesteps, 377 | rescale_learned_sigmas=rescale_learned_sigmas, 378 | timestep_respacing=timestep_respacing, 379 | ) 380 | return model, diffusion 381 | 382 | 383 | def sr_create_model( 384 | large_size, 385 | small_size, 386 | num_channels, 387 | num_res_blocks, 388 | learn_sigma, 389 | class_cond, 390 | text_cond, 391 | use_checkpoint, 392 | attention_resolutions, 393 | num_heads, 394 | num_head_channels, 395 | num_heads_upsample, 396 | use_scale_shift_norm, 397 | dropout, 398 | resblock_updown, 399 | use_fp16, 400 | ): 401 | _ = small_size # hack to prevent unused variable 402 | 403 | if large_size == 512: 404 | channel_mult = (1, 1, 2, 2, 4, 4) 405 | elif large_size == 256: 406 | channel_mult = (1, 1, 2, 2, 4, 4) 407 | elif large_size == 64: 408 | channel_mult = (1, 2, 3, 4) 409 | else: 410 | raise ValueError(f"unsupported large size: {large_size}") 411 | 412 | attention_ds = [] 413 | for res in attention_resolutions.split(","): 414 | attention_ds.append(large_size // int(res)) 415 | 416 | return SuperResModel( 417 | image_size=large_size, 418 | in_channels=3, 419 | model_channels=num_channels, 420 | out_channels=(3 if not learn_sigma else 6), 421 | num_res_blocks=num_res_blocks, 422 | attention_resolutions=tuple(attention_ds), 423 | dropout=dropout, 424 | channel_mult=channel_mult, 425 | num_classes=(NUM_CLASSES if class_cond else None), 426 | text_cond=text_cond, 427 | use_checkpoint=use_checkpoint, 428 | num_heads=num_heads, 429 | num_head_channels=num_head_channels, 430 | num_heads_upsample=num_heads_upsample, 431 | use_scale_shift_norm=use_scale_shift_norm, 432 | resblock_updown=resblock_updown, 433 | use_fp16=use_fp16, 434 | ) 435 | 436 | 437 | def create_gaussian_diffusion( 438 | *, 439 | steps=1000, 440 | learn_sigma=False, 441 | sigma_small=False, 442 | noise_schedule="linear", 443 | use_kl=False, 444 | predict_xstart=False, 445 | rescale_timesteps=False, 446 | rescale_learned_sigmas=False, 447 | timestep_respacing="", 448 | ): 449 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 450 | betas1000 = gd.get_named_beta_schedule(noise_schedule, 1000) 451 | if use_kl: 452 | loss_type = gd.LossType.RESCALED_KL 453 | elif rescale_learned_sigmas: 454 | loss_type = gd.LossType.RESCALED_MSE 455 | else: 456 | loss_type = gd.LossType.MSE 457 | if not timestep_respacing: 458 | timestep_respacing = [steps] 459 | return SpacedDiffusion( 460 | use_timesteps=space_timesteps(steps, timestep_respacing), 461 | betas=betas, 462 | model_mean_type=( 463 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 464 | ), 465 | model_var_type=( 466 | ( 467 | gd.ModelVarType.FIXED_LARGE 468 | if not sigma_small 469 | else gd.ModelVarType.FIXED_SMALL 470 | ) 471 | if not learn_sigma 472 | else gd.ModelVarType.LEARNED_RANGE 473 | ), 474 | loss_type=loss_type, 475 | rescale_timesteps=rescale_timesteps, 476 | betas1000=betas1000, 477 | ) 478 | 479 | 480 | def add_dict_to_argparser(parser, default_dict): 481 | for k, v in default_dict.items(): 482 | v_type = type(v) 483 | if v is None: 484 | v_type = str 485 | elif isinstance(v, bool): 486 | v_type = str2bool 487 | parser.add_argument(f"--{k}", default=v, type=v_type) 488 | 489 | 490 | def args_to_dict(args, keys): 491 | return {k: getattr(args, k) for k in keys} 492 | 493 | 494 | def str2bool(v): 495 | """ 496 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 497 | """ 498 | if isinstance(v, bool): 499 | return v 500 | if v.lower() in ("yes", "true", "t", "y", "1"): 501 | return True 502 | elif v.lower() in ("no", "false", "f", "n", "0"): 503 | return False 504 | else: 505 | raise argparse.ArgumentTypeError("boolean value expected") 506 | -------------------------------------------------------------------------------- /sdg/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import time 5 | import glob 6 | 7 | import blobfile as bf 8 | import torch as th 9 | import torch.distributed as dist 10 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 11 | from torch.optim import AdamW 12 | 13 | from .distributed import get_world_size, is_master 14 | from .distributed import master_only_print as print 15 | from .fp16_util import MixedPrecisionTrainer 16 | from .nn import update_ema 17 | from .resample import LossAwareSampler, UniformSampler 18 | from .logging import write_tb 19 | from .misc import to_cuda 20 | from . import logger 21 | 22 | # For ImageNet experiments, this was a good default value. 23 | # We found that the lg_loss_scale quickly climbed to 24 | # 20-21 within the first ~1K steps of training. 25 | INITIAL_LOG_LOSS_SCALE = 20.0 26 | 27 | 28 | class TrainLoop: 29 | def __init__( 30 | self, 31 | cfg, 32 | model, 33 | diffusion, 34 | data_train, 35 | tb_log, 36 | schedule_sampler=None, 37 | ): 38 | self.tb_log = tb_log 39 | self.model = model 40 | self.diffusion = diffusion 41 | self.data_train = data_train 42 | self.batch_size = cfg.batch_size 43 | self.microbatch = cfg.microbatch if cfg.microbatch > 0 else cfg.batch_size 44 | self.lr = cfg.lr 45 | self.ema_rate = ( 46 | [cfg.ema_rate] 47 | if isinstance(cfg.ema_rate, float) 48 | else [float(x) for x in cfg.ema_rate.split(",")] 49 | ) 50 | self.log_interval = cfg.log_interval 51 | self.save_interval = cfg.save_interval 52 | self.resume_checkpoint = cfg.resume_checkpoint 53 | self.use_fp16 = getattr(cfg, 'use_fp16', False) 54 | self.fp16_scale_growth = getattr(cfg, 'fp16_scale_growth', 1e-3) 55 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 56 | self.weight_decay = getattr(cfg, 'weight_decay', 0.0) 57 | self.lr_anneal_steps = getattr(cfg, 'lr_anneal_steps', 0) 58 | self.logdir = getattr(cfg, 'logdir', 'logs/debug') 59 | fp16_hyperparams = getattr(cfg, 'fp16_hyperparams', 'openai') 60 | if self.use_fp16: 61 | if fp16_hyperparams == 'pytorch': 62 | self.scaler = th.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True) 63 | elif fp16_hyperparams == 'openai': 64 | self.scaler = th.cuda.amp.GradScaler(init_scale=2**20 * 1.0, growth_factor=2**0.001, backoff_factor=0.5, growth_interval=1, enabled=True) 65 | 66 | self.step = 0 67 | self.resume_step = 0 68 | self.global_batch = self.batch_size * get_world_size() 69 | 70 | self.sync_cuda = True # th.cuda.is_available() 71 | 72 | self._load_and_sync_parameters() 73 | 74 | self.params = list(self.model.parameters()) 75 | self.opt = AdamW(self.params, lr=self.lr, weight_decay=self.weight_decay) 76 | if self.resume_step: 77 | self._load_optimizer_state() 78 | # Model was resumed, either due to a restart or a checkpoint 79 | # being specified at the command line. 80 | self.ema_params = [ 81 | self._load_ema_parameters(rate) for rate in self.ema_rate 82 | ] 83 | else: 84 | self.ema_params = [ 85 | copy.deepcopy(self.params) 86 | for _ in range(len(self.ema_rate)) 87 | ] 88 | 89 | self.use_ddp = th.cuda.is_available() and th.distributed.is_available() and dist.is_initialized() 90 | if self.use_ddp: 91 | self.ddp_model = DDP( 92 | self.model, 93 | device_ids=[cfg.local_rank], 94 | output_device=cfg.local_rank, 95 | bucket_cap_mb=128, 96 | broadcast_buffers=False, 97 | find_unused_parameters=False, 98 | ) 99 | else: 100 | print( 101 | "Single GPU Training without DistributedDataParallel. " 102 | ) 103 | self.ddp_model = self.model 104 | 105 | def _load_and_sync_parameters(self): 106 | resume_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint 107 | 108 | if resume_checkpoint: 109 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 110 | print(f"loading model from checkpoint: {resume_checkpoint}...") 111 | checkpoint = th.load(resume_checkpoint, map_location=lambda storage, loc: storage) 112 | self.model.load_state_dict(checkpoint) 113 | 114 | def _load_ema_parameters(self, rate): 115 | ema_params = copy.deepcopy(self.params) 116 | 117 | main_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint 118 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 119 | if ema_checkpoint: 120 | print(f"loading EMA from checkpoint: {ema_checkpoint}...") 121 | state_dict = th.load(ema_checkpoint, map_location=lambda storage, loc: storage) 122 | ema_params = to_cuda([state_dict[name] for name, _ in self.model.named_parameters()]) 123 | 124 | return ema_params 125 | 126 | def _load_optimizer_state(self): 127 | main_checkpoint = find_resume_checkpoint(self.logdir) or self.resume_checkpoint 128 | opt_checkpoint = bf.join( 129 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 130 | ) 131 | if bf.exists(opt_checkpoint): 132 | print(f"loading optimizer state from checkpoint: {opt_checkpoint}") 133 | checkpoint = th.load(opt_checkpoint, map_location=lambda storage, loc: storage) 134 | self.opt.load_state_dict(checkpoint) 135 | 136 | def _compute_norms(self, grad_scale=1.0): 137 | grad_norm = 0.0 138 | param_norm = 0.0 139 | for p in self.params: 140 | with th.no_grad(): 141 | param_norm += th.norm(p, p=2, dtype=th.float32) ** 2 142 | if p.grad is not None: 143 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32) ** 2 144 | return th.sqrt(grad_norm) / grad_scale, th.sqrt(param_norm) 145 | 146 | def run_loop(self): 147 | time0 = time.time() 148 | while ( 149 | not self.lr_anneal_steps 150 | or self.step + self.resume_step < self.lr_anneal_steps 151 | ): 152 | print('***step %d ' % (self.step + self.resume_step), end='') 153 | batch, cond = next(self.data_train) 154 | time2 = time.time() 155 | if self.use_fp16: 156 | self.run_step_amp(batch, cond) 157 | else: 158 | self.run_step(batch, cond) 159 | if self.step % self.log_interval == 0 and is_master(): 160 | self.tb_log.flush() 161 | if self.step % self.save_interval == 0: 162 | self.save() 163 | # Run for a finite amount of time in integration tests. 164 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 165 | return 166 | num_samples = (self.step + self.resume_step + 1) * self.global_batch 167 | time1 = time.time() 168 | if is_master(): 169 | self.tb_log.add_scalar('status/step', self.step+self.resume_step, self.step + self.resume_step) 170 | self.tb_log.add_scalar('status/samples', num_samples, self.step + self.resume_step) 171 | self.tb_log.add_scalar('time/time_per_iter', time1-time0, self.step + self.resume_step) 172 | self.tb_log.add_scalar('time/data_time_per_iter', time2-time0, self.step + self.resume_step) 173 | self.tb_log.add_scalar('time/model_time_per_iter', time1-time2, self.step + self.resume_step) 174 | self.tb_log.add_scalar('status/lr', self.lr, self.step + self.resume_step) 175 | print('lr: %f ' % self.lr, end='') 176 | print('samples: %d ' % num_samples, end='') 177 | print('data time: %f ' % (time2-time0), end='') 178 | print('model time: %f ' % (time1-time2), end='') 179 | print('') 180 | self.step += 1 181 | time0 = time1 182 | # Save the last checkpoint if it wasn't already saved. 183 | if (self.step - 1) % self.save_interval != 0: 184 | self.save() 185 | 186 | def run_step(self, batch, cond): 187 | self.forward_backward(batch, cond) 188 | self.opt.step() 189 | self._update_ema() 190 | self._anneal_lr() 191 | 192 | def forward_backward(self, batch, cond): 193 | self.opt.zero_grad() 194 | for i in range(0, batch.shape[0], self.microbatch): 195 | micro = to_cuda(batch[i : i + self.microbatch]) 196 | micro_cond = { 197 | k: to_cuda(v[i : i + self.microbatch]) 198 | for k, v in cond.items() 199 | } 200 | last_batch = (i + self.microbatch) >= batch.shape[0] 201 | t, weights = self.schedule_sampler.sample(micro.shape[0], 'cuda') 202 | 203 | compute_losses = functools.partial( 204 | self.diffusion.training_losses, 205 | self.ddp_model, 206 | micro, 207 | t, 208 | model_kwargs=micro_cond, 209 | ) 210 | 211 | if last_batch or not self.use_ddp: 212 | losses = compute_losses() 213 | else: 214 | with self.ddp_model.no_sync(): 215 | losses = compute_losses() 216 | 217 | if isinstance(self.schedule_sampler, LossAwareSampler): 218 | self.schedule_sampler.update_with_local_losses( 219 | t, losses["loss"].detach() 220 | ) 221 | 222 | loss = (losses["loss"] * weights).mean() 223 | if self.step % 10 == 0: 224 | log_loss_dict( 225 | self.diffusion, t, {k: v * weights for k, v in losses.items()}, 226 | self.tb_log, self.step + self.resume_step 227 | ) 228 | loss.backward() 229 | 230 | def run_step_amp(self, batch, cond): 231 | self.forward_backward_amp(batch, cond) 232 | self.scaler.step(self.opt) 233 | self.scaler.update() 234 | self._update_ema() 235 | self._anneal_lr() 236 | 237 | def forward_backward_amp(self, batch, cond): 238 | self.opt.zero_grad() 239 | for i in range(0, batch.shape[0], self.microbatch): 240 | micro = to_cuda(batch[i : i + self.microbatch]) 241 | micro_cond = { 242 | k: to_cuda(v[i : i + self.microbatch]) 243 | for k, v in cond.items() 244 | } 245 | last_batch = (i + self.microbatch) >= batch.shape[0] 246 | t, weights = self.schedule_sampler.sample(micro.shape[0], 'cuda') 247 | 248 | with th.cuda.amp.autocast(True): 249 | compute_losses = functools.partial( 250 | self.diffusion.training_losses, 251 | self.ddp_model, 252 | micro, 253 | t, 254 | model_kwargs=micro_cond, 255 | ) 256 | 257 | if last_batch or not self.use_ddp: 258 | losses = compute_losses() 259 | else: 260 | with self.ddp_model.no_sync(): 261 | losses = compute_losses() 262 | 263 | if isinstance(self.schedule_sampler, LossAwareSampler): 264 | self.schedule_sampler.update_with_local_losses( 265 | t, losses["loss"].detach() 266 | ) 267 | 268 | loss = (losses["loss"] * weights).mean() 269 | if self.step % 10 == 0: 270 | log_loss_dict( 271 | self.diffusion, t, {k: v * weights for k, v in losses.items()}, 272 | self.tb_log, self.step + self.resume_step 273 | ) 274 | self.scaler.scale(loss).backward() 275 | 276 | 277 | 278 | def _update_ema(self): 279 | for rate, params in zip(self.ema_rate, self.ema_params): 280 | update_ema(params, self.params, rate=rate) 281 | 282 | def _anneal_lr(self): 283 | if not self.lr_anneal_steps: 284 | return 285 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 286 | lr = self.lr * (1 - frac_done) 287 | for param_group in self.opt.param_groups: 288 | param_group["lr"] = lr 289 | 290 | def log_step(self): 291 | logger.logkv("step", self.step + self.resume_step) 292 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 293 | 294 | def save(self): 295 | def save_checkpoint(rate): 296 | print(f"saving model {rate}...") 297 | if is_master(): 298 | if not rate: 299 | filename = f"model{(self.step+self.resume_step):06d}.pt" 300 | else: 301 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 302 | with bf.BlobFile(bf.join(self.logdir, filename), "wb") as f: 303 | th.save(self.model.state_dict(), f) 304 | 305 | save_checkpoint(0) 306 | for rate, params in zip(self.ema_rate, self.ema_params): 307 | save_checkpoint(rate) 308 | 309 | if is_master(): 310 | with bf.BlobFile( 311 | bf.join(self.logdir, f"opt{(self.step+self.resume_step):06d}.pt"), 312 | "wb", 313 | ) as f: 314 | th.save(self.opt.state_dict(), f) 315 | 316 | if is_master() and self.use_fp16: 317 | with bf.BlobFile( 318 | bf.join(self.logdir, f"scaler{(self.step+self.resume_step):06d}.pt"), 319 | "wb", 320 | ) as f: 321 | th.save(self.scaler.state_dict(), f) 322 | 323 | 324 | def parse_resume_step_from_filename(filename): 325 | """ 326 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 327 | checkpoint's number of steps. 328 | """ 329 | split = filename.split("model") 330 | if len(split) < 2: 331 | return 0 332 | split1 = split[-1].split(".")[0] 333 | try: 334 | return int(split1) 335 | except ValueError: 336 | return 0 337 | 338 | 339 | def get_blob_logdir(): 340 | # You can change this to be a separate path to save checkpoints to 341 | # a blobstore or some external drive. 342 | return logger.get_dir() 343 | 344 | 345 | def find_resume_checkpoint(logdir): 346 | # On your infrastructure, you may want to override this to automatically 347 | # discover the latest checkpoint on your blob storage, etc. 348 | models = sorted(glob.glob(os.path.join(logdir, 'model*.pt'))) 349 | if len(models) >= 1: 350 | return models[-1] 351 | else: 352 | return None 353 | 354 | 355 | def find_ema_checkpoint(main_checkpoint, step, rate): 356 | if main_checkpoint is None: 357 | return None 358 | filename = f"ema_{rate}_{(step):06d}.pt" 359 | path = bf.join(bf.dirname(main_checkpoint), filename) 360 | if bf.exists(path): 361 | return path 362 | return None 363 | 364 | 365 | def log_loss_dict(diffusion, ts, losses, tb_log, step, prefix='loss'): 366 | for key, values in losses.items(): 367 | write_tb(tb_log, f"{prefix}/{key}", values, step) 368 | quartile_list = [[] for cnt in range(4)] 369 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 370 | quartile = int(4 * sub_t / diffusion.num_timesteps) 371 | quartile_list[quartile].append(sub_loss) 372 | for cnt in range(4): 373 | if len(quartile_list[cnt]) != 0: 374 | write_tb(tb_log, f"{prefix}/{key}_q{cnt}", sum(quartile_list[cnt])/len(quartile_list[cnt]), step) 375 | else: 376 | write_tb(tb_log, f"{prefix}/{key}_q{cnt}", 0.0, step) 377 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="sdg", 5 | py_modules=["sdg"], 6 | install_requires=["blobfile>=1.0.5", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xh-liu/SDG_code/62fb2725035f2cc327bcfc3710384d9cec0dac3c/teaser.png --------------------------------------------------------------------------------