├── .gitignore
├── LICENSE
├── README.md
├── edm
├── Dockerfile
├── LICENSE.txt
├── README.md
├── dataset_tool.py
├── dnnlib
│ ├── __init__.py
│ └── util.py
├── docs
│ ├── afhqv2-64x64.png
│ ├── cifar10-32x32.png
│ ├── dataset-tool-help.txt
│ ├── ffhq-64x64.png
│ ├── fid-help.txt
│ ├── generate-help.txt
│ ├── imagenet-64x64.png
│ ├── teaser-1280x640.jpg
│ ├── teaser-1920x640.jpg
│ ├── teaser-640x480.jpg
│ └── train-help.txt
├── environment.yml
├── example.py
├── fid.py
├── generate.py
├── jacobian.py
├── sscd.py
├── torch_utils
│ ├── __init__.py
│ ├── distributed.py
│ ├── misc.py
│ ├── persistence.py
│ └── training_stats.py
├── train.py
├── trainMoLRG.py
└── training
│ ├── __init__.py
│ ├── augment.py
│ ├── dataset.py
│ ├── loss.py
│ ├── networks.py
│ ├── training_loop.py
│ └── training_loop_MoLRG.py
└── figures
├── generalization-score.png
├── jacobian-MoLRG.png
├── jacobian-real.png
├── optimal-denoiser.png
├── reproducibility-score.png
└── similarity.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 huijieZH
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Understanding Generalizability of Diffusion Models through Low-dimensional Distribution Learning
2 |
3 | This is an official implementation of the paper
4 | 1. [The Emergence of Reproducibility and Consistency in Diffusion Models](https://arxiv.org/abs/2310.05264) **NeurIPS 2023 workshop Best Paper, ICML 2024**
5 | 2. [Diffusion Models Learn Low-Dimensional Distributions via Subspace Clustering](https://arxiv.org/abs/2409.02426)
6 |
7 | The codebase mainly focuses on the implementation of three main figures from these two papers, including:
8 | 1. "Memorization" and "Generalization" regimes for unconditional diffusion models. (Figure 2 in [Paper 1](https://arxiv.org/abs/2310.05264))
9 | 2. Convergence of the optimal denoiser. (Figure 4 Left in [Paper 1](https://arxiv.org/abs/2310.05264))
10 | 3. Similarity among different unconditional diffusion model settings in generalization regime. (Figure 6 and Figure 12 in [Paper 1](https://arxiv.org/abs/2310.05264))
11 | 4. Low-rank property of the denoising autoencoder of trained diffusion models. (Figure 3 in [Paper 2](https://arxiv.org/abs/2409.02426))
12 |
13 | For the implementation of Figure 1 (Correspondence between the singular vectors of the Jacobian of the DAE and semantic
14 | image attributes) in [Paper 2](https://arxiv.org/abs/2409.02426), please go through our concurrent work [Exploring Low-Dimensional Subspaces in Diffusion Models for Controllable Image Editing](https://arxiv.org/abs/2409.02374), the codebase could be found [here](https://github.com/ChicyChen/LOCO-Edit).
15 |
16 | ### Requirements
17 |
18 | ```bash
19 | conda env create -f edm/environment.yml -n generalizability
20 | conda activate generalizability
21 | ```
22 |
23 | ## "Memorization" and "Generalization" regimes for unconditional diffusion models.
24 |
25 |
26 |

27 |

28 |
29 |
30 | Slightly different from Figure 2 in [Paper 1](https://arxiv.org/abs/2310.05264), the code we release is under a finetuning setting: the training dataset is generated from a pre-trained diffusion model (teacher model).
31 |
32 | ### Create Dataset
33 |
34 | Create a dataset of specific dataset size as follows:
35 |
36 | ```bash
37 | # generate images from teacher model
38 | python edm/generate.py --outdir=out --seeds=0-49999 --batch=64 --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl
39 |
40 | # create dataset with different size
41 | python edm/dataset_tool.py --source=out --max-images=128 --dest=datasets/synthetic-cifar10-32x32-n128.zip
42 | ```
43 |
44 | ### Training
45 |
46 | ```bash
47 | torchrun --standalone --nproc_per_node=1 edm/train.py --outdir=training --data=datasets/synthetic-cifar10-32x32-n128.zip --cond=0 --arch=ddpmpp --duration 50 --batch 128 --snap 500 --dump 500 --precond vp --model_channels 64
48 | ```
49 |
50 | ### Evaluation
51 |
52 | All checkpoints we released can be found [here](https://www.dropbox.com/scl/fo/m8tf61cengcp1qyevwiwv/AKoLuvIY5Fx0Tz1g8eRFWoI?rlkey=x7t1iqunpzofddgv533bx48q8&st=wtfeg1a9&dl=0), and all training dataset we released can be found [here](https://www.dropbox.com/scl/fo/fqwgl5pvqe4jgvuw945k6/AHEH9P8AYVYhMx_ABTclVC4?rlkey=frsgki669ny9lmxiwlpafanhg&st=kbgrn9za&dl=0)
53 |
54 | ```bash
55 |
56 | ### generate image from diffusion model, the seeds is different from the one (which is 0-49999) used to generate training images from teacher model.
57 | python edm/generate.py --outdir=evaluation/ddpm-dim64-n64 --seeds=100000-109999 --batch=64 --network=training/ckpt/ddpm-dim64-n64.pkl
58 |
59 | python edm/generate.py --outdir=evaluation/ddpm-dim128-n64 --seeds=100000-109999 --batch=64 --network=training/ckpt/ddpm-dim128-n64.pkl
60 |
61 | ### Calculate SSCD feature
62 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim64-n64 --features ./evaluation/sscd-dim64-n64.npz
63 |
64 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim128-n64 --features ./evaluation/sscd-dim128-n64.npz
65 |
66 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images datasets/synthetic-cifar10-32x32-n64.zip --features ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n64.npz
67 |
68 | # Compute reproducibility score
69 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-dim64-n64.npz
70 |
71 | # Compute generalization score
72 | python edm/sscd.py mscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n64.npz
73 |
74 | ```
75 |
76 | ## Convergence of the optimal denoiser.
77 |
78 |
79 |

80 |
81 |
82 | We implement the optimal denoiser (derived from the score function of the empirial distribution). And compare the RP score between real diffusion model and the optimal denoiser.
83 |
84 | ```bash
85 | ### generate image from optimal denoiser
86 | python edm/generate.py --outdir=evaluation/memorization-n64 --seeds=100000-109999 --batch=64 --optimal_denoiser --dataset=datasets/synthetic-cifar10-32x32-n64.zip --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl
87 |
88 | ### Calculate SSCD feature
89 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/memorization-n64 --features ./evaluation/sscd-memorization-n64.npz
90 |
91 | ### Compute reproducibility score
92 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n64.npz --target ./evaluation/sscd-memorization-n64.npz
93 |
94 | ```
95 |
96 | ## Similarity among different unconditional diffusion model settings in generalization regime.
97 |
98 |
99 |

100 |
101 |
102 | We provide generated samples from those different diffusion models [here](https://www.dropbox.com/scl/fo/xq0yvr92ohzb6ov313928/ANo8GzZ5GybCrzJRb2P1qU8?rlkey=iaf6316aezz4wznigj4ir2v29&st=psa7il3e&dl=0). To generate new samples, you need to go through their own github repo and use the same initial noise for generation.
103 |
104 | ```bash
105 |
106 | ### Calculate SSCD feature
107 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./samples/ddpmv4 --features ./evaluation/sscd-ddpmv4.npz
108 |
109 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./samples/ddpmv6 --features ./evaluation/sscd-ddpmv6.npz
110 |
111 |
112 | # Compute reproducibility score
113 | python edm/sscd.py rpscore --source ./evaluation/sscd-ddpmv4.npz --target ./evaluation/sscd-ddpmv6.npz
114 |
115 |
116 | ```
117 |
118 | ## Low-rank property of the denoising autoencoder of trained diffusion models.
119 |
120 |
121 |

122 |

123 |
124 | These figures illustrate the low-dimensionality of the jacobian of the denoising autoencoder (DAE) trained on real dataset and Mixture of Low Rank Gaussian distribution (MoLRG).
125 |
126 | To training diffusion model with MoLRG:
127 | ```bash
128 | torchrun --standalone --nproc_per_node=1 edm/trainMoLRG.py --outdir training --path datasets --img_res 4 --class_num 2 --per_class_dim 7 --sample_per_class 350 --embed_channels 128
129 | ```
130 |
131 | To Evaluate rank of the jacobian:
132 | ```bash
133 | torchrun --standalone --nproc_per_node=1 edm/jacobian.py --network_pkl
134 |
135 | e.g.
136 | torchrun --standalone --nproc_per_node=1 edm/jacobian.py --network_pkl https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl
137 | ```
138 |
139 | Notably, built upon [NVlabs/edm](https://github.com/NVlabs/edm), our codebase is compatible with all training ckpts released from their repo, where you could find [here](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/) and [here](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/).
140 |
141 | ## Acknowledgements
142 | This repository is highly based on [NVlabs/edm](https://github.com/NVlabs/edm).
143 |
144 | ## BibTeX
145 | ```
146 | @inproceedings{
147 | zhang2024the,
148 | title={The Emergence of Reproducibility and Consistency in Diffusion Models},
149 | author={Huijie Zhang and Jinfan Zhou and Yifu Lu and Minzhe Guo and Peng Wang and Liyue Shen and Qing Qu},
150 | booktitle={Forty-first International Conference on Machine Learning},
151 | year={2024},
152 | url={https://openreview.net/forum?id=HsliOqZkc0}
153 | }
154 |
155 | @article{wang2024diffusion,
156 | title={Diffusion models learn low-dimensional distributions via subspace clustering},
157 | author={Wang, Peng and Zhang, Huijie and Zhang, Zekai and Chen, Siyi and Ma, Yi and Qu, Qing},
158 | journal={arXiv preprint arXiv:2409.02426},
159 | year={2024}
160 | }
161 | ```
162 |
--------------------------------------------------------------------------------
/edm/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | FROM nvcr.io/nvidia/pytorch:22.10-py3
9 |
10 | ENV PYTHONDONTWRITEBYTECODE 1
11 | ENV PYTHONUNBUFFERED 1
12 |
13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
14 |
15 | WORKDIR /workspace
16 |
17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
18 | ENTRYPOINT ["/entry.sh"]
19 |
--------------------------------------------------------------------------------
/edm/README.md:
--------------------------------------------------------------------------------
1 | ## Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
Official PyTorch implementation of the NeurIPS 2022 paper
2 |
3 | 
4 |
5 | **Elucidating the Design Space of Diffusion-Based Generative Models**
6 | Tero Karras, Miika Aittala, Timo Aila, Samuli Laine
7 |
https://arxiv.org/abs/2206.00364
8 |
9 | Abstract: *We argue that the theory and practice of diffusion-based generative models are currently unnecessarily convoluted and seek to remedy the situation by presenting a design space that clearly separates the concrete design choices. This lets us identify several changes to both the sampling and training processes, as well as preconditioning of the score networks. Together, our improvements yield new state-of-the-art FID of 1.79 for CIFAR-10 in a class-conditional setting and 1.97 in an unconditional setting, with much faster sampling (35 network evaluations per image) than prior designs. To further demonstrate their modular nature, we show that our design changes dramatically improve both the efficiency and quality obtainable with pre-trained score networks from previous work, including improving the FID of a previously trained ImageNet-64 model from 2.07 to near-SOTA 1.55, and after re-training with our proposed improvements to a new SOTA of 1.36.*
10 |
11 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
12 |
13 | ## Requirements
14 |
15 | * Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
16 | * 1+ high-end NVIDIA GPU for sampling and 8+ GPUs for training. We have done all testing and development using V100 and A100 GPUs.
17 | * 64-bit Python 3.8 and PyTorch 1.12.0 (or later). See https://pytorch.org for PyTorch install instructions.
18 | * Python libraries: See [environment.yml](./environment.yml) for exact library dependencies. You can use the following commands with Miniconda3 to create and activate your Python environment:
19 | - `conda env create -f environment.yml -n edm`
20 | - `conda activate edm`
21 | * Docker users:
22 | - Ensure you have correctly installed the [NVIDIA container runtime](https://docs.docker.com/config/containers/resource_constraints/#gpu).
23 | - Use the [provided Dockerfile](./Dockerfile) to build an image with the required library dependencies.
24 |
25 | ## Getting started
26 |
27 | To reproduce the main results from our paper, simply run:
28 |
29 | ```.bash
30 | python example.py
31 | ```
32 |
33 | This is a minimal standalone script that loads the best pre-trained model for each dataset and generates a random 8x8 grid of images using the optimal sampler settings. Expected results:
34 |
35 | | Dataset | Runtime | Reference image
36 | | :------- | :------ | :--------------
37 | | CIFAR-10 | ~6 sec | [`cifar10-32x32.png`](./docs/cifar10-32x32.png)
38 | | FFHQ | ~28 sec | [`ffhq-64x64.png`](./docs/ffhq-64x64.png)
39 | | AFHQv2 | ~28 sec | [`afhqv2-64x64.png`](./docs/afhqv2-64x64.png)
40 | | ImageNet | ~5 min | [`imagenet-64x64.png`](./docs/imagenet-64x64.png)
41 |
42 | The easiest way to explore different sampling strategies is to modify [`example.py`](./example.py) directly. You can also incorporate the pre-trained models and/or our proposed EDM sampler in your own code by simply copy-pasting the relevant bits. Note that the class definitions for the pre-trained models are stored within the pickles themselves and loaded automatically during unpickling via [`torch_utils.persistence`](./torch_utils/persistence.py). To use the models in external Python scripts, just make sure that `torch_utils` and `dnnlib` are accesible through `PYTHONPATH`.
43 |
44 | **Docker**: You can run the example script using Docker as follows:
45 |
46 | ```.bash
47 | # Build the edm:latest image
48 | docker build --tag edm:latest .
49 |
50 | # Run the generate.py script using Docker:
51 | docker run --gpus all -it --rm --user $(id -u):$(id -g) \
52 | -v `pwd`:/scratch --workdir /scratch -e HOME=/scratch \
53 | edm:latest \
54 | python example.py
55 | ```
56 |
57 | Note: The Docker image requires NVIDIA driver release `r520` or later.
58 |
59 | The `docker run` invocation may look daunting, so let's unpack its contents here:
60 |
61 | - `--gpus all -it --rm --user $(id -u):$(id -g)`: with all GPUs enabled, run an interactive session with current user's UID/GID to avoid Docker writing files as root.
62 | - ``-v `pwd`:/scratch --workdir /scratch``: mount current running dir (e.g., the top of this git repo on your host machine) to `/scratch` in the container and use that as the current working dir.
63 | - `-e HOME=/scratch`: specify where to cache temporary files. Note: if you want more fine-grained control, you can instead set `DNNLIB_CACHE_DIR` (for pre-trained model download cache). You want these cache dirs to reside on persistent volumes so that their contents are retained across multiple `docker run` invocations.
64 |
65 | ## Pre-trained models
66 |
67 | We provide pre-trained models for our proposed training configuration (config F) as well as the baseline configuration (config A):
68 |
69 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/)
70 | - [https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/](https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/)
71 |
72 | To generate a batch of images using a given model and sampler, run:
73 |
74 | ```.bash
75 | # Generate 64 images and save them as out/*.png
76 | python generate.py --outdir=out --seeds=0-63 --batch=64 \
77 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
78 | ```
79 |
80 | Generating a large number of images can be time-consuming; the workload can be distributed across multiple GPUs by launching the above command using `torchrun`:
81 |
82 | ```.bash
83 | # Generate 1024 images using 2 GPUs
84 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \
85 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
86 | ```
87 |
88 | The sampler settings can be controlled through command-line options; see [`python generate.py --help`](./docs/generate-help.txt) for more information. For best results, we recommend using the following settings for each dataset:
89 |
90 | ```.bash
91 | # For CIFAR-10 at 32x32, use deterministic sampling with 18 steps (NFE = 35)
92 | python generate.py --outdir=out --steps=18 \
93 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
94 |
95 | # For FFHQ and AFHQv2 at 64x64, use deterministic sampling with 40 steps (NFE = 79)
96 | python generate.py --outdir=out --steps=40 \
97 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-ffhq-64x64-uncond-vp.pkl
98 |
99 | # For ImageNet at 64x64, use stochastic sampling with 256 steps (NFE = 511)
100 | python generate.py --outdir=out --steps=256 --S_churn=40 --S_min=0.05 --S_max=50 --S_noise=1.003 \
101 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl
102 | ```
103 |
104 | Besides our proposed EDM sampler, `generate.py` can also be used to reproduce the sampler ablations from Section 3 of our paper. For example:
105 |
106 | ```.bash
107 | # Figure 2a, "Our reimplementation"
108 | python generate.py --outdir=out --steps=512 --solver=euler --disc=vp --schedule=vp --scaling=vp \
109 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
110 |
111 | # Figure 2a, "+ Heun & our {t_i}"
112 | python generate.py --outdir=out --steps=128 --solver=heun --disc=edm --schedule=vp --scaling=vp \
113 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
114 |
115 | # Figure 2a, "+ Our sigma(t) & s(t)"
116 | python generate.py --outdir=out --steps=18 --solver=heun --disc=edm --schedule=linear --scaling=none \
117 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/baseline/baseline-cifar10-32x32-uncond-vp.pkl
118 | ```
119 |
120 | ## Calculating FID
121 |
122 | To compute Fréchet inception distance (FID) for a given model and sampler, first generate 50,000 random images and then compare them against the dataset reference statistics using `fid.py`:
123 |
124 | ```.bash
125 | # Generate 50000 images and save them as fid-tmp/*/*.png
126 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \
127 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
128 |
129 | # Calculate FID
130 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \
131 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
132 | ```
133 |
134 | Both of the above commands can be parallelized across multiple GPUs by adjusting `--nproc_per_node`. The second command typically takes 1-3 minutes in practice, but the first one can sometimes take several hours, depending on the configuration. See [`python fid.py --help`](./docs/fid-help.txt) for the full list of options.
135 |
136 | Note that the numerical value of FID varies across different random seeds and is highly sensitive to the number of images. By default, `fid.py` will always use 50,000 generated images; providing fewer images will result in an error, whereas providing more will use a random subset. To reduce the effect of random variation, we recommend repeating the calculation multiple times with different seeds, e.g., `--seeds=0-49999`, `--seeds=50000-99999`, and `--seeds=100000-149999`. In our paper, we calculated each FID three times and reported the minimum.
137 |
138 | Also note that it is important to compare the generated images against the same dataset that the model was originally trained with. To facilitate evaluation, we provide the exact reference statistics that correspond to our pre-trained models:
139 |
140 | * [https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/](https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/)
141 |
142 | For ImageNet, we provide two sets of reference statistics to enable apples-to-apples comparison: `imagenet-64x64.npz` should be used when evaluating the EDM model (`edm-imagenet-64x64-cond-adm.pkl`), whereas `imagenet-64x64-baseline.npz` should be used when evaluating the baseline model (`baseline-imagenet-64x64-cond-adm.pkl`); the latter was originally trained by Dhariwal and Nichol using slightly different training data.
143 |
144 | You can compute the reference statistics for your own datasets as follows:
145 |
146 | ```.bash
147 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
148 | ```
149 |
150 | ## Preparing datasets
151 |
152 | Datasets are stored in the same format as in [StyleGAN](https://github.com/NVlabs/stylegan3): uncompressed ZIP archives containing uncompressed PNG files and a metadata file `dataset.json` for labels. Custom datasets can be created from a folder containing images; see [`python dataset_tool.py --help`](./docs/dataset-tool-help.txt) for more information.
153 |
154 | **CIFAR-10:** Download the [CIFAR-10 python version](https://www.cs.toronto.edu/~kriz/cifar.html) and convert to ZIP archive:
155 |
156 | ```.bash
157 | python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
158 | --dest=datasets/cifar10-32x32.zip
159 | python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz
160 | ```
161 |
162 | **FFHQ:** Download the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset) as 1024x1024 images and convert to ZIP archive at 64x64 resolution:
163 |
164 | ```.bash
165 | python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
166 | --dest=datasets/ffhq-64x64.zip --resolution=64x64
167 | python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz
168 | ```
169 |
170 | **AFHQv2:** Download the updated [Animal Faces-HQ dataset](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) (`afhq-v2-dataset`) and convert to ZIP archive at 64x64 resolution:
171 |
172 | ```.bash
173 | python dataset_tool.py --source=downloads/afhqv2 \
174 | --dest=datasets/afhqv2-64x64.zip --resolution=64x64
175 | python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz
176 | ```
177 |
178 | **ImageNet:** Download the [ImageNet Object Localization Challenge](https://www.kaggle.com/competitions/imagenet-object-localization-challenge/data) and convert to ZIP archive at 64x64 resolution:
179 |
180 | ```.bash
181 | python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
182 | --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
183 | python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz
184 | ```
185 |
186 | ## Training new models
187 |
188 | You can train new models using `train.py`. For example:
189 |
190 | ```.bash
191 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
192 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \
193 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
194 | ```
195 |
196 | The above example uses the default batch size of 512 images (controlled by `--batch`) that is divided evenly among 8 GPUs (controlled by `--nproc_per_node`) to yield 64 images per GPU. Training large models may run out of GPU memory; the best way to avoid this is to limit the per-GPU batch size, e.g., `--batch-gpu=32`. This employs gradient accumulation to yield the same results as using full per-GPU batches. See [`python train.py --help`](./docs/train-help.txt) for the full list of options.
197 |
198 | The results of each training run are saved to a newly created directory, for example `training-runs/00000-cifar10-cond-ddpmpp-edm-gpus8-batch64-fp32`. The training loop exports network snapshots (`network-snapshot-*.pkl`) and training states (`training-state-*.pt`) at regular intervals (controlled by `--snap` and `--dump`). The network snapshots can be used to generate images with `generate.py`, and the training states can be used to resume the training later on (`--resume`). Other useful information is recorded in `log.txt` and `stats.jsonl`. To monitor training convergence, we recommend looking at the training loss (`"Loss/loss"` in `stats.jsonl`) as well as periodically evaluating FID for `network-snapshot-*.pkl` using `generate.py` and `fid.py`.
199 |
200 | The following table lists the exact training configurations that we used to obtain our pre-trained models:
201 |
202 | | Model | GPUs | Time | Options
203 | | :-- | :-- | :-- | :--
204 | | cifar10‑32x32‑cond‑vp | 8xV100 | ~2 days | `--cond=1 --arch=ddpmpp`
205 | | cifar10‑32x32‑cond‑ve | 8xV100 | ~2 days | `--cond=1 --arch=ncsnpp`
206 | | cifar10‑32x32‑uncond‑vp | 8xV100 | ~2 days | `--cond=0 --arch=ddpmpp`
207 | | cifar10‑32x32‑uncond‑ve | 8xV100 | ~2 days | `--cond=0 --arch=ncsnpp`
208 | | ffhq‑64x64‑uncond‑vp | 8xV100 | ~4 days | `--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`
209 | | ffhq‑64x64‑uncond‑ve | 8xV100 | ~4 days | `--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15`
210 | | afhqv2‑64x64‑uncond‑vp | 8xV100 | ~4 days | `--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`
211 | | afhqv2‑64x64‑uncond‑ve | 8xV100 | ~4 days | `--cond=0 --arch=ncsnpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.25 --augment=0.15`
212 | | imagenet‑64x64‑cond‑adm | 32xA100 | ~13 days | `--cond=1 --arch=adm --duration=2500 --batch=4096 --lr=1e-4 --ema=50 --dropout=0.10 --augment=0 --fp16=1 --ls=100 --tick=200`
213 |
214 | For ImageNet-64, we ran the training on four NVIDIA DGX A100 nodes, each containing 8 Ampere GPUs with 80 GB of memory. To reduce the GPU memory requirements, we recommend either training the model with more GPUs or limiting the per-GPU batch size with `--batch-gpu`. To set up multi-node training, please consult the [torchrun documentation](https://pytorch.org/docs/stable/elastic/run.html).
215 |
216 | ## License
217 |
218 | Copyright © 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
219 |
220 | All material, including source code and pre-trained models, is licensed under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/).
221 |
222 | `baseline-cifar10-32x32-uncond-vp.pkl` and `baseline-cifar10-32x32-uncond-ve.pkl` are derived from the [pre-trained models](https://github.com/yang-song/score_sde_pytorch) by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. The models were originally shared under the [Apache 2.0 license](https://github.com/yang-song/score_sde_pytorch/blob/main/LICENSE).
223 |
224 | `baseline-imagenet-64x64-cond-adm.pkl` is derived from the [pre-trained model](https://github.com/openai/guided-diffusion) by Prafulla Dhariwal and Alex Nichol. The model was originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
225 |
226 | `imagenet-64x64-baseline.npz` is derived from the [precomputed reference statistics](https://github.com/openai/guided-diffusion/tree/main/evaluations) by Prafulla Dhariwal and Alex Nichol. The statistics were
227 | originally shared under the [MIT license](https://github.com/openai/guided-diffusion/blob/main/LICENSE).
228 |
229 | ## Citation
230 |
231 | ```
232 | @inproceedings{Karras2022edm,
233 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
234 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
235 | booktitle = {Proc. NeurIPS},
236 | year = {2022}
237 | }
238 | ```
239 |
240 | ## Development
241 |
242 | This is a research reference implementation and is treated as a one-time code drop. As such, we do not accept outside code contributions in the form of pull requests.
243 |
244 | ## Acknowledgments
245 |
246 | We thank Jaakko Lehtinen, Ming-Yu Liu, Tuomas Kynkäänniemi, Axel Sauer, Arash Vahdat, and Janne Hellsten for discussions and comments, and Tero Kuosmanen, Samuel Klenberg, and Janne Hellsten for maintaining our compute infrastructure.
247 |
--------------------------------------------------------------------------------
/edm/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/edm/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
59 | self.file = None
60 |
61 | if file_name is not None:
62 | self.file = open(file_name, file_mode)
63 |
64 | self.should_flush = should_flush
65 | self.stdout = sys.stdout
66 | self.stderr = sys.stderr
67 |
68 | sys.stdout = self
69 | sys.stderr = self
70 |
71 | def __enter__(self) -> "Logger":
72 | return self
73 |
74 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
75 | self.close()
76 |
77 | def write(self, text: Union[str, bytes]) -> None:
78 | """Write text to stdout (and a file) and optionally flush."""
79 | if isinstance(text, bytes):
80 | text = text.decode()
81 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
82 | return
83 |
84 | if self.file is not None:
85 | self.file.write(text)
86 |
87 | self.stdout.write(text)
88 |
89 | if self.should_flush:
90 | self.flush()
91 |
92 | def flush(self) -> None:
93 | """Flush written text to both stdout and a file, if open."""
94 | if self.file is not None:
95 | self.file.flush()
96 |
97 | self.stdout.flush()
98 |
99 | def close(self) -> None:
100 | """Flush, close possible files, and remove stdout/stderr mirroring."""
101 | self.flush()
102 |
103 | # if using multiple loggers, prevent closing in wrong order
104 | if sys.stdout is self:
105 | sys.stdout = self.stdout
106 | if sys.stderr is self:
107 | sys.stderr = self.stderr
108 |
109 | if self.file is not None:
110 | self.file.close()
111 | self.file = None
112 |
113 |
114 | # Cache directories
115 | # ------------------------------------------------------------------------------------------
116 |
117 | _dnnlib_cache_dir = None
118 |
119 | def set_cache_dir(path: str) -> None:
120 | global _dnnlib_cache_dir
121 | _dnnlib_cache_dir = path
122 |
123 | def make_cache_dir_path(*paths: str) -> str:
124 | if _dnnlib_cache_dir is not None:
125 | return os.path.join(_dnnlib_cache_dir, *paths)
126 | if 'DNNLIB_CACHE_DIR' in os.environ:
127 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
128 | if 'HOME' in os.environ:
129 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
130 | if 'USERPROFILE' in os.environ:
131 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
132 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
133 |
134 | # Small util functions
135 | # ------------------------------------------------------------------------------------------
136 |
137 |
138 | def format_time(seconds: Union[int, float]) -> str:
139 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
140 | s = int(np.rint(seconds))
141 |
142 | if s < 60:
143 | return "{0}s".format(s)
144 | elif s < 60 * 60:
145 | return "{0}m {1:02}s".format(s // 60, s % 60)
146 | elif s < 24 * 60 * 60:
147 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
148 | else:
149 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
150 |
151 |
152 | def format_time_brief(seconds: Union[int, float]) -> str:
153 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
154 | s = int(np.rint(seconds))
155 |
156 | if s < 60:
157 | return "{0}s".format(s)
158 | elif s < 60 * 60:
159 | return "{0}m {1:02}s".format(s // 60, s % 60)
160 | elif s < 24 * 60 * 60:
161 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
162 | else:
163 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
164 |
165 |
166 | def ask_yes_no(question: str) -> bool:
167 | """Ask the user the question until the user inputs a valid answer."""
168 | while True:
169 | try:
170 | print("{0} [y/n]".format(question))
171 | return strtobool(input().lower())
172 | except ValueError:
173 | pass
174 |
175 |
176 | def tuple_product(t: Tuple) -> Any:
177 | """Calculate the product of the tuple elements."""
178 | result = 1
179 |
180 | for v in t:
181 | result *= v
182 |
183 | return result
184 |
185 |
186 | _str_to_ctype = {
187 | "uint8": ctypes.c_ubyte,
188 | "uint16": ctypes.c_uint16,
189 | "uint32": ctypes.c_uint32,
190 | "uint64": ctypes.c_uint64,
191 | "int8": ctypes.c_byte,
192 | "int16": ctypes.c_int16,
193 | "int32": ctypes.c_int32,
194 | "int64": ctypes.c_int64,
195 | "float32": ctypes.c_float,
196 | "float64": ctypes.c_double
197 | }
198 |
199 |
200 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
201 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
202 | type_str = None
203 |
204 | if isinstance(type_obj, str):
205 | type_str = type_obj
206 | elif hasattr(type_obj, "__name__"):
207 | type_str = type_obj.__name__
208 | elif hasattr(type_obj, "name"):
209 | type_str = type_obj.name
210 | else:
211 | raise RuntimeError("Cannot infer type name from input")
212 |
213 | assert type_str in _str_to_ctype.keys()
214 |
215 | my_dtype = np.dtype(type_str)
216 | my_ctype = _str_to_ctype[type_str]
217 |
218 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
219 |
220 | return my_dtype, my_ctype
221 |
222 |
223 | def is_pickleable(obj: Any) -> bool:
224 | try:
225 | with io.BytesIO() as stream:
226 | pickle.dump(obj, stream)
227 | return True
228 | except:
229 | return False
230 |
231 |
232 | # Functionality to import modules/objects by name, and call functions by name
233 | # ------------------------------------------------------------------------------------------
234 |
235 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
236 | """Searches for the underlying module behind the name to some python object.
237 | Returns the module and the object name (original name with module part removed)."""
238 |
239 | # allow convenience shorthands, substitute them by full names
240 | obj_name = re.sub("^np.", "numpy.", obj_name)
241 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
242 |
243 | # list alternatives for (module_name, local_obj_name)
244 | parts = obj_name.split(".")
245 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
246 |
247 | # try each alternative in turn
248 | for module_name, local_obj_name in name_pairs:
249 | try:
250 | module = importlib.import_module(module_name) # may raise ImportError
251 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
252 | return module, local_obj_name
253 | except:
254 | pass
255 |
256 | # maybe some of the modules themselves contain errors?
257 | for module_name, _local_obj_name in name_pairs:
258 | try:
259 | importlib.import_module(module_name) # may raise ImportError
260 | except ImportError:
261 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
262 | raise
263 |
264 | # maybe the requested attribute is missing?
265 | for module_name, local_obj_name in name_pairs:
266 | try:
267 | module = importlib.import_module(module_name) # may raise ImportError
268 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
269 | except ImportError:
270 | pass
271 |
272 | # we are out of luck, but we have no idea why
273 | raise ImportError(obj_name)
274 |
275 |
276 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
277 | """Traverses the object name and returns the last (rightmost) python object."""
278 | if obj_name == '':
279 | return module
280 | obj = module
281 | for part in obj_name.split("."):
282 | obj = getattr(obj, part)
283 | return obj
284 |
285 |
286 | def get_obj_by_name(name: str) -> Any:
287 | """Finds the python object with the given name."""
288 | module, obj_name = get_module_from_obj_name(name)
289 | return get_obj_from_module(module, obj_name)
290 |
291 |
292 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
293 | """Finds the python object with the given name and calls it as a function."""
294 | assert func_name is not None
295 | func_obj = get_obj_by_name(func_name)
296 | assert callable(func_obj)
297 | return func_obj(*args, **kwargs)
298 |
299 |
300 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
301 | """Finds the python class with the given name and constructs it with the given arguments."""
302 | return call_func_by_name(*args, func_name=class_name, **kwargs)
303 |
304 |
305 | def get_module_dir_by_obj_name(obj_name: str) -> str:
306 | """Get the directory path of the module containing the given object name."""
307 | module, _ = get_module_from_obj_name(obj_name)
308 | return os.path.dirname(inspect.getfile(module))
309 |
310 |
311 | def is_top_level_function(obj: Any) -> bool:
312 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
313 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
314 |
315 |
316 | def get_top_level_function_name(obj: Any) -> str:
317 | """Return the fully-qualified name of a top-level function."""
318 | assert is_top_level_function(obj)
319 | module = obj.__module__
320 | if module == '__main__':
321 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
322 | return module + "." + obj.__name__
323 |
324 |
325 | # File system helpers
326 | # ------------------------------------------------------------------------------------------
327 |
328 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
329 | """List all files recursively in a given directory while ignoring given file and directory names.
330 | Returns list of tuples containing both absolute and relative paths."""
331 | assert os.path.isdir(dir_path)
332 | base_name = os.path.basename(os.path.normpath(dir_path))
333 |
334 | if ignores is None:
335 | ignores = []
336 |
337 | result = []
338 |
339 | for root, dirs, files in os.walk(dir_path, topdown=True):
340 | for ignore_ in ignores:
341 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
342 |
343 | # dirs need to be edited in-place
344 | for d in dirs_to_remove:
345 | dirs.remove(d)
346 |
347 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
348 |
349 | absolute_paths = [os.path.join(root, f) for f in files]
350 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
351 |
352 | if add_base_to_relative:
353 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
354 |
355 | assert len(absolute_paths) == len(relative_paths)
356 | result += zip(absolute_paths, relative_paths)
357 |
358 | return result
359 |
360 |
361 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
362 | """Takes in a list of tuples of (src, dst) paths and copies files.
363 | Will create all necessary directories."""
364 | for file in files:
365 | target_dir_name = os.path.dirname(file[1])
366 |
367 | # will create all intermediate-level directories
368 | if not os.path.exists(target_dir_name):
369 | os.makedirs(target_dir_name)
370 |
371 | shutil.copyfile(file[0], file[1])
372 |
373 |
374 | # URL helpers
375 | # ------------------------------------------------------------------------------------------
376 |
377 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
378 | """Determine whether the given object is a valid URL string."""
379 | if not isinstance(obj, str) or not "://" in obj:
380 | return False
381 | if allow_file_urls and obj.startswith('file://'):
382 | return True
383 | try:
384 | res = requests.compat.urlparse(obj)
385 | if not res.scheme or not res.netloc or not "." in res.netloc:
386 | return False
387 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
388 | if not res.scheme or not res.netloc or not "." in res.netloc:
389 | return False
390 | except:
391 | return False
392 | return True
393 |
394 |
395 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
396 | """Download the given URL and return a binary-mode file object to access the data."""
397 | assert num_attempts >= 1
398 | assert not (return_filename and (not cache))
399 |
400 | # Doesn't look like an URL scheme so interpret it as a local filename.
401 | if not re.match('^[a-z]+://', url):
402 | return url if return_filename else open(url, "rb")
403 |
404 | # Handle file URLs. This code handles unusual file:// patterns that
405 | # arise on Windows:
406 | #
407 | # file:///c:/foo.txt
408 | #
409 | # which would translate to a local '/c:/foo.txt' filename that's
410 | # invalid. Drop the forward slash for such pathnames.
411 | #
412 | # If you touch this code path, you should test it on both Linux and
413 | # Windows.
414 | #
415 | # Some internet resources suggest using urllib.request.url2pathname() but
416 | # but that converts forward slashes to backslashes and this causes
417 | # its own set of problems.
418 | if url.startswith('file://'):
419 | filename = urllib.parse.urlparse(url).path
420 | if re.match(r'^/[a-zA-Z]:', filename):
421 | filename = filename[1:]
422 | return filename if return_filename else open(filename, "rb")
423 |
424 | assert is_url(url)
425 |
426 | # Lookup from cache.
427 | if cache_dir is None:
428 | cache_dir = make_cache_dir_path('downloads')
429 |
430 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
431 | if cache:
432 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
433 | if len(cache_files) == 1:
434 | filename = cache_files[0]
435 | return filename if return_filename else open(filename, "rb")
436 |
437 | # Download.
438 | url_name = None
439 | url_data = None
440 | with requests.Session() as session:
441 | if verbose:
442 | print("Downloading %s ..." % url, end="", flush=True)
443 | for attempts_left in reversed(range(num_attempts)):
444 | try:
445 | with session.get(url) as res:
446 | res.raise_for_status()
447 | if len(res.content) == 0:
448 | raise IOError("No data received")
449 |
450 | if len(res.content) < 8192:
451 | content_str = res.content.decode("utf-8")
452 | if "download_warning" in res.headers.get("Set-Cookie", ""):
453 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
454 | if len(links) == 1:
455 | url = requests.compat.urljoin(url, links[0])
456 | raise IOError("Google Drive virus checker nag")
457 | if "Google Drive - Quota exceeded" in content_str:
458 | raise IOError("Google Drive download quota exceeded -- please try again later")
459 |
460 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
461 | url_name = match[1] if match else url
462 | url_data = res.content
463 | if verbose:
464 | print(" done")
465 | break
466 | except KeyboardInterrupt:
467 | raise
468 | except:
469 | if not attempts_left:
470 | if verbose:
471 | print(" failed")
472 | raise
473 | if verbose:
474 | print(".", end="", flush=True)
475 |
476 | # Save to cache.
477 | if cache:
478 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
479 | safe_name = safe_name[:min(len(safe_name), 128)]
480 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482 | os.makedirs(cache_dir, exist_ok=True)
483 | with open(temp_file, "wb") as f:
484 | f.write(url_data)
485 | os.replace(temp_file, cache_file) # atomic
486 | if return_filename:
487 | return cache_file
488 |
489 | # Return data as file object.
490 | assert not return_filename
491 | return io.BytesIO(url_data)
492 |
--------------------------------------------------------------------------------
/edm/docs/afhqv2-64x64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/afhqv2-64x64.png
--------------------------------------------------------------------------------
/edm/docs/cifar10-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/cifar10-32x32.png
--------------------------------------------------------------------------------
/edm/docs/dataset-tool-help.txt:
--------------------------------------------------------------------------------
1 | Usage: dataset_tool.py [OPTIONS]
2 |
3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
4 | PyTorch.
5 |
6 | The input dataset format is guessed from the --source argument:
7 |
8 | --source *_lmdb/ Load LSUN dataset
9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset
10 | --source train-images-idx3-ubyte.gz Load MNIST dataset
11 | --source path/ Recursively load all images from path/
12 | --source dataset.zip Recursively load all images from dataset.zip
13 |
14 | Specifying the output format and path:
15 |
16 | --dest /path/to/dir Save output files under /path/to/dir
17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
18 |
19 | The output dataset format can be either an image folder or an uncompressed
20 | zip archive. Zip archives makes it easier to move datasets around file
21 | servers and clusters, and may offer better training performance on network
22 | file systems.
23 |
24 | Images within the dataset archive will be stored as uncompressed PNG.
25 | Uncompresed PNGs can be efficiently decoded in the training loop.
26 |
27 | Class labels are stored in a file called 'dataset.json' that is stored at
28 | the dataset root folder. This file has the following structure:
29 |
30 | {
31 | "labels": [
32 | ["00000/img00000000.png",6],
33 | ["00000/img00000001.png",9],
34 | ... repeated for every image in the datase
35 | ["00049/img00049999.png",1]
36 | ]
37 | }
38 |
39 | If the 'dataset.json' file cannot be found, class labels are determined from
40 | top-level directory names.
41 |
42 | Image scale/crop and resolution requirements:
43 |
44 | Output images must be square-shaped and they must all have the same power-
45 | of-two dimensions.
46 |
47 | To scale arbitrary input image size to a specific width and height, use the
48 | --resolution option. Output resolution will be either the original input
49 | resolution (if resolution was not specified) or the one specified with
50 | --resolution option.
51 |
52 | Use the --transform=center-crop or --transform=center-crop-wide options to
53 | apply a center crop transform on the input image. These options should be
54 | used with the --resolution option. For example:
55 |
56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
57 | --transform=center-crop-wide --resolution=512x384
58 |
59 | Options:
60 | --source PATH Input directory or archive name [required]
61 | --dest PATH Output directory or archive name [required]
62 | --max-images INT Maximum number of images to output
63 | --transform MODE Input crop/resize mode
64 | --resolution WxH Output resolution (e.g., 512x512)
65 | --help Show this message and exit.
66 |
--------------------------------------------------------------------------------
/edm/docs/ffhq-64x64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/ffhq-64x64.png
--------------------------------------------------------------------------------
/edm/docs/fid-help.txt:
--------------------------------------------------------------------------------
1 | Usage: fid.py [OPTIONS] COMMAND [ARGS]...
2 |
3 | Calculate Frechet Inception Distance (FID).
4 |
5 | Examples:
6 |
7 | # Generate 50000 images and save them as fid-tmp/*/*.png
8 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \
9 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
10 |
11 | # Calculate FID
12 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \
13 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
14 |
15 | # Compute dataset reference statistics
16 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
17 |
18 | Options:
19 | --help Show this message and exit.
20 |
21 | Commands:
22 | calc Calculate FID for a given set of images.
23 | ref Calculate dataset reference statistics needed by 'calc'.
24 |
25 |
26 | Usage: fid.py calc [OPTIONS]
27 |
28 | Calculate FID for a given set of images.
29 |
30 | Options:
31 | --images PATH|ZIP Path to the images [required]
32 | --ref NPZ|URL Dataset reference statistics [required]
33 | --num INT Number of images to use [default: 50000; x>=2]
34 | --seed INT Random seed for selecting the images [default: 0]
35 | --batch INT Maximum batch size [default: 64; x>=1]
36 | --help Show this message and exit.
37 |
38 |
39 | Usage: fid.py ref [OPTIONS]
40 |
41 | Calculate dataset reference statistics needed by 'calc'.
42 |
43 | Options:
44 | --data PATH|ZIP Path to the dataset [required]
45 | --dest NPZ Destination .npz file [required]
46 | --batch INT Maximum batch size [default: 64; x>=1]
47 | --help Show this message and exit.
48 |
--------------------------------------------------------------------------------
/edm/docs/generate-help.txt:
--------------------------------------------------------------------------------
1 | Usage: generate.py [OPTIONS]
2 |
3 | Generate random images using the techniques described in the paper
4 | "Elucidating the Design Space of Diffusion-Based Generative Models".
5 |
6 | Examples:
7 |
8 | # Generate 64 images and save them as out/*.png
9 | python generate.py --outdir=out --seeds=0-63 --batch=64 \
10 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
11 |
12 | # Generate 1024 images using 2 GPUs
13 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \
14 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
15 |
16 | Options:
17 | --network PATH|URL Network pickle filename [required]
18 | --outdir DIR Where to save the output images [required]
19 | --seeds LIST Random seeds (e.g. 1,2,5-10) [default: 0-63]
20 | --subdirs Create subdirectory for every 1000 seeds
21 | --class INT Class label [default: random] [x>=0]
22 | --batch INT Maximum batch size [default: 64; x>=1]
23 | --steps INT Number of sampling steps [default: 18; x>=1]
24 | --sigma_min FLOAT Lowest noise level [default: varies] [x>0]
25 | --sigma_max FLOAT Highest noise level [default: varies] [x>0]
26 | --rho FLOAT Time step exponent [default: 7; x>0]
27 | --S_churn FLOAT Stochasticity strength [default: 0; x>=0]
28 | --S_min FLOAT Stoch. min noise level [default: 0; x>=0]
29 | --S_max FLOAT Stoch. max noise level [default: inf; x>=0]
30 | --S_noise FLOAT Stoch. noise inflation [default: 1]
31 | --solver euler|heun Ablate ODE solver
32 | --disc vp|ve|iddpm|edm Ablate time step discretization {t_i}
33 | --schedule vp|ve|linear Ablate noise schedule sigma(t)
34 | --scaling vp|none Ablate signal scaling s(t)
35 | --help Show this message and exit.
36 |
--------------------------------------------------------------------------------
/edm/docs/imagenet-64x64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/imagenet-64x64.png
--------------------------------------------------------------------------------
/edm/docs/teaser-1280x640.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-1280x640.jpg
--------------------------------------------------------------------------------
/edm/docs/teaser-1920x640.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-1920x640.jpg
--------------------------------------------------------------------------------
/edm/docs/teaser-640x480.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/edm/docs/teaser-640x480.jpg
--------------------------------------------------------------------------------
/edm/docs/train-help.txt:
--------------------------------------------------------------------------------
1 | Usage: train.py [OPTIONS]
2 |
3 | Train diffusion-based generative model using the techniques described in the
4 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".
5 |
6 | Examples:
7 |
8 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
9 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \
10 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
11 |
12 | Options:
13 | --outdir DIR Where to save the results [required]
14 | --data ZIP|DIR Path to the dataset [required]
15 | --cond BOOL Train class-conditional model [default: False]
16 | --arch ddpmpp|ncsnpp|adm Network architecture [default: ddpmpp]
17 | --precond vp|ve|edm Preconditioning & loss function [default: edm]
18 | --duration MIMG Training duration [default: 200; x>0]
19 | --batch INT Total batch size [default: 512; x>=1]
20 | --batch-gpu INT Limit batch size per GPU [x>=1]
21 | --cbase INT Channel multiplier [default: varies]
22 | --cres LIST Channels per resolution [default: varies]
23 | --lr FLOAT Learning rate [default: 0.001; x>0]
24 | --ema MIMG EMA half-life [default: 0.5; x>=0]
25 | --dropout FLOAT Dropout probability [default: 0.13; 0<=x<=1]
26 | --augment FLOAT Augment probability [default: 0.12; 0<=x<=1]
27 | --xflip BOOL Enable dataset x-flips [default: False]
28 | --fp16 BOOL Enable mixed-precision training [default: False]
29 | --ls FLOAT Loss scaling [default: 1; x>0]
30 | --bench BOOL Enable cuDNN benchmarking [default: True]
31 | --cache BOOL Cache dataset in CPU memory [default: True]
32 | --workers INT DataLoader worker processes [default: 1; x>=1]
33 | --desc STR String to include in result dir name
34 | --nosubdir Do not create a subdirectory for results
35 | --tick KIMG How often to print progress [default: 50; x>=1]
36 | --snap TICKS How often to save snapshots [default: 50; x>=1]
37 | --dump TICKS How often to dump state [default: 500; x>=1]
38 | --seed INT Random seed [default: random]
39 | --transfer PKL|URL Transfer learning from network pickle
40 | --resume PT Resume from previous training state
41 | -n, --dry-run Print training options and exit
42 | --help Show this message and exit.
43 |
--------------------------------------------------------------------------------
/edm/environment.yml:
--------------------------------------------------------------------------------
1 | name: edm
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python>=3.8, < 3.10 # package build failures on 3.10
7 | - pip
8 | - numpy>=1.20
9 | - click>=8.0
10 | - pillow>=8.3.1
11 | - scipy>=1.7.1
12 | - pytorch=1.12.1
13 | - psutil
14 | - requests
15 | - tqdm
16 | - imageio
17 | - pip:
18 | - imageio-ffmpeg>=0.4.3
19 | - pyspng
20 | - wget
21 |
--------------------------------------------------------------------------------
/edm/example.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Minimal standalone example to reproduce the main results from the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import tqdm
12 | import pickle
13 | import numpy as np
14 | import torch
15 | import PIL.Image
16 | import dnnlib
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | def generate_image_grid(
21 | network_pkl, dest_path,
22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'),
23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
25 | ):
26 | batch_size = gridw * gridh
27 | torch.manual_seed(seed)
28 |
29 | # Load network.
30 | print(f'Loading network from "{network_pkl}"...')
31 | with dnnlib.util.open_url(network_pkl) as f:
32 | net = pickle.load(f)['ema'].to(device)
33 |
34 | # Pick latents and labels.
35 | print(f'Generating {batch_size} images...')
36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
37 | class_labels = None
38 | if net.label_dim:
39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)]
40 |
41 | # Adjust noise levels based on what's supported by the network.
42 | sigma_min = max(sigma_min, net.sigma_min)
43 | sigma_max = min(sigma_max, net.sigma_max)
44 |
45 | # Time step discretization.
46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
49 |
50 | # Main sampling loop.
51 | x_next = latents.to(torch.float64) * t_steps[0]
52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1
53 | x_cur = x_next
54 |
55 | # Increase noise temporarily.
56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
57 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
59 |
60 | # Euler step.
61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
62 | d_cur = (x_hat - denoised) / t_hat
63 | x_next = x_hat + (t_next - t_hat) * d_cur
64 |
65 | # Apply 2nd order correction.
66 | if i < num_steps - 1:
67 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
68 | d_prime = (x_next - denoised) / t_next
69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
70 |
71 | # Save image grid.
72 | print(f'Saving image grid to "{dest_path}"...')
73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
76 | image = image.cpu().numpy()
77 | PIL.Image.fromarray(image, 'RGB').save(dest_path)
78 | print('Done.')
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def main():
83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained'
84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35
85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79
86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79
87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511
88 |
89 | #----------------------------------------------------------------------------
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
94 | #----------------------------------------------------------------------------
95 |
--------------------------------------------------------------------------------
/edm/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Script for calculating Frechet Inception Distance (FID)."""
9 |
10 | import os
11 | import click
12 | import tqdm
13 | import pickle
14 | import numpy as np
15 | import scipy.linalg
16 | import torch
17 | import dnnlib
18 | from torch_utils import distributed as dist
19 | from training import dataset
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def calculate_inception_stats(
24 | image_path, num_expected=None, seed=0, max_batch_size=64,
25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
26 | ):
27 | # Rank 0 goes first.
28 | if dist.get_rank() != 0:
29 | torch.distributed.barrier()
30 |
31 | # Load Inception-v3 model.
32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
33 | dist.print0('Loading Inception-v3 model...')
34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
35 | detector_kwargs = dict(return_features=True)
36 | feature_dim = 2048
37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
38 | detector_net = pickle.load(f).to(device)
39 |
40 | # List images.
41 | dist.print0(f'Loading images from "{image_path}"...')
42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
43 | if num_expected is not None and len(dataset_obj) < num_expected:
44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
45 | if len(dataset_obj) < 2:
46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
47 |
48 | # Other ranks follow.
49 | if dist.get_rank() == 0:
50 | torch.distributed.barrier()
51 |
52 | # Divide images into batches.
53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
57 |
58 | # Accumulate statistics.
59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
63 | torch.distributed.barrier()
64 | if images.shape[0] == 0:
65 | continue
66 | if images.shape[1] == 1:
67 | images = images.repeat([1, 3, 1, 1])
68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
69 | mu += features.sum(0)
70 | sigma += features.T @ features
71 |
72 | # Calculate grand totals.
73 | torch.distributed.all_reduce(mu)
74 | torch.distributed.all_reduce(sigma)
75 | mu /= len(dataset_obj)
76 | sigma -= mu.ger(mu) * len(dataset_obj)
77 | sigma /= len(dataset_obj) - 1
78 | return mu.cpu().numpy(), sigma.cpu().numpy()
79 |
80 | #----------------------------------------------------------------------------
81 |
82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
83 | m = np.square(mu - mu_ref).sum()
84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
85 | fid = m + np.trace(sigma + sigma_ref - s * 2)
86 | return float(np.real(fid))
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @click.group()
91 | def main():
92 | """Calculate Frechet Inception Distance (FID).
93 |
94 | Examples:
95 |
96 | \b
97 | # Generate 50000 images and save them as fid-tmp/*/*.png
98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
100 |
101 | \b
102 | # Calculate FID
103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
105 |
106 | \b
107 | # Compute dataset reference statistics
108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
109 | """
110 |
111 | #----------------------------------------------------------------------------
112 |
113 | @main.command()
114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
119 |
120 | def calc(image_path, ref_path, num_expected, seed, batch):
121 | """Calculate FID for a given set of images."""
122 | torch.multiprocessing.set_start_method('spawn')
123 | dist.init()
124 |
125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
126 | ref = None
127 | if dist.get_rank() == 0:
128 | with dnnlib.util.open_url(ref_path) as f:
129 | ref = dict(np.load(f))
130 |
131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
132 | dist.print0('Calculating FID...')
133 | if dist.get_rank() == 0:
134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
135 | print(f'{fid:g}')
136 | torch.distributed.barrier()
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | @main.command()
141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
144 |
145 | def ref(dataset_path, dest_path, batch):
146 | """Calculate dataset reference statistics needed by 'calc'."""
147 | torch.multiprocessing.set_start_method('spawn')
148 | dist.init()
149 |
150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
152 | if dist.get_rank() == 0:
153 | if os.path.dirname(dest_path):
154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True)
155 | np.savez(dest_path, mu=mu, sigma=sigma)
156 |
157 | torch.distributed.barrier()
158 | dist.print0('Done.')
159 |
160 | #----------------------------------------------------------------------------
161 |
162 | if __name__ == "__main__":
163 | main()
164 |
165 | #----------------------------------------------------------------------------
166 |
--------------------------------------------------------------------------------
/edm/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Generate random images using the techniques described in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import click
14 | import tqdm
15 | import pickle
16 | import numpy as np
17 | import torch
18 | import PIL.Image
19 | import dnnlib
20 | from torch_utils import distributed as dist
21 |
22 | #----------------------------------------------------------------------------
23 | # Proposed EDM sampler (Algorithm 2).
24 |
25 | def edm_sampler(
26 | net, latents, dataloader, class_labels=None, randn_like=torch.randn_like,
27 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
28 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, optimal_denoiser=False
29 | ):
30 | # Adjust noise levels based on what's supported by the network.
31 | if optimal_denoiser:
32 | optimal_x0 = optimal_solver(dataloader)
33 | sigma_min = max(sigma_min, net.sigma_min)
34 | sigma_max = min(sigma_max, net.sigma_max)
35 |
36 | # Time step discretization.
37 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
38 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
39 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
40 |
41 | # Main sampling loop.
42 | x_next = latents.to(torch.float64) * t_steps[0]
43 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
44 | x_cur = x_next
45 |
46 | # Increase noise temporarily.
47 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
48 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
49 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
50 |
51 | # Euler step.
52 | if not optimal_denoiser:
53 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
54 | else:
55 | denoised = optimal_x0(x_hat, s = torch.tensor(1, device = t_cur.device, dtype = t_cur.dtype), sigma = t_hat)
56 | d_cur = (x_hat - denoised) / t_hat
57 | x_next = x_hat + (t_next - t_hat) * d_cur
58 |
59 | # Apply 2nd order correction.
60 | if i < num_steps - 1 and not optimal_denoiser:
61 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
62 | d_prime = (x_next - denoised) / t_next
63 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
64 |
65 | return x_next
66 |
67 |
68 | def normal_distribution_batch(x, y_batch, s, std, bias = 0):
69 | bs_y = y_batch.shape[0]
70 | bs_x = x.shape[0]
71 | xb = x.unsqueeze(1)
72 | prob = torch.exp(-(((xb - s * y_batch)**2).view(bs_x, bs_y, -1).sum(dim=-1).to(torch.float64)/std**2)/2 - bias.unsqueeze(1))
73 | # prob = torch.where(prob==torch.inf, 1, 0)
74 | prob_y = prob.clone().view(bs_x, bs_y, 1, 1, 1) * y_batch # (bs_x, bs_y, 3, 32, 32)
75 | return prob.sum(dim=1, keepdim=True).squeeze(), prob_y.sum(dim=1, keepdim=True).squeeze() # (bs_x, ), (bs_x, 3, 32, 32)
76 |
77 | def get_exp_bias_batch(x, y_batch, s, std):
78 | ## because exp() might return a very small number, we need a bias
79 | bs_y = y_batch.shape[0]
80 | bs_x = x.shape[0]
81 | xb = x.unsqueeze(1)
82 | return (-(((xb - s * y_batch)**2).view(bs_x, bs_y, -1).sum(dim=-1).to(torch.float64)/std**2)/2).max(dim=1)
83 |
84 | def optimal_solver(dataloader):
85 | def optimal_sol(batch, s, sigma):
86 |
87 | std = s * sigma
88 | x = batch
89 | prob_sum = 0. * torch.ones(x.shape[0]).cuda()
90 | prob_y_sum = torch.zeros_like(x).to(torch.float64).cuda()
91 | exp_bias = -(torch.inf) * torch.ones(x.shape[0],).cuda()
92 | for y_batch, _ in dataloader:
93 | y_batch = y_batch.cuda().to(torch.float32) / 127.5 - 1
94 | curr_exp_bias = get_exp_bias_batch(x, y_batch, s, std)[0]
95 | exp_bias = torch.where(curr_exp_bias>exp_bias, curr_exp_bias, exp_bias)
96 | for y_batch, _ in dataloader:
97 | y_batch = y_batch.cuda().to(torch.float32) / 127.5 - 1
98 | prob, prob_y = normal_distribution_batch(x, y_batch, s, std, exp_bias)
99 | prob_sum += prob
100 | prob_y_sum += prob_y
101 |
102 | optimal_solution = prob_y_sum/prob_sum.view(-1, 1, 1, 1)
103 | return optimal_solution
104 | return optimal_sol
105 |
106 | #----------------------------------------------------------------------------
107 | # Wrapper for torch.Generator that allows specifying a different random seed
108 | # for each sample in a minibatch.
109 |
110 | class StackedRandomGenerator:
111 | def __init__(self, device, seeds):
112 | super().__init__()
113 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
114 |
115 | def randn(self, size, **kwargs):
116 | assert size[0] == len(self.generators)
117 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
118 |
119 | def randn_like(self, input):
120 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
121 |
122 | def randint(self, *args, size, **kwargs):
123 | assert size[0] == len(self.generators)
124 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
125 |
126 | #----------------------------------------------------------------------------
127 | # Parse a comma separated list of numbers or ranges and return a list of ints.
128 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
129 |
130 | def parse_int_list(s):
131 | if isinstance(s, list): return s
132 | ranges = []
133 | range_re = re.compile(r'^(\d+)-(\d+)$')
134 | for p in s.split(','):
135 | m = range_re.match(p)
136 | if m:
137 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
138 | else:
139 | ranges.append(int(p))
140 | return ranges
141 |
142 | #----------------------------------------------------------------------------
143 |
144 | @click.command()
145 | @click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True)
146 | @click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True)
147 | @click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True)
148 | @click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True)
149 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
150 | @click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
151 |
152 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
153 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
154 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True))
155 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
156 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
157 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
158 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
159 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
160 |
161 | @click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun']))
162 | @click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
163 | @click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear']))
164 | @click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none']))
165 |
166 | @click.option('--optimal_denoiser', help='Generate images from optimal denoiser', is_flag=True)
167 | @click.option('--dataset', 'dataset', help='Dataset used by the optimal denoiser', type=str)
168 |
169 |
170 | def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, optimal_denoiser, dataset, device=torch.device('cuda'), **sampler_kwargs):
171 | """Generate random images using the techniques described in the paper
172 | "Elucidating the Design Space of Diffusion-Based Generative Models".
173 |
174 | Examples:
175 |
176 | \b
177 | # Generate 64 images and save them as out/*.png
178 | python generate.py --outdir=out --seeds=0-63 --batch=64 \\
179 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
180 |
181 | \b
182 | # Generate 1024 images using 2 GPUs
183 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\
184 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
185 | """
186 | dist.init()
187 | num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
188 | all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
189 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
190 |
191 | # Rank 0 goes first.
192 | if dist.get_rank() != 0:
193 | torch.distributed.barrier()
194 |
195 | dataset_loader = None
196 | if optimal_denoiser:
197 | assert dataset is not None
198 | dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=dataset)
199 | data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1)
200 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs)
201 | dataset_loader = torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=max_batch_size, **data_loader_kwargs)
202 |
203 | # Load network.
204 | assert network_pkl is not None
205 | dist.print0(f'Loading network from "{network_pkl}"...')
206 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
207 | net = pickle.load(f)['ema'].to(device)
208 |
209 | # Other ranks follow.
210 | if dist.get_rank() == 0:
211 | torch.distributed.barrier()
212 |
213 | # Loop over batches.
214 | dist.print0(f'Generating {len(seeds)} images to "{outdir}"...')
215 | for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
216 | torch.distributed.barrier()
217 | batch_size = len(batch_seeds)
218 | if batch_size == 0:
219 | continue
220 |
221 | # Pick latents and labels.
222 | rnd = StackedRandomGenerator(device, batch_seeds)
223 |
224 | latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
225 |
226 | class_labels = None
227 | if net.label_dim:
228 | class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)]
229 | if class_idx is not None:
230 | class_labels[:, :] = 0
231 | class_labels[:, class_idx] = 1
232 |
233 | # Generate images.
234 | sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
235 | images = edm_sampler(net, latents, dataset_loader, class_labels, randn_like=rnd.randn_like, optimal_denoiser=optimal_denoiser, **sampler_kwargs)
236 |
237 | # Save images.
238 | images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
239 | for seed, image_np in zip(batch_seeds, images_np):
240 | image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir
241 | os.makedirs(image_dir, exist_ok=True)
242 | image_path = os.path.join(image_dir, f'{seed:06d}.png')
243 | if image_np.shape[2] == 1:
244 | PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
245 | else:
246 | PIL.Image.fromarray(image_np, 'RGB').save(image_path)
247 |
248 | # Done.
249 | torch.distributed.barrier()
250 | dist.print0('Done.')
251 |
252 | #----------------------------------------------------------------------------
253 |
254 | if __name__ == "__main__":
255 | main()
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/edm/jacobian.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Evaluate the rank of the jacobian of the denoising autoencoder"""
9 |
10 | import re
11 | import click
12 | import pickle
13 | import numpy as np
14 | import torch
15 | import dnnlib
16 | from torch_utils import distributed as dist
17 | from torch import nn
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def edm_sampler(
22 | net, latents, class_labels=None, randn_like=torch.randn_like,
23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
25 | ):
26 | # Adjust noise levels based on what's supported by the network.
27 | sigma_min = max(sigma_min, net.sigma_min)
28 | sigma_max = min(sigma_max, net.sigma_max)
29 |
30 | # Time step discretization.
31 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
32 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
33 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
34 |
35 | # Main sampling loop.
36 | x_next = latents.to(torch.float64) * t_steps[0]
37 | trajectory = [x_next]
38 | ts = [t_steps[0]]
39 |
40 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
41 |
42 | x_cur = x_next
43 |
44 | # Increase noise temporarily.
45 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
46 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
47 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
48 |
49 | # Euler step.
50 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
51 | d_cur = (x_hat - denoised) / t_hat
52 | x_next = x_hat + (t_next - t_hat) * d_cur
53 |
54 | # Apply 2nd order correction.
55 | if i < num_steps - 1:
56 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
57 | d_prime = (x_next - denoised) / t_next
58 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
59 | ts.append(t_next)
60 | trajectory.append(x_next)
61 |
62 | return trajectory[:-1], ts[:-1]
63 |
64 | class StackedRandomGenerator:
65 | def __init__(self, device, seeds):
66 | super().__init__()
67 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
68 |
69 | def randn(self, size, **kwargs):
70 | assert size[0] == len(self.generators)
71 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
72 |
73 | def randn_like(self, input):
74 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
75 |
76 | def randint(self, *args, size, **kwargs):
77 | assert size[0] == len(self.generators)
78 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
79 |
80 | #----------------------------------------------------------------------------
81 | # Parse a comma separated list of numbers or ranges and return a list of ints.
82 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
83 |
84 | def parse_int_list(s):
85 | if isinstance(s, list): return s
86 | ranges = []
87 | range_re = re.compile(r'^(\d+)-(\d+)$')
88 | for p in s.split(','):
89 | m = range_re.match(p)
90 | if m:
91 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
92 | else:
93 | ranges.append(int(p))
94 | return ranges
95 |
96 | #----------------------------------------------------------------------------
97 |
98 | @click.command()
99 |
100 | @click.option('--network_pkl', help='ckpt of the model to evaluate the rank of the jacobian', type=str, required = True)
101 | @click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None)
102 |
103 | @click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True)
104 | @click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=0.002)
105 | @click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=80.0)
106 | @click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
107 | @click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
108 | @click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
109 | @click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True)
110 | @click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True)
111 |
112 | #
113 |
114 | def main(network_pkl, class_idx, device=torch.device('cuda'), **sampler_kwargs):
115 | """Generate random images using the techniques described in the paper
116 | "Elucidating the Design Space of Diffusion-Based Generative Models".
117 |
118 | Examples:
119 |
120 | python edm/jacobian.py --network_pkl https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-ve.pkl --class 5
121 |
122 | """
123 | dist.init()
124 |
125 | # Rank 0 goes first.
126 | if dist.get_rank() != 0:
127 | torch.distributed.barrier()
128 |
129 | dist.print0(f'Loading network from "{network_pkl}"...')
130 | with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f:
131 | net = pickle.load(f)['ema'].to(device)
132 |
133 | # Other ranks follow.
134 | if dist.get_rank() == 0:
135 | torch.distributed.barrier()
136 |
137 | # Loop over batches.
138 | sigma_max = sampler_kwargs['sigma_max']
139 | sigma_min = sampler_kwargs['sigma_min']
140 | rho = sampler_kwargs['rho']
141 | num_steps = sampler_kwargs['num_steps']
142 |
143 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
144 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
145 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
146 | class_labels = None
147 | if net.label_dim:
148 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[1], device=device)]
149 | if class_idx is not None:
150 | class_labels[:, :] = 0
151 | class_labels[:, class_idx] = 1
152 |
153 | total_dim = net.img_channels * net.img_resolution * net.img_resolution
154 | latents = torch.randn(1, net.img_channels, net.img_resolution, net.img_resolution).to(device)
155 |
156 | trajectory, ts = edm_sampler(net, latents, num_steps=num_steps, class_labels=class_labels, sigma_max=sigma_max)
157 | cos = nn.CosineSimilarity(dim=0, eps=1e-6)
158 | for x, t in zip(trajectory, ts):
159 | func = lambda input: net(input, t)
160 |
161 | jacs = torch.autograd.functional.jacobian(func, x).squeeze().permute(1, 2, 0, 4, 5, 3).reshape(total_dim, total_dim).detach().cpu()
162 | output = net(x, t, class_labels).squeeze().permute(1, 2, 0).reshape(-1, 1).detach().cpu()
163 | U, S, V = torch.svd(jacs)
164 | acc_sum = torch.cumsum((S ** 2).to(torch.float32), dim=0).sqrt()
165 | total_sum = acc_sum[-1]
166 | sum = total_sum * 0.99
167 | rank = torch.where(acc_sum > sum)[0].min() + 1
168 | print(f"SNR 1/sigma_t: {1/t.item():.3f}, rank: {rank}")
169 |
170 | # Done.
171 | torch.distributed.barrier()
172 | dist.print0('Done.')
173 |
174 | #----------------------------------------------------------------------------
175 |
176 | if __name__ == "__main__":
177 | main()
178 |
--------------------------------------------------------------------------------
/edm/sscd.py:
--------------------------------------------------------------------------------
1 | """Script for Self-Supervised Descriptor for Image Copy Detection (SSCD)."""
2 |
3 | import os
4 | import click
5 | import tqdm
6 | import pickle
7 | import numpy as np
8 | import scipy.linalg
9 | import torch
10 | import dnnlib
11 | from torch_utils import distributed as dist
12 | from training import dataset
13 | from torchvision import transforms
14 | import wget
15 |
16 | #----------------------------------------------------------------------------
17 |
18 | def TransformSamples(samples, transform):
19 | sscd_samples = []
20 | for sample in samples:
21 | sscd_samples.append(transform(sample)[None, :])
22 | return torch.cat(sscd_samples, dim=0)
23 |
24 |
25 | @click.group()
26 | def main():
27 | """Calculate Self-Supervised Descriptor for Image Copy Detection (SSCD).
28 | The original github is https://github.com/facebookresearch/sscd-copy-detection
29 |
30 | Examples:
31 |
32 | \b
33 | # Calculate SSCD feature
34 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim128-n16384 --features ./evaluation/sscd-dim128-n16384.npz
35 |
36 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/ddpm-dim64-n16384 --features ./evaluation/sscd-dim64-n16384.npz
37 |
38 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images ./evaluation/generalization/ --features ./evaluation/sscd-generalization.npz
39 |
40 | torchrun --standalone --nproc_per_node=1 edm/sscd.py feature --images datasets/synthetic-cifar10-32x32-n16384.zip --features ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n16384.npz
41 |
42 | \b
43 | # Compute reproducibility score
44 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-dim64-n16384.npz
45 |
46 | # Compute generalization score
47 | python edm/sscd.py rpscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-generalization.npz
48 |
49 | # Compute memorization score
50 | python edm/sscd.py mscore --source ./evaluation/sscd-dim128-n16384.npz --target ./evaluation/sscd-training-dataset-synthetic-cifar10-32x32-n16384.npz
51 | """
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | @main.command()
56 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|NPZ', type=str, required=True)
57 | @click.option('--features', 'features_path', help='Path to save features', metavar='NPZ', type=str, required=True)
58 | def feature(image_path, features_path):
59 | """Calculate SSCD features for a given set of images."""
60 | if not os.path.exists("./pretrainedmodels"):
61 | os.makedirs("./pretrainedmodels")
62 | if not os.path.exists("./pretrainedmodels/sscd_disc_large.torchscript.pt"):
63 | wget.download("https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_large.torchscript.pt", "./pretrainedmodels/sscd_disc_large.torchscript.pt")
64 | sscd_model = torch.jit.load("./pretrainedmodels/sscd_disc_large.torchscript.pt")
65 | sscd_model = sscd_model.to(device=f"cuda:0")
66 | sscd_model.eval()
67 | sscd_transform = transforms.Compose([
68 | transforms.ToPILImage(),
69 | transforms.Resize((320, 320)),
70 | transforms.ToTensor(),
71 | # transform to float tensor
72 | # transforms.Lambda(lambda x: x.float()),
73 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],)
74 | ])
75 | dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=image_path)
76 | data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=1)
77 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs)
78 | dataloader = torch.utils.data.DataLoader(dataset=dataset_obj, batch_size=64, **data_loader_kwargs)
79 | sscd_features = []
80 | for x_batch, _ in tqdm.tqdm(dataloader):
81 | x_sscd = TransformSamples(x_batch, sscd_transform).to(device=f"cuda:0")
82 | sscd_feature = sscd_model(x_sscd).detach().cpu()
83 | sscd_features.append(sscd_feature)
84 | sscd_features = torch.cat(sscd_features, dim=0)
85 | np.savez(features_path, features=sscd_features.numpy())
86 | # #----------------------------------------------------------------------------
87 |
88 | @main.command()
89 | @click.option('--source', 'source_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True)
90 | @click.option('--target', 'target_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True)
91 | @click.option('--t', 'threshold', help='threshold for sscd similarity', type=float, default=0.6)
92 |
93 | def rpscore(source_path, target_path, threshold):
94 | """Calculate reproducibility score between source images and targe images."""
95 | source_features = np.load(source_path)["features"]
96 | target_features = np.load(target_path)["features"]
97 | similarity = (source_features * target_features).sum(axis=1)
98 | rpscore = (similarity > threshold).mean()
99 | print('RP score = ', rpscore)
100 | return rpscore
101 |
102 | @main.command()
103 | @click.option('--source', 'source_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True)
104 | @click.option('--target', 'target_path', help='Path to source sscd feature', metavar='NPZ', type=str, required=True)
105 | @click.option('--t', 'threshold', help='threshold for sscd similarity', type=float, default=0.6)
106 |
107 | def mscore(source_path, target_path, threshold):
108 | """Calculate reproducibility score between source images and targe images."""
109 | bs = 128
110 | source_features = np.load(source_path)["features"][:, None, :]
111 | target_features = np.load(target_path)["features"][None, :, :]
112 | rpscore = 0
113 | total_sample = source_features.shape[0]
114 | for idx in tqdm.tqdm(range(total_sample//bs + 1)):
115 |
116 | similarity = (source_features[idx*bs: (idx + 1)*bs, :] * target_features).sum(axis=2).max(axis=1)
117 | rpscore += (similarity > threshold).sum()
118 | rpscore = rpscore/total_sample
119 | print('M score = ', 1 - rpscore)
120 | return rpscore
121 |
122 |
123 | #----------------------------------------------------------------------------
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
128 | #----------------------------------------------------------------------------
129 |
--------------------------------------------------------------------------------
/edm/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/edm/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/edm/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | import dnnlib
14 |
15 | #----------------------------------------------------------------------------
16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17 | # same constant is used multiple times.
18 |
19 | _constant_cache = dict()
20 |
21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22 | value = np.asarray(value)
23 | if shape is not None:
24 | shape = tuple(shape)
25 | if dtype is None:
26 | dtype = torch.get_default_dtype()
27 | if device is None:
28 | device = torch.device('cpu')
29 | if memory_format is None:
30 | memory_format = torch.contiguous_format
31 |
32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33 | tensor = _constant_cache.get(key, None)
34 | if tensor is None:
35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36 | if shape is not None:
37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38 | tensor = tensor.contiguous(memory_format=memory_format)
39 | _constant_cache[key] = tensor
40 | return tensor
41 |
42 | #----------------------------------------------------------------------------
43 | # Replace NaN/Inf with specified numerical values.
44 |
45 | try:
46 | nan_to_num = torch.nan_to_num # 1.8.0a0
47 | except AttributeError:
48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49 | assert isinstance(input, torch.Tensor)
50 | if posinf is None:
51 | posinf = torch.finfo(input.dtype).max
52 | if neginf is None:
53 | neginf = torch.finfo(input.dtype).min
54 | assert nan == 0
55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56 |
57 | #----------------------------------------------------------------------------
58 | # Symbolic assert.
59 |
60 | try:
61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62 | except AttributeError:
63 | symbolic_assert = torch.Assert # 1.7.0
64 |
65 | #----------------------------------------------------------------------------
66 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68 |
69 | @contextlib.contextmanager
70 | def suppress_tracer_warnings():
71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72 | warnings.filters.insert(0, flt)
73 | yield
74 | warnings.filters.remove(flt)
75 |
76 | #----------------------------------------------------------------------------
77 | # Assert that the shape of a tensor matches the given list of integers.
78 | # None indicates that the size of a dimension is allowed to vary.
79 | # Performs symbolic assertion when used in torch.jit.trace().
80 |
81 | def assert_shape(tensor, ref_shape):
82 | if tensor.ndim != len(ref_shape):
83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85 | if ref_size is None:
86 | pass
87 | elif isinstance(ref_size, torch.Tensor):
88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90 | elif isinstance(size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93 | elif size != ref_size:
94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95 |
96 | #----------------------------------------------------------------------------
97 | # Function decorator that calls torch.autograd.profiler.record_function().
98 |
99 | def profiled_function(fn):
100 | def decorator(*args, **kwargs):
101 | with torch.autograd.profiler.record_function(fn.__name__):
102 | return fn(*args, **kwargs)
103 | decorator.__name__ = fn.__name__
104 | return decorator
105 |
106 | #----------------------------------------------------------------------------
107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
108 | # indefinitely, shuffling items as it goes.
109 |
110 | class InfiniteSampler(torch.utils.data.Sampler):
111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112 | assert len(dataset) > 0
113 | assert num_replicas > 0
114 | assert 0 <= rank < num_replicas
115 | assert 0 <= window_size <= 1
116 | super().__init__(dataset)
117 | self.dataset = dataset
118 | self.rank = rank
119 | self.num_replicas = num_replicas
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.window_size = window_size
123 |
124 | def __iter__(self):
125 | order = np.arange(len(self.dataset))
126 | rnd = None
127 | window = 0
128 | if self.shuffle:
129 | rnd = np.random.RandomState(self.seed)
130 | rnd.shuffle(order)
131 | window = int(np.rint(order.size * self.window_size))
132 |
133 | idx = 0
134 | while True:
135 | i = idx % order.size
136 | if idx % self.num_replicas == self.rank:
137 | yield order[i]
138 | if window >= 2:
139 | j = (i - rnd.randint(window)) % order.size
140 | order[i], order[j] = order[j], order[i]
141 | idx += 1
142 |
143 | #----------------------------------------------------------------------------
144 | # Utilities for operating with torch.nn.Module parameters and buffers.
145 |
146 | def params_and_buffers(module):
147 | assert isinstance(module, torch.nn.Module)
148 | return list(module.parameters()) + list(module.buffers())
149 |
150 | def named_params_and_buffers(module):
151 | assert isinstance(module, torch.nn.Module)
152 | return list(module.named_parameters()) + list(module.named_buffers())
153 |
154 | @torch.no_grad()
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = dict(named_params_and_buffers(src_module))
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name])
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | if tensor.is_floating_point():
188 | tensor = nan_to_num(tensor)
189 | other = tensor.clone()
190 | torch.distributed.broadcast(tensor=other, src=0)
191 | assert (tensor == other).all(), fullname
192 |
193 | #----------------------------------------------------------------------------
194 | # Print summary table of module hierarchy.
195 |
196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197 | assert isinstance(module, torch.nn.Module)
198 | assert not isinstance(module, torch.jit.ScriptModule)
199 | assert isinstance(inputs, (tuple, list))
200 |
201 | # Register hooks.
202 | entries = []
203 | nesting = [0]
204 | def pre_hook(_mod, _inputs):
205 | nesting[0] += 1
206 | def post_hook(mod, _inputs, outputs):
207 | nesting[0] -= 1
208 | if nesting[0] <= max_nesting:
209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214 |
215 | # Run module.
216 | outputs = module(*inputs)
217 | for hook in hooks:
218 | hook.remove()
219 |
220 | # Identify unique outputs, parameters, and buffers.
221 | tensors_seen = set()
222 | for e in entries:
223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227 |
228 | # Filter out redundant entries.
229 | if skip_redundant:
230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231 |
232 | # Construct table.
233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234 | rows += [['---'] * len(rows[0])]
235 | param_total = 0
236 | buffer_total = 0
237 | submodule_names = {mod: name for name, mod in module.named_modules()}
238 | for e in entries:
239 | name = '' if e.mod is module else submodule_names[e.mod]
240 | param_size = sum(t.numel() for t in e.unique_params)
241 | buffer_size = sum(t.numel() for t in e.unique_buffers)
242 | output_shapes = [str(list(t.shape)) for t in e.outputs]
243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244 | rows += [[
245 | name + (':0' if len(e.outputs) >= 2 else ''),
246 | str(param_size) if param_size else '-',
247 | str(buffer_size) if buffer_size else '-',
248 | (output_shapes + ['-'])[0],
249 | (output_dtypes + ['-'])[0],
250 | ]]
251 | for idx in range(1, len(e.outputs)):
252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253 | param_total += param_size
254 | buffer_total += buffer_size
255 | rows += [['---'] * len(rows[0])]
256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257 |
258 | # Print table.
259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260 | print()
261 | for row in rows:
262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263 | print()
264 | return outputs
265 |
266 | #----------------------------------------------------------------------------
267 |
--------------------------------------------------------------------------------
/edm/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import dnnlib
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | _version = 6 # internal version number
27 | _decorators = set() # {decorator_class, ...}
28 | _import_hooks = [] # [hook_function, ...]
29 | _module_to_src_dict = dict() # {module: src, ...}
30 | _src_to_module_dict = dict() # {src: module, ...}
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def persistent_class(orig_class):
35 | r"""Class decorator that extends a given class to save its source code
36 | when pickled.
37 |
38 | Example:
39 |
40 | from torch_utils import persistence
41 |
42 | @persistence.persistent_class
43 | class MyNetwork(torch.nn.Module):
44 | def __init__(self, num_inputs, num_outputs):
45 | super().__init__()
46 | self.fc = MyLayer(num_inputs, num_outputs)
47 | ...
48 |
49 | @persistence.persistent_class
50 | class MyLayer(torch.nn.Module):
51 | ...
52 |
53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54 | source code alongside other internal state (e.g., parameters, buffers,
55 | and submodules). This way, any previously exported pickle will remain
56 | usable even if the class definitions have been modified or are no
57 | longer available.
58 |
59 | The decorator saves the source code of the entire Python module
60 | containing the decorated class. It does *not* save the source code of
61 | any imported modules. Thus, the imported modules must be available
62 | during unpickling, also including `torch_utils.persistence` itself.
63 |
64 | It is ok to call functions defined in the same module from the
65 | decorated class. However, if the decorated class depends on other
66 | classes defined in the same module, they must be decorated as well.
67 | This is illustrated in the above example in the case of `MyLayer`.
68 |
69 | It is also possible to employ the decorator just-in-time before
70 | calling the constructor. For example:
71 |
72 | cls = MyLayer
73 | if want_to_make_it_persistent:
74 | cls = persistence.persistent_class(cls)
75 | layer = cls(num_inputs, num_outputs)
76 |
77 | As an additional feature, the decorator also keeps track of the
78 | arguments that were used to construct each instance of the decorated
79 | class. The arguments can be queried via `obj.init_args` and
80 | `obj.init_kwargs`, and they are automatically pickled alongside other
81 | object state. This feature can be disabled on a per-instance basis
82 | by setting `self._record_init_args = False` in the constructor.
83 |
84 | A typical use case is to first unpickle a previous instance of a
85 | persistent class, and then upgrade it to use the latest version of
86 | the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | record_init_args = getattr(self, '_record_init_args', True)
108 | self._init_args = copy.deepcopy(args) if record_init_args else None
109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110 | assert orig_class.__name__ in orig_module.__dict__
111 | _check_pickleable(self.__reduce__())
112 |
113 | @property
114 | def init_args(self):
115 | assert self._init_args is not None
116 | return copy.deepcopy(self._init_args)
117 |
118 | @property
119 | def init_kwargs(self):
120 | assert self._init_kwargs is not None
121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122 |
123 | def __reduce__(self):
124 | fields = list(super().__reduce__())
125 | fields += [None] * max(3 - len(fields), 0)
126 | if fields[0] is not _reconstruct_persistent_obj:
127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128 | fields[0] = _reconstruct_persistent_obj # reconstruct func
129 | fields[1] = (meta,) # reconstruct args
130 | fields[2] = None # state dict
131 | return tuple(fields)
132 |
133 | Decorator.__name__ = orig_class.__name__
134 | Decorator.__module__ = orig_class.__module__
135 | _decorators.add(Decorator)
136 | return Decorator
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | def is_persistent(obj):
141 | r"""Test whether the given object or class is persistent, i.e.,
142 | whether it will save its source code when pickled.
143 | """
144 | try:
145 | if obj in _decorators:
146 | return True
147 | except TypeError:
148 | pass
149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def import_hook(hook):
154 | r"""Register an import hook that is called whenever a persistent object
155 | is being unpickled. A typical use case is to patch the pickled source
156 | code to avoid errors and inconsistencies when the API of some imported
157 | module has changed.
158 |
159 | The hook should have the following signature:
160 |
161 | hook(meta) -> modified meta
162 |
163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164 |
165 | type: Type of the persistent object, e.g. `'class'`.
166 | version: Internal version number of `torch_utils.persistence`.
167 | module_src Original source code of the Python module.
168 | class_name: Class name in the original Python module.
169 | state: Internal state of the object.
170 |
171 | Example:
172 |
173 | @persistence.import_hook
174 | def wreck_my_network(meta):
175 | if meta.class_name == 'MyNetwork':
176 | print('MyNetwork is being imported. I will wreck it!')
177 | meta.module_src = meta.module_src.replace("True", "False")
178 | return meta
179 | """
180 | assert callable(hook)
181 | _import_hooks.append(hook)
182 |
183 | #----------------------------------------------------------------------------
184 |
185 | def _reconstruct_persistent_obj(meta):
186 | r"""Hook that is called internally by the `pickle` module to unpickle
187 | a persistent object.
188 | """
189 | meta = dnnlib.EasyDict(meta)
190 | meta.state = dnnlib.EasyDict(meta.state)
191 | for hook in _import_hooks:
192 | meta = hook(meta)
193 | assert meta is not None
194 |
195 | assert meta.version == _version
196 | module = _src_to_module(meta.module_src)
197 |
198 | assert meta.type == 'class'
199 | orig_class = module.__dict__[meta.class_name]
200 | decorator_class = persistent_class(orig_class)
201 | obj = decorator_class.__new__(decorator_class)
202 |
203 | setstate = getattr(obj, '__setstate__', None)
204 | if callable(setstate):
205 | setstate(meta.state) # pylint: disable=not-callable
206 | else:
207 | obj.__dict__.update(meta.state)
208 | return obj
209 |
210 | #----------------------------------------------------------------------------
211 |
212 | def _module_to_src(module):
213 | r"""Query the source code of a given Python module.
214 | """
215 | src = _module_to_src_dict.get(module, None)
216 | if src is None:
217 | src = inspect.getsource(module)
218 | _module_to_src_dict[module] = src
219 | _src_to_module_dict[src] = module
220 | return src
221 |
222 | def _src_to_module(src):
223 | r"""Get or create a Python module for the given source code.
224 | """
225 | module = _src_to_module_dict.get(src, None)
226 | if module is None:
227 | module_name = "_imported_module_" + uuid.uuid4().hex
228 | module = types.ModuleType(module_name)
229 | sys.modules[module_name] = module
230 | _module_to_src_dict[module] = src
231 | _src_to_module_dict[src] = module
232 | exec(src, module.__dict__) # pylint: disable=exec-used
233 | return module
234 |
235 | #----------------------------------------------------------------------------
236 |
237 | def _check_pickleable(obj):
238 | r"""Check that the given object is pickleable, raising an exception if
239 | it is not. This function is expected to be considerably more efficient
240 | than actually pickling the object.
241 | """
242 | def recurse(obj):
243 | if isinstance(obj, (list, tuple, set)):
244 | return [recurse(x) for x in obj]
245 | if isinstance(obj, dict):
246 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248 | return None # Python primitive types are pickleable.
249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250 | return None # NumPy arrays and PyTorch tensors are pickleable.
251 | if is_persistent(obj):
252 | return None # Persistent objects are pickleable, by virtue of the constructor check.
253 | return obj
254 | with io.BytesIO() as f:
255 | pickle.dump(recurse(obj), f)
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/edm/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from . import misc
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
25 | _rank = 0 # Rank of the current process.
26 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
27 | _sync_called = False # Has _sync() been called yet?
28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def init_multiprocessing(rank, sync_device):
34 | r"""Initializes `torch_utils.training_stats` for collecting statistics
35 | across multiple processes.
36 |
37 | This function must be called after
38 | `torch.distributed.init_process_group()` and before `Collector.update()`.
39 | The call is not necessary if multi-process collection is not needed.
40 |
41 | Args:
42 | rank: Rank of the current process.
43 | sync_device: PyTorch device to use for inter-process
44 | communication, or None to disable multi-process
45 | collection. Typically `torch.device('cuda', rank)`.
46 | """
47 | global _rank, _sync_device
48 | assert not _sync_called
49 | _rank = rank
50 | _sync_device = sync_device
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | @misc.profiled_function
55 | def report(name, value):
56 | r"""Broadcasts the given set of scalars to all interested instances of
57 | `Collector`, across device and process boundaries.
58 |
59 | This function is expected to be extremely cheap and can be safely
60 | called from anywhere in the training loop, loss function, or inside a
61 | `torch.nn.Module`.
62 |
63 | Warning: The current implementation expects the set of unique names to
64 | be consistent across processes. Please make sure that `report()` is
65 | called at least once for each unique name by each process, and in the
66 | same order. If a given process has no scalars to broadcast, it can do
67 | `report(name, [])` (empty list).
68 |
69 | Args:
70 | name: Arbitrary string specifying the name of the statistic.
71 | Averages are accumulated separately for each unique name.
72 | value: Arbitrary set of scalars. Can be a list, tuple,
73 | NumPy array, PyTorch tensor, or Python scalar.
74 |
75 | Returns:
76 | The same `value` that was passed in.
77 | """
78 | if name not in _counters:
79 | _counters[name] = dict()
80 |
81 | elems = torch.as_tensor(value)
82 | if elems.numel() == 0:
83 | return value
84 |
85 | elems = elems.detach().flatten().to(_reduce_dtype)
86 | moments = torch.stack([
87 | torch.ones_like(elems).sum(),
88 | elems.sum(),
89 | elems.square().sum(),
90 | ])
91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
92 | moments = moments.to(_counter_dtype)
93 |
94 | device = moments.device
95 | if device not in _counters[name]:
96 | _counters[name][device] = torch.zeros_like(moments)
97 | _counters[name][device].add_(moments)
98 | return value
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | def report0(name, value):
103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104 | but ignores any scalars provided by the other processes.
105 | See `report()` for further details.
106 | """
107 | report(name, value if _rank == 0 else [])
108 | return value
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | class Collector:
113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
114 | computes their long-term averages (mean and standard deviation) over
115 | user-defined periods of time.
116 |
117 | The averages are first collected into internal counters that are not
118 | directly visible to the user. They are then copied to the user-visible
119 | state as a result of calling `update()` and can then be queried using
120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121 | internal counters for the next round, so that the user-visible state
122 | effectively reflects averages collected between the last two calls to
123 | `update()`.
124 |
125 | Args:
126 | regex: Regular expression defining which statistics to
127 | collect. The default is to collect everything.
128 | keep_previous: Whether to retain the previous averages if no
129 | scalars were collected on a given round
130 | (default: True).
131 | """
132 | def __init__(self, regex='.*', keep_previous=True):
133 | self._regex = re.compile(regex)
134 | self._keep_previous = keep_previous
135 | self._cumulative = dict()
136 | self._moments = dict()
137 | self.update()
138 | self._moments.clear()
139 |
140 | def names(self):
141 | r"""Returns the names of all statistics broadcasted so far that
142 | match the regular expression specified at construction time.
143 | """
144 | return [name for name in _counters if self._regex.fullmatch(name)]
145 |
146 | def update(self):
147 | r"""Copies current values of the internal counters to the
148 | user-visible state and resets them for the next round.
149 |
150 | If `keep_previous=True` was specified at construction time, the
151 | operation is skipped for statistics that have received no scalars
152 | since the last update, retaining their previous averages.
153 |
154 | This method performs a number of GPU-to-CPU transfers and one
155 | `torch.distributed.all_reduce()`. It is intended to be called
156 | periodically in the main training loop, typically once every
157 | N training steps.
158 | """
159 | if not self._keep_previous:
160 | self._moments.clear()
161 | for name, cumulative in _sync(self.names()):
162 | if name not in self._cumulative:
163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164 | delta = cumulative - self._cumulative[name]
165 | self._cumulative[name].copy_(cumulative)
166 | if float(delta[0]) != 0:
167 | self._moments[name] = delta
168 |
169 | def _get_delta(self, name):
170 | r"""Returns the raw moments that were accumulated for the given
171 | statistic between the last two calls to `update()`, or zero if
172 | no scalars were collected.
173 | """
174 | assert self._regex.fullmatch(name)
175 | if name not in self._moments:
176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177 | return self._moments[name]
178 |
179 | def num(self, name):
180 | r"""Returns the number of scalars that were accumulated for the given
181 | statistic between the last two calls to `update()`, or zero if
182 | no scalars were collected.
183 | """
184 | delta = self._get_delta(name)
185 | return int(delta[0])
186 |
187 | def mean(self, name):
188 | r"""Returns the mean of the scalars that were accumulated for the
189 | given statistic between the last two calls to `update()`, or NaN if
190 | no scalars were collected.
191 | """
192 | delta = self._get_delta(name)
193 | if int(delta[0]) == 0:
194 | return float('nan')
195 | return float(delta[1] / delta[0])
196 |
197 | def std(self, name):
198 | r"""Returns the standard deviation of the scalars that were
199 | accumulated for the given statistic between the last two calls to
200 | `update()`, or NaN if no scalars were collected.
201 | """
202 | delta = self._get_delta(name)
203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204 | return float('nan')
205 | if int(delta[0]) == 1:
206 | return float(0)
207 | mean = float(delta[1] / delta[0])
208 | raw_var = float(delta[2] / delta[0])
209 | return np.sqrt(max(raw_var - np.square(mean), 0))
210 |
211 | def as_dict(self):
212 | r"""Returns the averages accumulated between the last two calls to
213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214 |
215 | dnnlib.EasyDict(
216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217 | ...
218 | )
219 | """
220 | stats = dnnlib.EasyDict()
221 | for name in self.names():
222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223 | return stats
224 |
225 | def __getitem__(self, name):
226 | r"""Convenience getter.
227 | `collector[name]` is a synonym for `collector.mean(name)`.
228 | """
229 | return self.mean(name)
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _sync(names):
234 | r"""Synchronize the global cumulative counters across devices and
235 | processes. Called internally by `Collector.update()`.
236 | """
237 | if len(names) == 0:
238 | return []
239 | global _sync_called
240 | _sync_called = True
241 |
242 | # Collect deltas within current rank.
243 | deltas = []
244 | device = _sync_device if _sync_device is not None else torch.device('cpu')
245 | for name in names:
246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247 | for counter in _counters[name].values():
248 | delta.add_(counter.to(device))
249 | counter.copy_(torch.zeros_like(counter))
250 | deltas.append(delta)
251 | deltas = torch.stack(deltas)
252 |
253 | # Sum deltas across ranks.
254 | if _sync_device is not None:
255 | torch.distributed.all_reduce(deltas)
256 |
257 | # Update cumulative values.
258 | deltas = deltas.cpu()
259 | for idx, name in enumerate(names):
260 | if name not in _cumulative:
261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262 | _cumulative[name].add_(deltas[idx])
263 |
264 | # Return name-value pairs.
265 | return [(name, _cumulative[name]) for name in names]
266 |
267 | #----------------------------------------------------------------------------
268 | # Convenience.
269 |
270 | default_collector = Collector()
271 |
272 | #----------------------------------------------------------------------------
273 |
--------------------------------------------------------------------------------
/edm/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Train diffusion-based generative model using the techniques described in the
9 | paper "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import os
12 | import re
13 | import json
14 | import click
15 | import torch
16 | import dnnlib
17 | from torch_utils import distributed as dist
18 | from training import training_loop
19 |
20 | import warnings
21 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
22 |
23 | #----------------------------------------------------------------------------
24 | # Parse a comma separated list of numbers or ranges and return a list of ints.
25 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
26 |
27 | def parse_int_list(s):
28 | if isinstance(s, list): return s
29 | ranges = []
30 | range_re = re.compile(r'^(\d+)-(\d+)$')
31 | for p in s.split(','):
32 | m = range_re.match(p)
33 | if m:
34 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
35 | else:
36 | ranges.append(int(p))
37 | return ranges
38 |
39 | #----------------------------------------------------------------------------
40 |
41 | @click.command()
42 |
43 | # Main options.
44 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True)
45 | @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
46 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
47 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True)
48 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
49 |
50 | # Hyperparameters.
51 |
52 | @click.option('--model_channels', help='model channel', metavar='INT', default=128, type=int, required=True)
53 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True)
54 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
55 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
56 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
57 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
58 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
59 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
60 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
61 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
62 | @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
63 |
64 | # Performance-related.
65 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
66 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
67 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
68 | @click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True)
69 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
70 |
71 | # I/O-related.
72 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
73 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
74 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
75 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)
76 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
77 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
78 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
79 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=str)
80 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
81 |
82 | def main(**kwargs):
83 | """Train diffusion-based generative model using the techniques described in the
84 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".
85 |
86 | Examples:
87 |
88 | \b
89 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
90 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
91 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
92 | """
93 | opts = dnnlib.EasyDict(kwargs)
94 | torch.multiprocessing.set_start_method('spawn')
95 | dist.init()
96 |
97 | # Initialize config dict.
98 | c = dnnlib.EasyDict()
99 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache)
100 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
101 | c.network_kwargs = dnnlib.EasyDict()
102 | c.loss_kwargs = dnnlib.EasyDict()
103 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
104 |
105 | # Validate dataset options.
106 | try:
107 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
108 | dataset_name = dataset_obj.name
109 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
110 | c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size
111 | if opts.cond and not dataset_obj.has_labels:
112 | raise click.ClickException('--cond=True requires labels specified in dataset.json')
113 | del dataset_obj # conserve memory
114 | except IOError as err:
115 | raise click.ClickException(f'--data: {err}')
116 |
117 | # Network architecture.
118 | if opts.arch == 'ddpmpp':
119 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
120 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels = opts.model_channels, channel_mult=[2,2,2])
121 | elif opts.arch == 'ncsnpp':
122 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
123 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels = opts.model_channels, channel_mult=[2,2,2])
124 | else:
125 | assert opts.arch == 'adm'
126 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
127 |
128 | # Preconditioning & loss function.
129 | if opts.precond == 'vp':
130 | c.network_kwargs.class_name = 'training.networks.VPPrecond'
131 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
132 | elif opts.precond == 've':
133 | c.network_kwargs.class_name = 'training.networks.VEPrecond'
134 | c.loss_kwargs.class_name = 'training.loss.VELoss'
135 | else:
136 | assert opts.precond == 'edm'
137 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
138 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
139 |
140 | # Network options.
141 | if opts.cbase is not None:
142 | c.network_kwargs.model_channels = opts.cbase
143 | if opts.cres is not None:
144 | c.network_kwargs.channel_mult = opts.cres
145 | if opts.augment:
146 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
147 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
148 | c.network_kwargs.augment_dim = 9
149 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
150 |
151 | # Training options.
152 | c.total_kimg = max(int(opts.duration * 1000), 1)
153 | c.ema_halflife_kimg = int(opts.ema * 1000)
154 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
155 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
156 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
157 |
158 | # Random seed.
159 | if opts.seed is not None:
160 | c.seed = opts.seed
161 | else:
162 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
163 | torch.distributed.broadcast(seed, src=0)
164 | c.seed = int(seed)
165 |
166 | # Transfer learning and resume.
167 | if opts.transfer is not None:
168 | if opts.resume is not None:
169 | raise click.ClickException('--transfer and --resume cannot be specified at the same time')
170 | c.resume_pkl = opts.transfer
171 | c.ema_rampup_ratio = None
172 | elif opts.resume is not None:
173 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume))
174 | if not match or not os.path.isfile(opts.resume):
175 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
176 | c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl')
177 | c.resume_kimg = int(match.group(1))
178 | c.resume_state_dump = opts.resume
179 |
180 | # Description string.
181 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
182 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
183 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
184 | if opts.desc is not None:
185 | desc += f'-{opts.desc}'
186 |
187 | # Pick output directory.
188 | if dist.get_rank() != 0:
189 | c.run_dir = None
190 | elif opts.nosubdir:
191 | c.run_dir = opts.outdir
192 | else:
193 | prev_run_dirs = []
194 | if os.path.isdir(opts.outdir):
195 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
196 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
197 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
198 | cur_run_id = max(prev_run_ids, default=-1) + 1
199 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
200 | assert not os.path.exists(c.run_dir)
201 |
202 | # Print options.
203 | dist.print0()
204 | dist.print0('Training options:')
205 | dist.print0(json.dumps(c, indent=2))
206 | dist.print0()
207 | dist.print0(f'Output directory: {c.run_dir}')
208 | dist.print0(f'Dataset path: {c.dataset_kwargs.path}')
209 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
210 | dist.print0(f'Network architecture: {opts.arch}')
211 | dist.print0(f'Preconditioning & loss: {opts.precond}')
212 | dist.print0(f'Number of GPUs: {dist.get_world_size()}')
213 | dist.print0(f'Batch size: {c.batch_size}')
214 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
215 | dist.print0()
216 |
217 | # Dry run?
218 | if opts.dry_run:
219 | dist.print0('Dry run; exiting.')
220 | return
221 |
222 | # Create output directory.
223 | dist.print0('Creating output directory...')
224 | if dist.get_rank() == 0:
225 | os.makedirs(c.run_dir, exist_ok=True)
226 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
227 | json.dump(c, f, indent=2)
228 | dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)
229 |
230 | # Train.
231 | training_loop.training_loop(**c)
232 |
233 | #----------------------------------------------------------------------------
234 |
235 | if __name__ == "__main__":
236 | main()
237 |
238 | #----------------------------------------------------------------------------
239 |
--------------------------------------------------------------------------------
/edm/trainMoLRG.py:
--------------------------------------------------------------------------------
1 | """Train mixture of low-rank Gaussian Distribution (MoLRG) using diffusion model"""
2 |
3 | import os
4 | import re
5 | import json
6 | import click
7 | import torch
8 | import dnnlib
9 | from torch_utils import distributed as dist
10 | from training import training_loop_MoLRG
11 | from glob import glob
12 |
13 | import warnings
14 | warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12.
15 |
16 | #----------------------------------------------------------------------------
17 | # Parse a comma separated list of numbers or ranges and return a list of ints.
18 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
19 |
20 | def parse_int_list(s):
21 | if isinstance(s, list): return s
22 | ranges = []
23 | range_re = re.compile(r'^(\d+)-(\d+)$')
24 | for p in s.split(','):
25 | m = range_re.match(p)
26 | if m:
27 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
28 | else:
29 | ranges.append(int(p))
30 | return ranges
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | @click.command()
35 |
36 | # Main options.
37 | @click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, default="/home/ubuntu/exp/theory/")
38 | @click.option('--img_res', help='resolusion of the MoG image', metavar='DIR', type=int, default=32)
39 | @click.option('--class_num', help='number of classes for the MoG', metavar='DIR', type=int, default=10)
40 | @click.option('--per_class_dim', help='dimension for each class', metavar='DIR', type=int, default=100)
41 | @click.option('--sample_per_class', help='num of samples for each class', metavar='DIR', type=int, default=5000)
42 | @click.option('--path', help='num of samples for each class', metavar='DIR', type=str, required=True)
43 | # @click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True)
44 | @click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
45 | @click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm|mlp|param-mlp', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm', 'mlp', "param-mlp"]), default='ddpmpp', show_default=True)
46 | @click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True)
47 |
48 |
49 | # Hyperparameters.
50 | @click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=6, show_default=True)
51 | @click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
52 | @click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1))
53 | @click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int)
54 | @click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list)
55 | @click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True)
56 | @click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True)
57 | @click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True)
58 | @click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True)
59 | # @click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True)
60 | @click.option('--embed_channels', help='Channel multiplier [default: varies]', metavar='INT', type=int, default=1024)
61 |
62 | # Performance-related.
63 | @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True)
64 | @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True)
65 | @click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True)
66 | @click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True)
67 |
68 | # I/O-related.
69 | @click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
70 | @click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True)
71 | @click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True)
72 | @click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=10, show_default=True)
73 | @click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True)
74 | @click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int)
75 | @click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str)
76 | @click.option('--resume', help='Resume from previous training state', metavar='PT', type=bool)
77 | @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
78 | @click.option('--resumedir', help='Resume from previous training directory', metavar='PT', type=str)
79 | @click.option('--optimizer', help='training optimizer', metavar='PT', default="adam", type=click.Choice(['adam', 'sgd']))
80 |
81 |
82 | def main(**kwargs):
83 | """Train diffusion-based generative model using the techniques described in the
84 | paper "Elucidating the Design Space of Diffusion-Based Generative Models".
85 |
86 | Examples:
87 |
88 | \b
89 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs
90 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\
91 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp
92 | """
93 | opts = dnnlib.EasyDict(kwargs)
94 | torch.multiprocessing.set_start_method('spawn')
95 | dist.init()
96 |
97 |
98 | # Initialize config dict.
99 | c = dnnlib.EasyDict()
100 | c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.MoLRG', resolution=opts.img_res, class_num=opts.class_num, per_class_dim=opts.per_class_dim, sample_per_class=opts.sample_per_class, path = opts.path, use_labels=opts.cond)
101 |
102 | c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2)
103 | c.network_kwargs = dnnlib.EasyDict()
104 | c.loss_kwargs = dnnlib.EasyDict()
105 | if opts.optimizer == "adam":
106 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8)
107 | elif opts.optimizer == "sgd":
108 | c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.SGD', lr=opts.lr)
109 | opts.outdir = os.path.join(opts.outdir, f"MoLRG_dataset_resolution{opts.img_res}_classnum{opts.class_num}_perclassdim{opts.per_class_dim}_sample{opts.sample_per_class}")
110 |
111 | # Validate dataset options.
112 | try:
113 | dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs)
114 | dataset_name = dataset_obj.name
115 | c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution
116 | if opts.cond and not dataset_obj.has_labels:
117 | raise click.ClickException('--cond=True requires labels specified in dataset.json')
118 | del dataset_obj # conserve memory
119 | except IOError as err:
120 | raise click.ClickException(f'--data: {err}')
121 |
122 | # Network architecture.
123 | if opts.arch == 'ddpmpp':
124 | c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
125 | c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=opts.embed_channels, channel_mult=[2,2,2])
126 | elif opts.arch == 'ncsnpp':
127 | c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard')
128 | c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=opts.embed_channels, channel_mult=[2,2,2])
129 | elif opts.arch == 'mlp':
130 | c.network_kwargs.update(model_type='TwoLayerMLP', embedding_type='positional')
131 | c.network_kwargs.update(embed_channels=opts.embed_channels, noise_channels=256)
132 | elif opts.arch == 'param-mlp':
133 | c.network_kwargs.update(model_type='ParametericMLP')
134 | c.network_kwargs.update(class_dim=opts.class_num, latent_dim=opts.per_class_dim)
135 | else:
136 | assert opts.arch == 'adm'
137 | c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
138 |
139 | # Preconditioning & loss function.
140 | if opts.precond == 'vp':
141 | c.network_kwargs.class_name = 'training.networks.VPPrecond'
142 | c.loss_kwargs.class_name = 'training.loss.VPLoss'
143 | elif opts.precond == 've':
144 | c.network_kwargs.class_name = 'training.networks.VEPrecond'
145 | c.loss_kwargs.class_name = 'training.loss.VELoss'
146 | else:
147 | assert opts.precond == 'edm'
148 | c.network_kwargs.class_name = 'training.networks.EDMPrecond'
149 | c.loss_kwargs.class_name = 'training.loss.EDMLoss'
150 |
151 | # Network options.
152 | if opts.cbase is not None:
153 | c.network_kwargs.model_channels = opts.cbase
154 | if opts.cres is not None:
155 | c.network_kwargs.channel_mult = opts.cres
156 | if opts.augment:
157 | c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment)
158 | c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1)
159 | c.network_kwargs.augment_dim = 9
160 | c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16)
161 |
162 | # Training options.
163 | c.total_kimg = max(int(opts.duration * 1000), 1)
164 | c.ema_halflife_kimg = int(opts.ema * 1000)
165 | c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu)
166 | c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench)
167 | c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump)
168 |
169 | # Random seed.
170 | if opts.seed is not None:
171 | c.seed = opts.seed
172 | else:
173 | seed = torch.randint(1 << 31, size=[], device=torch.device('cuda'))
174 | torch.distributed.broadcast(seed, src=0)
175 | c.seed = int(seed)
176 |
177 | # Transfer learning and resume.
178 | if opts.transfer is not None:
179 | if opts.resume is not None:
180 | raise click.ClickException('--transfer and --resume cannot be specified at the same time')
181 | c.resume_pkl = opts.transfer
182 | c.ema_rampup_ratio = None
183 |
184 | elif opts.resumedir is not None:
185 | pt_files = glob(os.path.join(opts.resumedir, 'training-state-*.pt'))
186 | pt_files.sort()
187 | if len(pt_files)!=0:
188 | latest_file = pt_files[-1]
189 |
190 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(latest_file))
191 | if not match or not os.path.isfile(latest_file):
192 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
193 | c.resume_pkl = os.path.join(os.path.dirname(latest_file), f'network-snapshot-{match.group(1)}.pkl')
194 | c.resume_kimg = int(match.group(1))
195 | c.resume_state_dump = latest_file
196 |
197 | # Description string.
198 | cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond'
199 | dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32'
200 | desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}'
201 | if opts.desc is not None:
202 | desc += f'-{opts.desc}'
203 |
204 |
205 | # Pick output directory.
206 | if dist.get_rank() != 0:
207 | c.run_dir = None
208 | elif opts.nosubdir:
209 | c.run_dir = opts.outdir
210 | elif opts.resumedir:
211 | c.run_dir = opts.resumedir
212 | else:
213 | prev_run_dirs = []
214 | if os.path.isdir(opts.outdir):
215 | prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))]
216 | prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
217 | prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
218 | if opts.resume and len(prev_run_ids) !=0 :
219 | cur_run_id = max(prev_run_ids, default=0)
220 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
221 |
222 | pt_files = glob(os.path.join(c.run_dir, 'training-state-*.pt'))
223 | pt_files.sort()
224 | latest_file = pt_files[-1]
225 | print(pt_files)
226 | print(latest_file)
227 |
228 | match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(latest_file))
229 | if not match or not os.path.isfile(latest_file):
230 | print(latest_file)
231 | raise click.ClickException('--resume must point to training-state-*.pt from a previous training run')
232 | c.resume_pkl = os.path.join(os.path.dirname(latest_file), f'network-snapshot-{match.group(1)}.pkl')
233 | c.resume_kimg = int(match.group(1))
234 | c.resume_state_dump = os.path.join(os.path.dirname(latest_file), f'training-state-{match.group(1)}.pt')
235 | else:
236 | cur_run_id = max(prev_run_ids, default=0)
237 | c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}')
238 | # assert not os.path.exists(c.run_dir)
239 |
240 |
241 |
242 | # Fine Tune
243 |
244 | # Print options.
245 | dist.print0()
246 | dist.print0('Training options:')
247 | dist.print0(json.dumps(c, indent=2))
248 | dist.print0()
249 | dist.print0(f'Output directory: {c.run_dir}')
250 | dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}')
251 | dist.print0(f'Network architecture: {opts.arch}')
252 | dist.print0(f'Preconditioning & loss: {opts.precond}')
253 | dist.print0(f'Number of GPUs: {dist.get_world_size()}')
254 | dist.print0(f'Batch size: {c.batch_size}')
255 | dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}')
256 | dist.print0()
257 |
258 | # Dry run?
259 | if opts.dry_run:
260 | dist.print0('Dry run; exiting.')
261 | return
262 |
263 | # Create output directory.
264 | dist.print0('Creating output directory...')
265 | if dist.get_rank() == 0:
266 | os.makedirs(c.run_dir, exist_ok=True)
267 | with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
268 | json.dump(c, f, indent=2)
269 | # log_dir = '/home/ubuntu/sky_workdir'
270 | # dnnlib.util.Logger(file_name=os.path.join(log_dir, 'log.txt'), file_mode='a', should_flush=True)
271 |
272 | # Train.
273 | training_loop_MoLRG.training_loop(**c)
274 |
275 | #----------------------------------------------------------------------------
276 |
277 | if __name__ == "__main__":
278 | main()
279 |
280 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/edm/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/edm/training/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Streaming images and labels from datasets created with dataset_tool.py."""
9 |
10 | import os
11 | import numpy as np
12 | import zipfile
13 | import PIL.Image
14 | import json
15 | import torch
16 | import dnnlib
17 | from torchvision import transforms
18 |
19 | try:
20 | import pyspng
21 | except ImportError:
22 | pyspng = None
23 |
24 | #----------------------------------------------------------------------------
25 | # Abstract base class for datasets.
26 |
27 | class Dataset(torch.utils.data.Dataset):
28 | def __init__(self,
29 | name, # Name of the dataset.
30 | raw_shape, # Shape of the raw image data (NCHW).
31 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
32 | use_labels = False, # Enable conditioning labels? False = label dimension is zero.
33 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
34 | random_seed = 0, # Random seed to use when applying max_size.
35 | cache = False, # Cache images in CPU memory?
36 | ):
37 | self._name = name
38 | self._raw_shape = list(raw_shape)
39 | self._use_labels = use_labels
40 | self._cache = cache
41 | self._cached_images = dict() # {raw_idx: np.ndarray, ...}
42 | self._raw_labels = None
43 | self._label_shape = None
44 |
45 | # Apply max_size.
46 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
47 | if (max_size is not None) and (self._raw_idx.size > max_size):
48 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
49 | self._raw_idx = np.sort(self._raw_idx[:max_size])
50 |
51 | # Apply xflip.
52 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
53 | if xflip:
54 | self._raw_idx = np.tile(self._raw_idx, 2)
55 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
56 |
57 | def _get_raw_labels(self):
58 | if self._raw_labels is None:
59 | self._raw_labels = self._load_raw_labels() if self._use_labels else None
60 | if self._raw_labels is None:
61 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
62 | assert isinstance(self._raw_labels, np.ndarray)
63 | assert self._raw_labels.shape[0] == self._raw_shape[0]
64 | assert self._raw_labels.dtype in [np.float32, np.int64]
65 | if self._raw_labels.dtype == np.int64:
66 | assert self._raw_labels.ndim == 1
67 | assert np.all(self._raw_labels >= 0)
68 | return self._raw_labels
69 |
70 | def close(self): # to be overridden by subclass
71 | pass
72 |
73 | def _load_raw_image(self, raw_idx): # to be overridden by subclass
74 | raise NotImplementedError
75 |
76 | def _load_raw_labels(self): # to be overridden by subclass
77 | raise NotImplementedError
78 |
79 | def __getstate__(self):
80 | return dict(self.__dict__, _raw_labels=None)
81 |
82 | def __del__(self):
83 | try:
84 | self.close()
85 | except:
86 | pass
87 |
88 | def __len__(self):
89 | return self._raw_idx.size
90 |
91 | def __getitem__(self, idx):
92 | raw_idx = self._raw_idx[idx]
93 | image = self._cached_images.get(raw_idx, None)
94 | if image is None:
95 | image = self._load_raw_image(raw_idx)
96 | if self._cache:
97 | self._cached_images[raw_idx] = image
98 | assert isinstance(image, np.ndarray)
99 | assert list(image.shape) == self.image_shape
100 | assert image.dtype == np.uint8
101 | if self._xflip[idx]:
102 | assert image.ndim == 3 # CHW
103 | image = image[:, :, ::-1]
104 | return image.copy(), self.get_label(idx)
105 |
106 | def get_label(self, idx):
107 | label = self._get_raw_labels()[self._raw_idx[idx]]
108 | if label.dtype == np.int64:
109 | onehot = np.zeros(self.label_shape, dtype=np.float32)
110 | onehot[label] = 1
111 | label = onehot
112 | return label.copy()
113 |
114 | def get_details(self, idx):
115 | d = dnnlib.EasyDict()
116 | d.raw_idx = int(self._raw_idx[idx])
117 | d.xflip = (int(self._xflip[idx]) != 0)
118 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
119 | return d
120 |
121 | @property
122 | def name(self):
123 | return self._name
124 |
125 | @property
126 | def image_shape(self):
127 | return list(self._raw_shape[1:])
128 |
129 | @property
130 | def num_channels(self):
131 | assert len(self.image_shape) == 3 # CHW
132 | return self.image_shape[0]
133 |
134 | @property
135 | def resolution(self):
136 | assert len(self.image_shape) == 3 # CHW
137 | assert self.image_shape[1] == self.image_shape[2]
138 | return self.image_shape[1]
139 |
140 | @property
141 | def label_shape(self):
142 | if self._label_shape is None:
143 | raw_labels = self._get_raw_labels()
144 | if raw_labels.dtype == np.int64:
145 | self._label_shape = [int(np.max(raw_labels)) + 1]
146 | else:
147 | self._label_shape = raw_labels.shape[1:]
148 | return list(self._label_shape)
149 |
150 | @property
151 | def label_dim(self):
152 | assert len(self.label_shape) == 1
153 | return self.label_shape[0]
154 |
155 | @property
156 | def has_labels(self):
157 | return any(x != 0 for x in self.label_shape)
158 |
159 | @property
160 | def has_onehot_labels(self):
161 | return self._get_raw_labels().dtype == np.int64
162 |
163 | #----------------------------------------------------------------------------
164 | # Dataset subclass that loads images recursively from the specified directory
165 | # or ZIP file.
166 |
167 | class ImageFolderDataset(Dataset):
168 | def __init__(self,
169 | path, # Path to directory or zip.
170 | resolution = None, # Ensure specific resolution, None = highest available.
171 | use_pyspng = True, # Use pyspng if available?
172 | **super_kwargs, # Additional arguments for the Dataset base class.
173 | ):
174 | self._path = path
175 | self._use_pyspng = use_pyspng
176 | self._zipfile = None
177 |
178 | if os.path.isdir(self._path):
179 | self._type = 'dir'
180 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
181 | elif self._file_ext(self._path) == '.zip':
182 | self._type = 'zip'
183 | self._all_fnames = set(self._get_zipfile().namelist())
184 | else:
185 | raise IOError('Path must point to a directory or zip')
186 |
187 | PIL.Image.init()
188 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
189 | images = None
190 | if len(self._image_fnames) == 0:
191 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in ".npz")
192 | if len(self._image_fnames) == 0:
193 | raise IOError('No image files found in the specified path')
194 | else:
195 | images = np.concatenate([np.load(os.path.join(self._path, fname))["samples"].transpose(0, 3, 1, 2) for fname in self._image_fnames], axis = 0)
196 | raw_shape = images.shape
197 | super_kwargs['cache'] = True
198 | else:
199 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
200 | name = os.path.splitext(os.path.basename(self._path))[0]
201 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
202 | raise IOError('Image files do not match the specified resolution')
203 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
204 |
205 | if images is not None:
206 | self._cached_images = {i: images[i] for i in range(images.shape[0])}
207 |
208 | @staticmethod
209 | def _file_ext(fname):
210 | return os.path.splitext(fname)[1].lower()
211 |
212 | def _get_zipfile(self):
213 | assert self._type == 'zip'
214 | if self._zipfile is None:
215 | self._zipfile = zipfile.ZipFile(self._path)
216 | return self._zipfile
217 |
218 | def _open_file(self, fname):
219 | if self._type == 'dir':
220 | return open(os.path.join(self._path, fname), 'rb')
221 | if self._type == 'zip':
222 | return self._get_zipfile().open(fname, 'r')
223 | return None
224 |
225 | def close(self):
226 | try:
227 | if self._zipfile is not None:
228 | self._zipfile.close()
229 | finally:
230 | self._zipfile = None
231 |
232 | def __getstate__(self):
233 | return dict(super().__getstate__(), _zipfile=None)
234 |
235 | def _load_raw_image(self, raw_idx):
236 | fname = self._image_fnames[raw_idx]
237 | # print(fname)
238 | with self._open_file(fname) as f:
239 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png':
240 | image = pyspng.load(f.read())
241 | else:
242 | image = np.array(PIL.Image.open(f))
243 | if image.ndim == 2:
244 | image = image[:, :, np.newaxis] # HW => HWC
245 | image = image.transpose(2, 0, 1) # HWC => CHW
246 | return image
247 |
248 | def _load_raw_labels(self):
249 | fname = 'dataset.json'
250 | if fname not in self._all_fnames:
251 | return None
252 | with self._open_file(fname) as f:
253 | labels = json.load(f)['labels']
254 | if labels is None:
255 | return None
256 | labels = dict(labels)
257 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
258 | labels = np.array(labels)
259 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
260 | return labels
261 |
262 | class optimal_denoiser_dataset(torch.utils.data.Dataset):
263 | def __init__(self,
264 | path, # Path to directory or zip.
265 | transforms = transforms.Compose([]),
266 | ):
267 | self._path = path
268 | self._file_list = os.listdir(self._path)
269 | self._file_list.sort()
270 | self.images = [torch.load(os.path.join(self._path, pth)) for pth in self._file_list]
271 | self.images = torch.cat(self.images).permute((0, 3, 1, 2))
272 |
273 | # if select_image_num is not None and select_image_num > 0:
274 | # self.images = self.images[:select_image_num]
275 | self.transforms = transforms
276 | def __len__(self):
277 | return len(self.images)
278 |
279 | def __getitem__(self, idx):
280 | image = self.transforms(self.images[idx])
281 | return image
282 |
283 | class MoLRG(torch.utils.data.Dataset):
284 | def __init__(self,
285 | resolution = 2,
286 | class_num = 2,
287 | per_class_dim = 2,
288 | sample_per_class = 500,
289 | path = "./datasets",
290 | use_labels = False,
291 | save_dataset = True,
292 | loading_dataset = True,
293 | ):
294 | img_resolution = torch.tensor([resolution, resolution, 3])
295 | dataset_path = os.path.join(path, f"MoLRG_dataset_resolution{resolution}_classnum{class_num}_perclassdim{per_class_dim}_sample{sample_per_class}.pt")
296 | if (not os.path.exists(dataset_path)) or (not loading_dataset):
297 | print("Create new dataset......")
298 | dim = img_resolution.prod()
299 | rand = torch.randn(dim, dim)
300 | U, _, _ = torch.linalg.svd(rand)
301 | classbasis = []
302 | ## generate basis
303 | for i in range(class_num):
304 | classbasis.append(U[:, per_class_dim * i:per_class_dim * (i+1)][None, :])
305 | classbasis = torch.cat(classbasis)
306 |
307 | ## generate sample
308 | data = []
309 | conds = []
310 | for cond in range(class_num):
311 | for idx in range(sample_per_class):
312 | data.append((classbasis[cond] @ torch.randn (per_class_dim, 1)).reshape((1, resolution, resolution, 3)))
313 | conds.append(cond)
314 | data = torch.cat(data)
315 | conds = torch.tensor(conds)
316 | if save_dataset:
317 | torch.save({
318 | "basis": classbasis,
319 | "space_basis": U,
320 | "data":data,
321 | "class_num": class_num,
322 | "sample_per_class": sample_per_class,
323 | "per_class_dim": per_class_dim,
324 | "resolution": resolution,
325 | "conds": conds,
326 | }, dataset_path)
327 | else:
328 | dataset = torch.load(dataset_path)
329 | resolution = dataset["resolution"]
330 | class_num = dataset["class_num"]
331 | per_class_dim = dataset["per_class_dim"]
332 | sample_per_class = dataset["sample_per_class"]
333 | data = dataset["data"]
334 | classbasis = dataset["basis"]
335 | conds = dataset["conds"]
336 | self.data = data
337 | self.conds = conds
338 | self.name = "MoLRG"
339 | self.resolution = resolution
340 | self.num_channels= 3
341 | self.label_dim= 0
342 | self.basis = classbasis
343 |
344 | def __len__(self):
345 | return self.data.shape[0]
346 |
347 | def __getitem__(self, idx):
348 | return (self.data[idx].permute((2, 0, 1)) + 1) * 127.5, self.conds[idx]
349 | #----------------------------------------------------------------------------
350 |
--------------------------------------------------------------------------------
/edm/training/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Loss functions used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import torch
12 | from torch_utils import persistence
13 |
14 | #----------------------------------------------------------------------------
15 | # Loss function corresponding to the variance preserving (VP) formulation
16 | # from the paper "Score-Based Generative Modeling through Stochastic
17 | # Differential Equations".
18 |
19 | @persistence.persistent_class
20 | class VPLoss:
21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
22 | self.beta_d = beta_d
23 | self.beta_min = beta_min
24 | self.epsilon_t = epsilon_t
25 |
26 | def __call__(self, net, images, labels, augment_pipe=None):
27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
29 | weight = 1 / sigma ** 2
30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
31 | n = torch.randn_like(y) * sigma
32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
33 | loss = weight * ((D_yn - y) ** 2)
34 | return loss
35 |
36 | def sigma(self, t):
37 | t = torch.as_tensor(t)
38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
39 |
40 | #----------------------------------------------------------------------------
41 | # Loss function corresponding to the variance exploding (VE) formulation
42 | # from the paper "Score-Based Generative Modeling through Stochastic
43 | # Differential Equations".
44 |
45 | @persistence.persistent_class
46 | class VELoss:
47 | def __init__(self, sigma_min=0.02, sigma_max=100):
48 | self.sigma_min = sigma_min
49 | self.sigma_max = sigma_max
50 |
51 | def __call__(self, net, images, labels, augment_pipe=None):
52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
54 | weight = 1 / sigma ** 2
55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
56 | n = torch.randn_like(y) * sigma
57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
58 | loss = weight * ((D_yn - y) ** 2)
59 | return loss
60 |
61 | #----------------------------------------------------------------------------
62 | # Improved loss function proposed in the paper "Elucidating the Design Space
63 | # of Diffusion-Based Generative Models" (EDM).
64 |
65 | @persistence.persistent_class
66 | class EDMLoss:
67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
68 | self.P_mean = P_mean
69 | self.P_std = P_std
70 | self.sigma_data = sigma_data
71 |
72 | def __call__(self, net, images, labels=None, augment_pipe=None):
73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
77 | n = torch.randn_like(y) * sigma
78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
79 | loss = weight * ((D_yn - y) ** 2)
80 | return loss
81 |
82 | #----------------------------------------------------------------------------
83 |
--------------------------------------------------------------------------------
/edm/training/training_loop.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Main training loop."""
9 |
10 | import os
11 | import time
12 | import copy
13 | import json
14 | import pickle
15 | import psutil
16 | import numpy as np
17 | import torch
18 | import dnnlib
19 | from torch_utils import distributed as dist
20 | from torch_utils import training_stats
21 | from torch_utils import misc
22 |
23 | #----------------------------------------------------------------------------
24 |
25 | def training_loop(
26 | run_dir = '.', # Output directory.
27 | dataset_kwargs = {}, # Options for training set.
28 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
29 | network_kwargs = {}, # Options for model and preconditioning.
30 | loss_kwargs = {}, # Options for loss function.
31 | optimizer_kwargs = {}, # Options for optimizer.
32 | augment_kwargs = None, # Options for augmentation pipeline, None = disable.
33 | seed = 0, # Global random seed.
34 | batch_size = 512, # Total batch size for one training iteration.
35 | batch_gpu = None, # Limit batch size per GPU, None = no limit.
36 | total_kimg = 200000, # Training duration, measured in thousands of training images.
37 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
38 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
39 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
40 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
41 | kimg_per_tick = 50, # Interval of progress prints.
42 | snapshot_ticks = 50, # How often to save network snapshots, None = disable.
43 | state_dump_ticks = 500, # How often to dump training state, None = disable.
44 | resume_pkl = None, # Start from the given network snapshot, None = random initialization.
45 | resume_state_dump = None, # Start from the given training state, None = reset training state.
46 | resume_kimg = 0, # Start from the given training progress.
47 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
48 | device = torch.device('cuda'),
49 | ):
50 | # Initialize.
51 | start_time = time.time()
52 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
53 | torch.manual_seed(np.random.randint(1 << 31))
54 | torch.backends.cudnn.benchmark = cudnn_benchmark
55 | torch.backends.cudnn.allow_tf32 = False
56 | torch.backends.cuda.matmul.allow_tf32 = False
57 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
58 |
59 | # Select batch size per GPU.
60 | batch_gpu_total = batch_size // dist.get_world_size()
61 | if batch_gpu is None or batch_gpu > batch_gpu_total:
62 | batch_gpu = batch_gpu_total
63 | num_accumulation_rounds = batch_gpu_total // batch_gpu
64 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
65 |
66 | # Load dataset.
67 | dist.print0('Loading dataset...')
68 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
69 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
70 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
71 |
72 | # Construct network.
73 | dist.print0('Constructing network...')
74 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
75 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
76 | net.train().requires_grad_(True).to(device)
77 | if dist.get_rank() == 0:
78 | with torch.no_grad():
79 | images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
80 | sigma = torch.ones([batch_gpu], device=device)
81 | labels = torch.zeros([batch_gpu, net.label_dim], device=device)
82 | misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
83 |
84 | # Setup optimizer.
85 | dist.print0('Setting up optimizer...')
86 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
87 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
88 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
89 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device])
90 | ema = copy.deepcopy(net).eval().requires_grad_(False)
91 |
92 | # Resume training from previous snapshot.
93 | if resume_pkl is not None:
94 | dist.print0(f'Loading network weights from "{resume_pkl}"...')
95 | if dist.get_rank() != 0:
96 | torch.distributed.barrier() # rank 0 goes first
97 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
98 | data = pickle.load(f)
99 | if dist.get_rank() == 0:
100 | torch.distributed.barrier() # other ranks follow
101 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
102 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
103 | del data # conserve memory
104 | if resume_state_dump:
105 | dist.print0(f'Loading training state from "{resume_state_dump}"...')
106 | data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
107 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
108 | optimizer.load_state_dict(data['optimizer_state'])
109 | del data # conserve memory
110 |
111 | # Train.
112 | dist.print0(f'Training for {total_kimg} kimg...')
113 | dist.print0()
114 | cur_nimg = resume_kimg * 1000
115 | cur_tick = 0
116 | tick_start_nimg = cur_nimg
117 | tick_start_time = time.time()
118 | maintenance_time = tick_start_time - start_time
119 | dist.update_progress(cur_nimg // 1000, total_kimg)
120 | stats_jsonl = None
121 | while True:
122 |
123 | # Accumulate gradients.
124 | optimizer.zero_grad(set_to_none=True)
125 | for round_idx in range(num_accumulation_rounds):
126 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
127 | images, labels = next(dataset_iterator)
128 | images = images.to(device).to(torch.float32) / 127.5 - 1
129 | labels = labels.to(device)
130 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
131 | training_stats.report('Loss/loss', loss)
132 | loss.sum().mul(loss_scaling / batch_gpu_total).backward()
133 |
134 | # Update weights.
135 | for g in optimizer.param_groups:
136 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
137 | for param in net.parameters():
138 | if param.grad is not None:
139 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
140 | optimizer.step()
141 |
142 | # Update EMA.
143 | ema_halflife_nimg = ema_halflife_kimg * 1000
144 | if ema_rampup_ratio is not None:
145 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
146 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
147 | for p_ema, p_net in zip(ema.parameters(), net.parameters()):
148 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
149 |
150 | # Perform maintenance tasks once per tick.
151 | cur_nimg += batch_size
152 | done = (cur_nimg >= total_kimg * 1000)
153 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
154 | continue
155 |
156 | # Print status line, accumulating the same information in training_stats.
157 | tick_end_time = time.time()
158 | fields = []
159 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
160 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
161 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
162 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
163 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
164 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
165 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
166 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
167 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
168 | torch.cuda.reset_peak_memory_stats()
169 | dist.print0(' '.join(fields))
170 |
171 | # Check for abort.
172 | if (not done) and dist.should_stop():
173 | done = True
174 | dist.print0()
175 | dist.print0('Aborting...')
176 |
177 | # Save network snapshot.
178 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
179 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
180 | for key, value in data.items():
181 | if isinstance(value, torch.nn.Module):
182 | value = copy.deepcopy(value).eval().requires_grad_(False)
183 | misc.check_ddp_consistency(value)
184 | data[key] = value.cpu()
185 | del value # conserve memory
186 | if dist.get_rank() == 0:
187 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
188 | pickle.dump(data, f)
189 | del data # conserve memory
190 |
191 | # Save full dump of the training state.
192 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
193 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
194 |
195 | # Update logs.
196 | training_stats.default_collector.update()
197 | if dist.get_rank() == 0:
198 | if stats_jsonl is None:
199 | stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
200 | stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
201 | stats_jsonl.flush()
202 | dist.update_progress(cur_nimg // 1000, total_kimg)
203 |
204 | # Update state.
205 | cur_tick += 1
206 | tick_start_nimg = cur_nimg
207 | tick_start_time = time.time()
208 | maintenance_time = tick_start_time - tick_end_time
209 | if done:
210 | break
211 |
212 | # Done.
213 | dist.print0()
214 | dist.print0('Exiting...')
215 |
216 | #----------------------------------------------------------------------------
217 |
--------------------------------------------------------------------------------
/edm/training/training_loop_MoLRG.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Main training loop."""
9 |
10 | import os
11 | import time
12 | import copy
13 | import json
14 | import pickle
15 | import psutil
16 | import numpy as np
17 | import torch
18 | import dnnlib
19 | from torch_utils import distributed as dist
20 | from torch_utils import training_stats
21 | from torch_utils import misc
22 | from training.networks import UNetBlock
23 | from training.networks import EDMPrecond
24 | #----------------------------------------------------------------------------
25 |
26 | def training_loop(
27 | run_dir = '.', # Output directory.
28 | dataset_kwargs = {}, # Options for training set.
29 | data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
30 | network_kwargs = {}, # Options for model and preconditioning.
31 | loss_kwargs = {}, # Options for loss function.
32 | optimizer_kwargs = {}, # Options for optimizer.
33 | augment_kwargs = None, # Options for augmentation pipeline, None = disable.
34 | seed = 0, # Global random seed.
35 | batch_size = 512, # Total batch size for one training iteration.
36 | batch_gpu = None, # Limit batch size per GPU, None = no limit.
37 | total_kimg = 200000, # Training duration, measured in thousands of training images.
38 | ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
39 | ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
40 | lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
41 | loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
42 | kimg_per_tick = 50, # Interval of progress prints.
43 | snapshot_ticks = 50, # How often to save network snapshots, None = disable.
44 | state_dump_ticks = 500, # How often to dump training state, None = disable.
45 | resume_pkl = None, # Start from the given network snapshot, None = random initialization.
46 | resume_state_dump = None, # Start from the given training state, None = reset training state.
47 | resume_kimg = 0, # Start from the given training progress.
48 | cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
49 | device = torch.device('cuda'),
50 | pretrained_model_path = None
51 | ):
52 | # device = torch.device(f'cuda:{dist.get_rank()}')
53 | # Initialize.
54 | start_time = time.time()
55 | np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
56 | torch.manual_seed(np.random.randint(1 << 31))
57 | torch.backends.cudnn.benchmark = cudnn_benchmark
58 | torch.backends.cudnn.allow_tf32 = False
59 | torch.backends.cuda.matmul.allow_tf32 = False
60 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
61 |
62 | # Select batch size per GPU.
63 | batch_gpu_total = batch_size // dist.get_world_size()
64 | if batch_gpu is None or batch_gpu > batch_gpu_total:
65 | batch_gpu = batch_gpu_total
66 | num_accumulation_rounds = batch_gpu_total // batch_gpu
67 | assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
68 |
69 | # Load dataset.
70 | dist.print0('Loading dataset...')
71 |
72 | dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
73 | dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
74 | dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
75 |
76 | # Construct network.
77 | dist.print0('Constructing network...')
78 | # if fine_tune:
79 | if 'cifar10' in dataset_kwargs['path']:
80 | interface_kwargs = dict(img_resolution=32, img_channels=3, label_dim=0)
81 | elif 'MoLRG' in dataset_kwargs["class_name"]:
82 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
83 | else:
84 | interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
85 | net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
86 |
87 | net.train().requires_grad_(True).to(device)
88 |
89 |
90 | # Setup optimizer.
91 | dist.print0('Setting up optimizer...')
92 | loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
93 | optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
94 | augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
95 | ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
96 | ema = copy.deepcopy(net).eval().requires_grad_(False)
97 |
98 |
99 | if pretrained_model_path and resume_pkl is None:
100 | if dist.get_rank() != 0:
101 | torch.distributed.barrier() # rank 0 goes first
102 | with dnnlib.util.open_url(pretrained_model_path, verbose=(dist.get_rank() == 0)) as f:
103 | data = pickle.load(f)
104 | if dist.get_rank() == 0:
105 | torch.distributed.barrier() # other ranks follow
106 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
107 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
108 | # Resume training from previous snapshot.
109 | if resume_pkl is not None:
110 | dist.print0(f'Loading network weights from "{resume_pkl}"...')
111 | if dist.get_rank() != 0:
112 | torch.distributed.barrier() # rank 0 goes first
113 | with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
114 | data = pickle.load(f)
115 | if dist.get_rank() == 0:
116 | torch.distributed.barrier() # other ranks follow
117 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
118 | misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
119 | del data # conserve memory
120 | if resume_state_dump:
121 | dist.print0(f'Loading training state from "{resume_state_dump}"...')
122 | data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
123 | misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
124 | optimizer.load_state_dict(data['optimizer_state'])
125 | del data # conserve memory
126 | dist.print0(f'Training for {total_kimg} kimg...')
127 | dist.print0()
128 | cur_nimg = resume_kimg * 1000
129 | cur_tick = 0
130 | tick_start_nimg = cur_nimg
131 | tick_start_time = time.time()
132 | maintenance_time = tick_start_time - start_time
133 | dist.update_progress(cur_nimg // 1000, total_kimg)
134 | stats_jsonl = None
135 | sd1 = copy.deepcopy(net.state_dict())
136 | while True:
137 |
138 | # Accumulate gradients.
139 | optimizer.zero_grad(set_to_none=True)
140 | for round_idx in range(num_accumulation_rounds):
141 | with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
142 | images, labels = next(dataset_iterator)
143 | images = images.to(device).to(torch.float32) / 127.5 - 1
144 | labels = labels.to(device)
145 | loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
146 | training_stats.report('Loss/loss', loss)
147 | loss.sum().mul(loss_scaling / batch_gpu_total).backward()
148 |
149 | # Update weights.
150 | for g in optimizer.param_groups:
151 | g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
152 | for param in net.parameters():
153 | if param.grad is not None:
154 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
155 | optimizer.step()
156 |
157 | # Update EMA.
158 | ema_halflife_nimg = ema_halflife_kimg * 1000
159 | if ema_rampup_ratio is not None:
160 | ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
161 | ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
162 | for p_ema, p_net in zip(ema.parameters(), net.parameters()):
163 | p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
164 |
165 | # Perform maintenance tasks once per tick.
166 | cur_nimg += batch_size
167 | done = (cur_nimg >= total_kimg * 1000)
168 | if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
169 | continue
170 |
171 | # Print status line, accumulating the same information in training_stats.
172 | tick_end_time = time.time()
173 | fields = []
174 | fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
175 | fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
176 | fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
177 | fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
178 | fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
179 | fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
180 | fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
181 | fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
182 | fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
183 | fields += [f"loss {training_stats.report('Loss/loss', loss.mean().mul(loss_scaling).item()):<6.5f}"]
184 |
185 | torch.cuda.reset_peak_memory_stats()
186 | dist.print0(' '.join(fields))
187 |
188 | # Check for abort.
189 | if (not done) and dist.should_stop():
190 | done = True
191 | dist.print0()
192 | dist.print0('Aborting...')
193 |
194 | # Save network snapshot.
195 | if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
196 | data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
197 | for key, value in data.items():
198 | if isinstance(value, torch.nn.Module):
199 | value = copy.deepcopy(value).eval().requires_grad_(False)
200 | misc.check_ddp_consistency(value)
201 | data[key] = value.cpu()
202 | del value # conserve memory
203 | if dist.get_rank() == 0:
204 | with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
205 | pickle.dump(data, f)
206 | del data # conserve memory
207 |
208 | # Save full dump of the training state.
209 | if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
210 | torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
211 |
212 | # Update logs.
213 | training_stats.default_collector.update()
214 | dist.update_progress(cur_nimg // 1000, total_kimg)
215 |
216 | # Update state.
217 | cur_tick += 1
218 | tick_start_nimg = cur_nimg
219 | tick_start_time = time.time()
220 | maintenance_time = tick_start_time - tick_end_time
221 | if done:
222 | break
223 |
224 | # Done.
225 | dist.print0()
226 | dist.print0('Exiting...')
227 |
228 | #----------------------------------------------------------------------------
229 |
--------------------------------------------------------------------------------
/figures/generalization-score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/generalization-score.png
--------------------------------------------------------------------------------
/figures/jacobian-MoLRG.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/jacobian-MoLRG.png
--------------------------------------------------------------------------------
/figures/jacobian-real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/jacobian-real.png
--------------------------------------------------------------------------------
/figures/optimal-denoiser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/optimal-denoiser.png
--------------------------------------------------------------------------------
/figures/reproducibility-score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/reproducibility-score.png
--------------------------------------------------------------------------------
/figures/similarity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huijieZH/Diffusion-Model-Generalizability/6c7576c8d2d9868013fc037744536a60cfb239fd/figures/similarity.png
--------------------------------------------------------------------------------