├── .gitignore ├── LICENSE ├── README.md ├── edm ├── Dockerfile ├── LICENSE.txt ├── README.md ├── dataset_tool.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── docs │ ├── afhqv2-64x64.png │ ├── cifar10-32x32.png │ ├── dataset-tool-help.txt │ ├── ffhq-64x64.png │ ├── fid-help.txt │ ├── generate-help.txt │ ├── imagenet-64x64.png │ ├── teaser-1280x640.jpg │ ├── teaser-1920x640.jpg │ ├── teaser-640x480.jpg │ └── train-help.txt ├── environment.yml ├── example.py ├── fid.py ├── generate.py ├── jacobian.py ├── sscd.py ├── torch_utils │ ├── __init__.py │ ├── distributed.py │ ├── misc.py │ ├── persistence.py │ └── training_stats.py ├── train.py ├── trainMoLRG.py └── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ ├── training_loop.py │ └── training_loop_MoLRG.py └── figures ├── generalization-score.png ├── jacobian-MoLRG.png ├── jacobian-real.png ├── optimal-denoiser.png ├── reproducibility-score.png └── similarity.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 huijieZH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Understanding Generalizability of Diffusion Models through Low-dimensional Distribution Learning 2 | 3 | This is an official implementation of the paper 4 | 1. [The Emergence of Reproducibility and Consistency in Diffusion Models](https://arxiv.org/abs/2310.05264) **NeurIPS 2023 workshop Best Paper, ICML 2024** 5 | 2. [Diffusion Models Learn Low-Dimensional Distributions via Subspace Clustering](https://arxiv.org/abs/2409.02426) 6 | 7 | The codebase mainly focuses on the implementation of three main figures from these two papers, including: 8 | 1. "Memorization" and "Generalization" regimes for unconditional diffusion models. (Figure 2 in [Paper 1](https://arxiv.org/abs/2310.05264)) 9 | 2. Convergence of the optimal denoiser. (Figure 4 Left in [Paper 1](https://arxiv.org/abs/2310.05264)) 10 | 3. Similarity among different unconditional diffusion model settings in generalization regime. (Figure 6 and Figure 12 in [Paper 1](https://arxiv.org/abs/2310.05264)) 11 | 4. Low-rank property of the denoising autoencoder of trained diffusion models. (Figure 3 in [Paper 2](https://arxiv.org/abs/2409.02426)) 12 | 13 | For the implementation of Figure 1 (Correspondence between the singular vectors of the Jacobian of the DAE and semantic 14 | image attributes) in [Paper 2](https://arxiv.org/abs/2409.02426), please go through our concurrent work [Exploring Low-Dimensional Subspaces in Diffusion Models for Controllable Image Editing](https://arxiv.org/abs/2409.02374), the codebase could be found [here](https://github.com/ChicyChen/LOCO-Edit). 15 | 16 | ### Requirements 17 | 18 | ```bash 19 | conda env create -f edm/environment.yml -n generalizability 20 | conda activate generalizability 21 | ``` 22 | 23 | ## "Memorization" and "Generalization" regimes for unconditional diffusion models. 24 | 25 |
26 | 27 | 28 |
29 | 30 | Slightly different from Figure 2 in [Paper 1](https://arxiv.org/abs/2310.05264), the code we release is under a finetuning setting: the training dataset is generated from a pre-trained diffusion model (teacher model). 31 | 32 | ### Create Dataset 33 | 34 | Create a dataset of specific dataset size as follows: 35 | 36 | ```bash 37 | # generate images from teacher model 38 | python edm/generate.py --outdir=out --seeds=0-49999 --batch=64 --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl 39 | 40 | # create dataset with different size 41 | python edm/dataset_tool.py --source=out --max-images=128 --dest=datasets/synthetic-cifar10-32x32-n128.zip 42 | ``` 43 | 44 | ### Training 45 | 46 | ```bash 47 | torchrun --standalone --nproc_per_node=1 edm/train.py --outdir=training --data=datasets/synthetic-cifar10-32x32-n128.zip --cond=0 --arch=ddpmpp --duration 50 --batch 128 --snap 500 --dump 500 --precond vp --model_channels 64 48 | ``` 49 | 50 | ### Evaluation 51 | 52 | All checkpoints we released can be found [here](https://www.dropbox.com/scl/fo/m8tf61cengcp1qyevwiwv/AKoLuvIY5Fx0Tz1g8eRFWoI?rlkey=x7t1iqunpzofddgv533bx48q8&st=wtfeg1a9&dl=0), and all training dataset we released can be found [here](https://www.dropbox.com/scl/fo/fqwgl5pvqe4jgvuw945k6/AHEH9P8AYVYhMx_ABTclVC4?rlkey=frsgki669ny9lmxiwlpafanhg&st=kbgrn9za&dl=0) 53 | 54 | ```bash 55 | 56 | ### generate image from diffusion model, the seeds is different from the one (which is 0-49999) used to generate training images from teacher model. 57 | python edm/generate.py --outdir=evaluation/ddpm-dim64-n64 --seeds=100000-109999 --batch=64 --network=training/ckpt/ddpm-dim64-n64.pkl 58 | 59 | python edm/generate.py --outdir=evaluation/ddpm-dim128-n64 --seeds=100000-109999 --batch=64 --network=training/ckpt/ddpm-dim128-n64.pkl 60 | 61 | ### Calculate SSCD feature 62 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim64-n64 --features ./evaluation/sscd-dim64-n64.npz 63 | 64 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim128-n64 --features ./evaluation/sscd-dim128-n64.npz 65 | 66 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images datasets/synthetic-cifar10-32x32-n64.zip --features ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n64.npz 67 | 68 | # Compute reproducibility score 69 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-dim64-n64.npz 70 | 71 | # Compute generalization score 72 | python edm/sscd.py mscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n64.npz 73 | 74 | ``` 75 | 76 | ## Convergence of the optimal denoiser. 77 | 78 |
79 | 80 |
81 | 82 | We implement the optimal denoiser (derived from the score function of the empirial distribution). And compare the RP score between real diffusion model and the optimal denoiser. 83 | 84 | ```bash 85 | ### generate image from optimal denoiser 86 | python edm/generate.py --outdir=evaluation/memorization-n64 --seeds=100000-109999 --batch=64 --optimal_denoiser --dataset=datasets/synthetic-cifar10-32x32-n64.zip --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl 87 | 88 | ### Calculate SSCD feature 89 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/memorization-n64 --features ./evaluation/sscd-memorization-n64.npz 90 | 91 | ### Compute reproducibility score 92 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-memorization-n64.npz 93 | 94 | ``` 95 | 96 | ## Similarity among different unconditional diffusion model settings in generalization regime. 97 | 98 |
99 | 100 |
101 | 102 | We provide generated samples from those different diffusion models [here](https://www.dropbox.com/scl/fo/xq0yvr92ohzb6ov313928/ANo8GzZ5GybCrzJRb2P1qU8?rlkey=iaf6316aezz4wznigj4ir2v29&st=psa7il3e&dl=0). To generate new samples, you need to go through their own github repo and use the same initial noise for generation. 103 | 104 | ```bash 105 | 106 | ### Calculate SSCD feature 107 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./samples/ddpmv4 --features ./evaluation/sscd-ddpmv4.npz 108 | 109 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./samples/ddpmv6 --features ./evaluation/sscd-ddpmv6.npz 110 | 111 | 112 | # Compute reproducibility score 113 | python edm/sscd.py rpscore --source ./evaluation/sscd-ddpmv4.npz --target ./evaluation/sscd-ddpmv6.npz 114 | 115 | 116 | ``` 117 | 118 | ## Low-rank property of the denoising autoencoder of trained diffusion models. 119 | 120 |
121 | 122 | 123 |
124 | These figures illustrate the low-dimensionality of the jacobian of the denoising autoencoder (DAE) trained on real dataset and Mixture of Low Rank Gaussian distribution (MoLRG). 125 | 126 | To training diffusion model with MoLRG: 127 | ```bash 128 | torchrun --standalone --nproc_per_node=1 edm/trainMoLRG.py --outdir training --path datasets --img_res 4 --class_num 2 --per_class_dim 7 --sample_per_class 350 --embed_channels 128 129 | ``` 130 | 131 | To Evaluate rank of the jacobian: 132 | ```bash 133 | torchrun --standalone --nproc_per_node=1 edm/jacobian.py --network_pkl 134 | 135 | e.g. 136 | torchrun --standalone --nproc_per_node=1 edm/jacobian.py --network_pkl https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl 137 | ``` 138 | 139 | Notably, built upon [NVlabs/edm](https://github.com/NVlabs/edm), our codebase is compatible with all training ckpts released from their repo, where you could find [here](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/) and [here](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/). 140 | 141 | ## Acknowledgements 142 | This repository is highly based on [NVlabs/edm](https://github.com/NVlabs/edm). 143 | 144 | ## BibTeX 145 | ``` 146 | @inproceedings{ 147 | zhang2024the, 148 | title={The Emergence of Reproducibility and Consistency in Diffusion Models}, 149 | author={Huijie Zhang and Jinfan Zhou and Yifu Lu and Minzhe Guo and Peng Wang and Liyue Shen and Qing Qu}, 150 | booktitle={Forty-first International Conference on Machine Learning}, 151 | year={2024}, 152 | url={https://openreview.net/forum?id=HsliOqZkc0} 153 | } 154 | 155 | @article{wang2024diffusion, 156 | title={Diffusion models learn low-dimensional distributions via subspace clustering}, 157 | author={Wang, Peng and Zhang, Huijie and Zhang, Zekai and Chen, Siyi and Ma, Yi and Qu, Qing}, 158 | journal={arXiv preprint arXiv:2409.02426}, 159 | year={2024} 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /edm/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | FROM nvcr.io/nvidia/pytorch:22.10-py3 9 | 10 | ENV PYTHONDONTWRITEBYTECODE 1 11 | ENV PYTHONUNBUFFERED 1 12 | 13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 14 | 15 | WORKDIR /workspace 16 | 17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 18 | ENTRYPOINT ["/entry.sh"] 19 | -------------------------------------------------------------------------------- /edm/README.md: -------------------------------------------------------------------------------- 1 | ## Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
Official PyTorch implementation of the NeurIPS 2022 paper 2 | 3 | ![Teaser image](./docs/teaser-1920x640.jpg) 4 | 5 | **Elucidating the Design Space of Diffusion-Based Generative Models**
6 | Tero Karras, Miika Aittala, Timo Aila, Samuli Laine 7 |
https://arxiv.org/abs/2206.00364
8 | 9 | Abstract: *We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.* 10 | 11 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) 12 | 13 | ## Requirements 14 | 15 | * Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons. 16 | * 1+ high-end NVIDIA GPU for sampling and 8+ GPUs for training. We have done all testing and development using V100 and A100 GPUs. 17 | * 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch install instructions. 18 | * Python libraries: See [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment: 19 | - `conda env create -f environment.yml -n edm` 20 | - `conda activate edm` 21 | * Docker users: 22 | - Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu). 23 | - Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies. 24 | 25 | ## Getting started 26 | 27 | To reproduce the main results from our paper, simply run: 28 | 29 | ```.bash 30 | python example.py 31 | ``` 32 | 33 | This is a minimal standalone script that loads the best pre-trained model for each dataset and generates a random 8x8 grid of images using the optimal sampler settings. Expected results: 34 | 35 | | Dataset | Runtime | Reference image 36 | | :------- | :------ | :-------------- 37 | | CIFAR-10 | ~6 sec | [`cifar10-32x32.png`](./docs/cifar10-32x32.png) 38 | | FFHQ | ~28 sec | [`ffhq-64x64.png`](./docs/ffhq-64x64.png) 39 | | AFHQv2 | ~28 sec | [`afhqv2-64x64.png`](./docs/afhqv2-64x64.png) 40 | | ImageNet | ~5 min | [`imagenet-64x64.png`](./docs/imagenet-64x64.png) 41 | 42 | The easiest way to explore different sampling strategies is to modify [`example.py`](./example.py) directly. You can also incorporate the pre-trained models and/or our proposed EDM sampler in your own code by simply copy-pasting the relevant bits. Note that the class definitions for the pre-trained models are stored within the pickles themselves and loaded automatically during unpickling via [`torch_utils.persistence`](./torch_utils/persistence.py). To use the models in external Python scripts, just make sure that `torch_utils` and `dnnlib` are accesible through `PYTHONPATH`. 43 | 44 | **Docker**: You can run the example script using Docker as follows: 45 | 46 | ```.bash 47 | # Build the edm:latest image 48 | docker build --tag edm:latest . 49 | 50 | # Run the generate.py script using Docker: 51 | docker run --gpus all -it --rm --user $(id -u):$(id -g) \ 52 | -v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \ 53 | edm:latest \ 54 | python example.py 55 | ``` 56 | 57 | Note: The Docker image requires NVIDIA driver release `r520` or later. 58 | 59 | The `docker run` invocation may look daunting, so let's unpack its contents here: 60 | 61 | - `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root. 62 | - ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir. 63 | - `-e HOME=/scratch`: specify where to cache temporary files. Note: if you want more fine-grained control, you can instead set `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations. 64 | 65 | ## Pre-trained models 66 | 67 | We provide pre-trained models for our proposed training configuration (config F) as well as the baseline configuration (config A): 68 | 69 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/) 70 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/) 71 | 72 | To generate a batch of images using a given model and sampler, run: 73 | 74 | ```.bash 75 | # Generate 64 images and save them as out/*.png 76 | python generate.py --outdir=out --seeds=0-63 --batch=64 \ 77 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 78 | ``` 79 | 80 | Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`: 81 | 82 | ```.bash 83 | # Generate 1024 images using 2 GPUs 84 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \ 85 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 86 | ``` 87 | 88 | The sampler settings can be controlled through command-line options; see [`python generate.py --help`](./docs/generate-help.txt) for more information. For best results, we recommend using the following settings for each dataset: 89 | 90 | ```.bash 91 | # For CIFAR-10 at 32x32, use deterministic sampling with 18 steps (NFE = 35) 92 | python generate.py --outdir=out --steps=18 \ 93 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 94 | 95 | # For FFHQ and AFHQv2 at 64x64, use deterministic sampling with 40 steps (NFE = 79) 96 | python generate.py --outdir=out --steps=40 \ 97 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-vp.pkl 98 | 99 | # For ImageNet at 64x64, use stochastic sampling with 256 steps (NFE = 511) 100 | python generate.py --outdir=out --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003 \ 101 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl 102 | ``` 103 | 104 | Besides our proposed EDM sampler, `generate.py` can also be used to reproduce the sampler ablations from Section 3 of our paper. For example: 105 | 106 | ```.bash 107 | # Figure 2a, "Our reimplementation" 108 | python generate.py --outdir=out --steps=512 --solver=euler --disc=vp --schedule=vp --scaling=vp \ 109 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl 110 | 111 | # Figure 2a, "+ Heun & our {t_i}" 112 | python generate.py --outdir=out --steps=128 --solver=heun --disc=edm --schedule=vp --scaling=vp \ 113 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl 114 | 115 | # Figure 2a, "+ Our sigma(t) & s(t)" 116 | python generate.py --outdir=out --steps=18 --solver=heun --disc=edm --schedule=linear --scaling=none \ 117 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl 118 | ``` 119 | 120 | ## Calculating FID 121 | 122 | To compute Fréchet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`: 123 | 124 | ```.bash 125 | # Generate 50000 images and save them as fid-tmp/*/*.png 126 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \ 127 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 128 | 129 | # Calculate FID 130 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \ 131 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 132 | ``` 133 | 134 | Both of the above commands can be parallelized across multiple GPUs by adjusting `--nproc_per_node`. The second command typically takes 1-3 minutes in practice, but the first one can sometimes take several hours, depending on the configuration. See [`python fid.py --help`](./docs/fid-help.txt) for the full list of options. 135 | 136 | Note that the numerical value of FID varies across different random seeds and is highly sensitive to the number of images. By default, `fid.py` will always use 50,000 generated images; providing fewer images will result in an error, whereas providing more will use a random subset. To reduce the effect of random variation, we recommend repeating the calculation multiple times with different seeds, e.g., `--seeds=0-49999`, `--seeds=50000-99999`, and `--seeds=100000-149999`. In our paper, we calculated each FID three times and reported the minimum. 137 | 138 | Also note that it is important to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, we provide the exact reference statistics that correspond to our pre-trained models: 139 | 140 | * [https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/) 141 | 142 | For ImageNet, we provide two sets of reference statistics to enable apples-to-apples comparison: `imagenet-64x64.npz` should be used when evaluating the EDM model (`edm-imagenet-64x64-cond-adm.pkl`), whereas `imagenet-64x64-baseline.npz` should be used when evaluating the baseline model (`baseline-imagenet-64x64-cond-adm.pkl`); the latter was originally trained by Dhariwal and Nichol using slightly different training data. 143 | 144 | You can compute the reference statistics for your own datasets as follows: 145 | 146 | ```.bash 147 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 148 | ``` 149 | 150 | ## Preparing datasets 151 | 152 | Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information. 153 | 154 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive: 155 | 156 | ```.bash 157 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \ 158 | --dest=datasets/cifar10-32x32.zip 159 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz 160 | ``` 161 | 162 | **FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution: 163 | 164 | ```.bash 165 | python dataset_tool.py --source=downloads/ffhq/images1024x1024 \ 166 | --dest=datasets/ffhq-64x64.zip --resolution=64x64 167 | python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz 168 | ``` 169 | 170 | **AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution: 171 | 172 | ```.bash 173 | python dataset_tool.py --source=downloads/afhqv2 \ 174 | --dest=datasets/afhqv2-64x64.zip --resolution=64x64 175 | python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz 176 | ``` 177 | 178 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution: 179 | 180 | ```.bash 181 | python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \ 182 | --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop 183 | python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz 184 | ``` 185 | 186 | ## Training new models 187 | 188 | You can train new models using `train.py`. For example: 189 | 190 | ```.bash 191 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 192 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \ 193 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 194 | ``` 195 | 196 | The above example uses the default batch size of 512 images (controlled by `--batch`) that is divided evenly among 8 GPUs (controlled by `--nproc_per_node`) to yield 64 images per GPU. Training large models may run out of GPU memory; the best way to avoid this is to limit the per-GPU batch size, e.g., `--batch-gpu=32`. This employs gradient accumulation to yield the same results as using full per-GPU batches. See [`python train.py --help`](./docs/train-help.txt) for the full list of options. 197 | 198 | The results of each training run are saved to a newly created directory, for example `training-runs/00000-cifar10-cond-ddpmpp-edm-gpus8-batch64-fp32`. The training loop exports network snapshots (`network-snapshot-*.pkl`) and training states (`training-state-*.pt`) at regular intervals (controlled by `--snap` and `--dump`). The network snapshots can be used to generate images with `generate.py`, and the training states can be used to resume the training later on (`--resume`). Other useful information is recorded in `log.txt` and `stats.jsonl`. To monitor training convergence, we recommend looking at the training loss (`"Loss/loss"` in `stats.jsonl`) as well as periodically evaluating FID for `network-snapshot-*.pkl` using `generate.py` and `fid.py`. 199 | 200 | The following table lists the exact training configurations that we used to obtain our pre-trained models: 201 | 202 | | Model | GPUs | Time | Options 203 | | :-- | :-- | :-- | :-- 204 | | cifar10‑32x32‑cond‑vp | 8xV100 | ~2 days | `--cond=1 --arch=ddpmpp` 205 | | cifar10‑32x32‑cond‑ve | 8xV100 | ~2 days | `--cond=1 --arch=ncsnpp` 206 | | cifar10‑32x32‑uncond‑vp | 8xV100 | ~2 days | `--cond=0 --arch=ddpmpp` 207 | | cifar10‑32x32‑uncond‑ve | 8xV100 | ~2 days | `--cond=0 --arch=ncsnpp` 208 | | ffhq‑64x64‑uncond‑vp | 8xV100 | ~4 days | `--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15` 209 | | ffhq‑64x64‑uncond‑ve | 8xV100 | ~4 days | `--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15` 210 | | afhqv2‑64x64‑uncond‑vp | 8xV100 | ~4 days | `--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15` 211 | | afhqv2‑64x64‑uncond‑ve | 8xV100 | ~4 days | `--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15` 212 | | imagenet‑64x64‑cond‑adm | 32xA100 | ~13 days | `--cond=1 --arch=adm --duration=2500 --batch=4096 --lr=1e-4 --ema=50 --dropout=0.10 --augment=0 --fp16=1 --ls=100 --tick=200` 213 | 214 | For ImageNet-64, we ran the training on four NVIDIA DGX A100 nodes, each containing 8 Ampere GPUs with 80 GB of memory. To reduce the GPU memory requirements, we recommend either training the model with more GPUs or limiting the per-GPU batch size with `--batch-gpu`. To set up multi-node training, please consult the [torchrun documentation](https://pytorch.org/docs/stable/elastic/run.html). 215 | 216 | ## License 217 | 218 | Copyright © 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 219 | 220 | All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/). 221 | 222 | `baseline-cifar10-32x32-uncond-vp.pkl` and `baseline-cifar10-32x32-uncond-ve.pkl` are derived from the [pre-trained models](https://github.com/yang-song/score_sde_pytorch) by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. The models were originally shared under the [Apache 2.0 license](https://github.com/yang-song/score_sde_pytorch/blob/main/LICENSE). 223 | 224 | `baseline-imagenet-64x64-cond-adm.pkl` is derived from the [pre-trained model](https://github.com/openai/guided-diffusion) by Prafulla Dhariwal and Alex Nichol. The model was originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE). 225 | 226 | `imagenet-64x64-baseline.npz` is derived from the [precomputed reference statistics](https://github.com/openai/guided-diffusion/tree/main/evaluations) by Prafulla Dhariwal and Alex Nichol. The statistics were 227 | originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE). 228 | 229 | ## Citation 230 | 231 | ``` 232 | @inproceedings{Karras2022edm, 233 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 234 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 235 | booktitle = {Proc. NeurIPS}, 236 | year = {2022} 237 | } 238 | ``` 239 | 240 | ## Development 241 | 242 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests. 243 | 244 | ## Acknowledgments 245 | 246 | We thank Jaakko Lehtinen, Ming-Yu Liu, Tuomas Kynkäänniemi, Axel Sauer, Arash Vahdat, and Janne Hellsten for discussions and comments, and Tero Kuosmanen, Samuel Klenberg, and Janne Hellsten for maintaining our compute infrastructure. 247 | -------------------------------------------------------------------------------- /edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /edm/dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 59 | self.file = None 60 | 61 | if file_name is not None: 62 | self.file = open(file_name, file_mode) 63 | 64 | self.should_flush = should_flush 65 | self.stdout = sys.stdout 66 | self.stderr = sys.stderr 67 | 68 | sys.stdout = self 69 | sys.stderr = self 70 | 71 | def __enter__(self) -> "Logger": 72 | return self 73 | 74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 75 | self.close() 76 | 77 | def write(self, text: Union[str, bytes]) -> None: 78 | """Write text to stdout (and a file) and optionally flush.""" 79 | if isinstance(text, bytes): 80 | text = text.decode() 81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 82 | return 83 | 84 | if self.file is not None: 85 | self.file.write(text) 86 | 87 | self.stdout.write(text) 88 | 89 | if self.should_flush: 90 | self.flush() 91 | 92 | def flush(self) -> None: 93 | """Flush written text to both stdout and a file, if open.""" 94 | if self.file is not None: 95 | self.file.flush() 96 | 97 | self.stdout.flush() 98 | 99 | def close(self) -> None: 100 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 101 | self.flush() 102 | 103 | # if using multiple loggers, prevent closing in wrong order 104 | if sys.stdout is self: 105 | sys.stdout = self.stdout 106 | if sys.stderr is self: 107 | sys.stderr = self.stderr 108 | 109 | if self.file is not None: 110 | self.file.close() 111 | self.file = None 112 | 113 | 114 | # Cache directories 115 | # ------------------------------------------------------------------------------------------ 116 | 117 | _dnnlib_cache_dir = None 118 | 119 | def set_cache_dir(path: str) -> None: 120 | global _dnnlib_cache_dir 121 | _dnnlib_cache_dir = path 122 | 123 | def make_cache_dir_path(*paths: str) -> str: 124 | if _dnnlib_cache_dir is not None: 125 | return os.path.join(_dnnlib_cache_dir, *paths) 126 | if 'DNNLIB_CACHE_DIR' in os.environ: 127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 128 | if 'HOME' in os.environ: 129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 130 | if 'USERPROFILE' in os.environ: 131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 133 | 134 | # Small util functions 135 | # ------------------------------------------------------------------------------------------ 136 | 137 | 138 | def format_time(seconds: Union[int, float]) -> str: 139 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 140 | s = int(np.rint(seconds)) 141 | 142 | if s < 60: 143 | return "{0}s".format(s) 144 | elif s < 60 * 60: 145 | return "{0}m {1:02}s".format(s // 60, s % 60) 146 | elif s < 24 * 60 * 60: 147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 148 | else: 149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 150 | 151 | 152 | def format_time_brief(seconds: Union[int, float]) -> str: 153 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 154 | s = int(np.rint(seconds)) 155 | 156 | if s < 60: 157 | return "{0}s".format(s) 158 | elif s < 60 * 60: 159 | return "{0}m {1:02}s".format(s // 60, s % 60) 160 | elif s < 24 * 60 * 60: 161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 162 | else: 163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 164 | 165 | 166 | def ask_yes_no(question: str) -> bool: 167 | """Ask the user the question until the user inputs a valid answer.""" 168 | while True: 169 | try: 170 | print("{0} [y/n]".format(question)) 171 | return strtobool(input().lower()) 172 | except ValueError: 173 | pass 174 | 175 | 176 | def tuple_product(t: Tuple) -> Any: 177 | """Calculate the product of the tuple elements.""" 178 | result = 1 179 | 180 | for v in t: 181 | result *= v 182 | 183 | return result 184 | 185 | 186 | _str_to_ctype = { 187 | "uint8": ctypes.c_ubyte, 188 | "uint16": ctypes.c_uint16, 189 | "uint32": ctypes.c_uint32, 190 | "uint64": ctypes.c_uint64, 191 | "int8": ctypes.c_byte, 192 | "int16": ctypes.c_int16, 193 | "int32": ctypes.c_int32, 194 | "int64": ctypes.c_int64, 195 | "float32": ctypes.c_float, 196 | "float64": ctypes.c_double 197 | } 198 | 199 | 200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 202 | type_str = None 203 | 204 | if isinstance(type_obj, str): 205 | type_str = type_obj 206 | elif hasattr(type_obj, "__name__"): 207 | type_str = type_obj.__name__ 208 | elif hasattr(type_obj, "name"): 209 | type_str = type_obj.name 210 | else: 211 | raise RuntimeError("Cannot infer type name from input") 212 | 213 | assert type_str in _str_to_ctype.keys() 214 | 215 | my_dtype = np.dtype(type_str) 216 | my_ctype = _str_to_ctype[type_str] 217 | 218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 219 | 220 | return my_dtype, my_ctype 221 | 222 | 223 | def is_pickleable(obj: Any) -> bool: 224 | try: 225 | with io.BytesIO() as stream: 226 | pickle.dump(obj, stream) 227 | return True 228 | except: 229 | return False 230 | 231 | 232 | # Functionality to import modules/objects by name, and call functions by name 233 | # ------------------------------------------------------------------------------------------ 234 | 235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 236 | """Searches for the underlying module behind the name to some python object. 237 | Returns the module and the object name (original name with module part removed).""" 238 | 239 | # allow convenience shorthands, substitute them by full names 240 | obj_name = re.sub("^np.", "numpy.", obj_name) 241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 242 | 243 | # list alternatives for (module_name, local_obj_name) 244 | parts = obj_name.split(".") 245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 246 | 247 | # try each alternative in turn 248 | for module_name, local_obj_name in name_pairs: 249 | try: 250 | module = importlib.import_module(module_name) # may raise ImportError 251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 252 | return module, local_obj_name 253 | except: 254 | pass 255 | 256 | # maybe some of the modules themselves contain errors? 257 | for module_name, _local_obj_name in name_pairs: 258 | try: 259 | importlib.import_module(module_name) # may raise ImportError 260 | except ImportError: 261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 262 | raise 263 | 264 | # maybe the requested attribute is missing? 265 | for module_name, local_obj_name in name_pairs: 266 | try: 267 | module = importlib.import_module(module_name) # may raise ImportError 268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 269 | except ImportError: 270 | pass 271 | 272 | # we are out of luck, but we have no idea why 273 | raise ImportError(obj_name) 274 | 275 | 276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 277 | """Traverses the object name and returns the last (rightmost) python object.""" 278 | if obj_name == '': 279 | return module 280 | obj = module 281 | for part in obj_name.split("."): 282 | obj = getattr(obj, part) 283 | return obj 284 | 285 | 286 | def get_obj_by_name(name: str) -> Any: 287 | """Finds the python object with the given name.""" 288 | module, obj_name = get_module_from_obj_name(name) 289 | return get_obj_from_module(module, obj_name) 290 | 291 | 292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 293 | """Finds the python object with the given name and calls it as a function.""" 294 | assert func_name is not None 295 | func_obj = get_obj_by_name(func_name) 296 | assert callable(func_obj) 297 | return func_obj(*args, **kwargs) 298 | 299 | 300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 301 | """Finds the python class with the given name and constructs it with the given arguments.""" 302 | return call_func_by_name(*args, func_name=class_name, **kwargs) 303 | 304 | 305 | def get_module_dir_by_obj_name(obj_name: str) -> str: 306 | """Get the directory path of the module containing the given object name.""" 307 | module, _ = get_module_from_obj_name(obj_name) 308 | return os.path.dirname(inspect.getfile(module)) 309 | 310 | 311 | def is_top_level_function(obj: Any) -> bool: 312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 314 | 315 | 316 | def get_top_level_function_name(obj: Any) -> str: 317 | """Return the fully-qualified name of a top-level function.""" 318 | assert is_top_level_function(obj) 319 | module = obj.__module__ 320 | if module == '__main__': 321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 322 | return module + "." + obj.__name__ 323 | 324 | 325 | # File system helpers 326 | # ------------------------------------------------------------------------------------------ 327 | 328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 329 | """List all files recursively in a given directory while ignoring given file and directory names. 330 | Returns list of tuples containing both absolute and relative paths.""" 331 | assert os.path.isdir(dir_path) 332 | base_name = os.path.basename(os.path.normpath(dir_path)) 333 | 334 | if ignores is None: 335 | ignores = [] 336 | 337 | result = [] 338 | 339 | for root, dirs, files in os.walk(dir_path, topdown=True): 340 | for ignore_ in ignores: 341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 342 | 343 | # dirs need to be edited in-place 344 | for d in dirs_to_remove: 345 | dirs.remove(d) 346 | 347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 348 | 349 | absolute_paths = [os.path.join(root, f) for f in files] 350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 351 | 352 | if add_base_to_relative: 353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 354 | 355 | assert len(absolute_paths) == len(relative_paths) 356 | result += zip(absolute_paths, relative_paths) 357 | 358 | return result 359 | 360 | 361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 362 | """Takes in a list of tuples of (src, dst) paths and copies files. 363 | Will create all necessary directories.""" 364 | for file in files: 365 | target_dir_name = os.path.dirname(file[1]) 366 | 367 | # will create all intermediate-level directories 368 | if not os.path.exists(target_dir_name): 369 | os.makedirs(target_dir_name) 370 | 371 | shutil.copyfile(file[0], file[1]) 372 | 373 | 374 | # URL helpers 375 | # ------------------------------------------------------------------------------------------ 376 | 377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 378 | """Determine whether the given object is a valid URL string.""" 379 | if not isinstance(obj, str) or not "://" in obj: 380 | return False 381 | if allow_file_urls and obj.startswith('file://'): 382 | return True 383 | try: 384 | res = requests.compat.urlparse(obj) 385 | if not res.scheme or not res.netloc or not "." in res.netloc: 386 | return False 387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 388 | if not res.scheme or not res.netloc or not "." in res.netloc: 389 | return False 390 | except: 391 | return False 392 | return True 393 | 394 | 395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 396 | """Download the given URL and return a binary-mode file object to access the data.""" 397 | assert num_attempts >= 1 398 | assert not (return_filename and (not cache)) 399 | 400 | # Doesn't look like an URL scheme so interpret it as a local filename. 401 | if not re.match('^[a-z]+://', url): 402 | return url if return_filename else open(url, "rb") 403 | 404 | # Handle file URLs. This code handles unusual file:// patterns that 405 | # arise on Windows: 406 | # 407 | # file:///c:/foo.txt 408 | # 409 | # which would translate to a local '/c:/foo.txt' filename that's 410 | # invalid. Drop the forward slash for such pathnames. 411 | # 412 | # If you touch this code path, you should test it on both Linux and 413 | # Windows. 414 | # 415 | # Some internet resources suggest using urllib.request.url2pathname() but 416 | # but that converts forward slashes to backslashes and this causes 417 | # its own set of problems. 418 | if url.startswith('file://'): 419 | filename = urllib.parse.urlparse(url).path 420 | if re.match(r'^/[a-zA-Z]:', filename): 421 | filename = filename[1:] 422 | return filename if return_filename else open(filename, "rb") 423 | 424 | assert is_url(url) 425 | 426 | # Lookup from cache. 427 | if cache_dir is None: 428 | cache_dir = make_cache_dir_path('downloads') 429 | 430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 431 | if cache: 432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 433 | if len(cache_files) == 1: 434 | filename = cache_files[0] 435 | return filename if return_filename else open(filename, "rb") 436 | 437 | # Download. 438 | url_name = None 439 | url_data = None 440 | with requests.Session() as session: 441 | if verbose: 442 | print("Downloading %s ..." % url, end="", flush=True) 443 | for attempts_left in reversed(range(num_attempts)): 444 | try: 445 | with session.get(url) as res: 446 | res.raise_for_status() 447 | if len(res.content) == 0: 448 | raise IOError("No data received") 449 | 450 | if len(res.content) < 8192: 451 | content_str = res.content.decode("utf-8") 452 | if "download_warning" in res.headers.get("Set-Cookie", ""): 453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 454 | if len(links) == 1: 455 | url = requests.compat.urljoin(url, links[0]) 456 | raise IOError("Google Drive virus checker nag") 457 | if "Google Drive - Quota exceeded" in content_str: 458 | raise IOError("Google Drive download quota exceeded -- please try again later") 459 | 460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 461 | url_name = match[1] if match else url 462 | url_data = res.content 463 | if verbose: 464 | print(" done") 465 | break 466 | except KeyboardInterrupt: 467 | raise 468 | except: 469 | if not attempts_left: 470 | if verbose: 471 | print(" failed") 472 | raise 473 | if verbose: 474 | print(".", end="", flush=True) 475 | 476 | # Save to cache. 477 | if cache: 478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 479 | safe_name = safe_name[:min(len(safe_name), 128)] 480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 482 | os.makedirs(cache_dir, exist_ok=True) 483 | with open(temp_file, "wb") as f: 484 | f.write(url_data) 485 | os.replace(temp_file, cache_file) # atomic 486 | if return_filename: 487 | return cache_file 488 | 489 | # Return data as file object. 490 | assert not return_filename 491 | return io.BytesIO(url_data) 492 | -------------------------------------------------------------------------------- /edm/docs/afhqv2-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/afhqv2-64x64.png -------------------------------------------------------------------------------- /edm/docs/cifar10-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/cifar10-32x32.png -------------------------------------------------------------------------------- /edm/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, class labels are determined from 40 | top-level directory names. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use the 48 | --resolution option. Output resolution will be either the original input 49 | resolution (if resolution was not specified) or the one specified with 50 | --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Input directory or archive name [required] 61 | --dest PATH Output directory or archive name [required] 62 | --max-images INT Maximum number of images to output 63 | --transform MODE Input crop/resize mode 64 | --resolution WxH Output resolution (e.g., 512x512) 65 | --help Show this message and exit. 66 | -------------------------------------------------------------------------------- /edm/docs/ffhq-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/ffhq-64x64.png -------------------------------------------------------------------------------- /edm/docs/fid-help.txt: -------------------------------------------------------------------------------- 1 | Usage: fid.py [OPTIONS] COMMAND [ARGS]... 2 | 3 | Calculate Frechet Inception Distance (FID). 4 | 5 | Examples: 6 | 7 | # Generate 50000 images and save them as fid-tmp/*/*.png 8 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \ 9 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 10 | 11 | # Calculate FID 12 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \ 13 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 14 | 15 | # Compute dataset reference statistics 16 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 17 | 18 | Options: 19 | --help Show this message and exit. 20 | 21 | Commands: 22 | calc Calculate FID for a given set of images. 23 | ref Calculate dataset reference statistics needed by 'calc'. 24 | 25 | 26 | Usage: fid.py calc [OPTIONS] 27 | 28 | Calculate FID for a given set of images. 29 | 30 | Options: 31 | --images PATH|ZIP Path to the images [required] 32 | --ref NPZ|URL Dataset reference statistics [required] 33 | --num INT Number of images to use [default: 50000; x>=2] 34 | --seed INT Random seed for selecting the images [default: 0] 35 | --batch INT Maximum batch size [default: 64; x>=1] 36 | --help Show this message and exit. 37 | 38 | 39 | Usage: fid.py ref [OPTIONS] 40 | 41 | Calculate dataset reference statistics needed by 'calc'. 42 | 43 | Options: 44 | --data PATH|ZIP Path to the dataset [required] 45 | --dest NPZ Destination .npz file [required] 46 | --batch INT Maximum batch size [default: 64; x>=1] 47 | --help Show this message and exit. 48 | -------------------------------------------------------------------------------- /edm/docs/generate-help.txt: -------------------------------------------------------------------------------- 1 | Usage: generate.py [OPTIONS] 2 | 3 | Generate random images using the techniques described in the paper 4 | "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Generate 64 images and save them as out/*.png 9 | python generate.py --outdir=out --seeds=0-63 --batch=64 \ 10 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 11 | 12 | # Generate 1024 images using 2 GPUs 13 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \ 14 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 15 | 16 | Options: 17 | --network PATH|URL Network pickle filename [required] 18 | --outdir DIR Where to save the output images [required] 19 | --seeds LIST Random seeds (e.g. 1,2,5-10) [default: 0-63] 20 | --subdirs Create subdirectory for every 1000 seeds 21 | --class INT Class label [default: random] [x>=0] 22 | --batch INT Maximum batch size [default: 64; x>=1] 23 | --steps INT Number of sampling steps [default: 18; x>=1] 24 | --sigma_min FLOAT Lowest noise level [default: varies] [x>0] 25 | --sigma_max FLOAT Highest noise level [default: varies] [x>0] 26 | --rho FLOAT Time step exponent [default: 7; x>0] 27 | --S_churn FLOAT Stochasticity strength [default: 0; x>=0] 28 | --S_min FLOAT Stoch. min noise level [default: 0; x>=0] 29 | --S_max FLOAT Stoch. max noise level [default: inf; x>=0] 30 | --S_noise FLOAT Stoch. noise inflation [default: 1] 31 | --solver euler|heun Ablate ODE solver 32 | --disc vp|ve|iddpm|edm Ablate time step discretization {t_i} 33 | --schedule vp|ve|linear Ablate noise schedule sigma(t) 34 | --scaling vp|none Ablate signal scaling s(t) 35 | --help Show this message and exit. 36 | -------------------------------------------------------------------------------- /edm/docs/imagenet-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/imagenet-64x64.png -------------------------------------------------------------------------------- /edm/docs/teaser-1280x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-1280x640.jpg -------------------------------------------------------------------------------- /edm/docs/teaser-1920x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-1920x640.jpg -------------------------------------------------------------------------------- /edm/docs/teaser-640x480.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-640x480.jpg -------------------------------------------------------------------------------- /edm/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train diffusion-based generative model using the techniques described in the 4 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 9 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \ 10 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 11 | 12 | Options: 13 | --outdir DIR Where to save the results [required] 14 | --data ZIP|DIR Path to the dataset [required] 15 | --cond BOOL Train class-conditional model [default: False] 16 | --arch ddpmpp|ncsnpp|adm Network architecture [default: ddpmpp] 17 | --precond vp|ve|edm Preconditioning & loss function [default: edm] 18 | --duration MIMG Training duration [default: 200; x>0] 19 | --batch INT Total batch size [default: 512; x>=1] 20 | --batch-gpu INT Limit batch size per GPU [x>=1] 21 | --cbase INT Channel multiplier [default: varies] 22 | --cres LIST Channels per resolution [default: varies] 23 | --lr FLOAT Learning rate [default: 0.001; x>0] 24 | --ema MIMG EMA half-life [default: 0.5; x>=0] 25 | --dropout FLOAT Dropout probability [default: 0.13; 0<=x<=1] 26 | --augment FLOAT Augment probability [default: 0.12; 0<=x<=1] 27 | --xflip BOOL Enable dataset x-flips [default: False] 28 | --fp16 BOOL Enable mixed-precision training [default: False] 29 | --ls FLOAT Loss scaling [default: 1; x>0] 30 | --bench BOOL Enable cuDNN benchmarking [default: True] 31 | --cache BOOL Cache dataset in CPU memory [default: True] 32 | --workers INT DataLoader worker processes [default: 1; x>=1] 33 | --desc STR String to include in result dir name 34 | --nosubdir Do not create a subdirectory for results 35 | --tick KIMG How often to print progress [default: 50; x>=1] 36 | --snap TICKS How often to save snapshots [default: 50; x>=1] 37 | --dump TICKS How often to dump state [default: 500; x>=1] 38 | --seed INT Random seed [default: random] 39 | --transfer PKL|URL Transfer learning from network pickle 40 | --resume PT Resume from previous training state 41 | -n, --dry-run Print training options and exit 42 | --help Show this message and exit. 43 | -------------------------------------------------------------------------------- /edm/environment.yml: -------------------------------------------------------------------------------- 1 | name: edm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python>=3.8, < 3.10 # package build failures on 3.10 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow>=8.3.1 11 | - scipy>=1.7.1 12 | - pytorch=1.12.1 13 | - psutil 14 | - requests 15 | - tqdm 16 | - imageio 17 | - pip: 18 | - imageio-ffmpeg>=0.4.3 19 | - pyspng 20 | - wget 21 | -------------------------------------------------------------------------------- /edm/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Minimal standalone example to reproduce the main results from the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import tqdm 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import PIL.Image 16 | import dnnlib 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def generate_image_grid( 21 | network_pkl, dest_path, 22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'), 23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 25 | ): 26 | batch_size = gridw * gridh 27 | torch.manual_seed(seed) 28 | 29 | # Load network. 30 | print(f'Loading network from "{network_pkl}"...') 31 | with dnnlib.util.open_url(network_pkl) as f: 32 | net = pickle.load(f)['ema'].to(device) 33 | 34 | # Pick latents and labels. 35 | print(f'Generating {batch_size} images...') 36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 37 | class_labels = None 38 | if net.label_dim: 39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)] 40 | 41 | # Adjust noise levels based on what's supported by the network. 42 | sigma_min = max(sigma_min, net.sigma_min) 43 | sigma_max = min(sigma_max, net.sigma_max) 44 | 45 | # Time step discretization. 46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 49 | 50 | # Main sampling loop. 51 | x_next = latents.to(torch.float64) * t_steps[0] 52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1 53 | x_cur = x_next 54 | 55 | # Increase noise temporarily. 56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 57 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 59 | 60 | # Euler step. 61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 62 | d_cur = (x_hat - denoised) / t_hat 63 | x_next = x_hat + (t_next - t_hat) * d_cur 64 | 65 | # Apply 2nd order correction. 66 | if i < num_steps - 1: 67 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 68 | d_prime = (x_next - denoised) / t_next 69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 70 | 71 | # Save image grid. 72 | print(f'Saving image grid to "{dest_path}"...') 73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8) 74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) 75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels) 76 | image = image.cpu().numpy() 77 | PIL.Image.fromarray(image, 'RGB').save(dest_path) 78 | print('Done.') 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def main(): 83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained' 84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35 85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79 86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79 87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | if __name__ == "__main__": 92 | main() 93 | 94 | #---------------------------------------------------------------------------- 95 | -------------------------------------------------------------------------------- /edm/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_kwargs = dict(return_features=True) 36 | feature_dim = 2048 37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 38 | detector_net = pickle.load(f).to(device) 39 | 40 | # List images. 41 | dist.print0(f'Loading images from "{image_path}"...') 42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 43 | if num_expected is not None and len(dataset_obj) < num_expected: 44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 45 | if len(dataset_obj) < 2: 46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 47 | 48 | # Other ranks follow. 49 | if dist.get_rank() == 0: 50 | torch.distributed.barrier() 51 | 52 | # Divide images into batches. 53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 57 | 58 | # Accumulate statistics. 59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 63 | torch.distributed.barrier() 64 | if images.shape[0] == 0: 65 | continue 66 | if images.shape[1] == 1: 67 | images = images.repeat([1, 3, 1, 1]) 68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 69 | mu += features.sum(0) 70 | sigma += features.T @ features 71 | 72 | # Calculate grand totals. 73 | torch.distributed.all_reduce(mu) 74 | torch.distributed.all_reduce(sigma) 75 | mu /= len(dataset_obj) 76 | sigma -= mu.ger(mu) * len(dataset_obj) 77 | sigma /= len(dataset_obj) - 1 78 | return mu.cpu().numpy(), sigma.cpu().numpy() 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 83 | m = np.square(mu - mu_ref).sum() 84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 85 | fid = m + np.trace(sigma + sigma_ref - s * 2) 86 | return float(np.real(fid)) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.group() 91 | def main(): 92 | """Calculate Frechet Inception Distance (FID). 93 | 94 | Examples: 95 | 96 | \b 97 | # Generate 50000 images and save them as fid-tmp/*/*.png 98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 100 | 101 | \b 102 | # Calculate FID 103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 105 | 106 | \b 107 | # Compute dataset reference statistics 108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 109 | """ 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | @main.command() 114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 119 | 120 | def calc(image_path, ref_path, num_expected, seed, batch): 121 | """Calculate FID for a given set of images.""" 122 | torch.multiprocessing.set_start_method('spawn') 123 | dist.init() 124 | 125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 126 | ref = None 127 | if dist.get_rank() == 0: 128 | with dnnlib.util.open_url(ref_path) as f: 129 | ref = dict(np.load(f)) 130 | 131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 132 | dist.print0('Calculating FID...') 133 | if dist.get_rank() == 0: 134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 135 | print(f'{fid:g}') 136 | torch.distributed.barrier() 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @main.command() 141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 144 | 145 | def ref(dataset_path, dest_path, batch): 146 | """Calculate dataset reference statistics needed by 'calc'.""" 147 | torch.multiprocessing.set_start_method('spawn') 148 | dist.init() 149 | 150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 152 | if dist.get_rank() == 0: 153 | if os.path.dirname(dest_path): 154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 155 | np.savez(dest_path, mu=mu, sigma=sigma) 156 | 157 | torch.distributed.barrier() 158 | dist.print0('Done.') 159 | 160 | #---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /edm/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Generate random images using the techniques described in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import click 14 | import tqdm 15 | import pickle 16 | import numpy as np 17 | import torch 18 | import PIL.Image 19 | import dnnlib 20 | from torch_utils import distributed as dist 21 | 22 | #---------------------------------------------------------------------------- 23 | # Proposed EDM sampler (Algorithm 2). 24 | 25 | def edm_sampler( 26 | net, latents, dataloader, class_labels=None, randn_like=torch.randn_like, 27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, optimal_denoiser=False 29 | ): 30 | # Adjust noise levels based on what's supported by the network. 31 | if optimal_denoiser: 32 | optimal_x0 = optimal_solver(dataloader) 33 | sigma_min = max(sigma_min, net.sigma_min) 34 | sigma_max = min(sigma_max, net.sigma_max) 35 | 36 | # Time step discretization. 37 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 38 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 39 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 40 | 41 | # Main sampling loop. 42 | x_next = latents.to(torch.float64) * t_steps[0] 43 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 44 | x_cur = x_next 45 | 46 | # Increase noise temporarily. 47 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 48 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 49 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 50 | 51 | # Euler step. 52 | if not optimal_denoiser: 53 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 54 | else: 55 | denoised = optimal_x0(x_hat, s = torch.tensor(1, device = t_cur.device, dtype = t_cur.dtype), sigma = t_hat) 56 | d_cur = (x_hat - denoised) / t_hat 57 | x_next = x_hat + (t_next - t_hat) * d_cur 58 | 59 | # Apply 2nd order correction. 60 | if i < num_steps - 1 and not optimal_denoiser: 61 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 62 | d_prime = (x_next - denoised) / t_next 63 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 64 | 65 | return x_next 66 | 67 | 68 | def normal_distribution_batch(x, y_batch, s, std, bias = 0): 69 | bs_y = y_batch.shape[0] 70 | bs_x = x.shape[0] 71 | xb = x.unsqueeze(1) 72 | prob = torch.exp(-(((xb - s * y_batch)**2).view(bs_x, bs_y, -1).sum(dim=-1).to(torch.float64)/std**2)/2 - bias.unsqueeze(1)) 73 | # prob = torch.where(prob==torch.inf, 1, 0) 74 | prob_y = prob.clone().view(bs_x, bs_y, 1, 1, 1) * y_batch # (bs_x, bs_y, 3, 32, 32) 75 | return prob.sum(dim=1, keepdim=True).squeeze(), prob_y.sum(dim=1, keepdim=True).squeeze() # (bs_x, ), (bs_x, 3, 32, 32) 76 | 77 | def get_exp_bias_batch(x, y_batch, s, std): 78 | ## because exp() might return a very small number, we need a bias 79 | bs_y = y_batch.shape[0] 80 | bs_x = x.shape[0] 81 | xb = x.unsqueeze(1) 82 | return (-(((xb - s * y_batch)**2).view(bs_x, bs_y, -1).sum(dim=-1).to(torch.float64)/std**2)/2).max(dim=1) 83 | 84 | def optimal_solver(dataloader): 85 | def optimal_sol(batch, s, sigma): 86 | 87 | std = s * sigma 88 | x = batch 89 | prob_sum = 0. * torch.ones(x.shape[0]).cuda() 90 | prob_y_sum = torch.zeros_like(x).to(torch.float64).cuda() 91 | exp_bias = -(torch.inf) * torch.ones(x.shape[0],).cuda() 92 | for y_batch, _ in dataloader: 93 | y_batch = y_batch.cuda().to(torch.float32) / 127.5 - 1 94 | curr_exp_bias = get_exp_bias_batch(x, y_batch, s, std)[0] 95 | exp_bias = torch.where(curr_exp_bias>exp_bias, curr_exp_bias, exp_bias) 96 | for y_batch, _ in dataloader: 97 | y_batch = y_batch.cuda().to(torch.float32) / 127.5 - 1 98 | prob, prob_y = normal_distribution_batch(x, y_batch, s, std, exp_bias) 99 | prob_sum += prob 100 | prob_y_sum += prob_y 101 | 102 | optimal_solution = prob_y_sum/prob_sum.view(-1, 1, 1, 1) 103 | return optimal_solution 104 | return optimal_sol 105 | 106 | #---------------------------------------------------------------------------- 107 | # Wrapper for torch.Generator that allows specifying a different random seed 108 | # for each sample in a minibatch. 109 | 110 | class StackedRandomGenerator: 111 | def __init__(self, device, seeds): 112 | super().__init__() 113 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 114 | 115 | def randn(self, size, **kwargs): 116 | assert size[0] == len(self.generators) 117 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 118 | 119 | def randn_like(self, input): 120 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 121 | 122 | def randint(self, *args, size, **kwargs): 123 | assert size[0] == len(self.generators) 124 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 125 | 126 | #---------------------------------------------------------------------------- 127 | # Parse a comma separated list of numbers or ranges and return a list of ints. 128 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 129 | 130 | def parse_int_list(s): 131 | if isinstance(s, list): return s 132 | ranges = [] 133 | range_re = re.compile(r'^(\d+)-(\d+)$') 134 | for p in s.split(','): 135 | m = range_re.match(p) 136 | if m: 137 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 138 | else: 139 | ranges.append(int(p)) 140 | return ranges 141 | 142 | #---------------------------------------------------------------------------- 143 | 144 | @click.command() 145 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) 146 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) 147 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) 148 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) 149 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 150 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 151 | 152 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 153 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 154 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) 155 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 156 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 157 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 158 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 159 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 160 | 161 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) 162 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) 163 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) 164 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) 165 | 166 | @click.option('--optimal_denoiser', help='Generate images from optimal denoiser', is_flag=True) 167 | @click.option('--dataset', 'dataset', help='Dataset used by the optimal denoiser', type=str) 168 | 169 | 170 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, optimal_denoiser, dataset, device=torch.device('cuda'), **sampler_kwargs): 171 | """Generate random images using the techniques described in the paper 172 | "Elucidating the Design Space of Diffusion-Based Generative Models". 173 | 174 | Examples: 175 | 176 | \b 177 | # Generate 64 images and save them as out/*.png 178 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\ 179 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 180 | 181 | \b 182 | # Generate 1024 images using 2 GPUs 183 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\ 184 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 185 | """ 186 | dist.init() 187 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 188 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches) 189 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 190 | 191 | # Rank 0 goes first. 192 | if dist.get_rank() != 0: 193 | torch.distributed.barrier() 194 | 195 | dataset_loader = None 196 | if optimal_denoiser: 197 | assert dataset is not None 198 | dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=dataset) 199 | data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1) 200 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) 201 | dataset_loader = torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=max_batch_size, **data_loader_kwargs) 202 | 203 | # Load network. 204 | assert network_pkl is not None 205 | dist.print0(f'Loading network from "{network_pkl}"...') 206 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 207 | net = pickle.load(f)['ema'].to(device) 208 | 209 | # Other ranks follow. 210 | if dist.get_rank() == 0: 211 | torch.distributed.barrier() 212 | 213 | # Loop over batches. 214 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') 215 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): 216 | torch.distributed.barrier() 217 | batch_size = len(batch_seeds) 218 | if batch_size == 0: 219 | continue 220 | 221 | # Pick latents and labels. 222 | rnd = StackedRandomGenerator(device, batch_seeds) 223 | 224 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 225 | 226 | class_labels = None 227 | if net.label_dim: 228 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] 229 | if class_idx is not None: 230 | class_labels[:, :] = 0 231 | class_labels[:, class_idx] = 1 232 | 233 | # Generate images. 234 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} 235 | images = edm_sampler(net, latents, dataset_loader, class_labels, randn_like=rnd.randn_like, optimal_denoiser=optimal_denoiser, **sampler_kwargs) 236 | 237 | # Save images. 238 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() 239 | for seed, image_np in zip(batch_seeds, images_np): 240 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir 241 | os.makedirs(image_dir, exist_ok=True) 242 | image_path = os.path.join(image_dir, f'{seed:06d}.png') 243 | if image_np.shape[2] == 1: 244 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path) 245 | else: 246 | PIL.Image.fromarray(image_np, 'RGB').save(image_path) 247 | 248 | # Done. 249 | torch.distributed.barrier() 250 | dist.print0('Done.') 251 | 252 | #---------------------------------------------------------------------------- 253 | 254 | if __name__ == "__main__": 255 | main() 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /edm/jacobian.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Evaluate the rank of the jacobian of the denoising autoencoder""" 9 | 10 | import re 11 | import click 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | from torch_utils import distributed as dist 17 | from torch import nn 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def edm_sampler( 22 | net, latents, class_labels=None, randn_like=torch.randn_like, 23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 25 | ): 26 | # Adjust noise levels based on what's supported by the network. 27 | sigma_min = max(sigma_min, net.sigma_min) 28 | sigma_max = min(sigma_max, net.sigma_max) 29 | 30 | # Time step discretization. 31 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) 32 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 33 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 34 | 35 | # Main sampling loop. 36 | x_next = latents.to(torch.float64) * t_steps[0] 37 | trajectory = [x_next] 38 | ts = [t_steps[0]] 39 | 40 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 41 | 42 | x_cur = x_next 43 | 44 | # Increase noise temporarily. 45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 46 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) 48 | 49 | # Euler step. 50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 51 | d_cur = (x_hat - denoised) / t_hat 52 | x_next = x_hat + (t_next - t_hat) * d_cur 53 | 54 | # Apply 2nd order correction. 55 | if i < num_steps - 1: 56 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 57 | d_prime = (x_next - denoised) / t_next 58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 59 | ts.append(t_next) 60 | trajectory.append(x_next) 61 | 62 | return trajectory[:-1], ts[:-1] 63 | 64 | class StackedRandomGenerator: 65 | def __init__(self, device, seeds): 66 | super().__init__() 67 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 68 | 69 | def randn(self, size, **kwargs): 70 | assert size[0] == len(self.generators) 71 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 72 | 73 | def randn_like(self, input): 74 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 75 | 76 | def randint(self, *args, size, **kwargs): 77 | assert size[0] == len(self.generators) 78 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 79 | 80 | #---------------------------------------------------------------------------- 81 | # Parse a comma separated list of numbers or ranges and return a list of ints. 82 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 83 | 84 | def parse_int_list(s): 85 | if isinstance(s, list): return s 86 | ranges = [] 87 | range_re = re.compile(r'^(\d+)-(\d+)$') 88 | for p in s.split(','): 89 | m = range_re.match(p) 90 | if m: 91 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 92 | else: 93 | ranges.append(int(p)) 94 | return ranges 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | @click.command() 99 | 100 | @click.option('--network_pkl', help='ckpt of the model to evaluate the rank of the jacobian', type=str, required = True) 101 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) 102 | 103 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) 104 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=0.002) 105 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=80.0) 106 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) 107 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 108 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) 109 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) 110 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) 111 | 112 | # 113 | 114 | def main(network_pkl, class_idx, device=torch.device('cuda'), **sampler_kwargs): 115 | """Generate random images using the techniques described in the paper 116 | "Elucidating the Design Space of Diffusion-Based Generative Models". 117 | 118 | Examples: 119 | 120 | python edm/jacobian.py --network_pkl https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-ve.pkl --class 5 121 | 122 | """ 123 | dist.init() 124 | 125 | # Rank 0 goes first. 126 | if dist.get_rank() != 0: 127 | torch.distributed.barrier() 128 | 129 | dist.print0(f'Loading network from "{network_pkl}"...') 130 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: 131 | net = pickle.load(f)['ema'].to(device) 132 | 133 | # Other ranks follow. 134 | if dist.get_rank() == 0: 135 | torch.distributed.barrier() 136 | 137 | # Loop over batches. 138 | sigma_max = sampler_kwargs['sigma_max'] 139 | sigma_min = sampler_kwargs['sigma_min'] 140 | rho = sampler_kwargs['rho'] 141 | num_steps = sampler_kwargs['num_steps'] 142 | 143 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 144 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 145 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) 146 | class_labels = None 147 | if net.label_dim: 148 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[1], device=device)] 149 | if class_idx is not None: 150 | class_labels[:, :] = 0 151 | class_labels[:, class_idx] = 1 152 | 153 | total_dim = net.img_channels * net.img_resolution * net.img_resolution 154 | latents = torch.randn(1, net.img_channels, net.img_resolution, net.img_resolution).to(device) 155 | 156 | trajectory, ts = edm_sampler(net, latents, num_steps=num_steps, class_labels=class_labels, sigma_max=sigma_max) 157 | cos = nn.CosineSimilarity(dim=0, eps=1e-6) 158 | for x, t in zip(trajectory, ts): 159 | func = lambda input: net(input, t) 160 | 161 | jacs = torch.autograd.functional.jacobian(func, x).squeeze().permute(1, 2, 0, 4, 5, 3).reshape(total_dim, total_dim).detach().cpu() 162 | output = net(x, t, class_labels).squeeze().permute(1, 2, 0).reshape(-1, 1).detach().cpu() 163 | U, S, V = torch.svd(jacs) 164 | acc_sum = torch.cumsum((S ** 2).to(torch.float32), dim=0).sqrt() 165 | total_sum = acc_sum[-1] 166 | sum = total_sum * 0.99 167 | rank = torch.where(acc_sum > sum)[0].min() + 1 168 | print(f"SNR 1/sigma_t: {1/t.item():.3f}, rank: {rank}") 169 | 170 | # Done. 171 | torch.distributed.barrier() 172 | dist.print0('Done.') 173 | 174 | #---------------------------------------------------------------------------- 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /edm/sscd.py: -------------------------------------------------------------------------------- 1 | """Script for Self-Supervised Descriptor for Image Copy Detection (SSCD).""" 2 | 3 | import os 4 | import click 5 | import tqdm 6 | import pickle 7 | import numpy as np 8 | import scipy.linalg 9 | import torch 10 | import dnnlib 11 | from torch_utils import distributed as dist 12 | from training import dataset 13 | from torchvision import transforms 14 | import wget 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def TransformSamples(samples, transform): 19 | sscd_samples = [] 20 | for sample in samples: 21 | sscd_samples.append(transform(sample)[None, :]) 22 | return torch.cat(sscd_samples, dim=0) 23 | 24 | 25 | @click.group() 26 | def main(): 27 | """Calculate Self-Supervised Descriptor for Image Copy Detection (SSCD). 28 | The original github is https://github.com/facebookresearch/sscd-copy-detection 29 | 30 | Examples: 31 | 32 | \b 33 | # Calculate SSCD feature 34 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim128-n16384 --features ./evaluation/sscd-dim128-n16384.npz 35 | 36 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim64-n16384 --features ./evaluation/sscd-dim64-n16384.npz 37 | 38 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/generalization/ --features ./evaluation/sscd-generalization.npz 39 | 40 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images datasets/synthetic-cifar10-32x32-n16384.zip --features ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n16384.npz 41 | 42 | \b 43 | # Compute reproducibility score 44 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-dim64-n16384.npz 45 | 46 | # Compute generalization score 47 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-generalization.npz 48 | 49 | # Compute memorization score 50 | python edm/sscd.py mscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n16384.npz 51 | """ 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @main.command() 56 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|NPZ', type=str, required=True) 57 | @click.option('--features', 'features_path', help='Path to save features', metavar='NPZ', type=str, required=True) 58 | def feature(image_path, features_path): 59 | """Calculate SSCD features for a given set of images.""" 60 | if not os.path.exists("./pretrainedmodels"): 61 | os.makedirs("./pretrainedmodels") 62 | if not os.path.exists("./pretrainedmodels/sscd_disc_large.torchscript.pt"): 63 | wget.download("https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_large.torchscript.pt", "./pretrainedmodels/sscd_disc_large.torchscript.pt") 64 | sscd_model = torch.jit.load("./pretrainedmodels/sscd_disc_large.torchscript.pt") 65 | sscd_model = sscd_model.to(device=f"cuda:0") 66 | sscd_model.eval() 67 | sscd_transform = transforms.Compose([ 68 | transforms.ToPILImage(), 69 | transforms.Resize((320, 320)), 70 | transforms.ToTensor(), 71 | # transform to float tensor 72 | # transforms.Lambda(lambda x: x.float()), 73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],) 74 | ]) 75 | dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=image_path) 76 | data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1) 77 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) 78 | dataloader = torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=64, **data_loader_kwargs) 79 | sscd_features = [] 80 | for x_batch, _ in tqdm.tqdm(dataloader): 81 | x_sscd = TransformSamples(x_batch, sscd_transform).to(device=f"cuda:0") 82 | sscd_feature = sscd_model(x_sscd).detach().cpu() 83 | sscd_features.append(sscd_feature) 84 | sscd_features = torch.cat(sscd_features, dim=0) 85 | np.savez(features_path, features=sscd_features.numpy()) 86 | # #---------------------------------------------------------------------------- 87 | 88 | @main.command() 89 | @click.option('--source', 'source_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True) 90 | @click.option('--target', 'target_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True) 91 | @click.option('--t', 'threshold', help='threshold for sscd similarity', type=float, default=0.6) 92 | 93 | def rpscore(source_path, target_path, threshold): 94 | """Calculate reproducibility score between source images and targe images.""" 95 | source_features = np.load(source_path)["features"] 96 | target_features = np.load(target_path)["features"] 97 | similarity = (source_features * target_features).sum(axis=1) 98 | rpscore = (similarity > threshold).mean() 99 | print('RP score = ', rpscore) 100 | return rpscore 101 | 102 | @main.command() 103 | @click.option('--source', 'source_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True) 104 | @click.option('--target', 'target_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True) 105 | @click.option('--t', 'threshold', help='threshold for sscd similarity', type=float, default=0.6) 106 | 107 | def mscore(source_path, target_path, threshold): 108 | """Calculate reproducibility score between source images and targe images.""" 109 | bs = 128 110 | source_features = np.load(source_path)["features"][:, None, :] 111 | target_features = np.load(target_path)["features"][None, :, :] 112 | rpscore = 0 113 | total_sample = source_features.shape[0] 114 | for idx in tqdm.tqdm(range(total_sample//bs + 1)): 115 | 116 | similarity = (source_features[idx*bs: (idx + 1)*bs, :] * target_features).sum(axis=2).max(axis=1) 117 | rpscore += (similarity > threshold).sum() 118 | rpscore = rpscore/total_sample 119 | print('M score = ', 1 - rpscore) 120 | return rpscore 121 | 122 | 123 | #---------------------------------------------------------------------------- 124 | 125 | if __name__ == "__main__": 126 | main() 127 | 128 | #---------------------------------------------------------------------------- 129 | -------------------------------------------------------------------------------- /edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /edm/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /edm/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /edm/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | -------------------------------------------------------------------------------- /edm/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Train diffusion-based generative model using the techniques described in the 9 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import os 12 | import re 13 | import json 14 | import click 15 | import torch 16 | import dnnlib 17 | from torch_utils import distributed as dist 18 | from training import training_loop 19 | 20 | import warnings 21 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 22 | 23 | #---------------------------------------------------------------------------- 24 | # Parse a comma separated list of numbers or ranges and return a list of ints. 25 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 26 | 27 | def parse_int_list(s): 28 | if isinstance(s, list): return s 29 | ranges = [] 30 | range_re = re.compile(r'^(\d+)-(\d+)$') 31 | for p in s.split(','): 32 | m = range_re.match(p) 33 | if m: 34 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 35 | else: 36 | ranges.append(int(p)) 37 | return ranges 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @click.command() 42 | 43 | # Main options. 44 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True) 45 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) 46 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) 47 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True) 48 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True) 49 | 50 | # Hyperparameters. 51 | 52 | @click.option('--model_channels', help='model channel', metavar='INT', default=128, type=int, required=True) 53 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True) 54 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) 55 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) 56 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) 57 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) 58 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 59 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) 60 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 61 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) 62 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) 63 | 64 | # Performance-related. 65 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 66 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 67 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 68 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) 69 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 70 | 71 | # I/O-related. 72 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 73 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) 74 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) 75 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) 76 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True) 77 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) 78 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 79 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) 80 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) 81 | 82 | def main(**kwargs): 83 | """Train diffusion-based generative model using the techniques described in the 84 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 85 | 86 | Examples: 87 | 88 | \b 89 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 90 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ 91 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 92 | """ 93 | opts = dnnlib.EasyDict(kwargs) 94 | torch.multiprocessing.set_start_method('spawn') 95 | dist.init() 96 | 97 | # Initialize config dict. 98 | c = dnnlib.EasyDict() 99 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) 100 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) 101 | c.network_kwargs = dnnlib.EasyDict() 102 | c.loss_kwargs = dnnlib.EasyDict() 103 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) 104 | 105 | # Validate dataset options. 106 | try: 107 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) 108 | dataset_name = dataset_obj.name 109 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 110 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size 111 | if opts.cond and not dataset_obj.has_labels: 112 | raise click.ClickException('--cond=True requires labels specified in dataset.json') 113 | del dataset_obj # conserve memory 114 | except IOError as err: 115 | raise click.ClickException(f'--data: {err}') 116 | 117 | # Network architecture. 118 | if opts.arch == 'ddpmpp': 119 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') 120 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels = opts.model_channels, channel_mult=[2,2,2]) 121 | elif opts.arch == 'ncsnpp': 122 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 123 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels = opts.model_channels, channel_mult=[2,2,2]) 124 | else: 125 | assert opts.arch == 'adm' 126 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 127 | 128 | # Preconditioning & loss function. 129 | if opts.precond == 'vp': 130 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 131 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 132 | elif opts.precond == 've': 133 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 134 | c.loss_kwargs.class_name = 'training.loss.VELoss' 135 | else: 136 | assert opts.precond == 'edm' 137 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 138 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 139 | 140 | # Network options. 141 | if opts.cbase is not None: 142 | c.network_kwargs.model_channels = opts.cbase 143 | if opts.cres is not None: 144 | c.network_kwargs.channel_mult = opts.cres 145 | if opts.augment: 146 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) 147 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 148 | c.network_kwargs.augment_dim = 9 149 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) 150 | 151 | # Training options. 152 | c.total_kimg = max(int(opts.duration * 1000), 1) 153 | c.ema_halflife_kimg = int(opts.ema * 1000) 154 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 155 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 156 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) 157 | 158 | # Random seed. 159 | if opts.seed is not None: 160 | c.seed = opts.seed 161 | else: 162 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 163 | torch.distributed.broadcast(seed, src=0) 164 | c.seed = int(seed) 165 | 166 | # Transfer learning and resume. 167 | if opts.transfer is not None: 168 | if opts.resume is not None: 169 | raise click.ClickException('--transfer and --resume cannot be specified at the same time') 170 | c.resume_pkl = opts.transfer 171 | c.ema_rampup_ratio = None 172 | elif opts.resume is not None: 173 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume)) 174 | if not match or not os.path.isfile(opts.resume): 175 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') 176 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') 177 | c.resume_kimg = int(match.group(1)) 178 | c.resume_state_dump = opts.resume 179 | 180 | # Description string. 181 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' 182 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' 183 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}' 184 | if opts.desc is not None: 185 | desc += f'-{opts.desc}' 186 | 187 | # Pick output directory. 188 | if dist.get_rank() != 0: 189 | c.run_dir = None 190 | elif opts.nosubdir: 191 | c.run_dir = opts.outdir 192 | else: 193 | prev_run_dirs = [] 194 | if os.path.isdir(opts.outdir): 195 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 196 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 197 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 198 | cur_run_id = max(prev_run_ids, default=-1) + 1 199 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 200 | assert not os.path.exists(c.run_dir) 201 | 202 | # Print options. 203 | dist.print0() 204 | dist.print0('Training options:') 205 | dist.print0(json.dumps(c, indent=2)) 206 | dist.print0() 207 | dist.print0(f'Output directory: {c.run_dir}') 208 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}') 209 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 210 | dist.print0(f'Network architecture: {opts.arch}') 211 | dist.print0(f'Preconditioning & loss: {opts.precond}') 212 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 213 | dist.print0(f'Batch size: {c.batch_size}') 214 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 215 | dist.print0() 216 | 217 | # Dry run? 218 | if opts.dry_run: 219 | dist.print0('Dry run; exiting.') 220 | return 221 | 222 | # Create output directory. 223 | dist.print0('Creating output directory...') 224 | if dist.get_rank() == 0: 225 | os.makedirs(c.run_dir, exist_ok=True) 226 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 227 | json.dump(c, f, indent=2) 228 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) 229 | 230 | # Train. 231 | training_loop.training_loop(**c) 232 | 233 | #---------------------------------------------------------------------------- 234 | 235 | if __name__ == "__main__": 236 | main() 237 | 238 | #---------------------------------------------------------------------------- 239 | -------------------------------------------------------------------------------- /edm/trainMoLRG.py: -------------------------------------------------------------------------------- 1 | """Train mixture of low-rank Gaussian Distribution (MoLRG) using diffusion model""" 2 | 3 | import os 4 | import re 5 | import json 6 | import click 7 | import torch 8 | import dnnlib 9 | from torch_utils import distributed as dist 10 | from training import training_loop_MoLRG 11 | from glob import glob 12 | 13 | import warnings 14 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. 15 | 16 | #---------------------------------------------------------------------------- 17 | # Parse a comma separated list of numbers or ranges and return a list of ints. 18 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 19 | 20 | def parse_int_list(s): 21 | if isinstance(s, list): return s 22 | ranges = [] 23 | range_re = re.compile(r'^(\d+)-(\d+)$') 24 | for p in s.split(','): 25 | m = range_re.match(p) 26 | if m: 27 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 28 | else: 29 | ranges.append(int(p)) 30 | return ranges 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | @click.command() 35 | 36 | # Main options. 37 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, default="/home/ubuntu/exp/theory/") 38 | @click.option('--img_res', help='resolusion of the MoG image', metavar='DIR', type=int, default=32) 39 | @click.option('--class_num', help='number of classes for the MoG', metavar='DIR', type=int, default=10) 40 | @click.option('--per_class_dim', help='dimension for each class', metavar='DIR', type=int, default=100) 41 | @click.option('--sample_per_class', help='num of samples for each class', metavar='DIR', type=int, default=5000) 42 | @click.option('--path', help='num of samples for each class', metavar='DIR', type=str, required=True) 43 | # @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) 44 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) 45 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm|mlp|param-mlp', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm', 'mlp', "param-mlp"]), default='ddpmpp', show_default=True) 46 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True) 47 | 48 | 49 | # Hyperparameters. 50 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=6, show_default=True) 51 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) 52 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) 53 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) 54 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) 55 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) 56 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) 57 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) 58 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) 59 | # @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) 60 | @click.option('--embed_channels', help='Channel multiplier [default: varies]', metavar='INT', type=int, default=1024) 61 | 62 | # Performance-related. 63 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) 64 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) 65 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) 66 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) 67 | 68 | # I/O-related. 69 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) 70 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) 71 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) 72 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=10, show_default=True) 73 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True) 74 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) 75 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) 76 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=bool) 77 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) 78 | @click.option('--resumedir', help='Resume from previous training directory', metavar='PT', type=str) 79 | @click.option('--optimizer', help='training optimizer', metavar='PT', default="adam", type=click.Choice(['adam', 'sgd'])) 80 | 81 | 82 | def main(**kwargs): 83 | """Train diffusion-based generative model using the techniques described in the 84 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 85 | 86 | Examples: 87 | 88 | \b 89 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 90 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ 91 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 92 | """ 93 | opts = dnnlib.EasyDict(kwargs) 94 | torch.multiprocessing.set_start_method('spawn') 95 | dist.init() 96 | 97 | 98 | # Initialize config dict. 99 | c = dnnlib.EasyDict() 100 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MoLRG', resolution=opts.img_res, class_num=opts.class_num, per_class_dim=opts.per_class_dim, sample_per_class=opts.sample_per_class, path = opts.path, use_labels=opts.cond) 101 | 102 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) 103 | c.network_kwargs = dnnlib.EasyDict() 104 | c.loss_kwargs = dnnlib.EasyDict() 105 | if opts.optimizer == "adam": 106 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) 107 | elif opts.optimizer == "sgd": 108 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.SGD', lr=opts.lr) 109 | opts.outdir = os.path.join(opts.outdir, f"MoLRG_dataset_resolution{opts.img_res}_classnum{opts.class_num}_perclassdim{opts.per_class_dim}_sample{opts.sample_per_class}") 110 | 111 | # Validate dataset options. 112 | try: 113 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) 114 | dataset_name = dataset_obj.name 115 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution 116 | if opts.cond and not dataset_obj.has_labels: 117 | raise click.ClickException('--cond=True requires labels specified in dataset.json') 118 | del dataset_obj # conserve memory 119 | except IOError as err: 120 | raise click.ClickException(f'--data: {err}') 121 | 122 | # Network architecture. 123 | if opts.arch == 'ddpmpp': 124 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') 125 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=opts.embed_channels, channel_mult=[2,2,2]) 126 | elif opts.arch == 'ncsnpp': 127 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 128 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=opts.embed_channels, channel_mult=[2,2,2]) 129 | elif opts.arch == 'mlp': 130 | c.network_kwargs.update(model_type='TwoLayerMLP', embedding_type='positional') 131 | c.network_kwargs.update(embed_channels=opts.embed_channels, noise_channels=256) 132 | elif opts.arch == 'param-mlp': 133 | c.network_kwargs.update(model_type='ParametericMLP') 134 | c.network_kwargs.update(class_dim=opts.class_num, latent_dim=opts.per_class_dim) 135 | else: 136 | assert opts.arch == 'adm' 137 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 138 | 139 | # Preconditioning & loss function. 140 | if opts.precond == 'vp': 141 | c.network_kwargs.class_name = 'training.networks.VPPrecond' 142 | c.loss_kwargs.class_name = 'training.loss.VPLoss' 143 | elif opts.precond == 've': 144 | c.network_kwargs.class_name = 'training.networks.VEPrecond' 145 | c.loss_kwargs.class_name = 'training.loss.VELoss' 146 | else: 147 | assert opts.precond == 'edm' 148 | c.network_kwargs.class_name = 'training.networks.EDMPrecond' 149 | c.loss_kwargs.class_name = 'training.loss.EDMLoss' 150 | 151 | # Network options. 152 | if opts.cbase is not None: 153 | c.network_kwargs.model_channels = opts.cbase 154 | if opts.cres is not None: 155 | c.network_kwargs.channel_mult = opts.cres 156 | if opts.augment: 157 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) 158 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) 159 | c.network_kwargs.augment_dim = 9 160 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) 161 | 162 | # Training options. 163 | c.total_kimg = max(int(opts.duration * 1000), 1) 164 | c.ema_halflife_kimg = int(opts.ema * 1000) 165 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) 166 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) 167 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) 168 | 169 | # Random seed. 170 | if opts.seed is not None: 171 | c.seed = opts.seed 172 | else: 173 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) 174 | torch.distributed.broadcast(seed, src=0) 175 | c.seed = int(seed) 176 | 177 | # Transfer learning and resume. 178 | if opts.transfer is not None: 179 | if opts.resume is not None: 180 | raise click.ClickException('--transfer and --resume cannot be specified at the same time') 181 | c.resume_pkl = opts.transfer 182 | c.ema_rampup_ratio = None 183 | 184 | elif opts.resumedir is not None: 185 | pt_files = glob(os.path.join(opts.resumedir, 'training-state-*.pt')) 186 | pt_files.sort() 187 | if len(pt_files)!=0: 188 | latest_file = pt_files[-1] 189 | 190 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(latest_file)) 191 | if not match or not os.path.isfile(latest_file): 192 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') 193 | c.resume_pkl = os.path.join(os.path.dirname(latest_file), f'network-snapshot-{match.group(1)}.pkl') 194 | c.resume_kimg = int(match.group(1)) 195 | c.resume_state_dump = latest_file 196 | 197 | # Description string. 198 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' 199 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' 200 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}' 201 | if opts.desc is not None: 202 | desc += f'-{opts.desc}' 203 | 204 | 205 | # Pick output directory. 206 | if dist.get_rank() != 0: 207 | c.run_dir = None 208 | elif opts.nosubdir: 209 | c.run_dir = opts.outdir 210 | elif opts.resumedir: 211 | c.run_dir = opts.resumedir 212 | else: 213 | prev_run_dirs = [] 214 | if os.path.isdir(opts.outdir): 215 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] 216 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] 217 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] 218 | if opts.resume and len(prev_run_ids) !=0 : 219 | cur_run_id = max(prev_run_ids, default=0) 220 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 221 | 222 | pt_files = glob(os.path.join(c.run_dir, 'training-state-*.pt')) 223 | pt_files.sort() 224 | latest_file = pt_files[-1] 225 | print(pt_files) 226 | print(latest_file) 227 | 228 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(latest_file)) 229 | if not match or not os.path.isfile(latest_file): 230 | print(latest_file) 231 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') 232 | c.resume_pkl = os.path.join(os.path.dirname(latest_file), f'network-snapshot-{match.group(1)}.pkl') 233 | c.resume_kimg = int(match.group(1)) 234 | c.resume_state_dump = os.path.join(os.path.dirname(latest_file), f'training-state-{match.group(1)}.pt') 235 | else: 236 | cur_run_id = max(prev_run_ids, default=0) 237 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') 238 | # assert not os.path.exists(c.run_dir) 239 | 240 | 241 | 242 | # Fine Tune 243 | 244 | # Print options. 245 | dist.print0() 246 | dist.print0('Training options:') 247 | dist.print0(json.dumps(c, indent=2)) 248 | dist.print0() 249 | dist.print0(f'Output directory: {c.run_dir}') 250 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') 251 | dist.print0(f'Network architecture: {opts.arch}') 252 | dist.print0(f'Preconditioning & loss: {opts.precond}') 253 | dist.print0(f'Number of GPUs: {dist.get_world_size()}') 254 | dist.print0(f'Batch size: {c.batch_size}') 255 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') 256 | dist.print0() 257 | 258 | # Dry run? 259 | if opts.dry_run: 260 | dist.print0('Dry run; exiting.') 261 | return 262 | 263 | # Create output directory. 264 | dist.print0('Creating output directory...') 265 | if dist.get_rank() == 0: 266 | os.makedirs(c.run_dir, exist_ok=True) 267 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: 268 | json.dump(c, f, indent=2) 269 | # log_dir = '/home/ubuntu/sky_workdir' 270 | # dnnlib.util.Logger(file_name=os.path.join(log_dir, 'log.txt'), file_mode='a', should_flush=True) 271 | 272 | # Train. 273 | training_loop_MoLRG.training_loop(**c) 274 | 275 | #---------------------------------------------------------------------------- 276 | 277 | if __name__ == "__main__": 278 | main() 279 | 280 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /edm/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /edm/training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | from torchvision import transforms 18 | 19 | try: 20 | import pyspng 21 | except ImportError: 22 | pyspng = None 23 | 24 | #---------------------------------------------------------------------------- 25 | # Abstract base class for datasets. 26 | 27 | class Dataset(torch.utils.data.Dataset): 28 | def __init__(self, 29 | name, # Name of the dataset. 30 | raw_shape, # Shape of the raw image data (NCHW). 31 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 32 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 33 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 34 | random_seed = 0, # Random seed to use when applying max_size. 35 | cache = False, # Cache images in CPU memory? 36 | ): 37 | self._name = name 38 | self._raw_shape = list(raw_shape) 39 | self._use_labels = use_labels 40 | self._cache = cache 41 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 42 | self._raw_labels = None 43 | self._label_shape = None 44 | 45 | # Apply max_size. 46 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 47 | if (max_size is not None) and (self._raw_idx.size > max_size): 48 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 49 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 50 | 51 | # Apply xflip. 52 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 53 | if xflip: 54 | self._raw_idx = np.tile(self._raw_idx, 2) 55 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 56 | 57 | def _get_raw_labels(self): 58 | if self._raw_labels is None: 59 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 60 | if self._raw_labels is None: 61 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 62 | assert isinstance(self._raw_labels, np.ndarray) 63 | assert self._raw_labels.shape[0] == self._raw_shape[0] 64 | assert self._raw_labels.dtype in [np.float32, np.int64] 65 | if self._raw_labels.dtype == np.int64: 66 | assert self._raw_labels.ndim == 1 67 | assert np.all(self._raw_labels >= 0) 68 | return self._raw_labels 69 | 70 | def close(self): # to be overridden by subclass 71 | pass 72 | 73 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 74 | raise NotImplementedError 75 | 76 | def _load_raw_labels(self): # to be overridden by subclass 77 | raise NotImplementedError 78 | 79 | def __getstate__(self): 80 | return dict(self.__dict__, _raw_labels=None) 81 | 82 | def __del__(self): 83 | try: 84 | self.close() 85 | except: 86 | pass 87 | 88 | def __len__(self): 89 | return self._raw_idx.size 90 | 91 | def __getitem__(self, idx): 92 | raw_idx = self._raw_idx[idx] 93 | image = self._cached_images.get(raw_idx, None) 94 | if image is None: 95 | image = self._load_raw_image(raw_idx) 96 | if self._cache: 97 | self._cached_images[raw_idx] = image 98 | assert isinstance(image, np.ndarray) 99 | assert list(image.shape) == self.image_shape 100 | assert image.dtype == np.uint8 101 | if self._xflip[idx]: 102 | assert image.ndim == 3 # CHW 103 | image = image[:, :, ::-1] 104 | return image.copy(), self.get_label(idx) 105 | 106 | def get_label(self, idx): 107 | label = self._get_raw_labels()[self._raw_idx[idx]] 108 | if label.dtype == np.int64: 109 | onehot = np.zeros(self.label_shape, dtype=np.float32) 110 | onehot[label] = 1 111 | label = onehot 112 | return label.copy() 113 | 114 | def get_details(self, idx): 115 | d = dnnlib.EasyDict() 116 | d.raw_idx = int(self._raw_idx[idx]) 117 | d.xflip = (int(self._xflip[idx]) != 0) 118 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 119 | return d 120 | 121 | @property 122 | def name(self): 123 | return self._name 124 | 125 | @property 126 | def image_shape(self): 127 | return list(self._raw_shape[1:]) 128 | 129 | @property 130 | def num_channels(self): 131 | assert len(self.image_shape) == 3 # CHW 132 | return self.image_shape[0] 133 | 134 | @property 135 | def resolution(self): 136 | assert len(self.image_shape) == 3 # CHW 137 | assert self.image_shape[1] == self.image_shape[2] 138 | return self.image_shape[1] 139 | 140 | @property 141 | def label_shape(self): 142 | if self._label_shape is None: 143 | raw_labels = self._get_raw_labels() 144 | if raw_labels.dtype == np.int64: 145 | self._label_shape = [int(np.max(raw_labels)) + 1] 146 | else: 147 | self._label_shape = raw_labels.shape[1:] 148 | return list(self._label_shape) 149 | 150 | @property 151 | def label_dim(self): 152 | assert len(self.label_shape) == 1 153 | return self.label_shape[0] 154 | 155 | @property 156 | def has_labels(self): 157 | return any(x != 0 for x in self.label_shape) 158 | 159 | @property 160 | def has_onehot_labels(self): 161 | return self._get_raw_labels().dtype == np.int64 162 | 163 | #---------------------------------------------------------------------------- 164 | # Dataset subclass that loads images recursively from the specified directory 165 | # or ZIP file. 166 | 167 | class ImageFolderDataset(Dataset): 168 | def __init__(self, 169 | path, # Path to directory or zip. 170 | resolution = None, # Ensure specific resolution, None = highest available. 171 | use_pyspng = True, # Use pyspng if available? 172 | **super_kwargs, # Additional arguments for the Dataset base class. 173 | ): 174 | self._path = path 175 | self._use_pyspng = use_pyspng 176 | self._zipfile = None 177 | 178 | if os.path.isdir(self._path): 179 | self._type = 'dir' 180 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 181 | elif self._file_ext(self._path) == '.zip': 182 | self._type = 'zip' 183 | self._all_fnames = set(self._get_zipfile().namelist()) 184 | else: 185 | raise IOError('Path must point to a directory or zip') 186 | 187 | PIL.Image.init() 188 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 189 | images = None 190 | if len(self._image_fnames) == 0: 191 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in ".npz") 192 | if len(self._image_fnames) == 0: 193 | raise IOError('No image files found in the specified path') 194 | else: 195 | images = np.concatenate([np.load(os.path.join(self._path, fname))["samples"].transpose(0, 3, 1, 2) for fname in self._image_fnames], axis = 0) 196 | raw_shape = images.shape 197 | super_kwargs['cache'] = True 198 | else: 199 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 200 | name = os.path.splitext(os.path.basename(self._path))[0] 201 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 202 | raise IOError('Image files do not match the specified resolution') 203 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 204 | 205 | if images is not None: 206 | self._cached_images = {i: images[i] for i in range(images.shape[0])} 207 | 208 | @staticmethod 209 | def _file_ext(fname): 210 | return os.path.splitext(fname)[1].lower() 211 | 212 | def _get_zipfile(self): 213 | assert self._type == 'zip' 214 | if self._zipfile is None: 215 | self._zipfile = zipfile.ZipFile(self._path) 216 | return self._zipfile 217 | 218 | def _open_file(self, fname): 219 | if self._type == 'dir': 220 | return open(os.path.join(self._path, fname), 'rb') 221 | if self._type == 'zip': 222 | return self._get_zipfile().open(fname, 'r') 223 | return None 224 | 225 | def close(self): 226 | try: 227 | if self._zipfile is not None: 228 | self._zipfile.close() 229 | finally: 230 | self._zipfile = None 231 | 232 | def __getstate__(self): 233 | return dict(super().__getstate__(), _zipfile=None) 234 | 235 | def _load_raw_image(self, raw_idx): 236 | fname = self._image_fnames[raw_idx] 237 | # print(fname) 238 | with self._open_file(fname) as f: 239 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 240 | image = pyspng.load(f.read()) 241 | else: 242 | image = np.array(PIL.Image.open(f)) 243 | if image.ndim == 2: 244 | image = image[:, :, np.newaxis] # HW => HWC 245 | image = image.transpose(2, 0, 1) # HWC => CHW 246 | return image 247 | 248 | def _load_raw_labels(self): 249 | fname = 'dataset.json' 250 | if fname not in self._all_fnames: 251 | return None 252 | with self._open_file(fname) as f: 253 | labels = json.load(f)['labels'] 254 | if labels is None: 255 | return None 256 | labels = dict(labels) 257 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 258 | labels = np.array(labels) 259 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 260 | return labels 261 | 262 | class optimal_denoiser_dataset(torch.utils.data.Dataset): 263 | def __init__(self, 264 | path, # Path to directory or zip. 265 | transforms = transforms.Compose([]), 266 | ): 267 | self._path = path 268 | self._file_list = os.listdir(self._path) 269 | self._file_list.sort() 270 | self.images = [torch.load(os.path.join(self._path, pth)) for pth in self._file_list] 271 | self.images = torch.cat(self.images).permute((0, 3, 1, 2)) 272 | 273 | # if select_image_num is not None and select_image_num > 0: 274 | # self.images = self.images[:select_image_num] 275 | self.transforms = transforms 276 | def __len__(self): 277 | return len(self.images) 278 | 279 | def __getitem__(self, idx): 280 | image = self.transforms(self.images[idx]) 281 | return image 282 | 283 | class MoLRG(torch.utils.data.Dataset): 284 | def __init__(self, 285 | resolution = 2, 286 | class_num = 2, 287 | per_class_dim = 2, 288 | sample_per_class = 500, 289 | path = "./datasets", 290 | use_labels = False, 291 | save_dataset = True, 292 | loading_dataset = True, 293 | ): 294 | img_resolution = torch.tensor([resolution, resolution, 3]) 295 | dataset_path = os.path.join(path, f"MoLRG_dataset_resolution{resolution}_classnum{class_num}_perclassdim{per_class_dim}_sample{sample_per_class}.pt") 296 | if (not os.path.exists(dataset_path)) or (not loading_dataset): 297 | print("Create new dataset......") 298 | dim = img_resolution.prod() 299 | rand = torch.randn(dim, dim) 300 | U, _, _ = torch.linalg.svd(rand) 301 | classbasis = [] 302 | ## generate basis 303 | for i in range(class_num): 304 | classbasis.append(U[:, per_class_dim * i:per_class_dim * (i+1)][None, :]) 305 | classbasis = torch.cat(classbasis) 306 | 307 | ## generate sample 308 | data = [] 309 | conds = [] 310 | for cond in range(class_num): 311 | for idx in range(sample_per_class): 312 | data.append((classbasis[cond] @ torch.randn (per_class_dim, 1)).reshape((1, resolution, resolution, 3))) 313 | conds.append(cond) 314 | data = torch.cat(data) 315 | conds = torch.tensor(conds) 316 | if save_dataset: 317 | torch.save({ 318 | "basis": classbasis, 319 | "space_basis": U, 320 | "data":data, 321 | "class_num": class_num, 322 | "sample_per_class": sample_per_class, 323 | "per_class_dim": per_class_dim, 324 | "resolution": resolution, 325 | "conds": conds, 326 | }, dataset_path) 327 | else: 328 | dataset = torch.load(dataset_path) 329 | resolution = dataset["resolution"] 330 | class_num = dataset["class_num"] 331 | per_class_dim = dataset["per_class_dim"] 332 | sample_per_class = dataset["sample_per_class"] 333 | data = dataset["data"] 334 | classbasis = dataset["basis"] 335 | conds = dataset["conds"] 336 | self.data = data 337 | self.conds = conds 338 | self.name = "MoLRG" 339 | self.resolution = resolution 340 | self.num_channels= 3 341 | self.label_dim= 0 342 | self.basis = classbasis 343 | 344 | def __len__(self): 345 | return self.data.shape[0] 346 | 347 | def __getitem__(self, idx): 348 | return (self.data[idx].permute((2, 0, 1)) + 1) * 127.5, self.conds[idx] 349 | #---------------------------------------------------------------------------- 350 | -------------------------------------------------------------------------------- /edm/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | 14 | #---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | @persistence.persistent_class 20 | class VPLoss: 21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 22 | self.beta_d = beta_d 23 | self.beta_min = beta_min 24 | self.epsilon_t = epsilon_t 25 | 26 | def __call__(self, net, images, labels, augment_pipe=None): 27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 29 | weight = 1 / sigma ** 2 30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 31 | n = torch.randn_like(y) * sigma 32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 33 | loss = weight * ((D_yn - y) ** 2) 34 | return loss 35 | 36 | def sigma(self, t): 37 | t = torch.as_tensor(t) 38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 39 | 40 | #---------------------------------------------------------------------------- 41 | # Loss function corresponding to the variance exploding (VE) formulation 42 | # from the paper "Score-Based Generative Modeling through Stochastic 43 | # Differential Equations". 44 | 45 | @persistence.persistent_class 46 | class VELoss: 47 | def __init__(self, sigma_min=0.02, sigma_max=100): 48 | self.sigma_min = sigma_min 49 | self.sigma_max = sigma_max 50 | 51 | def __call__(self, net, images, labels, augment_pipe=None): 52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 54 | weight = 1 / sigma ** 2 55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 56 | n = torch.randn_like(y) * sigma 57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 58 | loss = weight * ((D_yn - y) ** 2) 59 | return loss 60 | 61 | #---------------------------------------------------------------------------- 62 | # Improved loss function proposed in the paper "Elucidating the Design Space 63 | # of Diffusion-Based Generative Models" (EDM). 64 | 65 | @persistence.persistent_class 66 | class EDMLoss: 67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 68 | self.P_mean = P_mean 69 | self.P_std = P_std 70 | self.sigma_data = sigma_data 71 | 72 | def __call__(self, net, images, labels=None, augment_pipe=None): 73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 77 | n = torch.randn_like(y) * sigma 78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 79 | loss = weight * ((D_yn - y) ** 2) 80 | return loss 81 | 82 | #---------------------------------------------------------------------------- 83 | -------------------------------------------------------------------------------- /edm/training/training_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Main training loop.""" 9 | 10 | import os 11 | import time 12 | import copy 13 | import json 14 | import pickle 15 | import psutil 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | from torch_utils import distributed as dist 20 | from torch_utils import training_stats 21 | from torch_utils import misc 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def training_loop( 26 | run_dir = '.', # Output directory. 27 | dataset_kwargs = {}, # Options for training set. 28 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 29 | network_kwargs = {}, # Options for model and preconditioning. 30 | loss_kwargs = {}, # Options for loss function. 31 | optimizer_kwargs = {}, # Options for optimizer. 32 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 33 | seed = 0, # Global random seed. 34 | batch_size = 512, # Total batch size for one training iteration. 35 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 36 | total_kimg = 200000, # Training duration, measured in thousands of training images. 37 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 38 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 39 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 40 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 41 | kimg_per_tick = 50, # Interval of progress prints. 42 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 43 | state_dump_ticks = 500, # How often to dump training state, None = disable. 44 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 45 | resume_state_dump = None, # Start from the given training state, None = reset training state. 46 | resume_kimg = 0, # Start from the given training progress. 47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 48 | device = torch.device('cuda'), 49 | ): 50 | # Initialize. 51 | start_time = time.time() 52 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 53 | torch.manual_seed(np.random.randint(1 << 31)) 54 | torch.backends.cudnn.benchmark = cudnn_benchmark 55 | torch.backends.cudnn.allow_tf32 = False 56 | torch.backends.cuda.matmul.allow_tf32 = False 57 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 58 | 59 | # Select batch size per GPU. 60 | batch_gpu_total = batch_size // dist.get_world_size() 61 | if batch_gpu is None or batch_gpu > batch_gpu_total: 62 | batch_gpu = batch_gpu_total 63 | num_accumulation_rounds = batch_gpu_total // batch_gpu 64 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 65 | 66 | # Load dataset. 67 | dist.print0('Loading dataset...') 68 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 69 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) 70 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) 71 | 72 | # Construct network. 73 | dist.print0('Constructing network...') 74 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 75 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 76 | net.train().requires_grad_(True).to(device) 77 | if dist.get_rank() == 0: 78 | with torch.no_grad(): 79 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) 80 | sigma = torch.ones([batch_gpu], device=device) 81 | labels = torch.zeros([batch_gpu, net.label_dim], device=device) 82 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) 83 | 84 | # Setup optimizer. 85 | dist.print0('Setting up optimizer...') 86 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 87 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer 88 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 89 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device]) 90 | ema = copy.deepcopy(net).eval().requires_grad_(False) 91 | 92 | # Resume training from previous snapshot. 93 | if resume_pkl is not None: 94 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 95 | if dist.get_rank() != 0: 96 | torch.distributed.barrier() # rank 0 goes first 97 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 98 | data = pickle.load(f) 99 | if dist.get_rank() == 0: 100 | torch.distributed.barrier() # other ranks follow 101 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 102 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 103 | del data # conserve memory 104 | if resume_state_dump: 105 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 106 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 107 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) 108 | optimizer.load_state_dict(data['optimizer_state']) 109 | del data # conserve memory 110 | 111 | # Train. 112 | dist.print0(f'Training for {total_kimg} kimg...') 113 | dist.print0() 114 | cur_nimg = resume_kimg * 1000 115 | cur_tick = 0 116 | tick_start_nimg = cur_nimg 117 | tick_start_time = time.time() 118 | maintenance_time = tick_start_time - start_time 119 | dist.update_progress(cur_nimg // 1000, total_kimg) 120 | stats_jsonl = None 121 | while True: 122 | 123 | # Accumulate gradients. 124 | optimizer.zero_grad(set_to_none=True) 125 | for round_idx in range(num_accumulation_rounds): 126 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 127 | images, labels = next(dataset_iterator) 128 | images = images.to(device).to(torch.float32) / 127.5 - 1 129 | labels = labels.to(device) 130 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 131 | training_stats.report('Loss/loss', loss) 132 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 133 | 134 | # Update weights. 135 | for g in optimizer.param_groups: 136 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 137 | for param in net.parameters(): 138 | if param.grad is not None: 139 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 140 | optimizer.step() 141 | 142 | # Update EMA. 143 | ema_halflife_nimg = ema_halflife_kimg * 1000 144 | if ema_rampup_ratio is not None: 145 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 146 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 147 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 148 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 149 | 150 | # Perform maintenance tasks once per tick. 151 | cur_nimg += batch_size 152 | done = (cur_nimg >= total_kimg * 1000) 153 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 154 | continue 155 | 156 | # Print status line, accumulating the same information in training_stats. 157 | tick_end_time = time.time() 158 | fields = [] 159 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 160 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 161 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 162 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 163 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 164 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 165 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 166 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 167 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 168 | torch.cuda.reset_peak_memory_stats() 169 | dist.print0(' '.join(fields)) 170 | 171 | # Check for abort. 172 | if (not done) and dist.should_stop(): 173 | done = True 174 | dist.print0() 175 | dist.print0('Aborting...') 176 | 177 | # Save network snapshot. 178 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): 179 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) 180 | for key, value in data.items(): 181 | if isinstance(value, torch.nn.Module): 182 | value = copy.deepcopy(value).eval().requires_grad_(False) 183 | misc.check_ddp_consistency(value) 184 | data[key] = value.cpu() 185 | del value # conserve memory 186 | if dist.get_rank() == 0: 187 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 188 | pickle.dump(data, f) 189 | del data # conserve memory 190 | 191 | # Save full dump of the training state. 192 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 193 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 194 | 195 | # Update logs. 196 | training_stats.default_collector.update() 197 | if dist.get_rank() == 0: 198 | if stats_jsonl is None: 199 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') 200 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n') 201 | stats_jsonl.flush() 202 | dist.update_progress(cur_nimg // 1000, total_kimg) 203 | 204 | # Update state. 205 | cur_tick += 1 206 | tick_start_nimg = cur_nimg 207 | tick_start_time = time.time() 208 | maintenance_time = tick_start_time - tick_end_time 209 | if done: 210 | break 211 | 212 | # Done. 213 | dist.print0() 214 | dist.print0('Exiting...') 215 | 216 | #---------------------------------------------------------------------------- 217 | -------------------------------------------------------------------------------- /edm/training/training_loop_MoLRG.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Main training loop.""" 9 | 10 | import os 11 | import time 12 | import copy 13 | import json 14 | import pickle 15 | import psutil 16 | import numpy as np 17 | import torch 18 | import dnnlib 19 | from torch_utils import distributed as dist 20 | from torch_utils import training_stats 21 | from torch_utils import misc 22 | from training.networks import UNetBlock 23 | from training.networks import EDMPrecond 24 | #---------------------------------------------------------------------------- 25 | 26 | def training_loop( 27 | run_dir = '.', # Output directory. 28 | dataset_kwargs = {}, # Options for training set. 29 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. 30 | network_kwargs = {}, # Options for model and preconditioning. 31 | loss_kwargs = {}, # Options for loss function. 32 | optimizer_kwargs = {}, # Options for optimizer. 33 | augment_kwargs = None, # Options for augmentation pipeline, None = disable. 34 | seed = 0, # Global random seed. 35 | batch_size = 512, # Total batch size for one training iteration. 36 | batch_gpu = None, # Limit batch size per GPU, None = no limit. 37 | total_kimg = 200000, # Training duration, measured in thousands of training images. 38 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. 39 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. 40 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration. 41 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. 42 | kimg_per_tick = 50, # Interval of progress prints. 43 | snapshot_ticks = 50, # How often to save network snapshots, None = disable. 44 | state_dump_ticks = 500, # How often to dump training state, None = disable. 45 | resume_pkl = None, # Start from the given network snapshot, None = random initialization. 46 | resume_state_dump = None, # Start from the given training state, None = reset training state. 47 | resume_kimg = 0, # Start from the given training progress. 48 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? 49 | device = torch.device('cuda'), 50 | pretrained_model_path = None 51 | ): 52 | # device = torch.device(f'cuda:{dist.get_rank()}') 53 | # Initialize. 54 | start_time = time.time() 55 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) 56 | torch.manual_seed(np.random.randint(1 << 31)) 57 | torch.backends.cudnn.benchmark = cudnn_benchmark 58 | torch.backends.cudnn.allow_tf32 = False 59 | torch.backends.cuda.matmul.allow_tf32 = False 60 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 61 | 62 | # Select batch size per GPU. 63 | batch_gpu_total = batch_size // dist.get_world_size() 64 | if batch_gpu is None or batch_gpu > batch_gpu_total: 65 | batch_gpu = batch_gpu_total 66 | num_accumulation_rounds = batch_gpu_total // batch_gpu 67 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() 68 | 69 | # Load dataset. 70 | dist.print0('Loading dataset...') 71 | 72 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset 73 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) 74 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) 75 | 76 | # Construct network. 77 | dist.print0('Constructing network...') 78 | # if fine_tune: 79 | if 'cifar10' in dataset_kwargs['path']: 80 | interface_kwargs = dict(img_resolution=32, img_channels=3, label_dim=0) 81 | elif 'MoLRG' in dataset_kwargs["class_name"]: 82 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 83 | else: 84 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) 85 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module 86 | 87 | net.train().requires_grad_(True).to(device) 88 | 89 | 90 | # Setup optimizer. 91 | dist.print0('Setting up optimizer...') 92 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss 93 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer 94 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe 95 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) 96 | ema = copy.deepcopy(net).eval().requires_grad_(False) 97 | 98 | 99 | if pretrained_model_path and resume_pkl is None: 100 | if dist.get_rank() != 0: 101 | torch.distributed.barrier() # rank 0 goes first 102 | with dnnlib.util.open_url(pretrained_model_path, verbose=(dist.get_rank() == 0)) as f: 103 | data = pickle.load(f) 104 | if dist.get_rank() == 0: 105 | torch.distributed.barrier() # other ranks follow 106 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 107 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 108 | # Resume training from previous snapshot. 109 | if resume_pkl is not None: 110 | dist.print0(f'Loading network weights from "{resume_pkl}"...') 111 | if dist.get_rank() != 0: 112 | torch.distributed.barrier() # rank 0 goes first 113 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: 114 | data = pickle.load(f) 115 | if dist.get_rank() == 0: 116 | torch.distributed.barrier() # other ranks follow 117 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) 118 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) 119 | del data # conserve memory 120 | if resume_state_dump: 121 | dist.print0(f'Loading training state from "{resume_state_dump}"...') 122 | data = torch.load(resume_state_dump, map_location=torch.device('cpu')) 123 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) 124 | optimizer.load_state_dict(data['optimizer_state']) 125 | del data # conserve memory 126 | dist.print0(f'Training for {total_kimg} kimg...') 127 | dist.print0() 128 | cur_nimg = resume_kimg * 1000 129 | cur_tick = 0 130 | tick_start_nimg = cur_nimg 131 | tick_start_time = time.time() 132 | maintenance_time = tick_start_time - start_time 133 | dist.update_progress(cur_nimg // 1000, total_kimg) 134 | stats_jsonl = None 135 | sd1 = copy.deepcopy(net.state_dict()) 136 | while True: 137 | 138 | # Accumulate gradients. 139 | optimizer.zero_grad(set_to_none=True) 140 | for round_idx in range(num_accumulation_rounds): 141 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): 142 | images, labels = next(dataset_iterator) 143 | images = images.to(device).to(torch.float32) / 127.5 - 1 144 | labels = labels.to(device) 145 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) 146 | training_stats.report('Loss/loss', loss) 147 | loss.sum().mul(loss_scaling / batch_gpu_total).backward() 148 | 149 | # Update weights. 150 | for g in optimizer.param_groups: 151 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) 152 | for param in net.parameters(): 153 | if param.grad is not None: 154 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 155 | optimizer.step() 156 | 157 | # Update EMA. 158 | ema_halflife_nimg = ema_halflife_kimg * 1000 159 | if ema_rampup_ratio is not None: 160 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) 161 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) 162 | for p_ema, p_net in zip(ema.parameters(), net.parameters()): 163 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) 164 | 165 | # Perform maintenance tasks once per tick. 166 | cur_nimg += batch_size 167 | done = (cur_nimg >= total_kimg * 1000) 168 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): 169 | continue 170 | 171 | # Print status line, accumulating the same information in training_stats. 172 | tick_end_time = time.time() 173 | fields = [] 174 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] 175 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] 176 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] 177 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] 178 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] 179 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] 180 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] 181 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] 182 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] 183 | fields += [f"loss {training_stats.report('Loss/loss', loss.mean().mul(loss_scaling).item()):<6.5f}"] 184 | 185 | torch.cuda.reset_peak_memory_stats() 186 | dist.print0(' '.join(fields)) 187 | 188 | # Check for abort. 189 | if (not done) and dist.should_stop(): 190 | done = True 191 | dist.print0() 192 | dist.print0('Aborting...') 193 | 194 | # Save network snapshot. 195 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): 196 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) 197 | for key, value in data.items(): 198 | if isinstance(value, torch.nn.Module): 199 | value = copy.deepcopy(value).eval().requires_grad_(False) 200 | misc.check_ddp_consistency(value) 201 | data[key] = value.cpu() 202 | del value # conserve memory 203 | if dist.get_rank() == 0: 204 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: 205 | pickle.dump(data, f) 206 | del data # conserve memory 207 | 208 | # Save full dump of the training state. 209 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: 210 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) 211 | 212 | # Update logs. 213 | training_stats.default_collector.update() 214 | dist.update_progress(cur_nimg // 1000, total_kimg) 215 | 216 | # Update state. 217 | cur_tick += 1 218 | tick_start_nimg = cur_nimg 219 | tick_start_time = time.time() 220 | maintenance_time = tick_start_time - tick_end_time 221 | if done: 222 | break 223 | 224 | # Done. 225 | dist.print0() 226 | dist.print0('Exiting...') 227 | 228 | #---------------------------------------------------------------------------- 229 | -------------------------------------------------------------------------------- /figures/generalization-score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/generalization-score.png -------------------------------------------------------------------------------- /figures/jacobian-MoLRG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/jacobian-MoLRG.png -------------------------------------------------------------------------------- /figures/jacobian-real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/jacobian-real.png -------------------------------------------------------------------------------- /figures/optimal-denoiser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/optimal-denoiser.png -------------------------------------------------------------------------------- /figures/reproducibility-score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/reproducibility-score.png -------------------------------------------------------------------------------- /figures/similarity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/similarity.png --------------------------------------------------------------------------------