├── .gitignore ├── LICENSE ├── README.md ├── ade20k_psiphi.npy ├── asset ├── framework.png └── motivation.png ├── celeba_psiphi.npy ├── coco_psiphi.npy ├── environment.yaml ├── evaluations ├── fid │ ├── LICENSE │ ├── __init__.py │ ├── __pycache__ │ │ └── inception.cpython-37.pyc │ ├── inception.py │ └── tests_with_FID.py ├── lpips │ ├── default_dataset.py │ ├── lpips.py │ ├── lpips_weights.ckpt │ └── make_lpips_list.py └── miou │ ├── celeba_config.yaml │ ├── coco_config.yaml │ ├── test_celeba.py │ └── test_coco.py ├── guided_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── losses.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── image_process ├── downsample.py ├── edge.py └── random.py ├── image_sample.py ├── image_train.py └── scripts ├── sample.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 MLV Lab (Machine Learning and Vision Lab at Korea University) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Conditional Diffusion Models for Robust Semantic Image Synthesis 2 | 3 | Official PyTorch implementation of "[Stochastic Conditional Diffusion Models for Robust Semantic Image Synthesis](https://arxiv.org/abs/2402.16506)" ([ICML 2024](https://proceedings.mlr.press/v235/ko24e.html)). 4 | > Juyeon Ko*, Inho Kong*, Dogyun Park, Hyunwoo J. Kim†. 5 | > 6 | > Department of Computer Science and Engineering, Korea University 7 | 8 | ![SCDM Framework](./asset/framework.png) 9 | 10 | ![SCDM Motivation](./asset/motivation.png) 11 | 12 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/conditional-image-generation-on-celebamask-hq)](https://paperswithcode.com/sota/conditional-image-generation-on-celebamask-hq?p=stochastic-conditional-diffusion-models-for) 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/image-to-image-translation-on-ade20k-labels)](https://paperswithcode.com/sota/image-to-image-translation-on-ade20k-labels?p=stochastic-conditional-diffusion-models-for) 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/image-to-image-translation-on-coco-stuff)](https://paperswithcode.com/sota/image-to-image-translation-on-coco-stuff?p=stochastic-conditional-diffusion-models-for) 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/noisy-semantic-image-synthesis-on-noisy)](https://paperswithcode.com/sota/noisy-semantic-image-synthesis-on-noisy?p=stochastic-conditional-diffusion-models-for) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/noisy-semantic-image-synthesis-on-noisy-1)](https://paperswithcode.com/sota/noisy-semantic-image-synthesis-on-noisy-1?p=stochastic-conditional-diffusion-models-for) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stochastic-conditional-diffusion-models-for/noisy-semantic-image-synthesis-on-noisy-2)](https://paperswithcode.com/sota/noisy-semantic-image-synthesis-on-noisy-2?p=stochastic-conditional-diffusion-models-for) 18 | 19 | ## Setup 20 | - Clone repository 21 | 22 | ```bash 23 | git clone https://github.com/mlvlab/SCDM.git 24 | cd SCDM 25 | ``` 26 | - Setup conda environment 27 | 28 | ```bash 29 | conda env create -f environment.yaml 30 | conda activate scdm 31 | ``` 32 | 33 | ## Dataset Preparation 34 | 35 | ### Standard (clean) SIS 36 | - **CelebAMask-HQ** can be downloaded from [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ). The dataset should be structured as below: 37 | ``` 38 | CelebAMask/ 39 | ├─ train/ 40 | │ ├─ images/ 41 | │ │ ├─ 0.jpg 42 | │ │ ├─ ... 43 | │ │ ├─ 27999.jpg 44 | │ ├─ labels/ 45 | │ │ ├─ 0.png 46 | │ │ ├─ ... 47 | │ │ ├─ 27999.png 48 | ├─ test/ 49 | │ ├─ images/ 50 | │ │ ├─ 28000.jpg 51 | │ │ ├─ ... 52 | │ │ ├─ 29999.jpg 53 | │ ├─ labels/ 54 | │ │ ├─ 28000.png 55 | │ │ ├─ ... 56 | │ │ ├─ 29999.png 57 | ``` 58 | 59 | - **ADE20K** can be downloaded from [MIT Scene Parsing Benchmark](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip), and we followed [SPADE](https://github.com/NVlabs/SPADE?tab=readme-ov-file#dataset-preparation) for preparation. The dataset should be structured as below: 60 | ``` 61 | ADE20K/ 62 | ├─ ADEChallengeData2016/ 63 | │ │ ├─ images/ 64 | │ │ │ ├─ training/ 65 | │ │ │ │ ├─ ADE_train_00000001.jpg 66 | │ │ │ │ ├─ ... 67 | │ │ │ ├─ validation/ 68 | │ │ │ │ ├─ ADE_val_00000001.jpg 69 | │ │ │ │ ├─ ... 70 | │ │ ├─ annotations/ 71 | │ │ │ ├─ training/ 72 | │ │ │ │ ├─ ADE_train_00000001.png 73 | │ │ │ │ ├─ ... 74 | │ │ │ ├─ validation/ 75 | │ │ │ │ ├─ ADE_val_00000001.png 76 | │ │ │ │ ├─ ... 77 | ``` 78 | 79 | - **COCO-STUFF** can be downloaded from [cocostuff](https://github.com/nightrome/cocostuff), and we followed [SPADE](https://github.com/NVlabs/SPADE?tab=readme-ov-file#dataset-preparation) for preparation. The dataset should be structured as below: 80 | ``` 81 | coco/ 82 | ├─ train_img/ 83 | │ ├─ 000000000009.jpg 84 | │ ├─ ... 85 | ├─ train_label/ 86 | │ ├─ 000000000009.png 87 | │ ├─ ... 88 | ├─ train_inst/ 89 | │ ├─ 000000000009.png 90 | │ ├─ ... 91 | ├─ val_img/ 92 | │ ├─ 000000000139.jpg 93 | │ ├─ ... 94 | ├─ val_label/ 95 | │ ├─ 000000000139.png 96 | │ ├─ ... 97 | ├─ val_inst/ 98 | │ ├─ 000000000139.png 99 | │ ├─ ... 100 | ``` 101 | 102 | ### Noisy SIS dataset for evaluation 103 | Our noisy SIS dataset for three benchmark settings (DS, Edge, and Random) based on ADE20K is available at [Google Drive](https://drive.google.com/drive/folders/1KvGETbHUaqnLcslkwDxcWzoU7ixtvy9L?usp=sharing). 104 | You can also generate the same dataset by running Python codes at `image_process/`. 105 | 106 | ## Experiments 107 | You can set CUDA visible devices by `VISIBLE_DEVICES=${GPU_ID}`. (e.g., `VISIBLE_DEVICES=0,1,2,3`) 108 | 109 | ### Training 110 | - Run 111 | 112 | ``` 113 | sh scripts/train.sh 114 | ``` 115 | - For more details, please refer to `scripts/train.sh`. 116 | - Pretrained models are available at [Google Drive](https://drive.google.com/drive/folders/1OGrIyYuk7EtFwBuHsD1DkN9cSHCOqq-w?usp=sharing). 117 | 118 | ### Sampling 119 | - Run 120 | 121 | ``` 122 | sh scripts/sample.sh 123 | ``` 124 | - For more details, please refer to `scripts/sample.sh`. 125 | - Our samples are available at [Google Drive](https://drive.google.com/drive/folders/1bNRX1PC2Q_rH0Nudk9QSmVQ2O-E4D8D0?usp=sharing). 126 | 127 | ### Evaluation 128 | - FID (fidelity) 129 | 130 | The code is based on [OASIS](https://github.com/boschresearch/OASIS). 131 | 132 | ``` 133 | python evaluations/fid/tests_with_FID.py --path {SAMPLE_PATH} {GT_IMAGE_PATH} -b {BATCH_SIZE} --gpu {GPU_ID} 134 | ``` 135 | 136 | - LPIPS (diversity) 137 | 138 | You should generate 10 sets of samples, and make `lpips_list.txt` with `evaluations/lpips/make_lpips_list.py`. The code is based on [stargan-v2](https://github.com/clovaai/stargan-v2). 139 | 140 | ``` 141 | python evaluations/lpips/lpips.py --root_path results/ade20k --test_list lpips_list.txt --batch_size 10 142 | ``` 143 | 144 | - mIoU (correspondence) 145 | - **CelebAMask-HQ**: [U-Net](https://github.com/NVlabs/imaginaire/blob/master/imaginaire/evaluation/segmentation/celebamask_hq.py). Clone the repo and set up environments from [imaginaire](https://github.com/NVlabs/imaginaire), and add `evaluation/miou/test_celeba.py` to `imaginaire/`. Check out `evaluation/miou/celeba_config.yaml` for the config file and fix the path accordingly. 146 | 147 | ```bash 148 | cd imaginaire 149 | python test_celeba.py 150 | ``` 151 | 152 | - **ADE20K**: Vit-Adapter-S with UperNet. Clone the repo and set up environments from [Vit-Adapter](https://github.com/czczup/ViT-Adapter/tree/main/segmentation). 153 | 154 | ```bash 155 | cd ViT-Adapter/segmentation 156 | bash dist_test.sh \ 157 | configs/ade20k/upernet_deit_adapter_small_512_160k_ade20k.py \ 158 | pretrained/upernet_deit_adapter_small_512_160k_ade20k.pth \ 159 | 1 \ # NUM_GPUS 160 | --eval mIoU \ 161 | --img_dir {SAMPLE_DIR} \ 162 | --ann_dir {LABEL_DIR} \ 163 | --root_dir {SAMPLE_ROOT_DIR} 164 | ``` 165 | - **COCO-STUFF**: [DeepLabV2](https://github.com/NVlabs/imaginaire/blob/master/imaginaire/evaluation/segmentation/cocostuff.py). Clone the repo and set up environments from [imaginaire](https://github.com/NVlabs/imaginaire), and add `evaluation/miou/test_coco.py` to `imaginaire/`. Check out `evaluation/miou/coco_config.yaml` for the config file and fix the path accordingly. 166 | 167 | 168 | ```bash 169 | cd imaginaire 170 | python test_coco.py 171 | ``` 172 | 173 | ## Acknowledgement 174 | 175 | This repository is built upon [guided-diffusion](https://github.com/openai/guided-diffusion) and [SDM](https://github.com/WeilunWang/semantic-diffusion-model). 176 | 177 | ## Citation 178 | If you use this work, please cite as: 179 | ```bibtex 180 | @InProceedings{pmlr-v235-ko24e, 181 | title ={Stochastic Conditional Diffusion Models for Robust Semantic Image Synthesis}, 182 | author ={Ko, Juyeon and Kong, Inho and Park, Dogyun and Kim, Hyunwoo J.}, 183 | booktitle ={Proceedings of the 41st International Conference on Machine Learning}, 184 | year ={2024} 185 | } 186 | 187 | ``` 188 | 189 | ## Contact 190 | 191 | Feel free to contact us if you need help or explanations! 192 | 193 | - Juyeon Ko (juyon98@korea.ac.kr) 194 | - Inho Kong (inh212@korea.ac.kr) 195 | -------------------------------------------------------------------------------- /ade20k_psiphi.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/ade20k_psiphi.npy -------------------------------------------------------------------------------- /asset/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/asset/framework.png -------------------------------------------------------------------------------- /asset/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/asset/motivation.png -------------------------------------------------------------------------------- /celeba_psiphi.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/celeba_psiphi.npy -------------------------------------------------------------------------------- /coco_psiphi.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/coco_psiphi.npy -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: scdm 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2022.10.11=h06a4308_0 10 | - certifi=2022.9.24=py37h06a4308_0 11 | - cudatoolkit=11.0.221=h6bb024c_0 12 | - flit-core=3.6.0=pyhd3eb1b0_0 13 | - freetype=2.12.1=h4a9f257_0 14 | - giflib=5.2.1=h7b6447c_0 15 | - intel-openmp=2021.4.0=h06a4308_3561 16 | - jpeg=9b=h024ee3a_2 17 | - lcms2=2.12=h3be6417_0 18 | - ld_impl_linux-64=2.38=h1181459_1 19 | - libffi=3.4.2=h6a678d5_6 20 | - libgcc-ng=11.2.0=h1234567_1 21 | - libgfortran-ng=7.5.0=h14aa051_20 22 | - libgfortran4=7.5.0=h14aa051_20 23 | - libgomp=11.2.0=h1234567_1 24 | - libpng=1.6.37=hbc83047_0 25 | - libstdcxx-ng=11.2.0=h1234567_1 26 | - libtiff=4.1.0=h2733197_1 27 | - libuv=1.40.0=h7b6447c_0 28 | - libwebp=1.2.0=h89dd481_0 29 | - lz4-c=1.9.3=h295c915_1 30 | - mkl=2021.4.0=h06a4308_640 31 | - mkl-service=2.4.0=py37h7f8727e_0 32 | - mkl_fft=1.3.1=py37hd3c417c_0 33 | - mkl_random=1.2.2=py37h51133e4_0 34 | - mpi=1.0=openmpi 35 | - mpi4py=3.0.3=py37hd955b32_1 36 | - ncurses=6.3=h5eee18b_3 37 | - ninja=1.10.2=h06a4308_5 38 | - ninja-base=1.10.2=hd09550d_5 39 | - openmpi=4.0.4=hdf1f1ad_0 40 | - openssl=1.1.1s=h7f8727e_0 41 | - pip=22.2.2=py37h06a4308_0 42 | - python=3.7.10=hb7a2778_101_cpython 43 | - readline=8.2=h5eee18b_0 44 | - setuptools=65.5.0=py37h06a4308_0 45 | - six=1.16.0=pyhd3eb1b0_1 46 | - sqlite=3.40.0=h5082296_0 47 | - tk=8.6.12=h1ccaba5_0 48 | - typing_extensions=4.4.0=py37h06a4308_0 49 | - wheel=0.37.1=pyhd3eb1b0_0 50 | - xz=5.2.8=h5eee18b_0 51 | - zlib=1.2.13=h5eee18b_0 52 | - zstd=1.4.9=haebb681_0 53 | - pip: 54 | - blobfile==2.0.0 55 | - cycler==0.11.0 56 | - filelock==3.8.2 57 | - fonttools==4.38.0 58 | - imageio==2.31.2 59 | - kiwisolver==1.4.5 60 | - lxml==4.9.1 61 | - matplotlib==3.5.3 62 | - networkx==2.6.3 63 | - numpy==1.21.6 64 | - packaging==24.0 65 | - pillow==9.3.0 66 | - pycryptodomex==3.16.0 67 | - pyparsing==3.1.2 68 | - python-dateutil==2.9.0.post0 69 | - pywavelets==1.3.0 70 | - scikit-image==0.19.3 71 | - scipy==1.7.3 72 | - tifffile==2021.11.2 73 | - torch==1.10.1+cu111 74 | - torchaudio==0.10.1+cu111 75 | - torchvision==0.11.2+cu111 76 | - tqdm==4.64.1 77 | - urllib3==1.26.13 78 | prefix: /home/user/.conda/envs/scdm 79 | -------------------------------------------------------------------------------- /evaluations/fid/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /evaluations/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/evaluations/fid/__init__.py -------------------------------------------------------------------------------- /evaluations/fid/__pycache__/inception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/evaluations/fid/__pycache__/inception.cpython-37.pyc -------------------------------------------------------------------------------- /evaluations/fid/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | import ssl 189 | ssl._create_default_https_context = ssl._create_unverified_context 190 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 191 | inception.load_state_dict(state_dict) 192 | return inception 193 | 194 | 195 | class FIDInceptionA(models.inception.InceptionA): 196 | """InceptionA block patched for FID computation""" 197 | def __init__(self, in_channels, pool_features): 198 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 199 | 200 | def forward(self, x): 201 | branch1x1 = self.branch1x1(x) 202 | 203 | branch5x5 = self.branch5x5_1(x) 204 | branch5x5 = self.branch5x5_2(branch5x5) 205 | 206 | branch3x3dbl = self.branch3x3dbl_1(x) 207 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 208 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 209 | 210 | # Patch: Tensorflow's average pool does not use the padded zero's in 211 | # its average calculation 212 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 213 | count_include_pad=False) 214 | branch_pool = self.branch_pool(branch_pool) 215 | 216 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 217 | return torch.cat(outputs, 1) 218 | 219 | 220 | class FIDInceptionC(models.inception.InceptionC): 221 | """InceptionC block patched for FID computation""" 222 | def __init__(self, in_channels, channels_7x7): 223 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 224 | 225 | def forward(self, x): 226 | branch1x1 = self.branch1x1(x) 227 | 228 | branch7x7 = self.branch7x7_1(x) 229 | branch7x7 = self.branch7x7_2(branch7x7) 230 | branch7x7 = self.branch7x7_3(branch7x7) 231 | 232 | branch7x7dbl = self.branch7x7dbl_1(x) 233 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 235 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 236 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 237 | 238 | # Patch: Tensorflow's average pool does not use the padded zero's in 239 | # its average calculation 240 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 241 | count_include_pad=False) 242 | branch_pool = self.branch_pool(branch_pool) 243 | 244 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 245 | return torch.cat(outputs, 1) 246 | 247 | 248 | class FIDInceptionE_1(models.inception.InceptionE): 249 | """First InceptionE block patched for FID computation""" 250 | def __init__(self, in_channels): 251 | super(FIDInceptionE_1, self).__init__(in_channels) 252 | 253 | def forward(self, x): 254 | branch1x1 = self.branch1x1(x) 255 | 256 | branch3x3 = self.branch3x3_1(x) 257 | branch3x3 = [ 258 | self.branch3x3_2a(branch3x3), 259 | self.branch3x3_2b(branch3x3), 260 | ] 261 | branch3x3 = torch.cat(branch3x3, 1) 262 | 263 | branch3x3dbl = self.branch3x3dbl_1(x) 264 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 265 | branch3x3dbl = [ 266 | self.branch3x3dbl_3a(branch3x3dbl), 267 | self.branch3x3dbl_3b(branch3x3dbl), 268 | ] 269 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 270 | 271 | # Patch: Tensorflow's average pool does not use the padded zero's in 272 | # its average calculation 273 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 274 | count_include_pad=False) 275 | branch_pool = self.branch_pool(branch_pool) 276 | 277 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 278 | return torch.cat(outputs, 1) 279 | 280 | 281 | class FIDInceptionE_2(models.inception.InceptionE): 282 | """Second InceptionE block patched for FID computation""" 283 | def __init__(self, in_channels): 284 | super(FIDInceptionE_2, self).__init__(in_channels) 285 | 286 | def forward(self, x): 287 | branch1x1 = self.branch1x1(x) 288 | 289 | branch3x3 = self.branch3x3_1(x) 290 | branch3x3 = [ 291 | self.branch3x3_2a(branch3x3), 292 | self.branch3x3_2b(branch3x3), 293 | ] 294 | branch3x3 = torch.cat(branch3x3, 1) 295 | 296 | branch3x3dbl = self.branch3x3dbl_1(x) 297 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 298 | branch3x3dbl = [ 299 | self.branch3x3dbl_3a(branch3x3dbl), 300 | self.branch3x3dbl_3b(branch3x3dbl), 301 | ] 302 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 303 | 304 | # Patch: The FID Inception model uses max pooling instead of average 305 | # pooling. This is likely an error in this specific Inception 306 | # implementation, as other Inception models use average pooling here 307 | # (which matches the description in the paper). 308 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 309 | branch_pool = self.branch_pool(branch_pool) 310 | 311 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 312 | return torch.cat(outputs, 1) 313 | -------------------------------------------------------------------------------- /evaluations/fid/tests_with_FID.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | """ 35 | import os 36 | import pathlib 37 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 38 | 39 | import numpy as np 40 | import torch 41 | from scipy import linalg 42 | from imageio import imread 43 | import imageio 44 | from torch.nn.functional import adaptive_avg_pool2d 45 | from skimage.transform import resize 46 | import matplotlib.pyplot as plt 47 | import numpy as np 48 | 49 | try: 50 | from tqdm import tqdm 51 | except ImportError: 52 | # If not tqdm is not available, provide a mock version of it 53 | def tqdm(x): return x 54 | 55 | from inception import InceptionV3 56 | 57 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 58 | parser.add_argument('--path', type=str, nargs=2, 59 | help=('Path to the generated images or ' 60 | 'to .npz statistic files')) 61 | parser.add_argument('--batch-size', '-b', type=int, default=50, 62 | help='Batch size to use') 63 | parser.add_argument('--dims', type=int, default=2048, 64 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 65 | help=('Dimensionality of Inception features to use. ' 66 | 'By default, uses pool3 features')) 67 | parser.add_argument('-c', '--gpu', default='', type=str, 68 | help='GPU to use (leave blank for CPU only)') 69 | 70 | def get_activations(files, model, batch_size=50, dims=2048, 71 | cuda=False, verbose=False): 72 | """Calculates the activations of the pool_3 layer for all images. 73 | 74 | Params: 75 | -- files : List of image files paths 76 | -- model : Instance of inception model 77 | -- batch_size : Batch size of images for the model to process at once. 78 | Make sure that the number of samples is a multiple of 79 | the batch size, otherwise some samples are ignored. This 80 | behavior is retained to match the original FID score 81 | implementation. 82 | -- dims : Dimensionality of features returned by Inception 83 | -- cuda : If set to True, use GPU 84 | -- verbose : If set to True and parameter out_step is given, the number 85 | of calculated batches is reported. 86 | Returns: 87 | -- A numpy array of dimension (num images, dims) that contains the 88 | activations of the given tensor when feeding inception with the 89 | query tensor. 90 | """ 91 | model.eval() 92 | 93 | if len(files) % batch_size != 0: 94 | print(('Warning: number of images is not a multiple of the ' 95 | 'batch size. Some samples are going to be ignored.')) 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | n_batches = len(files) // batch_size 101 | n_used_imgs = n_batches * batch_size 102 | 103 | pred_arr = np.empty((n_used_imgs, dims)) 104 | 105 | for i in tqdm(range(n_batches)): 106 | if verbose: 107 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 108 | end='', flush=True) 109 | start = i * batch_size 110 | end = start + batch_size 111 | images = np.array([resize(imread(str(f)).astype(np.float32), (256, 256, 3)) 112 | for f in files[start:end]]) 113 | 114 | # Reshape to (n_images, 3, height, width) 115 | images = images.transpose((0, 3, 1, 2)) 116 | images /= 255 117 | 118 | batch = torch.from_numpy(images).type(torch.FloatTensor) 119 | 120 | #import matplotlib.pyplot as plt 121 | #import numpy 122 | 123 | #plt.imshow(np.rollaxis(batch[0].detach().cpu().numpy(), 0, 3)) 124 | #plt.show() 125 | 126 | if cuda: 127 | batch = batch.cuda() 128 | 129 | pred = model(batch)[0] 130 | 131 | # If model output is not scalar, apply global spatial average pooling. 132 | # This happens if you choose a dimensionality not equal 2048. 133 | if pred.shape[2] != 1 or pred.shape[3] != 1: 134 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 135 | 136 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 137 | 138 | if verbose: 139 | print(' done') 140 | 141 | print(pred_arr.shape) 142 | 143 | return pred_arr 144 | 145 | 146 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 147 | """Numpy implementation of the Frechet Distance. 148 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 149 | and X_2 ~ N(mu_2, C_2) is 150 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 151 | 152 | Stable version by Dougal J. Sutherland. 153 | 154 | Params: 155 | -- mu1 : Numpy array containing the activations of a layer of the 156 | inception net (like returned by the function 'get_predictions') 157 | for generated samples. 158 | -- mu2 : The sample mean over activations, precalculated on an 159 | representative data set. 160 | -- sigma1: The covariance matrix over activations for generated samples. 161 | -- sigma2: The covariance matrix over activations, precalculated on an 162 | representative data set. 163 | 164 | Returns: 165 | -- : The Frechet Distance. 166 | """ 167 | 168 | mu1 = np.atleast_1d(mu1) 169 | mu2 = np.atleast_1d(mu2) 170 | 171 | sigma1 = np.atleast_2d(sigma1) 172 | sigma2 = np.atleast_2d(sigma2) 173 | 174 | assert mu1.shape == mu2.shape, \ 175 | 'Training and test mean vectors have different lengths' 176 | assert sigma1.shape == sigma2.shape, \ 177 | 'Training and test covariances have different dimensions' 178 | 179 | diff = mu1 - mu2 180 | 181 | # Product might be almost singular 182 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 183 | if not np.isfinite(covmean).all(): 184 | msg = ('fid calculation produces singular product; ' 185 | 'adding %s to diagonal of cov estimates') % eps 186 | print(msg) 187 | offset = np.eye(sigma1.shape[0]) * eps 188 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 189 | 190 | # Numerical error might give slight imaginary component 191 | if np.iscomplexobj(covmean): 192 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 193 | m = np.max(np.abs(covmean.imag)) 194 | raise ValueError('Imaginary component {}'.format(m)) 195 | covmean = covmean.real 196 | 197 | tr_covmean = np.trace(covmean) 198 | 199 | return (diff.dot(diff) + np.trace(sigma1) + 200 | np.trace(sigma2) - 2 * tr_covmean) 201 | 202 | 203 | def calculate_activation_statistics(files, model, batch_size=50, 204 | dims=2048, cuda=False, verbose=False): 205 | """Calculation of the statistics used by the FID. 206 | Params: 207 | -- files : List of image files paths 208 | -- model : Instance of inception model 209 | -- batch_size : The images numpy array is split into batches with 210 | batch size batch_size. A reasonable batch size 211 | depends on the hardware. 212 | -- dims : Dimensionality of features returned by Inception 213 | -- cuda : If set to True, use GPU 214 | -- verbose : If set to True and parameter out_step is given, the 215 | number of calculated batches is reported. 216 | Returns: 217 | -- mu : The mean over samples of the activations of the pool_3 layer of 218 | the inception model. 219 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 220 | the inception model. 221 | """ 222 | act = get_activations(files, model, batch_size, dims, cuda, verbose) 223 | mu = np.mean(act, axis=0) 224 | sigma = np.cov(act, rowvar=False) 225 | #print(act.shape) 226 | #rint(mu.shape) 227 | #print(sigma.shape) 228 | return mu, sigma 229 | 230 | 231 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 232 | if path.endswith('.npz'): 233 | f = np.load(path) 234 | m, s = f['mu'][:], f['sigma'][:] 235 | f.close() 236 | else: 237 | path = pathlib.Path(path) 238 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 239 | m, s = calculate_activation_statistics(files, model, batch_size, 240 | dims, cuda) 241 | 242 | return m, s 243 | 244 | 245 | def calculate_fid_given_paths(paths, batch_size, cuda, dims): 246 | """Calculates the FID of two paths""" 247 | for p in paths: 248 | if not os.path.exists(p): 249 | raise RuntimeError('Invalid path: %s' % p) 250 | 251 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 252 | 253 | model = InceptionV3([block_idx]) 254 | if cuda: 255 | model.cuda() 256 | 257 | m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, 258 | dims, cuda) 259 | m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, 260 | dims, cuda) 261 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 262 | 263 | return fid_value 264 | 265 | 266 | if __name__ == '__main__': 267 | args = parser.parse_args() 268 | # os.environ['CUDA_VISIBLE_DEVICES'] = "0" 269 | 270 | # path1 = "./results/real_B/" 271 | # path2 = "./results/fake_B/" 272 | # batch_size = 50 273 | # dims = 2048 274 | 275 | path1 = args.path[0] 276 | path2 = args.path[1] 277 | batch_size = args.batch_size 278 | dims = args.dims 279 | cuda = False 280 | 281 | if args.gpu: 282 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 283 | cuda = True 284 | 285 | fid_value = calculate_fid_given_paths([path1, path2], 286 | batch_size, 287 | cuda, 288 | dims) 289 | print(fid_value) 290 | -------------------------------------------------------------------------------- /evaluations/lpips/default_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils import data 4 | from torchvision import transforms 5 | 6 | 7 | class DefaultDataset(data.Dataset): 8 | def __init__(self, root, transform=None, return_path=False, test_list=None): 9 | if test_list is not None: 10 | with open(test_list, "r") as f: 11 | test_files = f.readlines() 12 | test_files = [x.strip() for x in test_files] 13 | onlyroot = root.split('/')[-1] 14 | leftroot = root[:-2] 15 | self.samples = [os.path.join(leftroot, line) for line in test_files if line.startswith(onlyroot)] 16 | else: 17 | self.samples = [os.path.join(root, dir) for dir in os.listdir(root)] 18 | self.samples.sort() 19 | self.transform = transform 20 | self.return_path = return_path 21 | self.targets = None 22 | 23 | def __getitem__(self, index): 24 | fname = self.samples[index] 25 | img = Image.open(fname).convert('RGB') 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | if self.return_path: 29 | return img, os.path.basename(fname) 30 | else: 31 | return img 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | 36 | 37 | def get_eval_loader(root, img_size=256, batch_size=32, 38 | imagenet_normalize=True, shuffle=True, 39 | num_workers=4, drop_last=False, return_path=False, test_list=None): 40 | # print('Preparing DataLoader for the evaluation phase...') 41 | if imagenet_normalize: 42 | height, width = 299, 299 43 | mean = [0.485, 0.456, 0.406] 44 | std = [0.229, 0.224, 0.225] 45 | else: 46 | height, width = img_size, img_size 47 | mean = [0.5, 0.5, 0.5] 48 | std = [0.5, 0.5, 0.5] 49 | 50 | transform = transforms.Compose([ 51 | transforms.Resize([img_size, img_size]), 52 | transforms.Resize([height, width]), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=mean, std=std) 55 | ]) 56 | 57 | dataset = DefaultDataset(root, transform=transform, return_path=return_path, test_list=test_list) 58 | return data.DataLoader(dataset=dataset, 59 | batch_size=batch_size, 60 | shuffle=shuffle, 61 | num_workers=num_workers, 62 | pin_memory=True, 63 | drop_last=drop_last) -------------------------------------------------------------------------------- /evaluations/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | import torch 13 | import argparse 14 | import torch.nn as nn 15 | from torchvision import models 16 | from default_dataset import get_eval_loader 17 | 18 | 19 | def normalize(x, eps=1e-10): 20 | return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps) 21 | 22 | 23 | class AlexNet(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | self.layers = models.alexnet(pretrained=True).features 27 | self.channels = [] 28 | for layer in self.layers: 29 | if isinstance(layer, nn.Conv2d): 30 | self.channels.append(layer.out_channels) 31 | 32 | def forward(self, x): 33 | fmaps = [] 34 | for layer in self.layers: 35 | x = layer(x) 36 | if isinstance(layer, nn.ReLU): 37 | fmaps.append(x) 38 | return fmaps 39 | 40 | 41 | class Conv1x1(nn.Module): 42 | def __init__(self, in_channels, out_channels=1): 43 | super().__init__() 44 | self.main = nn.Sequential( 45 | nn.Dropout(0.5), 46 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)) 47 | 48 | def forward(self, x): 49 | return self.main(x) 50 | 51 | 52 | class LPIPS(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.alexnet = AlexNet() 56 | self.lpips_weights = nn.ModuleList() 57 | for channels in self.alexnet.channels: 58 | self.lpips_weights.append(Conv1x1(channels, 1)) 59 | self._load_lpips_weights() 60 | # imagenet normalization for range [-1, 1] 61 | self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda() 62 | self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda() 63 | 64 | def _load_lpips_weights(self): 65 | own_state_dict = self.state_dict() 66 | if torch.cuda.is_available(): 67 | # state_dict = torch.load('evaluations/alexnet-owt-7be5be79.pth') 68 | state_dict = torch.load('evaluations/lpips_weights.ckpt') 69 | else: 70 | state_dict = torch.load('evaluations/lpips_weights.ckpt', 71 | map_location=torch.device('cpu')) 72 | for name, param in state_dict.items(): 73 | if name in own_state_dict: 74 | own_state_dict[name].copy_(param) 75 | 76 | def forward(self, x, y): 77 | x = (x - self.mu) / self.sigma 78 | y = (y - self.mu) / self.sigma 79 | x_fmaps = self.alexnet(x) 80 | y_fmaps = self.alexnet(y) 81 | lpips_value = 0 82 | for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights): 83 | x_fmap = normalize(x_fmap) 84 | y_fmap = normalize(y_fmap) 85 | lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2)) 86 | return lpips_value 87 | 88 | 89 | @torch.no_grad() 90 | def calculate_lpips_given_images(group_of_images): 91 | # group_of_images = [torch.randn(N, C, H, W) for _ in range(10)] 92 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 93 | lpips = LPIPS().eval().to(device) 94 | lpips_values = [] 95 | num_rand_outputs = len(group_of_images) 96 | # num_rand_outputs = 10 97 | 98 | # calculate the average of pairwise distances among all random outputs 99 | for i in range(num_rand_outputs-1): 100 | for j in range(i+1, num_rand_outputs): 101 | lpips_values.append(lpips(group_of_images[i], group_of_images[j])) 102 | lpips_value = torch.mean(torch.stack(lpips_values, dim=0)) 103 | return lpips_value.item() 104 | 105 | 106 | @torch.no_grad() 107 | def calculate_lpips_given_paths(root_path, img_size=256, batch_size=50, test_list=None, seed_list=[]): 108 | # print('Calculating LPIPS given root path %s' % root_path) 109 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 110 | paths = [os.path.join(root_path, path) for path in os.listdir(root_path) ] #if path.startswith('seed') and path.split('/')[0][4:] in seed_list] 111 | loaders = [get_eval_loader(path, img_size, batch_size, shuffle=False, drop_last=False, test_list=test_list) for path in paths] 112 | loaders_iter = [iter(loader) for loader in loaders] 113 | num_clips = len(loaders[0]) 114 | 115 | lpips_values = [] 116 | 117 | for i in range(num_clips): 118 | group_of_images = [loader.next().cuda() for loader in loaders_iter] 119 | lpips_values.append(calculate_lpips_given_images(group_of_images)) 120 | lpips_values = torch.tensor(lpips_values) 121 | 122 | lpips_mean = torch.mean(lpips_values) 123 | 124 | return lpips_mean 125 | 126 | 127 | def make_lpips_list(seed_list, root_path): 128 | dirr = root_path +'0/samples' 129 | name = 'lpips_list.txt' 130 | filelist = os.listdir(dirr) 131 | filelist.sort() 132 | with open(name, 'w') as f: 133 | for seed in seed_list: 134 | f.writelines(seed + '/samples/' + line + '\n' for line in filelist) 135 | 136 | 137 | if __name__ == '__main__': 138 | # python -m metrics.lpips --paths PATH_REAL PATH_FAKE 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument('--root_path', type=str, default="", help='paths to real and fake images') 141 | parser.add_argument('--test_list', type=str, default="evaluations/ade20k_list.txt", help='paths to real and fake images') 142 | parser.add_argument('--img_size', type=int, default=256, help='image resolution') 143 | parser.add_argument('--batch_size', '-b', type=int, default=50, help='batch size to use') 144 | parser.add_argument('--seed_list', type=str, default='') 145 | parser.add_argument('--mini_lpips_size', type=int, default=-1) 146 | parser.add_argument('--mini_lpips_num', type=int, default=-1) 147 | args = parser.parse_args() 148 | if args.mini_lpips_num == -1: 149 | make_lpips_list(args.seed_list.split(','),args.root_path) 150 | lpips_value = calculate_lpips_given_paths(args.root_path, args.img_size, args.batch_size, test_list=args.test_list, seed_list = args.seed_list.split(',')) 151 | print('LPIPS: ', lpips_value) 152 | else: 153 | import random 154 | import numpy as np 155 | results = [] 156 | print_list = [] 157 | for i in range(args.mini_lpips_num): 158 | print(i) 159 | seed_list = random.sample(args.seed_list.split(','),args.mini_lpips_size) 160 | make_lpips_list(seed_list,args.root_path) 161 | lpips_value = calculate_lpips_given_paths(args.root_path, args.img_size, args.batch_size, test_list=args.test_list, seed_list = seed_list) 162 | results.append(lpips_value) 163 | print_list.append('mini batch('+str(args.mini_lpips_size)+') : [' + ','.join(seed_list) + '], LPIPS: '+ str(lpips_value)) 164 | for line in print_list: 165 | print(line) 166 | results = np.array(results) 167 | print('mean : ',results.mean(),', std : ',results.std()) 168 | np.save('mini_lpips_results/mini_lpips_results_'+str(args.mini_lpips_size)+'.npy',results) 169 | -------------------------------------------------------------------------------- /evaluations/lpips/lpips_weights.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/SCDM/dd176a177716a597b1aeac3d120ac9488821a661/evaluations/lpips/lpips_weights.ckpt -------------------------------------------------------------------------------- /evaluations/lpips/make_lpips_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | dirr = '/home/user/scdm/results/ade20k/generated_samples' 3 | name = 'lpips_list.txt' 4 | filelist = os.listdir(dirr) 5 | filelist.sort() 6 | seed_list 7 | with open(name, 'w') as f: 8 | f.writelines(dirr + line + '\n' for line in filelist) -------------------------------------------------------------------------------- /evaluations/miou/celeba_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: imaginaire.datasets.paired_images 3 | num_workers: 8 4 | input_types: 5 | - input_label: 6 | ext: png 7 | num_channels: 1 8 | normalize: False 9 | is_mask: True 10 | - synthesized_image: 11 | ext: jpg 12 | num_channels: 3 13 | normalize: True 14 | use_dont_care: False 15 | full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels 16 | use_dont_care: True 17 | one_hot_num_classes: 18 | seg_maps: 19 19 | input_labels: 20 | - input_label 21 | input_image: 22 | - synthesized_image 23 | val: 24 | # Input LMDBs. 25 | roots: 26 | - /home/user/generated_samples 27 | # Batch size per GPU. 28 | batch_size: 4 29 | # Data augmentations to be performed in given order. 30 | augmentations: 31 | # Crop size. 32 | resize_h_w: 256, 256 33 | -------------------------------------------------------------------------------- /evaluations/miou/coco_config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: imaginaire.datasets.paired_images 3 | num_workers: 8 4 | input_types: 5 | - input_label: 6 | ext: png 7 | num_channels: 1 8 | normalize: False 9 | is_mask: True 10 | - samples: 11 | ext: png 12 | num_channels: 3 13 | normalize: True 14 | use_dont_care: False 15 | full_data_ops: imaginaire.model_utils.label::make_one_hot, imaginaire.model_utils.label::concat_labels 16 | use_dont_care: True 17 | one_hot_num_classes: 18 | seg_maps: 183 19 | input_labels: 20 | - input_label 21 | input_image: 22 | - samples 23 | val: 24 | # Input LMDBs. 25 | roots: 26 | - /home/user/generated_samples 27 | # Batch size per GPU. 28 | batch_size: 4 29 | # Data augmentations to be performed in given order. 30 | augmentations: 31 | # Crop size. 32 | resize_h_w: 512, 512 33 | -------------------------------------------------------------------------------- /evaluations/miou/test_celeba.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import argparse 5 | 6 | from imaginaire.config import Config 7 | from imaginaire.evaluation.common2 import compute_all_metrics_data 8 | from imaginaire.utils.dataset import _get_val_dataset_object, _get_data_loader 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training') 12 | parser.add_argument('--dataset', required=True, 13 | help='dataset name') 14 | parser.add_argument('--batch_size', '-bs', default=40, type=int) 15 | args = parser.parse_args() 16 | return args 17 | 18 | def main(): 19 | args = parse_args() 20 | 21 | fake_b = f'/home/user/imaginaire/configs/celeba_miou/{args.dataset}/fake_b.yaml' 22 | fake_b_cfg = Config(fake_b) 23 | bs = args.batch_size 24 | 25 | fake_b_dataset = _get_val_dataset_object(fake_b_cfg) 26 | data_loader_b = _get_data_loader(fake_b_cfg, fake_b_dataset, bs) 27 | 28 | all_metrics = compute_all_metrics_data(data_loader_b=data_loader_b, metrics=['seg_mIOU'], dataset_name='celebamask_hq') 29 | print(args.dataset) 30 | print(all_metrics) 31 | print('---------------') 32 | 33 | if __name__ == '__main__': 34 | main() -------------------------------------------------------------------------------- /evaluations/miou/test_coco.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | import argparse 5 | 6 | from imaginaire.config import Config 7 | from imaginaire.evaluation.common2 import compute_all_metrics_data 8 | from imaginaire.utils.dataset import _get_val_dataset_object, _get_data_loader 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Training') 12 | parser.add_argument('--batch_size', '-bs', default=25, type=int) 13 | args = parser.parse_args() 14 | return args 15 | 16 | def main(): 17 | args = parse_args() 18 | 19 | fake_b = f'/home/user/imaginaire/configs/coco_config.yaml' 20 | fake_b_cfg = Config(fake_b) 21 | bs = args.batch_size 22 | 23 | fake_b_dataset = _get_val_dataset_object(fake_b_cfg) 24 | data_loader_b = _get_data_loader(fake_b_cfg, fake_b_dataset, bs) 25 | 26 | all_metrics = compute_all_metrics_data(key_a='samples', key_b='samples', data_loader_b=data_loader_b, metrics=['seg_mIOU'], dataset_name='cocostuff') 27 | 28 | print(all_metrics) 29 | print('---------------') 30 | 31 | if __name__ == '__main__': 32 | main() -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | import numpy as np 14 | import random 15 | 16 | # Change this to reflect your cluster layout. 17 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 18 | 19 | SETUP_RETRY_COUNT = 3 20 | 21 | 22 | def setup_dist(gpus_per_node=8): 23 | """ 24 | Setup a distributed process group. 25 | """ 26 | if dist.is_initialized(): 27 | return 28 | 29 | if "VISIBLE_DEVICES" in os.environ: 30 | visible_devices = os.environ["VISIBLE_DEVICES"].split(',') 31 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{visible_devices[MPI.COMM_WORLD.Get_rank()]}" 32 | else: 33 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{4 + MPI.COMM_WORLD.Get_rank() % gpus_per_node}" 34 | 35 | comm = MPI.COMM_WORLD 36 | backend = "gloo" if not th.cuda.is_available() else "nccl" 37 | 38 | if backend == "gloo": 39 | hostname = "localhost" 40 | else: 41 | hostname = socket.gethostbyname(socket.getfqdn()) 42 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 43 | os.environ["RANK"] = str(comm.rank) 44 | os.environ["WORLD_SIZE"] = str(comm.size) 45 | 46 | port = comm.bcast(_find_free_port(), root=0) 47 | os.environ["MASTER_PORT"] = str(port) 48 | dist.init_process_group(backend=backend, init_method="env://") 49 | 50 | 51 | def dev(): 52 | """ 53 | Get the device to use for torch.distributed. 54 | """ 55 | if th.cuda.is_available(): 56 | return th.device(f"cuda") 57 | return th.device("cpu") 58 | 59 | 60 | def load_state_dict(path, **kwargs): 61 | """ 62 | Load a PyTorch file without redundant fetches across MPI ranks. 63 | """ 64 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 65 | if MPI.COMM_WORLD.Get_rank() == 0: 66 | with bf.BlobFile(path, "rb") as f: 67 | data = f.read() 68 | num_chunks = len(data) // chunk_size 69 | if len(data) % chunk_size: 70 | num_chunks += 1 71 | MPI.COMM_WORLD.bcast(num_chunks) 72 | for i in range(0, len(data), chunk_size): 73 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 74 | else: 75 | num_chunks = MPI.COMM_WORLD.bcast(None) 76 | data = bytes() 77 | for _ in range(num_chunks): 78 | data += MPI.COMM_WORLD.bcast(None) 79 | 80 | return th.load(io.BytesIO(data), **kwargs) 81 | 82 | 83 | def sync_params(params): 84 | """ 85 | Synchronize a sequence of Tensors across ranks from rank 0. 86 | """ 87 | for p in params: 88 | with th.no_grad(): 89 | dist.broadcast(p.detach(), 0) 90 | 91 | 92 | def _find_free_port(): 93 | try: 94 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 95 | s.bind(("", 0)) 96 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 97 | return s.getsockname()[1] 98 | finally: 99 | s.close() 100 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 203 | opt.step() 204 | zero_master_grads(self.master_params) 205 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 206 | self.lg_loss_scale += self.fp16_scale_growth 207 | return True 208 | 209 | def _optimize_normal(self, opt: th.optim.Optimizer): 210 | grad_norm, param_norm = self._compute_norms() 211 | logger.logkv_mean("grad_norm", grad_norm) 212 | logger.logkv_mean("param_norm", param_norm) 213 | opt.step() 214 | return True 215 | 216 | def _compute_norms(self, grad_scale=1.0): 217 | grad_norm = 0.0 218 | param_norm = 0.0 219 | for p in self.master_params: 220 | with th.no_grad(): 221 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 222 | if p.grad is not None: 223 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 224 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 225 | 226 | def master_params_to_state_dict(self, master_params): 227 | return master_params_to_state_dict( 228 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 229 | ) 230 | 231 | def state_dict_to_master_params(self, state_dict): 232 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 233 | 234 | 235 | def check_overflow(value): 236 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 237 | -------------------------------------------------------------------------------- /guided_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code started out as a PyTorch port of Ho et al's diffusion models: 3 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py 4 | 5 | Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. 6 | """ 7 | 8 | import enum 9 | import math 10 | 11 | import numpy as np 12 | import torch as th 13 | 14 | import torch.distributed as dist 15 | import torchvision as tv 16 | from einops import rearrange 17 | 18 | from .nn import mean_flat 19 | from .losses import normal_kl, discretized_gaussian_log_likelihood 20 | import random 21 | 22 | colors = [] 23 | for _ in range(152): 24 | color = list(random.choices(range(256), k=3)) 25 | colors.append(color) 26 | 27 | def dynamic_thresholding(x, p): 28 | s = th.quantile( 29 | rearrange(x, 'b ... -> b (...)').abs(), 30 | p, 31 | dim = -1 32 | ) 33 | s.clamp_(min = 1.) 34 | s = s.view(-1, *((1,) * (x.ndim - 1))) 35 | # clip by threshold, depending on whether static or dynamic 36 | return x.clamp(-s, s) / s 37 | 38 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 39 | """ 40 | Get a pre-defined beta schedule for the given name. 41 | 42 | The beta schedule library consists of beta schedules which remain similar 43 | in the limit of num_diffusion_timesteps. 44 | Beta schedules may be added, but should not be removed or changed once 45 | they are committed to maintain backwards compatibility. 46 | """ 47 | if schedule_name == "linear": 48 | # Linear schedule from Ho et al, extended to work for any number of 49 | # diffusion steps. 50 | scale = 1000 / num_diffusion_timesteps 51 | beta_start = scale * 0.0001 52 | beta_end = scale * 0.02 53 | return np.linspace( 54 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 55 | ) 56 | elif schedule_name == "cosine": 57 | return betas_for_alpha_bar( 58 | num_diffusion_timesteps, 59 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 60 | ) 61 | elif schedule_name.split('_')[0] == "discrete": 62 | if schedule_name.split('_')[2] == "uniform": 63 | # beta_t^-1 = T - t + 1 64 | beta_inv = np.linspace( 65 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 66 | ) 67 | return 1/beta_inv 68 | else: 69 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 70 | else: 71 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 72 | 73 | 74 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 75 | """ 76 | Create a beta schedule that discretizes the given alpha_t_bar function, 77 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 78 | 79 | :param num_diffusion_timesteps: the number of betas to produce. 80 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 81 | produces the cumulative product of (1-beta) up to that 82 | part of the diffusion process. 83 | :param max_beta: the maximum beta to use; use values lower than 1 to 84 | prevent singularities. 85 | """ 86 | betas = [] 87 | for i in range(num_diffusion_timesteps): 88 | t1 = i / num_diffusion_timesteps 89 | t2 = (i + 1) / num_diffusion_timesteps 90 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 91 | return np.array(betas) 92 | 93 | 94 | class ModelMeanType(enum.Enum): 95 | """ 96 | Which type of output the model predicts. 97 | """ 98 | 99 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 100 | START_X = enum.auto() # the model predicts x_0 101 | EPSILON = enum.auto() # the model predicts epsilon 102 | 103 | 104 | class ModelVarType(enum.Enum): 105 | """ 106 | What is used as the model's output variance. 107 | 108 | The LEARNED_RANGE option has been added to allow the model to predict 109 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 110 | """ 111 | 112 | LEARNED = enum.auto() 113 | FIXED_SMALL = enum.auto() 114 | FIXED_LARGE = enum.auto() 115 | LEARNED_RANGE = enum.auto() 116 | 117 | 118 | class LossType(enum.Enum): 119 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 120 | RESCALED_MSE = ( 121 | enum.auto() 122 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 123 | KL = enum.auto() # use the variational lower-bound 124 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 125 | 126 | def is_vb(self): 127 | return self == LossType.KL or self == LossType.RESCALED_KL 128 | 129 | 130 | class GaussianDiffusion: 131 | """ 132 | Utilities for training and sampling diffusion models. 133 | 134 | Ported directly from here, and then adapted over time to further experimentation. 135 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 136 | 137 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 138 | starting at T and going to 1. 139 | :param model_mean_type: a ModelMeanType determining what the model outputs. 140 | :param model_var_type: a ModelVarType determining how variance is output. 141 | :param loss_type: a LossType determining the loss function to use. 142 | :param rescale_timesteps: if True, pass floating point timesteps into the 143 | model so that they are always scaled like in the 144 | original paper (0 to 1000). 145 | """ 146 | 147 | def __init__( 148 | self, 149 | *, 150 | betas, 151 | model_mean_type, 152 | model_var_type, 153 | loss_type, 154 | rescale_timesteps=False, 155 | cond_diffuse=False, 156 | cond_opt = '', 157 | cond_betas = None, 158 | no_instance = True, 159 | dataset_mode = "ade20k" 160 | ): 161 | self.model_mean_type = model_mean_type 162 | self.model_var_type = model_var_type 163 | self.loss_type = loss_type 164 | self.rescale_timesteps = rescale_timesteps 165 | self.cond_diffuse = cond_diffuse 166 | self.cond_opt = cond_opt 167 | self.no_instance = no_instance 168 | self.dataset_mode = dataset_mode 169 | 170 | # Use float64 for accuracy. 171 | betas = np.array(betas, dtype=np.float64) 172 | self.betas = betas 173 | assert len(betas.shape) == 1, "betas must be 1-D" 174 | assert (betas > 0).all() and (betas <= 1).all() 175 | 176 | self.num_timesteps = int(betas.shape[0]) 177 | 178 | alphas = 1.0 - betas 179 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 180 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 181 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 182 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 183 | 184 | if self.cond_opt.split('_')[0] == 'discrete' and self.cond_opt.split('_')[2] == 'uniform': 185 | cond_betas = get_named_beta_schedule(self.cond_opt,self.num_timesteps) 186 | cond_betas = np.array(cond_betas, dtype=np.float64) 187 | self.cond_betas = cond_betas 188 | assert len(cond_betas.shape) == 1, "betas must be 1-D" 189 | assert (cond_betas > 0).all() and (cond_betas <= 1).all() 190 | 191 | cond_alphas = 1.0 - cond_betas 192 | self.gammas = np.cumprod(cond_alphas, axis=0) 193 | 194 | if self.cond_opt != '': 195 | psiphi_path = dataset_mode + '_psiphi.npy' 196 | self.psiphi = np.load(psiphi_path) 197 | if dataset_mode == 'ade20k' or dataset_mode == 'coco': 198 | self.psiphi[-1] = 0.5 199 | 200 | # calculations for diffusion q(x_t | x_{t-1}) and others 201 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 202 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 203 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 204 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 205 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 206 | 207 | # calculations for posterior q(x_{t-1} | x_t, x_0) 208 | self.posterior_variance = ( 209 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 210 | ) 211 | # log calculation clipped because the posterior variance is 0 at the 212 | # beginning of the diffusion chain. 213 | self.posterior_log_variance_clipped = np.log( 214 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 215 | ) 216 | self.posterior_mean_coef1 = ( 217 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 218 | ) 219 | self.posterior_mean_coef2 = ( 220 | (1.0 - self.alphas_cumprod_prev) 221 | * np.sqrt(alphas) 222 | / (1.0 - self.alphas_cumprod) 223 | ) 224 | 225 | def q_mean_variance(self, x_start, t): 226 | """ 227 | Get the distribution q(x_t | x_0). 228 | 229 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 230 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 231 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 232 | """ 233 | mean = ( 234 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 235 | ) 236 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 237 | log_variance = _extract_into_tensor( 238 | self.log_one_minus_alphas_cumprod, t, x_start.shape 239 | ) 240 | return mean, variance, log_variance 241 | 242 | def q_sample(self, x_start, t, noise=None, option = '', x_next = None): 243 | """ 244 | Diffuse the data for a given number of diffusion steps. 245 | 246 | In other words, sample from q(x_t | x_0). 247 | 248 | :param x_start: the initial data batch. 249 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 250 | :param noise: if specified, the split-out normal noise. 251 | :return: A noisy version of x_start. 252 | """ 253 | 254 | if option[:8] == 'discrete': 255 | classwise_gammas = None 256 | if option.split('_')[2] == 'classwise': 257 | t = _extract_into_tensor(self.timestep_map,t,t.shape).type(th.long) 258 | B, _, H, W = x_start.shape 259 | psiphi = th.from_numpy(self.psiphi).to(t.device).unsqueeze(0).repeat(B, 1) 260 | log_psiphi = th.log(psiphi) 261 | tmp = ((th.exp(log_psiphi*(t/1000).unsqueeze(1)) - 1.0)/(psiphi - 1.0 + 1e-7)).clamp(max=(t/1000).unsqueeze(1)).unsqueeze(1).unsqueeze(2) 262 | classwise_gammas = 1.0 - (x_start.permute(0,2,3,1)*tmp).sum(dim=3) 263 | 264 | if option.split('_')[1] == 'zero': 265 | x_start = x_start.permute(0,2,3,1) # b c w h -> b w h c 266 | if noise is None: 267 | noise = th.rand((x_start.shape[0],x_start.shape[1],x_start.shape[2],1),device=x_start.device) 268 | if option.split('_')[2] == 'classwise': 269 | noise = classwise_gammas.unsqueeze(3) + noise 270 | else: 271 | noise = _extract_into_tensor(self.gammas, t, noise.shape) + noise 272 | mask = noise.floor() 273 | return (x_start * mask).permute(0,3,1,2) 274 | 275 | else: 276 | if noise is None: 277 | noise = th.randn_like(x_start) 278 | assert noise.shape == x_start.shape 279 | return ( 280 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 281 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 282 | * noise 283 | ) 284 | 285 | def q_posterior_mean_variance(self, x_start, x_t, t): 286 | """ 287 | Compute the mean and variance of the diffusion posterior: 288 | 289 | q(x_{t-1} | x_t, x_0) 290 | 291 | """ 292 | assert x_start.shape == x_t.shape 293 | posterior_mean = ( 294 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 295 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 296 | ) 297 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 298 | posterior_log_variance_clipped = _extract_into_tensor( 299 | self.posterior_log_variance_clipped, t, x_t.shape 300 | ) 301 | assert ( 302 | posterior_mean.shape[0] 303 | == posterior_variance.shape[0] 304 | == posterior_log_variance_clipped.shape[0] 305 | == x_start.shape[0] 306 | ) 307 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 308 | 309 | def p_mean_variance( 310 | self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None 311 | ): 312 | """ 313 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 314 | the initial x, x_0. 315 | 316 | :param model: the model, which takes a signal and a batch of timesteps 317 | as input. 318 | :param x: the [N x C x ...] tensor at time t. 319 | :param t: a 1-D Tensor of timesteps. 320 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 321 | :param denoised_fn: if not None, a function which applies to the 322 | x_start prediction before it is used to sample. Applies before 323 | clip_denoised. 324 | :param model_kwargs: if not None, a dict of extra keyword arguments to 325 | pass to the model. This can be used for conditioning. 326 | :return: a dict with the following keys: 327 | - 'mean': the model mean output. 328 | - 'variance': the model variance output. 329 | - 'log_variance': the log of 'variance'. 330 | - 'pred_xstart': the prediction for x_0. 331 | """ 332 | 333 | if model_kwargs is None: 334 | model_kwargs = {} 335 | 336 | B, C = x.shape[:2] 337 | assert t.shape == (B,) 338 | if 'y' in model_kwargs: 339 | cond_label = model_kwargs['y'] 340 | 341 | if self.no_instance == False: 342 | cond_label, cond_edge = th.split(model_kwargs['y'], (model_kwargs['y'].shape[1]-1, 1), dim=1) 343 | 344 | if self.cond_diffuse: 345 | if self.cond_opt[:8] == 'discrete': 346 | if 'y_noise' in model_kwargs: 347 | cond_label = self.q_sample(cond_label, t, noise = model_kwargs['y_noise'], option = self.cond_opt) 348 | else: 349 | cond_label = self.q_sample(cond_label, t, option = self.cond_opt) 350 | else: 351 | cond_label = self.q_sample(cond_label, t) 352 | 353 | if self.no_instance == False: 354 | cond = th.cat((cond_label, cond_edge), dim=1) 355 | else: 356 | cond = cond_label 357 | 358 | model_output = model(x, self._scale_timesteps(t), y=cond) 359 | else: 360 | model_output = model(x, self._scale_timesteps(t), **model_kwargs) 361 | 362 | if 's' in model_kwargs and model_kwargs['s'] > 1.0: 363 | model_output_zero = model(x, self._scale_timesteps(t), y=th.zeros_like(model_kwargs['y'])) 364 | model_output[:, :3] = model_output_zero[:, :3] + model_kwargs['s'] * (model_output[:, :3] - model_output_zero[:, :3]) 365 | 366 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 367 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 368 | model_output, model_var_values = th.split(model_output, C, dim=1) 369 | if self.model_var_type == ModelVarType.LEARNED: 370 | model_log_variance = model_var_values 371 | model_variance = th.exp(model_log_variance) 372 | else: 373 | min_log = _extract_into_tensor( 374 | self.posterior_log_variance_clipped, t, x.shape 375 | ) 376 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 377 | # The model_var_values is [-1, 1] for [min_var, max_var]. 378 | frac = (model_var_values + 1) / 2 379 | model_log_variance = frac * max_log + (1 - frac) * min_log 380 | model_variance = th.exp(model_log_variance) 381 | else: 382 | model_variance, model_log_variance = { 383 | # for fixedlarge, we set the initial (log-)variance like so 384 | # to get a better decoder log likelihood. 385 | ModelVarType.FIXED_LARGE: ( 386 | np.append(self.posterior_variance[1], self.betas[1:]), 387 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 388 | ), 389 | ModelVarType.FIXED_SMALL: ( 390 | self.posterior_variance, 391 | self.posterior_log_variance_clipped, 392 | ), 393 | }[self.model_var_type] 394 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 395 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 396 | 397 | def process_xstart(x, eps = None): 398 | if denoised_fn is not None: 399 | x = denoised_fn(x) 400 | 401 | if 'extrapolation' in model_kwargs: 402 | if model_kwargs['extrapolation'] > 0.0: 403 | if 'prev_x' in model_kwargs: 404 | x = x + model_kwargs['extrapolation']*(x - model_kwargs['prev_x']) 405 | model_kwargs['prev_x'] = x 406 | else: 407 | model_kwargs['prev_x'] = x 408 | 409 | if clip_denoised: 410 | if 'dynamic_threshold' in model_kwargs: 411 | if model_kwargs['dynamic_threshold'] > 0.0: 412 | x = dynamic_thresholding(x,model_kwargs['dynamic_threshold']) 413 | else: 414 | x = x.clamp(-1, 1) 415 | 416 | if 'mean' in model_kwargs and 'std' in model_kwargs: 417 | x = (x - x.mean(dim=(2, 3), keepdim=True)) / x.std(dim=(2, 3), keepdim=True) 418 | x = x * model_kwargs["std"][None, :, None, None] + model_kwargs["mean"][None, :, None, None] 419 | 420 | return x 421 | 422 | if self.model_mean_type == ModelMeanType.PREVIOUS_X: 423 | pred_xstart = process_xstart( 424 | self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) 425 | ) 426 | model_mean = model_output 427 | elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: 428 | if self.model_mean_type == ModelMeanType.START_X: 429 | pred_xstart = process_xstart(model_output) 430 | else: 431 | pred_xstart = process_xstart( 432 | x = self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output), eps = model_output 433 | ) 434 | model_mean, _, _ = self.q_posterior_mean_variance( 435 | x_start=pred_xstart, x_t=x, t=t 436 | ) 437 | else: 438 | raise NotImplementedError(self.model_mean_type) 439 | 440 | assert ( 441 | model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 442 | ) 443 | return { 444 | "mean": model_mean, 445 | "variance": model_variance, 446 | "log_variance": model_log_variance, 447 | "pred_xstart": pred_xstart, 448 | } 449 | 450 | def _predict_xstart_from_eps(self, x_t, t, eps): 451 | assert x_t.shape == eps.shape 452 | return ( 453 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 454 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 455 | ) 456 | 457 | def _predict_xstart_from_xprev(self, x_t, t, xprev): 458 | assert x_t.shape == xprev.shape 459 | return ( 460 | _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev 461 | - _extract_into_tensor( 462 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape 463 | ) 464 | * x_t 465 | ) 466 | 467 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 468 | return ( 469 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 470 | - pred_xstart 471 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 472 | 473 | def _scale_timesteps(self, t): 474 | if self.rescale_timesteps: 475 | return t.float() * (1000.0 / self.num_timesteps) 476 | return t 477 | 478 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 479 | """ 480 | Compute the mean for the previous step, given a function cond_fn that 481 | computes the gradient of a conditional log probability with respect to 482 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 483 | condition on y. 484 | 485 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 486 | """ 487 | gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) 488 | new_mean = ( 489 | p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 490 | ) 491 | return new_mean 492 | 493 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 494 | """ 495 | Compute what the p_mean_variance output would have been, should the 496 | model's score function be conditioned by cond_fn. 497 | 498 | See condition_mean() for details on cond_fn. 499 | 500 | Unlike condition_mean(), this instead uses the conditioning strategy 501 | from Song et al (2020). 502 | """ 503 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 504 | 505 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 506 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn( 507 | x, self._scale_timesteps(t), **model_kwargs 508 | ) 509 | 510 | out = p_mean_var.copy() 511 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 512 | out["mean"], _, _ = self.q_posterior_mean_variance( 513 | x_start=out["pred_xstart"], x_t=x, t=t 514 | ) 515 | return out 516 | 517 | def p_sample( 518 | self, 519 | model, 520 | x, 521 | t, 522 | clip_denoised=True, 523 | denoised_fn=None, 524 | cond_fn=None, 525 | model_kwargs=None, 526 | ): 527 | """ 528 | Sample x_{t-1} from the model at the given timestep. 529 | 530 | :param model: the model to sample from. 531 | :param x: the current tensor at x_{t-1}. 532 | :param t: the value of t, starting at 0 for the first diffusion step. 533 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 534 | :param denoised_fn: if not None, a function which applies to the 535 | x_start prediction before it is used to sample. 536 | :param cond_fn: if not None, this is a gradient function that acts 537 | similarly to the model. 538 | :param model_kwargs: if not None, a dict of extra keyword arguments to 539 | pass to the model. This can be used for conditioning. 540 | :return: a dict containing the following keys: 541 | - 'sample': a random sample from the model. 542 | - 'pred_xstart': a prediction of x_0. 543 | """ 544 | out = self.p_mean_variance( 545 | model, 546 | x, 547 | t, 548 | clip_denoised=clip_denoised, 549 | denoised_fn=denoised_fn, 550 | model_kwargs=model_kwargs, 551 | ) 552 | noise = th.randn_like(x) 553 | nonzero_mask = ( 554 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 555 | ) 556 | if cond_fn is not None: 557 | out["mean"] = self.condition_mean( 558 | cond_fn, out, x, t, model_kwargs=model_kwargs 559 | ) 560 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 561 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 562 | 563 | def p_sample_loop( 564 | self, 565 | model, 566 | shape, 567 | noise=None, 568 | clip_denoised=True, 569 | denoised_fn=None, 570 | cond_fn=None, 571 | model_kwargs=None, 572 | device=None, 573 | progress=False 574 | ): 575 | """ 576 | Generate samples from the model. 577 | 578 | :param model: the model module. 579 | :param shape: the shape of the samples, (N, C, H, W). 580 | :param noise: if specified, the noise from the encoder to sample. 581 | Should be of the same shape as `shape`. 582 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 583 | :param denoised_fn: if not None, a function which applies to the 584 | x_start prediction before it is used to sample. 585 | :param cond_fn: if not None, this is a gradient function that acts 586 | similarly to the model. 587 | :param model_kwargs: if not None, a dict of extra keyword arguments to 588 | pass to the model. This can be used for conditioning. 589 | :param device: if specified, the device to create the samples on. 590 | If not specified, use a model parameter's device. 591 | :param progress: if True, show a tqdm progress bar. 592 | :return: a non-differentiable batch of samples. 593 | """ 594 | 595 | final = None 596 | 597 | for i, sample in enumerate(self.p_sample_loop_progressive( 598 | model, 599 | shape, 600 | noise=noise, 601 | clip_denoised=clip_denoised, 602 | denoised_fn=denoised_fn, 603 | cond_fn=cond_fn, 604 | model_kwargs=model_kwargs, 605 | device=device, 606 | progress=progress, 607 | )): 608 | final = sample 609 | return final["sample"] 610 | 611 | def p_sample_loop_progressive( 612 | self, 613 | model, 614 | shape, 615 | noise=None, 616 | clip_denoised=True, 617 | denoised_fn=None, 618 | cond_fn=None, 619 | model_kwargs=None, 620 | device=None, 621 | progress=False, 622 | start_indices = None, 623 | ): 624 | """ 625 | Generate samples from the model and yield intermediate samples from 626 | each timestep of diffusion. 627 | 628 | Arguments are the same as p_sample_loop(). 629 | Returns a generator over dicts, where each dict is the return value of 630 | p_sample(). 631 | """ 632 | 633 | if device is None: 634 | device = next(model.parameters()).device 635 | assert isinstance(shape, (tuple, list)) 636 | if noise is not None: 637 | img = noise 638 | else: 639 | img = th.randn(*shape, device=device) 640 | if 'y' in model_kwargs: 641 | model_kwargs['y'] = model_kwargs['y'].to(device) 642 | if 'y_t' in model_kwargs and model_kwargs['y_t'] != None: 643 | model_kwargs['y_t'] = model_kwargs['y_t'].to(device) 644 | if start_indices == None: 645 | indices = list(range(self.num_timesteps))[::-1] 646 | else: 647 | indices = list(range(start_indices))[::-1] 648 | if progress: 649 | # Lazy import so that we don't depend on tqdm. 650 | from tqdm.auto import tqdm 651 | 652 | indices = tqdm(indices) 653 | # 654 | epsilon = img 655 | 656 | for i in indices: 657 | t = th.tensor([i] * shape[0], device=device) 658 | with th.no_grad(): 659 | model_kwargs["epsilon"] = epsilon 660 | out = self.p_sample( 661 | model, 662 | img, 663 | t, 664 | clip_denoised=clip_denoised, 665 | denoised_fn=denoised_fn, 666 | cond_fn=cond_fn, 667 | model_kwargs=model_kwargs, 668 | ) 669 | yield out 670 | img = out["sample"] 671 | pred_xstart = out["pred_xstart"] 672 | epsilon = (img - _extract_into_tensor(self.sqrt_alphas_cumprod, t, img.shape)*pred_xstart)/_extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, img.shape) 673 | 674 | 675 | def _vb_terms_bpd( 676 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 677 | ): 678 | """ 679 | Get a term for the variational lower-bound. 680 | 681 | The resulting units are bits (rather than nats, as one might expect). 682 | This allows for comparison to other papers. 683 | 684 | :return: a dict with the following keys: 685 | - 'output': a shape [N] tensor of NLLs or KLs. 686 | - 'pred_xstart': the x_0 predictions. 687 | """ 688 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 689 | x_start=x_start, x_t=x_t, t=t 690 | ) 691 | out = self.p_mean_variance( 692 | model, x_t, t,clip_denoised=clip_denoised, model_kwargs=model_kwargs 693 | ) 694 | kl = normal_kl( 695 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 696 | ) 697 | kl = mean_flat(kl) / np.log(2.0) 698 | 699 | decoder_nll = -discretized_gaussian_log_likelihood( 700 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 701 | ) 702 | assert decoder_nll.shape == x_start.shape 703 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 704 | 705 | # At the first timestep return the decoder NLL, 706 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 707 | output = th.where((t == 0), decoder_nll, kl) 708 | return {"output": output, "pred_xstart": out["pred_xstart"]} 709 | 710 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 711 | """ 712 | Compute training losses for a single timestep. 713 | 714 | :param model: the model to evaluate loss on. 715 | :param x_start: the [N x C x ...] tensor of inputs. 716 | :param t: a batch of timestep indices. 717 | :param model_kwargs: if not None, a dict of extra keyword arguments to 718 | pass to the model. This can be used for conditioning. 719 | :param noise: if specified, the specific Gaussian noise to try to remove. 720 | :return: a dict with the key "loss" containing a tensor of shape [N]. 721 | Some mean or variance settings may also have other keys. 722 | """ 723 | if model_kwargs is None: 724 | model_kwargs = {} 725 | if noise is None: 726 | noise = th.randn_like(x_start) 727 | x_t = self.q_sample(x_start, t, noise=noise) 728 | 729 | terms = {} 730 | 731 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 732 | terms["loss"] = self._vb_terms_bpd( 733 | model=model, 734 | x_start=x_start, 735 | x_t=x_t, 736 | t=t, 737 | clip_denoised=False, 738 | model_kwargs=model_kwargs, 739 | )["output"] 740 | if self.loss_type == LossType.RESCALED_KL: 741 | terms["loss"] *= self.num_timesteps 742 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 743 | cond_label = model_kwargs['y'] 744 | if self.no_instance == False: 745 | cond_label, cond_edge = th.split(model_kwargs['y'], (model_kwargs['y'].shape[1]-1, 1), dim=1) 746 | if self.cond_diffuse: 747 | if self.cond_opt[:8] == 'discrete': 748 | cond_label = self.q_sample(cond_label, t, option=self.cond_opt) 749 | else: 750 | cond_noise = th.randn_like(cond_label) 751 | cond_label = self.q_sample(cond_label, t, noise=cond_noise) 752 | 753 | if self.no_instance == False: 754 | cond = th.cat((cond_label, cond_edge), dim=1) 755 | else: 756 | cond = cond_label 757 | 758 | 759 | model_output = model(x_t, self._scale_timesteps(t), y=cond) 760 | 761 | if self.model_var_type in [ 762 | ModelVarType.LEARNED, 763 | ModelVarType.LEARNED_RANGE, 764 | ]: 765 | B, C = x_t.shape[:2] 766 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 767 | model_output, model_var_values = th.split(model_output, C, dim=1) 768 | # Learn the variance using the variational bound, but don't let 769 | # it affect our mean prediction. 770 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 771 | terms["vb"] = self._vb_terms_bpd( 772 | model=lambda *args, r=frozen_out: r, 773 | x_start=x_start, 774 | x_t=x_t, 775 | t=t, 776 | clip_denoised=False, 777 | )["output"] 778 | if self.loss_type == LossType.RESCALED_MSE: 779 | # Divide by 1000 for equivalence with initial implementation. 780 | # Without a factor of 1/1000, the VB term hurts the MSE term. 781 | terms["vb"] *= self.num_timesteps / 1000.0 782 | 783 | target = { 784 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 785 | x_start=x_start, x_t=x_t, t=t 786 | )[0], 787 | ModelMeanType.START_X: x_start, 788 | ModelMeanType.EPSILON: noise, 789 | }[self.model_mean_type] 790 | assert model_output.shape == target.shape == x_start.shape 791 | terms["mse"] = mean_flat((target - model_output) ** 2) 792 | if "vb" in terms: 793 | terms["loss"] = terms["mse"] + terms["vb"] 794 | else: 795 | terms["loss"] = terms["mse"] 796 | else: 797 | raise NotImplementedError(self.loss_type) 798 | 799 | return terms 800 | 801 | def _prior_bpd(self, x_start): 802 | """ 803 | Get the prior KL term for the variational lower-bound, measured in 804 | bits-per-dim. 805 | 806 | This term can't be optimized, as it only depends on the encoder. 807 | 808 | :param x_start: the [N x C x ...] tensor of inputs. 809 | :return: a batch of [N] KL values (in bits), one per batch element. 810 | """ 811 | batch_size = x_start.shape[0] 812 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 813 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 814 | kl_prior = normal_kl( 815 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 816 | ) 817 | return mean_flat(kl_prior) / np.log(2.0) 818 | 819 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 820 | """ 821 | Compute the entire variational lower-bound, measured in bits-per-dim, 822 | as well as other related quantities. 823 | 824 | :param model: the model to evaluate loss on. 825 | :param x_start: the [N x C x ...] tensor of inputs. 826 | :param clip_denoised: if True, clip denoised samples. 827 | :param model_kwargs: if not None, a dict of extra keyword arguments to 828 | pass to the model. This can be used for conditioning. 829 | 830 | :return: a dict containing the following keys: 831 | - total_bpd: the total variational lower-bound, per batch element. 832 | - prior_bpd: the prior term in the lower-bound. 833 | - vb: an [N x T] tensor of terms in the lower-bound. 834 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 835 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 836 | """ 837 | device = x_start.device 838 | batch_size = x_start.shape[0] 839 | 840 | vb = [] 841 | xstart_mse = [] 842 | mse = [] 843 | for t in list(range(self.num_timesteps))[::-1]: 844 | t_batch = th.tensor([t] * batch_size, device=device) 845 | noise = th.randn_like(x_start) 846 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 847 | # Calculate VLB term at the current timestep 848 | with th.no_grad(): 849 | out = self._vb_terms_bpd( 850 | model, 851 | x_start=x_start, 852 | x_t=x_t, 853 | t=t_batch, 854 | clip_denoised=clip_denoised, 855 | model_kwargs=model_kwargs, 856 | ) 857 | vb.append(out["output"]) 858 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 859 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 860 | mse.append(mean_flat((eps - noise) ** 2)) 861 | 862 | vb = th.stack(vb, dim=1) 863 | xstart_mse = th.stack(xstart_mse, dim=1) 864 | mse = th.stack(mse, dim=1) 865 | 866 | prior_bpd = self._prior_bpd(x_start) 867 | total_bpd = vb.sum(dim=1) + prior_bpd 868 | return { 869 | "total_bpd": total_bpd, 870 | "prior_bpd": prior_bpd, 871 | "vb": vb, 872 | "xstart_mse": xstart_mse, 873 | "mse": mse, 874 | } 875 | 876 | 877 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 878 | """ 879 | Extract values from a 1-D numpy array for a batch of indices. 880 | 881 | :param arr: the 1-D numpy array. 882 | :param timesteps: a tensor of indices into the array to extract. 883 | :param broadcast_shape: a larger shape of K dimensions with the batch 884 | dimension equal to the length of timesteps. 885 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 886 | """ 887 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 888 | while len(res.shape) < len(broadcast_shape): 889 | res = res[..., None] 890 | return res.expand(broadcast_shape) -------------------------------------------------------------------------------- /guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | 5 | from PIL import Image 6 | import blobfile as bf 7 | from mpi4py import MPI 8 | import numpy as np 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | 12 | def load_data( 13 | *, 14 | dataset_mode, 15 | data_dir, 16 | batch_size, 17 | image_size, 18 | class_cond=False, 19 | deterministic=False, 20 | random_crop=True, 21 | random_flip=True, 22 | is_train=True, 23 | ): 24 | """ 25 | For a dataset, create a generator over (images, kwargs) pairs. 26 | 27 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 28 | more keys, each of which map to a batched Tensor of their own. 29 | The kwargs dict can be used for class labels, in which case the key is "y" 30 | and the values are integer tensors of class labels. 31 | 32 | :param data_dir: a dataset directory. 33 | :param batch_size: the batch size of each returned pair. 34 | :param image_size: the size to which images are resized. 35 | :param class_cond: if True, include a "y" key in returned dicts for class 36 | label. If classes are not available and this is true, an 37 | exception will be raised. 38 | :param deterministic: if True, yield results in a deterministic order. 39 | :param random_crop: if True, randomly crop the images for augmentation. 40 | :param random_flip: if True, randomly flip the images for augmentation. 41 | """ 42 | if not data_dir: 43 | raise ValueError("unspecified data directory") 44 | 45 | if dataset_mode == 'cityscapes': 46 | all_files = _list_image_files_recursively(os.path.join(data_dir, 'leftImg8bit', 'train' if is_train else 'val')) 47 | labels_file = _list_image_files_recursively(os.path.join(data_dir, 'gtFine', 'train' if is_train else 'val')) 48 | classes = [x for x in labels_file if x.endswith('_labelIds.png')] 49 | instances = [x for x in labels_file if x.endswith('_instanceIds.png')] 50 | elif dataset_mode == 'ade20k': 51 | all_files = _list_image_files_recursively(os.path.join(data_dir, 'images', 'training' if is_train else 'validation')) 52 | classes = _list_image_files_recursively(os.path.join(data_dir, 'annotations', 'training' if is_train else 'validation')) 53 | instances = None 54 | elif dataset_mode == 'celeba': 55 | # The edge is computed by the instances. 56 | # However, the edge get from the labels and the instances are the same on CelebA. 57 | # You can take either as instance input 58 | all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'images')) 59 | classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels')) 60 | instances = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels')) 61 | elif dataset_mode == 'coco': 62 | all_files = _list_image_files_recursively(os.path.join(data_dir, 'train_img' if is_train else 'val_img')) 63 | classes = _list_image_files_recursively(os.path.join(data_dir, 'train_label' if is_train else 'val_label')) 64 | instances = _list_image_files_recursively(os.path.join(data_dir, 'train_inst' if is_train else 'val_inst')) 65 | 66 | else: 67 | raise NotImplementedError('{} not implemented'.format(dataset_mode)) 68 | 69 | print("Len of Dataset:", len(all_files)) 70 | 71 | dataset = ImageDataset( 72 | dataset_mode, 73 | image_size, 74 | all_files, 75 | classes=classes, 76 | instances=instances, 77 | shard=MPI.COMM_WORLD.Get_rank(), 78 | num_shards=MPI.COMM_WORLD.Get_size(), 79 | random_crop=random_crop, 80 | random_flip=random_flip, 81 | is_train=is_train 82 | ) 83 | 84 | if deterministic: 85 | loader = DataLoader( 86 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 87 | ) 88 | else: 89 | loader = DataLoader( 90 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 91 | ) 92 | while True: 93 | yield from loader 94 | 95 | 96 | def _list_image_files_recursively(data_dir): 97 | results = [] 98 | for entry in sorted(bf.listdir(data_dir)): 99 | full_path = bf.join(data_dir, entry) 100 | ext = entry.split(".")[-1] 101 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 102 | results.append(full_path) 103 | elif bf.isdir(full_path): 104 | results.extend(_list_image_files_recursively(full_path)) 105 | return results 106 | 107 | 108 | class ImageDataset(Dataset): 109 | def __init__( 110 | self, 111 | dataset_mode, 112 | resolution, 113 | image_paths, 114 | classes=None, 115 | instances=None, 116 | shard=0, 117 | num_shards=1, 118 | random_crop=False, 119 | random_flip=True, 120 | is_train=True 121 | ): 122 | super().__init__() 123 | self.is_train = is_train 124 | self.dataset_mode = dataset_mode 125 | self.resolution = resolution 126 | self.local_images = image_paths[shard:][::num_shards] 127 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 128 | self.local_instances = None if instances is None else instances[shard:][::num_shards] 129 | self.random_crop = random_crop 130 | self.random_flip = random_flip 131 | 132 | def __len__(self): 133 | return len(self.local_images) 134 | 135 | def __getitem__(self, idx): 136 | path = self.local_images[idx] 137 | with bf.BlobFile(path, "rb") as f: 138 | pil_image = Image.open(f) 139 | pil_image.load() 140 | pil_image = pil_image.convert("RGB") 141 | 142 | out_dict = {} 143 | class_path = self.local_classes[idx] 144 | with bf.BlobFile(class_path, "rb") as f: 145 | pil_class = Image.open(f) 146 | pil_class.load() 147 | pil_class = pil_class.convert("L") 148 | 149 | if self.local_instances is not None: 150 | instance_path = self.local_instances[idx] # DEBUG: from classes to instances, may affect CelebA 151 | with bf.BlobFile(instance_path, "rb") as f: 152 | pil_instance = Image.open(f) 153 | pil_instance.load() 154 | pil_instance = pil_instance.convert("L") 155 | else: 156 | pil_instance = None 157 | 158 | if self.dataset_mode == 'cityscapes': 159 | arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution) 160 | else: 161 | if self.is_train: 162 | if self.random_crop: 163 | arr_image, arr_class, arr_instance = random_crop_arr([pil_image, pil_class, pil_instance], self.resolution) 164 | else: 165 | arr_image, arr_class, arr_instance = center_crop_arr([pil_image, pil_class, pil_instance], self.resolution) 166 | else: 167 | arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution, keep_aspect=False) 168 | 169 | if self.random_flip and random.random() < 0.5: 170 | arr_image = arr_image[:, ::-1].copy() 171 | arr_class = arr_class[:, ::-1].copy() 172 | arr_instance = arr_instance[:, ::-1].copy() if arr_instance is not None else None 173 | 174 | arr_image = arr_image.astype(np.float32) / 127.5 - 1 175 | 176 | out_dict['path'] = path 177 | out_dict['label_ori'] = arr_class.copy() 178 | 179 | if self.dataset_mode == 'ade20k': 180 | arr_class = arr_class - 1 181 | arr_class[arr_class == 255] = 150 182 | elif self.dataset_mode == 'coco': 183 | arr_class[arr_class == 255] = 182 184 | 185 | out_dict['label'] = arr_class[None, ] 186 | 187 | if arr_instance is not None: 188 | out_dict['instance'] = arr_instance[None, ] 189 | 190 | return np.transpose(arr_image, [2, 0, 1]), out_dict 191 | 192 | 193 | def resize_arr(pil_list, image_size, keep_aspect=True): 194 | # We are not on a new enough PIL to support the `reducing_gap` 195 | # argument, which uses BOX downsampling at powers of two first. 196 | # Thus, we do it by hand to improve downsample quality. 197 | pil_image, pil_class, pil_instance = pil_list 198 | 199 | while min(*pil_image.size) >= 2 * image_size: 200 | pil_image = pil_image.resize( 201 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 202 | ) 203 | 204 | if keep_aspect: 205 | scale = image_size / min(*pil_image.size) 206 | pil_image = pil_image.resize( 207 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 208 | ) 209 | else: 210 | pil_image = pil_image.resize((image_size, image_size), resample=Image.BICUBIC) 211 | 212 | pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) 213 | if pil_instance is not None: 214 | pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) 215 | 216 | arr_image = np.array(pil_image) 217 | arr_class = np.array(pil_class) 218 | arr_instance = np.array(pil_instance) if pil_instance is not None else None 219 | return arr_image, arr_class, arr_instance 220 | 221 | 222 | def center_crop_arr(pil_list, image_size): 223 | # We are not on a new enough PIL to support the `reducing_gap` 224 | # argument, which uses BOX downsampling at powers of two first. 225 | # Thus, we do it by hand to improve downsample quality. 226 | pil_image, pil_class, pil_instance = pil_list 227 | 228 | while min(*pil_image.size) >= 2 * image_size: 229 | pil_image = pil_image.resize( 230 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 231 | ) 232 | 233 | scale = image_size / min(*pil_image.size) 234 | pil_image = pil_image.resize( 235 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 236 | ) 237 | 238 | pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) 239 | if pil_instance is not None: 240 | pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) 241 | 242 | arr_image = np.array(pil_image) 243 | arr_class = np.array(pil_class) 244 | arr_instance = np.array(pil_instance) if pil_instance is not None else None 245 | crop_y = (arr_image.shape[0] - image_size) // 2 246 | crop_x = (arr_image.shape[1] - image_size) // 2 247 | return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\ 248 | arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\ 249 | arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None 250 | 251 | 252 | def random_crop_arr(pil_list, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 253 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 254 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 255 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 256 | 257 | # We are not on a new enough PIL to support the `reducing_gap` 258 | # argument, which uses BOX downsampling at powers of two first. 259 | # Thus, we do it by hand to improve downsample quality. 260 | pil_image, pil_class, pil_instance = pil_list 261 | 262 | while min(*pil_image.size) >= 2 * smaller_dim_size: 263 | pil_image = pil_image.resize( 264 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 265 | ) 266 | 267 | scale = smaller_dim_size / min(*pil_image.size) 268 | pil_image = pil_image.resize( 269 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 270 | ) 271 | 272 | pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) 273 | if pil_instance is not None: 274 | pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) 275 | 276 | arr_image = np.array(pil_image) 277 | arr_class = np.array(pil_class) 278 | arr_instance = np.array(pil_instance) if pil_instance is not None else None 279 | crop_y = random.randrange(arr_image.shape[0] - image_size + 1) 280 | crop_x = random.randrange(arr_image.shape[1] - image_size + 1) 281 | return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\ 282 | arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\ 283 | arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None 284 | -------------------------------------------------------------------------------- /guided_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv,tensorboard").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | self.timestep_map = np.array(self.timestep_map) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | if self.rescale_timesteps: 128 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /guided_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | from .unet import SuperResModel, UNetModel, EncoderUNetModel 7 | 8 | NUM_CLASSES = 1000 9 | 10 | 11 | def diffusion_defaults(): 12 | """ 13 | Defaults for image and classifier training. 14 | """ 15 | return dict( 16 | learn_sigma=False, 17 | diffusion_steps=1000, 18 | noise_schedule="linear", 19 | timestep_respacing="", 20 | use_kl=False, 21 | predict_xstart=False, 22 | rescale_timesteps=False, 23 | rescale_learned_sigmas=False, 24 | cond_diffuse=False, 25 | cond_opt = "", 26 | dataset_mode="ade20k" 27 | ) 28 | 29 | 30 | def classifier_defaults(): 31 | """ 32 | Defaults for classifier models. 33 | """ 34 | return dict( 35 | image_size=64, 36 | classifier_use_fp16=False, 37 | classifier_width=128, 38 | classifier_depth=2, 39 | classifier_attention_resolutions="32,16,8", # 16 40 | classifier_use_scale_shift_norm=True, # False 41 | classifier_resblock_updown=True, # False 42 | classifier_pool="attention", 43 | ) 44 | 45 | 46 | def model_and_diffusion_defaults(): 47 | """ 48 | Defaults for image training. 49 | """ 50 | res = dict( 51 | image_size=64, 52 | num_classes=151, 53 | num_channels=128, 54 | num_res_blocks=2, 55 | num_heads=4, 56 | num_heads_upsample=-1, 57 | num_head_channels=-1, 58 | attention_resolutions="16,8", 59 | channel_mult="", 60 | dropout=0.0, 61 | class_cond=False, 62 | use_checkpoint=False, 63 | use_scale_shift_norm=True, 64 | resblock_updown=False, 65 | use_fp16=False, 66 | use_new_attention_order=False, 67 | no_instance=False, 68 | ) 69 | res.update(diffusion_defaults()) 70 | return res 71 | 72 | 73 | def classifier_and_diffusion_defaults(): 74 | res = classifier_defaults() 75 | res.update(diffusion_defaults()) 76 | return res 77 | 78 | 79 | def create_model_and_diffusion( 80 | image_size, 81 | class_cond, 82 | learn_sigma, 83 | num_classes, 84 | no_instance, 85 | num_channels, 86 | num_res_blocks, 87 | channel_mult, 88 | num_heads, 89 | num_head_channels, 90 | num_heads_upsample, 91 | attention_resolutions, 92 | dropout, 93 | diffusion_steps, 94 | noise_schedule, 95 | timestep_respacing, 96 | use_kl, 97 | predict_xstart, 98 | rescale_timesteps, 99 | rescale_learned_sigmas, 100 | use_checkpoint, 101 | use_scale_shift_norm, 102 | resblock_updown, 103 | use_fp16, 104 | use_new_attention_order, 105 | cond_diffuse, 106 | cond_opt, 107 | dataset_mode 108 | ): 109 | model = create_model( 110 | image_size, 111 | num_classes, 112 | num_channels, 113 | num_res_blocks, 114 | channel_mult=channel_mult, 115 | learn_sigma=learn_sigma, 116 | class_cond=class_cond, 117 | use_checkpoint=use_checkpoint, 118 | attention_resolutions=attention_resolutions, 119 | num_heads=num_heads, 120 | num_head_channels=num_head_channels, 121 | num_heads_upsample=num_heads_upsample, 122 | use_scale_shift_norm=use_scale_shift_norm, 123 | dropout=dropout, 124 | resblock_updown=resblock_updown, 125 | use_fp16=use_fp16, 126 | use_new_attention_order=use_new_attention_order, 127 | no_instance=no_instance, 128 | ) 129 | diffusion = create_gaussian_diffusion( 130 | steps=diffusion_steps, 131 | learn_sigma=learn_sigma, 132 | noise_schedule=noise_schedule, 133 | use_kl=use_kl, 134 | predict_xstart=predict_xstart, 135 | rescale_timesteps=rescale_timesteps, 136 | rescale_learned_sigmas=rescale_learned_sigmas, 137 | timestep_respacing=timestep_respacing, 138 | cond_diffuse=cond_diffuse, 139 | cond_opt=cond_opt, 140 | no_instance=no_instance, 141 | dataset_mode=dataset_mode 142 | ) 143 | return model, diffusion 144 | 145 | 146 | def create_model( 147 | image_size, 148 | num_classes, 149 | num_channels, 150 | num_res_blocks, 151 | channel_mult="", 152 | learn_sigma=False, 153 | class_cond=False, 154 | use_checkpoint=False, 155 | attention_resolutions="16", 156 | num_heads=1, 157 | num_head_channels=-1, 158 | num_heads_upsample=-1, 159 | use_scale_shift_norm=False, 160 | dropout=0, 161 | resblock_updown=False, 162 | use_fp16=False, 163 | use_new_attention_order=False, 164 | no_instance=False, 165 | ): 166 | if channel_mult == "": 167 | if image_size == 512: 168 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 169 | elif image_size == 256: 170 | channel_mult = (1, 1, 2, 2, 4, 4) 171 | elif image_size == 128: 172 | channel_mult = (1, 1, 2, 3, 4) 173 | elif image_size == 64: 174 | channel_mult = (1, 2, 3, 4) 175 | else: 176 | raise ValueError(f"unsupported image size: {image_size}") 177 | else: 178 | channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) 179 | 180 | attention_ds = [] 181 | for res in attention_resolutions.split(","): 182 | attention_ds.append(image_size // int(res)) 183 | 184 | num_classes = num_classes if no_instance else num_classes + 1 185 | 186 | return UNetModel( 187 | image_size=image_size, 188 | in_channels=3, 189 | model_channels=num_channels, 190 | out_channels=(3 if not learn_sigma else 6), 191 | num_res_blocks=num_res_blocks, 192 | attention_resolutions=tuple(attention_ds), 193 | dropout=dropout, 194 | channel_mult=channel_mult, 195 | num_classes=(num_classes if class_cond else None), 196 | use_checkpoint=use_checkpoint, 197 | use_fp16=use_fp16, 198 | num_heads=num_heads, 199 | num_head_channels=num_head_channels, 200 | num_heads_upsample=num_heads_upsample, 201 | use_scale_shift_norm=use_scale_shift_norm, 202 | resblock_updown=resblock_updown, 203 | use_new_attention_order=use_new_attention_order, 204 | ) 205 | 206 | 207 | def create_classifier_and_diffusion( 208 | image_size, 209 | classifier_use_fp16, 210 | classifier_width, 211 | classifier_depth, 212 | classifier_attention_resolutions, 213 | classifier_use_scale_shift_norm, 214 | classifier_resblock_updown, 215 | classifier_pool, 216 | learn_sigma, 217 | diffusion_steps, 218 | noise_schedule, 219 | timestep_respacing, 220 | use_kl, 221 | predict_xstart, 222 | rescale_timesteps, 223 | rescale_learned_sigmas, 224 | ): 225 | classifier = create_classifier( 226 | image_size, 227 | classifier_use_fp16, 228 | classifier_width, 229 | classifier_depth, 230 | classifier_attention_resolutions, 231 | classifier_use_scale_shift_norm, 232 | classifier_resblock_updown, 233 | classifier_pool, 234 | ) 235 | diffusion = create_gaussian_diffusion( 236 | steps=diffusion_steps, 237 | learn_sigma=learn_sigma, 238 | noise_schedule=noise_schedule, 239 | use_kl=use_kl, 240 | predict_xstart=predict_xstart, 241 | rescale_timesteps=rescale_timesteps, 242 | rescale_learned_sigmas=rescale_learned_sigmas, 243 | timestep_respacing=timestep_respacing, 244 | ) 245 | return classifier, diffusion 246 | 247 | 248 | def create_classifier( 249 | image_size, 250 | classifier_use_fp16, 251 | classifier_width, 252 | classifier_depth, 253 | classifier_attention_resolutions, 254 | classifier_use_scale_shift_norm, 255 | classifier_resblock_updown, 256 | classifier_pool, 257 | ): 258 | if image_size == 512: 259 | channel_mult = (0.5, 1, 1, 2, 2, 4, 4) 260 | elif image_size == 256: 261 | channel_mult = (1, 1, 2, 2, 4, 4) 262 | elif image_size == 128: 263 | channel_mult = (1, 1, 2, 3, 4) 264 | elif image_size == 64: 265 | channel_mult = (1, 2, 3, 4) 266 | else: 267 | raise ValueError(f"unsupported image size: {image_size}") 268 | 269 | attention_ds = [] 270 | for res in classifier_attention_resolutions.split(","): 271 | attention_ds.append(image_size // int(res)) 272 | 273 | return EncoderUNetModel( 274 | image_size=image_size, 275 | in_channels=3, 276 | model_channels=classifier_width, 277 | out_channels=1000, 278 | num_res_blocks=classifier_depth, 279 | attention_resolutions=tuple(attention_ds), 280 | channel_mult=channel_mult, 281 | use_fp16=classifier_use_fp16, 282 | num_head_channels=64, 283 | use_scale_shift_norm=classifier_use_scale_shift_norm, 284 | resblock_updown=classifier_resblock_updown, 285 | pool=classifier_pool, 286 | ) 287 | 288 | 289 | def sr_model_and_diffusion_defaults(): 290 | res = model_and_diffusion_defaults() 291 | res["large_size"] = 256 292 | res["small_size"] = 64 293 | arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] 294 | for k in res.copy().keys(): 295 | if k not in arg_names: 296 | del res[k] 297 | return res 298 | 299 | 300 | def sr_create_model_and_diffusion( 301 | large_size, 302 | small_size, 303 | class_cond, 304 | learn_sigma, 305 | num_channels, 306 | num_res_blocks, 307 | num_heads, 308 | num_head_channels, 309 | num_heads_upsample, 310 | attention_resolutions, 311 | dropout, 312 | diffusion_steps, 313 | noise_schedule, 314 | timestep_respacing, 315 | use_kl, 316 | predict_xstart, 317 | rescale_timesteps, 318 | rescale_learned_sigmas, 319 | use_checkpoint, 320 | use_scale_shift_norm, 321 | resblock_updown, 322 | use_fp16, 323 | ): 324 | model = sr_create_model( 325 | large_size, 326 | small_size, 327 | num_channels, 328 | num_res_blocks, 329 | learn_sigma=learn_sigma, 330 | class_cond=class_cond, 331 | use_checkpoint=use_checkpoint, 332 | attention_resolutions=attention_resolutions, 333 | num_heads=num_heads, 334 | num_head_channels=num_head_channels, 335 | num_heads_upsample=num_heads_upsample, 336 | use_scale_shift_norm=use_scale_shift_norm, 337 | dropout=dropout, 338 | resblock_updown=resblock_updown, 339 | use_fp16=use_fp16, 340 | ) 341 | diffusion = create_gaussian_diffusion( 342 | steps=diffusion_steps, 343 | learn_sigma=learn_sigma, 344 | noise_schedule=noise_schedule, 345 | use_kl=use_kl, 346 | predict_xstart=predict_xstart, 347 | rescale_timesteps=rescale_timesteps, 348 | rescale_learned_sigmas=rescale_learned_sigmas, 349 | timestep_respacing=timestep_respacing, 350 | ) 351 | return model, diffusion 352 | 353 | 354 | def sr_create_model( 355 | large_size, 356 | small_size, 357 | num_channels, 358 | num_res_blocks, 359 | learn_sigma, 360 | class_cond, 361 | use_checkpoint, 362 | attention_resolutions, 363 | num_heads, 364 | num_head_channels, 365 | num_heads_upsample, 366 | use_scale_shift_norm, 367 | dropout, 368 | resblock_updown, 369 | use_fp16, 370 | ): 371 | _ = small_size # hack to prevent unused variable 372 | 373 | if large_size == 512: 374 | channel_mult = (1, 1, 2, 2, 4, 4) 375 | elif large_size == 256: 376 | channel_mult = (1, 1, 2, 2, 4, 4) 377 | elif large_size == 64: 378 | channel_mult = (1, 2, 3, 4) 379 | else: 380 | raise ValueError(f"unsupported large size: {large_size}") 381 | 382 | attention_ds = [] 383 | for res in attention_resolutions.split(","): 384 | attention_ds.append(large_size // int(res)) 385 | 386 | return SuperResModel( 387 | image_size=large_size, 388 | in_channels=3, 389 | model_channels=num_channels, 390 | out_channels=(3 if not learn_sigma else 6), 391 | num_res_blocks=num_res_blocks, 392 | attention_resolutions=tuple(attention_ds), 393 | dropout=dropout, 394 | channel_mult=channel_mult, 395 | num_classes=(NUM_CLASSES if class_cond else None), 396 | use_checkpoint=use_checkpoint, 397 | num_heads=num_heads, 398 | num_head_channels=num_head_channels, 399 | num_heads_upsample=num_heads_upsample, 400 | use_scale_shift_norm=use_scale_shift_norm, 401 | resblock_updown=resblock_updown, 402 | use_fp16=use_fp16, 403 | ) 404 | 405 | 406 | def create_gaussian_diffusion( 407 | *, 408 | steps=1000, 409 | learn_sigma=False, 410 | sigma_small=False, 411 | noise_schedule="linear", 412 | use_kl=False, 413 | predict_xstart=False, 414 | rescale_timesteps=False, 415 | rescale_learned_sigmas=False, 416 | timestep_respacing="", 417 | cond_diffuse = False, 418 | cond_opt, 419 | no_instance, 420 | dataset_mode="ade20k" 421 | ): 422 | betas = gd.get_named_beta_schedule(noise_schedule, steps) 423 | if use_kl: 424 | loss_type = gd.LossType.RESCALED_KL 425 | elif rescale_learned_sigmas: 426 | loss_type = gd.LossType.RESCALED_MSE 427 | else: 428 | loss_type = gd.LossType.MSE 429 | if not timestep_respacing: 430 | timestep_respacing = [steps] 431 | return SpacedDiffusion( 432 | use_timesteps=space_timesteps(steps, timestep_respacing), 433 | betas=betas, 434 | model_mean_type=( 435 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 436 | ), 437 | model_var_type=( 438 | ( 439 | gd.ModelVarType.FIXED_LARGE 440 | if not sigma_small 441 | else gd.ModelVarType.FIXED_SMALL 442 | ) 443 | if not learn_sigma 444 | else gd.ModelVarType.LEARNED_RANGE 445 | ), 446 | loss_type=loss_type, 447 | rescale_timesteps=rescale_timesteps, 448 | cond_diffuse = cond_diffuse, 449 | cond_opt = cond_opt, 450 | no_instance = no_instance, 451 | dataset_mode = dataset_mode 452 | ) 453 | 454 | 455 | def add_dict_to_argparser(parser, default_dict): 456 | for k, v in default_dict.items(): 457 | v_type = type(v) 458 | if v is None: 459 | v_type = str 460 | elif isinstance(v, bool): 461 | v_type = str2bool 462 | parser.add_argument(f"--{k}", default=v, type=v_type) 463 | 464 | 465 | def args_to_dict(args, keys): 466 | return {k: getattr(args, k) for k in keys} 467 | 468 | 469 | def str2bool(v): 470 | """ 471 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 472 | """ 473 | if isinstance(v, bool): 474 | return v 475 | if v.lower() in ("yes", "true", "t", "y", "1"): 476 | return True 477 | elif v.lower() in ("no", "false", "f", "n", "0"): 478 | return False 479 | else: 480 | raise argparse.ArgumentTypeError("boolean value expected") 481 | -------------------------------------------------------------------------------- /guided_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | 16 | # For ImageNet experiments, this was a good default value. 17 | # We found that the lg_loss_scale quickly climbed to 18 | # 20-21 within the first ~1K steps of training. 19 | INITIAL_LOG_LOSS_SCALE = 20.0 20 | 21 | 22 | class TrainLoop: 23 | def __init__( 24 | self, 25 | *, 26 | model, 27 | diffusion, 28 | data, 29 | num_classes, 30 | batch_size, 31 | microbatch, 32 | lr, 33 | ema_rate, 34 | drop_rate, 35 | log_interval, 36 | save_interval, 37 | resume_checkpoint, 38 | use_fp16=False, 39 | fp16_scale_growth=1e-3, 40 | schedule_sampler=None, 41 | weight_decay=0.0, 42 | lr_anneal_steps=0, 43 | ): 44 | self.model = model 45 | self.diffusion = diffusion 46 | self.data = data 47 | self.num_classes = num_classes 48 | self.batch_size = batch_size 49 | self.microbatch = microbatch if microbatch > 0 else batch_size 50 | self.lr = lr 51 | self.ema_rate = ( 52 | [ema_rate] 53 | if isinstance(ema_rate, float) 54 | else [float(x) for x in ema_rate.split(",")] 55 | ) 56 | self.drop_rate = drop_rate 57 | self.log_interval = log_interval 58 | self.save_interval = save_interval 59 | self.resume_checkpoint = resume_checkpoint 60 | self.use_fp16 = use_fp16 61 | self.fp16_scale_growth = fp16_scale_growth 62 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 63 | self.weight_decay = weight_decay 64 | self.lr_anneal_steps = lr_anneal_steps 65 | 66 | self.step = 0 67 | self.resume_step = 0 68 | self.global_batch = self.batch_size * dist.get_world_size() 69 | 70 | self.sync_cuda = th.cuda.is_available() 71 | 72 | self._load_and_sync_parameters() 73 | self.mp_trainer = MixedPrecisionTrainer( 74 | model=self.model, 75 | use_fp16=self.use_fp16, 76 | fp16_scale_growth=fp16_scale_growth, 77 | ) 78 | 79 | self.opt = AdamW( 80 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 81 | ) 82 | if self.resume_step: 83 | self._load_optimizer_state() 84 | # Model was resumed, either due to a restart or a checkpoint 85 | # being specified at the command line. 86 | self.ema_params = [ 87 | self._load_ema_parameters(rate) for rate in self.ema_rate 88 | ] 89 | else: 90 | self.ema_params = [ 91 | copy.deepcopy(self.mp_trainer.master_params) 92 | for _ in range(len(self.ema_rate)) 93 | ] 94 | 95 | if th.cuda.is_available(): 96 | self.use_ddp = True 97 | self.ddp_model = DDP( 98 | self.model, 99 | device_ids=[dist_util.dev()], 100 | output_device=dist_util.dev(), 101 | broadcast_buffers=False, 102 | bucket_cap_mb=128, 103 | find_unused_parameters=False, 104 | ) 105 | else: 106 | if dist.get_world_size() > 1: 107 | logger.warn( 108 | "Distributed training requires CUDA. " 109 | "Gradients will not be synchronized properly!" 110 | ) 111 | self.use_ddp = False 112 | self.ddp_model = self.model 113 | 114 | def _load_and_sync_parameters(self): 115 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 116 | 117 | if resume_checkpoint: 118 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 119 | if dist.get_rank() == 0: 120 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 121 | self.model.load_state_dict( 122 | th.load( 123 | resume_checkpoint, map_location=dist_util.dev() 124 | ) 125 | ) 126 | 127 | dist_util.sync_params(self.model.parameters()) 128 | 129 | def _load_ema_parameters(self, rate): 130 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 131 | 132 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 133 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 134 | if ema_checkpoint: 135 | if dist.get_rank() == 0: 136 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 137 | state_dict = th.load( 138 | ema_checkpoint, map_location=dist_util.dev() 139 | ) 140 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 141 | 142 | dist_util.sync_params(ema_params) 143 | return ema_params 144 | 145 | def _load_optimizer_state(self): 146 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 147 | opt_checkpoint = bf.join( 148 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 149 | ) 150 | if bf.exists(opt_checkpoint): 151 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 152 | state_dict = th.load( 153 | opt_checkpoint, map_location=dist_util.dev() 154 | ) 155 | self.opt.load_state_dict(state_dict) 156 | 157 | if self.opt.param_groups[0]['lr'] != self.lr: 158 | self.opt.param_groups[0]['lr'] = self.lr 159 | 160 | def run_loop(self): 161 | while ( 162 | not self.lr_anneal_steps 163 | or self.step + self.resume_step < self.lr_anneal_steps 164 | ): 165 | batch, cond = next(self.data) 166 | cond = self.preprocess_input(cond) 167 | self.run_step(batch, cond) 168 | if self.step % self.log_interval == 0: 169 | logger.dumpkvs() 170 | if self.step % self.save_interval == 0: 171 | self.save() 172 | # Run for a finite amount of time in integration tests. 173 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 174 | return 175 | self.step += 1 176 | # Save the last checkpoint if it wasn't already saved. 177 | if (self.step - 1) % self.save_interval != 0: 178 | self.save() 179 | 180 | def run_step(self, batch, cond): 181 | self.forward_backward(batch, cond) 182 | took_step = self.mp_trainer.optimize(self.opt) 183 | if took_step: 184 | self._update_ema() 185 | self._anneal_lr() 186 | self.log_step() 187 | 188 | def forward_backward(self, batch, cond): 189 | self.mp_trainer.zero_grad() 190 | for i in range(0, batch.shape[0], self.microbatch): 191 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 192 | micro_cond = { 193 | k: v[i : i + self.microbatch].to(dist_util.dev()) 194 | for k, v in cond.items() 195 | } 196 | last_batch = (i + self.microbatch) >= batch.shape[0] 197 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 198 | 199 | compute_losses = functools.partial( 200 | self.diffusion.training_losses, 201 | self.ddp_model, 202 | micro, 203 | t, 204 | model_kwargs=micro_cond, 205 | ) 206 | 207 | if last_batch or not self.use_ddp: 208 | losses = compute_losses() 209 | else: 210 | with self.ddp_model.no_sync(): 211 | losses = compute_losses() 212 | 213 | if isinstance(self.schedule_sampler, LossAwareSampler): 214 | self.schedule_sampler.update_with_local_losses( 215 | t, losses["loss"].detach() 216 | ) 217 | 218 | loss = (losses["loss"] * weights).mean() 219 | log_loss_dict( 220 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 221 | ) 222 | self.mp_trainer.backward(loss) 223 | 224 | def _update_ema(self): 225 | for rate, params in zip(self.ema_rate, self.ema_params): 226 | update_ema(params, self.mp_trainer.master_params, rate=rate) 227 | 228 | def _anneal_lr(self): 229 | if not self.lr_anneal_steps: 230 | return 231 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 232 | lr = self.lr * (1 - frac_done) 233 | for param_group in self.opt.param_groups: 234 | param_group["lr"] = lr 235 | 236 | def log_step(self): 237 | logger.logkv("step", self.step + self.resume_step) 238 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 239 | 240 | def save(self): 241 | def save_checkpoint(rate, params): 242 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 243 | if dist.get_rank() == 0: 244 | logger.log(f"saving model {rate}...") 245 | if not rate: 246 | filename = f"model{(self.step+self.resume_step):06d}.pt" 247 | else: 248 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 249 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 250 | th.save(state_dict, f) 251 | 252 | save_checkpoint(0, self.mp_trainer.master_params) 253 | for rate, params in zip(self.ema_rate, self.ema_params): 254 | save_checkpoint(rate, params) 255 | 256 | if dist.get_rank() == 0: 257 | with bf.BlobFile( 258 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 259 | "wb", 260 | ) as f: 261 | th.save(self.opt.state_dict(), f) 262 | 263 | dist.barrier() 264 | 265 | def preprocess_input(self, data): 266 | # move to GPU and change data types 267 | data['label'] = data['label'].long() 268 | 269 | # create one-hot label map 270 | label_map = data['label'] 271 | bs, _, h, w = label_map.size() 272 | nc = self.num_classes 273 | input_label = th.FloatTensor(bs, nc, h, w).zero_() 274 | input_semantics = input_label.scatter_(1, label_map, 1.0) 275 | 276 | # concatenate instance map if it exists 277 | if 'instance' in data: 278 | inst_map = data['instance'] 279 | instance_edge_map = self.get_edges(inst_map) 280 | input_semantics = th.cat((input_semantics, instance_edge_map), dim=1) 281 | 282 | if self.drop_rate > 0.0: 283 | mask = (th.rand([input_semantics.shape[0], 1, 1, 1]) > self.drop_rate).float() 284 | input_semantics = input_semantics * mask 285 | 286 | cond = {key: value for key, value in data.items() if key not in ['label', 'instance', 'path', 'label_ori']} 287 | cond['y'] = input_semantics 288 | 289 | return cond 290 | 291 | def get_edges(self, t): 292 | edge = th.ByteTensor(t.size()).zero_() 293 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 294 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 295 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 296 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 297 | return edge.float() 298 | 299 | 300 | def parse_resume_step_from_filename(filename): 301 | """ 302 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 303 | checkpoint's number of steps. 304 | """ 305 | split = filename.split("model") 306 | if len(split) < 2: 307 | return 0 308 | split1 = split[-1].split(".")[0] 309 | try: 310 | return int(split1) 311 | except ValueError: 312 | return 0 313 | 314 | 315 | def get_blob_logdir(): 316 | # You can change this to be a separate path to save checkpoints to 317 | # a blobstore or some external drive. 318 | return logger.get_dir() 319 | 320 | 321 | def find_resume_checkpoint(): 322 | # On your infrastructure, you may want to override this to automatically 323 | # discover the latest checkpoint on your blob storage, etc. 324 | return None 325 | 326 | 327 | def find_ema_checkpoint(main_checkpoint, step, rate): 328 | if main_checkpoint is None: 329 | return None 330 | filename = f"ema_{rate}_{(step):06d}.pt" 331 | path = bf.join(bf.dirname(main_checkpoint), filename) 332 | if bf.exists(path): 333 | return path 334 | return None 335 | 336 | 337 | def log_loss_dict(diffusion, ts, losses): 338 | for key, values in losses.items(): 339 | logger.logkv_mean(key, values.mean().item()) 340 | # Log the quantiles (four quartiles, in particular). 341 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 342 | quartile = int(4 * sub_t / diffusion.num_timesteps) 343 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 344 | -------------------------------------------------------------------------------- /image_process/downsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | from PIL import Image 4 | 5 | def resize(target_path, save_path, size, opt): 6 | for item in os.listdir(target_path): 7 | img = Image.open(os.path.join(target_path,item)) 8 | img = img.resize(size = (size,size), resample = opt) 9 | img.save(os.path.join(save_path,item)) 10 | 11 | 12 | 13 | target_path = "data/ADE20K/ADEChallengeData2016/annotations/validation" 14 | save_path = "data/ADE20K_noisy/ADE20K_DS/ADEChallengeData2016/annotations/validation" 15 | size = 64 16 | opt = PIL.Image.NEAREST 17 | # os.makedirs(save_path) 18 | 19 | resize(target_path,save_path,size,opt) -------------------------------------------------------------------------------- /image_process/edge.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from guided_diffusion.image_datasets import load_data 3 | from guided_diffusion import dist_util, logger 4 | import torchvision as tv 5 | import os 6 | import cv2 7 | import random 8 | 9 | def get_edges(t): 10 | edge = th.ByteTensor(t.size()).zero_() 11 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 12 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 13 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 14 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 15 | return edge.float() 16 | 17 | dist_util.setup_dist() 18 | 19 | dataset_mode = 'ade20k' 20 | data_dir = 'data/ADE20K/ADEChallengeData2016' 21 | new_dir = 'data/ADE20K_noisy/ADE20K_edge' 22 | batch_size = 50 23 | image_size = 256 24 | class_cond = True 25 | 26 | data = load_data( 27 | dataset_mode=dataset_mode, 28 | data_dir=data_dir, 29 | batch_size=batch_size, 30 | image_size=image_size, 31 | class_cond=class_cond, 32 | deterministic=True, 33 | random_crop=False, 34 | random_flip=False, 35 | is_train=False 36 | ) 37 | 38 | 39 | for i , (batch, cond) in enumerate(data): 40 | if 'instance' in cond: 41 | inst_map = cond['instance'] 42 | instance_edge_map = get_edges(inst_map) 43 | instance_edge_map = instance_edge_map.permute(0,2,3,1) 44 | map_shape = instance_edge_map.shape 45 | new_edge_map = th.ByteTensor(inst_map.size()).zero_().float().permute(0,2,3,1) 46 | 47 | for b in range(map_shape[0]): 48 | for x in range(map_shape[1]): 49 | for y in range(map_shape[2]): 50 | if instance_edge_map[b][x][y] == 1: 51 | rad = 2 52 | for dx in range(-rad,rad+1): 53 | for dy in range(-(rad-abs(dx)),(rad-abs(dx))+1): 54 | if x+dx >= 0 and x+dx = 0 and y+dy = (2000/batch_size): 68 | break -------------------------------------------------------------------------------- /image_process/random.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from guided_diffusion.image_datasets import load_data 3 | from guided_diffusion import dist_util, logger 4 | import torchvision as tv 5 | import os 6 | import cv2 7 | import random 8 | 9 | def get_edges(t): 10 | edge = th.ByteTensor(t.size()).zero_() 11 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 12 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 13 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 14 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 15 | return edge.float() 16 | 17 | dist_util.setup_dist() 18 | 19 | dataset_mode = 'ade20k' 20 | data_dir = 'data/ADE20K/ADEChallengeData2016' 21 | new_dir = 'data/ADE20K_noisy/ADE20K_random' 22 | batch_size = 50 23 | image_size = 256 24 | class_cond = True 25 | 26 | data = load_data( 27 | dataset_mode=dataset_mode, 28 | data_dir=data_dir, 29 | batch_size=batch_size, 30 | image_size=image_size, 31 | class_cond=class_cond, 32 | deterministic=True, 33 | random_crop=False, 34 | random_flip=False, 35 | is_train=False 36 | ) 37 | 38 | for i , (batch, cond) in enumerate(data): 39 | if 'instance' in cond: 40 | inst_map = cond['instance'] 41 | instance_edge_map = get_edges(inst_map) 42 | instance_edge_map = instance_edge_map.permute(0,2,3,1) 43 | map_shape = instance_edge_map.shape 44 | 45 | for b in range(map_shape[0]): 46 | for x in range(map_shape[1]): 47 | for y in range(map_shape[2]): 48 | instance_edge_map[b][x][y] = 0 49 | if random.random() < 0.1: 50 | instance_edge_map[b][x][y] = 1 51 | 52 | instance_edge_map = instance_edge_map.permute(0,3,1,2) 53 | edge_unlabeled = instance_edge_map * (-1) + 1 54 | newlabel = cond['label_ori'] 55 | newlabel = newlabel * edge_unlabeled.squeeze(1) 56 | for j in range(inst_map.shape[0]): 57 | cv2.imwrite(os.path.join(new_dir, cond['path'][j].split('/')[-1].split('.')[0] + '.png'), newlabel[j].cpu().numpy()) 58 | if i >= (2000/batch_size): 59 | break 60 | -------------------------------------------------------------------------------- /image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import random 9 | 10 | import torch as th 11 | import torch.distributed as dist 12 | import torchvision as tv 13 | import numpy as np 14 | 15 | from guided_diffusion.image_datasets import load_data 16 | 17 | from guided_diffusion import dist_util, logger 18 | from guided_diffusion.script_util import ( 19 | model_and_diffusion_defaults, 20 | create_model_and_diffusion, 21 | add_dict_to_argparser, 22 | args_to_dict, 23 | ) 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | 29 | logger.log("seed: "+str(args.seed)) 30 | deterministic = True 31 | random.seed(args.seed) 32 | np.random.seed(args.seed) 33 | th.manual_seed(args.seed) 34 | if deterministic: 35 | th.backends.cudnn.deterministic = True 36 | th.backends.cudnn.benchmark = False 37 | 38 | 39 | dist_util.setup_dist(gpus_per_node=args.gpus_per_node) 40 | logger.configure() 41 | 42 | logger.log("creating model and diffusion...") 43 | model, diffusion = create_model_and_diffusion( 44 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 45 | ) 46 | 47 | model.load_state_dict( 48 | dist_util.load_state_dict(args.model_path, map_location="cpu") 49 | ) 50 | model.to(dist_util.dev()) 51 | 52 | logger.log("creating data loader...") 53 | data = load_data( 54 | dataset_mode=args.dataset_mode, 55 | data_dir=args.data_dir, 56 | batch_size=args.batch_size, 57 | image_size=args.image_size, 58 | class_cond=args.class_cond, 59 | deterministic=True, 60 | random_crop=False, 61 | random_flip=False, 62 | is_train=False 63 | ) 64 | 65 | if args.use_fp16: 66 | model.convert_to_fp16() 67 | model.eval() 68 | 69 | image_path = os.path.join(args.results_path, 'images') 70 | os.makedirs(image_path, exist_ok=True) 71 | label_path = os.path.join(args.results_path, 'labels') 72 | os.makedirs(label_path, exist_ok=True) 73 | sample_path = os.path.join(args.results_path, 'samples') 74 | os.makedirs(sample_path, exist_ok=True) 75 | 76 | logger.log("sampling...") 77 | all_samples = [] 78 | 79 | for i, (batch, cond) in enumerate(data): 80 | image = ((batch + 1.0) / 2.0).cuda() 81 | label = (cond['label_ori'].float() / 255.0).cuda() 82 | model_kwargs = preprocess_input(cond, num_classes=args.num_classes) 83 | 84 | model_kwargs['results_path'] = args.results_path 85 | 86 | if args.unlabeled_to_zero: 87 | y = model_kwargs['y'] 88 | mask = th.nn.functional.one_hot(th.tensor([y.shape[1]-1]),num_classes=y.shape[1]) 89 | mask = th.ones_like(mask) - mask 90 | mask = mask.unsqueeze(1).unsqueeze(2) 91 | model_kwargs['y'] = (y.permute(0,2,3,1)*mask).permute(0,3,1,2) 92 | 93 | if args.seed_align: 94 | y = model_kwargs['y'] 95 | _ = th.rand((y.shape[0],y.shape[2],y.shape[3]),device=image.device) 96 | 97 | # set hyperparameter 98 | model_kwargs['s'] = args.s 99 | model_kwargs['dynamic_threshold'] = args.dynamic_threshold 100 | model_kwargs['extrapolation'] = args.extrapolation 101 | 102 | sample_fn = ( 103 | diffusion.p_sample_loop 104 | ) 105 | 106 | sample = sample_fn( 107 | model, 108 | (args.batch_size, 3, image.shape[2], image.shape[3]), 109 | clip_denoised=args.clip_denoised, 110 | model_kwargs=model_kwargs, 111 | progress=True, 112 | noise=None 113 | ) 114 | sample = (sample + 1) / 2.0 115 | 116 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 117 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 118 | all_samples.extend([sample.cpu().numpy() for sample in gathered_samples]) 119 | 120 | for j in range(sample.shape[0]): 121 | tv.utils.save_image(image[j], os.path.join(image_path, cond['path'][j].split('/')[-1].split('.')[0] + '.png')) 122 | tv.utils.save_image(sample[j], os.path.join(sample_path, cond['path'][j].split('/')[-1].split('.')[0] + '.png')) 123 | tv.utils.save_image(label[j], os.path.join(label_path, cond['path'][j].split('/')[-1].split('.')[0] + '.png')) 124 | 125 | logger.log(f"created {len(all_samples) * args.batch_size} samples") 126 | 127 | if len(all_samples) * args.batch_size >= args.num_samples: 128 | break 129 | 130 | dist.barrier() 131 | logger.log("sampling complete") 132 | 133 | 134 | def preprocess_input(data, num_classes): 135 | # move to GPU and change data types 136 | data['label'] = data['label'].long() 137 | 138 | # create one-hot label map 139 | label_map = data['label'] 140 | bs, _, h, w = label_map.size() 141 | input_label = th.FloatTensor(bs, num_classes, h, w).zero_() 142 | input_semantics = input_label.scatter_(1, label_map, 1.0) 143 | 144 | # concatenate instance map if it exists 145 | if 'instance' in data: 146 | inst_map = data['instance'] 147 | instance_edge_map = get_edges(inst_map) 148 | input_semantics = th.cat((input_semantics, instance_edge_map), dim=1) 149 | 150 | return {'y': input_semantics} 151 | 152 | 153 | def get_edges(t): 154 | edge = th.ByteTensor(t.size()).zero_() 155 | edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 156 | edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) 157 | edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 158 | edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) 159 | return edge.float() 160 | 161 | 162 | def create_argparser(): 163 | defaults = dict( 164 | data_dir="", 165 | dataset_mode="", 166 | clip_denoised=True, 167 | num_samples=10000, 168 | batch_size=1, 169 | model_path="", 170 | results_path="", 171 | is_train=False, 172 | s=1.0, 173 | gpus_per_node=4, 174 | gpu=None, 175 | unlabeled_to_zero = False, 176 | dynamic_threshold = -1.0, 177 | extrapolation = -1.0, 178 | seed = 42, 179 | seed_align = False, 180 | ) 181 | defaults.update(model_and_diffusion_defaults()) 182 | parser = argparse.ArgumentParser() 183 | add_dict_to_argparser(parser, defaults) 184 | return parser 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | -------------------------------------------------------------------------------- /image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import os 6 | import argparse 7 | 8 | from guided_diffusion import dist_util, logger 9 | from guided_diffusion.image_datasets import load_data 10 | from guided_diffusion.resample import create_named_schedule_sampler 11 | from guided_diffusion.script_util import ( 12 | model_and_diffusion_defaults, 13 | create_model_and_diffusion, 14 | args_to_dict, 15 | add_dict_to_argparser, 16 | ) 17 | from guided_diffusion.train_util import TrainLoop 18 | 19 | 20 | def main(): 21 | args = create_argparser().parse_args() 22 | 23 | dist_util.setup_dist(gpus_per_node=args.gpus_per_node) 24 | logger.configure() 25 | 26 | logger.log("creating model and diffusion...") 27 | model, diffusion = create_model_and_diffusion( 28 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 29 | ) 30 | 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_data( 36 | dataset_mode=args.dataset_mode, 37 | data_dir=args.data_dir, 38 | batch_size=args.batch_size, 39 | image_size=args.image_size, 40 | class_cond=args.class_cond, 41 | is_train=args.is_train 42 | ) 43 | 44 | logger.log("training...") 45 | TrainLoop( 46 | model=model, 47 | diffusion=diffusion, 48 | data=data, 49 | num_classes=args.num_classes, 50 | batch_size=args.batch_size, 51 | microbatch=args.microbatch, 52 | lr=args.lr, 53 | ema_rate=args.ema_rate, 54 | drop_rate=args.drop_rate, 55 | log_interval=args.log_interval, 56 | save_interval=args.save_interval, 57 | resume_checkpoint=args.resume_checkpoint, 58 | use_fp16=args.use_fp16, 59 | fp16_scale_growth=args.fp16_scale_growth, 60 | schedule_sampler=schedule_sampler, 61 | weight_decay=args.weight_decay, 62 | lr_anneal_steps=args.lr_anneal_steps, 63 | ).run_loop() 64 | 65 | 66 | def create_argparser(): 67 | defaults = dict( 68 | data_dir="", 69 | dataset_mode="", 70 | schedule_sampler="uniform", 71 | lr=1e-4, 72 | weight_decay=0.0, 73 | lr_anneal_steps=0, 74 | batch_size=1, 75 | microbatch=-1, # -1 disables microbatches 76 | ema_rate="0.9999", # comma-separated list of EMA values 77 | drop_rate=0.0, 78 | log_interval=10, 79 | save_interval=10000, 80 | resume_checkpoint="", 81 | use_fp16=False, 82 | fp16_scale_growth=1e-3, 83 | is_train=True, 84 | gpus_per_node=4, 85 | ) 86 | defaults.update(model_and_diffusion_defaults()) 87 | parser = argparse.ArgumentParser() 88 | add_dict_to_argparser(parser, defaults) 89 | return parser 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /scripts/sample.sh: -------------------------------------------------------------------------------- 1 | export OPENAI_LOGDIR='logs/sample' 2 | export NCCL_P2P_DISABLE=1 3 | 4 | # CelebAMask-HQ 5 | # 25-step Generation 6 | mpiexec -n 4 python image_sample.py --data_dir data/celeba --dataset_mode celeba --attention_resolutions 32,16,8 \ 7 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 8 | --use_scale_shift_norm True --num_classes 19 --class_cond True --no_instance False --batch_size 10 --num_samples 2000 --model_path checkpoints/scdm_ade20k/model.pt \ 9 | --results_path results/celeba-25 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 10 | --gpus_per_node 4 --s 1.5 --dynamic_threshold 0.95 --extrapolation 0.8 --timestep_respacing ddim25; 11 | 12 | # 1000-step Generation 13 | mpiexec -n 4 python image_sample.py --data_dir data/celeba --dataset_mode celeba --attention_resolutions 32,16,8 \ 14 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 15 | --use_scale_shift_norm True --num_classes 19 --class_cond True --no_instance False --batch_size 10 --num_samples 2000 --model_path checkpoints/scdm_coco.pt \ 16 | --results_path results/celeba-1000 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 17 | --gpus_per_node 4 --s 1.5 18 | 19 | ################## 20 | 21 | # ADE20K 22 | # 25-step Generation 23 | mpiexec -n 4 python image_sample.py --data_dir data/ade20k --dataset_mode ade20k --attention_resolutions 32,16,8 \ 24 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 25 | --use_scale_shift_norm True --num_classes 151 --class_cond True --no_instance True --batch_size 10 --num_samples 2000 --model_path checkpoints/scdm_ade20k/model.pt \ 26 | --results_path results/ade20k-25 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 27 | --gpus_per_node 4 --s 1.5 --dynamic_threshold 0.95 --extrapolation 0.8 --timestep_respacing ddim25; 28 | 29 | # 1000-step Generation 30 | mpiexec -n 4 python image_sample.py --data_dir data/ade20k --dataset_mode ade20k --attention_resolutions 32,16,8 \ 31 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 32 | --use_scale_shift_norm True --num_classes 151 --class_cond True --no_instance True --batch_size 10 --num_samples 2000 --model_path checkpoints/scdm_coco.pt \ 33 | --results_path results/ade20-1000 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 34 | --gpus_per_node 4 --s 1.5 35 | 36 | ################## 37 | 38 | # COCO 39 | # 25-step Generation 40 | mpiexec -n 4 python image_sample.py --data_dir data/coco --dataset_mode coco --attention_resolutions 32,16,8 \ 41 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 42 | --use_scale_shift_norm True --num_classes 183 --class_cond True --no_instance False --batch_size 10 --num_samples 5000 --model_path checkpoints/scdm_ade20k/model.pt \ 43 | --results_path results/coco-25 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 44 | --gpus_per_node 4 --s 1.5 --dynamic_threshold 0.95 --extrapolation 0.8 --timestep_respacing ddim25; 45 | 46 | # 1000-step Generation 47 | mpiexec -n 4 python image_sample.py --data_dir data/coco --dataset_mode coco --attention_resolutions 32,16,8 \ 48 | --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 49 | --use_scale_shift_norm True --num_classes 183 --class_cond True --no_instance False --batch_size 10 --num_samples 5000 --model_path checkpoints/scdm_coco.pt \ 50 | --results_path results/coco-1000 --cond_diffuse True --cond_opt discrete_zero_classwise --seed 42 \ 51 | --gpus_per_node 4 --s 1.5 52 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | export NCCL_P2P_DISABLE=1 2 | 3 | # CelebAMask-HQ 4 | # scratch 5 | export OPENAI_LOGDIR='logs/train/celeba-scratch' 6 | mpiexec -n 4 python image_train.py --data_dir data/CelebAMask-HQ --dataset_mode celeba \ 7 | --lr 2e-5 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 8 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 9 | --use_scale_shift_norm True --use_checkpoint True --num_classes 19 --class_cond True --no_instance False \ 10 | --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 11 | 12 | # resume 13 | export OPENAI_LOGDIR='logs/train/celeba-resume' 14 | mpiexec -n 4 python image_train.py --data_dir data/CelebAMask-HQ --dataset_mode celeba \ 15 | --lr 2e-5 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 16 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 17 | --use_scale_shift_norm True --use_checkpoint True --num_classes 19 --class_cond True --no_instance False \ 18 | --resume_checkpoint checkpoints/celeba/model.pt --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 19 | 20 | 21 | ################## 22 | 23 | # ADE20K 24 | # scratch 25 | export OPENAI_LOGDIR='logs/train/ade20k-scratch' 26 | mpiexec -n 4 python image_train.py --data_dir data/ADE20K/ADEChallengeData2016 --dataset_mode ade20k \ 27 | --lr 1e-4 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 28 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 29 | --use_scale_shift_norm True --use_checkpoint True --num_classes 151 --class_cond True --no_instance True \ 30 | --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 31 | 32 | # resume 33 | export OPENAI_LOGDIR='logs/train/ade20k-resume' 34 | mpiexec -n 4 python image_train.py --data_dir data/ADE20K/ADEChallengeData2016 --dataset_mode ade20k \ 35 | --lr 2e-5 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 36 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 37 | --use_scale_shift_norm True --use_checkpoint True --num_classes 151 --class_cond True --no_instance True \ 38 | --resume_checkpoint checkpoints/ade20k/model.pt --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 39 | 40 | ################## 41 | 42 | # COCO 43 | # scratch 44 | export OPENAI_LOGDIR='logs/train/coco-scratch' 45 | mpiexec -n 4 python image_train.py --data_dir data/coco --dataset_mode coco \ 46 | --lr 1e-4 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 47 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 48 | --use_scale_shift_norm True --use_checkpoint True --num_classes 183 --class_cond True --no_instance False \ 49 | --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 50 | 51 | # resume 52 | export OPENAI_LOGDIR='logs/train/coco-resume' 53 | mpiexec -n 4 python image_train.py --data_dir data/coco --dataset_mode coco \ 54 | --lr 2e-5 --batch_size 10 --attention_resolutions 32,16,8 --diffusion_steps 1000 --image_size 256 --learn_sigma True \ 55 | --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True \ 56 | --use_scale_shift_norm True --use_checkpoint True --num_classes 183 --class_cond True --no_instance False \ 57 | --resume_checkpoint checkpoints/coco/model.pt --save_interval 1000 --cond_diffuse True --gpus_per_node 4 --cond_opt discrete_zero_classwise --drop_rate 0.2 58 | --------------------------------------------------------------------------------