├── .gitignore ├── README.md ├── assets └── teaser.png ├── ct_eval.py ├── ct_train.py ├── dataset_tool.py ├── dnnlib ├── __init__.py └── util.py ├── generate.py ├── license ├── LICENSE_ECT ├── LICENSE_EDM ├── LICENSE_EDM2 ├── LICENSE_STYLEGAN2 └── LICENSE_TCM ├── make_grid.py ├── metrics ├── __init__.py ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── torch_utils ├── __init__.py ├── distributed.py ├── misc.py ├── persistence.py └── training_stats.py └── training ├── __init__.py ├── augment.py ├── ct_training_loop.py ├── dataset.py ├── loss.py ├── networks.py └── networks_tcm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | .DS_Store 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | 165 | # Project-related 166 | datasets 167 | 168 | ct-runs 169 | ct-evals 170 | 171 | .idea 172 | debug.sh 173 | *.out 174 | *.json 175 | 176 | 177 | wandb/* 178 | core.* 179 | *.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TCM: Truncated Consistency Models 2 | ![alt text](assets/teaser.png) 3 | This is the official implementation of our paper [Truncated Consistency Models](https://truncated-cm.github.io/) (ICLR 2025). 4 | 5 | - Tested environment: PyTorch 2.3.1, Python 3.8.13. 6 | - Dependencies: `pip install click Pillow psutil requests scipy tqdm matplotlib wandb` 7 | - We tested with NVIDIA A100 GPUs. 8 | 9 | # Dataset 10 | 11 | We use the CIFAR-10 and ImageNet datasets for training and evaluation. 12 | You should prepare CIFAR-10 following [EDM](https://github.com/NVlabs/edm), and ImageNet following [EDM2](https://github.com/NVlabs/edm2). 13 | 14 | # Pre-trained models 15 | The pre-trained TCM/EDM/EDM2 checkpoints are provided in this [link](https://drive.google.com/drive/folders/1gw6OMKCKaEe3LxSSlJKNhwG-M92u9DsW?usp=sharing). 16 | 17 | 18 | | Model | Dataset | 1-step FID | 2-step FID | 19 | |------------------------------------|----------------|------------|------------| 20 | | [TCM (DDPM++)](https://drive.google.com/file/d/1M9djFtf03acsmeAlseaOu2Z2AjQDE1Uz/view?usp=drive_link) | CIFAR-10 | 2.46 | 2.05 | 21 | | [TCM (EDM2-S)](https://drive.google.com/file/d/1RhM0f-SHb_qpiMV9Br2fjE5w2CVYiHFE/view?usp=drive_link) | ImageNet64x64 | 2.88 | 2.31 | 22 | | [TCM (EDM2-XL)](https://drive.google.com/file/d/1XE1RzrTI_dZq-jE0aZ6A5-5naslGaA_F/view?usp=drive_link) | ImageNet64x64 | 2.20 | 1.62 | 23 | 24 | 25 | 26 | # Training 27 | 28 | 29 | Stage-1 (CIFAR-10): 30 | ```bash 31 | iteration=3000000 32 | bsize=512 33 | double_iter=25000 34 | print_iter=100 35 | save_iter=10000 36 | vis_iter=1000 37 | 38 | # Convert to kimg 39 | duration=$(($iteration * $bsize / 1000000)) 40 | 41 | # Convert to ticks 42 | double=$(($double_iter/ 100)) 43 | snap=$(($save_iter / 100)) 44 | vis=$(($vis_iter / 100)) 45 | 46 | 47 | torchrun --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --outdir=ct-runs/09.26 --desc 09.26.cifar.stage1 --arch=ddpmpp \ 48 | --duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \ 49 | --dropout=0.2 --ratio_limit 0.999 --weighting default -c 1e-8 \ 50 | --transfer=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-vp.pkl \ 51 | --metrics='fid50k_full' --lr=1e-4 --cond=0 --optim=RAdam --augment=0.0 --mean=-1.1 --std=2.0 \ 52 | --data=../datasets/cifar10-32x32.zip 53 | ``` 54 | 55 | Stage-2 (CIFAR-10): 56 | ```bash 57 | iteration=3000000 58 | bsize=1024 59 | double_iter=25000 60 | print_iter=100 61 | save_iter=10000 62 | vis_iter=1000 63 | 64 | # Convert to kimg 65 | duration=$(($iteration * $bsize / 1000000)) 66 | 67 | # Convert to ticks 68 | double=$(($double_iter/ 100)) 69 | snap=$(($save_iter / 100)) 70 | vis=$(($vis_iter / 100)) 71 | 72 | 73 | torchrun --nnodes=1 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --outdir=ct-runs/09.26 --desc 09.26.cifar.stage2 --arch=ddpmpp \ 74 | --duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \ 75 | --dropout=0.2 --ratio_limit 0.999 --weighting default -c 1e-8 \ 76 | --tcm_transition_t=1 --t_lower=1 --tcm_teacher_pkl= \ 77 | --transfer= --resume_tick= \ 78 | --metrics='fid50k_full' --lr=5e-5 --cond=0 --optim=RAdam --augment=0.0 \ 79 | --mean=0 --std=0.2 --boundary_prob=0.25 --w_boundary=0.1 --tdist=t --df=0.01 \ 80 | --data=../datasets/cifar10-32x32.zip 81 | ``` 82 | 83 | Stage-1 (ImageNet 64x64): 84 | ```bash 85 | iteration=800000 86 | bsize=2048 87 | double_iter=10000 88 | print_iter=100 89 | save_iter=10000 90 | vis_iter=1000 91 | 92 | # Convert to kimg 93 | duration=$(($iteration * $bsize / 1000000)) 94 | 95 | # Convert to ticks 96 | double=$(($double_iter/ 100)) 97 | snap=$(($save_iter / 100)) 98 | vis=$(($vis_iter / 100)) 99 | 100 | torchrun --nnodes=8 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --cond=1 --outdir=ct-runs/ --desc imagenet.stage1 --fp16=1 --ls=16 --arch=edm2-img64-s --augment=0 \ 101 | --duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \ 102 | --dropout=0.4 --ratio_limit 0.9961 --weighting cout_sq -c 0.06\ 103 | --decay_iter=2000 --optim=Adam --rampup_iter=0 --lr=1e-3 --beta2=0.99 --metrics=fid50k_full --ema_gamma=16.97 --ema_type=power -q 4 \ 104 | --data=../datasets/imgnet64x64-dhariwal.zip --transfer=edm2-imgnet64-s.pth \ 105 | --mean=-0.8 --std=1.6 106 | ``` 107 | 108 | Stage-2 (ImageNet 64x64): 109 | ```bash 110 | iteration=800000 111 | bsize=1024 112 | double_iter=10000 113 | print_iter=100 114 | save_iter=10000 115 | vis_iter=1000 116 | 117 | # Convert to kimg 118 | duration=$(($iteration * $bsize / 1000000)) 119 | 120 | # Convert to ticks 121 | double=$(($double_iter/ 100)) 122 | snap=$(($save_iter / 100)) 123 | vis=$(($vis_iter / 100)) 124 | 125 | torchrun --nnodes=8 --nproc_per_node=8 --rdzv_backend=c10d --rdzv_endpoint=localhost:8888 ct_train.py --use_wandb False --cond=1 --outdir=ct-runs/ --desc imagenet.stage2 --fp16=1 --ls=16 --arch=edm2-img64-s --augment=0 \ 126 | --duration=$duration --double=$double --batch=$bsize --snap $snap --dump $snap --ckpt $snap --eval_every $snap --sample_every $vis \ 127 | --dropout=0.4 --ratio_limit 0.9961 --weighting cout_sq -c 0.06 \ 128 | --tcm_transition_t=1 --t_lower=1 --tcm_teacher_pkl= \ 129 | --transfer= --resume_tick= \ 130 | --decay_iter=8000 --optim=Adam --rampup_iter=0 --lr=5e-4 --beta2=0.99 --metrics=fid50k_full --ema_gamma=16.97 --ema_type=power -q 4 \ 131 | --data=../datasets/imgnet64x64-dhariwal.zip --mean=0 --std=0.2 --boundary_prob=0.25 --w_boundary=0.1 --tdist=t --df=0.01 132 | ``` 133 | 134 | # Sampling 135 | 136 | ```bash 137 | # Evaluate a CIFAR-10 model with 4 GPUs 138 | torchrun --standalone --nproc_per_node=4 ct_eval.py \ 139 | --outdir=ct-evals --data=../datasets/cifar10-32x32.zip \ 140 | --cond=0 --arch=ddpmpp --metrics=fid50k_full \ 141 | --resume cifar-stage2-2.46.pkl 142 | 143 | # Evaluate an ImageNet64 model (EDM2-S) with 4 GPUs 144 | torchrun --standalone --nproc_per_node=4 ct_eval.py \ 145 | --outdir=ct-evals --data=../datasets/imgnet64x64-dhariwal.zip \ 146 | --cond=1 --arch=edm2-img64-s --metrics=fid50k_full --fp16=1 \ 147 | --resume imgnet-stage2-edm2s-2.88.pkl 148 | ``` 149 | 150 | # Acknowledgement 151 | This code is built on the implementation of [EDM](https://github.com/NVlabs/edm) and [EDM2](https://github.com/NVlabs/edm2). We thank the authors for their great implementation. 152 | 153 | # License 154 | Copyright @ 2024, NVIDIA Corporation. All rights reserved. 155 | 156 | This work is made available under the Nvidia Source Code License-NC. 157 | 158 | The model checkpoints are shared under [`Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)`](https://creativecommons.org/licenses/by-nc-sa/4.0/). If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original. 159 | 160 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). 161 | 162 | 163 | # Citation 164 | ``` 165 | @misc{lee2024truncatedconsistencymodels, 166 | title={Truncated Consistency Models}, 167 | author={Sangyun Lee and Yilun Xu and Tomas Geffner and Giulia Fanti and Karsten Kreis and Arash Vahdat and Weili Nie}, 168 | year={2024}, 169 | eprint={2410.14895}, 170 | archivePrefix={arXiv}, 171 | primaryClass={cs.LG}, 172 | url={https://arxiv.org/abs/2410.14895}, 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/TCM/b132239ef96075ea5d24c32bb8af81ca76b1510f/assets/teaser.png -------------------------------------------------------------------------------- /dataset_tool.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/dataset_tool (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Tool for creating ZIP/PNG based datasets.""" 13 | 14 | import functools 15 | import gzip 16 | import io 17 | import json 18 | import os 19 | import pickle 20 | import re 21 | import sys 22 | import tarfile 23 | import zipfile 24 | from pathlib import Path 25 | from typing import Callable, Optional, Tuple, Union 26 | import click 27 | import numpy as np 28 | import PIL.Image 29 | from tqdm import tqdm 30 | 31 | #---------------------------------------------------------------------------- 32 | # Parse a 'M,N' or 'MxN' integer tuple. 33 | # Example: '4x2' returns (4,2) 34 | 35 | def parse_tuple(s: str) -> Tuple[int, int]: 36 | m = re.match(r'^(\d+)[x,](\d+)$', s) 37 | if m: 38 | return int(m.group(1)), int(m.group(2)) 39 | raise click.ClickException(f'cannot parse tuple {s}') 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def maybe_min(a: int, b: Optional[int]) -> int: 44 | if b is not None: 45 | return min(a, b) 46 | return a 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def file_ext(name: Union[str, Path]) -> str: 51 | return str(name).split('.')[-1] 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def is_image_ext(fname: Union[str, Path]) -> bool: 56 | ext = file_ext(fname).lower() 57 | return f'.{ext}' in PIL.Image.EXTENSION 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | def open_image_folder(source_dir, *, max_images: Optional[int]): 62 | input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)] 63 | arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images} 64 | max_idx = maybe_min(len(input_images), max_images) 65 | 66 | # Load labels. 67 | labels = dict() 68 | meta_fname = os.path.join(source_dir, 'dataset.json') 69 | if os.path.isfile(meta_fname): 70 | with open(meta_fname, 'r') as file: 71 | data = json.load(file)['labels'] 72 | if data is not None: 73 | labels = {x[0]: x[1] for x in data} 74 | 75 | # No labels available => determine from top-level directory names. 76 | if len(labels) == 0: 77 | toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()} 78 | toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))} 79 | if len(toplevel_indices) > 1: 80 | labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()} 81 | 82 | def iterate_images(): 83 | for idx, fname in enumerate(input_images): 84 | img = np.array(PIL.Image.open(fname)) 85 | yield dict(img=img, label=labels.get(arch_fnames.get(fname))) 86 | if idx >= max_idx - 1: 87 | break 88 | return max_idx, iterate_images() 89 | 90 | #---------------------------------------------------------------------------- 91 | 92 | def open_image_zip(source, *, max_images: Optional[int]): 93 | with zipfile.ZipFile(source, mode='r') as z: 94 | input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)] 95 | max_idx = maybe_min(len(input_images), max_images) 96 | 97 | # Load labels. 98 | labels = dict() 99 | if 'dataset.json' in z.namelist(): 100 | with z.open('dataset.json', 'r') as file: 101 | data = json.load(file)['labels'] 102 | if data is not None: 103 | labels = {x[0]: x[1] for x in data} 104 | 105 | def iterate_images(): 106 | with zipfile.ZipFile(source, mode='r') as z: 107 | for idx, fname in enumerate(input_images): 108 | with z.open(fname, 'r') as file: 109 | img = np.array(PIL.Image.open(file)) 110 | yield dict(img=img, label=labels.get(fname)) 111 | if idx >= max_idx - 1: 112 | break 113 | return max_idx, iterate_images() 114 | 115 | #---------------------------------------------------------------------------- 116 | 117 | def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]): 118 | import cv2 # pyright: ignore [reportMissingImports] # pip install opencv-python 119 | import lmdb # pyright: ignore [reportMissingImports] # pip install lmdb 120 | 121 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 122 | max_idx = maybe_min(txn.stat()['entries'], max_images) 123 | 124 | def iterate_images(): 125 | with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn: 126 | for idx, (_key, value) in enumerate(txn.cursor()): 127 | try: 128 | try: 129 | img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1) 130 | if img is None: 131 | raise IOError('cv2.imdecode failed') 132 | img = img[:, :, ::-1] # BGR => RGB 133 | except IOError: 134 | img = np.array(PIL.Image.open(io.BytesIO(value))) 135 | yield dict(img=img, label=None) 136 | if idx >= max_idx - 1: 137 | break 138 | except: 139 | print(sys.exc_info()[1]) 140 | 141 | return max_idx, iterate_images() 142 | 143 | #---------------------------------------------------------------------------- 144 | 145 | def open_cifar10(tarball: str, *, max_images: Optional[int]): 146 | images = [] 147 | labels = [] 148 | 149 | with tarfile.open(tarball, 'r:gz') as tar: 150 | for batch in range(1, 6): 151 | member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}') 152 | with tar.extractfile(member) as file: 153 | data = pickle.load(file, encoding='latin1') 154 | images.append(data['data'].reshape(-1, 3, 32, 32)) 155 | labels.append(data['labels']) 156 | 157 | images = np.concatenate(images) 158 | labels = np.concatenate(labels) 159 | images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC 160 | assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8 161 | assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64] 162 | assert np.min(images) == 0 and np.max(images) == 255 163 | assert np.min(labels) == 0 and np.max(labels) == 9 164 | 165 | max_idx = maybe_min(len(images), max_images) 166 | 167 | def iterate_images(): 168 | for idx, img in enumerate(images): 169 | yield dict(img=img, label=int(labels[idx])) 170 | if idx >= max_idx - 1: 171 | break 172 | 173 | return max_idx, iterate_images() 174 | 175 | #---------------------------------------------------------------------------- 176 | 177 | def open_mnist(images_gz: str, *, max_images: Optional[int]): 178 | labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz') 179 | assert labels_gz != images_gz 180 | images = [] 181 | labels = [] 182 | 183 | with gzip.open(images_gz, 'rb') as f: 184 | images = np.frombuffer(f.read(), np.uint8, offset=16) 185 | with gzip.open(labels_gz, 'rb') as f: 186 | labels = np.frombuffer(f.read(), np.uint8, offset=8) 187 | 188 | images = images.reshape(-1, 28, 28) 189 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 190 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 191 | assert labels.shape == (60000,) and labels.dtype == np.uint8 192 | assert np.min(images) == 0 and np.max(images) == 255 193 | assert np.min(labels) == 0 and np.max(labels) == 9 194 | 195 | max_idx = maybe_min(len(images), max_images) 196 | 197 | def iterate_images(): 198 | for idx, img in enumerate(images): 199 | yield dict(img=img, label=int(labels[idx])) 200 | if idx >= max_idx - 1: 201 | break 202 | 203 | return max_idx, iterate_images() 204 | 205 | #---------------------------------------------------------------------------- 206 | 207 | def make_transform( 208 | transform: Optional[str], 209 | output_width: Optional[int], 210 | output_height: Optional[int] 211 | ) -> Callable[[np.ndarray], Optional[np.ndarray]]: 212 | def scale(width, height, img): 213 | w = img.shape[1] 214 | h = img.shape[0] 215 | if width == w and height == h: 216 | return img 217 | img = PIL.Image.fromarray(img) 218 | ww = width if width is not None else w 219 | hh = height if height is not None else h 220 | img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS) 221 | return np.array(img) 222 | 223 | def center_crop(width, height, img): 224 | crop = np.min(img.shape[:2]) 225 | img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2] 226 | if img.ndim == 2: 227 | img = img[:, :, np.newaxis].repeat(3, axis=2) 228 | img = PIL.Image.fromarray(img, 'RGB') 229 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 230 | return np.array(img) 231 | 232 | def center_crop_wide(width, height, img): 233 | ch = int(np.round(width * img.shape[0] / img.shape[1])) 234 | if img.shape[1] < width or ch < height: 235 | return None 236 | 237 | img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2] 238 | if img.ndim == 2: 239 | img = img[:, :, np.newaxis].repeat(3, axis=2) 240 | img = PIL.Image.fromarray(img, 'RGB') 241 | img = img.resize((width, height), PIL.Image.Resampling.LANCZOS) 242 | img = np.array(img) 243 | 244 | canvas = np.zeros([width, width, 3], dtype=np.uint8) 245 | canvas[(width - height) // 2 : (width + height) // 2, :] = img 246 | return canvas 247 | 248 | if transform is None: 249 | return functools.partial(scale, output_width, output_height) 250 | if transform == 'center-crop': 251 | if output_width is None or output_height is None: 252 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform') 253 | return functools.partial(center_crop, output_width, output_height) 254 | if transform == 'center-crop-wide': 255 | if output_width is None or output_height is None: 256 | raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform') 257 | return functools.partial(center_crop_wide, output_width, output_height) 258 | assert False, 'unknown transform' 259 | 260 | #---------------------------------------------------------------------------- 261 | 262 | def open_dataset(source, *, max_images: Optional[int]): 263 | if os.path.isdir(source): 264 | if source.rstrip('/').endswith('_lmdb'): 265 | return open_lmdb(source, max_images=max_images) 266 | else: 267 | return open_image_folder(source, max_images=max_images) 268 | elif os.path.isfile(source): 269 | if os.path.basename(source) == 'cifar-10-python.tar.gz': 270 | return open_cifar10(source, max_images=max_images) 271 | elif os.path.basename(source) == 'train-images-idx3-ubyte.gz': 272 | return open_mnist(source, max_images=max_images) 273 | elif file_ext(source) == 'zip': 274 | return open_image_zip(source, max_images=max_images) 275 | else: 276 | assert False, 'unknown archive type' 277 | else: 278 | raise click.ClickException(f'Missing input file or directory: {source}') 279 | 280 | #---------------------------------------------------------------------------- 281 | 282 | def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]: 283 | dest_ext = file_ext(dest) 284 | 285 | if dest_ext == 'zip': 286 | if os.path.dirname(dest) != '': 287 | os.makedirs(os.path.dirname(dest), exist_ok=True) 288 | zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED) 289 | def zip_write_bytes(fname: str, data: Union[bytes, str]): 290 | zf.writestr(fname, data) 291 | return '', zip_write_bytes, zf.close 292 | else: 293 | # If the output folder already exists, check that is is 294 | # empty. 295 | # 296 | # Note: creating the output directory is not strictly 297 | # necessary as folder_write_bytes() also mkdirs, but it's better 298 | # to give an error message earlier in case the dest folder 299 | # somehow cannot be created. 300 | if os.path.isdir(dest) and len(os.listdir(dest)) != 0: 301 | raise click.ClickException('--dest folder must be empty') 302 | os.makedirs(dest, exist_ok=True) 303 | 304 | def folder_write_bytes(fname: str, data: Union[bytes, str]): 305 | os.makedirs(os.path.dirname(fname), exist_ok=True) 306 | with open(fname, 'wb') as fout: 307 | if isinstance(data, str): 308 | data = data.encode('utf8') 309 | fout.write(data) 310 | return dest, folder_write_bytes, lambda: None 311 | 312 | #---------------------------------------------------------------------------- 313 | 314 | @click.command() 315 | @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True) 316 | @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True) 317 | @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int) 318 | @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide'])) 319 | @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple) 320 | 321 | def main( 322 | source: str, 323 | dest: str, 324 | max_images: Optional[int], 325 | transform: Optional[str], 326 | resolution: Optional[Tuple[int, int]] 327 | ): 328 | """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch. 329 | 330 | The input dataset format is guessed from the --source argument: 331 | 332 | \b 333 | --source *_lmdb/ Load LSUN dataset 334 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 335 | --source train-images-idx3-ubyte.gz Load MNIST dataset 336 | --source path/ Recursively load all images from path/ 337 | --source dataset.zip Recursively load all images from dataset.zip 338 | 339 | Specifying the output format and path: 340 | 341 | \b 342 | --dest /path/to/dir Save output files under /path/to/dir 343 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 344 | 345 | The output dataset format can be either an image folder or an uncompressed zip archive. 346 | Zip archives makes it easier to move datasets around file servers and clusters, and may 347 | offer better training performance on network file systems. 348 | 349 | Images within the dataset archive will be stored as uncompressed PNG. 350 | Uncompresed PNGs can be efficiently decoded in the training loop. 351 | 352 | Class labels are stored in a file called 'dataset.json' that is stored at the 353 | dataset root folder. This file has the following structure: 354 | 355 | \b 356 | { 357 | "labels": [ 358 | ["00000/img00000000.png",6], 359 | ["00000/img00000001.png",9], 360 | ... repeated for every image in the datase 361 | ["00049/img00049999.png",1] 362 | ] 363 | } 364 | 365 | If the 'dataset.json' file cannot be found, class labels are determined from 366 | top-level directory names. 367 | 368 | Image scale/crop and resolution requirements: 369 | 370 | Output images must be square-shaped and they must all have the same power-of-two 371 | dimensions. 372 | 373 | To scale arbitrary input image size to a specific width and height, use the 374 | --resolution option. Output resolution will be either the original 375 | input resolution (if resolution was not specified) or the one specified with 376 | --resolution option. 377 | 378 | Use the --transform=center-crop or --transform=center-crop-wide options to apply a 379 | center crop transform on the input image. These options should be used with the 380 | --resolution option. For example: 381 | 382 | \b 383 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\ 384 | --transform=center-crop-wide --resolution=512x384 385 | """ 386 | 387 | PIL.Image.init() 388 | 389 | if dest == '': 390 | raise click.ClickException('--dest output filename or directory must not be an empty string') 391 | 392 | num_files, input_iter = open_dataset(source, max_images=max_images) 393 | archive_root_dir, save_bytes, close_dest = open_dest(dest) 394 | 395 | if resolution is None: resolution = (None, None) 396 | transform_image = make_transform(transform, *resolution) 397 | 398 | dataset_attrs = None 399 | 400 | labels = [] 401 | for idx, image in tqdm(enumerate(input_iter), total=num_files): 402 | idx_str = f'{idx:08d}' 403 | archive_fname = f'{idx_str[:5]}/img{idx_str}.png' 404 | 405 | # Apply crop and resize. 406 | img = transform_image(image['img']) 407 | if img is None: 408 | continue 409 | 410 | # Error check to require uniform image attributes across 411 | # the whole dataset. 412 | channels = img.shape[2] if img.ndim == 3 else 1 413 | cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0], 'channels': channels} 414 | if dataset_attrs is None: 415 | dataset_attrs = cur_image_attrs 416 | width = dataset_attrs['width'] 417 | height = dataset_attrs['height'] 418 | if width != height: 419 | raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}') 420 | if dataset_attrs['channels'] not in [1, 3]: 421 | raise click.ClickException('Input images must be stored as RGB or grayscale') 422 | if width != 2 ** int(np.floor(np.log2(width))): 423 | raise click.ClickException('Image width/height after scale and crop are required to be power-of-two') 424 | elif dataset_attrs != cur_image_attrs: 425 | err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] 426 | raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err)) 427 | 428 | # Save the image as an uncompressed PNG. 429 | img = PIL.Image.fromarray(img, {1: 'L', 3: 'RGB'}[channels]) 430 | image_bits = io.BytesIO() 431 | img.save(image_bits, format='png', compress_level=0, optimize=False) 432 | save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer()) 433 | labels.append([archive_fname, image['label']] if image['label'] is not None else None) 434 | 435 | metadata = {'labels': labels if all(x is not None for x in labels) else None} 436 | save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata)) 437 | close_dest() 438 | 439 | #---------------------------------------------------------------------------- 440 | 441 | if __name__ == "__main__": 442 | main() 443 | 444 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | from .util import EasyDict, make_cache_dir_path 13 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Miscellaneous utility classes and functions.""" 13 | 14 | import ctypes 15 | import fnmatch 16 | import importlib 17 | import inspect 18 | import numpy as np 19 | import os 20 | import shutil 21 | import sys 22 | import types 23 | import io 24 | import pickle 25 | import re 26 | import requests 27 | import html 28 | import hashlib 29 | import glob 30 | import tempfile 31 | import urllib 32 | import urllib.request 33 | import uuid 34 | 35 | from distutils.util import strtobool 36 | from typing import Any, List, Tuple, Union, Optional 37 | 38 | 39 | # Util classes 40 | # ------------------------------------------------------------------------------------------ 41 | 42 | 43 | class EasyDict(dict): 44 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 45 | 46 | def __getattr__(self, name: str) -> Any: 47 | try: 48 | return self[name] 49 | except KeyError: 50 | raise AttributeError(name) 51 | 52 | def __setattr__(self, name: str, value: Any) -> None: 53 | self[name] = value 54 | 55 | def __delattr__(self, name: str) -> None: 56 | del self[name] 57 | 58 | 59 | class Logger(object): 60 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 61 | 62 | def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): 63 | self.file = None 64 | 65 | if file_name is not None: 66 | self.file = open(file_name, file_mode) 67 | 68 | self.should_flush = should_flush 69 | self.stdout = sys.stdout 70 | self.stderr = sys.stderr 71 | 72 | sys.stdout = self 73 | sys.stderr = self 74 | 75 | def __enter__(self) -> "Logger": 76 | return self 77 | 78 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 79 | self.close() 80 | 81 | def write(self, text: Union[str, bytes]) -> None: 82 | """Write text to stdout (and a file) and optionally flush.""" 83 | if isinstance(text, bytes): 84 | text = text.decode() 85 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 86 | return 87 | 88 | if self.file is not None: 89 | self.file.write(text) 90 | 91 | self.stdout.write(text) 92 | 93 | if self.should_flush: 94 | self.flush() 95 | 96 | def flush(self) -> None: 97 | """Flush written text to both stdout and a file, if open.""" 98 | if self.file is not None: 99 | self.file.flush() 100 | 101 | self.stdout.flush() 102 | 103 | def close(self) -> None: 104 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 105 | self.flush() 106 | 107 | # if using multiple loggers, prevent closing in wrong order 108 | if sys.stdout is self: 109 | sys.stdout = self.stdout 110 | if sys.stderr is self: 111 | sys.stderr = self.stderr 112 | 113 | if self.file is not None: 114 | self.file.close() 115 | self.file = None 116 | 117 | 118 | # Cache directories 119 | # ------------------------------------------------------------------------------------------ 120 | 121 | _dnnlib_cache_dir = None 122 | 123 | def set_cache_dir(path: str) -> None: 124 | global _dnnlib_cache_dir 125 | _dnnlib_cache_dir = path 126 | 127 | def make_cache_dir_path(*paths: str) -> str: 128 | if _dnnlib_cache_dir is not None: 129 | return os.path.join(_dnnlib_cache_dir, *paths) 130 | if 'DNNLIB_CACHE_DIR' in os.environ: 131 | return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) 132 | if 'HOME' in os.environ: 133 | return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) 134 | if 'USERPROFILE' in os.environ: 135 | return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) 136 | return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) 137 | 138 | # Small util functions 139 | # ------------------------------------------------------------------------------------------ 140 | 141 | 142 | def format_time(seconds: Union[int, float]) -> str: 143 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 144 | s = int(np.rint(seconds)) 145 | 146 | if s < 60: 147 | return "{0}s".format(s) 148 | elif s < 60 * 60: 149 | return "{0}m {1:02}s".format(s // 60, s % 60) 150 | elif s < 24 * 60 * 60: 151 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 152 | else: 153 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) 154 | 155 | 156 | def format_time_brief(seconds: Union[int, float]) -> str: 157 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 158 | s = int(np.rint(seconds)) 159 | 160 | if s < 60: 161 | return "{0}s".format(s) 162 | elif s < 60 * 60: 163 | return "{0}m {1:02}s".format(s // 60, s % 60) 164 | elif s < 24 * 60 * 60: 165 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 166 | else: 167 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 168 | 169 | 170 | def ask_yes_no(question: str) -> bool: 171 | """Ask the user the question until the user inputs a valid answer.""" 172 | while True: 173 | try: 174 | print("{0} [y/n]".format(question)) 175 | return strtobool(input().lower()) 176 | except ValueError: 177 | pass 178 | 179 | 180 | def tuple_product(t: Tuple) -> Any: 181 | """Calculate the product of the tuple elements.""" 182 | result = 1 183 | 184 | for v in t: 185 | result *= v 186 | 187 | return result 188 | 189 | 190 | _str_to_ctype = { 191 | "uint8": ctypes.c_ubyte, 192 | "uint16": ctypes.c_uint16, 193 | "uint32": ctypes.c_uint32, 194 | "uint64": ctypes.c_uint64, 195 | "int8": ctypes.c_byte, 196 | "int16": ctypes.c_int16, 197 | "int32": ctypes.c_int32, 198 | "int64": ctypes.c_int64, 199 | "float32": ctypes.c_float, 200 | "float64": ctypes.c_double 201 | } 202 | 203 | 204 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 205 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 206 | type_str = None 207 | 208 | if isinstance(type_obj, str): 209 | type_str = type_obj 210 | elif hasattr(type_obj, "__name__"): 211 | type_str = type_obj.__name__ 212 | elif hasattr(type_obj, "name"): 213 | type_str = type_obj.name 214 | else: 215 | raise RuntimeError("Cannot infer type name from input") 216 | 217 | assert type_str in _str_to_ctype.keys() 218 | 219 | my_dtype = np.dtype(type_str) 220 | my_ctype = _str_to_ctype[type_str] 221 | 222 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 223 | 224 | return my_dtype, my_ctype 225 | 226 | 227 | def is_pickleable(obj: Any) -> bool: 228 | try: 229 | with io.BytesIO() as stream: 230 | pickle.dump(obj, stream) 231 | return True 232 | except: 233 | return False 234 | 235 | 236 | # Functionality to import modules/objects by name, and call functions by name 237 | # ------------------------------------------------------------------------------------------ 238 | 239 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 240 | """Searches for the underlying module behind the name to some python object. 241 | Returns the module and the object name (original name with module part removed).""" 242 | 243 | # allow convenience shorthands, substitute them by full names 244 | obj_name = re.sub("^np.", "numpy.", obj_name) 245 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 246 | 247 | # list alternatives for (module_name, local_obj_name) 248 | parts = obj_name.split(".") 249 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] 250 | 251 | # try each alternative in turn 252 | for module_name, local_obj_name in name_pairs: 253 | try: 254 | module = importlib.import_module(module_name) # may raise ImportError 255 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 256 | return module, local_obj_name 257 | except: 258 | pass 259 | 260 | # maybe some of the modules themselves contain errors? 261 | for module_name, _local_obj_name in name_pairs: 262 | try: 263 | importlib.import_module(module_name) # may raise ImportError 264 | except ImportError: 265 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): 266 | raise 267 | 268 | # maybe the requested attribute is missing? 269 | for module_name, local_obj_name in name_pairs: 270 | try: 271 | module = importlib.import_module(module_name) # may raise ImportError 272 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 273 | except ImportError: 274 | pass 275 | 276 | # we are out of luck, but we have no idea why 277 | raise ImportError(obj_name) 278 | 279 | 280 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 281 | """Traverses the object name and returns the last (rightmost) python object.""" 282 | if obj_name == '': 283 | return module 284 | obj = module 285 | for part in obj_name.split("."): 286 | obj = getattr(obj, part) 287 | return obj 288 | 289 | 290 | def get_obj_by_name(name: str) -> Any: 291 | """Finds the python object with the given name.""" 292 | module, obj_name = get_module_from_obj_name(name) 293 | return get_obj_from_module(module, obj_name) 294 | 295 | 296 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 297 | """Finds the python object with the given name and calls it as a function.""" 298 | assert func_name is not None 299 | func_obj = get_obj_by_name(func_name) 300 | assert callable(func_obj) 301 | return func_obj(*args, **kwargs) 302 | 303 | 304 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 305 | """Finds the python class with the given name and constructs it with the given arguments.""" 306 | return call_func_by_name(*args, func_name=class_name, **kwargs) 307 | 308 | 309 | def get_module_dir_by_obj_name(obj_name: str) -> str: 310 | """Get the directory path of the module containing the given object name.""" 311 | module, _ = get_module_from_obj_name(obj_name) 312 | return os.path.dirname(inspect.getfile(module)) 313 | 314 | 315 | def is_top_level_function(obj: Any) -> bool: 316 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 317 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 318 | 319 | 320 | def get_top_level_function_name(obj: Any) -> str: 321 | """Return the fully-qualified name of a top-level function.""" 322 | assert is_top_level_function(obj) 323 | module = obj.__module__ 324 | if module == '__main__': 325 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 326 | return module + "." + obj.__name__ 327 | 328 | 329 | # File system helpers 330 | # ------------------------------------------------------------------------------------------ 331 | 332 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: 333 | """List all files recursively in a given directory while ignoring given file and directory names. 334 | Returns list of tuples containing both absolute and relative paths.""" 335 | assert os.path.isdir(dir_path) 336 | base_name = os.path.basename(os.path.normpath(dir_path)) 337 | 338 | if ignores is None: 339 | ignores = [] 340 | 341 | result = [] 342 | 343 | for root, dirs, files in os.walk(dir_path, topdown=True): 344 | for ignore_ in ignores: 345 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 346 | 347 | # dirs need to be edited in-place 348 | for d in dirs_to_remove: 349 | dirs.remove(d) 350 | 351 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 352 | 353 | absolute_paths = [os.path.join(root, f) for f in files] 354 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 355 | 356 | if add_base_to_relative: 357 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 358 | 359 | assert len(absolute_paths) == len(relative_paths) 360 | result += zip(absolute_paths, relative_paths) 361 | 362 | return result 363 | 364 | 365 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 366 | """Takes in a list of tuples of (src, dst) paths and copies files. 367 | Will create all necessary directories.""" 368 | for file in files: 369 | target_dir_name = os.path.dirname(file[1]) 370 | 371 | # will create all intermediate-level directories 372 | if not os.path.exists(target_dir_name): 373 | os.makedirs(target_dir_name) 374 | 375 | shutil.copyfile(file[0], file[1]) 376 | 377 | 378 | # URL helpers 379 | # ------------------------------------------------------------------------------------------ 380 | 381 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 382 | """Determine whether the given object is a valid URL string.""" 383 | if not isinstance(obj, str) or not "://" in obj: 384 | return False 385 | if allow_file_urls and obj.startswith('file://'): 386 | return True 387 | try: 388 | res = requests.compat.urlparse(obj) 389 | if not res.scheme or not res.netloc or not "." in res.netloc: 390 | return False 391 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 392 | if not res.scheme or not res.netloc or not "." in res.netloc: 393 | return False 394 | except: 395 | return False 396 | return True 397 | 398 | 399 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: 400 | """Download the given URL and return a binary-mode file object to access the data.""" 401 | assert num_attempts >= 1 402 | assert not (return_filename and (not cache)) 403 | 404 | # Doesn't look like an URL scheme so interpret it as a local filename. 405 | if not re.match('^[a-z]+://', url): 406 | return url if return_filename else open(url, "rb") 407 | 408 | # Handle file URLs. This code handles unusual file:// patterns that 409 | # arise on Windows: 410 | # 411 | # file:///c:/foo.txt 412 | # 413 | # which would translate to a local '/c:/foo.txt' filename that's 414 | # invalid. Drop the forward slash for such pathnames. 415 | # 416 | # If you touch this code path, you should test it on both Linux and 417 | # Windows. 418 | # 419 | # Some internet resources suggest using urllib.request.url2pathname() but 420 | # but that converts forward slashes to backslashes and this causes 421 | # its own set of problems. 422 | if url.startswith('file://'): 423 | filename = urllib.parse.urlparse(url).path 424 | if re.match(r'^/[a-zA-Z]:', filename): 425 | filename = filename[1:] 426 | return filename if return_filename else open(filename, "rb") 427 | 428 | assert is_url(url) 429 | 430 | # Lookup from cache. 431 | if cache_dir is None: 432 | cache_dir = make_cache_dir_path('downloads') 433 | 434 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 435 | if cache: 436 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 437 | if len(cache_files) == 1: 438 | filename = cache_files[0] 439 | return filename if return_filename else open(filename, "rb") 440 | 441 | # Download. 442 | url_name = None 443 | url_data = None 444 | with requests.Session() as session: 445 | if verbose: 446 | print("Downloading %s ..." % url, end="", flush=True) 447 | for attempts_left in reversed(range(num_attempts)): 448 | try: 449 | with session.get(url) as res: 450 | res.raise_for_status() 451 | if len(res.content) == 0: 452 | raise IOError("No data received") 453 | 454 | if len(res.content) < 8192: 455 | content_str = res.content.decode("utf-8") 456 | if "download_warning" in res.headers.get("Set-Cookie", ""): 457 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 458 | if len(links) == 1: 459 | url = requests.compat.urljoin(url, links[0]) 460 | raise IOError("Google Drive virus checker nag") 461 | if "Google Drive - Quota exceeded" in content_str: 462 | raise IOError("Google Drive download quota exceeded -- please try again later") 463 | 464 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 465 | url_name = match[1] if match else url 466 | url_data = res.content 467 | if verbose: 468 | print(" done") 469 | break 470 | except KeyboardInterrupt: 471 | raise 472 | except: 473 | if not attempts_left: 474 | if verbose: 475 | print(" failed") 476 | raise 477 | if verbose: 478 | print(".", end="", flush=True) 479 | 480 | # Save to cache. 481 | if cache: 482 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 483 | safe_name = safe_name[:min(len(safe_name), 128)] 484 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 485 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) 486 | os.makedirs(cache_dir, exist_ok=True) 487 | with open(temp_file, "wb") as f: 488 | f.write(url_data) 489 | os.replace(temp_file, cache_file) # atomic 490 | if return_filename: 491 | return cache_file 492 | 493 | # Return data as file object. 494 | assert not return_filename 495 | return io.BytesIO(url_data) 496 | -------------------------------------------------------------------------------- /license/LICENSE_ECT: -------------------------------------------------------------------------------- 1 | No license information. -------------------------------------------------------------------------------- /license/LICENSE_EDM: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 60 | Public License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 65 | ("Public License"). To the extent this Public License may be 66 | interpreted as a contract, You are granted the Licensed Rights in 67 | consideration of Your acceptance of these terms and conditions, and the 68 | Licensor grants You such rights in consideration of benefits the 69 | Licensor receives from making the Licensed Material available under 70 | these terms and conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. BY-NC-SA Compatible License means a license listed at 90 | creativecommons.org/compatiblelicenses, approved by Creative 91 | Commons as essentially the equivalent of this Public License. 92 | 93 | d. Copyright and Similar Rights means copyright and/or similar rights 94 | closely related to copyright including, without limitation, 95 | performance, broadcast, sound recording, and Sui Generis Database 96 | Rights, without regard to how the rights are labeled or 97 | categorized. For purposes of this Public License, the rights 98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 99 | Rights. 100 | 101 | e. Effective Technological Measures means those measures that, in the 102 | absence of proper authority, may not be circumvented under laws 103 | fulfilling obligations under Article 11 of the WIPO Copyright 104 | Treaty adopted on December 20, 1996, and/or similar international 105 | agreements. 106 | 107 | f. Exceptions and Limitations means fair use, fair dealing, and/or 108 | any other exception or limitation to Copyright and Similar Rights 109 | that applies to Your use of the Licensed Material. 110 | 111 | g. License Elements means the license attributes listed in the name 112 | of a Creative Commons Public License. The License Elements of this 113 | Public License are Attribution, NonCommercial, and ShareAlike. 114 | 115 | h. Licensed Material means the artistic or literary work, database, 116 | or other material to which the Licensor applied this Public 117 | License. 118 | 119 | i. Licensed Rights means the rights granted to You subject to the 120 | terms and conditions of this Public License, which are limited to 121 | all Copyright and Similar Rights that apply to Your use of the 122 | Licensed Material and that the Licensor has authority to license. 123 | 124 | j. Licensor means the individual(s) or entity(ies) granting rights 125 | under this Public License. 126 | 127 | k. NonCommercial means not primarily intended for or directed towards 128 | commercial advantage or monetary compensation. For purposes of 129 | this Public License, the exchange of the Licensed Material for 130 | other material subject to Copyright and Similar Rights by digital 131 | file-sharing or similar means is NonCommercial provided there is 132 | no payment of monetary compensation in connection with the 133 | exchange. 134 | 135 | l. Share means to provide material to the public by any means or 136 | process that requires permission under the Licensed Rights, such 137 | as reproduction, public display, public performance, distribution, 138 | dissemination, communication, or importation, and to make material 139 | available to the public including in ways that members of the 140 | public may access the material from a place and at a time 141 | individually chosen by them. 142 | 143 | m. Sui Generis Database Rights means rights other than copyright 144 | resulting from Directive 96/9/EC of the European Parliament and of 145 | the Council of 11 March 1996 on the legal protection of databases, 146 | as amended and/or succeeded, as well as other essentially 147 | equivalent rights anywhere in the world. 148 | 149 | n. You means the individual or entity exercising the Licensed Rights 150 | under this Public License. Your has a corresponding meaning. 151 | 152 | 153 | Section 2 -- Scope. 154 | 155 | a. License grant. 156 | 157 | 1. Subject to the terms and conditions of this Public License, 158 | the Licensor hereby grants You a worldwide, royalty-free, 159 | non-sublicensable, non-exclusive, irrevocable license to 160 | exercise the Licensed Rights in the Licensed Material to: 161 | 162 | a. reproduce and Share the Licensed Material, in whole or 163 | in part, for NonCommercial purposes only; and 164 | 165 | b. produce, reproduce, and Share Adapted Material for 166 | NonCommercial purposes only. 167 | 168 | 2. Exceptions and Limitations. For the avoidance of doubt, where 169 | Exceptions and Limitations apply to Your use, this Public 170 | License does not apply, and You do not need to comply with 171 | its terms and conditions. 172 | 173 | 3. Term. The term of this Public License is specified in Section 174 | 6(a). 175 | 176 | 4. Media and formats; technical modifications allowed. The 177 | Licensor authorizes You to exercise the Licensed Rights in 178 | all media and formats whether now known or hereafter created, 179 | and to make technical modifications necessary to do so. The 180 | Licensor waives and/or agrees not to assert any right or 181 | authority to forbid You from making technical modifications 182 | necessary to exercise the Licensed Rights, including 183 | technical modifications necessary to circumvent Effective 184 | Technological Measures. For purposes of this Public License, 185 | simply making modifications authorized by this Section 2(a) 186 | (4) never produces Adapted Material. 187 | 188 | 5. Downstream recipients. 189 | 190 | a. Offer from the Licensor -- Licensed Material. Every 191 | recipient of the Licensed Material automatically 192 | receives an offer from the Licensor to exercise the 193 | Licensed Rights under the terms and conditions of this 194 | Public License. 195 | 196 | b. Additional offer from the Licensor -- Adapted Material. 197 | Every recipient of Adapted Material from You 198 | automatically receives an offer from the Licensor to 199 | exercise the Licensed Rights in the Adapted Material 200 | under the conditions of the Adapter's License You apply. 201 | 202 | c. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 3. If requested by the Licensor, You must remove any of the 282 | information required by Section 3(a)(1)(A) to the extent 283 | reasonably practicable. 284 | 285 | b. ShareAlike. 286 | 287 | In addition to the conditions in Section 3(a), if You Share 288 | Adapted Material You produce, the following conditions also apply. 289 | 290 | 1. The Adapter's License You apply must be a Creative Commons 291 | license with the same License Elements, this version or 292 | later, or a BY-NC-SA Compatible License. 293 | 294 | 2. You must include the text of, or the URI or hyperlink to, the 295 | Adapter's License You apply. You may satisfy this condition 296 | in any reasonable manner based on the medium, means, and 297 | context in which You Share Adapted Material. 298 | 299 | 3. You may not offer or impose any additional or different terms 300 | or conditions on, or apply any Effective Technological 301 | Measures to, Adapted Material that restrict exercise of the 302 | rights granted under the Adapter's License You apply. 303 | 304 | 305 | Section 4 -- Sui Generis Database Rights. 306 | 307 | Where the Licensed Rights include Sui Generis Database Rights that 308 | apply to Your use of the Licensed Material: 309 | 310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 311 | to extract, reuse, reproduce, and Share all or a substantial 312 | portion of the contents of the database for NonCommercial purposes 313 | only; 314 | 315 | b. if You include all or a substantial portion of the database 316 | contents in a database in which You have Sui Generis Database 317 | Rights, then the database in which You have Sui Generis Database 318 | Rights (but not its individual contents) is Adapted Material, 319 | including for purposes of Section 3(b); and 320 | 321 | c. You must comply with the conditions in Section 3(a) if You Share 322 | all or a substantial portion of the contents of the database. 323 | 324 | For the avoidance of doubt, this Section 4 supplements and does not 325 | replace Your obligations under this Public License where the Licensed 326 | Rights include other Copyright and Similar Rights. 327 | 328 | 329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 330 | 331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 341 | 342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 351 | 352 | c. The disclaimer of warranties and limitation of liability provided 353 | above shall be interpreted in a manner that, to the extent 354 | possible, most closely approximates an absolute disclaimer and 355 | waiver of all liability. 356 | 357 | 358 | Section 6 -- Term and Termination. 359 | 360 | a. This Public License applies for the term of the Copyright and 361 | Similar Rights licensed here. However, if You fail to comply with 362 | this Public License, then Your rights under this Public License 363 | terminate automatically. 364 | 365 | b. Where Your right to use the Licensed Material has terminated under 366 | Section 6(a), it reinstates: 367 | 368 | 1. automatically as of the date the violation is cured, provided 369 | it is cured within 30 days of Your discovery of the 370 | violation; or 371 | 372 | 2. upon express reinstatement by the Licensor. 373 | 374 | For the avoidance of doubt, this Section 6(b) does not affect any 375 | right the Licensor may have to seek remedies for Your violations 376 | of this Public License. 377 | 378 | c. For the avoidance of doubt, the Licensor may also offer the 379 | Licensed Material under separate terms or conditions or stop 380 | distributing the Licensed Material at any time; however, doing so 381 | will not terminate this Public License. 382 | 383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 384 | License. 385 | 386 | 387 | Section 7 -- Other Terms and Conditions. 388 | 389 | a. The Licensor shall not be bound by any additional or different 390 | terms or conditions communicated by You unless expressly agreed. 391 | 392 | b. Any arrangements, understandings, or agreements regarding the 393 | Licensed Material not stated herein are separate from and 394 | independent of the terms and conditions of this Public License. 395 | 396 | 397 | Section 8 -- Interpretation. 398 | 399 | a. For the avoidance of doubt, this Public License does not, and 400 | shall not be interpreted to, reduce, limit, restrict, or impose 401 | conditions on any use of the Licensed Material that could lawfully 402 | be made without permission under this Public License. 403 | 404 | b. To the extent possible, if any provision of this Public License is 405 | deemed unenforceable, it shall be automatically reformed to the 406 | minimum extent necessary to make it enforceable. If the provision 407 | cannot be reformed, it shall be severed from this Public License 408 | without affecting the enforceability of the remaining terms and 409 | conditions. 410 | 411 | c. No term or condition of this Public License will be waived and no 412 | failure to comply consented to unless expressly agreed to by the 413 | Licensor. 414 | 415 | d. Nothing in this Public License constitutes or may be interpreted 416 | as a limitation upon, or waiver of, any privileges and immunities 417 | that apply to the Licensor or You, including from the legal 418 | processes of any jurisdiction or authority. 419 | 420 | ======================================================================= 421 | 422 | Creative Commons is not a party to its public 423 | licenses. Notwithstanding, Creative Commons may elect to apply one of 424 | its public licenses to material it publishes and in those instances 425 | will be considered the "Licensor." The text of the Creative Commons 426 | public licenses is dedicated to the public domain under the CC0 Public 427 | Domain Dedication. Except for the limited purpose of indicating that 428 | material is shared under a Creative Commons public license or as 429 | otherwise permitted by the Creative Commons policies published at 430 | creativecommons.org/policies, Creative Commons does not authorize the 431 | use of the trademark "Creative Commons" or any other trademark or logo 432 | of Creative Commons without its prior written consent including, 433 | without limitation, in connection with any unauthorized modifications 434 | to any of its public licenses or any other arrangements, 435 | understandings, or agreements concerning use of licensed material. For 436 | the avoidance of doubt, this paragraph does not form part of the 437 | public licenses. 438 | 439 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /license/LICENSE_EDM2: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | ======================================================================= 6 | 7 | Creative Commons Corporation ("Creative Commons") is not a law firm and 8 | does not provide legal services or legal advice. Distribution of 9 | Creative Commons public licenses does not create a lawyer-client or 10 | other relationship. Creative Commons makes its licenses and related 11 | information available on an "as-is" basis. Creative Commons gives no 12 | warranties regarding its licenses, any material licensed under their 13 | terms and conditions, or any related information. Creative Commons 14 | disclaims all liability for damages resulting from their use to the 15 | fullest extent possible. 16 | 17 | Using Creative Commons Public Licenses 18 | 19 | Creative Commons public licenses provide a standard set of terms and 20 | conditions that creators and other rights holders may use to share 21 | original works of authorship and other material subject to copyright 22 | and certain other rights specified in the public license below. The 23 | following considerations are for informational purposes only, are not 24 | exhaustive, and do not form part of our licenses. 25 | 26 | Considerations for licensors: Our public licenses are 27 | intended for use by those authorized to give the public 28 | permission to use material in ways otherwise restricted by 29 | copyright and certain other rights. Our licenses are 30 | irrevocable. Licensors should read and understand the terms 31 | and conditions of the license they choose before applying it. 32 | Licensors should also secure all rights necessary before 33 | applying our licenses so that the public can reuse the 34 | material as expected. Licensors should clearly mark any 35 | material not subject to the license. This includes other CC- 36 | licensed material, or material used under an exception or 37 | limitation to copyright. More considerations for licensors: 38 | wiki.creativecommons.org/Considerations_for_licensors 39 | 40 | Considerations for the public: By using one of our public 41 | licenses, a licensor grants the public permission to use the 42 | licensed material under specified terms and conditions. If 43 | the licensor's permission is not necessary for any reason--for 44 | example, because of any applicable exception or limitation to 45 | copyright--then that use is not regulated by the license. Our 46 | licenses grant only permissions under copyright and certain 47 | other rights that a licensor has authority to grant. Use of 48 | the licensed material may still be restricted for other 49 | reasons, including because others have copyright or other 50 | rights in the material. A licensor may make special requests, 51 | such as asking that all changes be marked or described. 52 | Although not required by our licenses, you are encouraged to 53 | respect those requests where reasonable. More considerations 54 | for the public: 55 | wiki.creativecommons.org/Considerations_for_licensees 56 | 57 | ======================================================================= 58 | 59 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 60 | Public License 61 | 62 | By exercising the Licensed Rights (defined below), You accept and agree 63 | to be bound by the terms and conditions of this Creative Commons 64 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 65 | ("Public License"). To the extent this Public License may be 66 | interpreted as a contract, You are granted the Licensed Rights in 67 | consideration of Your acceptance of these terms and conditions, and the 68 | Licensor grants You such rights in consideration of benefits the 69 | Licensor receives from making the Licensed Material available under 70 | these terms and conditions. 71 | 72 | 73 | Section 1 -- Definitions. 74 | 75 | a. Adapted Material means material subject to Copyright and Similar 76 | Rights that is derived from or based upon the Licensed Material 77 | and in which the Licensed Material is translated, altered, 78 | arranged, transformed, or otherwise modified in a manner requiring 79 | permission under the Copyright and Similar Rights held by the 80 | Licensor. For purposes of this Public License, where the Licensed 81 | Material is a musical work, performance, or sound recording, 82 | Adapted Material is always produced where the Licensed Material is 83 | synched in timed relation with a moving image. 84 | 85 | b. Adapter's License means the license You apply to Your Copyright 86 | and Similar Rights in Your contributions to Adapted Material in 87 | accordance with the terms and conditions of this Public License. 88 | 89 | c. BY-NC-SA Compatible License means a license listed at 90 | creativecommons.org/compatiblelicenses, approved by Creative 91 | Commons as essentially the equivalent of this Public License. 92 | 93 | d. Copyright and Similar Rights means copyright and/or similar rights 94 | closely related to copyright including, without limitation, 95 | performance, broadcast, sound recording, and Sui Generis Database 96 | Rights, without regard to how the rights are labeled or 97 | categorized. For purposes of this Public License, the rights 98 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 99 | Rights. 100 | 101 | e. Effective Technological Measures means those measures that, in the 102 | absence of proper authority, may not be circumvented under laws 103 | fulfilling obligations under Article 11 of the WIPO Copyright 104 | Treaty adopted on December 20, 1996, and/or similar international 105 | agreements. 106 | 107 | f. Exceptions and Limitations means fair use, fair dealing, and/or 108 | any other exception or limitation to Copyright and Similar Rights 109 | that applies to Your use of the Licensed Material. 110 | 111 | g. License Elements means the license attributes listed in the name 112 | of a Creative Commons Public License. The License Elements of this 113 | Public License are Attribution, NonCommercial, and ShareAlike. 114 | 115 | h. Licensed Material means the artistic or literary work, database, 116 | or other material to which the Licensor applied this Public 117 | License. 118 | 119 | i. Licensed Rights means the rights granted to You subject to the 120 | terms and conditions of this Public License, which are limited to 121 | all Copyright and Similar Rights that apply to Your use of the 122 | Licensed Material and that the Licensor has authority to license. 123 | 124 | j. Licensor means the individual(s) or entity(ies) granting rights 125 | under this Public License. 126 | 127 | k. NonCommercial means not primarily intended for or directed towards 128 | commercial advantage or monetary compensation. For purposes of 129 | this Public License, the exchange of the Licensed Material for 130 | other material subject to Copyright and Similar Rights by digital 131 | file-sharing or similar means is NonCommercial provided there is 132 | no payment of monetary compensation in connection with the 133 | exchange. 134 | 135 | l. Share means to provide material to the public by any means or 136 | process that requires permission under the Licensed Rights, such 137 | as reproduction, public display, public performance, distribution, 138 | dissemination, communication, or importation, and to make material 139 | available to the public including in ways that members of the 140 | public may access the material from a place and at a time 141 | individually chosen by them. 142 | 143 | m. Sui Generis Database Rights means rights other than copyright 144 | resulting from Directive 96/9/EC of the European Parliament and of 145 | the Council of 11 March 1996 on the legal protection of databases, 146 | as amended and/or succeeded, as well as other essentially 147 | equivalent rights anywhere in the world. 148 | 149 | n. You means the individual or entity exercising the Licensed Rights 150 | under this Public License. Your has a corresponding meaning. 151 | 152 | 153 | Section 2 -- Scope. 154 | 155 | a. License grant. 156 | 157 | 1. Subject to the terms and conditions of this Public License, 158 | the Licensor hereby grants You a worldwide, royalty-free, 159 | non-sublicensable, non-exclusive, irrevocable license to 160 | exercise the Licensed Rights in the Licensed Material to: 161 | 162 | a. reproduce and Share the Licensed Material, in whole or 163 | in part, for NonCommercial purposes only; and 164 | 165 | b. produce, reproduce, and Share Adapted Material for 166 | NonCommercial purposes only. 167 | 168 | 2. Exceptions and Limitations. For the avoidance of doubt, where 169 | Exceptions and Limitations apply to Your use, this Public 170 | License does not apply, and You do not need to comply with 171 | its terms and conditions. 172 | 173 | 3. Term. The term of this Public License is specified in Section 174 | 6(a). 175 | 176 | 4. Media and formats; technical modifications allowed. The 177 | Licensor authorizes You to exercise the Licensed Rights in 178 | all media and formats whether now known or hereafter created, 179 | and to make technical modifications necessary to do so. The 180 | Licensor waives and/or agrees not to assert any right or 181 | authority to forbid You from making technical modifications 182 | necessary to exercise the Licensed Rights, including 183 | technical modifications necessary to circumvent Effective 184 | Technological Measures. For purposes of this Public License, 185 | simply making modifications authorized by this Section 2(a) 186 | (4) never produces Adapted Material. 187 | 188 | 5. Downstream recipients. 189 | 190 | a. Offer from the Licensor -- Licensed Material. Every 191 | recipient of the Licensed Material automatically 192 | receives an offer from the Licensor to exercise the 193 | Licensed Rights under the terms and conditions of this 194 | Public License. 195 | 196 | b. Additional offer from the Licensor -- Adapted Material. 197 | Every recipient of Adapted Material from You 198 | automatically receives an offer from the Licensor to 199 | exercise the Licensed Rights in the Adapted Material 200 | under the conditions of the Adapter's License You apply. 201 | 202 | c. No downstream restrictions. You may not offer or impose 203 | any additional or different terms or conditions on, or 204 | apply any Effective Technological Measures to, the 205 | Licensed Material if doing so restricts exercise of the 206 | Licensed Rights by any recipient of the Licensed 207 | Material. 208 | 209 | 6. No endorsement. Nothing in this Public License constitutes or 210 | may be construed as permission to assert or imply that You 211 | are, or that Your use of the Licensed Material is, connected 212 | with, or sponsored, endorsed, or granted official status by, 213 | the Licensor or others designated to receive attribution as 214 | provided in Section 3(a)(1)(A)(i). 215 | 216 | b. Other rights. 217 | 218 | 1. Moral rights, such as the right of integrity, are not 219 | licensed under this Public License, nor are publicity, 220 | privacy, and/or other similar personality rights; however, to 221 | the extent possible, the Licensor waives and/or agrees not to 222 | assert any such rights held by the Licensor to the limited 223 | extent necessary to allow You to exercise the Licensed 224 | Rights, but not otherwise. 225 | 226 | 2. Patent and trademark rights are not licensed under this 227 | Public License. 228 | 229 | 3. To the extent possible, the Licensor waives any right to 230 | collect royalties from You for the exercise of the Licensed 231 | Rights, whether directly or through a collecting society 232 | under any voluntary or waivable statutory or compulsory 233 | licensing scheme. In all other cases the Licensor expressly 234 | reserves any right to collect such royalties, including when 235 | the Licensed Material is used other than for NonCommercial 236 | purposes. 237 | 238 | 239 | Section 3 -- License Conditions. 240 | 241 | Your exercise of the Licensed Rights is expressly made subject to the 242 | following conditions. 243 | 244 | a. Attribution. 245 | 246 | 1. If You Share the Licensed Material (including in modified 247 | form), You must: 248 | 249 | a. retain the following if it is supplied by the Licensor 250 | with the Licensed Material: 251 | 252 | i. identification of the creator(s) of the Licensed 253 | Material and any others designated to receive 254 | attribution, in any reasonable manner requested by 255 | the Licensor (including by pseudonym if 256 | designated); 257 | 258 | ii. a copyright notice; 259 | 260 | iii. a notice that refers to this Public License; 261 | 262 | iv. a notice that refers to the disclaimer of 263 | warranties; 264 | 265 | v. a URI or hyperlink to the Licensed Material to the 266 | extent reasonably practicable; 267 | 268 | b. indicate if You modified the Licensed Material and 269 | retain an indication of any previous modifications; and 270 | 271 | c. indicate the Licensed Material is licensed under this 272 | Public License, and include the text of, or the URI or 273 | hyperlink to, this Public License. 274 | 275 | 2. You may satisfy the conditions in Section 3(a)(1) in any 276 | reasonable manner based on the medium, means, and context in 277 | which You Share the Licensed Material. For example, it may be 278 | reasonable to satisfy the conditions by providing a URI or 279 | hyperlink to a resource that includes the required 280 | information. 281 | 3. If requested by the Licensor, You must remove any of the 282 | information required by Section 3(a)(1)(A) to the extent 283 | reasonably practicable. 284 | 285 | b. ShareAlike. 286 | 287 | In addition to the conditions in Section 3(a), if You Share 288 | Adapted Material You produce, the following conditions also apply. 289 | 290 | 1. The Adapter's License You apply must be a Creative Commons 291 | license with the same License Elements, this version or 292 | later, or a BY-NC-SA Compatible License. 293 | 294 | 2. You must include the text of, or the URI or hyperlink to, the 295 | Adapter's License You apply. You may satisfy this condition 296 | in any reasonable manner based on the medium, means, and 297 | context in which You Share Adapted Material. 298 | 299 | 3. You may not offer or impose any additional or different terms 300 | or conditions on, or apply any Effective Technological 301 | Measures to, Adapted Material that restrict exercise of the 302 | rights granted under the Adapter's License You apply. 303 | 304 | 305 | Section 4 -- Sui Generis Database Rights. 306 | 307 | Where the Licensed Rights include Sui Generis Database Rights that 308 | apply to Your use of the Licensed Material: 309 | 310 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 311 | to extract, reuse, reproduce, and Share all or a substantial 312 | portion of the contents of the database for NonCommercial purposes 313 | only; 314 | 315 | b. if You include all or a substantial portion of the database 316 | contents in a database in which You have Sui Generis Database 317 | Rights, then the database in which You have Sui Generis Database 318 | Rights (but not its individual contents) is Adapted Material, 319 | including for purposes of Section 3(b); and 320 | 321 | c. You must comply with the conditions in Section 3(a) if You Share 322 | all or a substantial portion of the contents of the database. 323 | 324 | For the avoidance of doubt, this Section 4 supplements and does not 325 | replace Your obligations under this Public License where the Licensed 326 | Rights include other Copyright and Similar Rights. 327 | 328 | 329 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 330 | 331 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 332 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 333 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 334 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 335 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 336 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 337 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 338 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 339 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 340 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 341 | 342 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 343 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 344 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 345 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 346 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 347 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 348 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 349 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 350 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 351 | 352 | c. The disclaimer of warranties and limitation of liability provided 353 | above shall be interpreted in a manner that, to the extent 354 | possible, most closely approximates an absolute disclaimer and 355 | waiver of all liability. 356 | 357 | 358 | Section 6 -- Term and Termination. 359 | 360 | a. This Public License applies for the term of the Copyright and 361 | Similar Rights licensed here. However, if You fail to comply with 362 | this Public License, then Your rights under this Public License 363 | terminate automatically. 364 | 365 | b. Where Your right to use the Licensed Material has terminated under 366 | Section 6(a), it reinstates: 367 | 368 | 1. automatically as of the date the violation is cured, provided 369 | it is cured within 30 days of Your discovery of the 370 | violation; or 371 | 372 | 2. upon express reinstatement by the Licensor. 373 | 374 | For the avoidance of doubt, this Section 6(b) does not affect any 375 | right the Licensor may have to seek remedies for Your violations 376 | of this Public License. 377 | 378 | c. For the avoidance of doubt, the Licensor may also offer the 379 | Licensed Material under separate terms or conditions or stop 380 | distributing the Licensed Material at any time; however, doing so 381 | will not terminate this Public License. 382 | 383 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 384 | License. 385 | 386 | 387 | Section 7 -- Other Terms and Conditions. 388 | 389 | a. The Licensor shall not be bound by any additional or different 390 | terms or conditions communicated by You unless expressly agreed. 391 | 392 | b. Any arrangements, understandings, or agreements regarding the 393 | Licensed Material not stated herein are separate from and 394 | independent of the terms and conditions of this Public License. 395 | 396 | 397 | Section 8 -- Interpretation. 398 | 399 | a. For the avoidance of doubt, this Public License does not, and 400 | shall not be interpreted to, reduce, limit, restrict, or impose 401 | conditions on any use of the Licensed Material that could lawfully 402 | be made without permission under this Public License. 403 | 404 | b. To the extent possible, if any provision of this Public License is 405 | deemed unenforceable, it shall be automatically reformed to the 406 | minimum extent necessary to make it enforceable. If the provision 407 | cannot be reformed, it shall be severed from this Public License 408 | without affecting the enforceability of the remaining terms and 409 | conditions. 410 | 411 | c. No term or condition of this Public License will be waived and no 412 | failure to comply consented to unless expressly agreed to by the 413 | Licensor. 414 | 415 | d. Nothing in this Public License constitutes or may be interpreted 416 | as a limitation upon, or waiver of, any privileges and immunities 417 | that apply to the Licensor or You, including from the legal 418 | processes of any jurisdiction or authority. 419 | 420 | ======================================================================= 421 | 422 | Creative Commons is not a party to its public 423 | licenses. Notwithstanding, Creative Commons may elect to apply one of 424 | its public licenses to material it publishes and in those instances 425 | will be considered the "Licensor." The text of the Creative Commons 426 | public licenses is dedicated to the public domain under the CC0 Public 427 | Domain Dedication. Except for the limited purpose of indicating that 428 | material is shared under a Creative Commons public license or as 429 | otherwise permitted by the Creative Commons policies published at 430 | creativecommons.org/policies, Creative Commons does not authorize the 431 | use of the trademark "Creative Commons" or any other trademark or logo 432 | of Creative Commons without its prior written consent including, 433 | without limitation, in connection with any unauthorized modifications 434 | to any of its public licenses or any other arrangements, 435 | understandings, or agreements concerning use of licensed material. For 436 | the avoidance of doubt, this paragraph does not form part of the 437 | public licenses. 438 | 439 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /license/LICENSE_STYLEGAN2: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= -------------------------------------------------------------------------------- /license/LICENSE_TCM: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for TCM 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /make_grid.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for TCM. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | from torchvision.utils import make_grid, save_image 9 | import torchvision.transforms.functional as TF 10 | import argparse 11 | import torch 12 | import os 13 | parser = argparse.ArgumentParser(description='Configs') 14 | 15 | parser.add_argument('--dir', type=str, help='image dir') 16 | parser.add_argument('--save_dir', type=str, default = "grid.png", help='image save dir') 17 | parser.add_argument('--size', type=str, default = '8', help='grid size') 18 | parser.add_argument('--shuffle', action='store_true', help='shuffle images') 19 | 20 | 21 | arg = parser.parse_args() 22 | 23 | from PIL import Image 24 | import glob 25 | 26 | # Separate arg.size by comma 27 | size_list = arg.size.split(',') 28 | arg.size = [int(i) for i in size_list] 29 | if len(arg.size) == 1: 30 | h,w = arg.size[0], arg.size[0] 31 | elif len(arg.size) == 2: 32 | h,w = arg.size[0], arg.size[1] 33 | else: 34 | raise ValueError("arg.size should be 1 or 2 elements") 35 | 36 | files = glob.glob(os.path.join(arg.dir, "*.png")) + glob.glob(os.path.join(arg.dir, "*.jpg")) 37 | files = sorted(files) 38 | if arg.shuffle: 39 | import random 40 | random.shuffle(files) 41 | print(f"len(files): {len(files)}") 42 | files = files[:h*w] 43 | img_list = [] 44 | for file in files: 45 | img = Image.open(file) 46 | img = TF.to_tensor(img) 47 | img_list.append(img) 48 | imgs = torch.stack(img_list) 49 | grid = make_grid(imgs, nrow=w, padding = 0) 50 | save_image(grid, arg.save_dir) -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | # empty 13 | -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Frechet Inception Distance (FID) from the paper 13 | "GANs trained by a two time-scale update rule converge to a local Nash 14 | equilibrium". Matches the original implementation by Heusel et al. at 15 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 16 | 17 | import numpy as np 18 | import scipy.linalg 19 | from . import metric_utils 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def compute_fid(opts, max_real, num_gen, t_max=None): 24 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 25 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 26 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 27 | 28 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 29 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 30 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 31 | 32 | if t_max is None: 33 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 34 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 35 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 36 | else: 37 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator_with_data( 38 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 39 | rel_lo=0, rel_hi=1, t_max=t_max, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 40 | 41 | if opts.rank != 0: 42 | return float('nan') 43 | 44 | m = np.square(mu_gen - mu_real).sum() 45 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 46 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 47 | return float(fid) 48 | 49 | #---------------------------------------------------------------------------- 50 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Inception Score (IS) from the paper "Improved techniques for training 13 | GANs". Matches the original implementation by Salimans et al. at 14 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 15 | 16 | import numpy as np 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_is(opts, num_gen, num_splits): 22 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 23 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 24 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 25 | 26 | gen_probs = metric_utils.compute_feature_stats_for_generator( 27 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 28 | capture_all=True, max_items=num_gen).get_all() 29 | 30 | if opts.rank != 0: 31 | return float('nan'), float('nan') 32 | 33 | scores = [] 34 | for i in range(num_splits): 35 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 36 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 37 | kl = np.mean(np.sum(kl, axis=1)) 38 | scores.append(np.exp(kl)) 39 | return float(np.mean(scores)), float(np.std(scores)) 40 | 41 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 13 | GANs". Matches the original implementation by Binkowski et al. at 14 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 15 | 16 | import numpy as np 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 22 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 23 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 24 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 25 | 26 | real_features = metric_utils.compute_feature_stats_for_dataset( 27 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 28 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 29 | 30 | gen_features = metric_utils.compute_feature_stats_for_generator( 31 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 32 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 33 | 34 | if opts.rank != 0: 35 | return float('nan') 36 | 37 | n = real_features.shape[1] 38 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 39 | t = 0 40 | for _subset_idx in range(num_subsets): 41 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 42 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 43 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 44 | b = (x @ y.T / n + 1) ** 3 45 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 46 | kid = t / num_subsets / m 47 | return float(kid) 48 | 49 | #---------------------------------------------------------------------------- 50 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # The modifications to this file are subject to the same license. 11 | # --------------------------------------------------------------- 12 | 13 | import os 14 | import time 15 | import json 16 | import torch 17 | import dnnlib 18 | 19 | from . import metric_utils 20 | from . import frechet_inception_distance 21 | from . import kernel_inception_distance 22 | from . import precision_recall 23 | from . import perceptual_path_length 24 | from . import inception_score 25 | #---------------------------------------------------------------------------- 26 | 27 | _metric_dict = dict() # name => fn 28 | 29 | def register_metric(fn): 30 | assert callable(fn) 31 | _metric_dict[fn.__name__] = fn 32 | return fn 33 | 34 | def is_valid_metric(metric): 35 | return metric in _metric_dict 36 | 37 | def list_valid_metrics(): 38 | return list(_metric_dict.keys()) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 43 | assert is_valid_metric(metric) 44 | opts = metric_utils.MetricOptions(**kwargs) 45 | 46 | # Calculate. 47 | start_time = time.time() 48 | results = _metric_dict[metric](opts) 49 | total_time = time.time() - start_time 50 | 51 | # Broadcast results. 52 | for key, value in list(results.items()): 53 | if opts.num_gpus > 1: 54 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 55 | torch.distributed.broadcast(tensor=value, src=0) 56 | value = float(value.cpu()) 57 | results[key] = value 58 | 59 | # Decorate with metadata. 60 | return dnnlib.EasyDict( 61 | results = dnnlib.EasyDict(results), 62 | metric = metric, 63 | total_time = total_time, 64 | total_time_str = dnnlib.util.format_time(total_time), 65 | num_gpus = opts.num_gpus, 66 | ) 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 71 | metric = result_dict['metric'] 72 | assert is_valid_metric(metric) 73 | if run_dir is not None and snapshot_pkl is not None: 74 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 75 | 76 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 77 | print(jsonl_line) 78 | if run_dir is not None and os.path.isdir(run_dir): 79 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 80 | f.write(jsonl_line + '\n') 81 | 82 | #---------------------------------------------------------------------------- 83 | # Primary metrics. 84 | 85 | @register_metric 86 | def fid50k_full(opts): 87 | opts.dataset_kwargs.update(max_size=None, xflip=False) 88 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 89 | return dict(fid50k_full=fid) 90 | 91 | @register_metric 92 | def fid50k_full_denoising(opts): 93 | opts.dataset_kwargs.update(max_size=None, xflip=False) 94 | d = {} 95 | 96 | for t_max in opts.dfid_ts: 97 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000, t_max=t_max) 98 | d[f'fid50k_full_denoising_{t_max:.3g}'] = fid 99 | return d 100 | 101 | @register_metric 102 | def kid50k_full(opts): 103 | opts.dataset_kwargs.update(max_size=None, xflip=False) 104 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 105 | return dict(kid50k_full=kid) 106 | 107 | @register_metric 108 | def pr50k3_full(opts): 109 | opts.dataset_kwargs.update(max_size=None, xflip=False) 110 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 111 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 112 | 113 | @register_metric 114 | def ppl2_wend(opts): 115 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 116 | return dict(ppl2_wend=ppl) 117 | 118 | @register_metric 119 | def is50k(opts): 120 | opts.dataset_kwargs.update(max_size=None, xflip=False) 121 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 122 | return dict(is50k_mean=mean, is50k_std=std) 123 | 124 | #---------------------------------------------------------------------------- 125 | # Legacy metrics. 126 | 127 | @register_metric 128 | def fid50k(opts): 129 | opts.dataset_kwargs.update(max_size=None) 130 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 131 | return dict(fid50k=fid) 132 | 133 | @register_metric 134 | def kid50k(opts): 135 | opts.dataset_kwargs.update(max_size=None) 136 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 137 | return dict(kid50k=kid) 138 | 139 | @register_metric 140 | def pr50k3(opts): 141 | opts.dataset_kwargs.update(max_size=None) 142 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 143 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 144 | 145 | @register_metric 146 | def ppl_zfull(opts): 147 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) 148 | return dict(ppl_zfull=ppl) 149 | 150 | @register_metric 151 | def ppl_wfull(opts): 152 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) 153 | return dict(ppl_wfull=ppl) 154 | 155 | @register_metric 156 | def ppl_zend(opts): 157 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) 158 | return dict(ppl_zend=ppl) 159 | 160 | @register_metric 161 | def ppl_wend(opts): 162 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) 163 | return dict(ppl_wend=ppl) 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /metrics/metric_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # The modifications to this file are subject to the same license. 11 | # --------------------------------------------------------------- 12 | 13 | import os 14 | import time 15 | import hashlib 16 | import pickle 17 | import copy 18 | import uuid 19 | import numpy as np 20 | import torch 21 | import dnnlib 22 | from torchvision.utils import save_image 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | class MetricOptions: 27 | def __init__(self, generator_fn=None, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, im_dir=None, latent_dir=None, dfid_ts=None): 28 | assert 0 <= rank < num_gpus 29 | 30 | self.generator_fn = generator_fn 31 | self.G = G 32 | self.G_kwargs = dnnlib.EasyDict(G_kwargs) 33 | self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) 34 | self.num_gpus = num_gpus 35 | self.rank = rank 36 | self.device = device if device is not None else torch.device('cuda', rank) 37 | self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() 38 | self.cache = cache 39 | self.im_dir = im_dir 40 | self.latent_dir = latent_dir 41 | self.dfid_ts = dfid_ts 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | _feature_detector_cache = dict() 46 | 47 | def get_feature_detector_name(url): 48 | return os.path.splitext(url.split('/')[-1])[0] 49 | 50 | def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): 51 | assert 0 <= rank < num_gpus 52 | key = (url, device) 53 | if key not in _feature_detector_cache: 54 | is_leader = (rank == 0) 55 | if not is_leader and num_gpus > 1: 56 | torch.distributed.barrier() # leader goes first 57 | with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: 58 | _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) 59 | if is_leader and num_gpus > 1: 60 | torch.distributed.barrier() # others follow 61 | return _feature_detector_cache[key] 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | class FeatureStats: 66 | def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): 67 | self.capture_all = capture_all 68 | self.capture_mean_cov = capture_mean_cov 69 | self.max_items = max_items 70 | self.num_items = 0 71 | self.num_features = None 72 | self.all_features = None 73 | self.raw_mean = None 74 | self.raw_cov = None 75 | 76 | def set_num_features(self, num_features): 77 | if self.num_features is not None: 78 | assert num_features == self.num_features 79 | else: 80 | self.num_features = num_features 81 | self.all_features = [] 82 | self.raw_mean = np.zeros([num_features], dtype=np.float64) 83 | self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) 84 | 85 | def is_full(self): 86 | return (self.max_items is not None) and (self.num_items >= self.max_items) 87 | 88 | def append(self, x): 89 | x = np.asarray(x, dtype=np.float32) 90 | assert x.ndim == 2 91 | if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): 92 | if self.num_items >= self.max_items: 93 | return 94 | x = x[:self.max_items - self.num_items] 95 | 96 | self.set_num_features(x.shape[1]) 97 | self.num_items += x.shape[0] 98 | if self.capture_all: 99 | self.all_features.append(x) 100 | if self.capture_mean_cov: 101 | x64 = x.astype(np.float64) 102 | self.raw_mean += x64.sum(axis=0) 103 | self.raw_cov += x64.T @ x64 104 | 105 | def append_torch(self, x, num_gpus=1, rank=0): 106 | assert isinstance(x, torch.Tensor) and x.ndim == 2 107 | assert 0 <= rank < num_gpus 108 | if num_gpus > 1: 109 | ys = [] 110 | for src in range(num_gpus): 111 | y = x.clone() 112 | torch.distributed.broadcast(y, src=src) 113 | ys.append(y) 114 | x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples 115 | self.append(x.cpu().numpy()) 116 | 117 | def get_all(self): 118 | assert self.capture_all 119 | return np.concatenate(self.all_features, axis=0) 120 | 121 | def get_all_torch(self): 122 | return torch.from_numpy(self.get_all()) 123 | 124 | def get_mean_cov(self): 125 | assert self.capture_mean_cov 126 | mean = self.raw_mean / self.num_items 127 | cov = self.raw_cov / self.num_items 128 | cov = cov - np.outer(mean, mean) 129 | return mean, cov 130 | 131 | def save(self, pkl_file): 132 | with open(pkl_file, 'wb') as f: 133 | pickle.dump(self.__dict__, f) 134 | 135 | @staticmethod 136 | def load(pkl_file): 137 | with open(pkl_file, 'rb') as f: 138 | s = dnnlib.EasyDict(pickle.load(f)) 139 | obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) 140 | obj.__dict__.update(s) 141 | return obj 142 | 143 | #---------------------------------------------------------------------------- 144 | 145 | class ProgressMonitor: 146 | def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): 147 | self.tag = tag 148 | self.num_items = num_items 149 | self.verbose = verbose 150 | self.flush_interval = flush_interval 151 | self.progress_fn = progress_fn 152 | self.pfn_lo = pfn_lo 153 | self.pfn_hi = pfn_hi 154 | self.pfn_total = pfn_total 155 | self.start_time = time.time() 156 | self.batch_time = self.start_time 157 | self.batch_items = 0 158 | if self.progress_fn is not None: 159 | self.progress_fn(self.pfn_lo, self.pfn_total) 160 | 161 | def update(self, cur_items): 162 | assert (self.num_items is None) or (cur_items <= self.num_items) 163 | if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): 164 | return 165 | cur_time = time.time() 166 | total_time = cur_time - self.start_time 167 | time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) 168 | if (self.verbose) and (self.tag is not None): 169 | print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') 170 | self.batch_time = cur_time 171 | self.batch_items = cur_items 172 | 173 | if (self.progress_fn is not None) and (self.num_items is not None): 174 | self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) 175 | 176 | def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): 177 | return ProgressMonitor( 178 | tag = tag, 179 | num_items = num_items, 180 | flush_interval = flush_interval, 181 | verbose = self.verbose, 182 | progress_fn = self.progress_fn, 183 | pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, 184 | pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, 185 | pfn_total = self.pfn_total, 186 | ) 187 | 188 | #---------------------------------------------------------------------------- 189 | 190 | def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=128, data_loader_kwargs=None, max_items=None, **stats_kwargs): 191 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 192 | if data_loader_kwargs is None: 193 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) 194 | 195 | # Try to lookup from cache. 196 | cache_file = None 197 | if opts.cache: 198 | # Choose cache file name. 199 | args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs) 200 | md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) 201 | cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' 202 | cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') 203 | 204 | # Check if the file exists (all processes must agree). 205 | flag = os.path.isfile(cache_file) if opts.rank == 0 else False 206 | if opts.num_gpus > 1: 207 | flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) 208 | torch.distributed.broadcast(tensor=flag, src=0) 209 | flag = (float(flag.cpu()) != 0) 210 | 211 | # Load. 212 | if flag: 213 | return FeatureStats.load(cache_file) 214 | 215 | # Initialize. 216 | num_items = len(dataset) 217 | if max_items is not None: 218 | num_items = min(num_items, max_items) 219 | stats = FeatureStats(max_items=num_items, **stats_kwargs) 220 | progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) 221 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 222 | 223 | # Main loop. 224 | item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] 225 | for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): 226 | if images.shape[1] == 1: 227 | images = images.repeat([1, 3, 1, 1]) 228 | features = detector(images.to(opts.device), **detector_kwargs) 229 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 230 | progress.update(stats.num_items) 231 | 232 | # Save to cache. 233 | if cache_file is not None and opts.rank == 0: 234 | os.makedirs(os.path.dirname(cache_file), exist_ok=True) 235 | temp_file = cache_file + '.' + uuid.uuid4().hex 236 | stats.save(temp_file) 237 | os.replace(temp_file, cache_file) # atomic 238 | return stats 239 | 240 | #---------------------------------------------------------------------------- 241 | 242 | def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=128, batch_gen=128, jit=False, **stats_kwargs): 243 | if batch_gen is None: 244 | batch_gen = min(batch_size, 4) 245 | assert batch_size % batch_gen == 0 246 | 247 | # Setup generator and load labels. 248 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 249 | generator_fn = opts.generator_fn 250 | 251 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 252 | 253 | # Image generation func. 254 | def run_generator(z, c, data=None): 255 | img = generator_fn(G, z, c, t_max=80, data=data, **opts.G_kwargs) 256 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 257 | return img 258 | 259 | # JIT. 260 | if jit: 261 | # TODO: Add an init_fn for this. 262 | z = torch.zeros([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 263 | c = torch.zeros([batch_gen, G.label_dim], device=opts.device) 264 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) 265 | 266 | # Initialize. 267 | stats = FeatureStats(**stats_kwargs) 268 | assert stats.max_items is not None 269 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) 270 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 271 | 272 | # Main loop. 273 | while not stats.is_full(): 274 | images = [] 275 | for _i in range(batch_size // batch_gen): 276 | z = torch.randn([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 277 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] 278 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 279 | images.append(run_generator(z, c)) 280 | images = torch.cat(images) 281 | if images.shape[1] == 1: 282 | images = images.repeat([1, 3, 1, 1]) 283 | 284 | features = detector(images, **detector_kwargs) 285 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 286 | progress.update(stats.num_items) 287 | return stats 288 | 289 | #---------------------------------------------------------------------------- 290 | 291 | 292 | def compute_feature_stats_for_generator_with_data(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=128, batch_gen=128, data_loader_kwargs=None, jit=False, t_max=1, **stats_kwargs): 293 | 294 | 295 | # Setup generator and load labels. 296 | G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) 297 | generator_fn = opts.generator_fn 298 | 299 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 300 | # Set the fixed seed for reproducibility 301 | np.random.seed(0) 302 | 303 | # Create a list of indices and shuffle it 304 | # indices = np.arange(min(len(dataset), 50000)) 305 | indices = np.arange(len(dataset)) 306 | np.random.shuffle(indices) 307 | indices = indices[:50000] 308 | dataset = torch.utils.data.Subset(dataset, indices) 309 | 310 | if data_loader_kwargs is None: 311 | data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) 312 | # dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_gen, **data_loader_kwargs) 313 | 314 | # Create distributed sampler 315 | item_subset = [(i * opts.num_gpus + opts.rank) % 50000 for i in range((50000 - 1) // opts.num_gpus + 1)] 316 | dataloader = torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_gen, **data_loader_kwargs) 317 | 318 | 319 | # Image generation func. 320 | def run_generator(z, c, t_max=1., data=None): 321 | img = generator_fn(G, z, c, t_max=t_max, data=data, **opts.G_kwargs) 322 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 323 | return img 324 | 325 | # JIT. 326 | if jit: 327 | raise NotImplementedError() 328 | # TODO: Add an init_fn for this. 329 | z = torch.zeros([batch_gen, G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 330 | c = torch.zeros([batch_gen, G.label_dim], device=opts.device) 331 | run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False) 332 | 333 | 334 | # Initialize. 335 | stats = FeatureStats(**stats_kwargs) 336 | assert stats.max_items is not None 337 | progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) 338 | detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) 339 | 340 | # Main loop. 341 | for batch, c in dataloader: 342 | if stats.is_full(): 343 | break 344 | 345 | batch = batch.to(opts.device) / 127.5 - 1 346 | c = c.to(opts.device) 347 | z = torch.randn([batch.shape[0], G.img_channels, G.img_resolution, G.img_resolution], device=opts.device) 348 | images = run_generator(z, c, t_max=t_max, data=batch) 349 | 350 | if images.shape[1] == 1: 351 | images = images.repeat([1, 3, 1, 1]) 352 | 353 | features = detector(images, **detector_kwargs) 354 | stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) 355 | progress.update(stats.num_items) 356 | return stats -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 13 | Architecture for Generative Adversarial Networks". Matches the original 14 | implementation by Karras et al. at 15 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 16 | 17 | import copy 18 | import numpy as np 19 | import torch 20 | import dnnlib 21 | from . import metric_utils 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | # Spherical interpolation of a batch of vectors. 26 | def slerp(a, b, t): 27 | a = a / a.norm(dim=-1, keepdim=True) 28 | b = b / b.norm(dim=-1, keepdim=True) 29 | d = (a * b).sum(dim=-1, keepdim=True) 30 | p = t * torch.acos(d) 31 | c = b - d * a 32 | c = c / c.norm(dim=-1, keepdim=True) 33 | d = a * torch.cos(p) + c * torch.sin(p) 34 | d = d / d.norm(dim=-1, keepdim=True) 35 | return d 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | class PPLSampler(torch.nn.Module): 40 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 41 | assert space in ['z', 'w'] 42 | assert sampling in ['full', 'end'] 43 | super().__init__() 44 | self.G = copy.deepcopy(G) 45 | self.G_kwargs = G_kwargs 46 | self.epsilon = epsilon 47 | self.space = space 48 | self.sampling = sampling 49 | self.crop = crop 50 | self.vgg16 = copy.deepcopy(vgg16) 51 | 52 | def forward(self, c): 53 | # Generate random latents and interpolation t-values. 54 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 55 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 56 | 57 | # Interpolate in W or Z. 58 | if self.space == 'w': 59 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 60 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 61 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 62 | else: # space == 'z' 63 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 64 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 65 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 66 | 67 | # Randomize noise buffers. 68 | for name, buf in self.G.named_buffers(): 69 | if name.endswith('.noise_const'): 70 | buf.copy_(torch.randn_like(buf)) 71 | 72 | # Generate images. 73 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 74 | 75 | # Center crop. 76 | if self.crop: 77 | assert img.shape[2] == img.shape[3] 78 | c = img.shape[2] // 8 79 | img = img[:, :, c*3 : c*7, c*2 : c*6] 80 | 81 | # Downsample to 256x256. 82 | factor = self.G.img_resolution // 256 83 | if factor > 1: 84 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 85 | 86 | # Scale dynamic range from [-1,1] to [0,255]. 87 | img = (img + 1) * (255 / 2) 88 | if self.G.img_channels == 1: 89 | img = img.repeat([1, 3, 1, 1]) 90 | 91 | # Evaluate differential LPIPS. 92 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 93 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 94 | return dist 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): 99 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 100 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 101 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 102 | 103 | # Setup sampler. 104 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 105 | sampler.eval().requires_grad_(False).to(opts.device) 106 | if jit: 107 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 108 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 109 | 110 | # Sampling loop. 111 | dist = [] 112 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 113 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 114 | progress.update(batch_start) 115 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] 116 | c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 117 | x = sampler(c) 118 | for src in range(opts.num_gpus): 119 | y = x.clone() 120 | if opts.num_gpus > 1: 121 | torch.distributed.broadcast(y, src=src) 122 | dist.append(y) 123 | progress.update(num_samples) 124 | 125 | # Compute PPL. 126 | if opts.rank != 0: 127 | return float('nan') 128 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 129 | lo = np.percentile(dist, 1, interpolation='lower') 130 | hi = np.percentile(dist, 99, interpolation='higher') 131 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 132 | return float(ppl) 133 | 134 | #---------------------------------------------------------------------------- 135 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from stylegan2-ada-pytorch. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/stylegan2-ada-pytorch/tree/main/metrics 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 12 | Metric for Assessing Generative Models". Matches the original implementation 13 | by Kynkaanniemi et al. at 14 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 15 | 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 22 | assert 0 <= rank < num_gpus 23 | num_cols = col_features.shape[0] 24 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 25 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 26 | dist_batches = [] 27 | for col_batch in col_batches[rank :: num_gpus]: 28 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 29 | for src in range(num_gpus): 30 | dist_broadcast = dist_batch.clone() 31 | if num_gpus > 1: 32 | torch.distributed.broadcast(dist_broadcast, src=src) 33 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 34 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 39 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 40 | detector_kwargs = dict(return_features=True) 41 | 42 | real_features = metric_utils.compute_feature_stats_for_dataset( 43 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 44 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 45 | 46 | gen_features = metric_utils.compute_feature_stats_for_generator( 47 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 48 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 49 | 50 | results = dict() 51 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 52 | kth = [] 53 | for manifold_batch in manifold.split(row_batch_size): 54 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 55 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 56 | kth = torch.cat(kth) if opts.rank == 0 else None 57 | pred = [] 58 | for probes_batch in probes.split(row_batch_size): 59 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 60 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 61 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 62 | return results['precision'], results['recall'] 63 | 64 | #---------------------------------------------------------------------------- 65 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | # empty 13 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | import os 13 | import torch 14 | from . import training_stats 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def init(): 19 | if 'MASTER_ADDR' not in os.environ: 20 | os.environ['MASTER_ADDR'] = 'localhost' 21 | if 'MASTER_PORT' not in os.environ: 22 | os.environ['MASTER_PORT'] = '29500' 23 | if 'RANK' not in os.environ: 24 | os.environ['RANK'] = '0' 25 | if 'LOCAL_RANK' not in os.environ: 26 | os.environ['LOCAL_RANK'] = '0' 27 | if 'WORLD_SIZE' not in os.environ: 28 | os.environ['WORLD_SIZE'] = '1' 29 | 30 | backend = 'gloo' if os.name == 'nt' else 'nccl' 31 | torch.distributed.init_process_group(backend=backend, init_method='env://') 32 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 33 | 34 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 35 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def get_rank(): 40 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def synchronize(): 45 | if not torch.distributed.is_available(): 46 | return 47 | if not torch.distributed.is_initialized(): 48 | return 49 | 50 | world_size = torch.distributed.get_world_size() 51 | if world_size == 1: 52 | return 53 | 54 | torch.distributed.barrier() 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | def get_world_size(): 59 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | def should_stop(): 64 | return False 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | def update_progress(cur, total): 69 | _ = cur, total 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | def print0(*args, **kwargs): 74 | if get_rank() == 0: 75 | print(*args, **kwargs) 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | import re 13 | import contextlib 14 | import numpy as np 15 | import torch 16 | import warnings 17 | import dnnlib 18 | import matplotlib.pyplot as plt 19 | import scipy.stats as stats 20 | 21 | #---------------------------------------------------------------------------- 22 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 23 | # same constant is used multiple times. 24 | 25 | _constant_cache = dict() 26 | 27 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 28 | value = np.asarray(value) 29 | if shape is not None: 30 | shape = tuple(shape) 31 | if dtype is None: 32 | dtype = torch.get_default_dtype() 33 | if device is None: 34 | device = torch.device('cpu') 35 | if memory_format is None: 36 | memory_format = torch.contiguous_format 37 | 38 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 39 | tensor = _constant_cache.get(key, None) 40 | if tensor is None: 41 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 42 | if shape is not None: 43 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 44 | tensor = tensor.contiguous(memory_format=memory_format) 45 | _constant_cache[key] = tensor 46 | return tensor 47 | 48 | #---------------------------------------------------------------------------- 49 | # Replace NaN/Inf with specified numerical values. 50 | 51 | try: 52 | nan_to_num = torch.nan_to_num # 1.8.0a0 53 | except AttributeError: 54 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 55 | assert isinstance(input, torch.Tensor) 56 | if posinf is None: 57 | posinf = torch.finfo(input.dtype).max 58 | if neginf is None: 59 | neginf = torch.finfo(input.dtype).min 60 | assert nan == 0 61 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 62 | 63 | 64 | # Variant of constant() that inherits dtype and device from the given 65 | # reference tensor by default. 66 | 67 | def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): 68 | if dtype is None: 69 | dtype = ref.dtype 70 | if device is None: 71 | device = ref.device 72 | return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) 73 | #---------------------------------------------------------------------------- 74 | # Symbolic assert. 75 | 76 | try: 77 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 78 | except AttributeError: 79 | symbolic_assert = torch.Assert # 1.7.0 80 | 81 | #---------------------------------------------------------------------------- 82 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 83 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 84 | 85 | @contextlib.contextmanager 86 | def suppress_tracer_warnings(): 87 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 88 | warnings.filters.insert(0, flt) 89 | yield 90 | warnings.filters.remove(flt) 91 | 92 | #---------------------------------------------------------------------------- 93 | # Assert that the shape of a tensor matches the given list of integers. 94 | # None indicates that the size of a dimension is allowed to vary. 95 | # Performs symbolic assertion when used in torch.jit.trace(). 96 | 97 | def assert_shape(tensor, ref_shape): 98 | if tensor.ndim != len(ref_shape): 99 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 100 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 101 | if ref_size is None: 102 | pass 103 | elif isinstance(ref_size, torch.Tensor): 104 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 105 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 106 | elif isinstance(size, torch.Tensor): 107 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 108 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 109 | elif size != ref_size: 110 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 111 | 112 | #---------------------------------------------------------------------------- 113 | # Function decorator that calls torch.autograd.profiler.record_function(). 114 | 115 | def profiled_function(fn): 116 | def decorator(*args, **kwargs): 117 | with torch.autograd.profiler.record_function(fn.__name__): 118 | return fn(*args, **kwargs) 119 | decorator.__name__ = fn.__name__ 120 | return decorator 121 | 122 | #---------------------------------------------------------------------------- 123 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 124 | # indefinitely, shuffling items as it goes. 125 | 126 | class InfiniteSampler(torch.utils.data.Sampler): 127 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 128 | assert len(dataset) > 0 129 | assert num_replicas > 0 130 | assert 0 <= rank < num_replicas 131 | assert 0 <= window_size <= 1 132 | super().__init__(dataset) 133 | self.dataset = dataset 134 | self.rank = rank 135 | self.num_replicas = num_replicas 136 | self.shuffle = shuffle 137 | self.seed = seed 138 | self.window_size = window_size 139 | 140 | def __iter__(self): 141 | order = np.arange(len(self.dataset)) 142 | rnd = None 143 | window = 0 144 | if self.shuffle: 145 | rnd = np.random.RandomState(self.seed) 146 | rnd.shuffle(order) 147 | window = int(np.rint(order.size * self.window_size)) 148 | 149 | idx = 0 150 | while True: 151 | i = idx % order.size 152 | if idx % self.num_replicas == self.rank: 153 | yield order[i] 154 | if window >= 2: 155 | j = (i - rnd.randint(window)) % order.size 156 | order[i], order[j] = order[j], order[i] 157 | idx += 1 158 | 159 | #---------------------------------------------------------------------------- 160 | # Utilities for operating with torch.nn.Module parameters and buffers. 161 | 162 | def params_and_buffers(module): 163 | assert isinstance(module, torch.nn.Module) 164 | return list(module.parameters()) + list(module.buffers()) 165 | 166 | def named_params_and_buffers(module): 167 | assert isinstance(module, torch.nn.Module) 168 | return list(module.named_parameters()) + list(module.named_buffers()) 169 | 170 | @torch.no_grad() 171 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 172 | assert isinstance(src_module, torch.nn.Module) 173 | assert isinstance(dst_module, torch.nn.Module) 174 | src_tensors = dict(named_params_and_buffers(src_module)) 175 | for name, tensor in named_params_and_buffers(dst_module): 176 | assert (name in src_tensors) or (not require_all), f"Missing source tensor: {name}" 177 | if name in src_tensors: 178 | tensor.copy_(src_tensors[name]) 179 | 180 | #---------------------------------------------------------------------------- 181 | # Context manager for easily enabling/disabling DistributedDataParallel 182 | # synchronization. 183 | 184 | @contextlib.contextmanager 185 | def ddp_sync(module, sync): 186 | assert isinstance(module, torch.nn.Module) 187 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 188 | yield 189 | else: 190 | with module.no_sync(): 191 | yield 192 | 193 | #---------------------------------------------------------------------------- 194 | # Check DistributedDataParallel consistency across processes. 195 | 196 | def check_ddp_consistency(module, ignore_regex=None): 197 | assert isinstance(module, torch.nn.Module) 198 | for name, tensor in named_params_and_buffers(module): 199 | fullname = type(module).__name__ + '.' + name 200 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 201 | continue 202 | tensor = tensor.detach() 203 | if tensor.is_floating_point(): 204 | tensor = nan_to_num(tensor) 205 | other = tensor.clone() 206 | torch.distributed.broadcast(tensor=other, src=0) 207 | assert (tensor == other).all(), fullname 208 | 209 | #---------------------------------------------------------------------------- 210 | # Print summary table of module hierarchy. 211 | 212 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 213 | assert isinstance(module, torch.nn.Module) 214 | assert not isinstance(module, torch.jit.ScriptModule) 215 | assert isinstance(inputs, (tuple, list)) 216 | 217 | # Register hooks. 218 | entries = [] 219 | nesting = [0] 220 | def pre_hook(_mod, _inputs): 221 | nesting[0] += 1 222 | def post_hook(mod, _inputs, outputs): 223 | nesting[0] -= 1 224 | if nesting[0] <= max_nesting: 225 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 226 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 227 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 228 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 229 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 230 | 231 | # Run module. 232 | outputs = module(*inputs) 233 | for hook in hooks: 234 | hook.remove() 235 | 236 | # Identify unique outputs, parameters, and buffers. 237 | tensors_seen = set() 238 | for e in entries: 239 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 240 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 241 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 242 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 243 | 244 | # Filter out redundant entries. 245 | if skip_redundant: 246 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 247 | 248 | # Construct table. 249 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 250 | rows += [['---'] * len(rows[0])] 251 | param_total = 0 252 | buffer_total = 0 253 | submodule_names = {mod: name for name, mod in module.named_modules()} 254 | for e in entries: 255 | name = '' if e.mod is module else submodule_names[e.mod] 256 | param_size = sum(t.numel() for t in e.unique_params) 257 | buffer_size = sum(t.numel() for t in e.unique_buffers) 258 | output_shapes = [str(list(t.shape)) for t in e.outputs] 259 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 260 | rows += [[ 261 | name + (':0' if len(e.outputs) >= 2 else ''), 262 | str(param_size) if param_size else '-', 263 | str(buffer_size) if buffer_size else '-', 264 | (output_shapes + ['-'])[0], 265 | (output_dtypes + ['-'])[0], 266 | ]] 267 | for idx in range(1, len(e.outputs)): 268 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 269 | param_total += param_size 270 | buffer_total += buffer_size 271 | rows += [['---'] * len(rows[0])] 272 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 273 | 274 | # Print table. 275 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 276 | print() 277 | for row in rows: 278 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 279 | print() 280 | return outputs 281 | 282 | #---------------------------------------------------------------------------- 283 | 284 | def get_t(i, sigma_max=80, sigma_min=0.002, rho=7): 285 | """ 286 | Calculate edm-style t given i in the range [0, 1]. get_t(1) = sigma_max, and get_t(0) = sigma_min 287 | """ 288 | t = (sigma_max ** (1 / rho) + (1 - i) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 289 | return t 290 | 291 | class LoggingT: 292 | def __init__(self, n_bins=100, alpha=1.): 293 | self.n_bins = n_bins 294 | self.alpha = alpha # Moving average rate 295 | self.bin_edges = [get_t(i / n_bins) for i in range(n_bins + 1)] # Store bin edges 296 | self.bin_edges[-1] = 100 297 | self.bin_edges[0] = 0. 298 | self.bin_edges = np.array(self.bin_edges) 299 | self.log_bin = np.zeros(n_bins) # Initialize moving average for the quantity 300 | self.iteration = 0 301 | 302 | def get_bin_idx(self, t_values, values): 303 | # If torch tensors are used, convert them to numpy arrays 304 | if isinstance(t_values, torch.Tensor): 305 | t_values = t_values.detach().cpu().numpy() 306 | if isinstance(values, torch.Tensor): 307 | values = values.detach().cpu().numpy() 308 | 309 | # Check if any t_values are outside the bounds and print them if they are 310 | if not (np.all(t_values > self.bin_edges[0]) and np.all(t_values < self.bin_edges[-1])): 311 | # Identify values outside the bounds 312 | out_of_bounds_values = t_values[(t_values < self.bin_edges[0]) | (t_values > self.bin_edges[-1])] 313 | raise ValueError(f"t_values contain elements outside the defined range of bin edges: {out_of_bounds_values}") 314 | bin_idx = np.digitize(t_values, self.bin_edges[1:]) # Determine bin indices 315 | return bin_idx 316 | 317 | def update(self, t_values, values): 318 | # If torch tensors are used, convert them to numpy arrays 319 | if isinstance(t_values, torch.Tensor): 320 | t_values = t_values.detach().cpu().numpy() 321 | if isinstance(values, torch.Tensor): 322 | values = values.detach().cpu().numpy() 323 | 324 | bin_idx = self.get_bin_idx(t_values, values) 325 | # Apply updates using np.where 326 | 327 | sums = np.bincount(bin_idx, weights=values, minlength=self.n_bins) 328 | counts = np.bincount(bin_idx, minlength=self.n_bins) 329 | 330 | avg_values = np.divide(sums, counts, out=np.zeros_like(sums), where=counts != 0) 331 | 332 | # Apply updates to the indices where counts are non-zero 333 | update_indices = np.where(counts != 0) 334 | self.log_bin[update_indices] = (1 - self.alpha) * self.log_bin[update_indices] + self.alpha * avg_values[update_indices] 335 | self.iteration += 1 336 | 337 | def plot(self, save_path): 338 | """ 339 | Plot the logged quantity as a histogram. 340 | Parameters: 341 | - save_path: Path to save the plot. 342 | """ 343 | plt.figure(figsize=(12, 6)) # Increase figure size for better visibility 344 | 345 | # Create a histogram plot 346 | plt.hist(self.bin_edges[:-1], bins=self.bin_edges, weights=self.log_bin, edgecolor='black', alpha=0.7) 347 | 348 | plt.xlabel('t') 349 | # Set x-ticks to bin edges and format the labels 350 | plt.xscale('log') 351 | 352 | plt.ylabel('Loss') 353 | plt.grid(True, linestyle='--', alpha=0.7) 354 | plt.xticks(rotation=45) # Rotate x-axis labels for better readability 355 | plt.ticklabel_format(style='plain', axis='y', scilimits=(0, 0)) 356 | 357 | plt.tight_layout() 358 | plt.savefig(save_path, dpi=100) 359 | plt.close() 360 | def warm_up(self): 361 | return self.iteration < 100 362 | 363 | 364 | 365 | def get_log(self): 366 | return self.log_bin 367 | 368 | 369 | def get_edm_cout(sigma, sigma_data=0.5): 370 | return sigma * sigma_data / (sigma ** 2 + sigma_data ** 2).sqrt() 371 | 372 | def truncated_normal(num_samples, mu, sigma, lower, upper): 373 | # Sample from N(mu, sigma^2) truncated to the range [lower, upper] 374 | a = (lower - mu) / sigma 375 | b = (upper - mu) / sigma 376 | x = stats.truncnorm.rvs(a, b, loc=mu, scale=sigma, size=num_samples) 377 | return x 378 | 379 | def truncated_t(num_samples, mu, sigma, lower, upper, df=2): 380 | # Sample from truncated student_t 381 | lower_adjusted = (lower - mu) / sigma 382 | upper_adjusted = (upper - mu) / sigma 383 | 384 | # Get the CDF values for the adjusted bounds 385 | a = stats.t.cdf(lower_adjusted, df) 386 | b = stats.t.cdf(upper_adjusted, df) 387 | 388 | # Sample from the uniform distribution within the adjusted CDF bounds 389 | u = np.random.uniform(a, b, num_samples) 390 | 391 | # Get the PPF and then adjust it to match the desired mean and scale 392 | x = stats.t.ppf(u, df) * sigma + mu 393 | return x 394 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Facilities for pickling Python code alongside other data. 13 | 14 | The pickled code is automatically imported into a separate Python module 15 | during unpickling. This way, any previously exported pickles will remain 16 | usable even if the original code is no longer available, or if the current 17 | version of the code is not consistent with what was originally pickled.""" 18 | 19 | import sys 20 | import pickle 21 | import io 22 | import inspect 23 | import copy 24 | import uuid 25 | import types 26 | import dnnlib 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | _version = 6 # internal version number 31 | _decorators = set() # {decorator_class, ...} 32 | _import_hooks = [] # [hook_function, ...] 33 | _module_to_src_dict = dict() # {module: src, ...} 34 | _src_to_module_dict = dict() # {src: module, ...} 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | def persistent_class(orig_class): 39 | r"""Class decorator that extends a given class to save its source code 40 | when pickled. 41 | 42 | Example: 43 | 44 | from torch_utils import persistence 45 | 46 | @persistence.persistent_class 47 | class MyNetwork(torch.nn.Module): 48 | def __init__(self, num_inputs, num_outputs): 49 | super().__init__() 50 | self.fc = MyLayer(num_inputs, num_outputs) 51 | ... 52 | 53 | @persistence.persistent_class 54 | class MyLayer(torch.nn.Module): 55 | ... 56 | 57 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 58 | source code alongside other internal state (e.g., parameters, buffers, 59 | and submodules). This way, any previously exported pickle will remain 60 | usable even if the class definitions have been modified or are no 61 | longer available. 62 | 63 | The decorator saves the source code of the entire Python module 64 | containing the decorated class. It does *not* save the source code of 65 | any imported modules. Thus, the imported modules must be available 66 | during unpickling, also including `torch_utils.persistence` itself. 67 | 68 | It is ok to call functions defined in the same module from the 69 | decorated class. However, if the decorated class depends on other 70 | classes defined in the same module, they must be decorated as well. 71 | This is illustrated in the above example in the case of `MyLayer`. 72 | 73 | It is also possible to employ the decorator just-in-time before 74 | calling the constructor. For example: 75 | 76 | cls = MyLayer 77 | if want_to_make_it_persistent: 78 | cls = persistence.persistent_class(cls) 79 | layer = cls(num_inputs, num_outputs) 80 | 81 | As an additional feature, the decorator also keeps track of the 82 | arguments that were used to construct each instance of the decorated 83 | class. The arguments can be queried via `obj.init_args` and 84 | `obj.init_kwargs`, and they are automatically pickled alongside other 85 | object state. This feature can be disabled on a per-instance basis 86 | by setting `self._record_init_args = False` in the constructor. 87 | 88 | A typical use case is to first unpickle a previous instance of a 89 | persistent class, and then upgrade it to use the latest version of 90 | the source code: 91 | 92 | with open('old_pickle.pkl', 'rb') as f: 93 | old_net = pickle.load(f) 94 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 95 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 96 | """ 97 | assert isinstance(orig_class, type) 98 | if is_persistent(orig_class): 99 | return orig_class 100 | 101 | assert orig_class.__module__ in sys.modules 102 | orig_module = sys.modules[orig_class.__module__] 103 | orig_module_src = _module_to_src(orig_module) 104 | 105 | class Decorator(orig_class): 106 | _orig_module_src = orig_module_src 107 | _orig_class_name = orig_class.__name__ 108 | 109 | def __init__(self, *args, **kwargs): 110 | super().__init__(*args, **kwargs) 111 | record_init_args = getattr(self, '_record_init_args', True) 112 | self._init_args = copy.deepcopy(args) if record_init_args else None 113 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 114 | assert orig_class.__name__ in orig_module.__dict__ 115 | _check_pickleable(self.__reduce__()) 116 | 117 | @property 118 | def init_args(self): 119 | assert self._init_args is not None 120 | return copy.deepcopy(self._init_args) 121 | 122 | @property 123 | def init_kwargs(self): 124 | assert self._init_kwargs is not None 125 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 126 | 127 | def __reduce__(self): 128 | fields = list(super().__reduce__()) 129 | fields += [None] * max(3 - len(fields), 0) 130 | if fields[0] is not _reconstruct_persistent_obj: 131 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 132 | fields[0] = _reconstruct_persistent_obj # reconstruct func 133 | fields[1] = (meta,) # reconstruct args 134 | fields[2] = None # state dict 135 | return tuple(fields) 136 | 137 | Decorator.__name__ = orig_class.__name__ 138 | Decorator.__module__ = orig_class.__module__ 139 | _decorators.add(Decorator) 140 | return Decorator 141 | 142 | #---------------------------------------------------------------------------- 143 | 144 | def is_persistent(obj): 145 | r"""Test whether the given object or class is persistent, i.e., 146 | whether it will save its source code when pickled. 147 | """ 148 | try: 149 | if obj in _decorators: 150 | return True 151 | except TypeError: 152 | pass 153 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 154 | 155 | #---------------------------------------------------------------------------- 156 | 157 | def import_hook(hook): 158 | r"""Register an import hook that is called whenever a persistent object 159 | is being unpickled. A typical use case is to patch the pickled source 160 | code to avoid errors and inconsistencies when the API of some imported 161 | module has changed. 162 | 163 | The hook should have the following signature: 164 | 165 | hook(meta) -> modified meta 166 | 167 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 168 | 169 | type: Type of the persistent object, e.g. `'class'`. 170 | version: Internal version number of `torch_utils.persistence`. 171 | module_src Original source code of the Python module. 172 | class_name: Class name in the original Python module. 173 | state: Internal state of the object. 174 | 175 | Example: 176 | 177 | @persistence.import_hook 178 | def wreck_my_network(meta): 179 | if meta.class_name == 'MyNetwork': 180 | print('MyNetwork is being imported. I will wreck it!') 181 | meta.module_src = meta.module_src.replace("True", "False") 182 | return meta 183 | """ 184 | assert callable(hook) 185 | _import_hooks.append(hook) 186 | 187 | #---------------------------------------------------------------------------- 188 | 189 | def _reconstruct_persistent_obj(meta): 190 | r"""Hook that is called internally by the `pickle` module to unpickle 191 | a persistent object. 192 | """ 193 | meta = dnnlib.EasyDict(meta) 194 | meta.state = dnnlib.EasyDict(meta.state) 195 | for hook in _import_hooks: 196 | meta = hook(meta) 197 | assert meta is not None 198 | 199 | assert meta.version == _version 200 | module = _src_to_module(meta.module_src) 201 | 202 | assert meta.type == 'class' 203 | orig_class = module.__dict__[meta.class_name] 204 | decorator_class = persistent_class(orig_class) 205 | obj = decorator_class.__new__(decorator_class) 206 | 207 | setstate = getattr(obj, '__setstate__', None) 208 | if callable(setstate): 209 | setstate(meta.state) # pylint: disable=not-callable 210 | else: 211 | obj.__dict__.update(meta.state) 212 | return obj 213 | 214 | #---------------------------------------------------------------------------- 215 | 216 | def _module_to_src(module): 217 | r"""Query the source code of a given Python module. 218 | """ 219 | src = _module_to_src_dict.get(module, None) 220 | if src is None: 221 | src = inspect.getsource(module) 222 | _module_to_src_dict[module] = src 223 | _src_to_module_dict[src] = module 224 | return src 225 | 226 | def _src_to_module(src): 227 | r"""Get or create a Python module for the given source code. 228 | """ 229 | module = _src_to_module_dict.get(src, None) 230 | if module is None: 231 | module_name = "_imported_module_" + uuid.uuid4().hex 232 | module = types.ModuleType(module_name) 233 | sys.modules[module_name] = module 234 | _module_to_src_dict[module] = src 235 | _src_to_module_dict[src] = module 236 | exec(src, module.__dict__) # pylint: disable=exec-used 237 | return module 238 | 239 | #---------------------------------------------------------------------------- 240 | 241 | def _check_pickleable(obj): 242 | r"""Check that the given object is pickleable, raising an exception if 243 | it is not. This function is expected to be considerably more efficient 244 | than actually pickling the object. 245 | """ 246 | def recurse(obj): 247 | if isinstance(obj, (list, tuple, set)): 248 | return [recurse(x) for x in obj] 249 | if isinstance(obj, dict): 250 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 251 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 252 | return None # Python primitive types are pickleable. 253 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 254 | return None # NumPy arrays and PyTorch tensors are pickleable. 255 | if is_persistent(obj): 256 | return None # Persistent objects are pickleable, by virtue of the constructor check. 257 | return obj 258 | with io.BytesIO() as f: 259 | pickle.dump(recurse(obj), f) 260 | 261 | #---------------------------------------------------------------------------- 262 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/torch_utils (EDM) 8 | # 9 | # The license for these can be found in license/ directory. 10 | # --------------------------------------------------------------- 11 | 12 | """Facilities for reporting and collecting training statistics across 13 | multiple processes and devices. The interface is designed to minimize 14 | synchronization overhead as well as the amount of boilerplate in user 15 | code.""" 16 | 17 | import re 18 | import numpy as np 19 | import torch 20 | import dnnlib 21 | 22 | from . import misc 23 | 24 | # ---------------------------------------------------------------------------- 25 | 26 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 27 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 28 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 29 | _rank = 0 # Rank of the current process. 30 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 31 | _sync_called = False # Has _sync() been called yet? 32 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 33 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 34 | 35 | 36 | # ---------------------------------------------------------------------------- 37 | 38 | def init_multiprocessing(rank, sync_device): 39 | r"""Initializes `torch_utils.training_stats` for collecting statistics 40 | across multiple processes. 41 | 42 | This function must be called after 43 | `torch.distributed.init_process_group()` and before `Collector.update()`. 44 | The call is not necessary if multi-process collection is not needed. 45 | 46 | Args: 47 | rank: Rank of the current process. 48 | sync_device: PyTorch device to use for inter-process 49 | communication, or None to disable multi-process 50 | collection. Typically `torch.device('cuda', rank)`. 51 | """ 52 | global _rank, _sync_device 53 | assert not _sync_called 54 | _rank = rank 55 | _sync_device = sync_device 56 | 57 | 58 | # ---------------------------------------------------------------------------- 59 | 60 | @misc.profiled_function 61 | def report(name, value): 62 | r"""Broadcasts the given set of scalars to all interested instances of 63 | `Collector`, across device and process boundaries. 64 | 65 | This function is expected to be extremely cheap and can be safely 66 | called from anywhere in the training loop, loss function, or inside a 67 | `torch.nn.Module`. 68 | 69 | Warning: The current implementation expects the set of unique names to 70 | be consistent across processes. Please make sure that `report()` is 71 | called at least once for each unique name by each process, and in the 72 | same order. If a given process has no scalars to broadcast, it can do 73 | `report(name, [])` (empty list). 74 | 75 | Args: 76 | name: Arbitrary string specifying the name of the statistic. 77 | Averages are accumulated separately for each unique name. 78 | value: Arbitrary set of scalars. Can be a list, tuple, 79 | NumPy array, PyTorch tensor, or Python scalar. 80 | 81 | Returns: 82 | The same `value` that was passed in. 83 | """ 84 | if name not in _counters: 85 | _counters[name] = dict() 86 | 87 | elems = torch.as_tensor(value) 88 | if elems.numel() == 0: 89 | return value 90 | 91 | elems = elems.detach().flatten().to(_reduce_dtype) 92 | moments = torch.stack([ 93 | torch.ones_like(elems).sum(), 94 | elems.sum(), 95 | elems.square().sum(), 96 | ]) 97 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 98 | moments = moments.to(_counter_dtype) 99 | 100 | device = moments.device 101 | if device not in _counters[name]: 102 | _counters[name][device] = torch.zeros_like(moments) 103 | _counters[name][device].add_(moments) 104 | return value 105 | 106 | 107 | # ---------------------------------------------------------------------------- 108 | 109 | def report0(name, value): 110 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 111 | but ignores any scalars provided by the other processes. 112 | See `report()` for further details. 113 | """ 114 | report(name, value if _rank == 0 else []) 115 | return value 116 | 117 | 118 | # ---------------------------------------------------------------------------- 119 | 120 | class Collector: 121 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 122 | computes their long-term averages (mean and standard deviation) over 123 | user-defined periods of time. 124 | 125 | The averages are first collected into internal counters that are not 126 | directly visible to the user. They are then copied to the user-visible 127 | state as a result of calling `update()` and can then be queried using 128 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 129 | internal counters for the next round, so that the user-visible state 130 | effectively reflects averages collected between the last two calls to 131 | `update()`. 132 | 133 | Args: 134 | regex: Regular expression defining which statistics to 135 | collect. The default is to collect everything. 136 | keep_previous: Whether to retain the previous averages if no 137 | scalars were collected on a given round 138 | (default: True). 139 | """ 140 | 141 | def __init__(self, regex='.*', keep_previous=True): 142 | self._regex = re.compile(regex) 143 | self._keep_previous = keep_previous 144 | self._cumulative = dict() 145 | self._moments = dict() 146 | self.update() 147 | self._moments.clear() 148 | 149 | def names(self): 150 | r"""Returns the names of all statistics broadcasted so far that 151 | match the regular expression specified at construction time. 152 | """ 153 | return [name for name in _counters if self._regex.fullmatch(name)] 154 | 155 | def update(self): 156 | r"""Copies current values of the internal counters to the 157 | user-visible state and resets them for the next round. 158 | 159 | If `keep_previous=True` was specified at construction time, the 160 | operation is skipped for statistics that have received no scalars 161 | since the last update, retaining their previous averages. 162 | 163 | This method performs a number of GPU-to-CPU transfers and one 164 | `torch.distributed.all_reduce()`. It is intended to be called 165 | periodically in the main training loop, typically once every 166 | N training steps. 167 | """ 168 | if not self._keep_previous: 169 | self._moments.clear() 170 | for name, cumulative in _sync(self.names()): 171 | if name not in self._cumulative: 172 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 173 | delta = cumulative - self._cumulative[name] 174 | self._cumulative[name].copy_(cumulative) 175 | if float(delta[0]) != 0: 176 | self._moments[name] = delta 177 | 178 | def _get_delta(self, name): 179 | r"""Returns the raw moments that were accumulated for the given 180 | statistic between the last two calls to `update()`, or zero if 181 | no scalars were collected. 182 | """ 183 | assert self._regex.fullmatch(name) 184 | if name not in self._moments: 185 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 186 | return self._moments[name] 187 | 188 | def num(self, name): 189 | r"""Returns the number of scalars that were accumulated for the given 190 | statistic between the last two calls to `update()`, or zero if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | return int(delta[0]) 195 | 196 | def mean(self, name): 197 | r"""Returns the mean of the scalars that were accumulated for the 198 | given statistic between the last two calls to `update()`, or NaN if 199 | no scalars were collected. 200 | """ 201 | delta = self._get_delta(name) 202 | if int(delta[0]) == 0: 203 | return float('nan') 204 | return float(delta[1] / delta[0]) 205 | 206 | def std(self, name): 207 | r"""Returns the standard deviation of the scalars that were 208 | accumulated for the given statistic between the last two calls to 209 | `update()`, or NaN if no scalars were collected. 210 | """ 211 | delta = self._get_delta(name) 212 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 213 | return float('nan') 214 | if int(delta[0]) == 1: 215 | return float(0) 216 | mean = float(delta[1] / delta[0]) 217 | raw_var = float(delta[2] / delta[0]) 218 | return np.sqrt(max(raw_var - np.square(mean), 0)) 219 | 220 | def as_dict(self): 221 | r"""Returns the averages accumulated between the last two calls to 222 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 223 | 224 | dnnlib.EasyDict( 225 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 226 | ... 227 | ) 228 | """ 229 | stats = dnnlib.EasyDict() 230 | for name in self.names(): 231 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 232 | return stats 233 | 234 | def __getitem__(self, name): 235 | r"""Convenience getter. 236 | `collector[name]` is a synonym for `collector.mean(name)`. 237 | """ 238 | return self.mean(name) 239 | 240 | 241 | # ---------------------------------------------------------------------------- 242 | 243 | def _sync(names): 244 | r"""Synchronize the global cumulative counters across devices and 245 | processes. Called internally by `Collector.update()`. 246 | """ 247 | if len(names) == 0: 248 | return [] 249 | global _sync_called 250 | _sync_called = True 251 | 252 | # Collect deltas within current rank. 253 | deltas = [] 254 | device = _sync_device if _sync_device is not None else torch.device('cpu') 255 | for name in names: 256 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 257 | for counter in _counters[name].values(): 258 | delta.add_(counter.to(device)) 259 | counter.copy_(torch.zeros_like(counter)) 260 | deltas.append(delta) 261 | deltas = torch.stack(deltas) 262 | 263 | # Sum deltas across ranks. 264 | if _sync_device is not None: 265 | torch.distributed.all_reduce(deltas) 266 | 267 | # Update cumulative values. 268 | deltas = deltas.cpu() 269 | for idx, name in enumerate(names): 270 | if name not in _cumulative: 271 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 272 | _cumulative[name].add_(deltas[idx]) 273 | 274 | # Return name-value pairs. 275 | return [(name, _cumulative[name]) for name in names] 276 | 277 | 278 | # ---------------------------------------------------------------------------- 279 | # Convenience. 280 | 281 | default_collector = Collector() 282 | 283 | # ---------------------------------------------------------------------------- 284 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ECT codebase, which built upon EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/training/__init__.py (EDM) 8 | # https://github.com/locuslab/ect/blob/main/training/__init__.py (ECT) 9 | # 10 | # The license for these can be found in license/ directory. 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | # empty 15 | -------------------------------------------------------------------------------- /training/augment.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ECT codebase, which built upon EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/training/augment.py (EDM) 8 | # https://github.com/locuslab/ect/blob/main/training/augment.py (ECT) 9 | # 10 | # The license for these can be found in license/ directory. 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | """Augmentation pipeline used in the paper 15 | "Elucidating the Design Space of Diffusion-Based Generative Models". 16 | Built around the same concepts that were originally proposed in the paper 17 | "Training Generative Adversarial Networks with Limited Data".""" 18 | 19 | import numpy as np 20 | import torch 21 | from torch_utils import persistence 22 | from torch_utils import misc 23 | 24 | #---------------------------------------------------------------------------- 25 | # Coefficients of various wavelet decomposition low-pass filters. 26 | 27 | wavelets = { 28 | 'haar': [0.7071067811865476, 0.7071067811865476], 29 | 'db1': [0.7071067811865476, 0.7071067811865476], 30 | 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 31 | 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 32 | 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], 33 | 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], 34 | 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], 35 | 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], 36 | 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], 37 | 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], 38 | 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], 39 | 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], 40 | 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], 41 | 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], 42 | 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], 43 | 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], 44 | } 45 | 46 | #---------------------------------------------------------------------------- 47 | # Helpers for constructing transformation matrices. 48 | 49 | def matrix(*rows, device=None): 50 | assert all(len(row) == len(rows[0]) for row in rows) 51 | elems = [x for row in rows for x in row] 52 | ref = [x for x in elems if isinstance(x, torch.Tensor)] 53 | if len(ref) == 0: 54 | return misc.constant(np.asarray(rows), device=device) 55 | assert device is None or device == ref[0].device 56 | elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] 57 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) 58 | 59 | def translate2d(tx, ty, **kwargs): 60 | return matrix( 61 | [1, 0, tx], 62 | [0, 1, ty], 63 | [0, 0, 1], 64 | **kwargs) 65 | 66 | def translate3d(tx, ty, tz, **kwargs): 67 | return matrix( 68 | [1, 0, 0, tx], 69 | [0, 1, 0, ty], 70 | [0, 0, 1, tz], 71 | [0, 0, 0, 1], 72 | **kwargs) 73 | 74 | def scale2d(sx, sy, **kwargs): 75 | return matrix( 76 | [sx, 0, 0], 77 | [0, sy, 0], 78 | [0, 0, 1], 79 | **kwargs) 80 | 81 | def scale3d(sx, sy, sz, **kwargs): 82 | return matrix( 83 | [sx, 0, 0, 0], 84 | [0, sy, 0, 0], 85 | [0, 0, sz, 0], 86 | [0, 0, 0, 1], 87 | **kwargs) 88 | 89 | def rotate2d(theta, **kwargs): 90 | return matrix( 91 | [torch.cos(theta), torch.sin(-theta), 0], 92 | [torch.sin(theta), torch.cos(theta), 0], 93 | [0, 0, 1], 94 | **kwargs) 95 | 96 | def rotate3d(v, theta, **kwargs): 97 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] 98 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c 99 | return matrix( 100 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], 101 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], 102 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], 103 | [0, 0, 0, 1], 104 | **kwargs) 105 | 106 | def translate2d_inv(tx, ty, **kwargs): 107 | return translate2d(-tx, -ty, **kwargs) 108 | 109 | def scale2d_inv(sx, sy, **kwargs): 110 | return scale2d(1 / sx, 1 / sy, **kwargs) 111 | 112 | def rotate2d_inv(theta, **kwargs): 113 | return rotate2d(-theta, **kwargs) 114 | 115 | #---------------------------------------------------------------------------- 116 | # Augmentation pipeline main class. 117 | # All augmentations are disabled by default; individual augmentations can 118 | # be enabled by setting their probability multipliers to 1. 119 | 120 | @persistence.persistent_class 121 | class AugmentPipe: 122 | def __init__(self, p=1, 123 | xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125, 124 | scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125, 125 | brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, 126 | ): 127 | super().__init__() 128 | self.p = float(p) # Overall multiplier for augmentation probability. 129 | 130 | # Pixel blitting. 131 | self.xflip = float(xflip) # Probability multiplier for x-flip. 132 | self.yflip = float(yflip) # Probability multiplier for y-flip. 133 | self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation. 134 | self.translate_int = float(translate_int) # Probability multiplier for integer translation. 135 | self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions. 136 | 137 | # Geometric transformations. 138 | self.scale = float(scale) # Probability multiplier for isotropic scaling. 139 | self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation. 140 | self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. 141 | self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation. 142 | self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. 143 | self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle. 144 | self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. 145 | self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame. 146 | self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions. 147 | 148 | # Color transformations. 149 | self.brightness = float(brightness) # Probability multiplier for brightness. 150 | self.contrast = float(contrast) # Probability multiplier for contrast. 151 | self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. 152 | self.hue = float(hue) # Probability multiplier for hue rotation. 153 | self.saturation = float(saturation) # Probability multiplier for saturation. 154 | self.brightness_std = float(brightness_std) # Standard deviation of brightness. 155 | self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. 156 | self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. 157 | self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. 158 | 159 | def __call__(self, images): 160 | N, C, H, W = images.shape 161 | device = images.device 162 | labels = [torch.zeros([images.shape[0], 0], device=device)] 163 | 164 | # --------------- 165 | # Pixel blitting. 166 | # --------------- 167 | 168 | if self.xflip > 0: 169 | w = torch.randint(2, [N, 1, 1, 1], device=device) 170 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w)) 171 | images = torch.where(w == 1, images.flip(3), images) 172 | labels += [w] 173 | 174 | if self.yflip > 0: 175 | w = torch.randint(2, [N, 1, 1, 1], device=device) 176 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w)) 177 | images = torch.where(w == 1, images.flip(2), images) 178 | labels += [w] 179 | 180 | if self.rotate_int > 0: 181 | w = torch.randint(4, [N, 1, 1, 1], device=device) 182 | w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w)) 183 | images = torch.where((w == 1) | (w == 2), images.flip(3), images) 184 | images = torch.where((w == 2) | (w == 3), images.flip(2), images) 185 | images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images) 186 | labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)] 187 | 188 | if self.translate_int > 0: 189 | w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1 190 | w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w)) 191 | tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64) 192 | ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64) 193 | b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij') 194 | x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs() 195 | y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs() 196 | images = images.flatten()[(((b * C) + c) * H + y) * W + x] 197 | labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)] 198 | 199 | # ------------------------------------------------ 200 | # Select parameters for geometric transformations. 201 | # ------------------------------------------------ 202 | 203 | I_3 = torch.eye(3, device=device) 204 | G_inv = I_3 205 | 206 | if self.scale > 0: 207 | w = torch.randn([N], device=device) 208 | w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w)) 209 | s = w.mul(self.scale_std).exp2() 210 | G_inv = G_inv @ scale2d_inv(s, s) 211 | labels += [w] 212 | 213 | if self.rotate_frac > 0: 214 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max) 215 | w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w)) 216 | G_inv = G_inv @ rotate2d_inv(-w) 217 | labels += [w.cos() - 1, w.sin()] 218 | 219 | if self.aniso > 0: 220 | w = torch.randn([N], device=device) 221 | r = (torch.rand([N], device=device) * 2 - 1) * np.pi 222 | w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w)) 223 | r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r)) 224 | s = w.mul(self.aniso_std).exp2() 225 | G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r) 226 | labels += [w * r.cos(), w * r.sin()] 227 | 228 | if self.translate_frac > 0: 229 | w = torch.randn([2, N], device=device) 230 | w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w)) 231 | G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std)) 232 | labels += [w[0], w[1]] 233 | 234 | # ---------------------------------- 235 | # Execute geometric transformations. 236 | # ---------------------------------- 237 | 238 | if G_inv is not I_3: 239 | cx = (W - 1) / 2 240 | cy = (H - 1) / 2 241 | cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] 242 | cp = G_inv @ cp.t() # [batch, xyz, idx] 243 | Hz = np.asarray(wavelets['sym6'], dtype=np.float32) 244 | Hz_pad = len(Hz) // 4 245 | margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] 246 | margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] 247 | margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) 248 | margin = margin.max(misc.constant([0, 0] * 2, device=device)) 249 | margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device)) 250 | mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) 251 | 252 | # Pad image and adjust origin. 253 | images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') 254 | G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv 255 | 256 | # Upsample. 257 | conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 258 | conv_pad = (len(Hz) + 1) // 2 259 | images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1] 260 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad]) 261 | images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :] 262 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0]) 263 | G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) 264 | G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) 265 | 266 | # Execute transformation. 267 | shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2] 268 | G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) 269 | grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) 270 | images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False) 271 | 272 | # Downsample and crop. 273 | conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) 274 | conv_pad = (len(Hz) - 1) // 2 275 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad] 276 | images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :] 277 | 278 | # -------------------------------------------- 279 | # Select parameters for color transformations. 280 | # -------------------------------------------- 281 | 282 | I_4 = torch.eye(4, device=device) 283 | M = I_4 284 | luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) 285 | 286 | if self.brightness > 0: 287 | w = torch.randn([N], device=device) 288 | w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w)) 289 | b = w * self.brightness_std 290 | M = translate3d(b, b, b) @ M 291 | labels += [w] 292 | 293 | if self.contrast > 0: 294 | w = torch.randn([N], device=device) 295 | w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w)) 296 | c = w.mul(self.contrast_std).exp2() 297 | M = scale3d(c, c, c) @ M 298 | labels += [w] 299 | 300 | if self.lumaflip > 0: 301 | w = torch.randint(2, [N, 1, 1], device=device) 302 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w)) 303 | M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M 304 | labels += [w] 305 | 306 | if self.hue > 0: 307 | w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max) 308 | w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w)) 309 | M = rotate3d(luma_axis, w) @ M 310 | labels += [w.cos() - 1, w.sin()] 311 | 312 | if self.saturation > 0: 313 | w = torch.randn([N, 1, 1], device=device) 314 | w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w)) 315 | M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M 316 | labels += [w] 317 | 318 | # ------------------------------ 319 | # Execute color transformations. 320 | # ------------------------------ 321 | 322 | if M is not I_4: 323 | images = images.reshape([N, C, H * W]) 324 | if C == 3: 325 | images = M[:, :3, :3] @ images + M[:, :3, 3:] 326 | elif C == 1: 327 | M = M[:, :3, :].mean(dim=1, keepdims=True) 328 | images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:] 329 | else: 330 | raise ValueError('Image must be RGB (3 channels) or L (1 channel)') 331 | images = images.reshape([N, C, H, W]) 332 | 333 | labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1) 334 | return images, labels 335 | 336 | #---------------------------------------------------------------------------- 337 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been taken from ECT codebase, which built upon EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/training/dataset.py (EDM) 8 | # https://github.com/locuslab/ect/blob/main/training/dataset.py (ECT) 9 | # 10 | # The license for these can be found in license/ directory. 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | """Streaming images and labels from datasets created with dataset_tool.py.""" 15 | 16 | import os 17 | import numpy as np 18 | import zipfile 19 | import PIL.Image 20 | import json 21 | import torch 22 | import dnnlib 23 | 24 | try: 25 | import pyspng 26 | except ImportError: 27 | pyspng = None 28 | 29 | #---------------------------------------------------------------------------- 30 | # Abstract base class for datasets. 31 | 32 | class Dataset(torch.utils.data.Dataset): 33 | def __init__(self, 34 | name, # Name of the dataset. 35 | raw_shape, # Shape of the raw image data (NCHW). 36 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 37 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 38 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 39 | random_seed = 0, # Random seed to use when applying max_size. 40 | cache = False, # Cache images in CPU memory? 41 | ): 42 | self._name = name 43 | self._raw_shape = list(raw_shape) 44 | self._use_labels = use_labels 45 | self._cache = cache 46 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 47 | self._raw_labels = None 48 | self._label_shape = None 49 | 50 | # Apply max_size. 51 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 52 | if (max_size is not None) and (self._raw_idx.size > max_size): 53 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 54 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 55 | 56 | # Apply xflip. 57 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 58 | if xflip: 59 | self._raw_idx = np.tile(self._raw_idx, 2) 60 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 61 | 62 | def _get_raw_labels(self): 63 | if self._raw_labels is None: 64 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 65 | if self._raw_labels is None: 66 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 67 | assert isinstance(self._raw_labels, np.ndarray) 68 | assert self._raw_labels.shape[0] == self._raw_shape[0] 69 | assert self._raw_labels.dtype in [np.float32, np.int64] 70 | if self._raw_labels.dtype == np.int64: 71 | assert self._raw_labels.ndim == 1 72 | assert np.all(self._raw_labels >= 0) 73 | return self._raw_labels 74 | 75 | def close(self): # to be overridden by subclass 76 | pass 77 | 78 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 79 | raise NotImplementedError 80 | 81 | def _load_raw_labels(self): # to be overridden by subclass 82 | raise NotImplementedError 83 | 84 | def __getstate__(self): 85 | return dict(self.__dict__, _raw_labels=None) 86 | 87 | def __del__(self): 88 | try: 89 | self.close() 90 | except: 91 | pass 92 | 93 | def __len__(self): 94 | return self._raw_idx.size 95 | 96 | def __getitem__(self, idx): 97 | raw_idx = self._raw_idx[idx] 98 | image = self._cached_images.get(raw_idx, None) 99 | if image is None: 100 | image = self._load_raw_image(raw_idx) 101 | if self._cache: 102 | self._cached_images[raw_idx] = image 103 | assert isinstance(image, np.ndarray) 104 | assert list(image.shape) == self.image_shape 105 | assert image.dtype == np.uint8 106 | if self._xflip[idx]: 107 | assert image.ndim == 3 # CHW 108 | image = image[:, :, ::-1] 109 | return image.copy(), self.get_label(idx) 110 | 111 | def get_label(self, idx): 112 | label = self._get_raw_labels()[self._raw_idx[idx]] 113 | if label.dtype == np.int64: 114 | onehot = np.zeros(self.label_shape, dtype=np.float32) 115 | onehot[label] = 1 116 | label = onehot 117 | return label.copy() 118 | 119 | def get_details(self, idx): 120 | d = dnnlib.EasyDict() 121 | d.raw_idx = int(self._raw_idx[idx]) 122 | d.xflip = (int(self._xflip[idx]) != 0) 123 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 124 | return d 125 | 126 | @property 127 | def name(self): 128 | return self._name 129 | 130 | @property 131 | def image_shape(self): 132 | return list(self._raw_shape[1:]) 133 | 134 | @property 135 | def num_channels(self): 136 | assert len(self.image_shape) == 3 # CHW 137 | return self.image_shape[0] 138 | 139 | @property 140 | def resolution(self): 141 | assert len(self.image_shape) == 3 # CHW 142 | assert self.image_shape[1] == self.image_shape[2] 143 | return self.image_shape[1] 144 | 145 | @property 146 | def label_shape(self): 147 | if self._label_shape is None: 148 | raw_labels = self._get_raw_labels() 149 | if raw_labels.dtype == np.int64: 150 | self._label_shape = [int(np.max(raw_labels)) + 1] 151 | else: 152 | self._label_shape = raw_labels.shape[1:] 153 | return list(self._label_shape) 154 | 155 | @property 156 | def label_dim(self): 157 | assert len(self.label_shape) == 1 158 | return self.label_shape[0] 159 | 160 | @property 161 | def has_labels(self): 162 | return any(x != 0 for x in self.label_shape) 163 | 164 | @property 165 | def has_onehot_labels(self): 166 | return self._get_raw_labels().dtype == np.int64 167 | 168 | #---------------------------------------------------------------------------- 169 | # Dataset subclass that loads images recursively from the specified directory 170 | # or ZIP file. 171 | 172 | class ImageFolderDataset(Dataset): 173 | def __init__(self, 174 | path, # Path to directory or zip. 175 | resolution = None, # Ensure specific resolution, None = highest available. 176 | use_pyspng = True, # Use pyspng if available? 177 | **super_kwargs, # Additional arguments for the Dataset base class. 178 | ): 179 | self._path = path 180 | self._use_pyspng = use_pyspng 181 | self._zipfile = None 182 | 183 | if os.path.isdir(self._path): 184 | self._type = 'dir' 185 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 186 | elif self._file_ext(self._path) == '.zip': 187 | self._type = 'zip' 188 | self._all_fnames = set(self._get_zipfile().namelist()) 189 | else: 190 | raise IOError('Path must point to a directory or zip') 191 | 192 | PIL.Image.init() 193 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 194 | if len(self._image_fnames) == 0: 195 | raise IOError('No image files found in the specified path') 196 | 197 | name = os.path.splitext(os.path.basename(self._path))[0] 198 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 199 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 200 | raise IOError('Image files do not match the specified resolution') 201 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 202 | 203 | @staticmethod 204 | def _file_ext(fname): 205 | return os.path.splitext(fname)[1].lower() 206 | 207 | def _get_zipfile(self): 208 | assert self._type == 'zip' 209 | if self._zipfile is None: 210 | self._zipfile = zipfile.ZipFile(self._path) 211 | return self._zipfile 212 | 213 | def _open_file(self, fname): 214 | if self._type == 'dir': 215 | return open(os.path.join(self._path, fname), 'rb') 216 | if self._type == 'zip': 217 | return self._get_zipfile().open(fname, 'r') 218 | return None 219 | 220 | def close(self): 221 | try: 222 | if self._zipfile is not None: 223 | self._zipfile.close() 224 | finally: 225 | self._zipfile = None 226 | 227 | def __getstate__(self): 228 | return dict(super().__getstate__(), _zipfile=None) 229 | 230 | def _load_raw_image(self, raw_idx): 231 | fname = self._image_fnames[raw_idx] 232 | with self._open_file(fname) as f: 233 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 234 | image = pyspng.load(f.read()) 235 | else: 236 | image = np.array(PIL.Image.open(f)) 237 | if image.ndim == 2: 238 | image = image[:, :, np.newaxis] # HW => HWC 239 | image = image.transpose(2, 0, 1) # HWC => CHW 240 | return image 241 | 242 | def _load_raw_labels(self): 243 | fname = 'dataset.json' 244 | if fname not in self._all_fnames: 245 | return None 246 | with self._open_file(fname) as f: 247 | labels = json.load(f)['labels'] 248 | if labels is None: 249 | return None 250 | labels = dict(labels) 251 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 252 | labels = np.array(labels) 253 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 254 | return labels 255 | 256 | #---------------------------------------------------------------------------- 257 | -------------------------------------------------------------------------------- /training/loss.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This file has been modified from ECT codebase, which built upon EDM. 5 | # 6 | # Source: 7 | # https://github.com/NVlabs/edm/blob/main/training/loss.py (EDM) 8 | # https://github.com/locuslab/ect/blob/main/training/loss.py (ECT) 9 | # 10 | # The license for these can be found in license/ directory. 11 | # The modifications to this file are subject to the same license. 12 | # --------------------------------------------------------------- 13 | 14 | import math 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | from torch_utils import persistence 19 | from torch_utils import distributed as dist 20 | from torch_utils import misc 21 | from functools import partial 22 | import os 23 | from training.networks import inplace_norm_flag 24 | #---------------------------------------------------------------------------- 25 | 26 | @torch.no_grad() 27 | def ode_solver(score_net, samples, t, next_t, labels, augment_labels): 28 | x_t = samples 29 | denoiser = score_net(x_t, t, labels, augment_labels=augment_labels) 30 | d = (x_t - denoiser) / t 31 | samples = x_t + d * (next_t - t) 32 | 33 | return samples 34 | 35 | def sigmoid(t): 36 | return 1 / (1 + (-t).exp()) 37 | def gaussian_pdf(x, mu, sigma): 38 | return (1.0 / (sigma * torch.sqrt(2 * torch.tensor(torch.pi)))) * torch.exp(-0.5 * ((x - mu) / sigma) ** 2) 39 | 40 | def delta_sigmoid(t, ratio): 41 | # Sigmoid delta t proposed by ECT 42 | dt = t * (1-ratio) * (1 + 8 * sigmoid(-t)) 43 | s = t - dt 44 | s[s<1e-6] = 1e-6 45 | dt = t - s 46 | return dt 47 | 48 | 49 | @persistence.persistent_class 50 | class ECMLoss(nn.Module): 51 | def __init__(self, P_mean=-1.1, P_std=2.0, t_lower = 0.002, t_upper = 80, sigma_data=0.5, q=2, c=0.0, k=8.0, b=1.0, cut=4.0, delta='sigmoid', ratio_limit=0.999, sqrt=True, weighting='default', boundary_prob=0.05, tdist='normal', df=2): 52 | super().__init__() 53 | 54 | self.P_mean = P_mean 55 | self.P_std = P_std 56 | self.t_lower = t_lower 57 | self.t_upper = t_upper # t \in [t_lower, t_upper] 58 | self.sigma_data = sigma_data 59 | self.ratio_limit = ratio_limit 60 | self.sqrt = sqrt 61 | self.weighting = weighting 62 | self.delta = delta 63 | self.boundary_prob = boundary_prob 64 | self.tdist = tdist # normal or t 65 | self.df = df # degrees of freedom for t distribution 66 | 67 | self.q = q 68 | self.stage = 0 69 | self.ratio = 0. 70 | 71 | self.k = k 72 | self.b = b 73 | self.cut = cut 74 | 75 | self.c = c 76 | dist.print0( 77 | f'P_mean: {self.P_mean}, P_std: {self.P_std}, q: {self.q}, k {self.k}, b {self.b}, cut {self.cut}, c: {self.c}') 78 | 79 | def update_schedule(self, stage): 80 | self.stage = stage 81 | self.ratio = 1 - 1 / self.q ** (stage + 1) 82 | if self.ratio > self.ratio_limit: 83 | dist.print0(f"Clipping ratio from {self.ratio} -> {self.ratio_limit}") 84 | self.ratio = self.ratio_limit 85 | 86 | 87 | 88 | 89 | 90 | def __call__(self, net, images, labels=None, augment_pipe=None, teacher_net=None, t=None): 91 | delta = delta_sigmoid 92 | if t is None: 93 | if self.tdist == 'normal': 94 | logt = misc.truncated_normal(images.shape[0], mu=self.P_mean, sigma=self.P_std, lower=np.log(self.t_lower), upper=np.log(self.t_upper)) # np.array 95 | elif self.tdist == 't': 96 | logt = misc.truncated_t(images.shape[0], mu=self.P_mean, sigma=self.P_std, lower=np.log(self.t_lower), upper=np.log(self.t_upper), df=self.df) 97 | logt = torch.tensor(logt, device=images.device).view(-1, 1, 1, 1) 98 | t = logt.exp() 99 | if hasattr(net, 'transition_t'): 100 | num_elements_to_mask = int(t.shape[0] * self.boundary_prob) 101 | indices = torch.randperm(t.shape[0]) 102 | mask_indices = indices[:num_elements_to_mask] 103 | mask_t = torch.zeros_like(t, dtype=torch.bool) 104 | mask_t[mask_indices] = True 105 | 106 | t = mask_t * (net.transition_t+1e-8) + ~mask_t * t 107 | 108 | r = t - delta(t, self.ratio) 109 | 110 | # Augmentation 111 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 112 | 113 | # Shared noise direction 114 | eps = torch.randn_like(y) 115 | y_t = y + eps * t 116 | if teacher_net is None: 117 | y_r = y + eps * r 118 | else: 119 | y_r = ode_solver(teacher_net, y_t, t, r, labels, augment_labels=augment_labels).detach() 120 | 121 | # Shared Dropout Mask 122 | rng_state = torch.cuda.get_rng_state() 123 | D_yt = net(y_t, t, labels, augment_labels=augment_labels) 124 | 125 | if r.max() > 0: 126 | torch.cuda.set_rng_state(rng_state) 127 | with torch.no_grad(): 128 | token = inplace_norm_flag.set(False) 129 | D_yr = net(y_r, r, labels, augment_labels=augment_labels) 130 | inplace_norm_flag.reset(token) 131 | mask = r > 0 132 | D_yr = torch.nan_to_num(D_yr) 133 | D_yr = mask * D_yr + (~mask) * y 134 | else: 135 | D_yr = y 136 | 137 | # L2 Loss 138 | l2_distance = torch.norm(D_yt - D_yr, dim=(1, 2, 3), p=2) 139 | 140 | 141 | # Huber Loss if needed 142 | if self.c > 0: 143 | loss_unweighted = torch.sqrt(l2_distance ** 2 + self.c ** 2) - self.c 144 | else: 145 | if self.sqrt: 146 | loss_unweighted = l2_distance 147 | else: 148 | loss_unweighted = l2_distance ** 2 149 | 150 | # Weighting fn 151 | 152 | 153 | t = t.flatten() 154 | r = r.flatten() 155 | 156 | if self.weighting == 'default': 157 | loss = loss_unweighted / delta(t, self.ratio) 158 | elif self.weighting == 'cout': 159 | loss = loss_unweighted / misc.get_edm_cout(t, sigma_data = self.sigma_data) 160 | elif self.weighting == 'cout_sq': 161 | loss = loss_unweighted / misc.get_edm_cout(t, sigma_data = self.sigma_data) ** 2 162 | elif self.weighting == 'sqrt': 163 | loss = loss_unweighted / (t-r)**0.5 164 | elif self.weighting == 'one': 165 | loss = loss_unweighted 166 | else: 167 | raise NotImplementedError(f"Weighting function {self.weighting} not implemented.") 168 | 169 | if hasattr(net, 'transition_t') and torch.any(mask_t): 170 | loss_boundary = loss[mask_t.flatten()] 171 | loss = loss[~mask_t.flatten()] 172 | loss_unweighted = loss_unweighted[~mask_t.flatten()] 173 | t = t[~mask_t.flatten()] 174 | l2_distance = l2_distance[~mask_t.flatten()] 175 | else: 176 | loss_boundary = torch.zeros_like(loss) 177 | return loss_unweighted, loss, t, l2_distance, loss_boundary 178 | -------------------------------------------------------------------------------- /training/networks_tcm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for TCM. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import torch 9 | 10 | # TCM parameterization 11 | class TCMPrecond: 12 | def __init__(self, 13 | model_t, # Teacher model (stage-1) 14 | model_s, # Student model (stage-2) 15 | transition_t = 2., # Transition time step (t') 16 | max_t = 80., # Maximum time step 17 | teacher_pkl = None, # Not used 18 | ): 19 | 20 | self.model_t = model_t 21 | self.model_s = model_s 22 | self.transition_t = transition_t 23 | 24 | def __call__(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): 25 | mask = sigma >= self.transition_t # If this is true, use the second-stage model. Otherwise, use the first-stage model 26 | mask = mask.squeeze() 27 | 28 | 29 | rng_state = torch.cuda.get_rng_state() 30 | if (~mask).any(): 31 | with torch.no_grad(): 32 | out_t = self.model_t( 33 | x, 34 | sigma, 35 | class_labels, 36 | force_fp32, 37 | **model_kwargs 38 | ) 39 | else: 40 | out_t = torch.zeros_like(x).to(torch.float32) 41 | torch.cuda.set_rng_state(rng_state) 42 | if mask.any(): 43 | out_s = self.model_s( 44 | x, 45 | sigma, 46 | class_labels, 47 | force_fp32, 48 | **model_kwargs 49 | ) 50 | else: 51 | out_s = torch.zeros_like(x).to(torch.float32) 52 | 53 | out = mask.view(-1,1,1,1) * out_s + (~mask.view(-1,1,1,1)) * out_t 54 | return out 55 | 56 | 57 | --------------------------------------------------------------------------------