├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── assets └── main.png ├── main ├── __init__.py ├── callbacks.py ├── configs │ └── dataset │ │ ├── afhqv2 │ │ └── afhqv2128_psld.yaml │ │ ├── celeba64 │ │ ├── celeba64_psld.yaml │ │ └── celeba64_vpsde.yaml │ │ └── cifar10 │ │ ├── cifar10_psld.yaml │ │ └── cifar10_vpsde.yaml ├── datasets │ ├── __init__.py │ ├── afhq.py │ ├── celeba.py │ ├── celebahq.py │ ├── cifar10.py │ ├── inpaint.py │ └── latent.py ├── eval │ ├── __init__.py │ ├── class_cond_sample.py │ ├── inpaint.py │ └── sample.py ├── losses.py ├── models │ ├── __init__.py │ ├── clf_wrapper.py │ ├── score_fn │ │ ├── __init__.py │ │ └── song_sde │ │ │ ├── __init__.py │ │ │ ├── layers.py │ │ │ ├── layerspp.py │ │ │ ├── ncsnpp.py │ │ │ ├── ncsnpp_clf.py │ │ │ ├── normalization.py │ │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ │ ├── up_or_down_sampling.py │ │ │ └── utils.py │ ├── sde │ │ ├── __init__.py │ │ ├── base.py │ │ ├── psld.py │ │ └── vpsde.py │ └── wrapper.py ├── samplers │ ├── __init__.py │ ├── base.py │ ├── ode.py │ └── sde.py ├── train_clf.py ├── train_sde.py └── util.py └── scripts_psld ├── ablations ├── cond │ ├── afhqv2 │ │ ├── sample_inpaint_psld.sh │ │ ├── sample_tclf_psld.sh │ │ └── train_tclf_psld.sh │ └── cifar10 │ │ ├── sample_tclf_psld.sh │ │ └── train_tclf_psld.sh └── uncond │ ├── afhqv2 │ ├── sample_uncond_psld.sh │ └── train_uncond_psld.sh │ ├── celeba64 │ ├── sample_uncond_psld.sh │ └── train_uncond_psld.sh │ └── cifar10 │ ├── sample_uncond_psld.sh │ ├── sample_uncond_psld_ode.sh │ ├── sample_uncond_vpsde.sh │ ├── sample_uncond_vpsde_ode.sh │ ├── train_uncond_psld.sh │ └── train_uncond_vpsde.sh ├── fid.sh └── sota ├── cond └── afhqv2 │ └── sample_inpaint_psld.sh └── uncond ├── celeba64 ├── sample_uncond_psld.sh ├── sample_uncond_psld_ode.sh └── train_uncond_psld.sh └── cifar10 ├── sample_uncond_psld.sh ├── sample_uncond_psld_ode.sh └── train_uncond_psld.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | /scripts_test/ 132 | /scripts_itsam/ 133 | /scripts_slurm/ 134 | /outputs/ 135 | analysis/ 136 | download.sh -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mandt-lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | torch == "1.13.1+cu116" 8 | torchdiffeq = "0.2.3" 9 | torch-fidelity = "0.3.0" 10 | torchmetrics = "0.11.0" 11 | torchvision = "0.14.1+cu116" 12 | tqdm = "4.64.1" 13 | wandb = "0.13.5" 14 | scipy = "1.9.3" 15 | scikit-learn = "1.1.3" 16 | Pillow = "9.3.0" 17 | numpy = "1.23.5" 18 | matplotlib = "3.6.2" 19 | hydra-core = "1.2.0" 20 | pytorch-lightning = "1.9.0" 21 | ninja = "1.11.1" 22 | 23 | [dev-packages] 24 | 25 | [requires] 26 | python_version = "3.8" 27 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "7f7606f08e0544d8d012ef4d097dabdd6df6843a28793eb6551245d4b2db4242" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.8" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": {}, 19 | "develop": {} 20 | } 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

Phase Space Langevin Diffusion

ICCV'23 (Oral Presentation)

2 | 3 |
4 | Kushagra Pandey·   5 | Stephan Mandt 6 |
7 |
8 | 9 | Official Implementation of the paper: A Complete Recipe for Diffusion Generative Models 10 | 11 | ## Overview 12 |
13 | 14 |
15 | 16 | We propose a complete recipe for constructing novel diffusion processes which are guaranteed to converge to a specified stationary distribution. This serves two benefits: 17 | 21 | 22 | To instantiate this recipe, we propose a new diffusion model: Phase Space Langevin Diffusion (PSLD) which outperforms diffusion models like VP-SDE and CLD, and achieves excellent sample quality on standard image synthesis benchmarks like CIFAR-10 (FID:2.10) and CelebA-64 (FID: 2.01). PSLD also supports classifier-(free) guidance out of the box like other diffusion models. 23 | 24 | ## Code Overview 25 | 26 | This repo uses [PyTorch Lightning](https://www.pytorchlightning.ai/) for training and [Hydra](https://hydra.cc/docs/intro/) for config management so basic familiarity with both these tools is expected. Please clone the repo with `PSLD` as the working directory for any downstream tasks like setting up the dependencies, training and inference. 27 | 28 | ## Dependency Setup 29 | 30 | We use `pipenv` for a project-level dependency management. Simply [install](https://pipenv.pypa.io/en/latest/#install-pipenv-today) `pipenv` and run the following command: 31 | 32 | ``` 33 | pipenv install 34 | ``` 35 | 36 | ## Config Management 37 | We manage hydra configurations separately for each benchmark/dataset used in this work. All configs are present in the `main/configs` directory. This directory has subfolders named according to the dataset where each subfolder contains the corresponding config associated with a specific diffusion models. 38 | 39 | 60 | 61 | ## Training and Evaluation 62 | 63 | We include a sample training and inference script for CIFAR-10. We include all scripts used in this work in the directory `/scripts_psld/`. 64 | 65 | - Training a PSLD model on CIFAR-10: 66 | 67 | ```shell script 68 | python main/train_sde.py +dataset=cifar10/cifar10_psld \ 69 | dataset.diffusion.data.root=\'/path/to/cifar10/\' \ 70 | dataset.diffusion.data.name='cifar10' \ 71 | dataset.diffusion.data.norm=True \ 72 | dataset.diffusion.data.hflip=True \ 73 | dataset.diffusion.model.score_fn.in_ch=6 \ 74 | dataset.diffusion.model.score_fn.out_ch=6 \ 75 | dataset.diffusion.model.score_fn.nf=128 \ 76 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 77 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 78 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 79 | dataset.diffusion.model.score_fn.dropout=0.15 \ 80 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 81 | dataset.diffusion.model.score_fn.fir=True \ 82 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 83 | dataset.diffusion.model.sde.beta_min=8.0 \ 84 | dataset.diffusion.model.sde.beta_max=8.0 \ 85 | dataset.diffusion.model.sde.decomp_mode='lower' \ 86 | dataset.diffusion.model.sde.nu=4.01 \ 87 | dataset.diffusion.model.sde.gamma=0.01 \ 88 | dataset.diffusion.model.sde.kappa=0.04 \ 89 | dataset.diffusion.training.seed=0 \ 90 | dataset.diffusion.training.chkpt_interval=50 \ 91 | dataset.diffusion.training.mode='hsm' \ 92 | dataset.diffusion.training.fp16=False \ 93 | dataset.diffusion.training.use_ema=True \ 94 | dataset.diffusion.training.batch_size=16 \ 95 | dataset.diffusion.training.epochs=2500 \ 96 | dataset.diffusion.training.accelerator='gpu' \ 97 | dataset.diffusion.training.devices=8 \ 98 | dataset.diffusion.training.results_dir=\'/path/to/logdir/' \ 99 | dataset.diffusion.training.workers=1 100 | ``` 101 | 102 | - Generating CIFAR-10 samples from the PSLD model 103 | 104 | ```shell script 105 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \ 106 | dataset.diffusion.model.score_fn.in_ch=6 \ 107 | dataset.diffusion.model.score_fn.out_ch=6 \ 108 | dataset.diffusion.model.score_fn.nf=128 \ 109 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 110 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 111 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 112 | dataset.diffusion.model.score_fn.dropout=0.15 \ 113 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 114 | dataset.diffusion.model.score_fn.fir=True \ 115 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 116 | dataset.diffusion.model.sde.beta_min=8.0 \ 117 | dataset.diffusion.model.sde.beta_max=8.0 \ 118 | dataset.diffusion.model.sde.nu=4.01 \ 119 | dataset.diffusion.model.sde.gamma=0.01 \ 120 | dataset.diffusion.model.sde.kappa=0.04 \ 121 | dataset.diffusion.model.sde.decomp_mode='lower' \ 122 | dataset.diffusion.evaluation.seed=0 \ 123 | dataset.diffusion.evaluation.sample_prefix='some_prefix_for_sample_names' \ 124 | dataset.diffusion.evaluation.devices=8 \ 125 | dataset.diffusion.evaluation.save_path=\'/path/to/generated/samples/\' \ 126 | dataset.diffusion.evaluation.batch_size=16 \ 127 | dataset.diffusion.evaluation.stride_type='uniform' \ 128 | dataset.diffusion.evaluation.sample_from='target' \ 129 | dataset.diffusion.evaluation.workers=1 \ 130 | dataset.diffusion.evaluation.chkpt_path=\'/path/to/pretrained/chkpt.ckpt\' \ 131 | dataset.diffusion.evaluation.sampler.name="em_sde" \ 132 | dataset.diffusion.evaluation.n_samples=50000 \ 133 | dataset.diffusion.evaluation.n_discrete_steps=50 \ 134 | dataset.diffusion.evaluation.path_prefix="50" 135 | ``` 136 | We evaluate sample quality using FID scores. We compute FID scores using the `torch-fidelity` package. 137 | 138 | ## Pretrained Checkpoints 139 | Pre-trained PSLD checkpoints can be found [here](https://personalmicrosoftsoftware-my.sharepoint.com/:f:/g/personal/pandeyk1_personalmicrosoftsoftware_uci_edu/EhehC1yAF1pKv1Vp2rZpMTsBG9v1dyjmvmGzkJSWkjPsfQ?e=fSTVao) 140 | 141 | ## Citation 142 | If you find the code useful for your research, please consider citing our ICCV'23 paper: 143 | 144 | ```bib 145 | @InProceedings{Pandey_2023_ICCV, 146 | author = {Pandey, Kushagra and Mandt, Stephan}, 147 | title = {A Complete Recipe for Diffusion Generative Models}, 148 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 149 | month = {October}, 150 | year = {2023}, 151 | pages = {4261-4272} 152 | } 153 | ``` 154 | 155 | ## License 156 | See the LICENSE file. 157 | -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/assets/main.png -------------------------------------------------------------------------------- /main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/__init__.py -------------------------------------------------------------------------------- /main/callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Sequence, Union 4 | 5 | import torch 6 | from pytorch_lightning import Callback, LightningModule, Trainer 7 | from pytorch_lightning.callbacks import BasePredictionWriter 8 | from pytorch_lightning.callbacks import Callback 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | from util import save_as_images, save_as_np 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class EMAWeightUpdate(Callback): 18 | """EMA weight update 19 | Your model should have: 20 | - ``self.online_network`` 21 | - ``self.target_network`` 22 | Updates the target_network params using an exponential moving average update rule weighted by tau. 23 | BYOL claims this keeps the online_network from collapsing. 24 | .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step 25 | Example:: 26 | # model must have 2 attributes 27 | model = Model() 28 | model.online_network = ... 29 | model.target_network = ... 30 | trainer = Trainer(callbacks=[EMAWeightUpdate()]) 31 | """ 32 | 33 | def __init__(self, tau: float = 0.9999): 34 | """ 35 | Args: 36 | tau: EMA decay rate 37 | """ 38 | super().__init__() 39 | self.tau = tau 40 | logger.info(f"Setup EMA callback with tau: {self.tau}") 41 | 42 | def on_train_batch_end( 43 | self, 44 | trainer: Trainer, 45 | pl_module: LightningModule, 46 | outputs: Sequence, 47 | batch: Sequence, 48 | batch_idx: int, 49 | ) -> None: 50 | # get networks 51 | online_net = pl_module.score_fn 52 | target_net = pl_module.ema_score_fn 53 | 54 | # update weights 55 | self.update_weights(online_net, target_net) 56 | 57 | def update_weights( 58 | self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor] 59 | ) -> None: 60 | # apply MA weight update 61 | with torch.no_grad(): 62 | for targ, src in zip(target_net.parameters(), online_net.parameters()): 63 | targ.mul_(self.tau).add_(src, alpha=1 - self.tau) 64 | 65 | 66 | # TODO: Add Support for saving momentum images 67 | class SimpleImageWriter(BasePredictionWriter): 68 | """Pytorch Lightning Callback for writing a batch of images to disk.""" 69 | 70 | def __init__( 71 | self, 72 | output_dir, 73 | write_interval, 74 | sample_prefix="", 75 | path_prefix="", 76 | save_mode="image", 77 | is_norm=True, 78 | is_augmented=True, 79 | ): 80 | super().__init__(write_interval) 81 | self.output_dir = output_dir 82 | self.sample_prefix = sample_prefix 83 | self.path_prefix = path_prefix 84 | self.is_norm = is_norm 85 | self.is_augmented = is_augmented 86 | self.save_fn = save_as_images if save_mode == "image" else save_as_np 87 | 88 | def write_on_batch_end( 89 | self, 90 | trainer, 91 | pl_module, 92 | prediction, 93 | batch_indices, 94 | batch, 95 | batch_idx, 96 | dataloader_idx, 97 | ): 98 | rank = pl_module.global_rank 99 | 100 | # Write output images 101 | # NOTE: We need to use gpu rank during saving to prevent 102 | # processes from overwriting images 103 | samples = prediction.cpu() 104 | 105 | # Ignore momentum states if the SDE is augmented 106 | if self.is_augmented: 107 | samples, _ = torch.chunk(samples, 2, dim=1) 108 | 109 | # Setup save dirs 110 | if self.path_prefix != "": 111 | base_save_path = os.path.join(self.output_dir, str(self.path_prefix)) 112 | else: 113 | base_save_path = self.output_dir 114 | img_save_path = os.path.join(base_save_path, "images") 115 | os.makedirs(img_save_path, exist_ok=True) 116 | 117 | # Save images 118 | self.save_fn( 119 | samples, 120 | file_name=os.path.join( 121 | img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}" 122 | ), 123 | denorm=self.is_norm, 124 | ) 125 | 126 | 127 | class InpaintingImageWriter(BasePredictionWriter): 128 | """Pytorch Lightning Callback for writing a batch of images to disk. 129 | Specifically adapted for Image inpainting. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | output_dir, 135 | write_interval, 136 | eval_mode="sample", 137 | sample_prefix="", 138 | path_prefix="", 139 | save_mode="image", 140 | is_norm=True, 141 | is_augmented=True, 142 | save_batch=False, 143 | ): 144 | super().__init__(write_interval) 145 | assert eval_mode in ["sample", "recons"] 146 | self.output_dir = output_dir 147 | self.eval_mode = eval_mode 148 | self.sample_prefix = sample_prefix 149 | self.path_prefix = path_prefix 150 | self.is_norm = is_norm 151 | self.is_augmented = is_augmented 152 | self.save_fn = save_as_images if save_mode == "image" else save_as_np 153 | self.save_batch = save_batch 154 | 155 | def write_on_batch_end( 156 | self, 157 | trainer, 158 | pl_module, 159 | prediction, 160 | batch_indices, 161 | batch, 162 | batch_idx, 163 | dataloader_idx, 164 | ): 165 | rank = pl_module.global_rank 166 | 167 | # Write output images 168 | # NOTE: We need to use gpu rank during saving to prevent 169 | # processes from overwriting images 170 | samples = prediction.cpu() 171 | 172 | if self.is_augmented: 173 | samples, _ = torch.chunk(samples, 2, dim=1) 174 | 175 | # Setup dirs 176 | if self.path_prefix != "": 177 | base_save_path = os.path.join(self.output_dir, str(self.path_prefix)) 178 | else: 179 | base_save_path = self.output_dir 180 | img_save_path = os.path.join(base_save_path, "images") 181 | os.makedirs(img_save_path, exist_ok=True) 182 | 183 | # Save images 184 | self.save_fn( 185 | samples, 186 | file_name=os.path.join( 187 | img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}" 188 | ), 189 | denorm=self.is_norm, 190 | ) 191 | 192 | # Save batch (For inpainting) 193 | if self.save_batch: 194 | batch_save_path = os.path.join(base_save_path, "batch") 195 | corr_save_path = os.path.join(base_save_path, "corrupt") 196 | os.makedirs(batch_save_path, exist_ok=True) 197 | os.makedirs(corr_save_path, exist_ok=True) 198 | img, mask = batch 199 | img = img * 0.5 + 0.5 200 | self.save_fn( 201 | img * mask, 202 | file_name=os.path.join( 203 | corr_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}" 204 | ), 205 | denorm=False, 206 | ) 207 | self.save_fn( 208 | img, 209 | file_name=os.path.join( 210 | batch_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}" 211 | ), 212 | denorm=False, 213 | ) 214 | -------------------------------------------------------------------------------- /main/configs/dataset/afhqv2/afhqv2128_psld.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | data: 3 | root: ??? 4 | name: "afhqv2" 5 | image_size: 128 6 | hflip: True 7 | num_channels: 3 8 | norm: True 9 | return_target: False 10 | 11 | model: 12 | pl_module: 'sde_wrapper' 13 | score_fn: 14 | name: "ncsnpp" 15 | in_ch: 6 16 | out_ch: 6 17 | nonlinearity: "swish" 18 | nf : 128 19 | ch_mult: [1,2,2,2,3] 20 | num_res_blocks: 2 21 | attn_resolutions: [16,] 22 | dropout: 0.2 23 | resamp_with_conv: True 24 | noise_cond: True 25 | fir: False 26 | fir_kernel: [1,3,3,1] 27 | skip_rescale: True 28 | resblock_type: "biggan" 29 | progressive: "none" 30 | progressive_input: "none" 31 | progressive_combine: "sum" 32 | embedding_type: "positional" 33 | init_scale: 0.0 34 | fourier_scale: 16 35 | sde: 36 | name: "psld" 37 | beta_min: 8.0 38 | beta_max: 8.0 39 | nu: 4.01 40 | gamma: 0.01 41 | kappa: 0.04 42 | decomp_mode: "lower" 43 | numerical_eps: 1e-9 44 | n_timesteps: 1000 45 | is_augmented: True 46 | 47 | training: 48 | seed: 0 49 | continuous: True 50 | mode: 'hsm' 51 | loss: 52 | name: "psld_score_loss" 53 | l_type: "l2" 54 | reduce_mean: True 55 | weighting: "fid" 56 | optimizer: 57 | name: "Adam" 58 | lr: 1e-4 59 | beta_1: 0.9 60 | beta_2: 0.999 61 | weight_decay: 0 62 | eps: 1e-8 63 | warmup: 5000 64 | grad_clip: 1.0 65 | train_eps: 1e-5 66 | fp16: False 67 | use_ema: True 68 | ema_decay: 0.9999 69 | batch_size: 32 70 | epochs: 5000 71 | log_step: 1 72 | accelerator: "gpu" 73 | devices: [0] 74 | chkpt_interval: 1 75 | restore_path: "" 76 | results_dir: ??? 77 | workers: 1 78 | chkpt_prefix: "" 79 | 80 | evaluation: 81 | sampler: 82 | name: em_sde 83 | seed: 0 84 | chkpt_path: ??? 85 | save_path: ??? 86 | n_discrete_steps: 1000 87 | denoise: True 88 | eval_eps: 1e-3 89 | stride_type: uniform 90 | use_pflow: False 91 | sample_from: "target" 92 | accelerator: "gpu" 93 | devices: [0] 94 | n_samples: 50000 95 | workers: 2 96 | batch_size: 64 97 | save_mode: image 98 | sample_prefix: "gpu" 99 | path_prefix: "" 100 | 101 | clf: 102 | data: 103 | root: ??? 104 | name: afhqv2 105 | image_size: 128 106 | hflip: true 107 | num_channels: 3 108 | norm: true 109 | return_target: true 110 | 111 | model: 112 | pl_module: tclf_wrapper 113 | clf_fn: 114 | name: ncsnpp_clf 115 | in_ch: 6 116 | nonlinearity: swish 117 | nf: 128 118 | ch_mult: [1,2,2,2] 119 | num_res_blocks: 2 120 | attn_resolutions: [16] 121 | dropout: 0.1 122 | resamp_with_conv: true 123 | noise_cond: true 124 | fir: false 125 | fir_kernel: [1,3,3,1] 126 | skip_rescale: true 127 | resblock_type: biggan 128 | progressive: none 129 | progressive_input: none 130 | progressive_combine: sum 131 | embedding_type: positional 132 | init_scale: 0 133 | fourier_scale: 16 134 | n_cls: ??? 135 | 136 | training: 137 | seed: 0 138 | continuous: true 139 | loss: 140 | name: tce_loss 141 | l_type: l2 142 | reduce_mean: true 143 | optimizer: 144 | name: Adam 145 | lr: 0.0002 146 | beta_1: 0.9 147 | beta_2: 0.999 148 | weight_decay: 0 149 | eps: 1e-8 150 | warmup: 5000 151 | fp16: false 152 | batch_size: 32 153 | epochs: 500 154 | log_step: 1 155 | accelerator: gpu 156 | devices: [0] 157 | chkpt_interval: 1 158 | restore_path: "" 159 | results_dir: ??? 160 | workers: 1 161 | chkpt_prefix: "" 162 | 163 | evaluation: 164 | seed: 0 165 | chkpt_path: ??? 166 | accelerator: gpu 167 | devices: [0] 168 | workers: 1 169 | batch_size: 64 170 | clf_temp: 1.0 171 | label_to_sample: 0 172 | -------------------------------------------------------------------------------- /main/configs/dataset/celeba64/celeba64_psld.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | data: 3 | root: ??? 4 | name: "celeba64" 5 | image_size: 64 6 | hflip: True 7 | num_channels: 3 8 | norm: True 9 | 10 | model: 11 | pl_module: 'sde_wrapper' 12 | score_fn: 13 | name: "ncsnpp" 14 | in_ch: 3 15 | out_ch: 3 16 | nonlinearity: "swish" 17 | nf : 128 18 | ch_mult: [1,2,2,2,4] 19 | num_res_blocks: 2 20 | attn_resolutions: [16,] 21 | dropout: 0.1 22 | resamp_with_conv: True 23 | noise_cond: True 24 | fir: False 25 | fir_kernel: [1,3,3,1] 26 | skip_rescale: True 27 | resblock_type: "biggan" 28 | progressive: "none" 29 | progressive_input: "none" 30 | progressive_combine: "sum" 31 | embedding_type: "positional" 32 | init_scale: 0.0 33 | fourier_scale: 16 34 | sde: 35 | name: "psld" 36 | beta_min: 8.0 37 | beta_max: 8.0 38 | nu: 1 39 | gamma: 2 40 | kappa: 0.04 41 | decomp_mode: "lower" 42 | numerical_eps: 1e-9 43 | n_timesteps: 1000 44 | is_augmented: True 45 | 46 | training: 47 | seed: 0 48 | continuous: True 49 | mode: 'hsm' 50 | loss: 51 | name: "psld_score_loss" 52 | l_type: "l2" 53 | reduce_mean: True 54 | weighting: "fid" 55 | optimizer: 56 | name: "Adam" 57 | lr: 2e-4 58 | beta_1: 0.9 59 | beta_2: 0.999 60 | weight_decay: 0 61 | eps: 1e-8 62 | warmup: 5000 63 | grad_clip: 1.0 64 | train_eps: 1e-5 65 | fp16: False 66 | use_ema: True 67 | ema_decay: 0.9999 68 | batch_size: 32 69 | epochs: 5000 70 | log_step: 1 71 | accelerator: "gpu" 72 | devices: [0] 73 | chkpt_interval: 1 74 | restore_path: "" 75 | results_dir: ??? 76 | workers: 1 77 | chkpt_prefix: "" 78 | 79 | evaluation: 80 | sampler: 81 | name: em_sde 82 | seed: 0 83 | chkpt_path: ??? 84 | save_path: ??? 85 | n_discrete_steps: 1000 86 | denoise: True 87 | eval_eps: 1e-3 88 | stride_type: uniform 89 | use_pflow: False 90 | sample_from: "target" 91 | accelerator: "gpu" 92 | devices: [0] 93 | n_samples: 50000 94 | workers: 2 95 | batch_size: 64 96 | save_mode: image 97 | sample_prefix: "gpu" 98 | path_prefix: "" 99 | -------------------------------------------------------------------------------- /main/configs/dataset/celeba64/celeba64_vpsde.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | data: 3 | root: ??? 4 | name: "cifar10" 5 | image_size: 64 6 | hflip: True 7 | num_channels: 3 8 | norm: True 9 | 10 | model: 11 | score_fn: 12 | name: "ncsnpp" 13 | in_ch: 3 14 | out_ch: 3 15 | nonlinearity: "swish" 16 | nf : 128 17 | ch_mult: [1,2,2,2,4] 18 | num_res_blocks: 2 19 | attn_resolutions: [16,] 20 | dropout: 0.1 21 | resamp_with_conv: True 22 | noise_cond: True 23 | fir: False 24 | fir_kernel: [1,3,3,1] 25 | skip_rescale: True 26 | resblock_type: "biggan" 27 | progressive: "none" 28 | progressive_input: "none" 29 | progressive_combine: "sum" 30 | embedding_type: "positional" 31 | init_scale: 0.0 32 | fourier_scale: 16 33 | # scale_by_sigma: False (The Unet always predicts epsilon) 34 | # n_heads: 8 35 | sde: 36 | name: "vpsde" 37 | beta_min: 0.1 38 | beta_max: 20 39 | n_timesteps: 1000 40 | is_augmented: False 41 | 42 | training: 43 | seed: 0 44 | continuous: True 45 | loss: 46 | name: "score_loss" 47 | l_type: "l2" 48 | reduce_mean: True 49 | weighting: "fid" 50 | optimizer: 51 | name: "Adam" 52 | lr: 2e-4 53 | beta_1: 0.9 54 | beta_2: 0.999 55 | weight_decay: 0 56 | eps: 1e-8 57 | warmup: 5000 58 | grad_clip: 1.0 59 | train_eps: 1e-5 60 | fp16: False 61 | use_ema: True 62 | ema_decay: 0.9999 63 | batch_size: 32 64 | epochs: 5000 65 | log_step: 1 66 | accelerator: "gpu" 67 | devices: [0] 68 | chkpt_interval: 1 69 | restore_path: "" 70 | results_dir: ??? 71 | workers: 1 72 | chkpt_prefix: "" 73 | 74 | evaluation: 75 | # Sampler specific config goes here 76 | sampler: 77 | name: em_sde 78 | seed: 0 79 | chkpt_path: ??? 80 | save_path: ??? 81 | n_discrete_steps: 1000 82 | denoise: True 83 | eval_eps: 1e-3 84 | stride_type: uniform 85 | use_pflow: False 86 | sample_from: "target" 87 | accelerator: "gpu" 88 | devices: [0] 89 | n_samples: 50000 90 | workers: 2 91 | batch_size: 64 92 | save_mode: image 93 | sample_prefix: "gpu" 94 | 95 | # VAE config used for VAE training 96 | vae: 97 | data: 98 | root: ??? 99 | name: "cifar10" 100 | image_size: 32 101 | n_channels: 3 102 | hflip: False 103 | 104 | model: 105 | enc_block_config : "32x7,32d2,32t16,16x4,16d2,16t8,8x4,8d2,8t4,4x3,4d4,4t1,1x3" 106 | enc_channel_config: "32:64,16:128,8:256,4:256,1:512" 107 | dec_block_config: "1x1,1u4,1t4,4x2,4u2,4t8,8x3,8u2,8t16,16x7,16u2,16t32,32x15" 108 | dec_channel_config: "32:64,16:128,8:256,4:256,1:512" 109 | 110 | training: 111 | seed: 0 112 | fp16: False 113 | batch_size: 128 114 | epochs: 1000 115 | log_step: 1 116 | device: "gpu:0" 117 | chkpt_interval: 1 118 | optimizer: "Adam" 119 | lr: 1e-4 120 | restore_path: "" 121 | results_dir: ??? 122 | workers: 2 123 | chkpt_prefix: "" 124 | alpha: 1.0 125 | -------------------------------------------------------------------------------- /main/configs/dataset/cifar10/cifar10_psld.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | data: 3 | root: ??? 4 | name: cifar10 5 | image_size: 32 6 | hflip: true 7 | num_channels: 3 8 | norm: true 9 | return_target: False 10 | 11 | model: 12 | pl_module: sde_wrapper 13 | score_fn: 14 | name: ncsnpp 15 | in_ch: 6 16 | out_ch: 6 17 | nonlinearity: swish 18 | nf: 128 19 | ch_mult: [1,2,2,2] 20 | num_res_blocks: 4 21 | attn_resolutions: [16] 22 | dropout: 0.1 23 | resamp_with_conv: true 24 | noise_cond: true 25 | fir: false 26 | fir_kernel: [1,3,3,1] 27 | skip_rescale: true 28 | resblock_type: biggan 29 | progressive: none 30 | progressive_input: none 31 | progressive_combine: sum 32 | embedding_type: positional 33 | init_scale: 0 34 | fourier_scale: 16 35 | sde: 36 | name: psld 37 | beta_min: 8 38 | beta_max: 8 39 | nu: 4.01 40 | gamma: 0.01 41 | kappa: 0.04 42 | decomp_mode: lower 43 | numerical_eps: 1e-9 44 | n_timesteps: 1000 45 | is_augmented: true 46 | 47 | training: 48 | seed: 0 49 | continuous: true 50 | mode: hsm 51 | loss: 52 | name: psld_score_loss 53 | l_type: l2 54 | reduce_mean: true 55 | weighting: fid 56 | optimizer: 57 | name: Adam 58 | lr: 0.0002 59 | beta_1: 0.9 60 | beta_2: 0.999 61 | weight_decay: 0 62 | eps: 1e-8 63 | warmup: 5000 64 | grad_clip: 1 65 | train_eps: 0.00001 66 | fp16: false 67 | use_ema: true 68 | ema_decay: 0.9999 69 | batch_size: 32 70 | epochs: 5000 71 | log_step: 1 72 | accelerator: gpu 73 | devices: [0] 74 | chkpt_interval: 1 75 | restore_path: "" 76 | results_dir: ??? 77 | workers: 1 78 | chkpt_prefix: "" 79 | 80 | evaluation: 81 | sampler: 82 | name: em_sde 83 | seed: 0 84 | chkpt_path: ??? 85 | save_path: ??? 86 | n_discrete_steps: 1000 87 | denoise: true 88 | eval_eps: 0.001 89 | stride_type: uniform 90 | use_pflow: false 91 | sample_from: target 92 | accelerator: gpu 93 | devices: [0] 94 | n_samples: 50000 95 | workers: 2 96 | batch_size: 64 97 | save_mode: image 98 | sample_prefix: gpu 99 | path_prefix: "" 100 | 101 | clf: 102 | data: 103 | root: ??? 104 | name: cifar10 105 | image_size: 32 106 | hflip: true 107 | num_channels: 3 108 | norm: true 109 | return_target: true 110 | 111 | model: 112 | pl_module: tclf_wrapper 113 | clf_fn: 114 | name: ncsnpp_clf 115 | in_ch: 6 116 | nonlinearity: swish 117 | nf: 128 118 | ch_mult: [1,2,2,2] 119 | num_res_blocks: 2 120 | attn_resolutions: [16] 121 | dropout: 0.1 122 | resamp_with_conv: true 123 | noise_cond: true 124 | fir: false 125 | fir_kernel: [1,3,3,1] 126 | skip_rescale: true 127 | resblock_type: biggan 128 | progressive: none 129 | progressive_input: none 130 | progressive_combine: sum 131 | embedding_type: positional 132 | init_scale: 0 133 | fourier_scale: 16 134 | n_cls: ??? 135 | 136 | training: 137 | seed: 0 138 | continuous: true 139 | loss: 140 | name: tce_loss 141 | l_type: l2 142 | reduce_mean: true 143 | optimizer: 144 | name: Adam 145 | lr: 0.0002 146 | beta_1: 0.9 147 | beta_2: 0.999 148 | weight_decay: 0 149 | eps: 1e-8 150 | warmup: 5000 151 | fp16: false 152 | batch_size: 32 153 | epochs: 500 154 | log_step: 1 155 | accelerator: gpu 156 | devices: [0] 157 | chkpt_interval: 1 158 | restore_path: "" 159 | results_dir: ??? 160 | workers: 1 161 | chkpt_prefix: "" 162 | 163 | evaluation: 164 | seed: 0 165 | chkpt_path: ??? 166 | accelerator: gpu 167 | devices: [0] 168 | workers: 1 169 | batch_size: 64 170 | clf_temp: 1.0 171 | label_to_sample: 0 172 | -------------------------------------------------------------------------------- /main/configs/dataset/cifar10/cifar10_vpsde.yaml: -------------------------------------------------------------------------------- 1 | diffusion: 2 | data: 3 | root: ??? 4 | name: "cifar10" 5 | image_size: 32 6 | hflip: True 7 | num_channels: 3 8 | norm: True 9 | 10 | model: 11 | score_fn: 12 | name: "ncsnpp" 13 | in_ch: 3 14 | out_ch: 3 15 | nonlinearity: "swish" 16 | nf : 128 17 | ch_mult: [1,2,2,2] 18 | num_res_blocks: 4 19 | attn_resolutions: [16,] 20 | dropout: 0.1 21 | resamp_with_conv: True 22 | noise_cond: True 23 | fir: False 24 | fir_kernel: [1,3,3,1] 25 | skip_rescale: True 26 | resblock_type: "biggan" 27 | progressive: "none" 28 | progressive_input: "none" 29 | progressive_combine: "sum" 30 | embedding_type: "positional" 31 | init_scale: 0.0 32 | fourier_scale: 16 33 | # scale_by_sigma: False (The Unet always predicts epsilon) 34 | # n_heads: 8 35 | sde: 36 | name: "vpsde" 37 | beta_min: 0.1 38 | beta_max: 20 39 | n_timesteps: 1000 40 | is_augmented: False 41 | 42 | training: 43 | seed: 0 44 | continuous: True 45 | loss: 46 | name: "score_loss" 47 | l_type: "l2" 48 | reduce_mean: True 49 | weighting: "fid" 50 | optimizer: 51 | name: "Adam" 52 | lr: 2e-4 53 | beta_1: 0.9 54 | beta_2: 0.999 55 | weight_decay: 0 56 | eps: 1e-8 57 | warmup: 5000 58 | grad_clip: 1.0 59 | train_eps: 1e-5 60 | fp16: False 61 | use_ema: True 62 | ema_decay: 0.9999 63 | batch_size: 32 64 | epochs: 5000 65 | log_step: 1 66 | accelerator: "gpu" 67 | devices: [0] 68 | chkpt_interval: 1 69 | restore_path: "" 70 | results_dir: ??? 71 | workers: 1 72 | chkpt_prefix: "" 73 | 74 | evaluation: 75 | # Sampler specific config goes here 76 | sampler: 77 | name: em_sde 78 | seed: 0 79 | chkpt_path: ??? 80 | save_path: ??? 81 | n_discrete_steps: 1000 82 | denoise: True 83 | eval_eps: 1e-3 84 | stride_type: uniform 85 | use_pflow: False 86 | sample_from: "target" 87 | accelerator: "gpu" 88 | devices: [0] 89 | n_samples: 50000 90 | workers: 2 91 | batch_size: 64 92 | save_mode: image 93 | sample_prefix: "gpu" 94 | path_prefix: "" 95 | -------------------------------------------------------------------------------- /main/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .celeba import CelebADataset 2 | from .cifar10 import CIFAR10Dataset 3 | from .celebahq import CelebAHQDataset 4 | from .afhq import AFHQv2Dataset 5 | from .inpaint import InpaintDataset 6 | -------------------------------------------------------------------------------- /main/datasets/afhq.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | from util import data_scaler, register_module 8 | 9 | 10 | @register_module(category="datasets", name="afhqv2") 11 | class AFHQv2Dataset(Dataset): 12 | """Implementation of the AFHQv2 dataset. 13 | Downloaded from https://github.com/clovaai/stargan-v2 14 | """ 15 | def __init__( 16 | self, 17 | root, 18 | norm=True, 19 | subsample_size=None, 20 | return_target=False, 21 | transform=None, 22 | **kwargs, 23 | ): 24 | if not os.path.isdir(root): 25 | raise ValueError(f"The specified root: {root} does not exist") 26 | self.root = root 27 | self.transform = transform 28 | self.norm = norm 29 | self.subsample_size = subsample_size 30 | self.return_target = return_target 31 | 32 | self.images = [] 33 | self.labels = [] 34 | 35 | cat = kwargs.get("cat", []) 36 | is_train = kwargs.get("train", True) 37 | subfolder_list = ["dog", "cat", "wild"] if cat == [] else cat 38 | base_path = os.path.join(self.root, "train" if is_train else "test") 39 | for idx, subfolder in enumerate(subfolder_list): 40 | sub_path = os.path.join(base_path, subfolder) 41 | 42 | for img in tqdm(os.listdir(sub_path)): 43 | self.images.append(os.path.join(sub_path, img)) 44 | self.labels.append(idx) 45 | 46 | def __getitem__(self, idx): 47 | img_path = self.images[idx] 48 | img = Image.open(img_path) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | # Scale images between [-1, 1] or [0, 1] 53 | # Normalize 54 | img = data_scaler(img, norm=self.norm) 55 | 56 | # Return Targets actually returns the class-label based on the animal category. 57 | # This is only helpful when using guidance using generation. 58 | if self.return_target: 59 | return torch.from_numpy(img).permute(2, 0, 1).float(), self.labels[idx] 60 | return torch.from_numpy(img).permute(2, 0, 1).float() 61 | 62 | def __len__(self): 63 | return len(self.images) if self.subsample_size is None else self.subsample_size 64 | -------------------------------------------------------------------------------- /main/datasets/celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | from util import data_scaler, register_module 8 | 9 | 10 | @register_module(category="datasets", name="celeba64") 11 | class CelebADataset(Dataset): 12 | """Implementation of the CelebA dataset. 13 | Downloaded from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html 14 | """ 15 | def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwargs): 16 | if not os.path.isdir(root): 17 | raise ValueError(f"The specified root: {root} does not exist") 18 | self.root = root 19 | self.transform = transform 20 | self.norm = norm 21 | 22 | self.images = [] 23 | 24 | for img in tqdm(os.listdir(root)): 25 | self.images.append(os.path.join(self.root, img)) 26 | 27 | # Subsample the dataset (if enabled) 28 | if subsample_size is not None: 29 | self.images = self.images[:subsample_size] 30 | 31 | def __getitem__(self, idx): 32 | img_path = self.images[idx] 33 | img = Image.open(img_path) 34 | 35 | # Apply transform 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | # Normalize 40 | img = data_scaler(img, norm=self.norm) 41 | 42 | # TODO: Add the functionality to return labels to enable 43 | # guidance-based generation. 44 | return torch.from_numpy(img).permute(2, 0, 1).float() 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | -------------------------------------------------------------------------------- /main/datasets/celebahq.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from tqdm import tqdm 7 | from util import data_scaler, register_module 8 | 9 | 10 | # NOTE: We dont use this dataset in the paper 11 | @register_module(category="datasets", name='celebahq256') 12 | class CelebAHQDataset(Dataset): 13 | def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwargs): 14 | if not os.path.isdir(root): 15 | raise ValueError(f"The specified root: {root} does not exist") 16 | self.root = root 17 | self.transform = transform 18 | self.norm = norm 19 | 20 | self.images = [] 21 | 22 | modes = ["train", "val"] 23 | subfolders = ["male", "female"] 24 | 25 | for mode in modes: 26 | for folder in subfolders: 27 | img_path = os.path.join(self.root, mode, folder) 28 | for img in tqdm(sorted(os.listdir(img_path))): 29 | self.images.append(os.path.join(img_path, img)) 30 | 31 | # Subsample the dataset (if enabled) 32 | if subsample_size is not None: 33 | self.images = self.images[:subsample_size] 34 | 35 | def __getitem__(self, idx): 36 | img_path = self.images[idx] 37 | img = Image.open(img_path) 38 | if self.transform is not None: 39 | img = self.transform(img) 40 | 41 | # Normalize 42 | img = data_scaler(img, norm=self.norm) 43 | 44 | return torch.from_numpy(img).permute(2, 0, 1).float() 45 | 46 | def __len__(self): 47 | return len(self.images) 48 | -------------------------------------------------------------------------------- /main/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets import CIFAR10 6 | from util import data_scaler, register_module 7 | 8 | 9 | @register_module(category="datasets", name="cifar10") 10 | class CIFAR10Dataset(Dataset): 11 | def __init__( 12 | self, 13 | root, 14 | norm=True, 15 | transform=None, 16 | subsample_size=None, 17 | return_target=False, 18 | **kwargs, 19 | ): 20 | if not os.path.isdir(root): 21 | raise ValueError(f"The specified root: {root} does not exist") 22 | 23 | if subsample_size is not None: 24 | assert isinstance(subsample_size, int) 25 | 26 | self.root = root 27 | self.norm = norm 28 | self.transform = transform 29 | self.dataset = CIFAR10(self.root, train=True, download=True, transform=None) 30 | self.subsample_size = subsample_size 31 | self.return_target = return_target 32 | 33 | def __getitem__(self, idx): 34 | img, target = self.dataset[idx] 35 | 36 | # Apply transform 37 | img_ = self.transform(img) if self.transform is not None else img 38 | 39 | # Normalize 40 | img = data_scaler(img_, norm=self.norm) 41 | 42 | # Return (with targets if needed for guidance-based generation) 43 | img_tensor = torch.tensor(img).permute(2, 0, 1).float() 44 | if self.return_target: 45 | return img_tensor, target 46 | return img_tensor 47 | 48 | def __len__(self): 49 | return len(self.dataset) if self.subsample_size is None else self.subsample_size 50 | -------------------------------------------------------------------------------- /main/datasets/inpaint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms.functional import InterpolationMode 3 | from torch.utils.data import Dataset 4 | 5 | from util import register_module 6 | from torchvision.datasets import MNIST 7 | import torchvision.transforms as T 8 | 9 | 10 | @register_module(category="datasets", name="inpaint") 11 | class InpaintDataset(Dataset): 12 | """Dataset for generating corrupted images. The images from the base dataset are masked with 13 | MNIST to generate images with missing pixels. 14 | """ 15 | def __init__(self, config, dataset): 16 | # Parent dataset (must return only images) 17 | self.config = config 18 | self.dataset = dataset 19 | 20 | # Used for creating masks 21 | t = T.Compose( 22 | [ 23 | T.Resize( 24 | (config.data.image_size, config.data.image_size), 25 | InterpolationMode.NEAREST, 26 | ), 27 | T.ToTensor(), 28 | ] 29 | ) 30 | self.mnist = MNIST(config.data.root, train=True, download=True, transform=t) 31 | 32 | def __getitem__(self, idx): 33 | img = self.dataset[idx] 34 | 35 | # Generate the mask with the mnist digit 36 | digit, _ = self.mnist[idx] 37 | digit = torch.cat([digit] * 3, dim=0) 38 | mask = (digit > 0).type(torch.long) 39 | mask = 1 - mask 40 | assert mask.shape == img.shape 41 | return img, mask 42 | 43 | def __len__(self): 44 | return min(self.config.evaluation.n_samples, len(self.dataset)) 45 | -------------------------------------------------------------------------------- /main/datasets/latent.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from util import register_module 3 | 4 | 5 | @register_module(category="datasets", name="latent") 6 | class SDELatentDataset(Dataset): 7 | """A dataset for generating samples from the equilibrium distribution 8 | of the forward SDE (useful during sampling) 9 | """ 10 | def __init__(self, sde, config): 11 | self.sde = sde 12 | self.num_samples = config.evaluation.n_samples 13 | self.shape = [ 14 | self.num_samples, 15 | config.data.num_channels, 16 | config.data.image_size, 17 | config.data.image_size, 18 | ] 19 | self.samples = self.sde.prior_sampling(self.shape) 20 | 21 | def get_batch(self, shape): 22 | return self.sde.prior_sampling(shape) 23 | 24 | def __getitem__(self, idx): 25 | return self.samples[idx] 26 | 27 | def __len__(self): 28 | return self.num_samples 29 | -------------------------------------------------------------------------------- /main/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/eval/__init__.py -------------------------------------------------------------------------------- /main/eval/class_cond_sample.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from copy import deepcopy 5 | 6 | # Add project directory to sys.path 7 | p = os.path.join(os.path.abspath("."), "main") 8 | sys.path.insert(1, p) 9 | 10 | import hydra 11 | import pytorch_lightning as pl 12 | from callbacks import SimpleImageWriter 13 | from datasets.latent import SDELatentDataset 14 | from models.clf_wrapper import TClfWrapper 15 | from models.wrapper import SDEWrapper 16 | from omegaconf import OmegaConf 17 | from pytorch_lightning.utilities.seed import seed_everything 18 | from torch.utils.data import DataLoader 19 | from util import get_module, import_modules_into_registry 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | # Import all modules into registry 25 | import_modules_into_registry() 26 | 27 | 28 | @hydra.main(config_path=os.path.join(p, "configs")) 29 | def cc_sample(config): 30 | """Evaluation script for Class conditional sampling with pre-trained score models 31 | using classifier guidance. 32 | """ 33 | config = config.dataset 34 | config_sde = config.diffusion 35 | config_clf = config.clf 36 | logger.info(OmegaConf.to_yaml(config)) 37 | 38 | # Set seed 39 | seed_everything(config_clf.evaluation.seed) 40 | 41 | # Setup Score SDE 42 | sde_cls = get_module(category="sde", name=config_sde.model.sde.name) 43 | sde = sde_cls(config_sde) 44 | logger.info(f"Using SDE: {sde_cls}") 45 | 46 | # Setup score predictor 47 | score_fn_cls = get_module(category="score_fn", name=config_sde.model.score_fn.name) 48 | score_fn = score_fn_cls(config_sde) 49 | logger.info(f"Using Score fn: {score_fn_cls}") 50 | 51 | ema_score_fn = deepcopy(score_fn) 52 | for p in ema_score_fn.parameters(): 53 | p.requires_grad = False 54 | 55 | wrapper = SDEWrapper.load_from_checkpoint( 56 | config_sde.evaluation.chkpt_path, 57 | config=config_sde, 58 | sde=sde, 59 | score_fn=score_fn, 60 | ema_score_fn=ema_score_fn, 61 | sampler_cls=None, 62 | ) 63 | 64 | score_fn = ( 65 | wrapper.ema_score_fn 66 | if config_sde.evaluation.sample_from == "target" 67 | else wrapper.score_fn 68 | ) 69 | score_fn.eval() 70 | 71 | # Setup sampler 72 | sampler_cls = get_module( 73 | category="samplers", name=config_sde.evaluation.sampler.name 74 | ) 75 | logger.info( 76 | f"Using Sampler: {sampler_cls}. Make sure the sampler supports Class-conditional sampling" 77 | ) 78 | 79 | # Setup dataset 80 | dataset = SDELatentDataset(sde, config_sde) 81 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}") 82 | 83 | # Setup classifier (for guidance) 84 | clf_fn_cls = get_module(category="clf_fn", name=config_clf.model.clf_fn.name) 85 | clf_fn = clf_fn_cls(config_clf) 86 | logger.info(f"Using Classifier fn: {clf_fn_cls}") 87 | 88 | wrapper = TClfWrapper.load_from_checkpoint( 89 | config_clf.evaluation.chkpt_path, 90 | config=config, 91 | sde=sde, 92 | clf_fn=clf_fn, 93 | score_fn=score_fn, 94 | sampler_cls=sampler_cls, 95 | strict=False, 96 | ) 97 | wrapper.eval() 98 | 99 | # Setup devices 100 | test_kwargs = {} 101 | loader_kws = {} 102 | device_type = config_sde.evaluation.accelerator 103 | test_kwargs["accelerator"] = device_type 104 | if device_type == "gpu": 105 | test_kwargs["devices"] = config_sde.evaluation.devices 106 | # # Disable find_unused_parameters when using DDP training for performance reasons 107 | # loader_kws["persistent_workers"] = True 108 | elif device_type == "tpu": 109 | test_kwargs["tpu_cores"] = 8 110 | 111 | # Predict loader 112 | val_loader = DataLoader( 113 | dataset, 114 | batch_size=config_sde.evaluation.batch_size, 115 | drop_last=False, 116 | pin_memory=True, 117 | shuffle=False, 118 | num_workers=config_sde.evaluation.workers, 119 | **loader_kws, 120 | ) 121 | 122 | # Setup Image writer callback trainer 123 | write_callback = SimpleImageWriter( 124 | config_sde.evaluation.save_path, 125 | write_interval="batch", 126 | sample_prefix=config_sde.evaluation.sample_prefix, 127 | path_prefix=config_sde.evaluation.path_prefix, 128 | save_mode=config_sde.evaluation.save_mode, 129 | is_augmented=config_sde.model.sde.is_augmented, 130 | ) 131 | 132 | test_kwargs["callbacks"] = [write_callback] 133 | test_kwargs["default_root_dir"] = config_sde.evaluation.save_path 134 | 135 | # Setup Pl module and predict 136 | sampler = pl.Trainer(**test_kwargs) 137 | sampler.predict(wrapper, val_loader) 138 | 139 | 140 | if __name__ == "__main__": 141 | cc_sample() 142 | -------------------------------------------------------------------------------- /main/eval/inpaint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | # Add project directory to sys.path 6 | p = os.path.join(os.path.abspath("."), "main") 7 | sys.path.insert(1, p) 8 | 9 | from copy import deepcopy 10 | 11 | import hydra 12 | import pytorch_lightning as pl 13 | from callbacks import InpaintingImageWriter 14 | from datasets import InpaintDataset 15 | from models.wrapper import SDEWrapper 16 | from omegaconf import OmegaConf 17 | from pytorch_lightning.utilities.seed import seed_everything 18 | from torch.utils.data import DataLoader 19 | from util import get_dataset, get_module, import_modules_into_registry 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | # Import all modules into registry 25 | import_modules_into_registry() 26 | 27 | 28 | @hydra.main(config_path=os.path.join(p, "configs")) 29 | def inpaint(config): 30 | """Evaluation script for inpainting with pre-trained score models using guidance.""" 31 | config = config.dataset.diffusion 32 | logger.info(OmegaConf.to_yaml(config)) 33 | 34 | # Set seed 35 | seed_everything(config.evaluation.seed) 36 | 37 | # Setup Score Predictor 38 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name) 39 | score_fn = score_fn_cls(config) 40 | logger.info(f"Using Score fn: {score_fn_cls}") 41 | 42 | ema_score_fn = deepcopy(score_fn) 43 | for p in ema_score_fn.parameters(): 44 | p.requires_grad = False 45 | 46 | score_fn.eval() 47 | ema_score_fn.eval() 48 | 49 | # Setup Score SDE 50 | sde_cls = get_module(category="sde", name=config.model.sde.name) 51 | sde = sde_cls(config) 52 | logger.info(f"Using SDE: {sde_cls}") 53 | 54 | # Setup sampler 55 | sampler_cls = get_module(category="samplers", name=config.evaluation.sampler.name) 56 | logger.info(f"Using Sampler: {sampler_cls}") 57 | 58 | # Setup dataset 59 | base_dataset = get_dataset(config) 60 | dataset = InpaintDataset(config, base_dataset) 61 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}") 62 | 63 | wrapper = SDEWrapper.load_from_checkpoint( 64 | config.evaluation.chkpt_path, 65 | config=config, 66 | sde=sde, 67 | score_fn=score_fn, 68 | ema_score_fn=ema_score_fn, 69 | sampler_cls=sampler_cls, 70 | ) 71 | 72 | # Setup devices 73 | test_kwargs = {} 74 | loader_kws = {} 75 | device_type = config.evaluation.accelerator 76 | test_kwargs["accelerator"] = device_type 77 | if device_type == "gpu": 78 | test_kwargs["devices"] = config.evaluation.devices 79 | # # Disable find_unused_parameters when using DDP training for performance reasons 80 | # loader_kws["persistent_workers"] = True 81 | elif device_type == "tpu": 82 | test_kwargs["tpu_cores"] = 8 83 | 84 | # Predict loader 85 | val_loader = DataLoader( 86 | dataset, 87 | batch_size=config.evaluation.batch_size, 88 | drop_last=False, 89 | pin_memory=True, 90 | shuffle=False, 91 | num_workers=config.evaluation.workers, 92 | **loader_kws, 93 | ) 94 | 95 | # Setup Image writer callback trainer 96 | write_callback = InpaintingImageWriter( 97 | config.evaluation.save_path, 98 | write_interval="batch", 99 | sample_prefix=config.evaluation.sample_prefix, 100 | path_prefix=config.evaluation.path_prefix, 101 | save_mode=config.evaluation.save_mode, 102 | is_augmented=config.model.sde.is_augmented, 103 | save_batch=True, 104 | ) 105 | 106 | test_kwargs["callbacks"] = [write_callback] 107 | test_kwargs["default_root_dir"] = config.evaluation.save_path 108 | 109 | # Setup Pl module and predict 110 | sampler = pl.Trainer(**test_kwargs) 111 | sampler.predict(wrapper, val_loader) 112 | 113 | 114 | if __name__ == "__main__": 115 | inpaint() 116 | -------------------------------------------------------------------------------- /main/eval/sample.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | # Add project directory to sys.path 6 | p = os.path.join(os.path.abspath("."), "main") 7 | sys.path.insert(1, p) 8 | 9 | from copy import deepcopy 10 | 11 | import hydra 12 | import pytorch_lightning as pl 13 | from callbacks import SimpleImageWriter 14 | from datasets.latent import SDELatentDataset 15 | from models.wrapper import SDEWrapper 16 | from omegaconf import OmegaConf 17 | from pytorch_lightning.utilities.seed import seed_everything 18 | from torch.utils.data import DataLoader 19 | from util import get_module, import_modules_into_registry 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | # Import all modules into registry 25 | import_modules_into_registry() 26 | 27 | 28 | @hydra.main(config_path=os.path.join(p, "configs")) 29 | def sample(config): 30 | """Evaluation script for Unconditional sampling using pre-trained score models""" 31 | config = config.dataset.diffusion 32 | logger.info(OmegaConf.to_yaml(config)) 33 | 34 | # Set seed 35 | seed_everything(config.evaluation.seed) 36 | 37 | # Setup score predictor 38 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name) 39 | score_fn = score_fn_cls(config) 40 | logger.info(f"Using Score fn: {score_fn_cls}") 41 | 42 | ema_score_fn = deepcopy(score_fn) 43 | for p in ema_score_fn.parameters(): 44 | p.requires_grad = False 45 | 46 | score_fn.eval() 47 | ema_score_fn.eval() 48 | 49 | # Setup Score SDE 50 | sde_cls = get_module(category="sde", name=config.model.sde.name) 51 | sde = sde_cls(config) 52 | logger.info(f"Using SDE: {sde_cls}") 53 | 54 | # Setup sampler 55 | sampler_cls = get_module(category="samplers", name=config.evaluation.sampler.name) 56 | logger.info(f"Using Sampler: {sampler_cls}") 57 | 58 | # Setup dataset 59 | dataset = SDELatentDataset(sde, config) 60 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}") 61 | 62 | wrapper = SDEWrapper.load_from_checkpoint( 63 | config.evaluation.chkpt_path, 64 | config=config, 65 | sde=sde, 66 | score_fn=score_fn, 67 | ema_score_fn=ema_score_fn, 68 | sampler_cls=sampler_cls, 69 | ) 70 | 71 | # Setup devices 72 | test_kwargs = {} 73 | loader_kws = {} 74 | device_type = config.evaluation.accelerator 75 | test_kwargs["accelerator"] = device_type 76 | if device_type == "gpu": 77 | test_kwargs["devices"] = config.evaluation.devices 78 | # # Disable find_unused_parameters when using DDP training for performance reasons 79 | # loader_kws["persistent_workers"] = True 80 | elif device_type == "tpu": 81 | test_kwargs["tpu_cores"] = 8 82 | 83 | # Predict loader 84 | val_loader = DataLoader( 85 | dataset, 86 | batch_size=config.evaluation.batch_size, 87 | drop_last=False, 88 | pin_memory=True, 89 | shuffle=False, 90 | num_workers=config.evaluation.workers, 91 | **loader_kws, 92 | ) 93 | 94 | # Setup Image writer callback trainer 95 | write_callback = SimpleImageWriter( 96 | config.evaluation.save_path, 97 | write_interval="batch", 98 | sample_prefix=config.evaluation.sample_prefix, 99 | path_prefix=config.evaluation.path_prefix, 100 | save_mode=config.evaluation.save_mode, 101 | is_augmented=config.model.sde.is_augmented, 102 | ) 103 | 104 | test_kwargs["callbacks"] = [write_callback] 105 | test_kwargs["default_root_dir"] = config.evaluation.save_path 106 | 107 | # Setup Pl module and predict 108 | sampler = pl.Trainer(**test_kwargs) 109 | sampler.predict(wrapper, val_loader) 110 | 111 | 112 | if __name__ == "__main__": 113 | sample() 114 | -------------------------------------------------------------------------------- /main/losses.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from util import get_module, register_module, reshape 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def compute_top_k(logits, labels, k, reduction="mean"): 13 | _, top_ks = torch.topk(logits, k, dim=-1) 14 | if reduction == "mean": 15 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 16 | elif reduction == "none": 17 | return (top_ks == labels[:, None]).float().sum(dim=-1) 18 | 19 | 20 | @register_module(category="losses", name="score_loss") 21 | class ScoreLoss(nn.Module): 22 | """Loss function for training non-augmented Score based models (like VP-SDE)""" 23 | 24 | def __init__(self, config, sde): 25 | super().__init__() 26 | assert config.training.loss.weighting in ["nll", "fid"] 27 | self.sde = sde 28 | self.l_type = config.training.loss.l_type 29 | self.weighting = config.training.loss.weighting 30 | 31 | logger.info(f"Initialized Loss fn with weighting: {self.weighting}") 32 | 33 | if self.weighting == "nll" and self.l_type != "l2": 34 | # Use only MSE loss when maximizing for nll 35 | raise ValueError("l_type can only be `l2` when using nll weighting") 36 | 37 | self.reduce_strategy = "mean" if config.training.loss.reduce_mean else "sum" 38 | criterion_type = nn.MSELoss if self.l_type == "l2" else nn.L1Loss 39 | self.criterion = criterion_type(reduction=self.reduce_strategy) 40 | 41 | def forward(self, x_0, t, score_fn, eps=None): 42 | if eps is None: 43 | eps = torch.randn_like(x_0, device=x_0.device) 44 | 45 | assert eps.shape == x_0.shape 46 | 47 | # Predict epsilon 48 | x_t = self.sde.perturb_data(x_0, t, noise=eps) 49 | eps_pred = score_fn(x_t.type(torch.float32), t.type(torch.float32)) 50 | 51 | # Use eps-prediction when optimizing for FID 52 | loss = self.criterion(eps, eps_pred) 53 | 54 | # Use g(t)**2 weighting when optimizing for NLL 55 | if self.weighting == "nll": 56 | gt_2 = reshape(self.sde.likelihood_weighting(t), x_0) 57 | gt_score = self.sde.get_score(eps, t) 58 | pred_score = self.sde.get_score(eps_pred, t) 59 | # loss = self.criterion(pred_score, gt_score) * gt_2 60 | loss = (pred_score - gt_score) ** 2 * gt_2 61 | loss = ( 62 | torch.mean(loss) if self.reduce_strategy == "mean" else torch.sum(loss) 63 | ) 64 | 65 | return loss 66 | 67 | 68 | @register_module(category="losses", name="psld_score_loss") 69 | class PSLDScoreLoss(nn.Module): 70 | """Loss function for training PSLD.""" 71 | 72 | def __init__(self, config, sde): 73 | super().__init__() 74 | # TODO: Add support for likelihood training 75 | assert config.training.loss.weighting in ["fid"] 76 | assert config.training.mode in ["hsm", "dsm"] 77 | assert isinstance(sde, get_module("sde", "psld")) 78 | self.sde = sde 79 | self.l_type = config.training.loss.l_type 80 | self.weighting = config.training.loss.weighting 81 | self.mode = config.training.mode 82 | self.decomp_mode = config.model.sde.decomp_mode 83 | 84 | logger.info( 85 | f"Initialized Loss fn with weighting: {self.weighting} mode: {self.mode}" 86 | ) 87 | 88 | if self.weighting == "nll" and self.l_type != "l2": 89 | # Use only MSE loss when maximizing for nll 90 | raise ValueError("l_type can only be `l2` when using nll weighting") 91 | 92 | self.reduce_strategy = "mean" if config.training.loss.reduce_mean else "sum" 93 | 94 | def forward(self, x_0, t, score_fn, eps=None): 95 | # Sample momentum (DSM) 96 | m_0 = np.sqrt(self.sde.mm_0) * torch.randn_like(x_0) 97 | mm_0 = 0.0 98 | 99 | # Update momentum (if training mode is HSM) 100 | if self.mode == "hsm": 101 | m_0 = torch.zeros_like(x_0) 102 | mm_0 = self.sde.mm_0 103 | 104 | z_0 = torch.cat([x_0, m_0], dim=1) 105 | xx_0 = 0 106 | 107 | # Sample random noise 108 | if eps is None: 109 | eps = torch.randn_like(z_0) 110 | assert eps.shape == z_0.shape 111 | 112 | # Predict epsilon 113 | z_t, _, _ = self.sde.perturb_data(x_0, m_0, xx_0, mm_0, t, eps=eps) 114 | z_t = z_t.type(torch.float32) 115 | eps_pred = score_fn(z_t, t.type(torch.float32)) 116 | 117 | # Use eps-prediction when optimizing for FID 118 | eps_x, eps_m = torch.chunk(eps, 2, dim=1) 119 | if self.sde.mode == "score_m" and self.decomp_mode == "lower": 120 | assert eps_pred.shape == eps_m.shape 121 | loss = (eps_m - eps_pred) ** 2 122 | elif self.sde.mode == "score_x" and self.decomp_mode == "upper": 123 | assert eps_pred.shape == eps_x.shape 124 | loss = (eps_x - eps_pred) ** 2 125 | else: 126 | assert eps_pred.shape == eps.shape 127 | loss = (eps - eps_pred) ** 2 128 | 129 | loss = torch.mean(loss) if self.reduce_strategy == "mean" else torch.sum(loss) 130 | return loss 131 | 132 | 133 | @register_module(category="losses", name="tce_loss") 134 | class PSLDTimeCELoss(nn.Module): 135 | """Loss function for training noise conditioned classifier for guidance""" 136 | 137 | def __init__(self, config, sde): 138 | super().__init__() 139 | assert config.diffusion.training.mode in ["hsm", "dsm"] 140 | assert isinstance(sde, get_module("sde", "psld")) 141 | self.sde = sde 142 | self.l_type = config.clf.training.loss.l_type 143 | self.mode = config.diffusion.training.mode 144 | 145 | self.reduce_strategy = ( 146 | "mean" if config.diffusion.training.loss.reduce_mean else "sum" 147 | ) 148 | self.criterion = nn.CrossEntropyLoss(reduction=self.reduce_strategy) 149 | 150 | def forward(self, x_0, y, t, clf_fn): 151 | # Sample momentum (DSM) 152 | m_0 = np.sqrt(self.sde.mm_0) * torch.randn_like(x_0) 153 | mm_0 = 0.0 154 | 155 | # Update momentum (if training mode is HSM) 156 | if self.mode == "hsm": 157 | m_0 = torch.zeros_like(x_0) 158 | mm_0 = self.sde.mm_0 159 | 160 | u_0 = torch.cat([x_0, m_0], dim=1) 161 | xx_0 = 0 162 | 163 | # Sample random noise 164 | eps = torch.randn_like(u_0) 165 | 166 | # Perturb Data 167 | u_t, _, _ = self.sde.perturb_data(x_0, m_0, xx_0, mm_0, t, eps=eps) 168 | 169 | # Predict label 170 | y_pred = clf_fn(u_t.type(torch.float32), t) 171 | 172 | # CE loss 173 | loss = self.criterion(y_pred, y) 174 | 175 | # Top-k accuracy (for debugging) 176 | acc = compute_top_k(y_pred, y, 1) 177 | return loss, acc 178 | -------------------------------------------------------------------------------- /main/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clf_wrapper import TClfWrapper 2 | from .score_fn.song_sde.ncsnpp import NCSNpp 3 | from .score_fn.song_sde.ncsnpp_clf import NCSNppClassifier 4 | from .sde.psld import PSLD 5 | from .sde.vpsde import VPSDE 6 | from .wrapper import SDEWrapper 7 | -------------------------------------------------------------------------------- /main/models/clf_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning.utilities.seed import seed_everything 6 | from util import get_module, register_module 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @register_module(category="pl_modules", name="tclf_wrapper") 12 | class TClfWrapper(pl.LightningModule): 13 | """This PL module can do the following tasks: 14 | - train: Train a classifier to predict labels given a noisy state z_t 15 | - predict: Generate class-conditional samples using classifier guidance 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config, 21 | sde, 22 | clf_fn, 23 | score_fn=None, 24 | criterion=None, 25 | sampler_cls=None, 26 | corrector_fn=None, 27 | ): 28 | super().__init__() 29 | self.config = config 30 | self.sde = sde 31 | self.clf_fn = clf_fn 32 | 33 | # Training 34 | self.criterion = criterion 35 | self.train_eps = self.config.diffusion.training.train_eps 36 | 37 | # Evaluation 38 | self.score_fn = score_fn 39 | self.sampler = None 40 | 41 | if sampler_cls is not None: 42 | self.sampler = sampler_cls( 43 | self.config, 44 | self.sde, 45 | self.score_fn, 46 | self.clf_fn, 47 | corrector_fn=corrector_fn, 48 | ) 49 | self.eval_eps = self.config.diffusion.evaluation.eval_eps 50 | self.denoise = self.config.diffusion.evaluation.denoise 51 | n_discrete_steps = self.config.diffusion.evaluation.n_discrete_steps 52 | self.n_discrete_steps = ( 53 | n_discrete_steps - 1 if self.denoise else n_discrete_steps 54 | ) 55 | self.val_eps = self.config.diffusion.evaluation.eval_eps 56 | self.stride_type = self.config.diffusion.evaluation.stride_type 57 | 58 | def forward(self): 59 | pass 60 | 61 | def training_step(self, batch, batch_idx): 62 | # Images and labels 63 | x_0, y = batch 64 | 65 | # Sample timepoints (between [train_eps, 1]) 66 | t_ = torch.rand(x_0.shape[0], device=x_0.device, dtype=torch.float64) 67 | t = t_ * (self.sde.T - self.train_eps) + self.train_eps 68 | assert t.shape[0] == x_0.shape[0] 69 | 70 | # Compute loss and backward 71 | loss, acc = self.criterion(x_0, y, t, self.clf_fn) 72 | 73 | self.log("loss", loss, prog_bar=True) 74 | self.log("Top1-Acc", acc, prog_bar=True) 75 | return loss 76 | 77 | def on_predict_start(self): 78 | seed = self.config.clf.evaluation.seed 79 | 80 | # This is done for predictions since setting a common seed 81 | # leads to generating same samples across gpus which affects 82 | # the evaluation metrics like FID negatively. 83 | seed_everything(seed + self.global_rank, workers=True) 84 | 85 | def predict_step(self, batch, batch_idx, dataloader_idx=None): 86 | t_final = self.sde.T - self.eval_eps 87 | ts = torch.linspace( 88 | 0, 89 | t_final, 90 | self.n_discrete_steps + 1, 91 | device=batch.device, 92 | dtype=torch.float64, 93 | ) 94 | 95 | if self.stride_type == "uniform": 96 | pass 97 | elif self.stride_type == "quadratic": 98 | ts = t_final * torch.flip(1 - (ts / t_final) ** 2.0, dims=[0]) 99 | 100 | return self.sampler.sample( 101 | batch, ts, self.n_discrete_steps, denoise=self.denoise, eps=self.eval_eps 102 | ) 103 | 104 | def on_predict_end(self): 105 | if isinstance(self.sampler, get_module("samplers", "bb_ode")): 106 | print(self.sampler.mean_nfe) 107 | 108 | def configure_optimizers(self): 109 | opt_config = self.config.clf.training.optimizer 110 | opt_name = opt_config.name 111 | if opt_name == "Adam": 112 | optimizer = torch.optim.Adam( 113 | self.clf_fn.parameters(), 114 | lr=opt_config.lr, 115 | betas=(opt_config.beta_1, opt_config.beta_2), 116 | eps=opt_config.eps, 117 | weight_decay=opt_config.weight_decay, 118 | ) 119 | else: 120 | raise NotImplementedError(f"Optimizer {opt_name} not supported yet!") 121 | 122 | # Define the LR scheduler (As in Ho et al.) 123 | if opt_config.warmup == 0: 124 | lr_lambda = lambda step: 1.0 125 | else: 126 | lr_lambda = lambda step: min(step / opt_config.warmup, 1.0) 127 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 128 | return { 129 | "optimizer": optimizer, 130 | "lr_scheduler": { 131 | "scheduler": scheduler, 132 | "interval": "step", 133 | "strict": False, 134 | }, 135 | } 136 | -------------------------------------------------------------------------------- /main/models/score_fn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/score_fn/__init__.py -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/score_fn/song_sde/__init__.py -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 68 | eps=1e-6) 69 | self.NIN_0 = NIN(channels, channels) 70 | self.NIN_1 = NIN(channels, channels) 71 | self.NIN_2 = NIN(channels, channels) 72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 73 | self.skip_rescale = skip_rescale 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | h = self.GroupNorm_0(x) 78 | q = self.NIN_0(h) 79 | k = self.NIN_1(h) 80 | v = self.NIN_2(h) 81 | 82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 83 | w = torch.reshape(w, (B, H, W, H * W)) 84 | w = F.softmax(w, dim=-1) 85 | w = torch.reshape(w, (B, H, W, H, W)) 86 | h = torch.einsum('bhwij,bcij->bchw', w, v) 87 | h = self.NIN_3(h) 88 | if not self.skip_rescale: 89 | return x + h 90 | else: 91 | return (x + h) / np.sqrt(2.) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 96 | fir_kernel=(1, 3, 3, 1)): 97 | super().__init__() 98 | out_ch = out_ch if out_ch else in_ch 99 | if not fir: 100 | if with_conv: 101 | self.Conv_0 = conv3x3(in_ch, out_ch) 102 | else: 103 | if with_conv: 104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 105 | kernel=3, up=True, 106 | resample_kernel=fir_kernel, 107 | use_bias=True, 108 | kernel_init=default_init()) 109 | self.fir = fir 110 | self.with_conv = with_conv 111 | self.fir_kernel = fir_kernel 112 | self.out_ch = out_ch 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.shape 116 | if not self.fir: 117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 118 | if self.with_conv: 119 | h = self.Conv_0(h) 120 | else: 121 | if not self.with_conv: 122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 123 | else: 124 | h = self.Conv2d_0(x) 125 | 126 | return h 127 | 128 | 129 | class Downsample(nn.Module): 130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 131 | fir_kernel=(1, 3, 3, 1)): 132 | super().__init__() 133 | out_ch = out_ch if out_ch else in_ch 134 | if not fir: 135 | if with_conv: 136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 137 | else: 138 | if with_conv: 139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 140 | kernel=3, down=True, 141 | resample_kernel=fir_kernel, 142 | use_bias=True, 143 | kernel_init=default_init()) 144 | self.fir = fir 145 | self.fir_kernel = fir_kernel 146 | self.with_conv = with_conv 147 | self.out_ch = out_ch 148 | 149 | def forward(self, x): 150 | B, C, H, W = x.shape 151 | if not self.fir: 152 | if self.with_conv: 153 | x = F.pad(x, (0, 1, 0, 1)) 154 | x = self.Conv_0(x) 155 | else: 156 | x = F.avg_pool2d(x, 2, stride=2) 157 | else: 158 | if not self.with_conv: 159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 160 | else: 161 | x = self.Conv2d_0(x) 162 | 163 | return x 164 | 165 | 166 | class ResnetBlockDDPMpp(nn.Module): 167 | """ResBlock adapted from DDPM.""" 168 | 169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 170 | dropout=0.1, skip_rescale=False, init_scale=0.): 171 | super().__init__() 172 | out_ch = out_ch if out_ch else in_ch 173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 174 | self.Conv_0 = conv3x3(in_ch, out_ch) 175 | if temb_dim is not None: 176 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 178 | nn.init.zeros_(self.Dense_0.bias) 179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 180 | self.Dropout_0 = nn.Dropout(dropout) 181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 182 | if in_ch != out_ch: 183 | if conv_shortcut: 184 | self.Conv_2 = conv3x3(in_ch, out_ch) 185 | else: 186 | self.NIN_0 = NIN(in_ch, out_ch) 187 | 188 | self.skip_rescale = skip_rescale 189 | self.act = act 190 | self.out_ch = out_ch 191 | self.conv_shortcut = conv_shortcut 192 | 193 | def forward(self, x, temb=None): 194 | h = self.act(self.GroupNorm_0(x)) 195 | h = self.Conv_0(h) 196 | if temb is not None: 197 | h += self.Dense_0(self.act(temb))[:, :, None, None] 198 | h = self.act(self.GroupNorm_1(h)) 199 | h = self.Dropout_0(h) 200 | h = self.Conv_1(h) 201 | if x.shape[1] != self.out_ch: 202 | if self.conv_shortcut: 203 | x = self.Conv_2(x) 204 | else: 205 | x = self.NIN_0(x) 206 | if not self.skip_rescale: 207 | return x + h 208 | else: 209 | return (x + h) / np.sqrt(2.) 210 | 211 | 212 | class ResnetBlockBigGANpp(nn.Module): 213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 215 | skip_rescale=True, init_scale=0.): 216 | super().__init__() 217 | 218 | out_ch = out_ch if out_ch else in_ch 219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 220 | self.up = up 221 | self.down = down 222 | self.fir = fir 223 | self.fir_kernel = fir_kernel 224 | 225 | self.Conv_0 = conv3x3(in_ch, out_ch) 226 | if temb_dim is not None: 227 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 229 | nn.init.zeros_(self.Dense_0.bias) 230 | 231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 232 | self.Dropout_0 = nn.Dropout(dropout) 233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 234 | if in_ch != out_ch or up or down: 235 | self.Conv_2 = conv1x1(in_ch, out_ch) 236 | 237 | self.skip_rescale = skip_rescale 238 | self.act = act 239 | self.in_ch = in_ch 240 | self.out_ch = out_ch 241 | 242 | def forward(self, x, temb=None): 243 | h = self.act(self.GroupNorm_0(x)) 244 | 245 | if self.up: 246 | if self.fir: 247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 249 | else: 250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 252 | elif self.down: 253 | if self.fir: 254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 259 | 260 | h = self.Conv_0(h) 261 | # Add bias to each feature map conditioned on the time embedding 262 | if temb is not None: 263 | h += self.Dense_0(self.act(temb))[:, :, None, None] 264 | h = self.act(self.GroupNorm_1(h)) 265 | h = self.Dropout_0(h) 266 | h = self.Conv_1(h) 267 | 268 | if self.in_ch != self.out_ch or self.up or self.down: 269 | x = self.Conv_2(x) 270 | 271 | if not self.skip_rescale: 272 | return x + h 273 | else: 274 | return (x + h) / np.sqrt(2.) 275 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/ncsnpp_clf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from . import utils, layers, layerspp, normalization 19 | from util import register_module 20 | import torch.nn as nn 21 | import functools 22 | import torch 23 | import numpy as np 24 | 25 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 26 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 27 | Combine = layerspp.Combine 28 | conv3x3 = layerspp.conv3x3 29 | conv1x1 = layerspp.conv1x1 30 | get_act = layers.get_act 31 | get_normalization = normalization.get_normalization 32 | default_initializer = layers.default_init 33 | 34 | 35 | @register_module(category="clf_fn", name="ncsnpp_clf") 36 | class NCSNppClassifier(nn.Module): 37 | """NCSN++ Classifier model for Class conditional generation. Same as the NCSN++ but removes the upsampling stage""" 38 | 39 | def __init__(self, config): 40 | super().__init__() 41 | self.config = config.model 42 | self.act = act = get_act(self.config.clf_fn) 43 | 44 | self.nf = nf = self.config.clf_fn.nf 45 | ch_mult = self.config.clf_fn.ch_mult 46 | self.num_res_blocks = num_res_blocks = self.config.clf_fn.num_res_blocks 47 | self.attn_resolutions = attn_resolutions = self.config.clf_fn.attn_resolutions 48 | dropout = self.config.clf_fn.dropout 49 | resamp_with_conv = self.config.clf_fn.resamp_with_conv 50 | self.num_resolutions = num_resolutions = len(ch_mult) 51 | self.all_resolutions = all_resolutions = [ 52 | config.data.image_size // (2**i) for i in range(num_resolutions) 53 | ] 54 | 55 | self.noise_cond = ( 56 | noise_cond 57 | ) = self.config.clf_fn.noise_cond # noise-conditional 58 | fir = self.config.clf_fn.fir 59 | fir_kernel = self.config.clf_fn.fir_kernel 60 | self.skip_rescale = skip_rescale = self.config.clf_fn.skip_rescale 61 | self.resblock_type = resblock_type = self.config.clf_fn.resblock_type.lower() 62 | self.progressive = progressive = self.config.clf_fn.progressive.lower() 63 | self.progressive_input = ( 64 | progressive_input 65 | ) = self.config.clf_fn.progressive_input.lower() 66 | self.embedding_type = embedding_type = self.config.clf_fn.embedding_type.lower() 67 | init_scale = self.config.clf_fn.init_scale 68 | assert progressive in ["none", "output_skip", "residual"] 69 | assert progressive_input in ["none", "input_skip", "residual"] 70 | assert embedding_type in ["fourier", "positional"] 71 | combine_method = self.config.clf_fn.progressive_combine.lower() 72 | combiner = functools.partial(Combine, method=combine_method) 73 | 74 | modules = [] 75 | # timestep/noise_level embedding; only for continuous training 76 | if embedding_type == "fourier": 77 | # Gaussian Fourier features embeddings. 78 | assert ( 79 | config.training.continuous 80 | ), "Fourier features are only used for continuous training." 81 | 82 | modules.append( 83 | layerspp.GaussianFourierProjection( 84 | embedding_size=nf, scale=self.config.clf_fn.fourier_scale 85 | ) 86 | ) 87 | embed_dim = 2 * nf 88 | 89 | elif embedding_type == "positional": 90 | embed_dim = nf 91 | 92 | else: 93 | raise ValueError(f"embedding type {embedding_type} unknown.") 94 | 95 | if noise_cond: 96 | modules.append(nn.Linear(embed_dim, nf * 4)) 97 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 98 | nn.init.zeros_(modules[-1].bias) 99 | modules.append(nn.Linear(nf * 4, nf * 4)) 100 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 101 | nn.init.zeros_(modules[-1].bias) 102 | 103 | AttnBlock = functools.partial( 104 | layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale 105 | ) 106 | 107 | Downsample = functools.partial( 108 | layerspp.Downsample, 109 | with_conv=resamp_with_conv, 110 | fir=fir, 111 | fir_kernel=fir_kernel, 112 | ) 113 | 114 | if progressive_input == "input_skip": 115 | self.pyramid_downsample = layerspp.Downsample( 116 | fir=fir, fir_kernel=fir_kernel, with_conv=False 117 | ) 118 | elif progressive_input == "residual": 119 | pyramid_downsample = functools.partial( 120 | layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True 121 | ) 122 | 123 | if resblock_type == "ddpm": 124 | ResnetBlock = functools.partial( 125 | ResnetBlockDDPM, 126 | act=act, 127 | dropout=dropout, 128 | init_scale=init_scale, 129 | skip_rescale=skip_rescale, 130 | temb_dim=nf * 4, 131 | ) 132 | 133 | elif resblock_type == "biggan": 134 | ResnetBlock = functools.partial( 135 | ResnetBlockBigGAN, 136 | act=act, 137 | dropout=dropout, 138 | fir=fir, 139 | fir_kernel=fir_kernel, 140 | init_scale=init_scale, 141 | skip_rescale=skip_rescale, 142 | temb_dim=nf * 4, 143 | ) 144 | 145 | else: 146 | raise ValueError(f"resblock type {resblock_type} unrecognized.") 147 | 148 | # Downsampling block 149 | 150 | channels = self.config.clf_fn.in_ch 151 | if progressive_input != "none": 152 | input_pyramid_ch = channels 153 | 154 | modules.append(conv3x3(channels, nf)) 155 | hs_c = [nf] 156 | 157 | in_ch = nf 158 | for i_level in range(num_resolutions): 159 | # Residual blocks for this resolution 160 | for i_block in range(num_res_blocks): 161 | out_ch = nf * ch_mult[i_level] 162 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 163 | in_ch = out_ch 164 | 165 | if all_resolutions[i_level] in attn_resolutions: 166 | modules.append(AttnBlock(channels=in_ch)) 167 | hs_c.append(in_ch) 168 | 169 | if i_level != num_resolutions - 1: 170 | if resblock_type == "ddpm": 171 | modules.append(Downsample(in_ch=in_ch)) 172 | else: 173 | modules.append(ResnetBlock(down=True, in_ch=in_ch)) 174 | 175 | if progressive_input == "input_skip": 176 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) 177 | if combine_method == "cat": 178 | in_ch *= 2 179 | 180 | elif progressive_input == "residual": 181 | modules.append( 182 | pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch) 183 | ) 184 | input_pyramid_ch = in_ch 185 | 186 | hs_c.append(in_ch) 187 | 188 | in_ch = hs_c[-1] 189 | modules.append(ResnetBlock(in_ch=in_ch)) 190 | modules.append(AttnBlock(channels=in_ch)) 191 | modules.append(ResnetBlock(in_ch=in_ch)) 192 | 193 | # Construct classifier 194 | self.n_cls = config.model.clf_fn.n_cls 195 | last_res = all_resolutions[-1] 196 | modules.append(nn.Linear(in_ch * last_res**2, self.n_cls, bias=False)) 197 | 198 | self.all_modules = nn.ModuleList(modules) 199 | 200 | def forward(self, x, time_cond): 201 | # TODO: Add label and other forms of conditioning here! 202 | # timestep/noise_level embedding; only for continuous training 203 | modules = self.all_modules 204 | m_idx = 0 205 | if self.embedding_type == "fourier": 206 | # Gaussian Fourier features embeddings. 207 | used_sigmas = time_cond 208 | temb = modules[m_idx](torch.log(used_sigmas)) 209 | m_idx += 1 210 | 211 | elif self.embedding_type == "positional": 212 | # Sinusoidal positional embeddings. 213 | timesteps = time_cond 214 | # used_sigmas = self.sigmas[time_cond.long()] 215 | temb = layers.get_timestep_embedding(timesteps, self.nf) 216 | 217 | else: 218 | raise ValueError(f"embedding type {self.embedding_type} unknown.") 219 | 220 | if self.noise_cond: 221 | temb = modules[m_idx](temb) 222 | m_idx += 1 223 | temb = modules[m_idx](self.act(temb)) 224 | m_idx += 1 225 | else: 226 | temb = None 227 | 228 | # Downsampling block 229 | input_pyramid = None 230 | if self.progressive_input != "none": 231 | input_pyramid = x 232 | 233 | hs = [modules[m_idx](x)] 234 | m_idx += 1 235 | for i_level in range(self.num_resolutions): 236 | # Residual blocks for this resolution 237 | for i_block in range(self.num_res_blocks): 238 | h = modules[m_idx](hs[-1], temb) 239 | m_idx += 1 240 | if h.shape[-1] in self.attn_resolutions: 241 | h = modules[m_idx](h) 242 | m_idx += 1 243 | 244 | hs.append(h) 245 | 246 | if i_level != self.num_resolutions - 1: 247 | if self.resblock_type == "ddpm": 248 | h = modules[m_idx](hs[-1]) 249 | m_idx += 1 250 | else: 251 | h = modules[m_idx](hs[-1], temb) 252 | m_idx += 1 253 | 254 | if self.progressive_input == "input_skip": 255 | input_pyramid = self.pyramid_downsample(input_pyramid) 256 | h = modules[m_idx](input_pyramid, h) 257 | m_idx += 1 258 | 259 | elif self.progressive_input == "residual": 260 | input_pyramid = modules[m_idx](input_pyramid) 261 | m_idx += 1 262 | if self.skip_rescale: 263 | input_pyramid = (input_pyramid + h) / np.sqrt(2.0) 264 | else: 265 | input_pyramid = input_pyramid + h 266 | h = input_pyramid 267 | 268 | hs.append(h) 269 | 270 | h = hs[-1] 271 | h = modules[m_idx](h, temb) 272 | m_idx += 1 273 | h = modules[m_idx](h) 274 | m_idx += 1 275 | h = modules[m_idx](h, temb) 276 | m_idx += 1 277 | 278 | # Apply Classifier 279 | h = torch.flatten(h, start_dim=1) 280 | h = modules[m_idx](h) 281 | 282 | assert h.shape[-1] == self.n_cls 283 | return h 284 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from .op import upfirdn2d 11 | 12 | 13 | # Function ported from StyleGAN2 14 | def get_weight(module, 15 | shape, 16 | weight_var='weight', 17 | kernel_init=None): 18 | """Get/create weight tensor for a convolution or fully-connected layer.""" 19 | 20 | return module.param(weight_var, kernel_init, shape) 21 | 22 | 23 | class Conv2d(nn.Module): 24 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 25 | 26 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 27 | resample_kernel=(1, 3, 3, 1), 28 | use_bias=True, 29 | kernel_init=None): 30 | super().__init__() 31 | assert not (up and down) 32 | assert kernel >= 1 and kernel % 2 == 1 33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 34 | if kernel_init is not None: 35 | self.weight.data = kernel_init(self.weight.data.shape) 36 | if use_bias: 37 | self.bias = nn.Parameter(torch.zeros(out_ch)) 38 | 39 | self.up = up 40 | self.down = down 41 | self.resample_kernel = resample_kernel 42 | self.kernel = kernel 43 | self.use_bias = use_bias 44 | 45 | def forward(self, x): 46 | if self.up: 47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 48 | elif self.down: 49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 50 | else: 51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 52 | 53 | if self.use_bias: 54 | x = x + self.bias.reshape(1, -1, 1, 1) 55 | 56 | return x 57 | 58 | 59 | def naive_upsample_2d(x, factor=2): 60 | _N, C, H, W = x.shape 61 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 62 | x = x.repeat(1, 1, 1, factor, 1, factor) 63 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 64 | 65 | 66 | def naive_downsample_2d(x, factor=2): 67 | _N, C, H, W = x.shape 68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 69 | return torch.mean(x, dim=(3, 5)) 70 | 71 | 72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 74 | 75 | Padding is performed only once at the beginning, not between the 76 | operations. 77 | The fused op is considerably more efficient than performing the same 78 | calculation 79 | using standard TensorFlow ops. It supports gradients of arbitrary order. 80 | Args: 81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 82 | C]`. 83 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 84 | outChannels]`. Grouped convolution can be performed by `inChannels = 85 | x.shape[0] // numGroups`. 86 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 87 | (separable). The default is `[1] * factor`, which corresponds to 88 | nearest-neighbor upsampling. 89 | factor: Integer upsampling factor (default: 2). 90 | gain: Scaling factor for signal magnitude (default: 1.0). 91 | 92 | Returns: 93 | Tensor of the shape `[N, C, H * factor, W * factor]` or 94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 95 | """ 96 | 97 | assert isinstance(factor, int) and factor >= 1 98 | 99 | # Check weight shape. 100 | assert len(w.shape) == 4 101 | convH = w.shape[2] 102 | convW = w.shape[3] 103 | inC = w.shape[1] 104 | outC = w.shape[0] 105 | 106 | assert convW == convH 107 | 108 | # Setup filter kernel. 109 | if k is None: 110 | k = [1] * factor 111 | k = _setup_kernel(k) * (gain * (factor ** 2)) 112 | p = (k.shape[0] - factor) - (convW - 1) 113 | 114 | stride = (factor, factor) 115 | 116 | # Determine data dimensions. 117 | stride = [1, 1, factor, factor] 118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 121 | assert output_padding[0] >= 0 and output_padding[1] >= 0 122 | num_groups = _shape(x, 1) // inC 123 | 124 | # Transpose weights. 125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 128 | 129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 130 | ## Original TF code. 131 | # x = tf.nn.conv2d_transpose( 132 | # x, 133 | # w, 134 | # output_shape=output_shape, 135 | # strides=stride, 136 | # padding='VALID', 137 | # data_format=data_format) 138 | ## JAX equivalent 139 | 140 | return upfirdn2d(x, torch.tensor(k, device=x.device), 141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 142 | 143 | 144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 146 | 147 | Padding is performed only once at the beginning, not between the operations. 148 | The fused op is considerably more efficient than performing the same 149 | calculation 150 | using standard TensorFlow ops. It supports gradients of arbitrary order. 151 | Args: 152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 153 | C]`. 154 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 155 | outChannels]`. Grouped convolution can be performed by `inChannels = 156 | x.shape[0] // numGroups`. 157 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 158 | (separable). The default is `[1] * factor`, which corresponds to 159 | average pooling. 160 | factor: Integer downsampling factor (default: 2). 161 | gain: Scaling factor for signal magnitude (default: 1.0). 162 | 163 | Returns: 164 | Tensor of the shape `[N, C, H // factor, W // factor]` or 165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 166 | """ 167 | 168 | assert isinstance(factor, int) and factor >= 1 169 | _outC, _inC, convH, convW = w.shape 170 | assert convW == convH 171 | if k is None: 172 | k = [1] * factor 173 | k = _setup_kernel(k) * gain 174 | p = (k.shape[0] - factor) + (convW - 1) 175 | s = [factor, factor] 176 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 177 | pad=((p + 1) // 2, p // 2)) 178 | return F.conv2d(x, w, stride=s, padding=0) 179 | 180 | 181 | def _setup_kernel(k): 182 | k = np.asarray(k, dtype=np.float32) 183 | if k.ndim == 1: 184 | k = np.outer(k, k) 185 | k /= np.sum(k) 186 | assert k.ndim == 2 187 | assert k.shape[0] == k.shape[1] 188 | return k 189 | 190 | 191 | def _shape(x, dim): 192 | return x.shape[dim] 193 | 194 | 195 | def upsample_2d(x, k=None, factor=2, gain=1): 196 | r"""Upsample a batch of 2D images with the given filter. 197 | 198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 199 | and upsamples each image with the given filter. The filter is normalized so 200 | that 201 | if the input pixels are constant, they will be scaled by the specified 202 | `gain`. 203 | Pixels outside the image are assumed to be zero, and the filter is padded 204 | with 205 | zeros so that its shape is a multiple of the upsampling factor. 206 | Args: 207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 208 | C]`. 209 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 210 | (separable). The default is `[1] * factor`, which corresponds to 211 | nearest-neighbor upsampling. 212 | factor: Integer upsampling factor (default: 2). 213 | gain: Scaling factor for signal magnitude (default: 1.0). 214 | 215 | Returns: 216 | Tensor of the shape `[N, C, H * factor, W * factor]` 217 | """ 218 | assert isinstance(factor, int) and factor >= 1 219 | if k is None: 220 | k = [1] * factor 221 | k = _setup_kernel(k) * (gain * (factor ** 2)) 222 | p = k.shape[0] - factor 223 | return upfirdn2d(x, torch.tensor(k, device=x.device), 224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 225 | 226 | 227 | def downsample_2d(x, k=None, factor=2, gain=1): 228 | r"""Downsample a batch of 2D images with the given filter. 229 | 230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 231 | and downsamples each image with the given filter. The filter is normalized 232 | so that 233 | if the input pixels are constant, they will be scaled by the specified 234 | `gain`. 235 | Pixels outside the image are assumed to be zero, and the filter is padded 236 | with 237 | zeros so that its shape is a multiple of the downsampling factor. 238 | Args: 239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 240 | C]`. 241 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 242 | (separable). The default is `[1] * factor`, which corresponds to 243 | average pooling. 244 | factor: Integer downsampling factor (default: 2). 245 | gain: Scaling factor for signal magnitude (default: 1.0). 246 | 247 | Returns: 248 | Tensor of the shape `[N, C, H // factor, W // factor]` 249 | """ 250 | 251 | assert isinstance(factor, int) and factor >= 1 252 | if k is None: 253 | k = [1] * factor 254 | k = _setup_kernel(k) * gain 255 | p = k.shape[0] - factor 256 | return upfirdn2d(x, torch.tensor(k, device=x.device), 257 | down=factor, pad=((p + 1) // 2, p // 2)) 258 | -------------------------------------------------------------------------------- /main/models/score_fn/song_sde/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | 19 | import torch 20 | import numpy as np 21 | 22 | 23 | _MODELS = {} 24 | 25 | 26 | def register_model(cls=None, *, name=None): 27 | """A decorator for registering model classes.""" 28 | 29 | def _register(cls): 30 | if name is None: 31 | local_name = cls.__name__ 32 | else: 33 | local_name = name 34 | if local_name in _MODELS: 35 | raise ValueError(f'Already registered model with name: {local_name}') 36 | _MODELS[local_name] = cls 37 | return cls 38 | 39 | if cls is None: 40 | return _register 41 | else: 42 | return _register(cls) 43 | 44 | 45 | def get_model(name): 46 | return _MODELS[name] 47 | 48 | 49 | def get_sigmas(config): 50 | """Get sigmas --- the set of noise levels for SMLD from config files. 51 | Args: 52 | config: A ConfigDict object parsed from the config file 53 | Returns: 54 | sigmas: a jax numpy arrary of noise levels 55 | """ 56 | sigmas = np.exp( 57 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 58 | 59 | return sigmas 60 | 61 | 62 | def get_ddpm_params(config): 63 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 64 | num_diffusion_timesteps = 1000 65 | # parameters need to be adapted if number of time steps differs from 1000 66 | beta_start = config.model.beta_min / config.model.num_scales 67 | beta_end = config.model.beta_max / config.model.num_scales 68 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 69 | 70 | alphas = 1. - betas 71 | alphas_cumprod = np.cumprod(alphas, axis=0) 72 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 73 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 74 | 75 | return { 76 | 'betas': betas, 77 | 'alphas': alphas, 78 | 'alphas_cumprod': alphas_cumprod, 79 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 80 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 81 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 82 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 83 | 'num_diffusion_timesteps': num_diffusion_timesteps 84 | } 85 | 86 | 87 | def create_model(config): 88 | """Create the score model.""" 89 | model_name = config.model.name 90 | score_model = get_model(model_name)(config) 91 | score_model = score_model.to(config.device) 92 | score_model = torch.nn.DataParallel(score_model) 93 | return score_model 94 | 95 | 96 | def get_model_fn(model, train=False): 97 | """Create a function to give the output of the score-based model. 98 | 99 | Args: 100 | model: The score model. 101 | train: `True` for training and `False` for evaluation. 102 | 103 | Returns: 104 | A model function. 105 | """ 106 | 107 | def model_fn(x, labels): 108 | """Compute the output of the score-based model. 109 | 110 | Args: 111 | x: A mini-batch of input data. 112 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 113 | for different models. 114 | 115 | Returns: 116 | A tuple of (model output, new mutable states) 117 | """ 118 | if not train: 119 | model.eval() 120 | return model(x, labels) 121 | else: 122 | model.train() 123 | return model(x, labels) 124 | 125 | return model_fn 126 | 127 | 128 | # def get_score_fn(sde, model, train=False, continuous=False): 129 | # """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 130 | 131 | # Args: 132 | # sde: An `sde_lib.SDE` object that represents the forward SDE. 133 | # model: A score model. 134 | # train: `True` for training and `False` for evaluation. 135 | # continuous: If `True`, the score-based model is expected to directly take continuous time steps. 136 | 137 | # Returns: 138 | # A score function. 139 | # """ 140 | # model_fn = get_model_fn(model, train=train) 141 | 142 | # if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 143 | # def score_fn(x, t): 144 | # # Scale neural network output by standard deviation and flip sign 145 | # if continuous or isinstance(sde, sde_lib.subVPSDE): 146 | # # For VP-trained models, t=0 corresponds to the lowest noise level 147 | # # The maximum value of time embedding is assumed to 999 for 148 | # # continuously-trained models. 149 | # labels = t * 999 150 | # score = model_fn(x, labels) 151 | # std = sde.marginal_prob(torch.zeros_like(x), t)[1] 152 | # else: 153 | # # For VP-trained models, t=0 corresponds to the lowest noise level 154 | # labels = t * (sde.N - 1) 155 | # score = model_fn(x, labels) 156 | # std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 157 | 158 | # score = -score / std[:, None, None, None] 159 | # return score 160 | 161 | # elif isinstance(sde, sde_lib.VESDE): 162 | # def score_fn(x, t): 163 | # if continuous: 164 | # labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 165 | # else: 166 | # # For VE-trained models, t=0 corresponds to the highest noise level 167 | # labels = sde.T - t 168 | # labels *= sde.N - 1 169 | # labels = torch.round(labels).long() 170 | 171 | # score = model_fn(x, labels) 172 | # return score 173 | 174 | # else: 175 | # raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 176 | 177 | # return score_fn 178 | 179 | 180 | def to_flattened_numpy(x): 181 | """Flatten a torch tensor `x` and convert it to numpy.""" 182 | return x.detach().cpu().numpy().reshape((-1,)) 183 | 184 | 185 | def from_flattened_numpy(x, shape): 186 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 187 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /main/models/sde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/sde/__init__.py -------------------------------------------------------------------------------- /main/models/sde/base.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | import abc 3 | 4 | 5 | class SDE(abc.ABC): 6 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 7 | 8 | def __init__(self, N): 9 | """Construct an SDE. 10 | 11 | Args: 12 | N: number of discretization time steps. 13 | """ 14 | super().__init__() 15 | self.N = N 16 | 17 | @property 18 | @abc.abstractmethod 19 | def T(self): 20 | """End time of the SDE.""" 21 | raise NotImplementedError 22 | 23 | @abc.abstractmethod 24 | def get_score(self, eps, t): 25 | """Computes the score from a given epsilon value""" 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def perturb_data(self, x_0, t, noise=None): 30 | """Add noise to a data point""" 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def sde(self, x, t): 35 | """Return the drift and the diffusion coefficients""" 36 | raise NotImplementedError 37 | 38 | @abc.abstractmethod 39 | def reverse_sde(self, x, t, score_fn, probability_flow=False): 40 | """Return the drift and the diffusion coefficients of the reverse-sde""" 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def cond_marginal_prob(self, x, t): 45 | """Parameters to determine the marginal distribution of the SDE, $p_t(x_t|x_0)$.""" 46 | raise NotImplementedError 47 | 48 | @abc.abstractmethod 49 | def prior_sampling(self, shape): 50 | """Generate one sample from the prior distribution, $p_T(x)$.""" 51 | raise NotImplementedError 52 | 53 | @abc.abstractmethod 54 | def prior_logp(self, z): 55 | """Compute log-density of the prior distribution. 56 | 57 | Useful for computing the log-likelihood via probability flow ODE. 58 | 59 | Args: 60 | z: latent code 61 | Returns: 62 | log probability density 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /main/models/sde/vpsde.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from util import register_module, reshape 4 | 5 | from .base import SDE 6 | 7 | 8 | @register_module(category="sde", name="vpsde") 9 | class VPSDE(SDE): 10 | def __init__(self, config): 11 | """Construct a Variance Preserving (VP) SDE.""" 12 | super().__init__(config.model.sde.n_timesteps) 13 | self.beta_0 = config.model.sde.beta_min 14 | self.beta_1 = config.model.sde.beta_max 15 | 16 | def beta_t(self, t): 17 | return self.beta_0 + t * (self.beta_1 - self.beta_0) 18 | 19 | @property 20 | def T(self): 21 | return 1.0 22 | 23 | @property 24 | def type(self): 25 | return "vpsde" 26 | 27 | def get_score(self, eps, t): 28 | return -eps / reshape(self._std(t), eps) 29 | 30 | def perturb_data(self, x_0, t, noise=None): 31 | """Add noise to input data point""" 32 | noise = torch.randn_like(x_0) if None else noise 33 | assert noise.shape == x_0.shape 34 | 35 | mu_t, std_t = self.cond_marginal_prob(x_0, t) 36 | assert mu_t.shape == x_0.shape 37 | assert len(std_t.shape) == len(x_0.shape) 38 | return mu_t + noise * std_t 39 | 40 | def sde(self, x_t, t): 41 | """Return the drift and diffusion coefficients""" 42 | beta_t = reshape(self.beta_t(t), x_t) 43 | drift = -0.5 * beta_t * x_t 44 | diffusion = torch.sqrt(beta_t) 45 | 46 | return drift, diffusion 47 | 48 | def reverse_sde(self, x_t, t, score_fn, probability_flow=False): 49 | """Return the drift and diffusion coefficients of the reverse sde""" 50 | # The reverse SDE is defines on the domain (T-t) 51 | t = self.T - t 52 | 53 | # Forward drift and diffusion 54 | f, g = self.sde(x_t, t) 55 | 56 | # scale the score by 0.5 for the prob. flow formulation 57 | eps_pred = score_fn(x_t.type(torch.float32), t.type(torch.float32)) 58 | score = self.get_score(eps_pred, t) 59 | score = 0.5 * score if probability_flow else score 60 | 61 | # Reverse drift 62 | f_bar = -f + g**2 * score 63 | assert f_bar.shape == f.shape 64 | 65 | # Reverse diffusion (0 for prob. flow) 66 | g_bar = g if not probability_flow else torch.zeros_like(g) 67 | return f_bar, g_bar 68 | 69 | def cond_marginal_prob(self, x_0, t): 70 | """Generate samples from the perturbation kernel p(x_t|x_0)""" 71 | log_mean_coeff = ( 72 | -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 73 | ) 74 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x_0 75 | assert mean.shape == x_0.shape 76 | 77 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 78 | std = reshape(std, x_0) 79 | return mean, std 80 | 81 | def _std(self, t): 82 | log_mean_coeff = ( 83 | -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 84 | ) 85 | return torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) 86 | 87 | def prior_sampling(self, shape): 88 | """Generate samples from the prior p(x_T)""" 89 | return torch.randn(*shape) 90 | 91 | def prior_logp(self, z): 92 | shape = z.shape 93 | N = np.prod(shape[1:]) 94 | logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0 95 | return logps 96 | 97 | def likelihood_weighting(self, t): 98 | # Return g(t)^2 for likelihood training and computation 99 | return self.beta_t(t) 100 | -------------------------------------------------------------------------------- /main/models/wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning.utilities.seed import seed_everything 6 | from util import get_module, register_module 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @register_module(category="pl_modules", name="sde_wrapper") 12 | class SDEWrapper(pl.LightningModule): 13 | """This PL module can do the following tasks: 14 | - train: Train an unconditional score model using HSM or DSM 15 | - predict: Generate unconditional samples from a pre-trained score model 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config, 21 | sde, 22 | score_fn, 23 | ema_score_fn=None, 24 | criterion=None, 25 | sampler_cls=None, 26 | corrector_fn=None, 27 | ): 28 | super().__init__() 29 | self.config = config 30 | self.score_fn = score_fn 31 | self.ema_score_fn = ema_score_fn 32 | self.sde = sde 33 | 34 | # Training 35 | self.criterion = criterion 36 | self.train_eps = self.config.training.train_eps 37 | 38 | # Evaluation 39 | self.sampler = None 40 | if sampler_cls is not None: 41 | self.sampler = sampler_cls( 42 | self.config, 43 | self.sde, 44 | self.ema_score_fn 45 | if config.evaluation.sample_from == "target" 46 | else self.score_fn, 47 | corrector_fn=corrector_fn, 48 | ) 49 | self.eval_eps = self.config.evaluation.eval_eps 50 | self.denoise = self.config.evaluation.denoise 51 | n_discrete_steps = self.config.evaluation.n_discrete_steps 52 | self.n_discrete_steps = ( 53 | n_discrete_steps - 1 if self.denoise else n_discrete_steps 54 | ) 55 | self.val_eps = self.config.evaluation.eval_eps 56 | self.stride_type = self.config.evaluation.stride_type 57 | 58 | # Disable automatic optimization 59 | self.automatic_optimization = False 60 | 61 | def forward(self): 62 | pass 63 | 64 | def training_step(self, batch, batch_idx): 65 | # Optimizers 66 | optim = self.optimizers() 67 | lr_sched = self.lr_schedulers() 68 | 69 | x_0 = batch 70 | 71 | # Sample timepoints (between [eps, 1]) 72 | t_ = torch.rand(x_0.shape[0], device=x_0.device, dtype=torch.float64) 73 | t = t_ * (self.sde.T - self.train_eps) + self.train_eps 74 | assert t.shape[0] == x_0.shape[0] 75 | 76 | # Compute loss and backward 77 | loss = self.criterion(x_0, t, self.score_fn) 78 | optim.zero_grad() 79 | self.manual_backward(loss) 80 | 81 | # Clip gradients (if enabled.) 82 | if self.config.training.optimizer.grad_clip != 0: 83 | torch.nn.utils.clip_grad_norm_( 84 | self.score_fn.parameters(), self.config.training.optimizer.grad_clip 85 | ) 86 | optim.step() 87 | 88 | # Scheduler step 89 | lr_sched.step() 90 | self.log("loss", loss, prog_bar=True) 91 | return loss 92 | 93 | def on_predict_start(self): 94 | seed = self.config.evaluation.seed 95 | 96 | # This is done for predictions since setting a common seed 97 | # leads to generating same samples across gpus which affects 98 | # the evaluation metrics like FID negatively. 99 | seed_everything(seed + self.global_rank, workers=True) 100 | 101 | def predict_step(self, batch, batch_idx, dataloader_idx=None): 102 | t_final = self.sde.T - self.eval_eps 103 | ts = torch.linspace( 104 | 0, 105 | t_final, 106 | self.n_discrete_steps + 1, 107 | device=self.device, 108 | dtype=torch.float64, 109 | ) 110 | 111 | if self.stride_type == "uniform": 112 | pass 113 | elif self.stride_type == "quadratic": 114 | ts = t_final * torch.flip(1 - (ts / t_final) ** 2.0, dims=[0]) 115 | 116 | return self.sampler.sample( 117 | batch, 118 | ts, 119 | self.n_discrete_steps, 120 | denoise=self.denoise, 121 | eps=self.eval_eps, 122 | ) 123 | 124 | def on_predict_end(self): 125 | if isinstance(self.sampler, get_module("samplers", "bb_ode")): 126 | print(self.sampler.mean_nfe) 127 | 128 | def configure_optimizers(self): 129 | opt_config = self.config.training.optimizer 130 | opt_name = opt_config.name 131 | if opt_name == "Adam": 132 | optimizer = torch.optim.Adam( 133 | self.score_fn.parameters(), 134 | lr=opt_config.lr, 135 | betas=(opt_config.beta_1, opt_config.beta_2), 136 | eps=opt_config.eps, 137 | weight_decay=opt_config.weight_decay, 138 | ) 139 | else: 140 | raise NotImplementedError(f"Optimizer {opt_name} not supported yet!") 141 | 142 | # Define the LR scheduler (As in Ho et al.) 143 | if opt_config.warmup == 0: 144 | lr_lambda = lambda step: 1.0 145 | else: 146 | lr_lambda = lambda step: min(step / opt_config.warmup, 1.0) 147 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 148 | return { 149 | "optimizer": optimizer, 150 | "lr_scheduler": { 151 | "scheduler": scheduler, 152 | "interval": "step", 153 | "strict": False, 154 | }, 155 | } 156 | -------------------------------------------------------------------------------- /main/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .sde import EulerMaruyamaSampler, ClassCondEulerMaruyamaSampler 2 | from .ode import BBODESampler 3 | -------------------------------------------------------------------------------- /main/samplers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Sampler(abc.ABC): 5 | """The abstract class for a Sampler algorithm.""" 6 | 7 | def __init__(self, config, sde, score_fn, corrector_fn=None): 8 | super().__init__() 9 | self.config = config 10 | self.sde = sde 11 | self.score_fn = score_fn 12 | self.corrector_fn = corrector_fn 13 | 14 | @property 15 | def n_steps(self): 16 | return self.config.evaluation.n_discrete_steps 17 | 18 | @abc.abstractmethod 19 | def predictor_update_fn(self): 20 | raise NotImplementedError 21 | 22 | def corrector_update_fn(self, x, t, dt): 23 | if self.corrector_fn is not None: 24 | return self.corrector_fn(x, t, dt) 25 | 26 | # If no corrector is defined, Identity is default 27 | return x, x 28 | 29 | @abc.abstractmethod 30 | def sample(self): 31 | raise NotImplementedError 32 | -------------------------------------------------------------------------------- /main/samplers/ode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchdiffeq import odeint 3 | from util import register_module 4 | 5 | from .base import Sampler 6 | 7 | 8 | @register_module(category="samplers", name="bb_ode") 9 | class BBODESampler(Sampler): 10 | """Black-Box ODE sampler for generating samples from the 11 | Probability Flow ODE. 12 | """ 13 | 14 | def __init__(self, config, sde, score_fn, corrector_fn=None): 15 | super().__init__(config, sde, score_fn, corrector_fn=corrector_fn) 16 | self.nfe = 0 17 | self.rtol = config.evaluation.sampler.rtol 18 | self.atol = config.evaluation.sampler.atol 19 | self.solver_opts = {"solver": config.evaluation.sampler.solver} 20 | 21 | self._counter = 0 22 | 23 | @property 24 | def n_steps(self): 25 | return self.nfe 26 | 27 | @property 28 | def mean_nfe(self): 29 | if self._counter != 0: 30 | return self.nfe / self._counter 31 | raise ValueError("Run .sample() to compute mean_nfe") 32 | 33 | def predictor_update_fn(self, x, t, dt): 34 | pass 35 | 36 | def denoise_fn(self, x, t, dt): 37 | f, _ = self.sde.reverse_sde(x, t, self.score_fn, probability_flow=True) 38 | x_mean = x + f * dt 39 | return x_mean 40 | 41 | def sample(self, batch, ts, n_discrete_steps, denoise=True, eps=1e-3): 42 | def ode_fn(t, x): 43 | self.nfe += 1 44 | vec_t = torch.ones(x.shape[0], device=x.device, dtype=torch.float64) * t 45 | f, _ = self.sde.reverse_sde(x, vec_t, self.score_fn, probability_flow=True) 46 | return f 47 | 48 | x = batch 49 | self._counter += 1 50 | 51 | time_tensor = torch.tensor( 52 | [0.0, self.sde.T - eps], dtype=torch.float64, device=x.device 53 | ) 54 | # BB-ODE solver 55 | solution = odeint( 56 | ode_fn, 57 | x, 58 | time_tensor, 59 | rtol=self.rtol, 60 | atol=self.atol, 61 | method="scipy_solver", 62 | options=self.solver_opts, 63 | ) 64 | 65 | x = solution[-1] 66 | 67 | # Denoise 68 | if denoise: 69 | x = self.denoise_fn( 70 | x, 71 | torch.ones(x.shape[0], device=x.device, dtype=torch.float64) 72 | * (self.sde.T - eps), 73 | eps, 74 | ) 75 | self.nfe += 1 76 | return x 77 | -------------------------------------------------------------------------------- /main/train_clf.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | from models.clf_wrapper import TClfWrapper 7 | from omegaconf import OmegaConf 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from pytorch_lightning.strategies import DDPStrategy 10 | from pytorch_lightning.utilities.seed import seed_everything 11 | from torch.utils.data import DataLoader 12 | from util import get_dataset, get_module, import_modules_into_registry 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # Import all modules into registry 18 | import_modules_into_registry() 19 | 20 | 21 | @hydra.main(config_path="configs") 22 | def train_clf(config): 23 | """Helper script for training a noise conditioned classifier for guidance purposes""" 24 | # Get config and setup 25 | config = config.dataset 26 | logger.info(OmegaConf.to_yaml(config)) 27 | 28 | # Set seed 29 | seed_everything(config.clf.training.seed, workers=True) 30 | 31 | # Setup dataset 32 | dataset = get_dataset(config.clf) 33 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}") 34 | 35 | # Setup score predictor 36 | clf_fn_cls = get_module(category="clf_fn", name=config.clf.model.clf_fn.name) 37 | clf_fn = clf_fn_cls(config.clf) 38 | logger.info(f"Using Classifier fn: {clf_fn_cls}") 39 | 40 | # Setup Score SDE 41 | sde_cls = get_module(category="sde", name=config.diffusion.model.sde.name) 42 | sde = sde_cls(config.diffusion) 43 | logger.info(f"Using SDE: {sde_cls} with type: {sde.type}") 44 | logger.info(sde) 45 | 46 | # Setup Loss fn 47 | criterion_cls = get_module(category="losses", name=config.clf.training.loss.name) 48 | criterion = criterion_cls(config, sde) 49 | logger.info(f"Using Loss: {criterion_cls}") 50 | 51 | # Setup Lightning Wrapper Module 52 | wrapper = TClfWrapper(config, sde, clf_fn, score_fn=None, criterion=criterion) 53 | 54 | # Setup Trainer 55 | train_kwargs = {} 56 | 57 | # Setup callbacks 58 | results_dir = config.clf.training.results_dir 59 | chkpt_callback = ModelCheckpoint( 60 | dirpath=os.path.join(results_dir, "checkpoints"), 61 | filename=f"{config.clf.model.clf_fn.name}-{config.diffusion.model.sde.name}-{config.clf.training.chkpt_prefix}" 62 | + "-{epoch:02d}-{loss:.4f}", 63 | every_n_epochs=config.clf.training.chkpt_interval, 64 | save_top_k=-1, 65 | ) 66 | 67 | train_kwargs["default_root_dir"] = results_dir 68 | train_kwargs["max_epochs"] = config.clf.training.epochs 69 | train_kwargs["log_every_n_steps"] = config.clf.training.log_step 70 | train_kwargs["callbacks"] = [chkpt_callback] 71 | 72 | device_type = config.clf.training.accelerator 73 | train_kwargs["accelerator"] = device_type 74 | loader_kws = {} 75 | if device_type == "gpu": 76 | train_kwargs["devices"] = config.clf.training.devices 77 | 78 | # Disable find_unused_parameters when using DDP training for performance reasons 79 | train_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False) 80 | loader_kws["persistent_workers"] = True 81 | elif device_type == "tpu": 82 | train_kwargs["tpu_cores"] = 8 83 | 84 | # Half precision training 85 | if config.clf.training.fp16: 86 | train_kwargs["precision"] = 16 87 | 88 | # Loader 89 | batch_size = config.clf.training.batch_size 90 | batch_size = min(len(dataset), batch_size) 91 | loader = DataLoader( 92 | dataset, 93 | batch_size, 94 | num_workers=config.clf.training.workers, 95 | pin_memory=True, 96 | shuffle=True, 97 | drop_last=True, 98 | **loader_kws, 99 | ) 100 | 101 | # Trainer 102 | logger.info(f"Running Trainer with kwargs: {train_kwargs}") 103 | trainer = pl.Trainer(**train_kwargs) 104 | 105 | # Restore checkpoint 106 | restore_path = config.clf.training.restore_path 107 | if restore_path == "": 108 | restore_path = None 109 | trainer.fit(wrapper, train_dataloaders=loader, ckpt_path=restore_path) 110 | 111 | 112 | if __name__ == "__main__": 113 | train_clf() 114 | -------------------------------------------------------------------------------- /main/train_sde.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from copy import deepcopy 4 | 5 | import hydra 6 | import pytorch_lightning as pl 7 | from callbacks import EMAWeightUpdate 8 | from omegaconf import OmegaConf 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from pytorch_lightning.utilities.seed import seed_everything 11 | from torch.utils.data import DataLoader 12 | from util import get_dataset, get_module, import_modules_into_registry 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | # Import all modules into registry 18 | import_modules_into_registry() 19 | 20 | 21 | @hydra.main(config_path="configs") 22 | def train(config): 23 | """Helper script for training a score-based generative model""" 24 | # Get config and setup 25 | config = config.dataset.diffusion 26 | logger.info(OmegaConf.to_yaml(config)) 27 | 28 | # Set seed 29 | seed_everything(config.training.seed, workers=True) 30 | 31 | # Setup dataset 32 | dataset = get_dataset(config) 33 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}") 34 | 35 | # Setup score predictor 36 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name) 37 | score_fn = score_fn_cls(config) 38 | logger.info(f"Using Score fn: {score_fn_cls}") 39 | 40 | # Setup target network for EMA 41 | ema_score_fn = deepcopy(score_fn) 42 | for p in ema_score_fn.parameters(): 43 | p.requires_grad = False 44 | 45 | # Setup Score SDE 46 | sde_cls = get_module(category="sde", name=config.model.sde.name) 47 | sde = sde_cls(config) 48 | logger.info(f"Using SDE: {sde_cls} with type: {sde.type}") 49 | logger.info(sde) 50 | 51 | # Setup Loss fn 52 | criterion_cls = get_module(category="losses", name=config.training.loss.name) 53 | criterion = criterion_cls(config, sde) 54 | logger.info(f"Using Loss: {criterion_cls}") 55 | 56 | # Setup Lightning Wrapper Module 57 | wrapper_cls = get_module(category="pl_modules", name=config.model.pl_module) 58 | wrapper = wrapper_cls( 59 | config, sde, score_fn, ema_score_fn=ema_score_fn, criterion=criterion 60 | ) 61 | 62 | # Setup Trainer 63 | train_kwargs = {} 64 | 65 | # Setup callbacks 66 | results_dir = config.training.results_dir 67 | chkpt_callback = ModelCheckpoint( 68 | dirpath=os.path.join(results_dir, "checkpoints"), 69 | filename=f"{config.model.sde.name}-{config.training.chkpt_prefix}" 70 | + "-{epoch:02d}-{loss:.4f}", 71 | every_n_epochs=config.training.chkpt_interval, 72 | save_top_k=-1, 73 | ) 74 | 75 | train_kwargs["default_root_dir"] = results_dir 76 | train_kwargs["max_epochs"] = config.training.epochs 77 | train_kwargs["log_every_n_steps"] = config.training.log_step 78 | train_kwargs["callbacks"] = [chkpt_callback] 79 | if config.training.use_ema: 80 | ema_callback = EMAWeightUpdate(tau=config.training.ema_decay) 81 | train_kwargs["callbacks"].append(ema_callback) 82 | 83 | device_type = config.training.accelerator 84 | train_kwargs["accelerator"] = device_type 85 | loader_kws = {} 86 | if device_type == "gpu": 87 | train_kwargs["devices"] = config.training.devices 88 | 89 | # Disable find_unused_parameters when using DDP training for performance reasons 90 | # train_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False) 91 | loader_kws["persistent_workers"] = True 92 | elif device_type == "tpu": 93 | train_kwargs["tpu_cores"] = 8 94 | 95 | # Half precision training 96 | if config.training.fp16: 97 | train_kwargs["precision"] = 16 98 | 99 | # Loader 100 | batch_size = config.training.batch_size 101 | batch_size = min(len(dataset), batch_size) 102 | loader = DataLoader( 103 | dataset, 104 | batch_size, 105 | num_workers=config.training.workers, 106 | pin_memory=True, 107 | shuffle=True, 108 | drop_last=True, 109 | **loader_kws, 110 | ) 111 | 112 | # Trainer 113 | logger.info(f"Running Trainer with kwargs: {train_kwargs}") 114 | trainer = pl.Trainer(**train_kwargs, strategy="ddp") 115 | 116 | # Restore checkpoint 117 | restore_path = config.training.restore_path 118 | if restore_path == "": 119 | restore_path = None 120 | trainer.fit(wrapper, train_dataloaders=loader, ckpt_path=restore_path) 121 | 122 | 123 | if __name__ == "__main__": 124 | train() 125 | -------------------------------------------------------------------------------- /main/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as T 6 | from PIL import Image 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | _MODULES = {} 11 | 12 | 13 | def reshape(t, rt): 14 | """Adds additional dimensions corresponding to the size of the 15 | reference tensor rt. 16 | """ 17 | if len(rt.shape) == len(t.shape): 18 | return t 19 | ones = [1] * len(rt.shape[1:]) 20 | t_ = t.view(-1, *ones) 21 | assert len(t_.shape) == len(rt.shape) 22 | return t_ 23 | 24 | 25 | def data_scaler(img, norm=True): 26 | if norm: 27 | img = (np.asarray(img).astype(np.float) / 127.5) - 1.0 28 | else: 29 | img = np.asarray(img).astype(np.float) / 255.0 30 | return img 31 | 32 | 33 | def register_module(category=None, name=None): 34 | """A decorator for registering model classes.""" 35 | 36 | def _register(cls): 37 | local_category = category 38 | if local_category is None: 39 | local_category = cls.__name__ if name is None else name 40 | 41 | # Create category (if does not exist) 42 | if local_category not in _MODULES: 43 | _MODULES[local_category] = {} 44 | 45 | # Add module to the category 46 | local_name = cls.__name__ if name is None else name 47 | if name in _MODULES[local_category]: 48 | raise ValueError( 49 | f"Already registered module with name: {local_name} in category: {category}" 50 | ) 51 | 52 | _MODULES[local_category][local_name] = cls 53 | return cls 54 | 55 | return _register 56 | 57 | 58 | def get_module(category, name): 59 | module = _MODULES.get(category, dict()).get(name, None) 60 | if module is None: 61 | raise ValueError(f"No module named `{name}` found in category: `{category}`") 62 | return module 63 | 64 | 65 | def configure_device(device): 66 | if device.startswith("gpu"): 67 | if not torch.cuda.is_available(): 68 | raise Exception( 69 | "CUDA support is not available on your platform. Re-run using CPU or TPU mode" 70 | ) 71 | gpu_id = device.split(":")[-1] 72 | if gpu_id == "": 73 | # Use all GPU's 74 | gpu_id = -1 75 | gpu_id = [int(id) for id in gpu_id.split(",")] 76 | return f"cuda:{gpu_id}", gpu_id 77 | return device 78 | 79 | 80 | def get_dataset(config): 81 | # TODO: Add support for dynamically adding **kwargs directly via config 82 | # Parse config 83 | name = config.data.name 84 | root = config.data.root 85 | image_size = config.data.image_size 86 | norm = config.data.norm 87 | flip = config.data.hflip 88 | 89 | # Checks 90 | assert isinstance(norm, bool) 91 | 92 | if name.lower() == "cifar10": 93 | assert image_size == 32 94 | 95 | # Construct transforms 96 | t_list = [T.Resize((image_size, image_size))] 97 | if flip: 98 | t_list.append(T.RandomHorizontalFlip()) 99 | transform = T.Compose(t_list) 100 | 101 | # Get dataset 102 | dataset_cls = get_module(category="datasets", name=name.lower()) 103 | if dataset_cls is None: 104 | raise ValueError( 105 | f"Dataset with name: {name} not found in category: `datasets`. Ensure its properly registered" 106 | ) 107 | 108 | return dataset_cls( 109 | root, 110 | norm=norm, 111 | transform=transform, 112 | return_target=config.data.return_target, 113 | ) 114 | 115 | 116 | def import_modules_into_registry(): 117 | logger.info("Importing modules into registry") 118 | import datasets 119 | import losses 120 | import models 121 | import samplers 122 | 123 | 124 | def convert_to_np(obj): 125 | obj = obj.permute(0, 2, 3, 1).contiguous() 126 | obj = obj.detach().cpu().numpy() 127 | 128 | obj_list = [] 129 | for _, out in enumerate(obj): 130 | obj_list.append(out) 131 | return obj_list 132 | 133 | 134 | def normalize(obj): 135 | B, C, H, W = obj.shape 136 | for i in range(3): 137 | channel_val = obj[:, i, :, :].view(B, -1) 138 | channel_val -= channel_val.min(1, keepdim=True)[0] 139 | channel_val /= ( 140 | channel_val.max(1, keepdim=True)[0] - channel_val.min(1, keepdim=True)[0] 141 | ) 142 | channel_val = channel_val.view(B, H, W) 143 | obj[:, i, :, :] = channel_val 144 | return obj 145 | 146 | 147 | def save_as_images(obj, file_name="output", denorm=True): 148 | # Saves predictions as png images (useful for Sample generation) 149 | if denorm: 150 | # obj = normalize(obj) 151 | obj = obj * 0.5 + 0.5 152 | obj_list = convert_to_np(obj) 153 | 154 | for i, out in enumerate(obj_list): 155 | out = (out * 255).clip(0, 255).astype(np.uint8) 156 | img_out = Image.fromarray(out) 157 | current_file_name = file_name + "_%d.png" % i 158 | img_out.save(current_file_name, "png") 159 | 160 | 161 | def save_as_np(obj, file_name="output", denorm=True): 162 | # Saves predictions directly as numpy arrays 163 | if denorm: 164 | obj = normalize(obj) 165 | obj_list = convert_to_np(obj) 166 | 167 | for i, out in enumerate(obj_list): 168 | current_file_name = file_name + "_%d.npy" % i 169 | np.save(current_file_name, out) 170 | -------------------------------------------------------------------------------- /scripts_psld/ablations/cond/afhqv2/sample_inpaint_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/inpaint.py +dataset=afhqv2/afhqv2128_es3sde \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 3 | dataset.diffusion.data.name='afhqv2' \ 4 | +dataset.diffusion.data.mask_path=\'/home/pandeyk1/datasets\' \ 5 | dataset.diffusion.data.norm=True \ 6 | dataset.diffusion.data.hflip=True \ 7 | dataset.diffusion.model.score_fn.in_ch=6 \ 8 | dataset.diffusion.model.score_fn.out_ch=6 \ 9 | dataset.diffusion.model.score_fn.nf=128 \ 10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \ 11 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 13 | dataset.diffusion.model.score_fn.dropout=0.2 \ 14 | dataset.diffusion.model.sde.beta_min=8.0 \ 15 | dataset.diffusion.model.sde.beta_max=8.0 \ 16 | dataset.diffusion.model.sde.nu=4.01 \ 17 | dataset.diffusion.model.sde.gamma=0.01 \ 18 | dataset.diffusion.model.sde.kappa=0.04 \ 19 | dataset.diffusion.model.sde.decomp_mode='lower' \ 20 | dataset.diffusion.evaluation.seed=0 \ 21 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 22 | dataset.diffusion.evaluation.path_prefix="1000" \ 23 | dataset.diffusion.evaluation.devices=8 \ 24 | dataset.diffusion.evaluation.batch_size=1 \ 25 | dataset.diffusion.evaluation.stride_type='uniform' \ 26 | dataset.diffusion.evaluation.sample_from='target' \ 27 | dataset.diffusion.evaluation.workers=1 \ 28 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/es3sde-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \ 29 | dataset.diffusion.evaluation.sampler.name="ip_em_sde" \ 30 | dataset.diffusion.evaluation.n_samples=32 \ 31 | dataset.diffusion.evaluation.n_discrete_steps=1000 \ 32 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/inpaint/es3sde_hsm_gamma=0.01_nu=4.01_afhqv2_continuous_sfn=ncsnpp_nres=2/dummy_samples/test/wild/\' 33 | -------------------------------------------------------------------------------- /scripts_psld/ablations/cond/afhqv2/sample_tclf_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/class_cond_sample.py +dataset=afhqv2/afhqv2128_es3sde \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 3 | dataset.diffusion.data.name='afhqv2' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.2 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.nu=4.01 \ 16 | dataset.diffusion.model.sde.gamma=0.01 \ 17 | dataset.diffusion.model.sde.kappa=0.04 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.evaluation.seed=0 \ 20 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 21 | dataset.diffusion.evaluation.path_prefix="1000" \ 22 | dataset.diffusion.evaluation.devices=8 \ 23 | dataset.diffusion.evaluation.batch_size=1 \ 24 | dataset.diffusion.evaluation.stride_type='uniform' \ 25 | dataset.diffusion.evaluation.sample_from='target' \ 26 | dataset.diffusion.evaluation.workers=1 \ 27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/es3sde-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \ 28 | dataset.diffusion.evaluation.sampler.name="cc_em_sde" \ 29 | dataset.diffusion.evaluation.n_samples=32 \ 30 | dataset.diffusion.evaluation.n_discrete_steps=1000 \ 31 | dataset.clf.model.clf_fn.in_ch=6 \ 32 | dataset.clf.model.clf_fn.nf=128 \ 33 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \ 34 | dataset.clf.model.clf_fn.num_res_blocks=4 \ 35 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \ 36 | dataset.clf.model.clf_fn.dropout=0.1 \ 37 | dataset.clf.model.clf_fn.n_cls=3 \ 38 | dataset.clf.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_afhqv2_continuous_clf=ncsnppclf/checkpoints/ncsnpp_clf-es3sde-tclf_gamma=0.01_nu=4.01_afhqv2_Feb27-epoch=299-loss=0.5960.ckpt\' \ 39 | dataset.clf.evaluation.clf_temp=10.0 \ 40 | dataset.clf.evaluation.label_to_sample=2 \ 41 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/cc/es3sde_hsm_gamma=0.01_nu=4.01_afhqv2_continuous_sfn=ncsnpp_nres=2/cc_others/\' 42 | -------------------------------------------------------------------------------- /scripts_psld/ablations/cond/afhqv2/train_tclf_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_clf.py +dataset=afhqv2/afhqv2128_es3sde \ 2 | dataset.clf.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 3 | dataset.clf.data.name='afhqv2' \ 4 | dataset.clf.data.norm=True \ 5 | dataset.clf.data.hflip=True \ 6 | dataset.clf.data.return_target=True \ 7 | dataset.clf.model.pl_module='tclf_wrapper' \ 8 | dataset.clf.model.clf_fn.in_ch=6 \ 9 | dataset.clf.model.clf_fn.nf=128 \ 10 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \ 11 | dataset.clf.model.clf_fn.num_res_blocks=4 \ 12 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \ 13 | dataset.clf.model.clf_fn.dropout=0.1 \ 14 | dataset.clf.model.clf_fn.n_cls=3 \ 15 | dataset.diffusion.model.sde.beta_min=8.0 \ 16 | dataset.diffusion.model.sde.beta_max=8.0 \ 17 | dataset.diffusion.model.sde.decomp_mode='lower' \ 18 | dataset.diffusion.model.sde.nu=4.01 \ 19 | dataset.diffusion.model.sde.gamma=0.01 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.clf.training.loss.name='tce_loss' \ 22 | dataset.clf.training.seed=0 \ 23 | dataset.clf.training.chkpt_interval=100 \ 24 | dataset.clf.training.fp16=False \ 25 | dataset.clf.training.batch_size=16 \ 26 | dataset.clf.training.epochs=2000 \ 27 | dataset.clf.training.accelerator='gpu' \ 28 | dataset.clf.training.devices=[1,3,5,6] \ 29 | dataset.clf.training.results_dir=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_afhqv2_continuous_clf=ncsnppclf/\' \ 30 | dataset.clf.training.workers=1 \ 31 | dataset.clf.training.chkpt_prefix=\"tclf_gamma=0.01_nu=4.01_afhqv2_Feb27\" -------------------------------------------------------------------------------- /scripts_psld/ablations/cond/cifar10/sample_tclf_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/class_cond_sample.py +dataset=cifar10/cifar10_es3sde \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.15 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.nu=4.01 \ 19 | dataset.diffusion.model.sde.gamma=0.01 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.model.sde.decomp_mode='lower' \ 22 | dataset.diffusion.evaluation.seed=0 \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.path_prefix="1000" \ 25 | dataset.diffusion.evaluation.devices=8 \ 26 | dataset.diffusion.evaluation.batch_size=16 \ 27 | dataset.diffusion.evaluation.stride_type='uniform' \ 28 | dataset.diffusion.evaluation.sample_from='target' \ 29 | dataset.diffusion.evaluation.workers=1 \ 30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/sota/uncond/es3sde_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/es3sde-hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_3rdFeb-epoch=2249-loss=0.0148.ckpt\' \ 31 | dataset.diffusion.evaluation.sampler.name="cc_em_sde" \ 32 | dataset.diffusion.evaluation.n_samples=64 \ 33 | dataset.diffusion.evaluation.n_discrete_steps=1000 \ 34 | dataset.clf.model.clf_fn.in_ch=6 \ 35 | dataset.clf.model.clf_fn.nf=128 \ 36 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \ 37 | dataset.clf.model.clf_fn.num_res_blocks=4 \ 38 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \ 39 | dataset.clf.model.clf_fn.dropout=0.1 \ 40 | dataset.clf.model.clf_fn.n_cls=10 \ 41 | dataset.clf.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_cifar10_continuous_clf=ncsnppclf/checkpoints/ncsnpp_clf-es3sde-tclf_gamma=0.01_nu=4.01_cifar10_Feb27-epoch=999-loss=1.4715.ckpt\' \ 42 | dataset.clf.evaluation.clf_temp=5.0 \ 43 | dataset.clf.evaluation.label_to_sample=9 \ 44 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/cc/es3sde_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/dummy_samples/cc_cifar10_9\' \ 45 | -------------------------------------------------------------------------------- /scripts_psld/ablations/cond/cifar10/train_tclf_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_clf.py +dataset=cifar10/cifar10_es3sde \ 2 | dataset.clf.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.clf.data.name='cifar10' \ 4 | dataset.clf.data.norm=True \ 5 | dataset.clf.data.hflip=True \ 6 | dataset.clf.data.return_target=True \ 7 | dataset.clf.model.pl_module='tclf_wrapper' \ 8 | dataset.clf.model.clf_fn.in_ch=6 \ 9 | dataset.clf.model.clf_fn.nf=128 \ 10 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \ 11 | dataset.clf.model.clf_fn.num_res_blocks=4 \ 12 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \ 13 | dataset.clf.model.clf_fn.dropout=0.1 \ 14 | dataset.clf.model.clf_fn.n_cls=10 \ 15 | dataset.diffusion.model.sde.beta_min=8.0 \ 16 | dataset.diffusion.model.sde.beta_max=8.0 \ 17 | dataset.diffusion.model.sde.decomp_mode='lower' \ 18 | dataset.diffusion.model.sde.nu=4.0 \ 19 | dataset.diffusion.model.sde.gamma=0 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.clf.training.loss.name='tce_loss' \ 22 | dataset.clf.training.seed=0 \ 23 | dataset.clf.training.chkpt_interval=100 \ 24 | dataset.clf.training.fp16=False \ 25 | dataset.clf.training.batch_size=64 \ 26 | dataset.clf.training.epochs=2000 \ 27 | dataset.clf.training.accelerator='gpu' \ 28 | dataset.clf.training.devices=[1,3,5,6] \ 29 | dataset.clf.training.results_dir=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0_nu=4.0_cifar10_continuous_clf=ncsnppclf/\' \ 30 | dataset.clf.training.workers=1 \ 31 | dataset.clf.training.chkpt_prefix=\"tclf_gamma=0_nu=4.0_cifar10_Feb27\" -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/afhqv2/sample_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=afhqv2/afhqv2128_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 3 | dataset.diffusion.data.name='afhqv2' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.2 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.nu=4.01 \ 16 | dataset.diffusion.model.sde.gamma=0.01 \ 17 | dataset.diffusion.model.sde.kappa=0.04 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.evaluation.seed=0 \ 20 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 21 | dataset.diffusion.evaluation.devices=[0] \ 22 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/psld_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/dummy_samples/\' \ 23 | dataset.diffusion.evaluation.batch_size=2 \ 24 | dataset.diffusion.evaluation.stride_type='quadratic' \ 25 | dataset.diffusion.evaluation.sample_from='target' \ 26 | dataset.diffusion.evaluation.workers=1 \ 27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/psld_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/psld-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \ 28 | dataset.diffusion.evaluation.sampler.name="em_sde" \ 29 | dataset.diffusion.evaluation.n_samples=128 \ 30 | dataset.diffusion.evaluation.n_discrete_steps=1000 31 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/afhqv2/train_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_sde.py +dataset=afhqv2/afhqv2128_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 3 | dataset.diffusion.data.name='afhqv2' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.pl_module='sde_wrapper' \ 7 | dataset.diffusion.model.score_fn.in_ch=6 \ 8 | dataset.diffusion.model.score_fn.out_ch=6 \ 9 | dataset.diffusion.model.score_fn.nf=128 \ 10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \ 11 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 13 | dataset.diffusion.model.score_fn.dropout=0.2 \ 14 | dataset.diffusion.model.sde.beta_min=8.0 \ 15 | dataset.diffusion.model.sde.beta_max=8.0 \ 16 | dataset.diffusion.model.sde.decomp_mode='lower' \ 17 | dataset.diffusion.model.sde.nu=4.01 \ 18 | dataset.diffusion.model.sde.gamma=0.01 \ 19 | dataset.diffusion.model.sde.kappa=0.04 \ 20 | dataset.diffusion.model.sde.numerical_eps=1e-9 \ 21 | dataset.diffusion.training.loss.name='psld_score_loss' \ 22 | dataset.diffusion.training.seed=0 \ 23 | dataset.diffusion.training.chkpt_interval=50 \ 24 | dataset.diffusion.training.mode='hsm' \ 25 | dataset.diffusion.training.fp16=False \ 26 | dataset.diffusion.training.use_ema=True \ 27 | dataset.diffusion.training.batch_size=8 \ 28 | dataset.diffusion.training.epochs=2000 \ 29 | dataset.diffusion.training.accelerator='gpu' \ 30 | dataset.diffusion.training.devices=8 \ 31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/\' \ 32 | dataset.diffusion.training.workers=1 \ 33 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23\" -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/celeba64/sample_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \ 3 | dataset.diffusion.data.name='celeba64' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=3 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.nu=4.0 \ 16 | dataset.diffusion.model.sde.gamma=0.0 \ 17 | dataset.diffusion.model.sde.kappa=0.04 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.evaluation.seed=0 \ 20 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 21 | dataset.diffusion.evaluation.devices=8 \ 22 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0_nu=4.0_celeba64_continuous_sfn=ncsnpp/speedvsquality/sscs_uniform/\' \ 23 | dataset.diffusion.evaluation.batch_size=6 \ 24 | dataset.diffusion.evaluation.stride_type='uniform' \ 25 | dataset.diffusion.evaluation.sample_from='target' \ 26 | dataset.diffusion.evaluation.workers=1 \ 27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0_nu=4.0_celeba64_continuous_sfn=ncsnpp/checkpoints/psld-hsm_ablation_gamma=0_nu=4.0_celeba64_24thJan23-epoch=199-loss=0.0036.ckpt\' \ 28 | dataset.diffusion.evaluation.sampler.name="sscs_sde" \ 29 | dataset.diffusion.evaluation.n_samples=10000 \ 30 | dataset.diffusion.evaluation.n_discrete_steps=50 \ 31 | dataset.diffusion.evaluation.path_prefix='50' 32 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/celeba64/train_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_sde.py +dataset=celeba64/celeba64_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \ 3 | dataset.diffusion.data.name='celeba64' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.pl_module='sde_wrapper' \ 7 | dataset.diffusion.model.score_fn.in_ch=6 \ 8 | dataset.diffusion.model.score_fn.out_ch=6 \ 9 | dataset.diffusion.model.score_fn.nf=128 \ 10 | dataset.diffusion.model.score_fn.ch_mult=[1,1,2,2,2] \ 11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 13 | dataset.diffusion.model.score_fn.dropout=0.1 \ 14 | dataset.diffusion.model.sde.beta_min=8.0 \ 15 | dataset.diffusion.model.sde.beta_max=8.0 \ 16 | dataset.diffusion.model.sde.decomp_mode='lower' \ 17 | dataset.diffusion.model.sde.nu=4.25 \ 18 | dataset.diffusion.model.sde.gamma=0.25 \ 19 | dataset.diffusion.model.sde.kappa=0.04 \ 20 | dataset.diffusion.training.loss.name='psld_score_loss' \ 21 | dataset.diffusion.training.seed=0 \ 22 | dataset.diffusion.training.mode='hsm' \ 23 | dataset.diffusion.training.fp16=False \ 24 | dataset.diffusion.training.use_ema=True \ 25 | dataset.diffusion.training.batch_size=32 \ 26 | dataset.diffusion.training.epochs=200 \ 27 | dataset.diffusion.training.accelerator='gpu' \ 28 | dataset.diffusion.training.devices=4 \ 29 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0.01_nu=4.01_celeba64_continuous_sfn=ncsnpp/\' \ 30 | dataset.diffusion.training.workers=1 \ 31 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.01_nu=4.01_celeba64_27thJan23\" -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/sample_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=3 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.nu=4.0 \ 16 | dataset.diffusion.model.sde.gamma=0 \ 17 | dataset.diffusion.model.sde.kappa=0.04 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.model.sde.numerical_eps=1e-9 \ 20 | dataset.diffusion.evaluation.seed=0 \ 21 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 22 | dataset.diffusion.evaluation.devices=8 \ 23 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/speedvsquality/sscs_uniform/\' \ 24 | dataset.diffusion.evaluation.batch_size=16 \ 25 | dataset.diffusion.evaluation.stride_type='uniform' \ 26 | dataset.diffusion.evaluation.sample_from='target' \ 27 | dataset.diffusion.evaluation.workers=1 \ 28 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/psld-hsm_ablation_scorem_cifar10_5thJan23-epoch=1999-loss=0.0142.ckpt\' \ 29 | dataset.diffusion.evaluation.sampler.name="sscs_sde" \ 30 | dataset.diffusion.evaluation.n_samples=10000 \ 31 | dataset.diffusion.evaluation.n_discrete_steps=1000 \ 32 | dataset.diffusion.evaluation.path_prefix="1000" 33 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/sample_uncond_psld_ode.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=3 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.nu=4.0 \ 16 | dataset.diffusion.model.sde.gamma=0 \ 17 | dataset.diffusion.model.sde.kappa=0.04 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.model.sde.numerical_eps=1e-9 \ 20 | dataset.diffusion.evaluation.seed=0 \ 21 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 22 | dataset.diffusion.evaluation.devices=8 \ 23 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/ode_exps/solver=rk45/\' \ 24 | dataset.diffusion.evaluation.batch_size=16 \ 25 | dataset.diffusion.evaluation.sample_from='target' \ 26 | dataset.diffusion.evaluation.workers=1 \ 27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/psld-hsm_ablation_scorem_cifar10_5thJan23-epoch=1999-loss=0.0142.ckpt\' \ 28 | dataset.diffusion.evaluation.sampler.name="bb_ode" \ 29 | +dataset.diffusion.evaluation.sampler.solver="RK45" \ 30 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \ 31 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \ 32 | dataset.diffusion.evaluation.n_samples=10000 \ 33 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-5\" 34 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/sample_uncond_vpsde.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_vpsde \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=3 \ 7 | dataset.diffusion.model.score_fn.out_ch=3 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=0.1 \ 14 | dataset.diffusion.model.sde.beta_max=20 \ 15 | dataset.diffusion.evaluation.seed=0 \ 16 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 17 | dataset.diffusion.evaluation.devices=8 \ 18 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/speedvsquality/em_quadratic/\' \ 19 | dataset.diffusion.evaluation.batch_size=16 \ 20 | dataset.diffusion.evaluation.stride_type='quadratic' \ 21 | dataset.diffusion.evaluation.sample_from='target' \ 22 | dataset.diffusion.evaluation.workers=1 \ 23 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/vpsde-dsm_ablation_cifar10_5thJan23-epoch=1999-loss=0.0259.ckpt\' \ 24 | dataset.diffusion.evaluation.sampler.name="em_sde" \ 25 | dataset.diffusion.evaluation.n_samples=10000 \ 26 | dataset.diffusion.evaluation.n_discrete_steps=1000 \ 27 | dataset.diffusion.evaluation.path_prefix="1000" 28 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/sample_uncond_vpsde_ode.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_vpsde \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=3 \ 7 | dataset.diffusion.model.score_fn.out_ch=3 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=0.1 \ 14 | dataset.diffusion.model.sde.beta_max=20 \ 15 | dataset.diffusion.evaluation.seed=0 \ 16 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 17 | dataset.diffusion.evaluation.devices=8 \ 18 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/ode_exps/solver=rk45/\' \ 19 | dataset.diffusion.evaluation.batch_size=16 \ 20 | dataset.diffusion.evaluation.sample_from='target' \ 21 | dataset.diffusion.evaluation.workers=1 \ 22 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/vpsde-dsm_ablation_cifar10_5thJan23-epoch=1999-loss=0.0259.ckpt\' \ 23 | dataset.diffusion.evaluation.sampler.name="bb_ode" \ 24 | +dataset.diffusion.evaluation.sampler.solver="RK45" \ 25 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \ 26 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \ 27 | dataset.diffusion.evaluation.n_samples=10000 \ 28 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-5\" 29 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/train_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_sde.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.sde.beta_min=8.0 \ 14 | dataset.diffusion.model.sde.beta_max=8.0 \ 15 | dataset.diffusion.model.sde.decomp_mode='lower' \ 16 | dataset.diffusion.model.sde.nu=4.02 \ 17 | dataset.diffusion.model.sde.gamma=0.02 \ 18 | dataset.diffusion.model.sde.kappa=0.04 \ 19 | dataset.diffusion.training.seed=0 \ 20 | dataset.diffusion.training.chkpt_interval=50 \ 21 | dataset.diffusion.training.mode='hsm' \ 22 | dataset.diffusion.training.fp16=False \ 23 | dataset.diffusion.training.use_ema=True \ 24 | dataset.diffusion.training.batch_size=32 \ 25 | dataset.diffusion.training.epochs=2000 \ 26 | dataset.diffusion.training.accelerator='gpu' \ 27 | dataset.diffusion.training.devices=4 \ 28 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0.02_nu=4.02_decomp=upper_cifar10_continuous_sfn=ncsnpp/\' \ 29 | dataset.diffusion.training.workers=1 \ 30 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.02_nu=4.02_decomp=lower_cifar10_14thFeb23\" 31 | -------------------------------------------------------------------------------- /scripts_psld/ablations/uncond/cifar10/train_uncond_vpsde.sh: -------------------------------------------------------------------------------- 1 | # CIFAR-10 2 | python main/train_sde.py +dataset=cifar10/cifar10_vpsde \ 3 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 4 | dataset.diffusion.data.name='cifar10' \ 5 | dataset.diffusion.data.norm=True \ 6 | dataset.diffusion.data.hflip=True \ 7 | dataset.diffusion.model.score_fn.in_ch=3 \ 8 | dataset.diffusion.model.score_fn.out_ch=3 \ 9 | dataset.diffusion.model.score_fn.nf=128 \ 10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 13 | dataset.diffusion.model.score_fn.dropout=0.1 \ 14 | dataset.diffusion.training.seed=0 \ 15 | dataset.diffusion.training.fp16=False \ 16 | dataset.diffusion.training.use_ema=True \ 17 | dataset.diffusion.training.batch_size=32 \ 18 | dataset.diffusion.training.epochs=2000 \ 19 | dataset.diffusion.training.accelerator='gpu' \ 20 | dataset.diffusion.training.devices=[0,1,2,3] \ 21 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/vpsde_cifar10_continuous_sfn=ncsnpp_testnccl/\' \ 22 | dataset.diffusion.training.workers=1 \ 23 | dataset.diffusion.training.chkpt_prefix="dsm_ablation_cifar10_5thJan23" -------------------------------------------------------------------------------- /scripts_psld/fid.sh: -------------------------------------------------------------------------------- 1 | fidelity --gpu 0 --fid --input1 /home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/1000/batch/ --input2 /home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/1000/images/ 2 | -------------------------------------------------------------------------------- /scripts_psld/sota/cond/afhqv2/sample_inpaint_psld.sh: -------------------------------------------------------------------------------- 1 | ulimit -n 64000 2 | python main/eval/inpaint.py +dataset=afhqv2/afhqv2128_psld \ 3 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \ 4 | dataset.diffusion.data.name='afhqv2' \ 5 | +dataset.diffusion.data.mask_path=\'/home/pandeyk1/datasets\' \ 6 | dataset.diffusion.data.norm=True \ 7 | dataset.diffusion.data.hflip=True \ 8 | dataset.diffusion.model.score_fn.in_ch=6 \ 9 | dataset.diffusion.model.score_fn.out_ch=3 \ 10 | dataset.diffusion.model.score_fn.nf=160 \ 11 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,3,3] \ 12 | dataset.diffusion.model.score_fn.num_res_blocks=2 \ 13 | dataset.diffusion.model.score_fn.attn_resolutions=[8,16] \ 14 | dataset.diffusion.model.score_fn.dropout=0.2 \ 15 | dataset.diffusion.model.sde.beta_min=8.0 \ 16 | dataset.diffusion.model.sde.beta_max=8.0 \ 17 | dataset.diffusion.model.sde.nu=4.0 \ 18 | dataset.diffusion.model.sde.gamma=0 \ 19 | dataset.diffusion.model.sde.kappa=0.04 \ 20 | dataset.diffusion.model.sde.decomp_mode='lower' \ 21 | dataset.diffusion.evaluation.seed=0 \ 22 | dataset.diffusion.evaluation.denoise=True \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.path_prefix="1000" \ 25 | dataset.diffusion.evaluation.devices=8 \ 26 | dataset.diffusion.evaluation.batch_size=16 \ 27 | dataset.diffusion.evaluation.stride_type='quadratic' \ 28 | dataset.diffusion.evaluation.sample_from='target' \ 29 | dataset.diffusion.evaluation.workers=1 \ 30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0_nu=4.0_afhq128_continuous_sfn=ddpmpp/checkpoints/cached_chkpts/es3sde-hsm_sota_gamma=0_nu=4.0_afhq128_25May23_nf=160_chmult=12233-epoch=1999-loss=0.0011.ckpt\' \ 31 | dataset.diffusion.evaluation.sampler.name="ip_em_sde" \ 32 | dataset.diffusion.evaluation.n_samples=50000 \ 33 | dataset.diffusion.evaluation.n_discrete_steps=250 \ 34 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0_nu=4.0_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/\' 35 | -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/celeba64/sample_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \ 3 | dataset.diffusion.data.name='celeba64' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.nu=4.005 \ 19 | dataset.diffusion.model.sde.gamma=0.005 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.model.sde.decomp_mode='lower' \ 22 | dataset.diffusion.evaluation.seed=0 \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.devices=7 \ 25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/samples_50k/epoch=500_emquad/\' \ 26 | dataset.diffusion.evaluation.batch_size=4 \ 27 | dataset.diffusion.evaluation.stride_type='quadratic' \ 28 | dataset.diffusion.evaluation.sample_from='target' \ 29 | dataset.diffusion.evaluation.workers=0 \ 30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/checkpoints/psld-hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23-epoch=499-loss=0.0041.ckpt\' \ 31 | dataset.diffusion.evaluation.sampler.name="em_sde" \ 32 | dataset.diffusion.evaluation.n_samples=50000 \ 33 | dataset.diffusion.evaluation.n_discrete_steps=250 \ 34 | dataset.diffusion.evaluation.path_prefix="250" -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/celeba64/sample_uncond_psld_ode.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \ 3 | dataset.diffusion.data.name='celeba64' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.1 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.nu=4.005 \ 19 | dataset.diffusion.model.sde.gamma=0.005 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.model.sde.decomp_mode='lower' \ 22 | dataset.diffusion.evaluation.seed=0 \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.devices=7 \ 25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/samples_10k/epoch=500_ode/\' \ 26 | dataset.diffusion.evaluation.batch_size=4 \ 27 | dataset.diffusion.evaluation.sample_from='target' \ 28 | dataset.diffusion.evaluation.workers=0 \ 29 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/checkpoints/psld-hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23-epoch=499-loss=0.0041.ckpt\' \ 30 | dataset.diffusion.evaluation.sampler.name="bb_ode" \ 31 | +dataset.diffusion.evaluation.sampler.solver="RK45" \ 32 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \ 33 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \ 34 | dataset.diffusion.evaluation.n_samples=10000 \ 35 | dataset.diffusion.evaluation.path_prefix=\'tol=1e-5\' -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/celeba64/train_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_sde.py +dataset=celeba64/celeba64_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \ 3 | dataset.diffusion.data.name='celeba64' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.pl_module='sde_wrapper' \ 7 | dataset.diffusion.model.score_fn.in_ch=6 \ 8 | dataset.diffusion.model.score_fn.out_ch=6 \ 9 | dataset.diffusion.model.score_fn.nf=128 \ 10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \ 11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \ 12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 13 | dataset.diffusion.model.score_fn.dropout=0.1 \ 14 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 15 | dataset.diffusion.model.score_fn.fir=True \ 16 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 17 | dataset.diffusion.model.sde.decomp_mode='lower' \ 18 | dataset.diffusion.model.sde.nu=4.005 \ 19 | dataset.diffusion.model.sde.gamma=0.005 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.training.loss.name='psld_score_loss' \ 22 | dataset.diffusion.training.seed=0 \ 23 | dataset.diffusion.training.chkpt_interval=25 \ 24 | dataset.diffusion.training.mode='hsm' \ 25 | dataset.diffusion.training.fp16=False \ 26 | dataset.diffusion.training.use_ema=True \ 27 | dataset.diffusion.training.batch_size=16 \ 28 | dataset.diffusion.training.epochs=500 \ 29 | dataset.diffusion.training.accelerator='gpu' \ 30 | dataset.diffusion.training.devices=8 \ 31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/\' \ 32 | dataset.diffusion.training.workers=1 \ 33 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23\" -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/cifar10/sample_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.15 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.nu=4.02 \ 19 | dataset.diffusion.model.sde.gamma=0.02 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.model.sde.decomp_mode='lower' \ 22 | dataset.diffusion.evaluation.seed=0 \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.devices=8 \ 25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/speedvsquality_50k/em_uniform/\' \ 26 | dataset.diffusion.evaluation.batch_size=16 \ 27 | dataset.diffusion.evaluation.stride_type='uniform' \ 28 | dataset.diffusion.evaluation.sample_from='target' \ 29 | dataset.diffusion.evaluation.workers=1 \ 30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/psld-hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_17thFeb-epoch=2349-loss=0.0201.ckpt\' \ 31 | dataset.diffusion.evaluation.sampler.name="em_sde" \ 32 | dataset.diffusion.evaluation.n_samples=50000 \ 33 | dataset.diffusion.evaluation.n_discrete_steps=50 \ 34 | dataset.diffusion.evaluation.path_prefix="50" 35 | -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/cifar10/sample_uncond_psld_ode.sh: -------------------------------------------------------------------------------- 1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.15 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.nu=4.02 \ 19 | dataset.diffusion.model.sde.gamma=0.02 \ 20 | dataset.diffusion.model.sde.kappa=0.04 \ 21 | dataset.diffusion.model.sde.decomp_mode='lower' \ 22 | dataset.diffusion.evaluation.seed=0 \ 23 | dataset.diffusion.evaluation.sample_prefix='gpu' \ 24 | dataset.diffusion.evaluation.devices=8 \ 25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/samples_50k_ode/epoch=2350/\' \ 26 | dataset.diffusion.evaluation.batch_size=16 \ 27 | dataset.diffusion.evaluation.sample_from='target' \ 28 | dataset.diffusion.evaluation.workers=1 \ 29 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/psld-hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_17thFeb-epoch=2349-loss=0.0201.ckpt\' \ 30 | dataset.diffusion.evaluation.sampler.name="bb_ode" \ 31 | +dataset.diffusion.evaluation.sampler.solver="RK45" \ 32 | +dataset.diffusion.evaluation.sampler.rtol=1e-4 \ 33 | +dataset.diffusion.evaluation.sampler.atol=1e-4 \ 34 | dataset.diffusion.evaluation.n_samples=50000 \ 35 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-4\" 36 | -------------------------------------------------------------------------------- /scripts_psld/sota/uncond/cifar10/train_uncond_psld.sh: -------------------------------------------------------------------------------- 1 | python main/train_sde.py +dataset=cifar10/cifar10_psld \ 2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \ 3 | dataset.diffusion.data.name='cifar10' \ 4 | dataset.diffusion.data.norm=True \ 5 | dataset.diffusion.data.hflip=True \ 6 | dataset.diffusion.model.score_fn.in_ch=6 \ 7 | dataset.diffusion.model.score_fn.out_ch=6 \ 8 | dataset.diffusion.model.score_fn.nf=128 \ 9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \ 10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \ 11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \ 12 | dataset.diffusion.model.score_fn.dropout=0.15 \ 13 | dataset.diffusion.model.score_fn.progressive_input='residual' \ 14 | dataset.diffusion.model.score_fn.fir=True \ 15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \ 16 | dataset.diffusion.model.sde.beta_min=8.0 \ 17 | dataset.diffusion.model.sde.beta_max=8.0 \ 18 | dataset.diffusion.model.sde.decomp_mode='lower' \ 19 | dataset.diffusion.model.sde.nu=4.01 \ 20 | dataset.diffusion.model.sde.gamma=0.01 \ 21 | dataset.diffusion.model.sde.kappa=0.04 \ 22 | dataset.diffusion.training.seed=0 \ 23 | dataset.diffusion.training.chkpt_interval=50 \ 24 | dataset.diffusion.training.mode='hsm' \ 25 | dataset.diffusion.training.fp16=False \ 26 | dataset.diffusion.training.use_ema=True \ 27 | dataset.diffusion.training.batch_size=16 \ 28 | dataset.diffusion.training.epochs=2500 \ 29 | dataset.diffusion.training.accelerator='gpu' \ 30 | dataset.diffusion.training.devices=8 \ 31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/\' \ 32 | dataset.diffusion.training.workers=1 \ 33 | dataset.diffusion.training.chkpt_prefix=\"hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_3rdFeb\" 34 | --------------------------------------------------------------------------------