├── .gitignore ├── .gitmodules ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── binary_distill.py ├── data ├── README.md └── image_datasets.py ├── fid ├── README.md ├── compute_fid_stats.py ├── fid_model.py └── fid_zip.py ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── data.py │ └── data_torch.py ├── distributed.py ├── eval │ ├── __init__.py │ ├── fid.py │ └── inception_net.py ├── io.py ├── nn │ ├── __init__.py │ ├── functional │ │ ├── __init__.py │ │ └── functional.py │ ├── ncsnpp │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── layerspp.py │ │ └── up_or_down_sampling.py │ └── nn.py ├── optim.py ├── train.py ├── util.py └── zoo │ ├── __init__.py │ └── unet.py ├── requirements.txt ├── tc_distill.py ├── tc_distill_edm.py └── teacher ├── README.md └── download_and_convert_jax.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Teacher ckpts 2 | ckpts/ 3 | # Local logdir 4 | e/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # Images 35 | *.png 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Sphinx documentation 62 | docs/_build/ 63 | 64 | # PyBuilder 65 | target/ 66 | 67 | # Jupyter Notebook 68 | .ipynb_checkpoints 69 | 70 | # IPython 71 | profile_default/ 72 | ipython_config.py 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 78 | __pypackages__/ 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # mypy 93 | .mypy_cache/ 94 | .dmypy.json 95 | dmypy.json 96 | 97 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "edm"] 2 | path = edm 3 | url = https://github.com/NVlabs/edm.git 4 | hexsha = 62072d2612c7da05165d6233d13d17d71f213fee 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to enable reproducible and continuing research. The intention is to clearly communicate research progress. We hope the community tries out the code and any suggestions are welcome. 4 | 5 | ## Before you get started 6 | 7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation 2 | This software project accompanies the research paper [TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation](https://arxiv.org/abs/2303.04248) 3 | 4 | # Citation: 5 | ``` 6 | @article{berthelot2023tract, 7 | title={TRACT: Denoising Diffusion Models with Transitive Closure Time-Distillation}, 8 | author={Berthelot, David and Autef, Arnaud and Lin, Jierui and Yap, Dian Ang and Zhai, Shuangfei and Hu, Siyuan and Zheng, Daniel and Talbott, Walter and Gu, Eric}, 9 | journal={arXiv preprint arXiv:2303.04248}, 10 | year={2023} 11 | } 12 | ``` 13 | 14 | # Setup 15 | 16 | Git clone with `--recurse-submodules` to initialize EDM submodule properly. 17 | 18 | Setup environment variables: 19 | 20 | ```bash 21 | export ML_DATA=~/Data/DDPM-Images 22 | export PYTHONPATH=$PYTHONPATH:. 23 | ``` 24 | 25 | Then run 26 | ```bash 27 | sudo apt install python3.8-dev python3.8-venv python3-dev -y 28 | ``` 29 | 30 | Set up a virtualenv 31 | 32 | ```bash 33 | python3.8 -m venv ~/tract_venv 34 | source ~/tract_venv/bin/activate 35 | ``` 36 | 37 | or via `pyenv` 38 | 39 | ```bash 40 | pyenv install 3.8.0 41 | pyenv virtualenv 3.8.0 tract_venv 42 | pyenv local tract_venv 43 | ``` 44 | 45 | then upgrade pip 46 | ``` 47 | pip install --upgrade pip 48 | ``` 49 | 50 | Install pip pkgs 51 | ``` 52 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 53 | ``` 54 | 55 | # Setup data 56 | 57 | Please follow [README](data/README.md) to setup datasets. 58 | 59 | # Set up teacher models 60 | 61 | Please follow [README](teacher/README.md) to setup teacher checkpoints. 62 | 63 | # Set up EDM 64 | For running with NVIDIA's Elucidated model (EDM), ensure the `edm/` 65 | submodule has been initialized properly. 66 | 67 | # Real activation stats for FID 68 | 69 | Please follow the `Real activation statistics` section in [README](fid/README.md) in order to compute and save the real activation statistics to be used in FID evaluation. 70 | 71 | # Training 72 | 73 | The below commands will reproduce results from our paper when run on a cluster of 8 NVIDIA A100 or V100 GPUs. 74 | 75 | Example: Run TC distillation on Cifar10 using distillation time schedule: 1024, 32, 1. 76 | 77 | ```bash 78 | python tc_distill.py --dataset=cifar10 --time_schedule=1024,32,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96M 79 | ``` 80 | 81 | Example: Run TC distillation on Cifar10 using EDM teacher 82 | ```bash 83 | python tc_distill_edm.py --dataset=cifar10 --time_schedule=40,1 --fid_len=50000 --report_fid_len=8M --report_img_len=8M --train_len=96M --batch=512 84 | ``` 85 | 86 | Getting help 87 | 88 | ```bash 89 | python tc_distill.py --help 90 | ``` 91 | 92 | Tensorboard outputs are generated in a dir like `e/DATASET/MODEL/EXP_NAME/tb/`. For example, you can start tensorboard to view metrics like 93 | ```bash 94 | tensorboard --logdir e/cifar10/EluDDIM05TCMultiStepx0\(EluUNet\)/aug_prob@0.0_batch@8_dropout@0.0_ema@0.9996_lr@0.001_lr_warmup@None_res@32_sema@0.1_time_schedule@40,1_timesteps@40/tb/ --bind_all 95 | ``` 96 | 97 | # Evaluation 98 | 99 | Please follow [README](fid/README.md) to run FID evaluation. 100 | -------------------------------------------------------------------------------- /binary_distill.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import copy 6 | import functools 7 | import math 8 | import os 9 | import pathlib 10 | import shutil 11 | from typing import Callable, Dict, Optional 12 | 13 | import torch 14 | import torch.nn.functional 15 | from absl import app, flags 16 | 17 | import lib 18 | from lib.distributed import device, device_id, print 19 | from lib.util import FLAGS 20 | from lib.zoo.unet import UNet 21 | 22 | 23 | def get_model(name: str): 24 | if name == 'cifar10': 25 | net = UNet(in_channel=3, 26 | channel=256, 27 | emb_channel=1024, 28 | channel_multiplier=[1, 1, 1], 29 | n_res_blocks=3, 30 | attn_rezs=[8, 16], 31 | attn_heads=1, 32 | head_dim=None, 33 | use_affine_time=True, 34 | dropout=0.2, 35 | num_output=1, 36 | resample=True, 37 | num_classes=1) 38 | elif name == 'imagenet64': 39 | # imagenet model is class conditional 40 | net = UNet(in_channel=3, 41 | channel=192, 42 | emb_channel=768, 43 | channel_multiplier=[1, 2, 3, 4], 44 | n_res_blocks=3, 45 | init_rez=64, 46 | attn_rezs=[8, 16, 32], 47 | attn_heads=None, 48 | head_dim=64, 49 | use_affine_time=True, 50 | dropout=0., 51 | num_output=2, # predict signal and noise 52 | resample=True, 53 | num_classes=1000) 54 | else: 55 | raise NotImplementedError(name) 56 | return net 57 | 58 | 59 | class BinaryDistillGoogleModel(lib.train.TrainModel): 60 | R_NONE, R_STEP, R_PHASE = 'none', 'step', 'phase' 61 | R_ALL = R_NONE, R_STEP, R_PHASE 62 | 63 | def __init__(self, name: str, res: int, timesteps: int, **params): 64 | super().__init__("GoogleUNet", res, timesteps, **params) 65 | self.num_classes = 1 66 | self.shape = 3, res, res 67 | self.timesteps = timesteps 68 | model = get_model(name) 69 | if 'cifar' in name: 70 | self.ckpt_path = 'ckpts/cifar_original.pt' 71 | self.predict_both = False 72 | elif 'imagenet' in name: 73 | self.ckpt_path = 'ckpts/imagenet_original.pt' 74 | self.predict_both = False 75 | elif 'imagenet' in name: 76 | self.ckpt_path = 'ckpts/imagenet_original.pt' 77 | self.num_classes = 1000 78 | self.predict_both = True 79 | self.EVAL_COLUMNS = self.EVAL_ROWS = 8 80 | else: 81 | raise NotImplementedError(name) 82 | 83 | model.apply(functools.partial(lib.nn.functional.set_bn_momentum, momentum=1 - self.params.ema)) 84 | model.apply(functools.partial(lib.nn.functional.set_dropout, p=0)) 85 | self.model = lib.distributed.wrap(model) 86 | self.model_eval = lib.optim.ModuleEMA(model, momentum=self.params.ema).to(device_id()) 87 | self.teacher = copy.deepcopy(model).to(device_id()) 88 | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.params.lr) 89 | self.register_buffer('phase', torch.zeros((), dtype=torch.long)) 90 | self.cur_step = self.timesteps // 2 91 | 92 | def initialize_weights_from_teacher(self, logdir: pathlib.Path, teacher_ckpt: Optional[str] = None): 93 | teacher_ckpt_path = logdir / 'ckpt/teacher.ckpt' 94 | if device_id() == 0: 95 | os.makedirs(logdir / 'ckpt', exist_ok=True) 96 | shutil.copy2(self.ckpt_path, teacher_ckpt_path) 97 | 98 | lib.distributed.barrier() 99 | self.model.module.load_state_dict(torch.load(teacher_ckpt_path)) 100 | self.model_eval.module.load_state_dict(torch.load(teacher_ckpt_path)) 101 | self.self_teacher.module.load_state_dict(torch.load(teacher_ckpt_path)) 102 | self.teacher.load_state_dict(torch.load(teacher_ckpt_path)) 103 | 104 | def randn(self, n: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 105 | if generator is not None: 106 | assert generator.device == torch.device('cpu') 107 | return torch.randn((n, *self.shape), device='cpu', generator=generator, dtype=torch.double).to(self.device) 108 | 109 | def call_model(self, model: Callable, xt: torch.Tensor, index: torch.Tensor, 110 | y: Optional[torch.Tensor] = None) -> torch.Tensor: 111 | if y is None: 112 | return model(xt.float(), index.float()).double() 113 | else: 114 | return model(xt.float(), index.float(), y.long()).double() 115 | 116 | def forward(self, samples: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 117 | step = self.timesteps // self.cur_step 118 | xt = self.randn(samples, generator).to(device_id()) 119 | if self.num_classes > 1: 120 | y = torch.randint(0, self.num_classes, (samples,)).to(xt) 121 | else: 122 | y = None 123 | 124 | for t in reversed(range(0, self.timesteps, step)): 125 | ix = torch.Tensor([t + step]).long().to(device_id()), torch.Tensor([t]).long().to(device_id()) 126 | logsnr = tuple(self.logsnr_schedule_cosine(i / self.timesteps).to(xt.double()) for i in ix) 127 | g = tuple(torch.sigmoid(l).view(-1, 1, 1, 1) for l in logsnr) # Get gamma values 128 | x0 = self.call_model(self.model_eval, xt, logsnr[0].repeat(xt.shape[0]), y) 129 | xt = self.post_xt_x0(xt, x0, g[0], g[1], clip_x=True) 130 | return xt 131 | 132 | @staticmethod 133 | def logsnr_schedule_cosine(t, logsnr_min=torch.Tensor([-20.]), logsnr_max=torch.Tensor([20.])): 134 | b = torch.arctan(torch.exp(-0.5 * logsnr_max)).to(t) 135 | a = torch.arctan(torch.exp(-0.5 * logsnr_min)).to(t) - b 136 | return -2. * torch.log(torch.tan(a * t + b)) 137 | 138 | @staticmethod 139 | def predict_eps_from_x(z, x, logsnr): 140 | """eps = (z - alpha*x)/sigma.""" 141 | assert logsnr.ndim == x.ndim 142 | return torch.sqrt(1. + torch.exp(logsnr)) * (z - x * torch.rsqrt(1. + torch.exp(-logsnr))) 143 | 144 | def post_xt_x0(self, xt: torch.Tensor, out: torch.Tensor, g: torch.Tensor, g1: torch.Tensor, clip_x=False) -> torch.Tensor: 145 | if self.predict_both: 146 | assert out.shape[1] == 6 147 | model_x, model_eps = out[:, :3], out[:, 3:] 148 | # reconcile the two predictions 149 | model_x_eps = (xt - model_eps * (1 - g).sqrt()) * g.rsqrt() 150 | wx = 1 - g 151 | x0 = wx * model_x + (1. - wx) * model_x_eps 152 | else: 153 | x0 = out 154 | if clip_x: 155 | x0 = torch.clip(x0, -1., 1.) 156 | eps = (xt - x0 * g.sqrt()) * (1 - g).rsqrt() 157 | return torch.nan_to_num(x0 * g1.sqrt() + eps * (1 - g1).sqrt()) 158 | 159 | def train_op(self, info: lib.train.TrainInfo, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: 160 | if self.num_classes == 1: 161 | y = None 162 | else: 163 | y = y[:, 0] 164 | with torch.no_grad(): 165 | phase = int(info.progress * (1 - 1e-9) * math.log(self.timesteps, 2)) 166 | if phase != self.phase: 167 | print(f'Refreshing teacher {phase}') 168 | self.phase.add_(1) 169 | self.teacher.load_state_dict(self.model_eval.module.state_dict()) 170 | if self.params.reset == self.R_PHASE: 171 | self.model_eval.step.mul_(0) 172 | self.cur_step = self.cur_step // 2 173 | assert self.cur_step >= 1 174 | 175 | step = self.timesteps // self.cur_step 176 | index = torch.randint(1, 1 + (self.timesteps // step), (x.shape[0],), device=device()) * step 177 | ix = index, index - step // 2, index - step 178 | logsnr = tuple(self.logsnr_schedule_cosine(i.double() / self.timesteps).to(x.double()) for i in ix) 179 | g = tuple(torch.sigmoid(l).view(-1, 1, 1, 1) for l in logsnr) # Get gamma values 180 | noise = torch.randn_like(x) 181 | xt0 = x.double() * g[0].sqrt() + noise * (1 - g[0]).sqrt() 182 | xt1 = self.post_xt_x0(xt0, self.call_model(self.teacher, xt0, logsnr[0], y), g[0], g[1]) 183 | x_hat = self.call_model(self.teacher, xt1, logsnr[1], y) 184 | xt2 = self.post_xt_x0(xt1, x_hat, g[1], g[2]) 185 | # Find target such that self.post_xt_x0(xt0, target, g[0], g[2]) == xt2 186 | target = ((xt0 * (1 - g[2]).sqrt() - xt2 * (1 - g[0]).sqrt()) / 187 | ((g[0] * (1 - g[2])).sqrt() - (g[2] * (1 - g[0])).sqrt())) 188 | # use predicted x0 as target when t=0 189 | target += (index == step).view(-1, 1, 1, 1) * (x_hat[:, :3] - target) 190 | 191 | self.opt.zero_grad(set_to_none=True) 192 | pred = self.call_model(self.model, xt0, logsnr[0], y) 193 | if self.predict_both: 194 | assert pred.shape[1] == 6 195 | model_x, model_eps = pred[:, :3], pred[:, 3:] 196 | # reconcile the two predictions 197 | model_x_eps = (xt0 - model_eps * (1 - g[0]).sqrt()) * g[0].rsqrt() 198 | wx = 1 - g[0] 199 | pred_x = wx * model_x + (1. - wx) * model_x_eps 200 | else: 201 | pred_x = pred 202 | 203 | loss = ((g[0] / (1 - g[0])).clamp(1) * (pred_x - target.detach()).square()).mean(0).sum() 204 | loss.backward() 205 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 206 | self.opt.step() 207 | self.model_eval.update(self.model) 208 | return {'loss/global': loss, 'stat/timestep': self.cur_step} 209 | 210 | 211 | @ lib.distributed.auto_distribute 212 | def main(_): 213 | data = lib.data.DATASETS[FLAGS.dataset]() 214 | model = BinaryDistillGoogleModel(FLAGS.dataset, data.res, FLAGS.timesteps, reset=FLAGS.reset, 215 | batch=FLAGS.batch, lr=FLAGS.lr, ema=FLAGS.ema) 216 | logdir = lib.util.artifact_dir(FLAGS.dataset, model.logdir) 217 | # resume from previous run (ckpt will be loaded in train.py) 218 | if FLAGS.restart_ckpt: 219 | lib.distributed.barrier() 220 | 221 | if FLAGS.eval: 222 | if not FLAGS.restart_ckpt: 223 | model.initialize_weights_from_teacher(logdir, FLAGS.teacher_ckpt) 224 | model.eval() 225 | with torch.no_grad(): 226 | generator = torch.Generator(device='cpu') 227 | generator.manual_seed(123623113456) 228 | model(4, generator) 229 | else: 230 | train, fid = data.make_dataloaders() 231 | if not FLAGS.restart_ckpt: 232 | model.initialize_weights_from_teacher(logdir, FLAGS.teacher_ckpt) 233 | model.train_loop(train, fid, FLAGS.batch, FLAGS.train_len, FLAGS.report_len, logdir, fid_len=FLAGS.fid_len) 234 | 235 | 236 | if __name__ == '__main__': 237 | flags.DEFINE_bool('eval', False, help='Whether to run model evaluation.') 238 | flags.DEFINE_enum('reset', BinaryDistillGoogleModel.R_NONE, BinaryDistillGoogleModel.R_ALL, help='EMA reset mode.') 239 | flags.DEFINE_float('ema', 0.9995, help='Exponential Moving Average of model.') 240 | flags.DEFINE_float('lr', 2e-4, help='Learning rate.') 241 | flags.DEFINE_integer('fid_len', 4096, help='Number of samples for FID evaluation.') 242 | flags.DEFINE_integer('timesteps', 1024, help='Sampling timesteps.') 243 | flags.DEFINE_string('dataset', 'cifar10', help='Training dataset.') 244 | flags.DEFINE_string('train_len', '64M', help='Training duration in samples per distillation logstep.') 245 | flags.DEFINE_string('report_len', '1M', help='Reporting interval in samples.') 246 | flags.FLAGS.set_default('report_img_len', '1M') 247 | flags.FLAGS.set_default('report_fid_len', '4M') 248 | flags.DEFINE_string('restart_ckpt', None, 249 | help='Trainer checkpoint in the form : with of the form "ckpt/*.pth" .') 250 | flags.DEFINE_string('teacher_ckpt', None, 251 | help='Teacher checkpoint in the form : with of the form "ckpt/model_*.ckpt".') 252 | app.run(lib.distributed.main(main)) 253 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | ## Cifar10 4 | Download Cifar10 data through `torchvision`. 5 | ```bash 6 | python -c "import torchvision; import os; torchvision.datasets.CIFAR10(os.getenv('ML_DATA'), train=True, download=True)" 7 | ``` 8 | 9 | ## Class-conditional 64x64 ImageNet 10 | 11 | we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to the [2012 challenge page](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php) and download the data in "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 12 | 13 | ## Preparation for Torch dataloaders 14 | We need to prepare `$ML_DATA/imagenet/train/`. These instructions are adapted from a [pytorch example script](https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh) 15 | 16 | ```bash 17 | # Create train directory; move .tar file; change directory 18 | mkdir -p $ML_DATA/imagenet/train && mv ILSVRC2012_img_train.tar $ML_DATA/imagenet/train/ && cd $ML_DATA/imagenet/train 19 | # Extract training set; remove compressed file 20 | tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar 21 | # 22 | # At this stage imagenet/train will contain 1000 compressed .tar files, one for each category 23 | # 24 | # For each .tar file: 25 | # 1. create directory with same name as .tar file 26 | # 2. extract and copy contents of .tar file into directory 27 | # 3. remove .tar file 28 | find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done 29 | ``` 30 | 31 | Our data downloading and pre-processing pipeline is largely adapted from https://github.com/openai/guided-diffusion. Thanks for open-sourcing! 32 | -------------------------------------------------------------------------------- /data/image_datasets.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # Adapted from OpenAI https://github.com/openai/guided-diffusion 6 | import math 7 | import random 8 | 9 | from PIL import Image 10 | import blobfile as bf 11 | import numpy as np 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def _list_image_files_recursively(data_dir): 16 | results = [] 17 | for entry in sorted(bf.listdir(data_dir)): 18 | full_path = bf.join(data_dir, entry) 19 | ext = entry.split(".")[-1] 20 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 21 | results.append(full_path) 22 | elif bf.isdir(full_path): 23 | results.extend(_list_image_files_recursively(full_path)) 24 | return results 25 | 26 | 27 | class ImageDataset(Dataset): 28 | def __init__( 29 | self, 30 | resolution, 31 | data_dir, 32 | shard=0, 33 | num_shards=1, 34 | random_crop=False, 35 | ): 36 | super().__init__() 37 | image_paths = _list_image_files_recursively(data_dir) 38 | # Assume classes are the first part of the filename, 39 | # before an underscore. 40 | class_names = [bf.basename(path).split("_")[0] for path in image_paths] 41 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 42 | classes = [sorted_classes[x] for x in class_names] 43 | 44 | self.resolution = resolution 45 | self.local_images = image_paths[shard:][::num_shards] 46 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 47 | self.random_crop = random_crop 48 | 49 | def __len__(self): 50 | return len(self.local_images) 51 | 52 | def __getitem__(self, idx): 53 | path = self.local_images[idx] 54 | with bf.BlobFile(path, "rb") as f: 55 | pil_image = Image.open(f) 56 | pil_image.load() 57 | pil_image = pil_image.convert("RGB") 58 | 59 | if self.random_crop: 60 | arr = random_crop_arr(pil_image, self.resolution) 61 | else: 62 | arr = center_crop_arr(pil_image, self.resolution) 63 | 64 | out_dict = {} 65 | if self.local_classes is not None: 66 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 67 | return arr, out_dict["y"] 68 | 69 | 70 | def center_crop_arr(pil_image, image_size): 71 | # We are not on a new enough PIL to support the `reducing_gap` 72 | # argument, which uses BOX downsampling at powers of two first. 73 | # Thus, we do it by hand to improve downsample quality. 74 | while min(*pil_image.size) >= 2 * image_size: 75 | pil_image = pil_image.resize( 76 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 77 | ) 78 | 79 | scale = image_size / min(*pil_image.size) 80 | pil_image = pil_image.resize( 81 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 82 | ) 83 | 84 | arr = np.array(pil_image) 85 | crop_y = (arr.shape[0] - image_size) // 2 86 | crop_x = (arr.shape[1] - image_size) // 2 87 | return arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size] 88 | 89 | 90 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 91 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 92 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 93 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 94 | 95 | # We are not on a new enough PIL to support the `reducing_gap` 96 | # argument, which uses BOX downsampling at powers of two first. 97 | # Thus, we do it by hand to improve downsample quality. 98 | while min(*pil_image.size) >= 2 * smaller_dim_size: 99 | pil_image = pil_image.resize( 100 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 101 | ) 102 | 103 | scale = smaller_dim_size / min(*pil_image.size) 104 | pil_image = pil_image.resize( 105 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 106 | ) 107 | 108 | arr = np.array(pil_image) 109 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 110 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 111 | return arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size] 112 | -------------------------------------------------------------------------------- /fid/README.md: -------------------------------------------------------------------------------- 1 | # FID 2 | 3 | ## Real activation statistics 4 | 5 | In order to compute FID, we first need to compute the activation statistics over all real images. 6 | 7 | Run 8 | 9 | ```bash 10 | python compute_fid_stats.py --dataset cifar10 11 | python compute_fid_stats.py --dataset imagenet64 12 | ``` 13 | 14 | to compute and save the mean and std of real activation on CIFAR10 and 64x64 ImageNet. 15 | 16 | ## Compute FID with a trained model 17 | 18 | ```bash 19 | python fid_model.py --dataset cifar10 --ckpt={the path to your model} 20 | ``` 21 | 22 | ## Compute FID from generated samples 23 | 24 | ```bash 25 | python fid_zip.py {the path to your zip file} --dataset cifar10 26 | ``` -------------------------------------------------------------------------------- /fid/compute_fid_stats.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pathlib 7 | 8 | import lib 9 | import numpy as np 10 | import torch 11 | from absl import app, flags 12 | from lib.distributed import auto_distribute 13 | from lib.util import FLAGS, artifact_dir 14 | 15 | ML_DATA = pathlib.Path(os.getenv('ML_DATA')) 16 | 17 | @auto_distribute 18 | def main(argv): 19 | data = lib.data.DATASETS[FLAGS.dataset]() 20 | real = data.make_dataloaders()[1] 21 | num_samples = len(real) * FLAGS.batch 22 | with torch.no_grad(): 23 | fid = lib.eval.FID(FLAGS.dataset, (3, data.res, data.res)) 24 | real_activations = fid.data_activations(real, num_samples, cpu=True) 25 | m_real, s_real = fid.calculate_activation_statistics(real_activations) 26 | np.save(f'{ML_DATA}/{FLAGS.dataset}_activation_mean.npy', m_real.numpy()) 27 | np.save(f'{ML_DATA}/{FLAGS.dataset}_activation_std.npy', s_real.numpy()) 28 | 29 | 30 | if __name__ == '__main__': 31 | flags.DEFINE_string('dataset', 'cifar10', help='Training dataset.') 32 | app.run(lib.distributed.main(main)) 33 | -------------------------------------------------------------------------------- /fid/fid_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """Compute FID and approximation at 50,000 for zip file of samples.""" 6 | import pathlib 7 | import time 8 | from types import SimpleNamespace 9 | from typing import Optional 10 | 11 | import lib 12 | import torch 13 | from absl import app, flags 14 | from lib.distributed import auto_distribute, device_id 15 | from lib.eval.fid import FID 16 | from lib.io import Summary, SummaryWriter, zip_batch_as_png 17 | from lib.util import FLAGS 18 | from lib.zoo.unet import UNet 19 | 20 | 21 | def logsnr_schedule_cosine(t, logsnr_min=torch.Tensor([-20.]), logsnr_max=torch.Tensor([20.])): 22 | b = torch.arctan(torch.exp(-0.5 * logsnr_max)) 23 | a = torch.arctan(torch.exp(-0.5 * logsnr_min)) - b 24 | return -2. * torch.log(torch.tan(a * t + b)) 25 | 26 | 27 | def predict_eps_from_x(z, x, logsnr): 28 | """eps = (z - alpha*x)/sigma.""" 29 | assert logsnr.ndim == x.ndim 30 | return torch.sqrt(1. + torch.exp(logsnr)) * (z - x * torch.rsqrt(1. + torch.exp(-logsnr))) 31 | 32 | 33 | def predict_x_from_eps(z, eps, logsnr): 34 | """x = (z - sigma*eps)/alpha.""" 35 | assert logsnr.ndim == eps.ndim 36 | return torch.sqrt(1. + torch.exp(-logsnr)) * (z - eps * torch.rsqrt(1. + torch.exp(logsnr))) 37 | 38 | 39 | class ModelFID(torch.nn.Module): 40 | COLORS = 3 41 | 42 | def __init__(self, name: str, res: int, timesteps: int, **params): 43 | super().__init__() 44 | self.name = name 45 | if name == 'cifar10': 46 | self.model = UNet(in_channel=3, 47 | channel=256, 48 | emb_channel=1024, 49 | channel_multiplier=[1, 1, 1], 50 | n_res_blocks=3, 51 | attn_rezs=[8, 16], 52 | attn_heads=1, 53 | head_dim=None, 54 | use_affine_time=True, 55 | dropout=0.2, 56 | num_output=1, 57 | resample=True, 58 | num_classes=1).to(device_id()) 59 | self.shape = 3, 32, 32 60 | self.mean_type = 'x' 61 | self.ckpt_name = 'cifar_original.pt' 62 | self.num_classes = 1 63 | elif name == 'imagenet64': 64 | # imagenet model is class conditional 65 | self.model = UNet(in_channel=3, 66 | channel=192, 67 | emb_channel=768, 68 | channel_multiplier=[1, 2, 3, 4], 69 | n_res_blocks=3, 70 | init_rez=64, 71 | attn_rezs=[8, 16, 32], 72 | attn_heads=None, 73 | head_dim=64, 74 | use_affine_time=True, 75 | dropout=0., 76 | num_output=2, # predict signal and noise 77 | resample=True, 78 | num_classes=1000).to(device_id()) 79 | self.shape = 3, 64, 64 80 | self.mean_type = 'both' 81 | self.ckpt_name = 'imagenet_original.pt' 82 | self.num_classes = 1000 83 | else: 84 | raise NotImplementedError(name) 85 | self.params = SimpleNamespace(res=res, timesteps=timesteps, **params) 86 | self.timesteps = timesteps 87 | self.logstep = 0 88 | self.clip_x = True 89 | 90 | @property 91 | def logdir(self) -> str: 92 | params = ','.join(f'{k}={v}' for k, v in sorted(vars(self.params).items())) 93 | return f'{self.__class__.__name__}({params})' 94 | 95 | def initialize_weights(self, logdir: pathlib.Path): 96 | self.model.load_state_dict(torch.load(self.params.ckpt)) 97 | 98 | def run_model(self, z, logsnr, y=None): 99 | if self.mean_type == 'x': 100 | model_x = self.model(z.float(), logsnr.float(), y).double() 101 | logsnr = logsnr[:, None, None, None] 102 | elif self.mean_type == 'both': 103 | output = self.model(z.float(), logsnr.float(), y).double() 104 | model_x, model_eps = output[:, :3], output[:, 3:] 105 | # reconcile the two predictions 106 | logsnr = logsnr[:, None, None, None] 107 | model_x_eps = predict_x_from_eps(z=z, eps=model_eps, logsnr=logsnr) 108 | wx = torch.sigmoid(-logsnr) 109 | model_x = wx * model_x + (1. - wx) * model_x_eps 110 | else: 111 | raise NotImplementedError(self.mean_type) 112 | 113 | # clipping 114 | if self.clip_x: 115 | model_x = torch.clip(model_x, -1., 1.) 116 | 117 | model_eps = predict_eps_from_x(z=z, x=model_x, logsnr=logsnr) 118 | return {'model_x': model_x, 119 | 'model_eps': model_eps} 120 | 121 | def ddim_step(self, t, z_t, y=None, step=1024): 122 | logsnr_t = logsnr_schedule_cosine((t+step) / self.timesteps).to(z_t) 123 | logsnr_s = logsnr_schedule_cosine(t / self.timesteps).to(z_t) 124 | model_out = self.run_model(z=z_t, logsnr=logsnr_t.repeat( 125 | z_t.shape[0]), y=y.to(z_t).long() if y is not None else None) 126 | x_pred_t = model_out['model_x'] 127 | eps_pred_t = model_out['model_eps'] 128 | stdv_s = torch.sqrt(torch.sigmoid(-logsnr_s)) 129 | alpha_s = torch.sqrt(torch.sigmoid(logsnr_s)) 130 | z_s_pred = alpha_s * x_pred_t + stdv_s * eps_pred_t 131 | return torch.where(torch.Tensor([t]).to(x_pred_t) == 0, x_pred_t, z_s_pred) 132 | 133 | def sample_loop(self, init_x, y=None, step=1024): 134 | # loop over t = num_steps-1, ..., 0 135 | image = init_x 136 | for t in reversed(range(self.timesteps // step)): 137 | image = self.ddim_step(t * step, image, y, step=step) 138 | return image 139 | 140 | def forward(self, samples: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 141 | if generator is not None: 142 | assert generator.device == torch.device('cpu') 143 | init_x = torch.randn((samples, *self.shape), device='cpu', generator=generator, dtype=torch.double).to(device_id()) 144 | else: 145 | init_x = torch.randn((samples, *self.shape), dtype=torch.double).to(device_id()) 146 | if self.name == 'imagenet64': 147 | y = torch.randint(0, self.num_classes, (samples,)).to(device_id()) 148 | else: 149 | y = None 150 | return self.sample_loop(init_x, y, step=1 << self.logstep) 151 | 152 | 153 | @auto_distribute 154 | def main(_): 155 | data = lib.data.DATASETS[FLAGS.dataset]() 156 | model = ModelFID(FLAGS.dataset, data.res, FLAGS.timesteps, 157 | batch=FLAGS.batch, fid_len=FLAGS.fid_len, ckpt=FLAGS.ckpt) 158 | logdir = lib.util.artifact_dir(FLAGS.dataset, model.logdir) 159 | 160 | 161 | model.initialize_weights(logdir) 162 | model.eval() 163 | 164 | if FLAGS.eval: 165 | model.eval() 166 | with torch.no_grad(): 167 | generator = torch.Generator(device='cpu') 168 | generator.manual_seed(123623113456) 169 | x = model(4, generator) 170 | open('debug_fid_model.png', 'wb').write(lib.util.to_png(x.view(2, 2, *x.shape[1:]))) 171 | import numpy as np 172 | np.save('debug_arr_fid_model.npy', x.detach().cpu().numpy()) 173 | return 174 | 175 | def eval(logstep: int): 176 | model.logstep = logstep 177 | summary = Summary() 178 | t0 = time.time() 179 | with torch.no_grad(): 180 | fid = FID(FLAGS.dataset, (model.COLORS, model.params.res, model.params.res)) 181 | fake_activations, fake_samples = fid.generate_activations_and_samples(model, FLAGS.fid_len) 182 | timesteps = model.params.timesteps >> model.logstep 183 | zip_batch_as_png(fake_samples, logdir / f'samples_{FLAGS.fid_len}_timesteps_{timesteps}.zip') 184 | fidn, fid50 = fid.approximate_fid(fake_activations) 185 | summary.scalar('eval/logstep', logstep) 186 | summary.scalar('eval/timesteps', timesteps) 187 | summary.scalar(f'eval/fid({FLAGS.fid_len})', fidn) 188 | summary.scalar('eval/fid(50000)', fid50) 189 | summary.scalar('system/eval_time', time.time() - t0) 190 | data_logger.write(summary, logstep) 191 | if lib.distributed.is_master(): 192 | print(f'Logstep {logstep} Timesteps {timesteps}') 193 | print(summary) 194 | 195 | with SummaryWriter.create(logdir) as data_logger: 196 | if FLAGS.denoise_steps: 197 | logstep = lib.util.ilog2(FLAGS.timesteps // FLAGS.denoise_steps) 198 | eval(logstep) 199 | else: 200 | for logstep in range(lib.util.ilog2(FLAGS.timesteps) + 1): 201 | eval(logstep) 202 | 203 | 204 | if __name__ == '__main__': 205 | flags.DEFINE_bool('eval', False, help='Whether to run model evaluation.') 206 | flags.DEFINE_integer('fid_len', 4096, help='Number of samples for FID evaluation.') 207 | flags.DEFINE_integer('timesteps', 1024, help='Sampling timesteps.') 208 | flags.DEFINE_string('dataset', 'cifar10', help='Dataset.') 209 | flags.DEFINE_integer('denoise_steps', None, help='Denoising timesteps.') 210 | flags.DEFINE_string('ckpt', None, help='Path to the model checkpoint.') 211 | app.run(lib.distributed.main(main)) 212 | -------------------------------------------------------------------------------- /fid/fid_zip.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """Compute FID and approximation at 50,000 for zip file of samples.""" 6 | import time 7 | import zipfile 8 | 9 | import lib 10 | import torch 11 | import torchvision.transforms.functional 12 | from absl import app, flags 13 | from lib.distributed import auto_distribute, device_id, is_master, world_size 14 | from lib.util import FLAGS 15 | from PIL import Image 16 | 17 | 18 | @auto_distribute 19 | def main(argv): 20 | def zip_iterator(filename: str, batch: int): 21 | with zipfile.ZipFile(filename, 'r') as fzip: 22 | x = [] 23 | fn_list = [fn for fn in fzip.namelist() if fn.endswith('.png')] 24 | assert len(fn_list) >= FLAGS.fid_len 25 | for fn in fn_list[device_id()::world_size()]: 26 | with fzip.open(fn, 'r') as f: 27 | y = torchvision.transforms.functional.to_tensor(Image.open(f)) 28 | x.append(2 * y - 1) 29 | if len(x) == batch: 30 | yield torch.stack(x), None 31 | x = [] 32 | 33 | t0 = time.time() 34 | data = lib.data.DATASETS[FLAGS.dataset]() 35 | fake = (x for x in zip_iterator(argv[1], FLAGS.batch // world_size())) 36 | with torch.no_grad(): 37 | fid = lib.eval.FID(FLAGS.dataset, (3, data.res, data.res)) 38 | fake_activations = fid.data_activations(fake, FLAGS.fid_len) 39 | fid, fid50 = fid.approximate_fid(fake_activations) 40 | if is_master(): 41 | print(f'dataset={FLAGS.dataset}') 42 | print(f'fid{FLAGS.fid_len}={fid}') 43 | print(f'fid(50000)={fid50}') 44 | print(f'time={time.time() - t0}') 45 | 46 | 47 | if __name__ == '__main__': 48 | flags.DEFINE_integer('fid_len', 4096, help='Number of samples for FID evaluation.') 49 | flags.DEFINE_string('dataset', 'cifar10', help='Training dataset.') 50 | app.run(lib.distributed.main(main)) 51 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from . import data # noqa 6 | from . import distributed # noqa 7 | from . import eval # noqa 8 | from . import io # noqa 9 | from . import nn # noqa 10 | from . import optim # noqa 11 | from . import train # noqa 12 | from . import util # noqa 13 | -------------------------------------------------------------------------------- /lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pathlib 7 | 8 | from .data import * 9 | from .data_torch import * 10 | 11 | 12 | ML_DATA = pathlib.Path(os.getenv('ML_DATA')) 13 | -------------------------------------------------------------------------------- /lib/data/data.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['Dataset'] 7 | 8 | from typing import Callable 9 | 10 | from absl import flags 11 | 12 | flags.DEFINE_integer('batch', 256, help='Batch size.') 13 | 14 | 15 | class Dataset: 16 | def __init__(self, res: int, make_train: Callable, make_fid: Callable): 17 | self.res = res 18 | self.make_train = make_train 19 | self.make_fid = make_fid 20 | -------------------------------------------------------------------------------- /lib/data/data_torch.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 3 | # 4 | 5 | __all__ = ['make_cifar10', 'make_imagenet64', 'DATASETS'] 6 | 7 | import os 8 | import pathlib 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.distributed 13 | import torch.nn.functional 14 | import torchvision.datasets 15 | import torchvision.transforms.functional 16 | from lib.util import FLAGS 17 | from torch.utils.data import DataLoader 18 | from torchvision.transforms import Compose 19 | 20 | from . import data 21 | 22 | ML_DATA = pathlib.Path(os.getenv('ML_DATA')) 23 | 24 | 25 | class DatasetTorch(data.Dataset): 26 | def make_dataloaders(self, **kwargs) -> Tuple[DataLoader, DataLoader]: 27 | batch = FLAGS.batch 28 | if torch.distributed.is_initialized(): 29 | assert batch % torch.distributed.get_world_size() == 0 30 | batch //= torch.distributed.get_world_size() 31 | return (DataLoader(self.make_train(), shuffle=True, drop_last=True, batch_size=batch, 32 | num_workers=4, prefetch_factor=8, persistent_workers=True, **kwargs), 33 | DataLoader(self.make_fid(), shuffle=True, drop_last=True, batch_size=batch, 34 | num_workers=4, prefetch_factor=8, persistent_workers=True, **kwargs)) 35 | 36 | 37 | def normalize(x: torch.Tensor) -> torch.Tensor: 38 | return 2 * x - 1 39 | 40 | 41 | def make_cifar10() -> DatasetTorch: 42 | transforms = [ 43 | torchvision.transforms.ToTensor(), 44 | normalize, 45 | ] 46 | transforms_fid = Compose(transforms) 47 | transforms_train = Compose(transforms + [torchvision.transforms.RandomHorizontalFlip()]) 48 | fid = lambda: torchvision.datasets.CIFAR10(str(ML_DATA), train=True, transform=transforms_fid, download=True) 49 | train = lambda: torchvision.datasets.CIFAR10(str(ML_DATA), train=True, transform=transforms_train, download=True) 50 | return DatasetTorch(32, train, fid) 51 | 52 | def make_imagenet64() -> DatasetTorch: 53 | transforms = [ 54 | torchvision.transforms.ToTensor(), 55 | torchvision.transforms.CenterCrop(64), 56 | normalize, 57 | ] 58 | transforms_fid = Compose(transforms) 59 | transforms_train = Compose(transforms + [torchvision.transforms.RandomHorizontalFlip()]) 60 | fid = lambda: torchvision.datasets.ImageFolder(str(ML_DATA / "imagenet" / "train"), transform=transforms_fid) 61 | train = lambda: torchvision.datasets.ImageFolder(str(ML_DATA / "imagenet" / "train"), transform=transforms_train) 62 | return DatasetTorch(64, train, fid) 63 | 64 | 65 | DATASETS = { 66 | 'cifar10': make_cifar10, 67 | 'imagenet64': make_imagenet64, 68 | } 69 | -------------------------------------------------------------------------------- /lib/distributed.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | """ 6 | Single machine, multi GPU training support 7 | """ 8 | 9 | __all__ = ['WrapModel', 'auto_distribute', 'barrier', 'device', 'device_id', 'gather_tensor', 'is_master', 'main', 10 | 'print', 'reduce_dict_mean', 'tqdm', 'tqdm_module', 'tqdm_with', 'trange', 'world_size', 'wrap'] 11 | 12 | import builtins 13 | import contextlib 14 | import functools 15 | import os 16 | import time 17 | from types import SimpleNamespace 18 | from typing import Callable, Dict, Iterable, Optional 19 | 20 | import torch 21 | import torch.distributed 22 | import tqdm as tqdm_module 23 | 24 | from .util import FLAGS, setup 25 | 26 | 27 | class WrapModel(torch.nn.Module): 28 | def __init__(self, m: torch.nn.Module): 29 | super().__init__() 30 | self.module = m 31 | 32 | def forward(self, *args, **kwargs): 33 | return self.module(*args, **kwargs) 34 | 35 | def auto_distribute(f: Callable) -> Callable: 36 | """Automatically make a function distributed""" 37 | 38 | @functools.wraps(f) 39 | def wrapped(node_rank: Optional[int], world_size: Optional[int], flag_values: SimpleNamespace, *args): 40 | if node_rank is None: 41 | return f(*args) 42 | setup(quiet=True, flags_values=flag_values) 43 | os.environ['MASTER_ADDR'] = 'localhost' 44 | os.environ['MASTER_PORT'] = '12359' 45 | 46 | rank = node_rank 47 | torch.distributed.init_process_group('nccl', rank=rank, world_size=torch.cuda.device_count()) 48 | time.sleep(1) 49 | try: 50 | return f(*args) 51 | finally: 52 | torch.distributed.destroy_process_group() 53 | 54 | return wrapped 55 | 56 | 57 | def barrier(): 58 | if torch.distributed.is_initialized(): 59 | torch.distributed.barrier() 60 | 61 | 62 | def device() -> str: 63 | return f'cuda:{device_id()}' 64 | 65 | 66 | def device_id() -> int: 67 | if not torch.distributed.is_initialized(): 68 | return 0 69 | return torch.distributed.get_rank() % 8 70 | 71 | 72 | def gather_tensor(x: torch.Tensor) -> torch.Tensor: 73 | """Returns a concatenated tensor from all the devices.""" 74 | if not torch.distributed.is_initialized(): 75 | return x 76 | x_list = [torch.empty_like(x) for _ in range(world_size())] 77 | torch.distributed.all_gather(x_list, x, async_op=False) 78 | return torch.cat(x_list, dim=0) 79 | 80 | 81 | def is_master() -> bool: 82 | return device_id() == 0 83 | 84 | 85 | def main(main_fn: Callable) -> Callable: 86 | """Main function that automatically handle multiprocessing""" 87 | 88 | @functools.wraps(main_fn) 89 | def wrapped(*args): 90 | setup() 91 | if torch.cuda.device_count() == 1: 92 | return main_fn(None, None, FLAGS, *args) 93 | num_gpus = torch.cuda.device_count() 94 | torch.multiprocessing.spawn(main_fn, args=(num_gpus, FLAGS, *args), nprocs=num_gpus, join=True) 95 | 96 | return wrapped 97 | 98 | 99 | def print(*args, **kwargs): 100 | if is_master(): 101 | builtins.print(*args, **kwargs) 102 | 103 | 104 | def reduce_dict_mean(d: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 105 | """Mean reduce the tensor in a dict.""" 106 | if not torch.distributed.is_initialized(): 107 | return d 108 | d = {k: (v if isinstance(v, torch.Tensor) else torch.tensor(v)).to(device_id()) for k, v in d.items()} 109 | e = {k: [torch.empty_like(v) for _ in range(world_size())] for k, v in d.items()} 110 | # Ideally we should be using all_reduce, but it mysteriously returns incorrect results for the loss 111 | [v.wait() for v in [torch.distributed.all_gather(e[k], d[k], async_op=True) for k in d]] 112 | return {k: sum(v) / len(v) for k, v in e.items()} 113 | 114 | 115 | def tqdm(iterable: Iterable, **kwargs) -> Iterable: 116 | return tqdm_module.tqdm(iterable, **kwargs) 117 | 118 | 119 | def tqdm_with(**kwargs) -> Iterable: 120 | class Noop: 121 | def update(self, *args, **kwargs): 122 | pass 123 | 124 | @contextlib.contextmanager 125 | def noop(): 126 | yield Noop() 127 | 128 | return tqdm_module.tqdm(**kwargs) 129 | 130 | 131 | def trange(*args, **kwargs): 132 | return tqdm_module.trange(*args, **kwargs) 133 | 134 | 135 | def rank() -> int: 136 | if not torch.distributed.is_initialized(): 137 | return 1 138 | return torch.distributed.get_rank() 139 | 140 | 141 | def world_size() -> int: 142 | if not torch.distributed.is_initialized(): 143 | return 1 144 | return torch.distributed.get_world_size() 145 | 146 | 147 | def wrap(m: torch.nn.Module): 148 | if not torch.distributed.is_initialized(): 149 | return WrapModel(m.to(device())) 150 | return torch.nn.parallel.DistributedDataParallel(m.to(device_id()), device_ids=[device_id()]) 151 | -------------------------------------------------------------------------------- /lib/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from .fid import * # noqa 6 | -------------------------------------------------------------------------------- /lib/eval/fid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import pathlib 8 | from typing import Iterable, Tuple 9 | 10 | import numpy as np 11 | import scipy 12 | import torch 13 | import torch.nn.functional 14 | from lib.distributed import (barrier, device_id, gather_tensor, is_master, 15 | trange, world_size) 16 | from lib.util import FLAGS, to_numpy 17 | 18 | from .inception_net import InceptionV3 19 | 20 | ML_DATA = pathlib.Path(os.getenv('ML_DATA')) 21 | 22 | 23 | class FID: 24 | def __init__(self, dataset: str, shape: Tuple[int, int, int], dims: int = 2048): 25 | assert dataset in ('cifar10', 'imagenet64') 26 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 27 | self.dims = dims 28 | self.shape = shape 29 | self.model = InceptionV3([block_idx]).eval().to(device_id()) 30 | self.post = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1), torch.nn.Flatten()) 31 | if pathlib.Path(f'{ML_DATA}/{dataset}_activation_mean.npy').exists(): 32 | self.real_activations_mean = torch.from_numpy(np.load(f'{ML_DATA}/{dataset}_activation_mean.npy')) 33 | self.real_activations_std = torch.from_numpy(np.load(f'{ML_DATA}/{dataset}_activation_std.npy')) 34 | 35 | def generate_activations_and_samples(self, model: torch.nn.Module, n: int) -> Tuple[torch.Tensor, torch.Tensor]: 36 | barrier() 37 | samples = torch.empty((n, *self.shape)) 38 | activations = torch.empty((n, self.dims), dtype=torch.double).to(device_id()) 39 | k = world_size() 40 | assert FLAGS.batch % k == 0 41 | for i in trange(0, n, FLAGS.batch, desc='Generating FID samples'): 42 | p = min(n - i, FLAGS.batch) 43 | x = model(FLAGS.batch // k).float() 44 | # Discretize to {0,...,255} and project back to [-1,1] 45 | x = torch.round(127.5 * (x + 1)).clamp(0, 255) / 127.5 - 1 46 | y = self.post(self.model(x)[0]) 47 | samples[i: i + p] = gather_tensor(x)[:p] 48 | activations[i: i + p] = gather_tensor(y)[:p] 49 | return activations, samples 50 | 51 | def data_activations(self, iterator: Iterable, n: int, cpu: bool = False) -> torch.Tensor: 52 | activations = torch.empty((n, self.dims), dtype=torch.double) 53 | if not cpu: 54 | activations = activations.to(device_id()) 55 | k = world_size() 56 | it = iter(iterator) 57 | for i in trange(0, n, FLAGS.batch, desc='Calculating activations'): 58 | x = next(it)[0] 59 | p = min((n - i) // k, x.shape[0]) 60 | y = self.post(self.model(x.to(device_id()))[0]) 61 | activations[i: i + k * p] = gather_tensor(y[:p]).cpu() if cpu else gather_tensor(y[:p]) 62 | return activations 63 | 64 | @staticmethod 65 | def calculate_activation_statistics(activations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 66 | return activations.mean(0), torch.cov(activations.T) 67 | 68 | def calculate_fid(self, fake_activations: torch.Tensor) -> float: 69 | m_fake, s_fake = self.calculate_activation_statistics(fake_activations) 70 | m_real = self.real_activations_mean.to(m_fake) 71 | s_real = self.real_activations_std.to(s_fake) 72 | return self.calculate_frechet_distance(m_fake, s_fake, m_real, s_real) 73 | 74 | def approximate_fid(self, fake_activations: torch.Tensor, n: int = 50_000) -> Tuple[float, float]: 75 | k = fake_activations.shape[0] 76 | fid = self.calculate_fid(fake_activations) 77 | fid_half = [] 78 | for it in range(5): 79 | sel_fake = np.random.choice(k, k // 2, replace=False) 80 | fid_half.append(self.calculate_fid(fake_activations[sel_fake])) 81 | fid_half = np.median(fid_half) 82 | return fid, fid + (fid_half - fid) * (k / n - 1) 83 | 84 | def calculate_frechet_distance(self, mu1: torch.Tensor, sigma1: torch.Tensor, 85 | mu2: torch.Tensor, sigma2: torch.Tensor, eps: float = 1e-6) -> float: 86 | """Numpy implementation of the Frechet Distance. 87 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 88 | and X_2 ~ N(mu_2, C_2) is 89 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 90 | Stable version by Dougal J. Sutherland. 91 | Params: 92 | -- mu1 : Numpy array containing the activations of a layer of the 93 | inception net (like returned by the function 'get_predictions') 94 | for generated samples. 95 | -- mu2 : The sample mean over activations, precalculated on an 96 | representative data set. 97 | -- sigma1: The covariance matrix over activations for generated samples. 98 | -- sigma2: The covariance matrix over activations, precalculated on an 99 | representative data set. 100 | Returns: 101 | -- : The Frechet Distance. 102 | """ 103 | if not is_master(): 104 | return 0 105 | mu1, mu2, sigma1, sigma2 = (to_numpy(x) for x in (mu1, mu2, sigma1, sigma2)) 106 | mu1 = np.atleast_1d(mu1) 107 | mu2 = np.atleast_1d(mu2) 108 | sigma1 = np.atleast_2d(sigma1) 109 | sigma2 = np.atleast_2d(sigma2) 110 | assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' 111 | assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' 112 | diff = mu1 - mu2 113 | 114 | # Product might be almost singular 115 | covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)[0] 116 | if not np.isfinite(covmean).all(): 117 | print(f'fid calculation produces singular product; adding {eps} to diagonal of cov estimates') 118 | offset = np.eye(sigma1.shape[0]) * eps 119 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 120 | 121 | # Numerical error might give slight imaginary component 122 | if np.iscomplexobj(covmean): 123 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 124 | m = np.max(np.abs(covmean.imag)) 125 | raise ValueError(f'Imaginary component {m}') 126 | covmean = covmean.real 127 | 128 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean) 129 | -------------------------------------------------------------------------------- /lib/eval/inception_net.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pathlib 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision 12 | 13 | import lib 14 | 15 | try: 16 | from torchvision.models.utils import load_state_dict_from_url 17 | except ImportError: 18 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 19 | 20 | # Inception weights ported to Pytorch from 21 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 23 | FID_WEIGHTS_FILE = 'pt_inception-2015-12-05-6726825d.pth' 24 | 25 | 26 | class InceptionV3(nn.Module): 27 | """Pretrained InceptionV3 network returning feature maps""" 28 | 29 | # Index of default block of inception to return, 30 | # corresponds to output of final average pooling 31 | DEFAULT_BLOCK_INDEX = 3 32 | 33 | # Maps feature dimensionality to their output blocks indices 34 | BLOCK_INDEX_BY_DIM = { 35 | 64: 0, # First max pooling features 36 | 192: 1, # Second max pooling featurs 37 | 768: 2, # Pre-aux classifier features 38 | 2048: 3 # Final average pooling features 39 | } 40 | 41 | def __init__(self, 42 | output_blocks=(DEFAULT_BLOCK_INDEX,), 43 | resize_input=True, 44 | normalize_input=False, 45 | requires_grad=False, 46 | use_fid_inception=True): 47 | """Build pretrained InceptionV3 48 | 49 | Parameters 50 | ---------- 51 | output_blocks : list of int 52 | Indices of blocks to return features of. Possible values are: 53 | - 0: corresponds to output of first max pooling 54 | - 1: corresponds to output of second max pooling 55 | - 2: corresponds to output which is fed to aux classifier 56 | - 3: corresponds to output of final average pooling 57 | resize_input : bool 58 | If true, bilinearly resizes input to width and height 299 before 59 | feeding input to model. As the network without fully connected 60 | layers is fully convolutional, it should be able to handle inputs 61 | of arbitrary size, so resizing might not be strictly needed 62 | normalize_input : bool 63 | If true, scales the input from range (0, 1) to the range the 64 | pretrained Inception network expects, namely (-1, 1) 65 | requires_grad : bool 66 | If true, parameters of the model require gradients. Possibly useful 67 | for finetuning the network 68 | use_fid_inception : bool 69 | If true, uses the pretrained Inception model used in Tensorflow's 70 | FID implementation. If false, uses the pretrained Inception model 71 | available in torchvision. The FID Inception model has different 72 | weights and a slightly different structure from torchvision's 73 | Inception model. If you want to compute FID scores, you are 74 | strongly advised to set this parameter to true to get comparable 75 | results. 76 | """ 77 | super(InceptionV3, self).__init__() 78 | 79 | self.resize_input = resize_input 80 | self.normalize_input = normalize_input 81 | self.output_blocks = sorted(output_blocks) 82 | self.last_needed_block = max(output_blocks) 83 | 84 | assert self.last_needed_block <= 3, \ 85 | 'Last possible output block index is 3' 86 | 87 | self.blocks = nn.ModuleList() 88 | 89 | if use_fid_inception: 90 | inception = fid_inception_v3() 91 | else: 92 | inception = _inception_v3(pretrained=True) 93 | 94 | # Block 0: input to maxpool1 95 | block0 = [ 96 | inception.Conv2d_1a_3x3, 97 | inception.Conv2d_2a_3x3, 98 | inception.Conv2d_2b_3x3, 99 | nn.MaxPool2d(kernel_size=3, stride=2) 100 | ] 101 | self.blocks.append(nn.Sequential(*block0)) 102 | 103 | # Block 1: maxpool1 to maxpool2 104 | if self.last_needed_block >= 1: 105 | block1 = [ 106 | inception.Conv2d_3b_1x1, 107 | inception.Conv2d_4a_3x3, 108 | nn.MaxPool2d(kernel_size=3, stride=2) 109 | ] 110 | self.blocks.append(nn.Sequential(*block1)) 111 | 112 | # Block 2: maxpool2 to aux classifier 113 | if self.last_needed_block >= 2: 114 | block2 = [ 115 | inception.Mixed_5b, 116 | inception.Mixed_5c, 117 | inception.Mixed_5d, 118 | inception.Mixed_6a, 119 | inception.Mixed_6b, 120 | inception.Mixed_6c, 121 | inception.Mixed_6d, 122 | inception.Mixed_6e, 123 | ] 124 | self.blocks.append(nn.Sequential(*block2)) 125 | 126 | # Block 3: aux classifier to final avgpool 127 | if self.last_needed_block >= 3: 128 | block3 = [ 129 | inception.Mixed_7a, 130 | inception.Mixed_7b, 131 | inception.Mixed_7c, 132 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 133 | ] 134 | self.blocks.append(nn.Sequential(*block3)) 135 | 136 | for param in self.parameters(): 137 | param.requires_grad = requires_grad 138 | 139 | def forward(self, inp): 140 | """Get Inception feature maps 141 | 142 | Parameters 143 | ---------- 144 | inp : torch.autograd.Variable 145 | Input tensor of shape Bx3xHxW. Values are expected to be in 146 | range (0, 1) 147 | 148 | Returns 149 | ------- 150 | List of torch.autograd.Variable, corresponding to the selected output 151 | block, sorted ascending by index 152 | """ 153 | outp = [] 154 | x = inp 155 | 156 | if self.resize_input: 157 | x = F.interpolate(x, 158 | size=(299, 299), 159 | mode='bilinear', 160 | align_corners=False) 161 | 162 | if self.normalize_input: 163 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 164 | 165 | for idx, block in enumerate(self.blocks): 166 | x = block(x) 167 | if idx in self.output_blocks: 168 | outp.append(x) 169 | 170 | if idx == self.last_needed_block: 171 | break 172 | 173 | return outp 174 | 175 | 176 | def _inception_v3(*args, **kwargs): 177 | """Wraps `torchvision.models.inception_v3` 178 | 179 | Skips default weight inititialization if supported by torchvision version. 180 | See https://github.com/mseitzer/pytorch-fid/issues/28. 181 | """ 182 | try: 183 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 184 | except ValueError: 185 | # Just a caution against weird version strings 186 | version = (0,) 187 | 188 | if version >= (0, 6): 189 | kwargs['init_weights'] = False 190 | 191 | return torchvision.models.inception_v3(*args, **kwargs) 192 | 193 | 194 | def fid_inception_v3(): 195 | """Build pretrained Inception model for FID computation 196 | 197 | The Inception model for FID computation uses a different set of weights 198 | and has a slightly different structure than torchvision's Inception. 199 | 200 | This method first constructs torchvision's Inception and then patches the 201 | necessary parts that are different in the FID Inception model. 202 | """ 203 | inception = _inception_v3(num_classes=1008, 204 | aux_logits=False, 205 | pretrained=False) 206 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 207 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 208 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 209 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 210 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 211 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 212 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 213 | inception.Mixed_7b = FIDInceptionE_1(1280) 214 | inception.Mixed_7c = FIDInceptionE_2(2048) 215 | 216 | local_fid_weights = pathlib.Path(lib.data.ML_DATA / os.path.basename(FID_WEIGHTS_URL)) 217 | if local_fid_weights.is_file(): 218 | state_dict = torch.load(local_fid_weights) 219 | else: 220 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 221 | inception.load_state_dict(state_dict) 222 | return inception 223 | 224 | 225 | class FIDInceptionA(torchvision.models.inception.InceptionA): 226 | """InceptionA block patched for FID computation""" 227 | 228 | def __init__(self, in_channels, pool_features): 229 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 230 | 231 | def forward(self, x): 232 | branch1x1 = self.branch1x1(x) 233 | 234 | branch5x5 = self.branch5x5_1(x) 235 | branch5x5 = self.branch5x5_2(branch5x5) 236 | 237 | branch3x3dbl = self.branch3x3dbl_1(x) 238 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 239 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 240 | 241 | # Patch: Tensorflow's average pool does not use the padded zero's in 242 | # its average calculation 243 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 244 | count_include_pad=False) 245 | branch_pool = self.branch_pool(branch_pool) 246 | 247 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 248 | return torch.cat(outputs, 1) 249 | 250 | 251 | class FIDInceptionC(torchvision.models.inception.InceptionC): 252 | """InceptionC block patched for FID computation""" 253 | 254 | def __init__(self, in_channels, channels_7x7): 255 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 256 | 257 | def forward(self, x): 258 | branch1x1 = self.branch1x1(x) 259 | 260 | branch7x7 = self.branch7x7_1(x) 261 | branch7x7 = self.branch7x7_2(branch7x7) 262 | branch7x7 = self.branch7x7_3(branch7x7) 263 | 264 | branch7x7dbl = self.branch7x7dbl_1(x) 265 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 266 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 267 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 268 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 269 | 270 | # Patch: Tensorflow's average pool does not use the padded zero's in 271 | # its average calculation 272 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 273 | count_include_pad=False) 274 | branch_pool = self.branch_pool(branch_pool) 275 | 276 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 277 | return torch.cat(outputs, 1) 278 | 279 | 280 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 281 | """First InceptionE block patched for FID computation""" 282 | 283 | def __init__(self, in_channels): 284 | super(FIDInceptionE_1, self).__init__(in_channels) 285 | 286 | def forward(self, x): 287 | branch1x1 = self.branch1x1(x) 288 | 289 | branch3x3 = self.branch3x3_1(x) 290 | branch3x3 = [ 291 | self.branch3x3_2a(branch3x3), 292 | self.branch3x3_2b(branch3x3), 293 | ] 294 | branch3x3 = torch.cat(branch3x3, 1) 295 | 296 | branch3x3dbl = self.branch3x3dbl_1(x) 297 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 298 | branch3x3dbl = [ 299 | self.branch3x3dbl_3a(branch3x3dbl), 300 | self.branch3x3dbl_3b(branch3x3dbl), 301 | ] 302 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 303 | 304 | # Patch: Tensorflow's average pool does not use the padded zero's in 305 | # its average calculation 306 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 307 | count_include_pad=False) 308 | branch_pool = self.branch_pool(branch_pool) 309 | 310 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 311 | return torch.cat(outputs, 1) 312 | 313 | 314 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 315 | """Second InceptionE block patched for FID computation""" 316 | 317 | def __init__(self, in_channels): 318 | super(FIDInceptionE_2, self).__init__(in_channels) 319 | 320 | def forward(self, x): 321 | branch1x1 = self.branch1x1(x) 322 | 323 | branch3x3 = self.branch3x3_1(x) 324 | branch3x3 = [ 325 | self.branch3x3_2a(branch3x3), 326 | self.branch3x3_2b(branch3x3), 327 | ] 328 | branch3x3 = torch.cat(branch3x3, 1) 329 | 330 | branch3x3dbl = self.branch3x3dbl_1(x) 331 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 332 | branch3x3dbl = [ 333 | self.branch3x3dbl_3a(branch3x3dbl), 334 | self.branch3x3dbl_3b(branch3x3dbl), 335 | ] 336 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 337 | 338 | # Patch: The FID Inception model uses max pooling instead of average 339 | # pooling. This is likely an error in this specific Inception 340 | # implementation, as other Inception models use average pooling here 341 | # (which matches the description in the paper). 342 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 343 | branch_pool = self.branch_pool(branch_pool) 344 | 345 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 346 | return torch.cat(outputs, 1) 347 | -------------------------------------------------------------------------------- /lib/io.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['Checkpoint', 'Summary', 'SummaryWriter', 'zip_batch_as_png'] 7 | 8 | import enum 9 | import io 10 | import os 11 | import pathlib 12 | import zipfile 13 | from time import time 14 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 15 | 16 | import imageio 17 | import matplotlib.figure 18 | import numpy as np 19 | import torch 20 | import torch.nn 21 | from tensorboard.compat.proto import event_pb2, summary_pb2 22 | from tensorboard.summary.writer.event_file_writer import EventFileWriter 23 | from tensorboard.util.tensor_util import make_tensor_proto 24 | 25 | from .distributed import is_master, print, reduce_dict_mean 26 | from .util import to_numpy, to_png 27 | 28 | 29 | class Checkpoint: 30 | DIR_NAME: str = 'ckpt' 31 | FILE_MATCH: str = '*.pth' 32 | FILE_FORMAT: str = '%012d.pth' 33 | 34 | def __init__(self, 35 | model: torch.nn.Module, 36 | logdir: pathlib.Path, 37 | keep_ckpts: int = 0): 38 | self.model = model 39 | self.logdir = logdir / self.DIR_NAME 40 | self.keep_ckpts = keep_ckpts 41 | 42 | @staticmethod 43 | def checkpoint_idx(filename: str) -> int: 44 | return int(os.path.basename(filename).split('.')[0]) 45 | 46 | def restore(self, idx: Optional[int] = None) -> Tuple[int, Optional[pathlib.Path]]: 47 | if idx is None: 48 | all_ckpts = self.logdir.glob(self.FILE_MATCH) 49 | try: 50 | idx = self.checkpoint_idx(max(str(x) for x in all_ckpts)) 51 | except ValueError: 52 | return 0, None 53 | ckpt = self.logdir / (self.FILE_FORMAT % idx) 54 | print(f'Resuming from: {ckpt}') 55 | with ckpt.open('rb') as f: 56 | self.model.load_state_dict(torch.load(f, map_location='cpu')) 57 | return idx, ckpt 58 | 59 | def save(self, idx: int) -> None: 60 | if not is_master(): # only save master's state 61 | return 62 | self.logdir.mkdir(exist_ok=True, parents=True) 63 | ckpt = self.logdir / (self.FILE_FORMAT % idx) 64 | with ckpt.open('wb') as f: 65 | torch.save(self.model.state_dict(), f) 66 | old_ckpts = sorted(self.logdir.glob(self.FILE_MATCH), key=str) 67 | for ckpt in old_ckpts[:-self.keep_ckpts]: 68 | ckpt.unlink() 69 | 70 | def save_file(self, model: torch.nn.Module, filename: str) -> None: 71 | if not is_master(): # only save master's state 72 | return 73 | self.logdir.mkdir(exist_ok=True, parents=True) 74 | with (self.logdir / filename).open('wb') as f: 75 | torch.save(model.state_dict(), f) 76 | 77 | class Summary(dict): 78 | """Helper to generate summary_pb2.Summary protobufs.""" 79 | 80 | # Inspired from https://github.com/google/objax/blob/master/objax/jaxboard.py 81 | 82 | class ProtoMode(enum.Flag): 83 | """Enum describing what to export to a tensorboard proto.""" 84 | 85 | IMAGES = enum.auto() 86 | VIDEOS = enum.auto() 87 | OTHERS = enum.auto() 88 | ALL = IMAGES | VIDEOS | OTHERS 89 | 90 | class Scalar: 91 | """Class for a Summary Scalar.""" 92 | 93 | def __init__(self, reduce: Callable[[Sequence[float]], float] = np.mean): 94 | self.values = [] 95 | self.reduce = reduce 96 | 97 | def __call__(self): 98 | return self.reduce(self.values) 99 | 100 | class Text: 101 | """Class for a Summary Text.""" 102 | 103 | def __init__(self, text: str): 104 | self.text = text 105 | 106 | class Image: 107 | """Class for a Summary Image.""" 108 | 109 | def __init__(self, shape: Tuple[int, int, int], image_bytes: bytes): 110 | self.shape = shape # (C, H, W) 111 | self.image_bytes = image_bytes 112 | 113 | class Video: 114 | """Class for a Summary Video.""" 115 | 116 | def __init__(self, shape: Tuple[int, int, int], image_bytes: bytes): 117 | self.shape = shape # (C, H, W) 118 | self.image_bytes = image_bytes 119 | 120 | def from_metrics(self, metrics: Dict[str, torch.Tensor]): 121 | metrics = reduce_dict_mean(metrics) 122 | for k, v in metrics.items(): 123 | v = to_numpy(v) 124 | if np.isnan(v): 125 | raise ValueError('NaN', k) 126 | self.scalar(k, float(v)) 127 | 128 | def gif(self, tag: str, imgs: List[np.ndarray]): 129 | assert imgs 130 | try: 131 | height, width, _ = imgs[0].shape 132 | vid_save_path = '/tmp/video.gif' 133 | imageio.mimsave(vid_save_path, [np.array(img) for i, img in enumerate(imgs) if i % 2 == 0], fps=30) 134 | with open(vid_save_path, 'rb') as f: 135 | encoded_image_string = f.read() 136 | self[tag] = Summary.Video((3, height, width), encoded_image_string) 137 | except AttributeError: 138 | # the kitchen and hand manipulation envs do not support rendering. 139 | return 140 | 141 | def plot(self, tag: str, fig: matplotlib.figure.Figure): 142 | byte_data = io.BytesIO() 143 | fig.savefig(byte_data, format='png') 144 | img_w, img_h = fig.canvas.get_width_height() 145 | self[tag] = Summary.Image((4, img_h, img_w), byte_data.getvalue()) 146 | 147 | def png(self, tag: str, img: Union[np.ndarray, torch.Tensor]): 148 | if img.ndim == 3: 149 | shape = (img.shape[2], *img.shape[:2]) 150 | elif img.ndim == 5: 151 | shape = (img.shape[2], img.shape[0] * img.shape[3], img.shape[1] * img.shape[4]) 152 | else: 153 | raise ValueError(f'Unsupported image shape {img.shape}') 154 | self[tag] = Summary.Image(shape, to_png(img)) 155 | 156 | def scalar(self, tag: str, value: float, reduce: Callable[[Sequence[float]], float] = np.mean): 157 | if tag not in self: 158 | self[tag] = Summary.Scalar(reduce) 159 | self[tag].values.append(value) 160 | 161 | def text(self, tag: str, text: str): 162 | self[tag] = Summary.Text(text) 163 | 164 | def proto(self, mode: ProtoMode = ProtoMode.ALL): 165 | entries = [] 166 | for tag, value in self.items(): 167 | if isinstance(value, Summary.Scalar): 168 | if mode & self.ProtoMode.OTHERS: 169 | entries.append(summary_pb2.Summary.Value(tag=tag, simple_value=value())) 170 | elif isinstance(value, Summary.Text): 171 | if mode & self.ProtoMode.OTHERS: 172 | metadata = summary_pb2.SummaryMetadata( 173 | plugin_data=summary_pb2.SummaryMetadata.PluginData(plugin_name='text')) 174 | entries.append(summary_pb2.Summary.Value( 175 | tag=tag, metadata=metadata, 176 | tensor=make_tensor_proto(values=value.text.encode('utf-8'), shape=(1,)))) 177 | elif isinstance(value, (Summary.Image, Summary.Video)): 178 | if mode & (self.ProtoMode.IMAGES | self.ProtoMode.VIDEOS): 179 | image_summary = summary_pb2.Summary.Image( 180 | encoded_image_string=value.image_bytes, 181 | colorspace=value.shape[0], # RGBA 182 | height=value.shape[1], 183 | width=value.shape[2]) 184 | entries.append(summary_pb2.Summary.Value(tag=tag, image=image_summary)) 185 | else: 186 | raise NotImplementedError(tag, value) 187 | return summary_pb2.Summary(value=entries) 188 | 189 | def to_dict(self) -> Dict[str, Any]: 190 | entries = {} 191 | for tag, value in self.items(): 192 | if isinstance(value, Summary.Scalar): 193 | entries[tag] = float(value()) 194 | elif isinstance(value, (Summary.Text, Summary.Image, Summary.Video)): 195 | pass 196 | else: 197 | raise NotImplementedError(tag, value) 198 | return entries 199 | 200 | def __str__(self) -> str: 201 | return '\n'.join(f' {k:40s}: {v:.6f}' for k, v in self.to_dict().items()) 202 | 203 | 204 | class SummaryForgetter: 205 | """Used as placeholder for workers, it basically does nothing.""" 206 | 207 | def __init__(self, 208 | logdir: pathlib.Path, 209 | queue_size: int = 5, 210 | write_interval: int = 5): 211 | self.logdir = logdir 212 | 213 | def write(self, summary: Summary, step: int): 214 | pass 215 | 216 | def close(self): 217 | """Flushes the event file to disk and close the file.""" 218 | pass 219 | 220 | def __enter__(self): 221 | return self 222 | 223 | def __exit__(self, exc_type, exc_val, exc_tb): 224 | self.close() 225 | 226 | 227 | # Inspired from https://github.com/google/objax/blob/master/objax/jaxboard.py 228 | class SummaryWriter: 229 | """Writes entries to logdir to be consumed by TensorBoard and Weight & Biases.""" 230 | 231 | def __init__(self, 232 | logdir: pathlib.Path, 233 | queue_size: int = 5, 234 | write_interval: int = 5): 235 | (logdir / 'tb').mkdir(exist_ok=True, parents=True) 236 | self.logdir = logdir 237 | self.writer = EventFileWriter(logdir / 'tb', queue_size, write_interval) 238 | self.writer_image = EventFileWriter(logdir / 'tb', queue_size, write_interval, filename_suffix='images') 239 | 240 | def write(self, summary: Summary, step: int): 241 | """Add on event to the event file.""" 242 | self.writer.add_event( 243 | event_pb2.Event(step=step, summary=summary.proto(summary.ProtoMode.OTHERS), 244 | wall_time=time())) 245 | self.writer_image.add_event( 246 | event_pb2.Event(step=step, summary=summary.proto(summary.ProtoMode.IMAGES), 247 | wall_time=time())) 248 | 249 | def close(self): 250 | """Flushes the event file to disk and close the file.""" 251 | self.writer.close() 252 | self.writer_image.close() 253 | 254 | def __enter__(self): 255 | return self 256 | 257 | def __exit__(self, exc_type, exc_val, exc_tb): 258 | self.close() 259 | 260 | @classmethod 261 | def create(cls, logdir: pathlib.Path, 262 | queue_size: int = 5, 263 | write_interval: int = 5) -> Union[SummaryForgetter, 'SummaryWriter']: 264 | if is_master(): 265 | return cls(logdir, queue_size, write_interval) 266 | return SummaryForgetter(logdir, queue_size, write_interval) 267 | 268 | 269 | def zip_batch_as_png(x: Union[np.ndarray, torch.Tensor], filename: pathlib.Path): 270 | if not is_master(): 271 | return 272 | assert x.ndim == 4 273 | with zipfile.ZipFile(filename, 'w') as fzip: 274 | for i in range(x.shape[0]): 275 | with fzip.open(f'{i:06d}.png', 'w') as f: 276 | f.write(to_png(x[i])) 277 | -------------------------------------------------------------------------------- /lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from . import functional # noqa 6 | from . import ncsnpp # noqa 7 | from .nn import * # noqa 8 | -------------------------------------------------------------------------------- /lib/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from .functional import * # noqa 6 | -------------------------------------------------------------------------------- /lib/nn/functional/functional.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['expand_to', 'float_index', 'label_smoothing', 'set_bn_momentum', 'set_cond', 'set_dropout', 7 | 'default', 'log'] 8 | 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn.functional 13 | 14 | 15 | def expand_to(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 16 | """Expand x to the number of dimensions in y.""" 17 | return x.view(x.shape + (1,) * (y.ndim - x.ndim)) 18 | 19 | 20 | def float_index(x: torch.Tensor, i: torch.Tensor) -> torch.Tensor: 21 | a, b = x[i.long()], x[i.ceil().long()] 22 | return a + i.frac() * (b - a) 23 | 24 | 25 | def label_smoothing(x: torch.Tensor, q: float) -> torch.Tensor: 26 | u = torch.zeros_like(x) + 1 / x.shape[-1] 27 | return x + q * (u - x) 28 | 29 | 30 | def set_bn_momentum(m: torch.nn.Module, momentum: float): 31 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 32 | print('Set momentum for', m) 33 | m.momentum = momentum 34 | 35 | 36 | def set_cond(cond: Optional[torch.Tensor]): 37 | def apply_op(m: torch.nn.Module): 38 | if hasattr(m, 'set_cond'): 39 | m.set_cond(cond) 40 | 41 | return apply_op 42 | 43 | 44 | def set_dropout(m: torch.nn.Module, p: float): 45 | if isinstance(m, torch.nn.modules.dropout._DropoutNd): 46 | print(f'Set dropout to {p} for', m) 47 | m.p = p 48 | 49 | 50 | def default(val, d): 51 | if val is not None: 52 | return val 53 | return d() if callable(d) else d 54 | 55 | 56 | def log(t, eps=1e-20): 57 | return torch.log(t.clamp(min=eps)) 58 | -------------------------------------------------------------------------------- /lib/nn/ncsnpp/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /lib/nn/ncsnpp/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # -------------------------------------------------------------------------------- /lib/nn/ncsnpp/layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # This file is adapted and modified from https://github.com/yang-song/score_sde_pytorch. 6 | import math 7 | import string 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def variance_scaling(scale, mode, distribution, 15 | in_axis=1, out_axis=0, 16 | dtype=torch.float32, 17 | device='cpu'): 18 | 19 | def _compute_fans(shape, in_axis=1, out_axis=0): 20 | receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] 21 | fan_in = shape[in_axis] * receptive_field_size 22 | fan_out = shape[out_axis] * receptive_field_size 23 | return fan_in, fan_out 24 | 25 | def init(shape, dtype=dtype, device=device): 26 | fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) 27 | if mode == "fan_in": 28 | denominator = fan_in 29 | elif mode == "fan_out": 30 | denominator = fan_out 31 | elif mode == "fan_avg": 32 | denominator = (fan_in + fan_out) / 2 33 | else: 34 | raise ValueError( 35 | "invalid mode for variance scaling initializer: {}".format(mode)) 36 | variance = scale / denominator 37 | if distribution == "normal": 38 | return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) 39 | elif distribution == "uniform": 40 | return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) 41 | else: 42 | raise ValueError("invalid distribution for variance scaling initializer") 43 | 44 | return init 45 | 46 | 47 | def default_init(scale=1.): 48 | """Initialize the same way as per DDPM.""" 49 | scale = 1e-10 if scale == 0 else scale 50 | return variance_scaling(scale, 'fan_avg', 'uniform') 51 | 52 | 53 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 54 | """Return timestep embedding for positional embeddings.""" 55 | assert len(timesteps.shape) == 1 56 | half_dim = embedding_dim // 2 57 | 58 | # magic number 10000 is from transformers 59 | emb = math.log(max_positions) / (half_dim - 1) 60 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 61 | emb = timesteps.float()[:, None] * emb[None, :] 62 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 63 | 64 | if embedding_dim % 2 == 1: # zero pad 65 | emb = F.pad(emb, (0, 1), mode='constant') 66 | 67 | assert emb.shape == (timesteps.shape[0], embedding_dim) 68 | 69 | return emb 70 | 71 | 72 | def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): 73 | """Return 1x1 convolution with DDPM initialization.""" 74 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) 75 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 76 | if bias: 77 | nn.init.zeros_(conv.bias) 78 | return conv 79 | 80 | 81 | def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): 82 | """Return 3x3 convolution with DDPM initialization.""" 83 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, 84 | dilation=dilation, bias=bias) 85 | conv.weight.data = default_init(init_scale)(conv.weight.data.shape) 86 | if bias: 87 | nn.init.zeros_(conv.bias) 88 | return conv 89 | 90 | 91 | def _einsum(a, b, c, x, y): 92 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 93 | return torch.einsum(einsum_str, x, y) 94 | 95 | 96 | def contract_inner(x, y): 97 | """Return tensordot(x, y, 1).""" 98 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 99 | y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) 100 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 101 | out_chars = x_chars[:-1] + y_chars[1:] 102 | return _einsum(x_chars, y_chars, out_chars, x, y) 103 | 104 | 105 | class NIN(nn.Module): 106 | def __init__(self, in_dim, num_units, init_scale=0.1): 107 | super().__init__() 108 | self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) 109 | self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) 110 | 111 | def forward(self, x): 112 | x = x.permute(0, 2, 3, 1) 113 | y = contract_inner(x, self.W) + self.b 114 | return y.permute(0, 3, 1, 2) 115 | -------------------------------------------------------------------------------- /lib/nn/ncsnpp/layerspp.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # This file is adapted and modified from https://github.com/yang-song/score_sde_pytorch. 6 | """Layers for defining NCSN++.""" 7 | from . import layers 8 | from . import up_or_down_sampling 9 | import torch.nn as nn 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | conv1x1 = layers.ddpm_conv1x1 15 | conv3x3 = layers.ddpm_conv3x3 16 | NIN = layers.NIN 17 | default_init = layers.default_init 18 | 19 | 20 | class GaussianFourierProjection(nn.Module): 21 | """Gaussian Fourier embeddings for noise levels.""" 22 | 23 | def __init__(self, embedding_size=256, scale=1.0): 24 | super().__init__() 25 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 26 | 27 | def forward(self, x): 28 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 29 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 30 | 31 | 32 | class Combine(nn.Module): 33 | """Combine information from skip connections.""" 34 | 35 | def __init__(self, dim1, dim2, method='cat'): 36 | super().__init__() 37 | self.Conv_0 = conv1x1(dim1, dim2) 38 | self.method = method 39 | 40 | def forward(self, x, y): 41 | h = self.Conv_0(x) 42 | if self.method == 'cat': 43 | return torch.cat([h, y], dim=1) 44 | elif self.method == 'sum': 45 | return h + y 46 | else: 47 | raise ValueError(f'Method {self.method} not recognized.') 48 | 49 | 50 | class AttnBlockpp(nn.Module): 51 | """Channel-wise self-attention block. Modified from DDPM.""" 52 | 53 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 54 | super().__init__() 55 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 56 | eps=1e-6) 57 | self.NIN_0 = NIN(channels, channels) 58 | self.NIN_1 = NIN(channels, channels) 59 | self.NIN_2 = NIN(channels, channels) 60 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 61 | self.skip_rescale = skip_rescale 62 | 63 | def forward(self, x): 64 | B, C, H, W = x.shape 65 | h = self.GroupNorm_0(x) 66 | q = self.NIN_0(h) 67 | k = self.NIN_1(h) 68 | v = self.NIN_2(h) 69 | 70 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 71 | w = torch.reshape(w, (B, H, W, H * W)) 72 | w = F.softmax(w, dim=-1) 73 | w = torch.reshape(w, (B, H, W, H, W)) 74 | h = torch.einsum('bhwij,bcij->bchw', w, v) 75 | h = self.NIN_3(h) 76 | if not self.skip_rescale: 77 | return x + h 78 | else: 79 | return (x + h) / np.sqrt(2.) 80 | 81 | 82 | class Upsample(nn.Module): 83 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): 84 | super().__init__() 85 | out_ch = out_ch if out_ch else in_ch 86 | if not fir: 87 | if with_conv: 88 | self.Conv_0 = conv3x3(in_ch, out_ch) 89 | else: 90 | if with_conv: 91 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, kernel=3, up=True, resample_kernel=fir_kernel, use_bias=True, kernel_init=default_init()) 92 | self.fir = fir 93 | self.with_conv = with_conv 94 | self.fir_kernel = fir_kernel 95 | self.out_ch = out_ch 96 | 97 | def forward(self, x): 98 | B, C, H, W = x.shape 99 | if not self.fir: 100 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 101 | if self.with_conv: 102 | h = self.Conv_0(h) 103 | else: 104 | if not self.with_conv: 105 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 106 | else: 107 | h = self.Conv2d_0(x) 108 | 109 | return h 110 | 111 | 112 | class Downsample(nn.Module): 113 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 114 | fir_kernel=(1, 3, 3, 1)): 115 | super().__init__() 116 | out_ch = out_ch if out_ch else in_ch 117 | if not fir: 118 | if with_conv: 119 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 120 | else: 121 | if with_conv: 122 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, kernel=3, down=True, resample_kernel=fir_kernel, use_bias=True, kernel_init=default_init()) 123 | self.fir = fir 124 | self.fir_kernel = fir_kernel 125 | self.with_conv = with_conv 126 | self.out_ch = out_ch 127 | 128 | def forward(self, x): 129 | B, C, H, W = x.shape 130 | if not self.fir: 131 | if self.with_conv: 132 | x = F.pad(x, (0, 1, 0, 1)) 133 | x = self.Conv_0(x) 134 | else: 135 | x = F.avg_pool2d(x, 2, stride=2) 136 | else: 137 | if not self.with_conv: 138 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 139 | else: 140 | x = self.Conv2d_0(x) 141 | 142 | return x 143 | 144 | 145 | class ResnetBlockDDPMpp(nn.Module): 146 | """ResBlock adapted from DDPM.""" 147 | 148 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, skip_rescale=False, init_scale=0.): 149 | super().__init__() 150 | out_ch = out_ch if out_ch else in_ch 151 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 152 | self.Conv_0 = conv3x3(in_ch, out_ch) 153 | if temb_dim is not None: 154 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 155 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 156 | nn.init.zeros_(self.Dense_0.bias) 157 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 158 | self.Dropout_0 = nn.Dropout(dropout) 159 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 160 | if in_ch != out_ch: 161 | if conv_shortcut: 162 | self.Conv_2 = conv3x3(in_ch, out_ch) 163 | else: 164 | self.NIN_0 = NIN(in_ch, out_ch) 165 | 166 | self.skip_rescale = skip_rescale 167 | self.act = act 168 | self.out_ch = out_ch 169 | self.conv_shortcut = conv_shortcut 170 | 171 | def forward(self, x, temb=None): 172 | h = self.act(self.GroupNorm_0(x)) 173 | h = self.Conv_0(h) 174 | if temb is not None: 175 | h += self.Dense_0(self.act(temb))[:, :, None, None] 176 | h = self.act(self.GroupNorm_1(h)) 177 | h = self.Dropout_0(h) 178 | h = self.Conv_1(h) 179 | if x.shape[1] != self.out_ch: 180 | if self.conv_shortcut: 181 | x = self.Conv_2(x) 182 | else: 183 | x = self.NIN_0(x) 184 | if not self.skip_rescale: 185 | return x + h 186 | else: 187 | return (x + h) / np.sqrt(2.) 188 | 189 | 190 | class ResnetBlockBigGANpp(nn.Module): 191 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 192 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 193 | skip_rescale=True, init_scale=0.): 194 | super().__init__() 195 | 196 | out_ch = out_ch if out_ch else in_ch 197 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 198 | self.up = up 199 | self.down = down 200 | self.fir = fir 201 | self.fir_kernel = fir_kernel 202 | 203 | self.Conv_0 = conv3x3(in_ch, out_ch) 204 | if temb_dim is not None: 205 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 206 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 207 | nn.init.zeros_(self.Dense_0.bias) 208 | 209 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 210 | self.Dropout_0 = nn.Dropout(dropout) 211 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 212 | if in_ch != out_ch or up or down: 213 | self.Conv_2 = conv1x1(in_ch, out_ch) 214 | 215 | self.skip_rescale = skip_rescale 216 | self.act = act 217 | self.in_ch = in_ch 218 | self.out_ch = out_ch 219 | 220 | def forward(self, x, temb=None): 221 | h = self.act(self.GroupNorm_0(x)) 222 | 223 | if self.up: 224 | if self.fir: 225 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 226 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 227 | else: 228 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 229 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 230 | elif self.down: 231 | if self.fir: 232 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 233 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 234 | else: 235 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 236 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 237 | 238 | h = self.Conv_0(h) 239 | # Add bias to each feature map conditioned on the time embedding 240 | if temb is not None: 241 | h += self.Dense_0(self.act(temb))[:, :, None, None] 242 | h = self.act(self.GroupNorm_1(h)) 243 | h = self.Dropout_0(h) 244 | h = self.Conv_1(h) 245 | 246 | if self.in_ch != self.out_ch or self.up or self.down: 247 | x = self.Conv_2(x) 248 | 249 | if not self.skip_rescale: 250 | return x + h 251 | else: 252 | return (x + h) / np.sqrt(2.) 253 | -------------------------------------------------------------------------------- /lib/nn/ncsnpp/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # This file is adapted from https://github.com/yang-song/score_sde_pytorch. 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | def _setup_kernel(k): 13 | k = np.asarray(k, dtype=np.float32) 14 | if k.ndim == 1: 15 | k = np.outer(k, k) 16 | k /= np.sum(k) 17 | assert k.ndim == 2 18 | assert k.shape[0] == k.shape[1] 19 | return k 20 | 21 | 22 | def _shape(x, dim): 23 | return x.shape[dim] 24 | 25 | 26 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 27 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 28 | 29 | Performs the following sequence of operations for each channel: 30 | 31 | 1. Upsample the input by inserting N-1 zeros after each pixel (`up`). 32 | 33 | 2. Pad the image with the specified number of zeros on each side (`padding`). 34 | Negative padding corresponds to cropping the image. 35 | 36 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 37 | so that the footprint of all output pixels lies within the input image. 38 | 39 | 4. Downsample the image by keeping every Nth pixel (`down`). 40 | 41 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 42 | It supports gradients of arbitrary order. 43 | 44 | Args: 45 | input: Float32/float64/float16 input tensor of the shape 46 | `[batch_size, num_channels, in_height, in_width]`. 47 | kernel: Float32 FIR filter of the shape 48 | `[filter_height, filter_width]` called from _setup_kernel. 49 | up: Integer upsampling factor. 50 | down: Integer downsampling factor. 51 | pad: Padding with respect to the upsampled image. list/tuple `[x, y]`. 52 | 53 | Returns: 54 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 55 | """ 56 | def upfirdn2d_native( 57 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 58 | ): 59 | _, channel, in_h, in_w = input.shape 60 | input = input.reshape(-1, in_h, in_w, 1) 61 | 62 | _, in_h, in_w, minor = input.shape 63 | kernel_h, kernel_w = kernel.shape 64 | 65 | out = input.view(-1, in_h, 1, in_w, 1, minor) 66 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 67 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 68 | 69 | out = F.pad( 70 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 71 | ) 72 | out = out[ 73 | :, 74 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 75 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 76 | :, 77 | ] 78 | 79 | out = out.permute(0, 3, 1, 2) 80 | out = out.reshape( 81 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 82 | ) 83 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 84 | out = F.conv2d(out, w) 85 | out = out.reshape( 86 | -1, 87 | minor, 88 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 89 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 90 | ) 91 | out = out.permute(0, 2, 3, 1) 92 | out = out[:, ::down_y, ::down_x, :] 93 | 94 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 95 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 96 | 97 | return out.view(-1, channel, out_h, out_w) 98 | 99 | out = upfirdn2d_native( 100 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 101 | ) 102 | 103 | return out 104 | 105 | 106 | # Function ported from StyleGAN2 107 | def get_weight(module, 108 | shape, 109 | weight_var='weight', 110 | kernel_init=None): 111 | """Get/create weight tensor for a convolution or fully-connected layer.""" 112 | return module.param(weight_var, kernel_init, shape) 113 | 114 | 115 | class Conv2d(nn.Module): 116 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 117 | 118 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 119 | resample_kernel=(1, 3, 3, 1), 120 | use_bias=True, 121 | kernel_init=None): 122 | super().__init__() 123 | assert not (up and down) 124 | assert kernel >= 1 and kernel % 2 == 1 125 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 126 | if kernel_init is not None: 127 | self.weight.data = kernel_init(self.weight.data.shape) 128 | if use_bias: 129 | self.bias = nn.Parameter(torch.zeros(out_ch)) 130 | 131 | self.up = up 132 | self.down = down 133 | self.resample_kernel = resample_kernel 134 | self.kernel = kernel 135 | self.use_bias = use_bias 136 | 137 | def forward(self, x): 138 | if self.up: 139 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 140 | elif self.down: 141 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 142 | else: 143 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 144 | 145 | if self.use_bias: 146 | x = x + self.bias.reshape(1, -1, 1, 1) 147 | 148 | return x 149 | 150 | 151 | def naive_upsample_2d(x, factor=2): 152 | _N, C, H, W = x.shape 153 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 154 | x = x.repeat(1, 1, 1, factor, 1, factor) 155 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 156 | 157 | 158 | def naive_downsample_2d(x, factor=2): 159 | _N, C, H, W = x.shape 160 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 161 | return torch.mean(x, dim=(3, 5)) 162 | 163 | 164 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 165 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 166 | 167 | Padding is performed only once at the beginning, not between the 168 | operations. 169 | The fused op is considerably more efficient than performing the same 170 | calculation 171 | using standard TensorFlow ops. It supports gradients of arbitrary order. 172 | Args: 173 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 174 | C]`. 175 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 176 | outChannels]`. Grouped convolution can be performed by `inChannels = 177 | x.shape[0] // numGroups`. 178 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 179 | (separable). The default is `[1] * factor`, which corresponds to 180 | nearest-neighbor upsampling. 181 | factor: Integer upsampling factor (default: 2). 182 | gain: Scaling factor for signal magnitude (default: 1.0). 183 | Returns: 184 | Tensor of the shape `[N, C, H * factor, W * factor]` or 185 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 186 | """ 187 | assert isinstance(factor, int) and factor >= 1 188 | 189 | # Check weight shape. 190 | assert len(w.shape) == 4 191 | convH = w.shape[2] 192 | convW = w.shape[3] 193 | inC = w.shape[1] 194 | w.shape[0] 195 | 196 | assert convW == convH 197 | 198 | # Setup filter kernel. 199 | if k is None: 200 | k = [1] * factor 201 | k = _setup_kernel(k) * (gain * (factor ** 2)) 202 | p = (k.shape[0] - factor) - (convW - 1) 203 | 204 | stride = (factor, factor) 205 | 206 | # Determine data dimensions. 207 | stride = [1, 1, factor, factor] 208 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 209 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 210 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 211 | assert output_padding[0] >= 0 and output_padding[1] >= 0 212 | num_groups = _shape(x, 1) // inC 213 | 214 | # Transpose weights. 215 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 216 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 217 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 218 | 219 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 220 | 221 | return upfirdn2d(x, torch.tensor(k, device=x.device), 222 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 223 | 224 | 225 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 226 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 227 | 228 | Padding is performed only once at the beginning, not between the operations. 229 | The fused op is considerably more efficient than performing the same 230 | calculation 231 | using standard TensorFlow ops. It supports gradients of arbitrary order. 232 | Args: 233 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 234 | C]`. 235 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 236 | outChannels]`. Grouped convolution can be performed by `inChannels = 237 | x.shape[0] // numGroups`. 238 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 239 | (separable). The default is `[1] * factor`, which corresponds to 240 | average pooling. 241 | factor: Integer downsampling factor (default: 2). 242 | gain: Scaling factor for signal magnitude (default: 1.0). 243 | Returns: 244 | Tensor of the shape `[N, C, H // factor, W // factor]` or 245 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 246 | """ 247 | assert isinstance(factor, int) and factor >= 1 248 | _outC, _inC, convH, convW = w.shape 249 | assert convW == convH 250 | if k is None: 251 | k = [1] * factor 252 | k = _setup_kernel(k) * gain 253 | p = (k.shape[0] - factor) + (convW - 1) 254 | s = [factor, factor] 255 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 256 | pad=((p + 1) // 2, p // 2)) 257 | return F.conv2d(x, w, stride=s, padding=0) 258 | 259 | 260 | def upsample_2d(x, k=None, factor=2, gain=1): 261 | r"""Upsample a batch of 2D images with the given filter. 262 | 263 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 264 | and upsamples each image with the given filter. The filter is normalized so 265 | that 266 | if the input pixels are constant, they will be scaled by the specified 267 | `gain`. 268 | Pixels outside the image are assumed to be zero, and the filter is padded 269 | with 270 | zeros so that its shape is a multiple of the upsampling factor. 271 | Args: 272 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 273 | C]`. 274 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 275 | (separable). The default is `[1] * factor`, which corresponds to 276 | nearest-neighbor upsampling. 277 | factor: Integer upsampling factor (default: 2). 278 | gain: Scaling factor for signal magnitude (default: 1.0). 279 | Returns: 280 | Tensor of the shape `[N, C, H * factor, W * factor]` 281 | """ 282 | assert isinstance(factor, int) and factor >= 1 283 | if k is None: 284 | k = [1] * factor 285 | k = _setup_kernel(k) * (gain * (factor ** 2)) 286 | p = k.shape[0] - factor 287 | return upfirdn2d(x,torch.tensor(k, device=x.device), 288 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 289 | 290 | 291 | def downsample_2d(x, k=None, factor=2, gain=1): 292 | r"""Downsample a batch of 2D images with the given filter. 293 | 294 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 295 | and downsamples each image with the given filter. The filter is normalized 296 | so that 297 | if the input pixels are constant, they will be scaled by the specified 298 | `gain`. 299 | Pixels outside the image are assumed to be zero, and the filter is padded 300 | with 301 | zeros so that its shape is a multiple of the downsampling factor. 302 | Args: 303 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 304 | C]`. 305 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 306 | (separable). The default is `[1] * factor`, which corresponds to 307 | average pooling. 308 | factor: Integer downsampling factor (default: 2). 309 | gain: Scaling factor for signal magnitude (default: 1.0). 310 | Returns: 311 | Tensor of the shape `[N, C, H // factor, W // factor]` 312 | """ 313 | 314 | assert isinstance(factor, int) and factor >= 1 315 | if k is None: 316 | k = [1] * factor 317 | k = _setup_kernel(k) * gain 318 | p = k.shape[0] - factor 319 | return upfirdn2d(x, torch.tensor(k, device=x.device), 320 | down=factor, pad=((p + 1) // 2, p // 2)) 321 | -------------------------------------------------------------------------------- /lib/nn/nn.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['AutoNorm', 'CondAffinePost', 'CondAffineScaleThenOffset', 'CondLinearlyCombine', 'EMA', 7 | 'EmbeddingTriangle', 'Residual'] 8 | 9 | from typing import Callable, Optional, Tuple, Sequence 10 | 11 | import torch 12 | import torch.nn 13 | import torch.nn.functional 14 | 15 | from .functional import expand_to 16 | 17 | 18 | class AutoNorm(torch.nn.Module): 19 | def __init__(self, n: int, momentum: float): 20 | super().__init__() 21 | self.avg = EMA((n,), momentum) 22 | self.var = EMA((n,), momentum) 23 | 24 | def forward(self, x: torch.Tensor) -> torch.Tensor: 25 | pad = [1] * (x.ndim - 2) 26 | if self.training: 27 | reduce = tuple(i for i in range(x.ndim) if i != 1) 28 | avg = self.avg(x.mean(reduce)) 29 | var = self.var((x - x.mean(reduce, keepdims=True)).square().mean(reduce)) 30 | else: 31 | avg, var = self.avg(), self.var() 32 | 33 | return (x - avg.view(1, -1, *pad)) * var.clamp(1e-6).rsqrt().view(1, -1, *pad) 34 | 35 | def denorm(self, x: torch.Tensor) -> torch.Tensor: 36 | assert not self.training 37 | pad = [1] * (x.ndim - 2) 38 | avg, var = self.avg(), self.var() 39 | return avg.view(1, -1, *pad) + x * var.clamp(1e-6).sqrt().view(1, -1, *pad) 40 | 41 | 42 | class CondAffinePost(torch.nn.Module): 43 | def __init__(self, ncond: int, nout: int, op: torch.nn.Module, scale: bool = True): 44 | super().__init__() 45 | self.op = op 46 | self.scale = scale 47 | self.m = torch.nn.Linear(ncond, nout + (nout if scale else 0)) 48 | self.cond: Optional[torch.Tensor] = None 49 | 50 | def set_cond(self, x: Optional[torch.Tensor]): 51 | self.cond = x if x is None else self.m(x) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | if self.scale: 55 | w, b = expand_to(self.cond, x).chunk(2, dim=1) 56 | return self.op(x) * w + b 57 | return self.op(x) + expand_to(self.cond, x) 58 | 59 | 60 | class CondAffineScaleThenOffset(torch.nn.Module): 61 | def __init__(self, ncond: int, nin: int, nout: int, op: torch.nn.Module, scale: bool = True): 62 | super().__init__() 63 | self.op = op 64 | self.nin = nin 65 | self.scale = scale 66 | self.m = torch.nn.Linear(ncond, nout + (nin if scale else 0)) 67 | self.cond: Optional[torch.Tensor] = None 68 | 69 | def set_cond(self, x: Optional[torch.Tensor]): 70 | self.cond = x if x is None else self.m(x) 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | cond = expand_to(self.cond, x) 74 | if self.scale: 75 | w, b = cond[:, :self.nin], cond[:, self.nin:] 76 | return self.op(x * w) + b 77 | return self.op(x) + cond 78 | 79 | 80 | class CondLinearlyCombine(torch.nn.Module): 81 | def __init__(self, ncond: int, n: int): 82 | super().__init__() 83 | self.n = n 84 | self.mix = torch.nn.Linear(ncond, n) 85 | self.cond: Optional[torch.Tensor] = None 86 | 87 | def set_cond(self, x: Optional[torch.Tensor]): 88 | self.cond = x if x is None else self.mix(x) 89 | 90 | def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 91 | cond = expand_to(self.cond, x[0]) 92 | return sum(x[i] * cond[:, i:i + 1] for i in range(self.n)) 93 | 94 | 95 | class EMA(torch.nn.Module): 96 | def __init__(self, shape: Tuple[int, ...], momentum: float): 97 | super().__init__() 98 | self.momentum = momentum 99 | self.register_buffer('step', torch.zeros((), dtype=torch.long)) 100 | self.register_buffer('ema', torch.zeros(shape)) 101 | 102 | def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor: 103 | if self.training: 104 | self.step.add_(1) 105 | mu = 1 - (1 - self.momentum) / (1 - self.momentum ** self.step) 106 | self.ema.add_((1 - mu) * (x - self.ema)) 107 | return self.ema 108 | 109 | 110 | class EmbeddingTriangle(torch.nn.Module): 111 | def __init__(self, dim: int, delta: float): 112 | """dim number of dimensions for embedding, delta is minimum distance between two values.""" 113 | super().__init__() 114 | logres = -torch.tensor(max(2 ** -31, 2 * delta)).log2() 115 | logfreqs = torch.nn.functional.pad(torch.linspace(0, logres, dim - 1), (1, 0), mode='constant', value=-1) 116 | self.register_buffer('freq', torch.pow(2, logfreqs)) 117 | 118 | def forward(self, x: torch.Tensor) -> torch.Tensor: 119 | y = 2 * (x.view(-1, 1) * self.freq).fmod(1) 120 | return 2 * (y * (y < 1) + (2 - y) * (y >= 1)) - 1 121 | 122 | 123 | class Residual(torch.nn.Module): 124 | def __init__(self, residual: Callable, skip: Optional[Callable] = None): 125 | super().__init__() 126 | self.residual = residual 127 | self.skip = torch.nn.Identity() if skip is None else skip 128 | 129 | def forward(self, x: torch.Tensor) -> torch.Tensor: 130 | return self.skip(x) + self.residual(x) 131 | -------------------------------------------------------------------------------- /lib/optim.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import copy 6 | import itertools 7 | 8 | import torch 9 | import torch.optim.swa_utils 10 | 11 | 12 | class ModuleEMA(torch.nn.Module): # Preferred to PyTorch's builtin because this is pickable 13 | def __init__(self, m: torch.nn.Module, momentum: float): 14 | super().__init__() 15 | self.module = copy.deepcopy(m) 16 | self.momentum = momentum 17 | self.register_buffer('step', torch.zeros((), dtype=torch.long)) 18 | 19 | def update(self, source: torch.nn.Module): 20 | self.step.add_(1) 21 | decay = (1 - self.momentum) / (1 - self.momentum ** self.step) 22 | with torch.no_grad(): 23 | for p_self, p_source in zip(self.module.parameters(), source.parameters()): 24 | p_self.add_(p_source - p_self, alpha=decay) 25 | for p_self, p_source in zip(self.module.buffers(), source.buffers()): 26 | if torch.is_floating_point(p_source): 27 | assert torch.is_floating_point(p_self) 28 | p_self.add_(p_source - p_self, alpha=decay) 29 | else: 30 | assert not torch.is_floating_point(p_self) 31 | p_self.add_(p_source - p_self) 32 | 33 | def forward(self, *args, **kwargs): 34 | return self.module.forward(*args, **kwargs) 35 | 36 | 37 | class AveragedModel(torch.optim.swa_utils.AveragedModel): 38 | def update_parameters(self, model): 39 | self_param = itertools.chain(self.module.parameters(), self.module.buffers()) 40 | model_param = itertools.chain(model.parameters(), model.buffers()) 41 | for p_swa, p_model in zip(self_param, model_param): 42 | device = p_swa.device 43 | p_model_ = p_model.detach().to(device) 44 | if self.n_averaged == 0: 45 | p_swa.detach().copy_(p_model_) 46 | else: 47 | p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device))) 48 | self.n_averaged += 1 49 | 50 | 51 | def module_exponential_moving_average(model: torch.nn.Module, 52 | momentum: float) -> torch.optim.swa_utils.AveragedModel: 53 | """Create an AverageModel using Stochastic Weight Averaging. 54 | 55 | Args: 56 | model: the torch Module to average. 57 | momentum: the running average momentum coefficient. I found values in 0.9, 0.99, 0.999, 58 | 0.9999, ... to give good results. The closer to 1 the better, but the longer one needs 59 | to train. 60 | Returns: 61 | torch.optim.swa_utils.AveragedModel module that replicates the model behavior with SWA 62 | weights. 63 | """ 64 | 65 | def ema(target: torch.Tensor, source: torch.Tensor, count: int) -> torch.Tensor: 66 | mu = 1 - (1 - momentum) / (1 - momentum ** (1 + count)) 67 | return mu * target + (1 - mu) * source 68 | 69 | return AveragedModel(model, avg_fn=ema) 70 | 71 | 72 | class HalfLifeEMA(torch.nn.Module): 73 | def __init__(self, m: torch.nn.Module, half_life: int = 500000, batch_size: int = 512): 74 | """ 75 | EMA Module based of half life (units of samples/images). 76 | 77 | Args: 78 | half_life : Half life of EMA in units of samples. 79 | """ 80 | super().__init__() 81 | self.module = copy.deepcopy(m) 82 | self.half_life = half_life 83 | self.batch_size = batch_size 84 | 85 | def update(self, source: torch.nn.Module): 86 | ema_beta = 0.5 ** (self.batch_size / self.half_life) 87 | 88 | with torch.no_grad(): 89 | for p_self, p_source in zip(self.module.parameters(), source.parameters()): 90 | p_self.copy_(p_source.lerp(p_self, ema_beta)) 91 | for p_self, p_source in zip(self.module.buffers(), source.buffers()): 92 | p_self.copy_(p_source) 93 | 94 | def forward(self, *args, **kwargs): 95 | return self.module.forward(*args, **kwargs) 96 | 97 | 98 | class CopyModule(torch.nn.Module): 99 | def __init__(self, m: torch.nn.Module): 100 | """Copy Module for self-conditioning.""" 101 | super().__init__() 102 | self.module = copy.deepcopy(m) 103 | 104 | def update(self, source: torch.nn.Module): 105 | 106 | with torch.no_grad(): 107 | for p_self, p_source in zip(self.module.parameters(), source.parameters()): 108 | p_self.copy_(p_source) 109 | for p_self, p_source in zip(self.module.buffers(), source.buffers()): 110 | p_self.copy_(p_source) 111 | 112 | def forward(self, *args, **kwargs): 113 | return self.module.forward(*args, **kwargs) -------------------------------------------------------------------------------- /lib/train.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['TrainInfo', 'TrainModel', 'DistillModel'] 7 | 8 | import dataclasses 9 | import json 10 | import pathlib 11 | import time 12 | from types import SimpleNamespace 13 | from typing import Callable, Dict, Iterable, List, Optional 14 | 15 | import torch.distributed 16 | import torch.nn.functional 17 | from absl import flags 18 | 19 | from lib.eval.fid import FID 20 | 21 | from .distributed import (gather_tensor, is_master, print, 22 | rank, trange, world_size) 23 | from .io import Checkpoint, Summary, SummaryWriter, zip_batch_as_png 24 | from .util import (FLAGS, command_line, int_str, repeater, 25 | report_module_weights, time_format) 26 | 27 | flags.DEFINE_integer('logstart', 1, help='Logstep at which to start.') 28 | flags.DEFINE_string('report_fid_len', '16M', help='How often to compute the FID during evaluations.') 29 | flags.DEFINE_string('report_img_len', '4M', help='How often to sample images during evaluations.') 30 | 31 | 32 | @dataclasses.dataclass 33 | class TrainInfo: 34 | samples: int 35 | progress: float 36 | 37 | 38 | class TrainModel(torch.nn.Module): 39 | COLORS = 3 40 | EVAL_ROWS = 16 41 | EVAL_COLUMNS = 16 42 | model: torch.nn.Module 43 | model_eval: torch.nn.Module 44 | train_op: Callable[..., Dict[str, torch.Tensor]] 45 | 46 | def __init__(self, arch: str, res: int, timesteps: int, **params): 47 | super().__init__() 48 | self.params = SimpleNamespace(arch=arch, res=res, timesteps=timesteps, **params) 49 | self.register_buffer('logstep', torch.zeros((), dtype=torch.long)) 50 | 51 | @property 52 | def device(self) -> str: 53 | for x in self.model.parameters(): 54 | return x.device 55 | 56 | @property 57 | def logdir(self) -> str: 58 | params = '_'.join(f'{k}@{v}' for k, v in sorted(vars(self.params).items()) if k not in ('arch',)) 59 | return f'{self.__class__.__name__}({self.params.arch})/{params}' 60 | 61 | def __str__(self) -> str: 62 | return '\n'.join(( 63 | f'{" Model ":-^80}', str(self.model), 64 | f'{" Parameters ":-^80}', report_module_weights(self.model), 65 | f'{" Config ":-^80}', 66 | '\n'.join(f'{k:20s}: {v}' for k, v in vars(self.params).items()) 67 | )) 68 | 69 | def save_meta(self, logdir: pathlib.Path, data_logger: Optional[SummaryWriter] = None): 70 | if not is_master(): 71 | return 72 | if data_logger is not None: 73 | summary = Summary() 74 | summary.text('info', f'
{self}
') 75 | data_logger.write(summary, 0) 76 | (logdir / 'params.json').open('w').write(json.dumps(vars(self.params), indent=4)) 77 | (logdir / 'model.txt').open('w').write(str(self.model.module)) 78 | (logdir / 'cmd.txt').open('w').write(command_line()) 79 | 80 | def evaluate(self, summary: Summary, 81 | logdir: pathlib.Path, 82 | ckpt: Optional[Checkpoint] = None, 83 | data_fid: Optional[Iterable] = None, 84 | fid_len: int = 0, sample_imgs: bool = True): 85 | assert (self.EVAL_ROWS * self.EVAL_COLUMNS) % world_size() == 0 86 | self.eval() 87 | with torch.no_grad(): 88 | if sample_imgs: 89 | generator = torch.Generator(device='cpu') 90 | generator.manual_seed(123623113456 + rank()) 91 | fixed = self((self.EVAL_ROWS * self.EVAL_COLUMNS) // world_size(), generator) 92 | rand = self((self.EVAL_ROWS * self.EVAL_COLUMNS) // world_size()) 93 | fixed, rand = (gather_tensor(x) for x in (fixed, rand)) 94 | summary.png('eval/fixed', fixed.view(self.EVAL_ROWS, self.EVAL_COLUMNS, *fixed.shape[1:])) 95 | summary.png('eval/random', rand.view(self.EVAL_ROWS, self.EVAL_COLUMNS, *rand.shape[1:])) 96 | if fid_len and data_fid: 97 | fid = FID(FLAGS.dataset, (self.COLORS, self.params.res, self.params.res)) 98 | fake_activations, fake_samples = fid.generate_activations_and_samples(self, FLAGS.fid_len) 99 | timesteps = self.params.timesteps >> self.logstep.item() 100 | zip_batch_as_png(fake_samples, logdir / f'samples_{fid_len}_timesteps_{timesteps}.zip') 101 | fidn, fid50 = fid.approximate_fid(fake_activations) 102 | summary.scalar(f'eval/fid({fid_len})', fidn) 103 | summary.scalar('eval/fid(50000)', fid50) 104 | if ckpt: 105 | ckpt.save_file(self.model_eval.module, f'model_{fid50:.5f}.ckpt') 106 | 107 | def train_loop(self, 108 | data_train: Iterable, 109 | data_fid: Optional[Iterable], 110 | batch: int, 111 | train_len: str, 112 | report_len: str, 113 | logdir: pathlib.Path, 114 | *, 115 | fid_len: int = 4096, 116 | keep_ckpts: int = 2): 117 | print(self) 118 | print(f'logdir: {logdir}') 119 | train_len, report_len, report_fid_len, report_img_len = (int_str(x) for x in ( 120 | train_len, report_len, FLAGS.report_fid_len, FLAGS.report_img_len)) 121 | assert report_len % batch == 0 122 | assert train_len % report_len == 0 123 | assert report_fid_len % report_len == 0 124 | assert report_img_len % report_len == 0 125 | data_train = repeater(data_train) 126 | ckpt = Checkpoint(self, logdir, keep_ckpts) 127 | start = ckpt.restore()[0] 128 | if start: 129 | print(f'Resuming training at {start} ({start / (1 << 20):.2f}M samples)') 130 | 131 | with SummaryWriter.create(logdir) as data_logger: 132 | if start == 0: 133 | self.save_meta(logdir, data_logger) 134 | 135 | for i in range(start, train_len, report_len): 136 | self.train() 137 | summary = Summary() 138 | range_iter = trange(i, i + report_len, batch, leave=False, unit='samples', 139 | unit_scale=batch, 140 | desc=f'Training kimg {i >> 10}/{train_len >> 10}') 141 | t0 = time.time() 142 | for samples in range_iter: 143 | self.train_step(summary, TrainInfo(samples, samples / train_len), next(data_train)) 144 | 145 | samples += batch 146 | t1 = time.time() 147 | summary.scalar('sys/samples_per_sec_train', report_len / (t1 - t0)) 148 | compute_fid = (samples % report_fid_len == 0) or (samples >= train_len) 149 | self.evaluate(summary, logdir, ckpt, data_fid, fid_len=fid_len if compute_fid else 0, 150 | sample_imgs=samples % report_img_len == 0) 151 | t2 = time.time() 152 | summary.scalar('sys/eval_time', t2 - t1) 153 | data_logger.write(summary, samples) 154 | ckpt.save(samples) 155 | print(f'{samples / (1 << 20):.2f}M/{train_len / (1 << 20):.2f}M samples, ' 156 | f'time left {time_format((t2 - t0) * (train_len - samples) / report_len)}\n{summary}') 157 | ckpt.save_file(self.model_eval.module, 'model.ckpt') 158 | 159 | def train_step(self, summary: Summary, info: TrainInfo, batch: List[torch.Tensor]) -> None: 160 | device = self.device 161 | metrics = self.train_op(info, *[x.to(device, non_blocking=True) for x in batch]) 162 | summary.from_metrics(metrics) 163 | -------------------------------------------------------------------------------- /lib/util.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | __all__ = ['FLAGS', 'artifact_dir', 'command_line', 'convert_256_to_11', 'cpu_count', 7 | 'downcast', 'ilog2', 'int_str', 'local_kwargs', 'power_of_2', 'repeater', 'report_module_weights', 'setup', 8 | 'time_format', 'to_numpy', 'to_png', 'tqdm', 'tqdm_with', 'trange'] 9 | 10 | import contextlib 11 | import dataclasses 12 | import inspect 13 | import io 14 | import multiprocessing 15 | import os 16 | import pathlib 17 | import random 18 | import re 19 | import sys 20 | from types import SimpleNamespace 21 | from typing import Callable, Iterable, Optional, Union 22 | 23 | import absl.flags 24 | import numpy as np 25 | import torch 26 | import torch.backends.cudnn 27 | import tqdm as tqdm_module 28 | from absl import flags 29 | from PIL import Image 30 | 31 | FLAGS = SimpleNamespace() 32 | 33 | flags.DEFINE_string('logdir', 'e', help='Directory whwer to save logs.') 34 | 35 | SYSTEM_FLAGS = {'?', 'alsologtostderr', 'help', 'helpfull', 'helpshort', 'helpxml', 'log_dir', 'logger_levels', 36 | 'logtostderr', 'only_check_args', 'pdb', 'pdb_post_mortem', 'profile_file', 'run_with_pdb', 37 | 'run_with_profiling', 'showprefixforinfo', 'stderrthreshold', 'use_cprofile_for_profiling', 'v', 38 | 'verbosity'} 39 | 40 | @dataclasses.dataclass 41 | class MemInfo: 42 | total: int # KB 43 | res: int # KB 44 | shared: int # KB 45 | 46 | @classmethod 47 | def query(cls): 48 | with open(f'/proc/{os.getpid()}/statm', 'r') as f: 49 | return cls(*[int(x) for x in f.read().split(' ')[:3]]) 50 | 51 | def __str__(self): 52 | gb = 1 << 20 53 | return f'Total {self.total / gb:.4f} GB | Res {self.res / gb:.4f} GB | Shared {self.shared / gb:.4f} GB' 54 | 55 | 56 | def artifact_dir(*args) -> pathlib.Path: 57 | path = pathlib.Path(FLAGS.logdir) 58 | return path.joinpath(*args) 59 | 60 | 61 | def command_line() -> str: 62 | argv = sys.argv[:] 63 | rex = re.compile(r'([!|*$#?~&<>{}()\[\]\\ "\'])') 64 | cmd = ' '.join(rex.sub(r'\\\1', v) for v in argv) 65 | return cmd 66 | 67 | 68 | def convert_256_to_11(x: torch.Tensor) -> torch.Tensor: 69 | """Lossless conversion of 0,255 interval to -1,1 interval.""" 70 | return x / 128 - 255 / 256 71 | 72 | 73 | def cpu_count() -> int: 74 | return multiprocessing.cpu_count() 75 | 76 | 77 | def downcast(x: Union[np.ndarray, np.dtype]) -> Union[np.ndarray, np.dtype]: 78 | """Downcast numpy float64 to float32.""" 79 | if isinstance(x, np.dtype): 80 | return np.float32 if x == np.float64 else x 81 | if x.dtype == np.float64: 82 | return x.astype('f') 83 | return x 84 | 85 | 86 | def ilog2(x: int) -> int: 87 | y = x.bit_length() - 1 88 | assert 1 << y == x 89 | return y 90 | 91 | 92 | def int_str(s: str) -> int: 93 | p = 1 94 | if s.endswith('K'): 95 | s, p = s[:-1], 1 << 10 96 | elif s.endswith('M'): 97 | s, p = s[:-1], 1 << 20 98 | elif s.endswith('G'): 99 | s, p = s[:-1], 1 << 30 100 | return int(float(eval(s)) * p) 101 | 102 | 103 | def local_kwargs(kwargs: dict, f: Callable) -> dict: 104 | """Return the kwargs from dict that are inputs to function f.""" 105 | s = inspect.signature(f) 106 | p = s.parameters 107 | if next(reversed(p.values())).kind == inspect.Parameter.VAR_KEYWORD: 108 | return kwargs 109 | if len(kwargs) < len(p): 110 | return {k: v for k, v in kwargs.items() if k in p} 111 | return {k: kwargs[k] for k in p.keys() if k in kwargs} 112 | 113 | 114 | def power_of_2(x: int) -> int: 115 | """Return highest power of 2 <= x""" 116 | return 1 << (x.bit_length() - 1) 117 | 118 | 119 | def repeater(it: Iterable): 120 | """Helper function to repeat an iterator in a memory efficient way.""" 121 | while True: 122 | for x in it: 123 | yield x 124 | 125 | 126 | def report_module_weights(m: torch.nn.Module): 127 | weights = [(k, tuple(v.shape)) for k, v in m.named_parameters()] 128 | weights.append((f'Total ({len(weights)})', (sum(np.prod(x[1]) for x in weights),))) 129 | width = max(len(x[0]) for x in weights) 130 | return '\n'.join(f'{k:<{width}} {np.prod(s):>10} {str(s):>16}' for k, s in weights) 131 | 132 | 133 | def setup(seed: Optional[int] = None, quiet: bool = False, flags_values: Optional[SimpleNamespace] = None): 134 | if flags_values: 135 | for k, v in vars(flags_values).items(): 136 | setattr(FLAGS, k, v) 137 | else: 138 | for k in absl.flags.FLAGS: 139 | if k not in SYSTEM_FLAGS: 140 | setattr(FLAGS, k, getattr(absl.flags.FLAGS, k)) 141 | torch.backends.cudnn.benchmark = True 142 | # os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' 143 | try: 144 | torch.multiprocessing.set_start_method('spawn') 145 | except RuntimeError: 146 | pass 147 | if seed is not None: 148 | random.seed(seed) 149 | np.random.seed(seed) 150 | torch.manual_seed(seed) 151 | if not quiet: 152 | print(f'{" Flags ":-^79s}') 153 | for k in sorted(vars(FLAGS)): 154 | print(f'{k:32s}: {getattr(FLAGS, k)}') 155 | print(f'{" System ":-^79s}') 156 | for k, v in {'cpus(system)': multiprocessing.cpu_count(), 157 | 'cpus(fixed)': cpu_count(), 158 | 'multiprocessing.start_method': torch.multiprocessing.get_start_method()}.items(): 159 | print(f'{k:32s}: {v}') 160 | 161 | 162 | def time_format(t: float) -> str: 163 | t = int(t) 164 | hours = t // 3600 165 | mins = (t // 60) % 60 166 | secs = t % 60 167 | return f'{hours:02d}:{mins:02d}:{secs:02d}' 168 | 169 | 170 | def to_numpy(x: Union[np.ndarray, torch.Tensor]): 171 | if not isinstance(x, torch.Tensor): 172 | return x 173 | return x.detach().cpu().numpy() 174 | 175 | 176 | def to_png(x: Union[np.ndarray, torch.Tensor]) -> bytes: 177 | """Converts numpy array in (C, H, W) or (Rows, Cols, C, H, W) format into PNG format.""" 178 | assert x.ndim in (3, 5) 179 | if isinstance(x, torch.Tensor): 180 | x = to_numpy(x) 181 | if x.ndim == 5: # Image grid 182 | x = np.transpose(x, (2, 0, 3, 1, 4)) 183 | x = x.reshape((x.shape[0], x.shape[1] * x.shape[2], x.shape[3] * x.shape[4])) # (C, H, W) 184 | if x.dtype in (np.float64, np.float32, np.float16): 185 | x = np.transpose(np.round(127.5 * (x + 1)), (1, 2, 0)).clip(0, 255).astype('uint8') 186 | elif x.dtype != np.uint8: 187 | raise ValueError('Unsupported array type, expecting float or uint8', x.dtype) 188 | if x.shape[2] == 1: 189 | x = np.broadcast_to(x, x.shape[:2] + (3,)) 190 | with io.BytesIO() as f: 191 | Image.fromarray(x).save(f, 'png') 192 | return f.getvalue() 193 | 194 | 195 | def tqdm(iterable: Iterable, **kwargs) -> Iterable: 196 | return tqdm_module.tqdm(iterable, **kwargs) 197 | 198 | 199 | def tqdm_with(**kwargs) -> Iterable: 200 | class Noop: 201 | def update(self, *args, **kwargs): 202 | pass 203 | 204 | @contextlib.contextmanager 205 | def noop(): 206 | yield Noop() 207 | 208 | return tqdm_module.tqdm(**kwargs) 209 | 210 | 211 | def trange(*args, **kwargs): 212 | return tqdm_module.trange(*args, **kwargs) 213 | 214 | -------------------------------------------------------------------------------- /lib/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | from . import unet # noqa -------------------------------------------------------------------------------- /lib/zoo/unet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | # Mostly copied from https://github.com/rosinality/denoising-diffusion-pytorch 7 | # modified to match https://github.com/google-research/google-research/blob/master/diffusion_distillation/diffusion_distillation/unet.py 8 | 9 | import math 10 | from typing import List, Tuple, Optional 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | swish = F.silu 17 | 18 | 19 | def get_timestep_embedding(timesteps, embedding_dim, max_time=1000.): 20 | """Build sinusoidal embeddings (from Fairseq). 21 | This matches the implementation in tensor2tensor, but differs slightly 22 | from the description in Section 3.5 of "Attention Is All You Need". 23 | Args: 24 | timesteps: jnp.ndarray: generate embedding vectors at these timesteps 25 | embedding_dim: int: dimension of the embeddings to generate 26 | max_time: float: largest time input 27 | dtype: data type of the generated embeddings 28 | Returns: 29 | embedding vectors with shape `(len(timesteps), embedding_dim)` 30 | """ 31 | assert len(timesteps.shape) == 1 32 | timesteps *= (1000. / max_time) 33 | 34 | half_dim = embedding_dim // 2 35 | emb = math.log(10000) / (half_dim - 1) 36 | emb = torch.exp(torch.arange(half_dim) * -emb) 37 | emb = timesteps.float()[:, None] * emb[None, :].to(timesteps) 38 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], axis=1) 39 | assert emb.shape == (timesteps.shape[0], embedding_dim) 40 | return emb 41 | 42 | 43 | @torch.no_grad() 44 | def variance_scaling_init_(tensor, scale=1, mode="fan_avg", distribution="uniform"): 45 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor) 46 | 47 | if mode == "fan_in": 48 | scale /= fan_in 49 | 50 | elif mode == "fan_out": 51 | scale /= fan_out 52 | 53 | else: 54 | scale /= (fan_in + fan_out) / 2 55 | 56 | if distribution == "normal": 57 | std = math.sqrt(scale) 58 | 59 | return tensor.normal_(0, std) 60 | 61 | else: 62 | bound = math.sqrt(3 * scale) 63 | 64 | return tensor.uniform_(-bound, bound) 65 | 66 | 67 | def conv2d( 68 | in_channel, 69 | out_channel, 70 | kernel_size, 71 | stride=1, 72 | padding=0, 73 | bias=True, 74 | scale=1, 75 | mode="fan_avg", 76 | ): 77 | conv = nn.Conv2d( 78 | in_channel, out_channel, kernel_size, stride=stride, padding=padding, bias=bias 79 | ) 80 | 81 | variance_scaling_init_(conv.weight, scale, mode=mode) 82 | 83 | if bias: 84 | nn.init.zeros_(conv.bias) 85 | 86 | return conv 87 | 88 | 89 | def linear(in_channel, out_channel, scale=1, mode="fan_avg"): 90 | lin = nn.Linear(in_channel, out_channel) 91 | 92 | variance_scaling_init_(lin.weight, scale, mode=mode) 93 | nn.init.zeros_(lin.bias) 94 | 95 | return lin 96 | 97 | 98 | class Swish(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | 102 | def forward(self, input): 103 | return swish(input) 104 | 105 | 106 | class Upsample(nn.Sequential): 107 | def __init__(self, channel): 108 | layers = [ 109 | nn.Upsample(scale_factor=2, mode="nearest"), 110 | conv2d(channel, channel, 3, padding=1), 111 | ] 112 | 113 | super().__init__(*layers) 114 | 115 | 116 | class Downsample(nn.Sequential): 117 | def __init__(self, channel): 118 | layers = [conv2d(channel, channel, 3, stride=2, padding=1)] 119 | 120 | super().__init__(*layers) 121 | 122 | 123 | class ResBlock(nn.Module): 124 | def __init__( 125 | self, in_channel, out_channel, time_dim, resample, use_affine_time=False, dropout=0, group_norm=32 126 | ): 127 | super().__init__() 128 | 129 | self.use_affine_time = use_affine_time 130 | self.resample = resample 131 | time_out_dim = out_channel 132 | time_scale = 1 133 | 134 | if self.use_affine_time: 135 | time_out_dim *= 2 136 | time_scale = 1e-10 137 | 138 | self.norm1 = nn.GroupNorm(group_norm, in_channel) 139 | self.activation1 = Swish() 140 | if self.resample: 141 | self.updown = { 142 | 'up': nn.Upsample(scale_factor=2, mode="nearest"), 143 | 'down': nn.AvgPool2d(kernel_size=2, stride=2) 144 | }[self.resample] 145 | 146 | self.conv1 = conv2d(in_channel, out_channel, 3, padding=1) 147 | 148 | self.time = nn.Sequential( 149 | Swish(), linear(time_dim, time_out_dim, scale=time_scale) 150 | ) 151 | 152 | self.norm2 = nn.GroupNorm(group_norm, out_channel) 153 | self.activation2 = Swish() 154 | self.dropout = nn.Dropout(dropout) 155 | self.conv2 = conv2d(out_channel, out_channel, 3, padding=1, scale=1e-10) 156 | 157 | if in_channel != out_channel: 158 | self.skip = conv2d(in_channel, out_channel, 1) 159 | 160 | else: 161 | self.skip = None 162 | 163 | def forward(self, input, time): 164 | batch = input.shape[0] 165 | out = self.norm1(input) 166 | out = self.activation1(out) 167 | 168 | if self.resample: 169 | out = self.updown(out) 170 | input = self.updown(input) 171 | 172 | out = self.conv1(out) 173 | 174 | if self.use_affine_time: 175 | gamma, beta = self.time(time).view(batch, -1, 1, 1).chunk(2, dim=1) 176 | out = (1 + gamma) * self.norm2(out) + beta 177 | else: 178 | out = out + self.time(time).view(batch, -1, 1, 1) 179 | out = self.norm2(out) 180 | 181 | out = self.conv2(self.dropout(self.activation2(out))) 182 | 183 | if self.skip is not None: 184 | input = self.skip(input) 185 | 186 | return out + input 187 | 188 | 189 | class SelfAttention(nn.Module): 190 | def __init__(self, in_channel, n_head=1, head_dim=None, group_norm=32): 191 | super().__init__() 192 | 193 | if head_dim is None: 194 | assert n_head is not None 195 | assert in_channel % n_head == 0 196 | self.n_head = n_head 197 | self.head_dim = in_channel // n_head 198 | else: 199 | assert n_head is None 200 | assert in_channel % head_dim == 0 201 | self.head_dim = head_dim 202 | self.n_head = in_channel // head_dim 203 | 204 | self.norm = nn.GroupNorm(group_norm, in_channel) 205 | self.qkv = conv2d(in_channel, in_channel * 3, 1) 206 | self.out = conv2d(in_channel, in_channel, 1, scale=1e-10) 207 | 208 | def forward(self, input): 209 | batch, channel, height, width = input.shape 210 | norm = self.norm(input) 211 | qkv = self.qkv(norm).view(batch, self.n_head, self.head_dim * 3, height, width) 212 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 213 | 214 | attn = torch.einsum( 215 | "bnchw, bncyx -> bnhwyx", query, key 216 | ).contiguous() / math.sqrt(self.head_dim) 217 | attn = attn.view(batch, self.n_head, height, width, -1) 218 | attn = torch.softmax(attn, -1) 219 | attn = attn.view(batch, self.n_head, height, width, height, width) 220 | 221 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 222 | out = self.out(out.view(batch, channel, height, width)) 223 | 224 | return out + input 225 | 226 | 227 | class ResBlockWithAttention(nn.Module): 228 | def __init__( 229 | self, 230 | in_channel, 231 | out_channel, 232 | time_dim, 233 | dropout, 234 | resample, 235 | use_attention=False, 236 | attention_head: Optional[int] = 1, 237 | head_dim: Optional[int] = None, 238 | use_affine_time=False, 239 | group_norm=32, 240 | ): 241 | super().__init__() 242 | self.resblocks = ResBlock( 243 | in_channel, out_channel, time_dim, resample, use_affine_time, dropout, group_norm=group_norm 244 | ) 245 | 246 | if use_attention: 247 | self.attention = SelfAttention(out_channel, n_head=attention_head, head_dim=head_dim, group_norm=group_norm) 248 | 249 | else: 250 | self.attention = None 251 | 252 | def forward(self, input, time): 253 | out = self.resblocks(input, time) 254 | 255 | if self.attention is not None: 256 | out = self.attention(out) 257 | 258 | return out 259 | 260 | 261 | class UNet(nn.Module): 262 | def __init__( 263 | self, 264 | in_channel: int, 265 | channel: int, 266 | emb_channel: int, 267 | channel_multiplier: List[int], 268 | n_res_blocks: int, 269 | attn_rezs: List[int], 270 | attn_heads: Optional[int], 271 | head_dim: Optional[int], 272 | use_affine_time: bool = False, 273 | dropout: float = 0, 274 | num_output: int = 1, 275 | resample: bool = False, 276 | init_rez: int = 32, 277 | logsnr_input_type: str = 'inv_cos', 278 | logsnr_scale_range: Tuple[float, float] = (-10., 10.), 279 | num_classes: int = 1 280 | ): 281 | super().__init__() 282 | 283 | self.resample = resample 284 | self.channel = channel 285 | self.logsnr_input_type = logsnr_input_type 286 | self.logsnr_scale_range = logsnr_scale_range 287 | self.num_classes = num_classes 288 | time_dim = emb_channel 289 | group_norm = 32 290 | 291 | n_block = len(channel_multiplier) 292 | 293 | if self.num_classes > 1: 294 | self.class_emb = nn.Linear(self.num_classes, time_dim) 295 | 296 | self.time = nn.Sequential( 297 | linear(channel, time_dim), 298 | Swish(), 299 | linear(time_dim, time_dim), 300 | ) 301 | 302 | down_layers = [conv2d(in_channel, channel, 3, padding=1)] 303 | feat_channels = [channel] 304 | in_channel = channel 305 | cur_rez = init_rez 306 | for i in range(n_block): 307 | for _ in range(n_res_blocks): 308 | channel_mult = channel * channel_multiplier[i] 309 | 310 | down_layers.append( 311 | ResBlockWithAttention( 312 | in_channel, 313 | channel_mult, 314 | time_dim, 315 | dropout, 316 | resample=None, 317 | use_attention=cur_rez in attn_rezs, 318 | attention_head=attn_heads, 319 | head_dim=head_dim, 320 | use_affine_time=use_affine_time, 321 | group_norm=group_norm 322 | ) 323 | ) 324 | 325 | feat_channels.append(channel_mult) 326 | in_channel = channel_mult 327 | 328 | if i != n_block - 1: 329 | if self.resample: 330 | down_layers.append(ResBlock( 331 | in_channel, 332 | in_channel, 333 | time_dim, 334 | resample='down', 335 | use_affine_time=use_affine_time, 336 | dropout=dropout, 337 | group_norm=group_norm 338 | )) 339 | else: 340 | down_layers.append(Downsample(in_channel)) 341 | cur_rez = cur_rez // 2 342 | feat_channels.append(in_channel) 343 | 344 | self.down = nn.ModuleList(down_layers) 345 | 346 | self.mid = nn.ModuleList( 347 | [ 348 | ResBlockWithAttention( 349 | in_channel, 350 | in_channel, 351 | time_dim, 352 | resample=None, 353 | dropout=dropout, 354 | use_attention=True, 355 | attention_head=attn_heads, 356 | head_dim=head_dim, 357 | use_affine_time=use_affine_time, 358 | group_norm=group_norm 359 | ), 360 | ResBlockWithAttention( 361 | in_channel, 362 | in_channel, 363 | time_dim, 364 | resample=None, 365 | dropout=dropout, 366 | use_affine_time=use_affine_time, 367 | group_norm=group_norm 368 | ), 369 | ] 370 | ) 371 | 372 | up_layers = [] 373 | for i in reversed(range(n_block)): 374 | for _ in range(n_res_blocks + 1): 375 | channel_mult = channel * channel_multiplier[i] 376 | 377 | up_layers.append( 378 | ResBlockWithAttention( 379 | in_channel + feat_channels.pop(), 380 | channel_mult, 381 | time_dim, 382 | resample=None, 383 | dropout=dropout, 384 | use_attention=cur_rez in attn_rezs, 385 | attention_head=attn_heads, 386 | head_dim=head_dim, 387 | use_affine_time=use_affine_time, 388 | group_norm=group_norm 389 | ) 390 | ) 391 | 392 | in_channel = channel_mult 393 | 394 | if i != 0: 395 | if self.resample: 396 | up_layers.append(ResBlock( 397 | in_channel, 398 | in_channel, 399 | time_dim, 400 | resample='up', 401 | use_affine_time=use_affine_time, 402 | dropout=dropout, 403 | group_norm=group_norm 404 | )) 405 | else: 406 | up_layers.append(Upsample(in_channel)) 407 | cur_rez = cur_rez * 2 408 | 409 | self.up = nn.ModuleList(up_layers) 410 | 411 | self.out = nn.Sequential( 412 | nn.GroupNorm(group_norm, in_channel), 413 | Swish(), 414 | conv2d(in_channel, 3 * num_output, 3, padding=1, scale=1e-10), 415 | ) 416 | 417 | def get_time_embed(self, logsnr): 418 | if self.logsnr_input_type == 'linear': 419 | logsnr_input = (logsnr - self.logsnr_scale_range[0]) / (self.logsnr_scale_range[1] - self.logsnr_scale_range[0]) 420 | elif self.logsnr_input_type == 'sigmoid': 421 | logsnr_input = torch.sigmoid(logsnr) 422 | elif self.logsnr_input_type == 'inv_cos': 423 | logsnr_input = (torch.arctan(torch.exp(-0.5 * torch.clip(logsnr, -20., 20.))) / (0.5 * torch.pi)) 424 | else: 425 | raise NotImplementedError(self.logsnr_input_type) 426 | time_emb = get_timestep_embedding(logsnr_input, embedding_dim=self.channel, max_time=1.) 427 | time_embed = self.time(time_emb) 428 | return time_embed 429 | 430 | def forward(self, input, logsnr, y=None): 431 | time_embed = self.get_time_embed(logsnr) 432 | 433 | # Class embedding 434 | assert self.num_classes >= 1 435 | if self.num_classes > 1: 436 | y_emb = nn.functional.one_hot(y, num_classes=self.num_classes).float() 437 | y_emb = self.class_emb(y_emb) 438 | time_embed += y_emb 439 | del y 440 | 441 | feats = [] 442 | out = input 443 | for layer in self.down: 444 | if isinstance(layer, ResBlockWithAttention): 445 | out = layer(out, time_embed) 446 | elif isinstance(layer, ResBlock): 447 | out = layer(out, time_embed) 448 | else: 449 | out = layer(out) 450 | 451 | feats.append(out) 452 | 453 | for layer in self.mid: 454 | out = layer(out, time_embed) 455 | 456 | for layer in self.up: 457 | if isinstance(layer, ResBlockWithAttention): 458 | out = layer(torch.cat((out, feats.pop()), 1), time_embed) 459 | elif isinstance(layer, ResBlock): 460 | out = layer(out, time_embed) 461 | else: 462 | out = layer(out) 463 | 464 | out = self.out(out) 465 | 466 | return out 467 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | blobfile==2.0.1 3 | colorama==0.4.3 4 | einops==0.4.1 5 | flax==0.6.6 6 | imageio==2.19.3 7 | lmdb==1.4.0 8 | matplotlib==3.5.2 9 | nbconvert>=6.5.1 10 | numba==0.55.2 11 | numpy==1.22.4 12 | protobuf==3.20.3 13 | resize-right 14 | scikit-image 15 | scipy==1.8.1 16 | setuptools>=65.5.1 17 | tabulate==0.8.9 18 | tensorboard==2.12.2 19 | tensorflow==2.12.0 20 | torch==1.13.1+cu116 21 | torchvision==0.14.1+cu116 22 | tqdm==4.64.0 23 | -------------------------------------------------------------------------------- /tc_distill.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import copy 6 | import functools 7 | import os 8 | import pathlib 9 | import shutil 10 | from typing import Callable, Dict, Optional 11 | 12 | import torch 13 | import torch.nn.functional 14 | from absl import app, flags 15 | 16 | import lib 17 | from lib.distributed import device, device_id, print 18 | from lib.util import FLAGS, int_str 19 | from lib.zoo.unet import UNet 20 | 21 | 22 | def get_model(name: str): 23 | if name == 'cifar10': 24 | net = UNet(in_channel=3, 25 | channel=256, 26 | emb_channel=1024, 27 | channel_multiplier=[1, 1, 1], 28 | n_res_blocks=3, 29 | attn_rezs=[8, 16], 30 | attn_heads=1, 31 | head_dim=None, 32 | use_affine_time=True, 33 | dropout=0.2, 34 | num_output=1, 35 | resample=True, 36 | num_classes=1) 37 | elif name == 'imagenet64': 38 | # imagenet model is class conditional 39 | net = UNet(in_channel=3, 40 | channel=192, 41 | emb_channel=768, 42 | channel_multiplier=[1, 2, 3, 4], 43 | n_res_blocks=3, 44 | init_rez=64, 45 | attn_rezs=[8, 16, 32], 46 | attn_heads=None, 47 | head_dim=64, 48 | use_affine_time=True, 49 | dropout=0., 50 | num_output=2, # predict signal and noise 51 | resample=True, 52 | num_classes=1000) 53 | else: 54 | raise NotImplementedError(name) 55 | return net 56 | 57 | 58 | class TCDistillGoogleModel(lib.train.TrainModel): 59 | R_NONE, R_STEP, R_PHASE = 'none', 'step', 'phase' 60 | R_ALL = R_NONE, R_STEP, R_PHASE 61 | 62 | def __init__(self, name: str, res: int, timesteps: int, **params): 63 | super().__init__("GoogleUNet", res, timesteps, **params) 64 | self.num_classes = 1 65 | self.shape = 3, res, res 66 | self.timesteps = timesteps 67 | model = get_model(name) 68 | if 'cifar' in name: 69 | self.ckpt_path = 'ckpts/cifar_original.pt' 70 | self.predict_both = False 71 | elif 'imagenet' in name: 72 | self.ckpt_path = 'ckpts/imagenet_original.pt' 73 | self.num_classes = 1000 74 | self.predict_both = True 75 | self.EVAL_COLUMNS = self.EVAL_ROWS = 8 76 | else: 77 | raise NotImplementedError(name) 78 | 79 | self.time_schedule = tuple(int(x) for x in self.params.time_schedule.split(',')) 80 | steps_per_phase = int_str(FLAGS.train_len) / (FLAGS.batch * (len(self.time_schedule) - 1)) 81 | ema = self.params.ema_residual ** (1 / steps_per_phase) 82 | model.apply(functools.partial(lib.nn.functional.set_bn_momentum, momentum=1 - ema)) 83 | model.apply(functools.partial(lib.nn.functional.set_dropout, p=0)) 84 | self.model = lib.distributed.wrap(model) 85 | self.model_eval = lib.optim.ModuleEMA(model, momentum=ema).to(device_id()) 86 | self.self_teacher = lib.optim.ModuleEMA(model, momentum=self.params.sema).to(device_id()) 87 | self.teacher = copy.deepcopy(model).to(device_id()) 88 | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.params.lr) 89 | self.register_buffer('phase', torch.zeros((), dtype=torch.long)) 90 | 91 | def initialize_weights_from_teacher(self, logdir: pathlib.Path): 92 | teacher_ckpt_path = logdir / 'ckpt/teacher.ckpt' 93 | if device_id() == 0: 94 | os.makedirs(logdir / 'ckpt', exist_ok=True) 95 | shutil.copy2(self.ckpt_path, teacher_ckpt_path) 96 | 97 | lib.distributed.barrier() 98 | self.model.module.load_state_dict(torch.load(teacher_ckpt_path)) 99 | self.model_eval.module.load_state_dict(torch.load(teacher_ckpt_path)) 100 | self.self_teacher.module.load_state_dict(torch.load(teacher_ckpt_path)) 101 | self.teacher.load_state_dict(torch.load(teacher_ckpt_path)) 102 | 103 | def randn(self, n: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 104 | if generator is not None: 105 | assert generator.device == torch.device('cpu') 106 | return torch.randn((n, *self.shape), device='cpu', generator=generator, dtype=torch.double).to(self.device) 107 | 108 | def call_model(self, model: Callable, xt: torch.Tensor, index: torch.Tensor, 109 | y: Optional[torch.Tensor] = None) -> torch.Tensor: 110 | if y is None: 111 | return model(xt.float(), index.float()).double() 112 | else: 113 | return model(xt.float(), index.float(), y.long()).double() 114 | 115 | def forward(self, samples: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 116 | step = self.timesteps // self.time_schedule[self.phase.item() + 1] 117 | xt = self.randn(samples, generator).to(device_id()) 118 | if self.num_classes > 1: 119 | y = torch.randint(0, self.num_classes, (samples,)).to(xt) 120 | else: 121 | y = None 122 | 123 | for t in reversed(range(0, self.timesteps, step)): 124 | ix = torch.Tensor([t + step]).long().to(device_id()), torch.Tensor([t]).long().to(device_id()) 125 | logsnr = tuple(self.logsnr_schedule_cosine(i / self.timesteps).to(xt.double()) for i in ix) 126 | g = tuple(torch.sigmoid(l).view(-1, 1, 1, 1) for l in logsnr) # Get gamma values 127 | x0 = self.call_model(self.model_eval, xt, logsnr[0].repeat(xt.shape[0]), y) 128 | xt = self.post_xt_x0(xt, x0, g[0], g[1]) 129 | return xt 130 | 131 | @staticmethod 132 | def logsnr_schedule_cosine(t, logsnr_min=torch.Tensor([-20.]), logsnr_max=torch.Tensor([20.])): 133 | b = torch.arctan(torch.exp(-0.5 * logsnr_max)).to(t) 134 | a = torch.arctan(torch.exp(-0.5 * logsnr_min)).to(t) - b 135 | return -2. * torch.log(torch.tan(a * t + b)) 136 | 137 | @staticmethod 138 | def predict_eps_from_x(z, x, logsnr): 139 | """eps = (z - alpha*x)/sigma.""" 140 | assert logsnr.ndim == x.ndim 141 | return torch.sqrt(1. + torch.exp(logsnr)) * (z - x * torch.rsqrt(1. + torch.exp(-logsnr))) 142 | 143 | def post_xt_x0(self, xt: torch.Tensor, out: torch.Tensor, g: torch.Tensor, g1: torch.Tensor) -> torch.Tensor: 144 | if self.predict_both: 145 | assert out.shape[1] == 6 146 | model_x, model_eps = out[:, :3], out[:, 3:] 147 | # reconcile the two predictions 148 | model_x_eps = (xt - model_eps * (1 - g).sqrt()) * g.rsqrt() 149 | wx = 1 - g 150 | x0 = wx * model_x + (1. - wx) * model_x_eps 151 | else: 152 | x0 = out 153 | x0 = torch.clip(x0, -1., 1.) 154 | eps = (xt - x0 * g.sqrt()) * (1 - g).rsqrt() 155 | return torch.nan_to_num(x0 * g1.sqrt() + eps * (1 - g1).sqrt()) 156 | 157 | def train_op(self, info: lib.train.TrainInfo, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: 158 | if self.num_classes == 1: 159 | y = None 160 | with torch.no_grad(): 161 | phase = int(info.progress * (1 - 1e-9) * (len(self.time_schedule) - 1)) 162 | if phase != self.phase: 163 | print(f'Refreshing teacher {phase}') 164 | self.phase.add_(1) 165 | self.teacher.load_state_dict(self.model_eval.module.state_dict()) 166 | if self.params.reset == self.R_PHASE: 167 | self.model_eval.step.mul_(0) 168 | semi_range = self.time_schedule[phase] // self.time_schedule[phase + 1] 169 | semi = self.timesteps // self.time_schedule[phase] 170 | step = self.timesteps // self.time_schedule[phase + 1] 171 | index = torch.randint(1, 1 + (self.timesteps // step), (x.shape[0],), device=device()) * step 172 | semi_index = torch.randint(semi_range, index.shape, device=device()) * semi 173 | ix = index - semi_index, index - semi_index - semi, index - step 174 | logsnr = tuple(self.logsnr_schedule_cosine(i.double() / self.timesteps).to(x.double()) for i in ix) 175 | g = tuple(torch.sigmoid(l).view(-1, 1, 1, 1) for l in logsnr) # Get gamma values 176 | noise = torch.randn_like(x) 177 | xt0 = x.double() * g[0].sqrt() + noise * (1 - g[0]).sqrt() 178 | xt1 = self.post_xt_x0(xt0, self.call_model(self.teacher, xt0, logsnr[0], y), g[0], g[1]) 179 | xt2 = self.post_xt_x0(xt1, self.call_model(self.self_teacher, xt1, logsnr[1], y), g[1], g[2]) 180 | xt2 += (semi_index + semi == step).view(-1, 1, 1, 1) * (xt1 - xt2) # Only propagate inside phase semi_range 181 | # Find target such that self.post_xt_x0(xt0, target, g[0], g[2]) == xt2 182 | target = ((xt0 * (1 - g[2]).sqrt() - xt2 * (1 - g[0]).sqrt()) / 183 | ((g[0] * (1 - g[2])).sqrt() - (g[2] * (1 - g[0])).sqrt())) 184 | 185 | self.opt.zero_grad(set_to_none=True) 186 | pred = self.call_model(self.model, xt0, logsnr[0], y) 187 | if self.predict_both: 188 | assert pred.shape[1] == 6 189 | model_x, model_eps = pred[:, :3], pred[:, 3:] 190 | # reconcile the two predictions 191 | model_x_eps = (xt0 - model_eps * (1 - g[0]).sqrt()) * g[0].rsqrt() 192 | wx = 1 - g[0] 193 | pred_x = wx * model_x + (1. - wx) * model_x_eps 194 | else: 195 | pred_x = pred 196 | 197 | loss = ((g[0] / (1 - g[0])).clamp(1) * (pred_x - target.detach()).square()).mean(0).sum() 198 | loss.backward() 199 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) 200 | self.opt.step() 201 | self.self_teacher.update(self.model) 202 | self.model_eval.update(self.model) 203 | return {'loss/global': loss, 'stat/timestep': self.time_schedule[phase + 1]} 204 | 205 | 206 | def check_steps(): 207 | timesteps = [int(x) for x in FLAGS.time_schedule.split(',')] 208 | assert len(timesteps) > 1 209 | for i in range(len(timesteps) - 1): 210 | assert timesteps[i + 1] < timesteps[i] 211 | 212 | 213 | @lib.distributed.auto_distribute 214 | def main(_): 215 | check_steps() 216 | data = lib.data.DATASETS[FLAGS.dataset]() 217 | model = TCDistillGoogleModel(FLAGS.dataset, data.res, FLAGS.timesteps, reset=FLAGS.reset, 218 | batch=FLAGS.batch, lr=FLAGS.lr, ema_residual=FLAGS.ema_residual, 219 | sema=FLAGS.sema, time_schedule=FLAGS.time_schedule) 220 | logdir = lib.util.artifact_dir(FLAGS.dataset, model.logdir) 221 | train, fid = data.make_dataloaders() 222 | model.initialize_weights_from_teacher(logdir) 223 | model.train_loop(train, fid, FLAGS.batch, FLAGS.train_len, FLAGS.report_len, logdir, fid_len=FLAGS.fid_len) 224 | 225 | 226 | if __name__ == '__main__': 227 | flags.DEFINE_enum('reset', TCDistillGoogleModel.R_NONE, TCDistillGoogleModel.R_ALL, help='EMA reset mode.') 228 | flags.DEFINE_float('ema_residual', 1e-3, help='Residual for the Exponential Moving Average of model.') 229 | flags.DEFINE_float('sema', 0.5, help='Exponential Moving Average of self-teacher.') 230 | flags.DEFINE_float('lr', 2e-4, help='Learning rate.') 231 | flags.DEFINE_integer('fid_len', 4096, help='Number of samples for FID evaluation.') 232 | flags.DEFINE_integer('timesteps', 1024, help='Sampling timesteps.') 233 | flags.DEFINE_string('dataset', 'cifar10', help='Training dataset.') 234 | flags.DEFINE_string('time_schedule', None, required=True, 235 | help='Comma separated distillation timesteps, for example: 1024,32,1.') 236 | flags.DEFINE_string('train_len', '64M', help='Training duration in samples per distillation logstep.') 237 | flags.DEFINE_string('report_len', '1M', help='Reporting interval in samples.') 238 | flags.FLAGS.set_default('report_img_len', '1M') 239 | flags.FLAGS.set_default('report_fid_len', '4M') 240 | app.run(lib.distributed.main(main)) 241 | -------------------------------------------------------------------------------- /tc_distill_edm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import copy 6 | import functools 7 | import pickle 8 | import sys 9 | from typing import Dict, Optional 10 | 11 | import torch 12 | import torch.nn.functional 13 | from absl import app, flags 14 | 15 | import lib 16 | from lib.distributed import device, device_id 17 | from lib.util import FLAGS, int_str 18 | 19 | # Imports within edm/ are often relative to edm/ so we do this. 20 | sys.path.append('edm') 21 | import dnnlib 22 | from torch_utils import distributed as dist 23 | from torch_utils import misc 24 | 25 | 26 | class EluDDIM05TCMultiStepx0(lib.train.TrainModel): 27 | SIGMA_DATA = 0.5 28 | SIGMA_MIN: float = 0.002 29 | SIGMA_MAX: float = 80. 30 | RHO: float = 7. 31 | 32 | def __init__(self, res: int, timesteps: int, **params): 33 | super().__init__("EluUNet", res, timesteps, **params) 34 | self.use_imagenet = FLAGS.dataset == "imagenet64" 35 | self.num_classes = 1000 if self.use_imagenet else 10 36 | 37 | # Setup pretrained model 38 | lib.distributed.barrier() 39 | if FLAGS.dataset == "imagenet64": 40 | pretrained_url = "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl" 41 | elif FLAGS.dataset == "cifar10": 42 | pretrained_url = "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-uncond-ve.pkl" 43 | else: 44 | raise ValueError("Only cifar10 or imagenet64 is supported for now.") 45 | with dnnlib.util.open_url(pretrained_url) as f: 46 | pretrained = pickle.load(f)['ema'] 47 | lib.distributed.barrier() 48 | 49 | network_kwargs = self.get_pretrained_cifar10_network_kwargs() 50 | if self.use_imagenet: 51 | network_kwargs = self.get_pretrained_imagenet_network_kwargs() 52 | label_dim = self.num_classes if self.use_imagenet else 0 53 | interface_kwargs = dict(img_resolution=res, img_channels=3, label_dim=label_dim) 54 | model = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) 55 | model.train().requires_grad_(True) 56 | misc.copy_params_and_buffers(src_module=pretrained, dst_module=model, require_all=False) 57 | del pretrained # save memory 58 | 59 | self.time_schedule = tuple(int(x) for x in self.params.time_schedule.split(',')) 60 | steps_per_phase = int_str(FLAGS.train_len) / (FLAGS.batch * (len(self.time_schedule) - 1)) 61 | ema = self.params.ema_residual ** (1 / steps_per_phase) 62 | model.apply(functools.partial(lib.nn.functional.set_bn_momentum, momentum=1 - ema)) 63 | model.apply(functools.partial(lib.nn.functional.set_dropout, p=self.params.dropout)) 64 | self.model = lib.distributed.wrap(model) 65 | self.model_eval = lib.optim.ModuleEMA(model, momentum=ema).eval().requires_grad_(False).to(device_id()) 66 | lib.distributed.barrier() 67 | 68 | # Disable dropout noise for teacher 69 | model.apply(functools.partial(lib.nn.functional.set_dropout, p=0)) 70 | self.self_teacher = lib.optim.ModuleEMA(model, momentum=self.params.sema).to(device_id()) 71 | self.self_teacher.eval().requires_grad_(False) 72 | self.teacher = copy.deepcopy(model).to(device_id()) 73 | self.teacher.eval().requires_grad_(False) 74 | 75 | self.opt = torch.optim.Adam(self.model.parameters(), lr=self.params.lr, weight_decay=0.0) 76 | 77 | # Setup noise schedule 78 | sigma = torch.linspace(self.SIGMA_MIN ** (1 / self.RHO), 79 | self.SIGMA_MAX ** (1 / self.RHO), timesteps, dtype=torch.double).pow(self.RHO) 80 | sigma = torch.cat([torch.zeros_like(sigma[:1]), sigma]) 81 | self.register_buffer('sigma', sigma.to(device())) 82 | self.timesteps = timesteps 83 | 84 | def get_pretrained_cifar10_network_kwargs(self): 85 | network_kwargs = dnnlib.EasyDict() 86 | network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') 87 | network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) 88 | network_kwargs.class_name = 'training.networks.EDMPrecond' 89 | network_kwargs.augment_dim = 0 90 | network_kwargs.update(dropout=0.0, use_fp16=False) 91 | return network_kwargs 92 | 93 | def get_pretrained_imagenet_network_kwargs(self): 94 | network_kwargs = dnnlib.EasyDict() 95 | network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) 96 | network_kwargs.class_name = 'training.networks.EDMPrecond' 97 | network_kwargs.update(dropout=0.0, use_fp16=False) 98 | return network_kwargs 99 | 100 | @classmethod 101 | def c_in(cls, sigma: torch.Tensor) -> torch.Tensor: 102 | return (sigma ** 2 + cls.SIGMA_DATA ** 2) ** -0.5 103 | 104 | @classmethod 105 | def c_skip(cls, sigma: torch.Tensor) -> torch.Tensor: 106 | return (cls.SIGMA_DATA ** 2) / (sigma ** 2 + cls.SIGMA_DATA ** 2) 107 | 108 | @classmethod 109 | def c_out(cls, sigma: torch.Tensor) -> torch.Tensor: 110 | return sigma * cls.SIGMA_DATA * (cls.SIGMA_DATA ** 2 + sigma ** 2) ** -0.5 111 | 112 | @staticmethod 113 | def c_noise(sigma: torch.Tensor) -> torch.Tensor: 114 | return 0.25 * sigma.clamp(1e-20).log() 115 | 116 | def forward(self, n: int, generator: Optional[torch.Generator] = None) -> torch.Tensor: 117 | step = self.timesteps // self.time_schedule[1] 118 | shape = n, self.COLORS, self.params.res, self.params.res 119 | 120 | xt = self.sigma[-1] * torch.randn(shape, generator=generator, dtype=torch.double).to(device()) 121 | class_labels = (torch.eye(self.num_classes, device=device())[torch.randint(self.num_classes, size=[n], device=device())]) if self.use_imagenet else None 122 | 123 | for t in reversed(range(0, self.timesteps, step)): 124 | ix = torch.Tensor([t + step]).long().to(device_id()), torch.Tensor([t]).long().to(device_id()) 125 | g = tuple(self.sigma[i].view(-1, 1, 1, 1) for i in ix) 126 | x0 = self.model_eval(xt, g[0], class_labels).to(torch.float64) 127 | xt = self.post_xt_x0(xt, x0, g[0], g[1]) 128 | 129 | return xt.clamp(-1, 1).float() 130 | 131 | def post_xt_x0(self, xt: torch.Tensor, out: torch.Tensor, sigma: torch.Tensor, sigma1: torch.Tensor) -> torch.Tensor: 132 | x0 = torch.clip(out, -1., 1.) 133 | eps = (xt - x0) / sigma 134 | return torch.nan_to_num(x0 + eps * sigma1) 135 | 136 | def train_op(self, info: lib.train.TrainInfo, x: torch.Tensor, y: torch.Tensor) -> Dict[str, torch.Tensor]: 137 | if self.num_classes == 1000: # imagenet 138 | y = torch.nn.functional.one_hot(y, self.num_classes).to(y.device) 139 | else: 140 | y = None 141 | 142 | with torch.no_grad(): 143 | 144 | step = self.timesteps // self.time_schedule[1] 145 | index = torch.randint(1, 1 + (self.timesteps // step), (x.shape[0],), device=device()) * step 146 | semi_index = torch.randint(step, index.shape, device=device()) 147 | ix = index - semi_index, (index - semi_index - 1).clamp(1), index - step 148 | 149 | s = tuple(self.sigma[i].view(-1, 1, 1, 1) for i in ix) 150 | noise = torch.randn_like(x).to(device()) 151 | 152 | # RK step from teacher 153 | xt = x.double() + noise * s[0] 154 | x0 = self.teacher(xt, s[0], y) 155 | eps = (xt - x0) / s[0] 156 | xt_ = xt + (s[1] - s[0]) * eps 157 | x0_ = self.teacher(xt_, s[1], y) 158 | eps = .5 * (eps + (xt_ - x0_) / s[1]) 159 | xt_ = xt + (s[1] - s[0]) * eps # RK target from teacher; no RK needed for sigma_min 160 | 161 | # self-teacher step 162 | xt2 = self.post_xt_x0(xt_, self.self_teacher(xt_, s[1], y), s[1], s[2]) 163 | xt2 += ((semi_index + 1) == step).view(-1, 1, 1, 1) * (xt_ - xt2) # Only propagate inside phase semi_range 164 | 165 | xt2 = ((xt2 * s[0] - xt * s[2]) / (s[0] - s[2])) 166 | 167 | # Boundary and terminal condition: last time step, no RK and self-teaching needed 168 | target_without_precon = torch.where((index - semi_index - 1).view(-1, 1, 1, 1) == 0, x0.double(), xt2.double()) 169 | 170 | target = (target_without_precon - self.c_skip(s[0]) * xt) / self.c_out(s[0]) 171 | 172 | self.opt.zero_grad(set_to_none=True) 173 | pred = self.model(xt.float(), s[0].float(), y).double() 174 | pred = (pred - self.c_skip(s[0]) * xt) / self.c_out(s[0]) 175 | 176 | weight = (s[0] ** 2 + self.SIGMA_DATA ** 2) * (self.c_out(s[0]) ** 2) * (s[0] * self.SIGMA_DATA) ** -2 177 | loss = (torch.nn.functional.mse_loss(pred.float(), target.float(), reduction='none')).mean((1, 2, 3)) 178 | loss = (weight.float() * loss).mean() 179 | 180 | loss.backward() 181 | 182 | # LR warmup and clip gradient like EDM paper 183 | if self.params.lr_warmup is not None: 184 | for g in self.opt.param_groups: 185 | g['lr'] = self.params.lr * min(info.samples / max(int_str(self.params.lr_warmup), 1e-8), 1) 186 | for param in self.model.parameters(): 187 | if param.grad is not None: 188 | torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) 189 | 190 | self.opt.step() 191 | self.self_teacher.update(self.model) 192 | self.model_eval.update(self.model) 193 | return {'loss/global': loss} 194 | 195 | 196 | def check_steps(): 197 | timesteps = [int(x) for x in FLAGS.time_schedule.split(',')] 198 | assert len(timesteps) > 1 199 | assert timesteps[0] == FLAGS.timesteps 200 | for i in range(len(timesteps) - 1): 201 | assert timesteps[i + 1] < timesteps[i] 202 | 203 | 204 | @lib.distributed.auto_distribute 205 | def main(_): 206 | check_steps() 207 | data = lib.data.DATASETS[FLAGS.dataset]() 208 | lib.distributed.barrier() 209 | model = EluDDIM05TCMultiStepx0(data.res, FLAGS.timesteps, batch=FLAGS.batch, lr=FLAGS.lr, 210 | ema_residual=FLAGS.ema_residual, sema=FLAGS.sema, lr_warmup=FLAGS.lr_warmup, 211 | aug_prob=FLAGS.aug_prob, dropout=FLAGS.dropout, time_schedule=FLAGS.time_schedule) 212 | lib.distributed.barrier() 213 | logdir = lib.util.artifact_dir(FLAGS.dataset, model.logdir) 214 | train, fid = data.make_dataloaders() 215 | model.train_loop(train, fid, FLAGS.batch, FLAGS.train_len, FLAGS.report_len, logdir, fid_len=FLAGS.fid_len) 216 | 217 | 218 | if __name__ == '__main__': 219 | flags.DEFINE_float('ema_residual', 1e-3, help='Residual for the Exponential Moving Average of model.') 220 | flags.DEFINE_float('sema', 0.5, help='Exponential Moving Average of self-teacher.') 221 | flags.DEFINE_float('lr', 1e-3, help='Learning rate.') 222 | flags.DEFINE_string('lr_warmup', None, help='Warmup for LR in num samples, e.g. 4M') 223 | flags.DEFINE_integer('fid_len', 50000, help='Number of samples for FID evaluation.') 224 | flags.DEFINE_integer('timesteps', 40, help='Sampling timesteps.') 225 | flags.DEFINE_string('time_schedule', None, required=True, 226 | help='Comma separated distillation timesteps, for example: 36,1.') 227 | flags.DEFINE_string('dataset', 'cifar10', help='Training dataset. Either cifar10 or imagenet64') 228 | flags.DEFINE_string('report_len', '1M', help='Reporting interval in samples.') 229 | flags.DEFINE_string('train_len', '64M', help='Training duration in samples per distillation logstep.') 230 | flags.DEFINE_float('aug_prob', 0.0, help='Probability of applying data augmentation in training.') 231 | flags.DEFINE_float('dropout', 0.0, help='Dropout probability for training.') 232 | flags.FLAGS.set_default('report_img_len', '1M') 233 | flags.FLAGS.set_default('report_fid_len', '4M') 234 | app.run(lib.distributed.main(main)) 235 | -------------------------------------------------------------------------------- /teacher/README.md: -------------------------------------------------------------------------------- 1 | # Downloading pre-trained teacher models 2 | 3 | To compare with our baselines, we start from their pre-trained teacher model checkpoints and distill our student model. 4 | 5 | Follow the [guide](https://cloud.google.com/storage/docs/gsutil_install) to install gsutils that will be used to download checkpoints from Google Cloud Storage. 6 | 7 | Since the original model is written in JAX, we need to convert the teacher's weights from JAX to torch. 8 | 9 | Run 10 | 11 | ```bash 12 | python teacher/download_and_convert_jax.py 13 | ``` 14 | -------------------------------------------------------------------------------- /teacher/download_and_convert_jax.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | import os 6 | import pathlib 7 | from flax import serialization 8 | from tensorflow.compat.v2.io import gfile 9 | from lib.zoo.unet import UNet 10 | import numpy as np 11 | import torch 12 | import einops as ei 13 | from absl import app, flags 14 | 15 | 16 | def to_torch(x): 17 | return torch.nn.Parameter(torch.from_numpy(x.copy())) 18 | 19 | 20 | def check_and_convert_gcs_filepath(filepath, raise_if_not_gcs=False): 21 | """Utility for loading model checkpoints from GCS.""" 22 | local_filepath = filepath.split('/')[-1] 23 | if os.path.exists(local_filepath): 24 | print('loading from local copy of GCS file: ' + local_filepath) 25 | else: 26 | print('downloading file from GCS: ' + filepath) 27 | os.system('gsutil cp ' + filepath + ' ' + local_filepath) 28 | return local_filepath 29 | 30 | 31 | def restore_from_path(ckpt_path, target): 32 | ckpt_path = check_and_convert_gcs_filepath(ckpt_path) 33 | with gfile.GFile(ckpt_path, 'rb') as fp: 34 | return serialization.from_bytes(target, fp.read()) 35 | 36 | 37 | def convert_conv(module_from, module_to): 38 | # PyTorch kernel has shape [outC, inC, kH, kW] and the Flax kernel has shape [kH, kW, inC, outC] 39 | module_to.weight = to_torch(module_from['kernel'].transpose(3, 2, 0, 1)) 40 | module_to.bias = to_torch(module_from['bias']) 41 | 42 | 43 | def convert_conv_after_qkv(module_from, module_to): 44 | module_to.weight = to_torch(ei.rearrange(module_from['kernel'], "nh h f -> f (nh h) 1 1")) 45 | module_to.bias = to_torch(module_from['bias']) 46 | 47 | 48 | def convert_fc(module_from, module_to): 49 | # PyTorch kernel has shape [outC, inC] and the Flax kernel has shape [inC, outC] 50 | module_to.weight = to_torch(module_from['kernel'].transpose(1, 0)) 51 | module_to.bias = to_torch(module_from['bias']) 52 | 53 | 54 | def convert_group_norm(module_from, module_to): 55 | module_to.weight = to_torch(module_from['scale']) 56 | module_to.bias = to_torch(module_from['bias']) 57 | 58 | 59 | def convert_qkv(module_from_q, module_from_k, module_from_v, module_to): 60 | weight = np.concatenate((module_from_q['kernel'], module_from_k['kernel'], module_from_v['kernel']), 2) 61 | module_to.weight = to_torch(ei.rearrange(weight, 'f nh h -> (nh h) f 1 1')) 62 | bias = np.concatenate((module_from_q['bias'], module_from_k['bias'], module_from_v['bias']), 1) 63 | module_to.bias = to_torch(ei.rearrange(bias, 'nh h -> (nh h)')) 64 | 65 | 66 | def convert1x1conv(module_from, module_to): 67 | module_to.weight = to_torch(module_from['kernel'].transpose(1, 0)[:, :, None, None]) 68 | module_to.bias = to_torch(module_from['bias']) 69 | 70 | 71 | def convert_res_block(module_from, module_to): 72 | convert_group_norm(module_from['norm1'], module_to.norm1) 73 | convert_conv(module_from['conv1'], module_to.conv1) 74 | convert_fc(module_from['temb_proj'], module_to.time[1]) 75 | convert_group_norm(module_from['norm2'], module_to.norm2) 76 | convert_conv(module_from['conv2'], module_to.conv2) 77 | if 'nin_shortcut' in module_from: 78 | convert1x1conv(module_from['nin_shortcut'], module_to.skip) 79 | 80 | 81 | def convert_attention(module_from, module_to): 82 | convert_group_norm(module_from['norm'], module_to.norm) 83 | convert_qkv(module_from['q'], module_from['k'], module_from['v'], module_to.qkv) 84 | convert_conv_after_qkv(module_from['proj_out'], module_to.out) 85 | 86 | 87 | def convert_down(module_from, module_to, n_down_blocks, n_res_blocks): 88 | convert_conv(module_from['conv_in'], module_to[0]) 89 | module_to_idx = 1 90 | for i in range(n_down_blocks): 91 | for j in range(n_res_blocks): 92 | convert_res_block(module_from[f'down_{i}.block_{j}'], module_to[module_to_idx].resblocks) 93 | if f'down_{i}.attn_{j}' in module_from.keys(): 94 | convert_attention(module_from[f'down_{i}.attn_{j}'], module_to[module_to_idx].attention) 95 | module_to_idx += 1 96 | # downsample layer is a res block 97 | if f'down_{i}.downsample' in module_from.keys(): 98 | convert_res_block(module_from[f'down_{i}.downsample'], module_to[module_to_idx]) 99 | module_to_idx += 1 100 | assert module_to_idx == len(module_to) 101 | 102 | 103 | def convert_mid(module_from, module_to): 104 | convert_res_block(module_from['mid.block_1'], module_to[0].resblocks) 105 | convert_attention(module_from['mid.attn_1'], module_to[0].attention) 106 | convert_res_block(module_from['mid.block_2'], module_to[1].resblocks) 107 | 108 | 109 | def convert_up(module_from, module_to, num_up_blocks, n_res_blocks): 110 | module_to_idx = 0 111 | for i in reversed(range(num_up_blocks)): 112 | for j in range(n_res_blocks + 1): 113 | convert_res_block(module_from[f'up_{i}.block_{j}'], module_to[module_to_idx].resblocks) 114 | if f'up_{i}.attn_{j}' in module_from.keys(): 115 | convert_attention(module_from[f'up_{i}.attn_{j}'], module_to[module_to_idx].attention) 116 | module_to_idx += 1 117 | # upsample layer is a res block 118 | if f'up_{i}.upsample' in module_from.keys(): 119 | convert_res_block(module_from[f'up_{i}.upsample'], module_to[module_to_idx]) 120 | module_to_idx += 1 121 | assert module_to_idx == len(module_to) 122 | 123 | 124 | def convert_out(module_from, module_to): 125 | convert_group_norm(module_from['norm_out'], module_to[0]) 126 | convert_conv(module_from['conv_out'], module_to[2]) 127 | 128 | 129 | def convert_time(module_from, module_to): 130 | convert_fc(module_from['dense0'], module_to[0]) 131 | convert_fc(module_from['dense1'], module_to[2]) 132 | 133 | 134 | def convert_class(module_from, module_to): 135 | convert_fc(module_from['class_emb'], module_to) 136 | 137 | 138 | def convert(module_from, module_to, n_down_blocks, n_up_blocks, n_res_blocks, class_conditional=False): 139 | # downsample 140 | convert_down(module_from['ema_params'], module_to.down, n_down_blocks, n_res_blocks) 141 | # mid 142 | convert_mid(module_from['ema_params'], module_to.mid) 143 | # up 144 | convert_up(module_from['ema_params'], module_to.up, n_up_blocks, n_res_blocks) 145 | # out 146 | convert_out(module_from['ema_params'], module_to.out) 147 | # time 148 | convert_time(module_from['ema_params'], module_to.time) 149 | # class 150 | if class_conditional: 151 | convert_class(module_from['ema_params'], module_to.class_emb) 152 | 153 | 154 | def cifar10(path: pathlib.Path): 155 | ckpt = restore_from_path('gs://gresearch/diffusion-distillation/cifar_original', None) 156 | net = UNet(in_channel=3, 157 | channel=256, 158 | emb_channel=1024, 159 | channel_multiplier=[1, 1, 1], 160 | n_res_blocks=3, 161 | attn_rezs=[8, 16], 162 | attn_heads=1, 163 | head_dim=None, 164 | use_affine_time=True, 165 | dropout=0.2, 166 | num_output=1, 167 | resample=True, 168 | num_classes=1) 169 | convert(ckpt, net, n_down_blocks=3, n_up_blocks=3, n_res_blocks=3) 170 | # save torch checkpoint 171 | torch.save(net.state_dict(), path / 'cifar_original.pt') 172 | return net 173 | 174 | 175 | def imagenet64_conditional(path: pathlib.Path): 176 | ckpt = restore_from_path('gs://gresearch/diffusion-distillation/imagenet_original', None) 177 | net = UNet(in_channel=3, 178 | channel=192, 179 | emb_channel=768, 180 | channel_multiplier=[1, 2, 3, 4], 181 | n_res_blocks=3, 182 | init_rez=64, 183 | attn_rezs=[8, 16, 32], 184 | attn_heads=None, 185 | head_dim=64, 186 | use_affine_time=True, 187 | dropout=0., 188 | num_output=2, # predict signal and noise 189 | resample=True, 190 | num_classes=1000) 191 | convert(ckpt, net, n_down_blocks=4, n_up_blocks=4, n_res_blocks=3, class_conditional=True) 192 | # save torch checkpoint 193 | torch.save(net.state_dict(), path / 'imagenet_original.pt') 194 | return net 195 | 196 | 197 | def main(_): 198 | path = pathlib.Path(flags.FLAGS.path) 199 | os.makedirs(path, exist_ok=True) 200 | imagenet64_conditional(path) 201 | cifar10(path) 202 | 203 | 204 | if __name__ == '__main__': 205 | flags.DEFINE_string('path', './ckpts/', help='Path to save the checkpoints.') 206 | app.run(main) 207 | --------------------------------------------------------------------------------