├── .gitignore
├── LICENSE
├── Pipfile
├── Pipfile.lock
├── README.md
├── assets
└── main.png
├── main
├── __init__.py
├── callbacks.py
├── configs
│ └── dataset
│ │ ├── afhqv2
│ │ └── afhqv2128_psld.yaml
│ │ ├── celeba64
│ │ ├── celeba64_psld.yaml
│ │ └── celeba64_vpsde.yaml
│ │ └── cifar10
│ │ ├── cifar10_psld.yaml
│ │ └── cifar10_vpsde.yaml
├── datasets
│ ├── __init__.py
│ ├── afhq.py
│ ├── celeba.py
│ ├── celebahq.py
│ ├── cifar10.py
│ ├── inpaint.py
│ └── latent.py
├── eval
│ ├── __init__.py
│ ├── class_cond_sample.py
│ ├── inpaint.py
│ └── sample.py
├── losses.py
├── models
│ ├── __init__.py
│ ├── clf_wrapper.py
│ ├── score_fn
│ │ ├── __init__.py
│ │ └── song_sde
│ │ │ ├── __init__.py
│ │ │ ├── layers.py
│ │ │ ├── layerspp.py
│ │ │ ├── ncsnpp.py
│ │ │ ├── ncsnpp_clf.py
│ │ │ ├── normalization.py
│ │ │ ├── op
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ ├── fused_bias_act.cpp
│ │ │ ├── fused_bias_act_kernel.cu
│ │ │ ├── upfirdn2d.cpp
│ │ │ ├── upfirdn2d.py
│ │ │ └── upfirdn2d_kernel.cu
│ │ │ ├── up_or_down_sampling.py
│ │ │ └── utils.py
│ ├── sde
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── psld.py
│ │ └── vpsde.py
│ └── wrapper.py
├── samplers
│ ├── __init__.py
│ ├── base.py
│ ├── ode.py
│ └── sde.py
├── train_clf.py
├── train_sde.py
└── util.py
└── scripts_psld
├── ablations
├── cond
│ ├── afhqv2
│ │ ├── sample_inpaint_psld.sh
│ │ ├── sample_tclf_psld.sh
│ │ └── train_tclf_psld.sh
│ └── cifar10
│ │ ├── sample_tclf_psld.sh
│ │ └── train_tclf_psld.sh
└── uncond
│ ├── afhqv2
│ ├── sample_uncond_psld.sh
│ └── train_uncond_psld.sh
│ ├── celeba64
│ ├── sample_uncond_psld.sh
│ └── train_uncond_psld.sh
│ └── cifar10
│ ├── sample_uncond_psld.sh
│ ├── sample_uncond_psld_ode.sh
│ ├── sample_uncond_vpsde.sh
│ ├── sample_uncond_vpsde_ode.sh
│ ├── train_uncond_psld.sh
│ └── train_uncond_vpsde.sh
├── fid.sh
└── sota
├── cond
└── afhqv2
│ └── sample_inpaint_psld.sh
└── uncond
├── celeba64
├── sample_uncond_psld.sh
├── sample_uncond_psld_ode.sh
└── train_uncond_psld.sh
└── cifar10
├── sample_uncond_psld.sh
├── sample_uncond_psld_ode.sh
└── train_uncond_psld.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | /scripts_test/
132 | /scripts_itsam/
133 | /scripts_slurm/
134 | /outputs/
135 | analysis/
136 | download.sh
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mandt-lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | url = "https://pypi.org/simple"
3 | verify_ssl = true
4 | name = "pypi"
5 |
6 | [packages]
7 | torch == "1.13.1+cu116"
8 | torchdiffeq = "0.2.3"
9 | torch-fidelity = "0.3.0"
10 | torchmetrics = "0.11.0"
11 | torchvision = "0.14.1+cu116"
12 | tqdm = "4.64.1"
13 | wandb = "0.13.5"
14 | scipy = "1.9.3"
15 | scikit-learn = "1.1.3"
16 | Pillow = "9.3.0"
17 | numpy = "1.23.5"
18 | matplotlib = "3.6.2"
19 | hydra-core = "1.2.0"
20 | pytorch-lightning = "1.9.0"
21 | ninja = "1.11.1"
22 |
23 | [dev-packages]
24 |
25 | [requires]
26 | python_version = "3.8"
27 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "7f7606f08e0544d8d012ef4d097dabdd6df6843a28793eb6551245d4b2db4242"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.8"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {},
19 | "develop": {}
20 | }
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Phase Space Langevin Diffusion
ICCV'23 (Oral Presentation)
2 |
3 |
7 |
8 |
9 | Official Implementation of the paper: A Complete Recipe for Diffusion Generative Models
10 |
11 | ## Overview
12 |
13 |
14 |
15 |
16 | We propose a complete recipe for constructing novel diffusion processes which are guaranteed to converge to a specified stationary distribution. This serves two benefits:
17 |
18 | - Firstly, a principled parameterization to construct diffusion models can allow for construction of more flexible processes which do not necessarily rely on physical intuition (like CLD)
19 | - Secondly, given a diffusion process, our parameterization can validate if the process converges to a specified stationary distribution.
20 |
21 |
22 | To instantiate this recipe, we propose a new diffusion model: Phase Space Langevin Diffusion (PSLD) which outperforms diffusion models like VP-SDE and CLD, and achieves excellent sample quality on standard image synthesis benchmarks like CIFAR-10 (FID:2.10) and CelebA-64 (FID: 2.01). PSLD also supports classifier-(free) guidance out of the box like other diffusion models.
23 |
24 | ## Code Overview
25 |
26 | This repo uses [PyTorch Lightning](https://www.pytorchlightning.ai/) for training and [Hydra](https://hydra.cc/docs/intro/) for config management so basic familiarity with both these tools is expected. Please clone the repo with `PSLD` as the working directory for any downstream tasks like setting up the dependencies, training and inference.
27 |
28 | ## Dependency Setup
29 |
30 | We use `pipenv` for a project-level dependency management. Simply [install](https://pipenv.pypa.io/en/latest/#install-pipenv-today) `pipenv` and run the following command:
31 |
32 | ```
33 | pipenv install
34 | ```
35 |
36 | ## Config Management
37 | We manage hydra configurations separately for each benchmark/dataset used in this work. All configs are present in the `main/configs` directory. This directory has subfolders named according to the dataset where each subfolder contains the corresponding config associated with a specific diffusion models.
38 |
39 |
60 |
61 | ## Training and Evaluation
62 |
63 | We include a sample training and inference script for CIFAR-10. We include all scripts used in this work in the directory `/scripts_psld/`.
64 |
65 | - Training a PSLD model on CIFAR-10:
66 |
67 | ```shell script
68 | python main/train_sde.py +dataset=cifar10/cifar10_psld \
69 | dataset.diffusion.data.root=\'/path/to/cifar10/\' \
70 | dataset.diffusion.data.name='cifar10' \
71 | dataset.diffusion.data.norm=True \
72 | dataset.diffusion.data.hflip=True \
73 | dataset.diffusion.model.score_fn.in_ch=6 \
74 | dataset.diffusion.model.score_fn.out_ch=6 \
75 | dataset.diffusion.model.score_fn.nf=128 \
76 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
77 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
78 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
79 | dataset.diffusion.model.score_fn.dropout=0.15 \
80 | dataset.diffusion.model.score_fn.progressive_input='residual' \
81 | dataset.diffusion.model.score_fn.fir=True \
82 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
83 | dataset.diffusion.model.sde.beta_min=8.0 \
84 | dataset.diffusion.model.sde.beta_max=8.0 \
85 | dataset.diffusion.model.sde.decomp_mode='lower' \
86 | dataset.diffusion.model.sde.nu=4.01 \
87 | dataset.diffusion.model.sde.gamma=0.01 \
88 | dataset.diffusion.model.sde.kappa=0.04 \
89 | dataset.diffusion.training.seed=0 \
90 | dataset.diffusion.training.chkpt_interval=50 \
91 | dataset.diffusion.training.mode='hsm' \
92 | dataset.diffusion.training.fp16=False \
93 | dataset.diffusion.training.use_ema=True \
94 | dataset.diffusion.training.batch_size=16 \
95 | dataset.diffusion.training.epochs=2500 \
96 | dataset.diffusion.training.accelerator='gpu' \
97 | dataset.diffusion.training.devices=8 \
98 | dataset.diffusion.training.results_dir=\'/path/to/logdir/' \
99 | dataset.diffusion.training.workers=1
100 | ```
101 |
102 | - Generating CIFAR-10 samples from the PSLD model
103 |
104 | ```shell script
105 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \
106 | dataset.diffusion.model.score_fn.in_ch=6 \
107 | dataset.diffusion.model.score_fn.out_ch=6 \
108 | dataset.diffusion.model.score_fn.nf=128 \
109 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
110 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
111 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
112 | dataset.diffusion.model.score_fn.dropout=0.15 \
113 | dataset.diffusion.model.score_fn.progressive_input='residual' \
114 | dataset.diffusion.model.score_fn.fir=True \
115 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
116 | dataset.diffusion.model.sde.beta_min=8.0 \
117 | dataset.diffusion.model.sde.beta_max=8.0 \
118 | dataset.diffusion.model.sde.nu=4.01 \
119 | dataset.diffusion.model.sde.gamma=0.01 \
120 | dataset.diffusion.model.sde.kappa=0.04 \
121 | dataset.diffusion.model.sde.decomp_mode='lower' \
122 | dataset.diffusion.evaluation.seed=0 \
123 | dataset.diffusion.evaluation.sample_prefix='some_prefix_for_sample_names' \
124 | dataset.diffusion.evaluation.devices=8 \
125 | dataset.diffusion.evaluation.save_path=\'/path/to/generated/samples/\' \
126 | dataset.diffusion.evaluation.batch_size=16 \
127 | dataset.diffusion.evaluation.stride_type='uniform' \
128 | dataset.diffusion.evaluation.sample_from='target' \
129 | dataset.diffusion.evaluation.workers=1 \
130 | dataset.diffusion.evaluation.chkpt_path=\'/path/to/pretrained/chkpt.ckpt\' \
131 | dataset.diffusion.evaluation.sampler.name="em_sde" \
132 | dataset.diffusion.evaluation.n_samples=50000 \
133 | dataset.diffusion.evaluation.n_discrete_steps=50 \
134 | dataset.diffusion.evaluation.path_prefix="50"
135 | ```
136 | We evaluate sample quality using FID scores. We compute FID scores using the `torch-fidelity` package.
137 |
138 | ## Pretrained Checkpoints
139 | Pre-trained PSLD checkpoints can be found [here](https://personalmicrosoftsoftware-my.sharepoint.com/:f:/g/personal/pandeyk1_personalmicrosoftsoftware_uci_edu/EhehC1yAF1pKv1Vp2rZpMTsBG9v1dyjmvmGzkJSWkjPsfQ?e=fSTVao)
140 |
141 | ## Citation
142 | If you find the code useful for your research, please consider citing our ICCV'23 paper:
143 |
144 | ```bib
145 | @InProceedings{Pandey_2023_ICCV,
146 | author = {Pandey, Kushagra and Mandt, Stephan},
147 | title = {A Complete Recipe for Diffusion Generative Models},
148 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
149 | month = {October},
150 | year = {2023},
151 | pages = {4261-4272}
152 | }
153 | ```
154 |
155 | ## License
156 | See the LICENSE file.
157 |
--------------------------------------------------------------------------------
/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/assets/main.png
--------------------------------------------------------------------------------
/main/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/__init__.py
--------------------------------------------------------------------------------
/main/callbacks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Sequence, Union
4 |
5 | import torch
6 | from pytorch_lightning import Callback, LightningModule, Trainer
7 | from pytorch_lightning.callbacks import BasePredictionWriter
8 | from pytorch_lightning.callbacks import Callback
9 | from torch import Tensor
10 | from torch.nn import Module
11 |
12 | from util import save_as_images, save_as_np
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class EMAWeightUpdate(Callback):
18 | """EMA weight update
19 | Your model should have:
20 | - ``self.online_network``
21 | - ``self.target_network``
22 | Updates the target_network params using an exponential moving average update rule weighted by tau.
23 | BYOL claims this keeps the online_network from collapsing.
24 | .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step
25 | Example::
26 | # model must have 2 attributes
27 | model = Model()
28 | model.online_network = ...
29 | model.target_network = ...
30 | trainer = Trainer(callbacks=[EMAWeightUpdate()])
31 | """
32 |
33 | def __init__(self, tau: float = 0.9999):
34 | """
35 | Args:
36 | tau: EMA decay rate
37 | """
38 | super().__init__()
39 | self.tau = tau
40 | logger.info(f"Setup EMA callback with tau: {self.tau}")
41 |
42 | def on_train_batch_end(
43 | self,
44 | trainer: Trainer,
45 | pl_module: LightningModule,
46 | outputs: Sequence,
47 | batch: Sequence,
48 | batch_idx: int,
49 | ) -> None:
50 | # get networks
51 | online_net = pl_module.score_fn
52 | target_net = pl_module.ema_score_fn
53 |
54 | # update weights
55 | self.update_weights(online_net, target_net)
56 |
57 | def update_weights(
58 | self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]
59 | ) -> None:
60 | # apply MA weight update
61 | with torch.no_grad():
62 | for targ, src in zip(target_net.parameters(), online_net.parameters()):
63 | targ.mul_(self.tau).add_(src, alpha=1 - self.tau)
64 |
65 |
66 | # TODO: Add Support for saving momentum images
67 | class SimpleImageWriter(BasePredictionWriter):
68 | """Pytorch Lightning Callback for writing a batch of images to disk."""
69 |
70 | def __init__(
71 | self,
72 | output_dir,
73 | write_interval,
74 | sample_prefix="",
75 | path_prefix="",
76 | save_mode="image",
77 | is_norm=True,
78 | is_augmented=True,
79 | ):
80 | super().__init__(write_interval)
81 | self.output_dir = output_dir
82 | self.sample_prefix = sample_prefix
83 | self.path_prefix = path_prefix
84 | self.is_norm = is_norm
85 | self.is_augmented = is_augmented
86 | self.save_fn = save_as_images if save_mode == "image" else save_as_np
87 |
88 | def write_on_batch_end(
89 | self,
90 | trainer,
91 | pl_module,
92 | prediction,
93 | batch_indices,
94 | batch,
95 | batch_idx,
96 | dataloader_idx,
97 | ):
98 | rank = pl_module.global_rank
99 |
100 | # Write output images
101 | # NOTE: We need to use gpu rank during saving to prevent
102 | # processes from overwriting images
103 | samples = prediction.cpu()
104 |
105 | # Ignore momentum states if the SDE is augmented
106 | if self.is_augmented:
107 | samples, _ = torch.chunk(samples, 2, dim=1)
108 |
109 | # Setup save dirs
110 | if self.path_prefix != "":
111 | base_save_path = os.path.join(self.output_dir, str(self.path_prefix))
112 | else:
113 | base_save_path = self.output_dir
114 | img_save_path = os.path.join(base_save_path, "images")
115 | os.makedirs(img_save_path, exist_ok=True)
116 |
117 | # Save images
118 | self.save_fn(
119 | samples,
120 | file_name=os.path.join(
121 | img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
122 | ),
123 | denorm=self.is_norm,
124 | )
125 |
126 |
127 | class InpaintingImageWriter(BasePredictionWriter):
128 | """Pytorch Lightning Callback for writing a batch of images to disk.
129 | Specifically adapted for Image inpainting.
130 | """
131 |
132 | def __init__(
133 | self,
134 | output_dir,
135 | write_interval,
136 | eval_mode="sample",
137 | sample_prefix="",
138 | path_prefix="",
139 | save_mode="image",
140 | is_norm=True,
141 | is_augmented=True,
142 | save_batch=False,
143 | ):
144 | super().__init__(write_interval)
145 | assert eval_mode in ["sample", "recons"]
146 | self.output_dir = output_dir
147 | self.eval_mode = eval_mode
148 | self.sample_prefix = sample_prefix
149 | self.path_prefix = path_prefix
150 | self.is_norm = is_norm
151 | self.is_augmented = is_augmented
152 | self.save_fn = save_as_images if save_mode == "image" else save_as_np
153 | self.save_batch = save_batch
154 |
155 | def write_on_batch_end(
156 | self,
157 | trainer,
158 | pl_module,
159 | prediction,
160 | batch_indices,
161 | batch,
162 | batch_idx,
163 | dataloader_idx,
164 | ):
165 | rank = pl_module.global_rank
166 |
167 | # Write output images
168 | # NOTE: We need to use gpu rank during saving to prevent
169 | # processes from overwriting images
170 | samples = prediction.cpu()
171 |
172 | if self.is_augmented:
173 | samples, _ = torch.chunk(samples, 2, dim=1)
174 |
175 | # Setup dirs
176 | if self.path_prefix != "":
177 | base_save_path = os.path.join(self.output_dir, str(self.path_prefix))
178 | else:
179 | base_save_path = self.output_dir
180 | img_save_path = os.path.join(base_save_path, "images")
181 | os.makedirs(img_save_path, exist_ok=True)
182 |
183 | # Save images
184 | self.save_fn(
185 | samples,
186 | file_name=os.path.join(
187 | img_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
188 | ),
189 | denorm=self.is_norm,
190 | )
191 |
192 | # Save batch (For inpainting)
193 | if self.save_batch:
194 | batch_save_path = os.path.join(base_save_path, "batch")
195 | corr_save_path = os.path.join(base_save_path, "corrupt")
196 | os.makedirs(batch_save_path, exist_ok=True)
197 | os.makedirs(corr_save_path, exist_ok=True)
198 | img, mask = batch
199 | img = img * 0.5 + 0.5
200 | self.save_fn(
201 | img * mask,
202 | file_name=os.path.join(
203 | corr_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
204 | ),
205 | denorm=False,
206 | )
207 | self.save_fn(
208 | img,
209 | file_name=os.path.join(
210 | batch_save_path, f"output_{self.sample_prefix }_{rank}_{batch_idx}"
211 | ),
212 | denorm=False,
213 | )
214 |
--------------------------------------------------------------------------------
/main/configs/dataset/afhqv2/afhqv2128_psld.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | data:
3 | root: ???
4 | name: "afhqv2"
5 | image_size: 128
6 | hflip: True
7 | num_channels: 3
8 | norm: True
9 | return_target: False
10 |
11 | model:
12 | pl_module: 'sde_wrapper'
13 | score_fn:
14 | name: "ncsnpp"
15 | in_ch: 6
16 | out_ch: 6
17 | nonlinearity: "swish"
18 | nf : 128
19 | ch_mult: [1,2,2,2,3]
20 | num_res_blocks: 2
21 | attn_resolutions: [16,]
22 | dropout: 0.2
23 | resamp_with_conv: True
24 | noise_cond: True
25 | fir: False
26 | fir_kernel: [1,3,3,1]
27 | skip_rescale: True
28 | resblock_type: "biggan"
29 | progressive: "none"
30 | progressive_input: "none"
31 | progressive_combine: "sum"
32 | embedding_type: "positional"
33 | init_scale: 0.0
34 | fourier_scale: 16
35 | sde:
36 | name: "psld"
37 | beta_min: 8.0
38 | beta_max: 8.0
39 | nu: 4.01
40 | gamma: 0.01
41 | kappa: 0.04
42 | decomp_mode: "lower"
43 | numerical_eps: 1e-9
44 | n_timesteps: 1000
45 | is_augmented: True
46 |
47 | training:
48 | seed: 0
49 | continuous: True
50 | mode: 'hsm'
51 | loss:
52 | name: "psld_score_loss"
53 | l_type: "l2"
54 | reduce_mean: True
55 | weighting: "fid"
56 | optimizer:
57 | name: "Adam"
58 | lr: 1e-4
59 | beta_1: 0.9
60 | beta_2: 0.999
61 | weight_decay: 0
62 | eps: 1e-8
63 | warmup: 5000
64 | grad_clip: 1.0
65 | train_eps: 1e-5
66 | fp16: False
67 | use_ema: True
68 | ema_decay: 0.9999
69 | batch_size: 32
70 | epochs: 5000
71 | log_step: 1
72 | accelerator: "gpu"
73 | devices: [0]
74 | chkpt_interval: 1
75 | restore_path: ""
76 | results_dir: ???
77 | workers: 1
78 | chkpt_prefix: ""
79 |
80 | evaluation:
81 | sampler:
82 | name: em_sde
83 | seed: 0
84 | chkpt_path: ???
85 | save_path: ???
86 | n_discrete_steps: 1000
87 | denoise: True
88 | eval_eps: 1e-3
89 | stride_type: uniform
90 | use_pflow: False
91 | sample_from: "target"
92 | accelerator: "gpu"
93 | devices: [0]
94 | n_samples: 50000
95 | workers: 2
96 | batch_size: 64
97 | save_mode: image
98 | sample_prefix: "gpu"
99 | path_prefix: ""
100 |
101 | clf:
102 | data:
103 | root: ???
104 | name: afhqv2
105 | image_size: 128
106 | hflip: true
107 | num_channels: 3
108 | norm: true
109 | return_target: true
110 |
111 | model:
112 | pl_module: tclf_wrapper
113 | clf_fn:
114 | name: ncsnpp_clf
115 | in_ch: 6
116 | nonlinearity: swish
117 | nf: 128
118 | ch_mult: [1,2,2,2]
119 | num_res_blocks: 2
120 | attn_resolutions: [16]
121 | dropout: 0.1
122 | resamp_with_conv: true
123 | noise_cond: true
124 | fir: false
125 | fir_kernel: [1,3,3,1]
126 | skip_rescale: true
127 | resblock_type: biggan
128 | progressive: none
129 | progressive_input: none
130 | progressive_combine: sum
131 | embedding_type: positional
132 | init_scale: 0
133 | fourier_scale: 16
134 | n_cls: ???
135 |
136 | training:
137 | seed: 0
138 | continuous: true
139 | loss:
140 | name: tce_loss
141 | l_type: l2
142 | reduce_mean: true
143 | optimizer:
144 | name: Adam
145 | lr: 0.0002
146 | beta_1: 0.9
147 | beta_2: 0.999
148 | weight_decay: 0
149 | eps: 1e-8
150 | warmup: 5000
151 | fp16: false
152 | batch_size: 32
153 | epochs: 500
154 | log_step: 1
155 | accelerator: gpu
156 | devices: [0]
157 | chkpt_interval: 1
158 | restore_path: ""
159 | results_dir: ???
160 | workers: 1
161 | chkpt_prefix: ""
162 |
163 | evaluation:
164 | seed: 0
165 | chkpt_path: ???
166 | accelerator: gpu
167 | devices: [0]
168 | workers: 1
169 | batch_size: 64
170 | clf_temp: 1.0
171 | label_to_sample: 0
172 |
--------------------------------------------------------------------------------
/main/configs/dataset/celeba64/celeba64_psld.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | data:
3 | root: ???
4 | name: "celeba64"
5 | image_size: 64
6 | hflip: True
7 | num_channels: 3
8 | norm: True
9 |
10 | model:
11 | pl_module: 'sde_wrapper'
12 | score_fn:
13 | name: "ncsnpp"
14 | in_ch: 3
15 | out_ch: 3
16 | nonlinearity: "swish"
17 | nf : 128
18 | ch_mult: [1,2,2,2,4]
19 | num_res_blocks: 2
20 | attn_resolutions: [16,]
21 | dropout: 0.1
22 | resamp_with_conv: True
23 | noise_cond: True
24 | fir: False
25 | fir_kernel: [1,3,3,1]
26 | skip_rescale: True
27 | resblock_type: "biggan"
28 | progressive: "none"
29 | progressive_input: "none"
30 | progressive_combine: "sum"
31 | embedding_type: "positional"
32 | init_scale: 0.0
33 | fourier_scale: 16
34 | sde:
35 | name: "psld"
36 | beta_min: 8.0
37 | beta_max: 8.0
38 | nu: 1
39 | gamma: 2
40 | kappa: 0.04
41 | decomp_mode: "lower"
42 | numerical_eps: 1e-9
43 | n_timesteps: 1000
44 | is_augmented: True
45 |
46 | training:
47 | seed: 0
48 | continuous: True
49 | mode: 'hsm'
50 | loss:
51 | name: "psld_score_loss"
52 | l_type: "l2"
53 | reduce_mean: True
54 | weighting: "fid"
55 | optimizer:
56 | name: "Adam"
57 | lr: 2e-4
58 | beta_1: 0.9
59 | beta_2: 0.999
60 | weight_decay: 0
61 | eps: 1e-8
62 | warmup: 5000
63 | grad_clip: 1.0
64 | train_eps: 1e-5
65 | fp16: False
66 | use_ema: True
67 | ema_decay: 0.9999
68 | batch_size: 32
69 | epochs: 5000
70 | log_step: 1
71 | accelerator: "gpu"
72 | devices: [0]
73 | chkpt_interval: 1
74 | restore_path: ""
75 | results_dir: ???
76 | workers: 1
77 | chkpt_prefix: ""
78 |
79 | evaluation:
80 | sampler:
81 | name: em_sde
82 | seed: 0
83 | chkpt_path: ???
84 | save_path: ???
85 | n_discrete_steps: 1000
86 | denoise: True
87 | eval_eps: 1e-3
88 | stride_type: uniform
89 | use_pflow: False
90 | sample_from: "target"
91 | accelerator: "gpu"
92 | devices: [0]
93 | n_samples: 50000
94 | workers: 2
95 | batch_size: 64
96 | save_mode: image
97 | sample_prefix: "gpu"
98 | path_prefix: ""
99 |
--------------------------------------------------------------------------------
/main/configs/dataset/celeba64/celeba64_vpsde.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | data:
3 | root: ???
4 | name: "cifar10"
5 | image_size: 64
6 | hflip: True
7 | num_channels: 3
8 | norm: True
9 |
10 | model:
11 | score_fn:
12 | name: "ncsnpp"
13 | in_ch: 3
14 | out_ch: 3
15 | nonlinearity: "swish"
16 | nf : 128
17 | ch_mult: [1,2,2,2,4]
18 | num_res_blocks: 2
19 | attn_resolutions: [16,]
20 | dropout: 0.1
21 | resamp_with_conv: True
22 | noise_cond: True
23 | fir: False
24 | fir_kernel: [1,3,3,1]
25 | skip_rescale: True
26 | resblock_type: "biggan"
27 | progressive: "none"
28 | progressive_input: "none"
29 | progressive_combine: "sum"
30 | embedding_type: "positional"
31 | init_scale: 0.0
32 | fourier_scale: 16
33 | # scale_by_sigma: False (The Unet always predicts epsilon)
34 | # n_heads: 8
35 | sde:
36 | name: "vpsde"
37 | beta_min: 0.1
38 | beta_max: 20
39 | n_timesteps: 1000
40 | is_augmented: False
41 |
42 | training:
43 | seed: 0
44 | continuous: True
45 | loss:
46 | name: "score_loss"
47 | l_type: "l2"
48 | reduce_mean: True
49 | weighting: "fid"
50 | optimizer:
51 | name: "Adam"
52 | lr: 2e-4
53 | beta_1: 0.9
54 | beta_2: 0.999
55 | weight_decay: 0
56 | eps: 1e-8
57 | warmup: 5000
58 | grad_clip: 1.0
59 | train_eps: 1e-5
60 | fp16: False
61 | use_ema: True
62 | ema_decay: 0.9999
63 | batch_size: 32
64 | epochs: 5000
65 | log_step: 1
66 | accelerator: "gpu"
67 | devices: [0]
68 | chkpt_interval: 1
69 | restore_path: ""
70 | results_dir: ???
71 | workers: 1
72 | chkpt_prefix: ""
73 |
74 | evaluation:
75 | # Sampler specific config goes here
76 | sampler:
77 | name: em_sde
78 | seed: 0
79 | chkpt_path: ???
80 | save_path: ???
81 | n_discrete_steps: 1000
82 | denoise: True
83 | eval_eps: 1e-3
84 | stride_type: uniform
85 | use_pflow: False
86 | sample_from: "target"
87 | accelerator: "gpu"
88 | devices: [0]
89 | n_samples: 50000
90 | workers: 2
91 | batch_size: 64
92 | save_mode: image
93 | sample_prefix: "gpu"
94 |
95 | # VAE config used for VAE training
96 | vae:
97 | data:
98 | root: ???
99 | name: "cifar10"
100 | image_size: 32
101 | n_channels: 3
102 | hflip: False
103 |
104 | model:
105 | enc_block_config : "32x7,32d2,32t16,16x4,16d2,16t8,8x4,8d2,8t4,4x3,4d4,4t1,1x3"
106 | enc_channel_config: "32:64,16:128,8:256,4:256,1:512"
107 | dec_block_config: "1x1,1u4,1t4,4x2,4u2,4t8,8x3,8u2,8t16,16x7,16u2,16t32,32x15"
108 | dec_channel_config: "32:64,16:128,8:256,4:256,1:512"
109 |
110 | training:
111 | seed: 0
112 | fp16: False
113 | batch_size: 128
114 | epochs: 1000
115 | log_step: 1
116 | device: "gpu:0"
117 | chkpt_interval: 1
118 | optimizer: "Adam"
119 | lr: 1e-4
120 | restore_path: ""
121 | results_dir: ???
122 | workers: 2
123 | chkpt_prefix: ""
124 | alpha: 1.0
125 |
--------------------------------------------------------------------------------
/main/configs/dataset/cifar10/cifar10_psld.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | data:
3 | root: ???
4 | name: cifar10
5 | image_size: 32
6 | hflip: true
7 | num_channels: 3
8 | norm: true
9 | return_target: False
10 |
11 | model:
12 | pl_module: sde_wrapper
13 | score_fn:
14 | name: ncsnpp
15 | in_ch: 6
16 | out_ch: 6
17 | nonlinearity: swish
18 | nf: 128
19 | ch_mult: [1,2,2,2]
20 | num_res_blocks: 4
21 | attn_resolutions: [16]
22 | dropout: 0.1
23 | resamp_with_conv: true
24 | noise_cond: true
25 | fir: false
26 | fir_kernel: [1,3,3,1]
27 | skip_rescale: true
28 | resblock_type: biggan
29 | progressive: none
30 | progressive_input: none
31 | progressive_combine: sum
32 | embedding_type: positional
33 | init_scale: 0
34 | fourier_scale: 16
35 | sde:
36 | name: psld
37 | beta_min: 8
38 | beta_max: 8
39 | nu: 4.01
40 | gamma: 0.01
41 | kappa: 0.04
42 | decomp_mode: lower
43 | numerical_eps: 1e-9
44 | n_timesteps: 1000
45 | is_augmented: true
46 |
47 | training:
48 | seed: 0
49 | continuous: true
50 | mode: hsm
51 | loss:
52 | name: psld_score_loss
53 | l_type: l2
54 | reduce_mean: true
55 | weighting: fid
56 | optimizer:
57 | name: Adam
58 | lr: 0.0002
59 | beta_1: 0.9
60 | beta_2: 0.999
61 | weight_decay: 0
62 | eps: 1e-8
63 | warmup: 5000
64 | grad_clip: 1
65 | train_eps: 0.00001
66 | fp16: false
67 | use_ema: true
68 | ema_decay: 0.9999
69 | batch_size: 32
70 | epochs: 5000
71 | log_step: 1
72 | accelerator: gpu
73 | devices: [0]
74 | chkpt_interval: 1
75 | restore_path: ""
76 | results_dir: ???
77 | workers: 1
78 | chkpt_prefix: ""
79 |
80 | evaluation:
81 | sampler:
82 | name: em_sde
83 | seed: 0
84 | chkpt_path: ???
85 | save_path: ???
86 | n_discrete_steps: 1000
87 | denoise: true
88 | eval_eps: 0.001
89 | stride_type: uniform
90 | use_pflow: false
91 | sample_from: target
92 | accelerator: gpu
93 | devices: [0]
94 | n_samples: 50000
95 | workers: 2
96 | batch_size: 64
97 | save_mode: image
98 | sample_prefix: gpu
99 | path_prefix: ""
100 |
101 | clf:
102 | data:
103 | root: ???
104 | name: cifar10
105 | image_size: 32
106 | hflip: true
107 | num_channels: 3
108 | norm: true
109 | return_target: true
110 |
111 | model:
112 | pl_module: tclf_wrapper
113 | clf_fn:
114 | name: ncsnpp_clf
115 | in_ch: 6
116 | nonlinearity: swish
117 | nf: 128
118 | ch_mult: [1,2,2,2]
119 | num_res_blocks: 2
120 | attn_resolutions: [16]
121 | dropout: 0.1
122 | resamp_with_conv: true
123 | noise_cond: true
124 | fir: false
125 | fir_kernel: [1,3,3,1]
126 | skip_rescale: true
127 | resblock_type: biggan
128 | progressive: none
129 | progressive_input: none
130 | progressive_combine: sum
131 | embedding_type: positional
132 | init_scale: 0
133 | fourier_scale: 16
134 | n_cls: ???
135 |
136 | training:
137 | seed: 0
138 | continuous: true
139 | loss:
140 | name: tce_loss
141 | l_type: l2
142 | reduce_mean: true
143 | optimizer:
144 | name: Adam
145 | lr: 0.0002
146 | beta_1: 0.9
147 | beta_2: 0.999
148 | weight_decay: 0
149 | eps: 1e-8
150 | warmup: 5000
151 | fp16: false
152 | batch_size: 32
153 | epochs: 500
154 | log_step: 1
155 | accelerator: gpu
156 | devices: [0]
157 | chkpt_interval: 1
158 | restore_path: ""
159 | results_dir: ???
160 | workers: 1
161 | chkpt_prefix: ""
162 |
163 | evaluation:
164 | seed: 0
165 | chkpt_path: ???
166 | accelerator: gpu
167 | devices: [0]
168 | workers: 1
169 | batch_size: 64
170 | clf_temp: 1.0
171 | label_to_sample: 0
172 |
--------------------------------------------------------------------------------
/main/configs/dataset/cifar10/cifar10_vpsde.yaml:
--------------------------------------------------------------------------------
1 | diffusion:
2 | data:
3 | root: ???
4 | name: "cifar10"
5 | image_size: 32
6 | hflip: True
7 | num_channels: 3
8 | norm: True
9 |
10 | model:
11 | score_fn:
12 | name: "ncsnpp"
13 | in_ch: 3
14 | out_ch: 3
15 | nonlinearity: "swish"
16 | nf : 128
17 | ch_mult: [1,2,2,2]
18 | num_res_blocks: 4
19 | attn_resolutions: [16,]
20 | dropout: 0.1
21 | resamp_with_conv: True
22 | noise_cond: True
23 | fir: False
24 | fir_kernel: [1,3,3,1]
25 | skip_rescale: True
26 | resblock_type: "biggan"
27 | progressive: "none"
28 | progressive_input: "none"
29 | progressive_combine: "sum"
30 | embedding_type: "positional"
31 | init_scale: 0.0
32 | fourier_scale: 16
33 | # scale_by_sigma: False (The Unet always predicts epsilon)
34 | # n_heads: 8
35 | sde:
36 | name: "vpsde"
37 | beta_min: 0.1
38 | beta_max: 20
39 | n_timesteps: 1000
40 | is_augmented: False
41 |
42 | training:
43 | seed: 0
44 | continuous: True
45 | loss:
46 | name: "score_loss"
47 | l_type: "l2"
48 | reduce_mean: True
49 | weighting: "fid"
50 | optimizer:
51 | name: "Adam"
52 | lr: 2e-4
53 | beta_1: 0.9
54 | beta_2: 0.999
55 | weight_decay: 0
56 | eps: 1e-8
57 | warmup: 5000
58 | grad_clip: 1.0
59 | train_eps: 1e-5
60 | fp16: False
61 | use_ema: True
62 | ema_decay: 0.9999
63 | batch_size: 32
64 | epochs: 5000
65 | log_step: 1
66 | accelerator: "gpu"
67 | devices: [0]
68 | chkpt_interval: 1
69 | restore_path: ""
70 | results_dir: ???
71 | workers: 1
72 | chkpt_prefix: ""
73 |
74 | evaluation:
75 | # Sampler specific config goes here
76 | sampler:
77 | name: em_sde
78 | seed: 0
79 | chkpt_path: ???
80 | save_path: ???
81 | n_discrete_steps: 1000
82 | denoise: True
83 | eval_eps: 1e-3
84 | stride_type: uniform
85 | use_pflow: False
86 | sample_from: "target"
87 | accelerator: "gpu"
88 | devices: [0]
89 | n_samples: 50000
90 | workers: 2
91 | batch_size: 64
92 | save_mode: image
93 | sample_prefix: "gpu"
94 | path_prefix: ""
95 |
--------------------------------------------------------------------------------
/main/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .celeba import CelebADataset
2 | from .cifar10 import CIFAR10Dataset
3 | from .celebahq import CelebAHQDataset
4 | from .afhq import AFHQv2Dataset
5 | from .inpaint import InpaintDataset
6 |
--------------------------------------------------------------------------------
/main/datasets/afhq.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from tqdm import tqdm
7 | from util import data_scaler, register_module
8 |
9 |
10 | @register_module(category="datasets", name="afhqv2")
11 | class AFHQv2Dataset(Dataset):
12 | """Implementation of the AFHQv2 dataset.
13 | Downloaded from https://github.com/clovaai/stargan-v2
14 | """
15 | def __init__(
16 | self,
17 | root,
18 | norm=True,
19 | subsample_size=None,
20 | return_target=False,
21 | transform=None,
22 | **kwargs,
23 | ):
24 | if not os.path.isdir(root):
25 | raise ValueError(f"The specified root: {root} does not exist")
26 | self.root = root
27 | self.transform = transform
28 | self.norm = norm
29 | self.subsample_size = subsample_size
30 | self.return_target = return_target
31 |
32 | self.images = []
33 | self.labels = []
34 |
35 | cat = kwargs.get("cat", [])
36 | is_train = kwargs.get("train", True)
37 | subfolder_list = ["dog", "cat", "wild"] if cat == [] else cat
38 | base_path = os.path.join(self.root, "train" if is_train else "test")
39 | for idx, subfolder in enumerate(subfolder_list):
40 | sub_path = os.path.join(base_path, subfolder)
41 |
42 | for img in tqdm(os.listdir(sub_path)):
43 | self.images.append(os.path.join(sub_path, img))
44 | self.labels.append(idx)
45 |
46 | def __getitem__(self, idx):
47 | img_path = self.images[idx]
48 | img = Image.open(img_path)
49 | if self.transform is not None:
50 | img = self.transform(img)
51 |
52 | # Scale images between [-1, 1] or [0, 1]
53 | # Normalize
54 | img = data_scaler(img, norm=self.norm)
55 |
56 | # Return Targets actually returns the class-label based on the animal category.
57 | # This is only helpful when using guidance using generation.
58 | if self.return_target:
59 | return torch.from_numpy(img).permute(2, 0, 1).float(), self.labels[idx]
60 | return torch.from_numpy(img).permute(2, 0, 1).float()
61 |
62 | def __len__(self):
63 | return len(self.images) if self.subsample_size is None else self.subsample_size
64 |
--------------------------------------------------------------------------------
/main/datasets/celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from tqdm import tqdm
7 | from util import data_scaler, register_module
8 |
9 |
10 | @register_module(category="datasets", name="celeba64")
11 | class CelebADataset(Dataset):
12 | """Implementation of the CelebA dataset.
13 | Downloaded from https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
14 | """
15 | def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwargs):
16 | if not os.path.isdir(root):
17 | raise ValueError(f"The specified root: {root} does not exist")
18 | self.root = root
19 | self.transform = transform
20 | self.norm = norm
21 |
22 | self.images = []
23 |
24 | for img in tqdm(os.listdir(root)):
25 | self.images.append(os.path.join(self.root, img))
26 |
27 | # Subsample the dataset (if enabled)
28 | if subsample_size is not None:
29 | self.images = self.images[:subsample_size]
30 |
31 | def __getitem__(self, idx):
32 | img_path = self.images[idx]
33 | img = Image.open(img_path)
34 |
35 | # Apply transform
36 | if self.transform is not None:
37 | img = self.transform(img)
38 |
39 | # Normalize
40 | img = data_scaler(img, norm=self.norm)
41 |
42 | # TODO: Add the functionality to return labels to enable
43 | # guidance-based generation.
44 | return torch.from_numpy(img).permute(2, 0, 1).float()
45 |
46 | def __len__(self):
47 | return len(self.images)
48 |
--------------------------------------------------------------------------------
/main/datasets/celebahq.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from tqdm import tqdm
7 | from util import data_scaler, register_module
8 |
9 |
10 | # NOTE: We dont use this dataset in the paper
11 | @register_module(category="datasets", name='celebahq256')
12 | class CelebAHQDataset(Dataset):
13 | def __init__(self, root, norm=True, subsample_size=None, transform=None, **kwargs):
14 | if not os.path.isdir(root):
15 | raise ValueError(f"The specified root: {root} does not exist")
16 | self.root = root
17 | self.transform = transform
18 | self.norm = norm
19 |
20 | self.images = []
21 |
22 | modes = ["train", "val"]
23 | subfolders = ["male", "female"]
24 |
25 | for mode in modes:
26 | for folder in subfolders:
27 | img_path = os.path.join(self.root, mode, folder)
28 | for img in tqdm(sorted(os.listdir(img_path))):
29 | self.images.append(os.path.join(img_path, img))
30 |
31 | # Subsample the dataset (if enabled)
32 | if subsample_size is not None:
33 | self.images = self.images[:subsample_size]
34 |
35 | def __getitem__(self, idx):
36 | img_path = self.images[idx]
37 | img = Image.open(img_path)
38 | if self.transform is not None:
39 | img = self.transform(img)
40 |
41 | # Normalize
42 | img = data_scaler(img, norm=self.norm)
43 |
44 | return torch.from_numpy(img).permute(2, 0, 1).float()
45 |
46 | def __len__(self):
47 | return len(self.images)
48 |
--------------------------------------------------------------------------------
/main/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 | from torchvision.datasets import CIFAR10
6 | from util import data_scaler, register_module
7 |
8 |
9 | @register_module(category="datasets", name="cifar10")
10 | class CIFAR10Dataset(Dataset):
11 | def __init__(
12 | self,
13 | root,
14 | norm=True,
15 | transform=None,
16 | subsample_size=None,
17 | return_target=False,
18 | **kwargs,
19 | ):
20 | if not os.path.isdir(root):
21 | raise ValueError(f"The specified root: {root} does not exist")
22 |
23 | if subsample_size is not None:
24 | assert isinstance(subsample_size, int)
25 |
26 | self.root = root
27 | self.norm = norm
28 | self.transform = transform
29 | self.dataset = CIFAR10(self.root, train=True, download=True, transform=None)
30 | self.subsample_size = subsample_size
31 | self.return_target = return_target
32 |
33 | def __getitem__(self, idx):
34 | img, target = self.dataset[idx]
35 |
36 | # Apply transform
37 | img_ = self.transform(img) if self.transform is not None else img
38 |
39 | # Normalize
40 | img = data_scaler(img_, norm=self.norm)
41 |
42 | # Return (with targets if needed for guidance-based generation)
43 | img_tensor = torch.tensor(img).permute(2, 0, 1).float()
44 | if self.return_target:
45 | return img_tensor, target
46 | return img_tensor
47 |
48 | def __len__(self):
49 | return len(self.dataset) if self.subsample_size is None else self.subsample_size
50 |
--------------------------------------------------------------------------------
/main/datasets/inpaint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms.functional import InterpolationMode
3 | from torch.utils.data import Dataset
4 |
5 | from util import register_module
6 | from torchvision.datasets import MNIST
7 | import torchvision.transforms as T
8 |
9 |
10 | @register_module(category="datasets", name="inpaint")
11 | class InpaintDataset(Dataset):
12 | """Dataset for generating corrupted images. The images from the base dataset are masked with
13 | MNIST to generate images with missing pixels.
14 | """
15 | def __init__(self, config, dataset):
16 | # Parent dataset (must return only images)
17 | self.config = config
18 | self.dataset = dataset
19 |
20 | # Used for creating masks
21 | t = T.Compose(
22 | [
23 | T.Resize(
24 | (config.data.image_size, config.data.image_size),
25 | InterpolationMode.NEAREST,
26 | ),
27 | T.ToTensor(),
28 | ]
29 | )
30 | self.mnist = MNIST(config.data.root, train=True, download=True, transform=t)
31 |
32 | def __getitem__(self, idx):
33 | img = self.dataset[idx]
34 |
35 | # Generate the mask with the mnist digit
36 | digit, _ = self.mnist[idx]
37 | digit = torch.cat([digit] * 3, dim=0)
38 | mask = (digit > 0).type(torch.long)
39 | mask = 1 - mask
40 | assert mask.shape == img.shape
41 | return img, mask
42 |
43 | def __len__(self):
44 | return min(self.config.evaluation.n_samples, len(self.dataset))
45 |
--------------------------------------------------------------------------------
/main/datasets/latent.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from util import register_module
3 |
4 |
5 | @register_module(category="datasets", name="latent")
6 | class SDELatentDataset(Dataset):
7 | """A dataset for generating samples from the equilibrium distribution
8 | of the forward SDE (useful during sampling)
9 | """
10 | def __init__(self, sde, config):
11 | self.sde = sde
12 | self.num_samples = config.evaluation.n_samples
13 | self.shape = [
14 | self.num_samples,
15 | config.data.num_channels,
16 | config.data.image_size,
17 | config.data.image_size,
18 | ]
19 | self.samples = self.sde.prior_sampling(self.shape)
20 |
21 | def get_batch(self, shape):
22 | return self.sde.prior_sampling(shape)
23 |
24 | def __getitem__(self, idx):
25 | return self.samples[idx]
26 |
27 | def __len__(self):
28 | return self.num_samples
29 |
--------------------------------------------------------------------------------
/main/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/eval/__init__.py
--------------------------------------------------------------------------------
/main/eval/class_cond_sample.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | from copy import deepcopy
5 |
6 | # Add project directory to sys.path
7 | p = os.path.join(os.path.abspath("."), "main")
8 | sys.path.insert(1, p)
9 |
10 | import hydra
11 | import pytorch_lightning as pl
12 | from callbacks import SimpleImageWriter
13 | from datasets.latent import SDELatentDataset
14 | from models.clf_wrapper import TClfWrapper
15 | from models.wrapper import SDEWrapper
16 | from omegaconf import OmegaConf
17 | from pytorch_lightning.utilities.seed import seed_everything
18 | from torch.utils.data import DataLoader
19 | from util import get_module, import_modules_into_registry
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | # Import all modules into registry
25 | import_modules_into_registry()
26 |
27 |
28 | @hydra.main(config_path=os.path.join(p, "configs"))
29 | def cc_sample(config):
30 | """Evaluation script for Class conditional sampling with pre-trained score models
31 | using classifier guidance.
32 | """
33 | config = config.dataset
34 | config_sde = config.diffusion
35 | config_clf = config.clf
36 | logger.info(OmegaConf.to_yaml(config))
37 |
38 | # Set seed
39 | seed_everything(config_clf.evaluation.seed)
40 |
41 | # Setup Score SDE
42 | sde_cls = get_module(category="sde", name=config_sde.model.sde.name)
43 | sde = sde_cls(config_sde)
44 | logger.info(f"Using SDE: {sde_cls}")
45 |
46 | # Setup score predictor
47 | score_fn_cls = get_module(category="score_fn", name=config_sde.model.score_fn.name)
48 | score_fn = score_fn_cls(config_sde)
49 | logger.info(f"Using Score fn: {score_fn_cls}")
50 |
51 | ema_score_fn = deepcopy(score_fn)
52 | for p in ema_score_fn.parameters():
53 | p.requires_grad = False
54 |
55 | wrapper = SDEWrapper.load_from_checkpoint(
56 | config_sde.evaluation.chkpt_path,
57 | config=config_sde,
58 | sde=sde,
59 | score_fn=score_fn,
60 | ema_score_fn=ema_score_fn,
61 | sampler_cls=None,
62 | )
63 |
64 | score_fn = (
65 | wrapper.ema_score_fn
66 | if config_sde.evaluation.sample_from == "target"
67 | else wrapper.score_fn
68 | )
69 | score_fn.eval()
70 |
71 | # Setup sampler
72 | sampler_cls = get_module(
73 | category="samplers", name=config_sde.evaluation.sampler.name
74 | )
75 | logger.info(
76 | f"Using Sampler: {sampler_cls}. Make sure the sampler supports Class-conditional sampling"
77 | )
78 |
79 | # Setup dataset
80 | dataset = SDELatentDataset(sde, config_sde)
81 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}")
82 |
83 | # Setup classifier (for guidance)
84 | clf_fn_cls = get_module(category="clf_fn", name=config_clf.model.clf_fn.name)
85 | clf_fn = clf_fn_cls(config_clf)
86 | logger.info(f"Using Classifier fn: {clf_fn_cls}")
87 |
88 | wrapper = TClfWrapper.load_from_checkpoint(
89 | config_clf.evaluation.chkpt_path,
90 | config=config,
91 | sde=sde,
92 | clf_fn=clf_fn,
93 | score_fn=score_fn,
94 | sampler_cls=sampler_cls,
95 | strict=False,
96 | )
97 | wrapper.eval()
98 |
99 | # Setup devices
100 | test_kwargs = {}
101 | loader_kws = {}
102 | device_type = config_sde.evaluation.accelerator
103 | test_kwargs["accelerator"] = device_type
104 | if device_type == "gpu":
105 | test_kwargs["devices"] = config_sde.evaluation.devices
106 | # # Disable find_unused_parameters when using DDP training for performance reasons
107 | # loader_kws["persistent_workers"] = True
108 | elif device_type == "tpu":
109 | test_kwargs["tpu_cores"] = 8
110 |
111 | # Predict loader
112 | val_loader = DataLoader(
113 | dataset,
114 | batch_size=config_sde.evaluation.batch_size,
115 | drop_last=False,
116 | pin_memory=True,
117 | shuffle=False,
118 | num_workers=config_sde.evaluation.workers,
119 | **loader_kws,
120 | )
121 |
122 | # Setup Image writer callback trainer
123 | write_callback = SimpleImageWriter(
124 | config_sde.evaluation.save_path,
125 | write_interval="batch",
126 | sample_prefix=config_sde.evaluation.sample_prefix,
127 | path_prefix=config_sde.evaluation.path_prefix,
128 | save_mode=config_sde.evaluation.save_mode,
129 | is_augmented=config_sde.model.sde.is_augmented,
130 | )
131 |
132 | test_kwargs["callbacks"] = [write_callback]
133 | test_kwargs["default_root_dir"] = config_sde.evaluation.save_path
134 |
135 | # Setup Pl module and predict
136 | sampler = pl.Trainer(**test_kwargs)
137 | sampler.predict(wrapper, val_loader)
138 |
139 |
140 | if __name__ == "__main__":
141 | cc_sample()
142 |
--------------------------------------------------------------------------------
/main/eval/inpaint.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | # Add project directory to sys.path
6 | p = os.path.join(os.path.abspath("."), "main")
7 | sys.path.insert(1, p)
8 |
9 | from copy import deepcopy
10 |
11 | import hydra
12 | import pytorch_lightning as pl
13 | from callbacks import InpaintingImageWriter
14 | from datasets import InpaintDataset
15 | from models.wrapper import SDEWrapper
16 | from omegaconf import OmegaConf
17 | from pytorch_lightning.utilities.seed import seed_everything
18 | from torch.utils.data import DataLoader
19 | from util import get_dataset, get_module, import_modules_into_registry
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | # Import all modules into registry
25 | import_modules_into_registry()
26 |
27 |
28 | @hydra.main(config_path=os.path.join(p, "configs"))
29 | def inpaint(config):
30 | """Evaluation script for inpainting with pre-trained score models using guidance."""
31 | config = config.dataset.diffusion
32 | logger.info(OmegaConf.to_yaml(config))
33 |
34 | # Set seed
35 | seed_everything(config.evaluation.seed)
36 |
37 | # Setup Score Predictor
38 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name)
39 | score_fn = score_fn_cls(config)
40 | logger.info(f"Using Score fn: {score_fn_cls}")
41 |
42 | ema_score_fn = deepcopy(score_fn)
43 | for p in ema_score_fn.parameters():
44 | p.requires_grad = False
45 |
46 | score_fn.eval()
47 | ema_score_fn.eval()
48 |
49 | # Setup Score SDE
50 | sde_cls = get_module(category="sde", name=config.model.sde.name)
51 | sde = sde_cls(config)
52 | logger.info(f"Using SDE: {sde_cls}")
53 |
54 | # Setup sampler
55 | sampler_cls = get_module(category="samplers", name=config.evaluation.sampler.name)
56 | logger.info(f"Using Sampler: {sampler_cls}")
57 |
58 | # Setup dataset
59 | base_dataset = get_dataset(config)
60 | dataset = InpaintDataset(config, base_dataset)
61 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}")
62 |
63 | wrapper = SDEWrapper.load_from_checkpoint(
64 | config.evaluation.chkpt_path,
65 | config=config,
66 | sde=sde,
67 | score_fn=score_fn,
68 | ema_score_fn=ema_score_fn,
69 | sampler_cls=sampler_cls,
70 | )
71 |
72 | # Setup devices
73 | test_kwargs = {}
74 | loader_kws = {}
75 | device_type = config.evaluation.accelerator
76 | test_kwargs["accelerator"] = device_type
77 | if device_type == "gpu":
78 | test_kwargs["devices"] = config.evaluation.devices
79 | # # Disable find_unused_parameters when using DDP training for performance reasons
80 | # loader_kws["persistent_workers"] = True
81 | elif device_type == "tpu":
82 | test_kwargs["tpu_cores"] = 8
83 |
84 | # Predict loader
85 | val_loader = DataLoader(
86 | dataset,
87 | batch_size=config.evaluation.batch_size,
88 | drop_last=False,
89 | pin_memory=True,
90 | shuffle=False,
91 | num_workers=config.evaluation.workers,
92 | **loader_kws,
93 | )
94 |
95 | # Setup Image writer callback trainer
96 | write_callback = InpaintingImageWriter(
97 | config.evaluation.save_path,
98 | write_interval="batch",
99 | sample_prefix=config.evaluation.sample_prefix,
100 | path_prefix=config.evaluation.path_prefix,
101 | save_mode=config.evaluation.save_mode,
102 | is_augmented=config.model.sde.is_augmented,
103 | save_batch=True,
104 | )
105 |
106 | test_kwargs["callbacks"] = [write_callback]
107 | test_kwargs["default_root_dir"] = config.evaluation.save_path
108 |
109 | # Setup Pl module and predict
110 | sampler = pl.Trainer(**test_kwargs)
111 | sampler.predict(wrapper, val_loader)
112 |
113 |
114 | if __name__ == "__main__":
115 | inpaint()
116 |
--------------------------------------------------------------------------------
/main/eval/sample.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | # Add project directory to sys.path
6 | p = os.path.join(os.path.abspath("."), "main")
7 | sys.path.insert(1, p)
8 |
9 | from copy import deepcopy
10 |
11 | import hydra
12 | import pytorch_lightning as pl
13 | from callbacks import SimpleImageWriter
14 | from datasets.latent import SDELatentDataset
15 | from models.wrapper import SDEWrapper
16 | from omegaconf import OmegaConf
17 | from pytorch_lightning.utilities.seed import seed_everything
18 | from torch.utils.data import DataLoader
19 | from util import get_module, import_modules_into_registry
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | # Import all modules into registry
25 | import_modules_into_registry()
26 |
27 |
28 | @hydra.main(config_path=os.path.join(p, "configs"))
29 | def sample(config):
30 | """Evaluation script for Unconditional sampling using pre-trained score models"""
31 | config = config.dataset.diffusion
32 | logger.info(OmegaConf.to_yaml(config))
33 |
34 | # Set seed
35 | seed_everything(config.evaluation.seed)
36 |
37 | # Setup score predictor
38 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name)
39 | score_fn = score_fn_cls(config)
40 | logger.info(f"Using Score fn: {score_fn_cls}")
41 |
42 | ema_score_fn = deepcopy(score_fn)
43 | for p in ema_score_fn.parameters():
44 | p.requires_grad = False
45 |
46 | score_fn.eval()
47 | ema_score_fn.eval()
48 |
49 | # Setup Score SDE
50 | sde_cls = get_module(category="sde", name=config.model.sde.name)
51 | sde = sde_cls(config)
52 | logger.info(f"Using SDE: {sde_cls}")
53 |
54 | # Setup sampler
55 | sampler_cls = get_module(category="samplers", name=config.evaluation.sampler.name)
56 | logger.info(f"Using Sampler: {sampler_cls}")
57 |
58 | # Setup dataset
59 | dataset = SDELatentDataset(sde, config)
60 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}")
61 |
62 | wrapper = SDEWrapper.load_from_checkpoint(
63 | config.evaluation.chkpt_path,
64 | config=config,
65 | sde=sde,
66 | score_fn=score_fn,
67 | ema_score_fn=ema_score_fn,
68 | sampler_cls=sampler_cls,
69 | )
70 |
71 | # Setup devices
72 | test_kwargs = {}
73 | loader_kws = {}
74 | device_type = config.evaluation.accelerator
75 | test_kwargs["accelerator"] = device_type
76 | if device_type == "gpu":
77 | test_kwargs["devices"] = config.evaluation.devices
78 | # # Disable find_unused_parameters when using DDP training for performance reasons
79 | # loader_kws["persistent_workers"] = True
80 | elif device_type == "tpu":
81 | test_kwargs["tpu_cores"] = 8
82 |
83 | # Predict loader
84 | val_loader = DataLoader(
85 | dataset,
86 | batch_size=config.evaluation.batch_size,
87 | drop_last=False,
88 | pin_memory=True,
89 | shuffle=False,
90 | num_workers=config.evaluation.workers,
91 | **loader_kws,
92 | )
93 |
94 | # Setup Image writer callback trainer
95 | write_callback = SimpleImageWriter(
96 | config.evaluation.save_path,
97 | write_interval="batch",
98 | sample_prefix=config.evaluation.sample_prefix,
99 | path_prefix=config.evaluation.path_prefix,
100 | save_mode=config.evaluation.save_mode,
101 | is_augmented=config.model.sde.is_augmented,
102 | )
103 |
104 | test_kwargs["callbacks"] = [write_callback]
105 | test_kwargs["default_root_dir"] = config.evaluation.save_path
106 |
107 | # Setup Pl module and predict
108 | sampler = pl.Trainer(**test_kwargs)
109 | sampler.predict(wrapper, val_loader)
110 |
111 |
112 | if __name__ == "__main__":
113 | sample()
114 |
--------------------------------------------------------------------------------
/main/losses.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from util import get_module, register_module, reshape
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def compute_top_k(logits, labels, k, reduction="mean"):
13 | _, top_ks = torch.topk(logits, k, dim=-1)
14 | if reduction == "mean":
15 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
16 | elif reduction == "none":
17 | return (top_ks == labels[:, None]).float().sum(dim=-1)
18 |
19 |
20 | @register_module(category="losses", name="score_loss")
21 | class ScoreLoss(nn.Module):
22 | """Loss function for training non-augmented Score based models (like VP-SDE)"""
23 |
24 | def __init__(self, config, sde):
25 | super().__init__()
26 | assert config.training.loss.weighting in ["nll", "fid"]
27 | self.sde = sde
28 | self.l_type = config.training.loss.l_type
29 | self.weighting = config.training.loss.weighting
30 |
31 | logger.info(f"Initialized Loss fn with weighting: {self.weighting}")
32 |
33 | if self.weighting == "nll" and self.l_type != "l2":
34 | # Use only MSE loss when maximizing for nll
35 | raise ValueError("l_type can only be `l2` when using nll weighting")
36 |
37 | self.reduce_strategy = "mean" if config.training.loss.reduce_mean else "sum"
38 | criterion_type = nn.MSELoss if self.l_type == "l2" else nn.L1Loss
39 | self.criterion = criterion_type(reduction=self.reduce_strategy)
40 |
41 | def forward(self, x_0, t, score_fn, eps=None):
42 | if eps is None:
43 | eps = torch.randn_like(x_0, device=x_0.device)
44 |
45 | assert eps.shape == x_0.shape
46 |
47 | # Predict epsilon
48 | x_t = self.sde.perturb_data(x_0, t, noise=eps)
49 | eps_pred = score_fn(x_t.type(torch.float32), t.type(torch.float32))
50 |
51 | # Use eps-prediction when optimizing for FID
52 | loss = self.criterion(eps, eps_pred)
53 |
54 | # Use g(t)**2 weighting when optimizing for NLL
55 | if self.weighting == "nll":
56 | gt_2 = reshape(self.sde.likelihood_weighting(t), x_0)
57 | gt_score = self.sde.get_score(eps, t)
58 | pred_score = self.sde.get_score(eps_pred, t)
59 | # loss = self.criterion(pred_score, gt_score) * gt_2
60 | loss = (pred_score - gt_score) ** 2 * gt_2
61 | loss = (
62 | torch.mean(loss) if self.reduce_strategy == "mean" else torch.sum(loss)
63 | )
64 |
65 | return loss
66 |
67 |
68 | @register_module(category="losses", name="psld_score_loss")
69 | class PSLDScoreLoss(nn.Module):
70 | """Loss function for training PSLD."""
71 |
72 | def __init__(self, config, sde):
73 | super().__init__()
74 | # TODO: Add support for likelihood training
75 | assert config.training.loss.weighting in ["fid"]
76 | assert config.training.mode in ["hsm", "dsm"]
77 | assert isinstance(sde, get_module("sde", "psld"))
78 | self.sde = sde
79 | self.l_type = config.training.loss.l_type
80 | self.weighting = config.training.loss.weighting
81 | self.mode = config.training.mode
82 | self.decomp_mode = config.model.sde.decomp_mode
83 |
84 | logger.info(
85 | f"Initialized Loss fn with weighting: {self.weighting} mode: {self.mode}"
86 | )
87 |
88 | if self.weighting == "nll" and self.l_type != "l2":
89 | # Use only MSE loss when maximizing for nll
90 | raise ValueError("l_type can only be `l2` when using nll weighting")
91 |
92 | self.reduce_strategy = "mean" if config.training.loss.reduce_mean else "sum"
93 |
94 | def forward(self, x_0, t, score_fn, eps=None):
95 | # Sample momentum (DSM)
96 | m_0 = np.sqrt(self.sde.mm_0) * torch.randn_like(x_0)
97 | mm_0 = 0.0
98 |
99 | # Update momentum (if training mode is HSM)
100 | if self.mode == "hsm":
101 | m_0 = torch.zeros_like(x_0)
102 | mm_0 = self.sde.mm_0
103 |
104 | z_0 = torch.cat([x_0, m_0], dim=1)
105 | xx_0 = 0
106 |
107 | # Sample random noise
108 | if eps is None:
109 | eps = torch.randn_like(z_0)
110 | assert eps.shape == z_0.shape
111 |
112 | # Predict epsilon
113 | z_t, _, _ = self.sde.perturb_data(x_0, m_0, xx_0, mm_0, t, eps=eps)
114 | z_t = z_t.type(torch.float32)
115 | eps_pred = score_fn(z_t, t.type(torch.float32))
116 |
117 | # Use eps-prediction when optimizing for FID
118 | eps_x, eps_m = torch.chunk(eps, 2, dim=1)
119 | if self.sde.mode == "score_m" and self.decomp_mode == "lower":
120 | assert eps_pred.shape == eps_m.shape
121 | loss = (eps_m - eps_pred) ** 2
122 | elif self.sde.mode == "score_x" and self.decomp_mode == "upper":
123 | assert eps_pred.shape == eps_x.shape
124 | loss = (eps_x - eps_pred) ** 2
125 | else:
126 | assert eps_pred.shape == eps.shape
127 | loss = (eps - eps_pred) ** 2
128 |
129 | loss = torch.mean(loss) if self.reduce_strategy == "mean" else torch.sum(loss)
130 | return loss
131 |
132 |
133 | @register_module(category="losses", name="tce_loss")
134 | class PSLDTimeCELoss(nn.Module):
135 | """Loss function for training noise conditioned classifier for guidance"""
136 |
137 | def __init__(self, config, sde):
138 | super().__init__()
139 | assert config.diffusion.training.mode in ["hsm", "dsm"]
140 | assert isinstance(sde, get_module("sde", "psld"))
141 | self.sde = sde
142 | self.l_type = config.clf.training.loss.l_type
143 | self.mode = config.diffusion.training.mode
144 |
145 | self.reduce_strategy = (
146 | "mean" if config.diffusion.training.loss.reduce_mean else "sum"
147 | )
148 | self.criterion = nn.CrossEntropyLoss(reduction=self.reduce_strategy)
149 |
150 | def forward(self, x_0, y, t, clf_fn):
151 | # Sample momentum (DSM)
152 | m_0 = np.sqrt(self.sde.mm_0) * torch.randn_like(x_0)
153 | mm_0 = 0.0
154 |
155 | # Update momentum (if training mode is HSM)
156 | if self.mode == "hsm":
157 | m_0 = torch.zeros_like(x_0)
158 | mm_0 = self.sde.mm_0
159 |
160 | u_0 = torch.cat([x_0, m_0], dim=1)
161 | xx_0 = 0
162 |
163 | # Sample random noise
164 | eps = torch.randn_like(u_0)
165 |
166 | # Perturb Data
167 | u_t, _, _ = self.sde.perturb_data(x_0, m_0, xx_0, mm_0, t, eps=eps)
168 |
169 | # Predict label
170 | y_pred = clf_fn(u_t.type(torch.float32), t)
171 |
172 | # CE loss
173 | loss = self.criterion(y_pred, y)
174 |
175 | # Top-k accuracy (for debugging)
176 | acc = compute_top_k(y_pred, y, 1)
177 | return loss, acc
178 |
--------------------------------------------------------------------------------
/main/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .clf_wrapper import TClfWrapper
2 | from .score_fn.song_sde.ncsnpp import NCSNpp
3 | from .score_fn.song_sde.ncsnpp_clf import NCSNppClassifier
4 | from .sde.psld import PSLD
5 | from .sde.vpsde import VPSDE
6 | from .wrapper import SDEWrapper
7 |
--------------------------------------------------------------------------------
/main/models/clf_wrapper.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from pytorch_lightning.utilities.seed import seed_everything
6 | from util import get_module, register_module
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | @register_module(category="pl_modules", name="tclf_wrapper")
12 | class TClfWrapper(pl.LightningModule):
13 | """This PL module can do the following tasks:
14 | - train: Train a classifier to predict labels given a noisy state z_t
15 | - predict: Generate class-conditional samples using classifier guidance
16 | """
17 |
18 | def __init__(
19 | self,
20 | config,
21 | sde,
22 | clf_fn,
23 | score_fn=None,
24 | criterion=None,
25 | sampler_cls=None,
26 | corrector_fn=None,
27 | ):
28 | super().__init__()
29 | self.config = config
30 | self.sde = sde
31 | self.clf_fn = clf_fn
32 |
33 | # Training
34 | self.criterion = criterion
35 | self.train_eps = self.config.diffusion.training.train_eps
36 |
37 | # Evaluation
38 | self.score_fn = score_fn
39 | self.sampler = None
40 |
41 | if sampler_cls is not None:
42 | self.sampler = sampler_cls(
43 | self.config,
44 | self.sde,
45 | self.score_fn,
46 | self.clf_fn,
47 | corrector_fn=corrector_fn,
48 | )
49 | self.eval_eps = self.config.diffusion.evaluation.eval_eps
50 | self.denoise = self.config.diffusion.evaluation.denoise
51 | n_discrete_steps = self.config.diffusion.evaluation.n_discrete_steps
52 | self.n_discrete_steps = (
53 | n_discrete_steps - 1 if self.denoise else n_discrete_steps
54 | )
55 | self.val_eps = self.config.diffusion.evaluation.eval_eps
56 | self.stride_type = self.config.diffusion.evaluation.stride_type
57 |
58 | def forward(self):
59 | pass
60 |
61 | def training_step(self, batch, batch_idx):
62 | # Images and labels
63 | x_0, y = batch
64 |
65 | # Sample timepoints (between [train_eps, 1])
66 | t_ = torch.rand(x_0.shape[0], device=x_0.device, dtype=torch.float64)
67 | t = t_ * (self.sde.T - self.train_eps) + self.train_eps
68 | assert t.shape[0] == x_0.shape[0]
69 |
70 | # Compute loss and backward
71 | loss, acc = self.criterion(x_0, y, t, self.clf_fn)
72 |
73 | self.log("loss", loss, prog_bar=True)
74 | self.log("Top1-Acc", acc, prog_bar=True)
75 | return loss
76 |
77 | def on_predict_start(self):
78 | seed = self.config.clf.evaluation.seed
79 |
80 | # This is done for predictions since setting a common seed
81 | # leads to generating same samples across gpus which affects
82 | # the evaluation metrics like FID negatively.
83 | seed_everything(seed + self.global_rank, workers=True)
84 |
85 | def predict_step(self, batch, batch_idx, dataloader_idx=None):
86 | t_final = self.sde.T - self.eval_eps
87 | ts = torch.linspace(
88 | 0,
89 | t_final,
90 | self.n_discrete_steps + 1,
91 | device=batch.device,
92 | dtype=torch.float64,
93 | )
94 |
95 | if self.stride_type == "uniform":
96 | pass
97 | elif self.stride_type == "quadratic":
98 | ts = t_final * torch.flip(1 - (ts / t_final) ** 2.0, dims=[0])
99 |
100 | return self.sampler.sample(
101 | batch, ts, self.n_discrete_steps, denoise=self.denoise, eps=self.eval_eps
102 | )
103 |
104 | def on_predict_end(self):
105 | if isinstance(self.sampler, get_module("samplers", "bb_ode")):
106 | print(self.sampler.mean_nfe)
107 |
108 | def configure_optimizers(self):
109 | opt_config = self.config.clf.training.optimizer
110 | opt_name = opt_config.name
111 | if opt_name == "Adam":
112 | optimizer = torch.optim.Adam(
113 | self.clf_fn.parameters(),
114 | lr=opt_config.lr,
115 | betas=(opt_config.beta_1, opt_config.beta_2),
116 | eps=opt_config.eps,
117 | weight_decay=opt_config.weight_decay,
118 | )
119 | else:
120 | raise NotImplementedError(f"Optimizer {opt_name} not supported yet!")
121 |
122 | # Define the LR scheduler (As in Ho et al.)
123 | if opt_config.warmup == 0:
124 | lr_lambda = lambda step: 1.0
125 | else:
126 | lr_lambda = lambda step: min(step / opt_config.warmup, 1.0)
127 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
128 | return {
129 | "optimizer": optimizer,
130 | "lr_scheduler": {
131 | "scheduler": scheduler,
132 | "interval": "step",
133 | "strict": False,
134 | },
135 | }
136 |
--------------------------------------------------------------------------------
/main/models/score_fn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/score_fn/__init__.py
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/score_fn/song_sde/__init__.py
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/layerspp.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # pylint: skip-file
17 | """Layers for defining NCSN++.
18 | """
19 | from . import layers
20 | from . import up_or_down_sampling
21 | import torch.nn as nn
22 | import torch
23 | import torch.nn.functional as F
24 | import numpy as np
25 |
26 | conv1x1 = layers.ddpm_conv1x1
27 | conv3x3 = layers.ddpm_conv3x3
28 | NIN = layers.NIN
29 | default_init = layers.default_init
30 |
31 |
32 | class GaussianFourierProjection(nn.Module):
33 | """Gaussian Fourier embeddings for noise levels."""
34 |
35 | def __init__(self, embedding_size=256, scale=1.0):
36 | super().__init__()
37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38 |
39 | def forward(self, x):
40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42 |
43 |
44 | class Combine(nn.Module):
45 | """Combine information from skip connections."""
46 |
47 | def __init__(self, dim1, dim2, method='cat'):
48 | super().__init__()
49 | self.Conv_0 = conv1x1(dim1, dim2)
50 | self.method = method
51 |
52 | def forward(self, x, y):
53 | h = self.Conv_0(x)
54 | if self.method == 'cat':
55 | return torch.cat([h, y], dim=1)
56 | elif self.method == 'sum':
57 | return h + y
58 | else:
59 | raise ValueError(f'Method {self.method} not recognized.')
60 |
61 |
62 | class AttnBlockpp(nn.Module):
63 | """Channel-wise self-attention block. Modified from DDPM."""
64 |
65 | def __init__(self, channels, skip_rescale=False, init_scale=0.):
66 | super().__init__()
67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68 | eps=1e-6)
69 | self.NIN_0 = NIN(channels, channels)
70 | self.NIN_1 = NIN(channels, channels)
71 | self.NIN_2 = NIN(channels, channels)
72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73 | self.skip_rescale = skip_rescale
74 |
75 | def forward(self, x):
76 | B, C, H, W = x.shape
77 | h = self.GroupNorm_0(x)
78 | q = self.NIN_0(h)
79 | k = self.NIN_1(h)
80 | v = self.NIN_2(h)
81 |
82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83 | w = torch.reshape(w, (B, H, W, H * W))
84 | w = F.softmax(w, dim=-1)
85 | w = torch.reshape(w, (B, H, W, H, W))
86 | h = torch.einsum('bhwij,bcij->bchw', w, v)
87 | h = self.NIN_3(h)
88 | if not self.skip_rescale:
89 | return x + h
90 | else:
91 | return (x + h) / np.sqrt(2.)
92 |
93 |
94 | class Upsample(nn.Module):
95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96 | fir_kernel=(1, 3, 3, 1)):
97 | super().__init__()
98 | out_ch = out_ch if out_ch else in_ch
99 | if not fir:
100 | if with_conv:
101 | self.Conv_0 = conv3x3(in_ch, out_ch)
102 | else:
103 | if with_conv:
104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105 | kernel=3, up=True,
106 | resample_kernel=fir_kernel,
107 | use_bias=True,
108 | kernel_init=default_init())
109 | self.fir = fir
110 | self.with_conv = with_conv
111 | self.fir_kernel = fir_kernel
112 | self.out_ch = out_ch
113 |
114 | def forward(self, x):
115 | B, C, H, W = x.shape
116 | if not self.fir:
117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118 | if self.with_conv:
119 | h = self.Conv_0(h)
120 | else:
121 | if not self.with_conv:
122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123 | else:
124 | h = self.Conv2d_0(x)
125 |
126 | return h
127 |
128 |
129 | class Downsample(nn.Module):
130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131 | fir_kernel=(1, 3, 3, 1)):
132 | super().__init__()
133 | out_ch = out_ch if out_ch else in_ch
134 | if not fir:
135 | if with_conv:
136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137 | else:
138 | if with_conv:
139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140 | kernel=3, down=True,
141 | resample_kernel=fir_kernel,
142 | use_bias=True,
143 | kernel_init=default_init())
144 | self.fir = fir
145 | self.fir_kernel = fir_kernel
146 | self.with_conv = with_conv
147 | self.out_ch = out_ch
148 |
149 | def forward(self, x):
150 | B, C, H, W = x.shape
151 | if not self.fir:
152 | if self.with_conv:
153 | x = F.pad(x, (0, 1, 0, 1))
154 | x = self.Conv_0(x)
155 | else:
156 | x = F.avg_pool2d(x, 2, stride=2)
157 | else:
158 | if not self.with_conv:
159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160 | else:
161 | x = self.Conv2d_0(x)
162 |
163 | return x
164 |
165 |
166 | class ResnetBlockDDPMpp(nn.Module):
167 | """ResBlock adapted from DDPM."""
168 |
169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170 | dropout=0.1, skip_rescale=False, init_scale=0.):
171 | super().__init__()
172 | out_ch = out_ch if out_ch else in_ch
173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174 | self.Conv_0 = conv3x3(in_ch, out_ch)
175 | if temb_dim is not None:
176 | self.Dense_0 = nn.Linear(temb_dim, out_ch)
177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178 | nn.init.zeros_(self.Dense_0.bias)
179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180 | self.Dropout_0 = nn.Dropout(dropout)
181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182 | if in_ch != out_ch:
183 | if conv_shortcut:
184 | self.Conv_2 = conv3x3(in_ch, out_ch)
185 | else:
186 | self.NIN_0 = NIN(in_ch, out_ch)
187 |
188 | self.skip_rescale = skip_rescale
189 | self.act = act
190 | self.out_ch = out_ch
191 | self.conv_shortcut = conv_shortcut
192 |
193 | def forward(self, x, temb=None):
194 | h = self.act(self.GroupNorm_0(x))
195 | h = self.Conv_0(h)
196 | if temb is not None:
197 | h += self.Dense_0(self.act(temb))[:, :, None, None]
198 | h = self.act(self.GroupNorm_1(h))
199 | h = self.Dropout_0(h)
200 | h = self.Conv_1(h)
201 | if x.shape[1] != self.out_ch:
202 | if self.conv_shortcut:
203 | x = self.Conv_2(x)
204 | else:
205 | x = self.NIN_0(x)
206 | if not self.skip_rescale:
207 | return x + h
208 | else:
209 | return (x + h) / np.sqrt(2.)
210 |
211 |
212 | class ResnetBlockBigGANpp(nn.Module):
213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215 | skip_rescale=True, init_scale=0.):
216 | super().__init__()
217 |
218 | out_ch = out_ch if out_ch else in_ch
219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220 | self.up = up
221 | self.down = down
222 | self.fir = fir
223 | self.fir_kernel = fir_kernel
224 |
225 | self.Conv_0 = conv3x3(in_ch, out_ch)
226 | if temb_dim is not None:
227 | self.Dense_0 = nn.Linear(temb_dim, out_ch)
228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229 | nn.init.zeros_(self.Dense_0.bias)
230 |
231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232 | self.Dropout_0 = nn.Dropout(dropout)
233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234 | if in_ch != out_ch or up or down:
235 | self.Conv_2 = conv1x1(in_ch, out_ch)
236 |
237 | self.skip_rescale = skip_rescale
238 | self.act = act
239 | self.in_ch = in_ch
240 | self.out_ch = out_ch
241 |
242 | def forward(self, x, temb=None):
243 | h = self.act(self.GroupNorm_0(x))
244 |
245 | if self.up:
246 | if self.fir:
247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249 | else:
250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252 | elif self.down:
253 | if self.fir:
254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256 | else:
257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259 |
260 | h = self.Conv_0(h)
261 | # Add bias to each feature map conditioned on the time embedding
262 | if temb is not None:
263 | h += self.Dense_0(self.act(temb))[:, :, None, None]
264 | h = self.act(self.GroupNorm_1(h))
265 | h = self.Dropout_0(h)
266 | h = self.Conv_1(h)
267 |
268 | if self.in_ch != self.out_ch or self.up or self.down:
269 | x = self.Conv_2(x)
270 |
271 | if not self.skip_rescale:
272 | return x + h
273 | else:
274 | return (x + h) / np.sqrt(2.)
275 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/ncsnpp_clf.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # pylint: skip-file
17 |
18 | from . import utils, layers, layerspp, normalization
19 | from util import register_module
20 | import torch.nn as nn
21 | import functools
22 | import torch
23 | import numpy as np
24 |
25 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
26 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
27 | Combine = layerspp.Combine
28 | conv3x3 = layerspp.conv3x3
29 | conv1x1 = layerspp.conv1x1
30 | get_act = layers.get_act
31 | get_normalization = normalization.get_normalization
32 | default_initializer = layers.default_init
33 |
34 |
35 | @register_module(category="clf_fn", name="ncsnpp_clf")
36 | class NCSNppClassifier(nn.Module):
37 | """NCSN++ Classifier model for Class conditional generation. Same as the NCSN++ but removes the upsampling stage"""
38 |
39 | def __init__(self, config):
40 | super().__init__()
41 | self.config = config.model
42 | self.act = act = get_act(self.config.clf_fn)
43 |
44 | self.nf = nf = self.config.clf_fn.nf
45 | ch_mult = self.config.clf_fn.ch_mult
46 | self.num_res_blocks = num_res_blocks = self.config.clf_fn.num_res_blocks
47 | self.attn_resolutions = attn_resolutions = self.config.clf_fn.attn_resolutions
48 | dropout = self.config.clf_fn.dropout
49 | resamp_with_conv = self.config.clf_fn.resamp_with_conv
50 | self.num_resolutions = num_resolutions = len(ch_mult)
51 | self.all_resolutions = all_resolutions = [
52 | config.data.image_size // (2**i) for i in range(num_resolutions)
53 | ]
54 |
55 | self.noise_cond = (
56 | noise_cond
57 | ) = self.config.clf_fn.noise_cond # noise-conditional
58 | fir = self.config.clf_fn.fir
59 | fir_kernel = self.config.clf_fn.fir_kernel
60 | self.skip_rescale = skip_rescale = self.config.clf_fn.skip_rescale
61 | self.resblock_type = resblock_type = self.config.clf_fn.resblock_type.lower()
62 | self.progressive = progressive = self.config.clf_fn.progressive.lower()
63 | self.progressive_input = (
64 | progressive_input
65 | ) = self.config.clf_fn.progressive_input.lower()
66 | self.embedding_type = embedding_type = self.config.clf_fn.embedding_type.lower()
67 | init_scale = self.config.clf_fn.init_scale
68 | assert progressive in ["none", "output_skip", "residual"]
69 | assert progressive_input in ["none", "input_skip", "residual"]
70 | assert embedding_type in ["fourier", "positional"]
71 | combine_method = self.config.clf_fn.progressive_combine.lower()
72 | combiner = functools.partial(Combine, method=combine_method)
73 |
74 | modules = []
75 | # timestep/noise_level embedding; only for continuous training
76 | if embedding_type == "fourier":
77 | # Gaussian Fourier features embeddings.
78 | assert (
79 | config.training.continuous
80 | ), "Fourier features are only used for continuous training."
81 |
82 | modules.append(
83 | layerspp.GaussianFourierProjection(
84 | embedding_size=nf, scale=self.config.clf_fn.fourier_scale
85 | )
86 | )
87 | embed_dim = 2 * nf
88 |
89 | elif embedding_type == "positional":
90 | embed_dim = nf
91 |
92 | else:
93 | raise ValueError(f"embedding type {embedding_type} unknown.")
94 |
95 | if noise_cond:
96 | modules.append(nn.Linear(embed_dim, nf * 4))
97 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
98 | nn.init.zeros_(modules[-1].bias)
99 | modules.append(nn.Linear(nf * 4, nf * 4))
100 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
101 | nn.init.zeros_(modules[-1].bias)
102 |
103 | AttnBlock = functools.partial(
104 | layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale
105 | )
106 |
107 | Downsample = functools.partial(
108 | layerspp.Downsample,
109 | with_conv=resamp_with_conv,
110 | fir=fir,
111 | fir_kernel=fir_kernel,
112 | )
113 |
114 | if progressive_input == "input_skip":
115 | self.pyramid_downsample = layerspp.Downsample(
116 | fir=fir, fir_kernel=fir_kernel, with_conv=False
117 | )
118 | elif progressive_input == "residual":
119 | pyramid_downsample = functools.partial(
120 | layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True
121 | )
122 |
123 | if resblock_type == "ddpm":
124 | ResnetBlock = functools.partial(
125 | ResnetBlockDDPM,
126 | act=act,
127 | dropout=dropout,
128 | init_scale=init_scale,
129 | skip_rescale=skip_rescale,
130 | temb_dim=nf * 4,
131 | )
132 |
133 | elif resblock_type == "biggan":
134 | ResnetBlock = functools.partial(
135 | ResnetBlockBigGAN,
136 | act=act,
137 | dropout=dropout,
138 | fir=fir,
139 | fir_kernel=fir_kernel,
140 | init_scale=init_scale,
141 | skip_rescale=skip_rescale,
142 | temb_dim=nf * 4,
143 | )
144 |
145 | else:
146 | raise ValueError(f"resblock type {resblock_type} unrecognized.")
147 |
148 | # Downsampling block
149 |
150 | channels = self.config.clf_fn.in_ch
151 | if progressive_input != "none":
152 | input_pyramid_ch = channels
153 |
154 | modules.append(conv3x3(channels, nf))
155 | hs_c = [nf]
156 |
157 | in_ch = nf
158 | for i_level in range(num_resolutions):
159 | # Residual blocks for this resolution
160 | for i_block in range(num_res_blocks):
161 | out_ch = nf * ch_mult[i_level]
162 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
163 | in_ch = out_ch
164 |
165 | if all_resolutions[i_level] in attn_resolutions:
166 | modules.append(AttnBlock(channels=in_ch))
167 | hs_c.append(in_ch)
168 |
169 | if i_level != num_resolutions - 1:
170 | if resblock_type == "ddpm":
171 | modules.append(Downsample(in_ch=in_ch))
172 | else:
173 | modules.append(ResnetBlock(down=True, in_ch=in_ch))
174 |
175 | if progressive_input == "input_skip":
176 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
177 | if combine_method == "cat":
178 | in_ch *= 2
179 |
180 | elif progressive_input == "residual":
181 | modules.append(
182 | pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)
183 | )
184 | input_pyramid_ch = in_ch
185 |
186 | hs_c.append(in_ch)
187 |
188 | in_ch = hs_c[-1]
189 | modules.append(ResnetBlock(in_ch=in_ch))
190 | modules.append(AttnBlock(channels=in_ch))
191 | modules.append(ResnetBlock(in_ch=in_ch))
192 |
193 | # Construct classifier
194 | self.n_cls = config.model.clf_fn.n_cls
195 | last_res = all_resolutions[-1]
196 | modules.append(nn.Linear(in_ch * last_res**2, self.n_cls, bias=False))
197 |
198 | self.all_modules = nn.ModuleList(modules)
199 |
200 | def forward(self, x, time_cond):
201 | # TODO: Add label and other forms of conditioning here!
202 | # timestep/noise_level embedding; only for continuous training
203 | modules = self.all_modules
204 | m_idx = 0
205 | if self.embedding_type == "fourier":
206 | # Gaussian Fourier features embeddings.
207 | used_sigmas = time_cond
208 | temb = modules[m_idx](torch.log(used_sigmas))
209 | m_idx += 1
210 |
211 | elif self.embedding_type == "positional":
212 | # Sinusoidal positional embeddings.
213 | timesteps = time_cond
214 | # used_sigmas = self.sigmas[time_cond.long()]
215 | temb = layers.get_timestep_embedding(timesteps, self.nf)
216 |
217 | else:
218 | raise ValueError(f"embedding type {self.embedding_type} unknown.")
219 |
220 | if self.noise_cond:
221 | temb = modules[m_idx](temb)
222 | m_idx += 1
223 | temb = modules[m_idx](self.act(temb))
224 | m_idx += 1
225 | else:
226 | temb = None
227 |
228 | # Downsampling block
229 | input_pyramid = None
230 | if self.progressive_input != "none":
231 | input_pyramid = x
232 |
233 | hs = [modules[m_idx](x)]
234 | m_idx += 1
235 | for i_level in range(self.num_resolutions):
236 | # Residual blocks for this resolution
237 | for i_block in range(self.num_res_blocks):
238 | h = modules[m_idx](hs[-1], temb)
239 | m_idx += 1
240 | if h.shape[-1] in self.attn_resolutions:
241 | h = modules[m_idx](h)
242 | m_idx += 1
243 |
244 | hs.append(h)
245 |
246 | if i_level != self.num_resolutions - 1:
247 | if self.resblock_type == "ddpm":
248 | h = modules[m_idx](hs[-1])
249 | m_idx += 1
250 | else:
251 | h = modules[m_idx](hs[-1], temb)
252 | m_idx += 1
253 |
254 | if self.progressive_input == "input_skip":
255 | input_pyramid = self.pyramid_downsample(input_pyramid)
256 | h = modules[m_idx](input_pyramid, h)
257 | m_idx += 1
258 |
259 | elif self.progressive_input == "residual":
260 | input_pyramid = modules[m_idx](input_pyramid)
261 | m_idx += 1
262 | if self.skip_rescale:
263 | input_pyramid = (input_pyramid + h) / np.sqrt(2.0)
264 | else:
265 | input_pyramid = input_pyramid + h
266 | h = input_pyramid
267 |
268 | hs.append(h)
269 |
270 | h = hs[-1]
271 | h = modules[m_idx](h, temb)
272 | m_idx += 1
273 | h = modules[m_idx](h)
274 | m_idx += 1
275 | h = modules[m_idx](h, temb)
276 | m_idx += 1
277 |
278 | # Apply Classifier
279 | h = torch.flatten(h, start_dim=1)
280 | h = modules[m_idx](h)
281 |
282 | assert h.shape[-1] == self.n_cls
283 | return h
284 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/normalization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Normalization layers."""
17 | import torch.nn as nn
18 | import torch
19 | import functools
20 |
21 |
22 | def get_normalization(config, conditional=False):
23 | """Obtain normalization modules from the config file."""
24 | norm = config.model.normalization
25 | if conditional:
26 | if norm == 'InstanceNorm++':
27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28 | else:
29 | raise NotImplementedError(f'{norm} not implemented yet.')
30 | else:
31 | if norm == 'InstanceNorm':
32 | return nn.InstanceNorm2d
33 | elif norm == 'InstanceNorm++':
34 | return InstanceNorm2dPlus
35 | elif norm == 'VarianceNorm':
36 | return VarianceNorm2d
37 | elif norm == 'GroupNorm':
38 | return nn.GroupNorm
39 | else:
40 | raise ValueError('Unknown normalization: %s' % norm)
41 |
42 |
43 | class ConditionalBatchNorm2d(nn.Module):
44 | def __init__(self, num_features, num_classes, bias=True):
45 | super().__init__()
46 | self.num_features = num_features
47 | self.bias = bias
48 | self.bn = nn.BatchNorm2d(num_features, affine=False)
49 | if self.bias:
50 | self.embed = nn.Embedding(num_classes, num_features * 2)
51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53 | else:
54 | self.embed = nn.Embedding(num_classes, num_features)
55 | self.embed.weight.data.uniform_()
56 |
57 | def forward(self, x, y):
58 | out = self.bn(x)
59 | if self.bias:
60 | gamma, beta = self.embed(y).chunk(2, dim=1)
61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62 | else:
63 | gamma = self.embed(y)
64 | out = gamma.view(-1, self.num_features, 1, 1) * out
65 | return out
66 |
67 |
68 | class ConditionalInstanceNorm2d(nn.Module):
69 | def __init__(self, num_features, num_classes, bias=True):
70 | super().__init__()
71 | self.num_features = num_features
72 | self.bias = bias
73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74 | if bias:
75 | self.embed = nn.Embedding(num_classes, num_features * 2)
76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78 | else:
79 | self.embed = nn.Embedding(num_classes, num_features)
80 | self.embed.weight.data.uniform_()
81 |
82 | def forward(self, x, y):
83 | h = self.instance_norm(x)
84 | if self.bias:
85 | gamma, beta = self.embed(y).chunk(2, dim=-1)
86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87 | else:
88 | gamma = self.embed(y)
89 | out = gamma.view(-1, self.num_features, 1, 1) * h
90 | return out
91 |
92 |
93 | class ConditionalVarianceNorm2d(nn.Module):
94 | def __init__(self, num_features, num_classes, bias=False):
95 | super().__init__()
96 | self.num_features = num_features
97 | self.bias = bias
98 | self.embed = nn.Embedding(num_classes, num_features)
99 | self.embed.weight.data.normal_(1, 0.02)
100 |
101 | def forward(self, x, y):
102 | vars = torch.var(x, dim=(2, 3), keepdim=True)
103 | h = x / torch.sqrt(vars + 1e-5)
104 |
105 | gamma = self.embed(y)
106 | out = gamma.view(-1, self.num_features, 1, 1) * h
107 | return out
108 |
109 |
110 | class VarianceNorm2d(nn.Module):
111 | def __init__(self, num_features, bias=False):
112 | super().__init__()
113 | self.num_features = num_features
114 | self.bias = bias
115 | self.alpha = nn.Parameter(torch.zeros(num_features))
116 | self.alpha.data.normal_(1, 0.02)
117 |
118 | def forward(self, x):
119 | vars = torch.var(x, dim=(2, 3), keepdim=True)
120 | h = x / torch.sqrt(vars + 1e-5)
121 |
122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h
123 | return out
124 |
125 |
126 | class ConditionalNoneNorm2d(nn.Module):
127 | def __init__(self, num_features, num_classes, bias=True):
128 | super().__init__()
129 | self.num_features = num_features
130 | self.bias = bias
131 | if bias:
132 | self.embed = nn.Embedding(num_classes, num_features * 2)
133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135 | else:
136 | self.embed = nn.Embedding(num_classes, num_features)
137 | self.embed.weight.data.uniform_()
138 |
139 | def forward(self, x, y):
140 | if self.bias:
141 | gamma, beta = self.embed(y).chunk(2, dim=-1)
142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143 | else:
144 | gamma = self.embed(y)
145 | out = gamma.view(-1, self.num_features, 1, 1) * x
146 | return out
147 |
148 |
149 | class NoneNorm2d(nn.Module):
150 | def __init__(self, num_features, bias=True):
151 | super().__init__()
152 |
153 | def forward(self, x):
154 | return x
155 |
156 |
157 | class InstanceNorm2dPlus(nn.Module):
158 | def __init__(self, num_features, bias=True):
159 | super().__init__()
160 | self.num_features = num_features
161 | self.bias = bias
162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163 | self.alpha = nn.Parameter(torch.zeros(num_features))
164 | self.gamma = nn.Parameter(torch.zeros(num_features))
165 | self.alpha.data.normal_(1, 0.02)
166 | self.gamma.data.normal_(1, 0.02)
167 | if bias:
168 | self.beta = nn.Parameter(torch.zeros(num_features))
169 |
170 | def forward(self, x):
171 | means = torch.mean(x, dim=(2, 3))
172 | m = torch.mean(means, dim=-1, keepdim=True)
173 | v = torch.var(means, dim=-1, keepdim=True)
174 | means = (means - m) / (torch.sqrt(v + 1e-5))
175 | h = self.instance_norm(x)
176 |
177 | if self.bias:
178 | h = h + means[..., None, None] * self.alpha[..., None, None]
179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180 | else:
181 | h = h + means[..., None, None] * self.alpha[..., None, None]
182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h
183 | return out
184 |
185 |
186 | class ConditionalInstanceNorm2dPlus(nn.Module):
187 | def __init__(self, num_features, num_classes, bias=True):
188 | super().__init__()
189 | self.num_features = num_features
190 | self.bias = bias
191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192 | if bias:
193 | self.embed = nn.Embedding(num_classes, num_features * 3)
194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196 | else:
197 | self.embed = nn.Embedding(num_classes, 2 * num_features)
198 | self.embed.weight.data.normal_(1, 0.02)
199 |
200 | def forward(self, x, y):
201 | means = torch.mean(x, dim=(2, 3))
202 | m = torch.mean(means, dim=-1, keepdim=True)
203 | v = torch.var(means, dim=-1, keepdim=True)
204 | means = (means - m) / (torch.sqrt(v + 1e-5))
205 | h = self.instance_norm(x)
206 |
207 | if self.bias:
208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209 | h = h + means[..., None, None] * alpha[..., None, None]
210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211 | else:
212 | gamma, alpha = self.embed(y).chunk(2, dim=-1)
213 | h = h + means[..., None, None] * alpha[..., None, None]
214 | out = gamma.view(-1, self.num_features, 1, 1) * h
215 | return out
216 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | grad_bias = grad_input.sum(dim).detach()
39 |
40 | return grad_input, grad_bias
41 |
42 | @staticmethod
43 | def backward(ctx, gradgrad_input, gradgrad_bias):
44 | out, = ctx.saved_tensors
45 | gradgrad_out = fused.fused_bias_act(
46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47 | )
48 |
49 | return gradgrad_out, None, None, None
50 |
51 |
52 | class FusedLeakyReLUFunction(Function):
53 | @staticmethod
54 | def forward(ctx, input, bias, negative_slope, scale):
55 | empty = input.new_empty(0)
56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57 | ctx.save_for_backward(out)
58 | ctx.negative_slope = negative_slope
59 | ctx.scale = scale
60 |
61 | return out
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | out, = ctx.saved_tensors
66 |
67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68 | grad_output, out, ctx.negative_slope, ctx.scale
69 | )
70 |
71 | return grad_input, grad_bias, None, None
72 |
73 |
74 | class FusedLeakyReLU(nn.Module):
75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76 | super().__init__()
77 |
78 | self.bias = nn.Parameter(torch.zeros(channel))
79 | self.negative_slope = negative_slope
80 | self.scale = scale
81 |
82 | def forward(self, input):
83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84 |
85 |
86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87 | if input.device.type == "cpu":
88 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
89 | return (
90 | F.leaky_relu(
91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92 | )
93 | * scale
94 | )
95 |
96 | else:
97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
98 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | upfirdn2d_op = load(
11 | "upfirdn2d",
12 | sources=[
13 | os.path.join(module_path, "upfirdn2d.cpp"),
14 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
15 | ],
16 | )
17 |
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | if input.device.type == "cpu":
147 | out = upfirdn2d_native(
148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
149 | )
150 |
151 | else:
152 | out = UpFirDn2d.apply(
153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
154 | )
155 |
156 | return out
157 |
158 |
159 | def upfirdn2d_native(
160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
161 | ):
162 | _, channel, in_h, in_w = input.shape
163 | input = input.reshape(-1, in_h, in_w, 1)
164 |
165 | _, in_h, in_w, minor = input.shape
166 | kernel_h, kernel_w = kernel.shape
167 |
168 | out = input.view(-1, in_h, 1, in_w, 1, minor)
169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
171 |
172 | out = F.pad(
173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
174 | )
175 | out = out[
176 | :,
177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
179 | :,
180 | ]
181 |
182 | out = out.permute(0, 3, 1, 2)
183 | out = out.reshape(
184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
185 | )
186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
187 | out = F.conv2d(out, w)
188 | out = out.reshape(
189 | -1,
190 | minor,
191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
193 | )
194 | out = out.permute(0, 2, 3, 1)
195 | out = out[:, ::down_y, ::down_x, :]
196 |
197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
199 |
200 | return out.view(-1, channel, out_h, out_w)
201 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/up_or_down_sampling.py:
--------------------------------------------------------------------------------
1 | """Layers used for up-sampling or down-sampling images.
2 |
3 | Many functions are ported from https://github.com/NVlabs/stylegan2.
4 | """
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 | import numpy as np
10 | from .op import upfirdn2d
11 |
12 |
13 | # Function ported from StyleGAN2
14 | def get_weight(module,
15 | shape,
16 | weight_var='weight',
17 | kernel_init=None):
18 | """Get/create weight tensor for a convolution or fully-connected layer."""
19 |
20 | return module.param(weight_var, kernel_init, shape)
21 |
22 |
23 | class Conv2d(nn.Module):
24 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
25 |
26 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
27 | resample_kernel=(1, 3, 3, 1),
28 | use_bias=True,
29 | kernel_init=None):
30 | super().__init__()
31 | assert not (up and down)
32 | assert kernel >= 1 and kernel % 2 == 1
33 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
34 | if kernel_init is not None:
35 | self.weight.data = kernel_init(self.weight.data.shape)
36 | if use_bias:
37 | self.bias = nn.Parameter(torch.zeros(out_ch))
38 |
39 | self.up = up
40 | self.down = down
41 | self.resample_kernel = resample_kernel
42 | self.kernel = kernel
43 | self.use_bias = use_bias
44 |
45 | def forward(self, x):
46 | if self.up:
47 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
48 | elif self.down:
49 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
50 | else:
51 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
52 |
53 | if self.use_bias:
54 | x = x + self.bias.reshape(1, -1, 1, 1)
55 |
56 | return x
57 |
58 |
59 | def naive_upsample_2d(x, factor=2):
60 | _N, C, H, W = x.shape
61 | x = torch.reshape(x, (-1, C, H, 1, W, 1))
62 | x = x.repeat(1, 1, 1, factor, 1, factor)
63 | return torch.reshape(x, (-1, C, H * factor, W * factor))
64 |
65 |
66 | def naive_downsample_2d(x, factor=2):
67 | _N, C, H, W = x.shape
68 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
69 | return torch.mean(x, dim=(3, 5))
70 |
71 |
72 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
73 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
74 |
75 | Padding is performed only once at the beginning, not between the
76 | operations.
77 | The fused op is considerably more efficient than performing the same
78 | calculation
79 | using standard TensorFlow ops. It supports gradients of arbitrary order.
80 | Args:
81 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
82 | C]`.
83 | w: Weight tensor of the shape `[filterH, filterW, inChannels,
84 | outChannels]`. Grouped convolution can be performed by `inChannels =
85 | x.shape[0] // numGroups`.
86 | k: FIR filter of the shape `[firH, firW]` or `[firN]`
87 | (separable). The default is `[1] * factor`, which corresponds to
88 | nearest-neighbor upsampling.
89 | factor: Integer upsampling factor (default: 2).
90 | gain: Scaling factor for signal magnitude (default: 1.0).
91 |
92 | Returns:
93 | Tensor of the shape `[N, C, H * factor, W * factor]` or
94 | `[N, H * factor, W * factor, C]`, and same datatype as `x`.
95 | """
96 |
97 | assert isinstance(factor, int) and factor >= 1
98 |
99 | # Check weight shape.
100 | assert len(w.shape) == 4
101 | convH = w.shape[2]
102 | convW = w.shape[3]
103 | inC = w.shape[1]
104 | outC = w.shape[0]
105 |
106 | assert convW == convH
107 |
108 | # Setup filter kernel.
109 | if k is None:
110 | k = [1] * factor
111 | k = _setup_kernel(k) * (gain * (factor ** 2))
112 | p = (k.shape[0] - factor) - (convW - 1)
113 |
114 | stride = (factor, factor)
115 |
116 | # Determine data dimensions.
117 | stride = [1, 1, factor, factor]
118 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
119 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
120 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
121 | assert output_padding[0] >= 0 and output_padding[1] >= 0
122 | num_groups = _shape(x, 1) // inC
123 |
124 | # Transpose weights.
125 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
126 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
127 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
128 |
129 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
130 | ## Original TF code.
131 | # x = tf.nn.conv2d_transpose(
132 | # x,
133 | # w,
134 | # output_shape=output_shape,
135 | # strides=stride,
136 | # padding='VALID',
137 | # data_format=data_format)
138 | ## JAX equivalent
139 |
140 | return upfirdn2d(x, torch.tensor(k, device=x.device),
141 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
142 |
143 |
144 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
145 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
146 |
147 | Padding is performed only once at the beginning, not between the operations.
148 | The fused op is considerably more efficient than performing the same
149 | calculation
150 | using standard TensorFlow ops. It supports gradients of arbitrary order.
151 | Args:
152 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
153 | C]`.
154 | w: Weight tensor of the shape `[filterH, filterW, inChannels,
155 | outChannels]`. Grouped convolution can be performed by `inChannels =
156 | x.shape[0] // numGroups`.
157 | k: FIR filter of the shape `[firH, firW]` or `[firN]`
158 | (separable). The default is `[1] * factor`, which corresponds to
159 | average pooling.
160 | factor: Integer downsampling factor (default: 2).
161 | gain: Scaling factor for signal magnitude (default: 1.0).
162 |
163 | Returns:
164 | Tensor of the shape `[N, C, H // factor, W // factor]` or
165 | `[N, H // factor, W // factor, C]`, and same datatype as `x`.
166 | """
167 |
168 | assert isinstance(factor, int) and factor >= 1
169 | _outC, _inC, convH, convW = w.shape
170 | assert convW == convH
171 | if k is None:
172 | k = [1] * factor
173 | k = _setup_kernel(k) * gain
174 | p = (k.shape[0] - factor) + (convW - 1)
175 | s = [factor, factor]
176 | x = upfirdn2d(x, torch.tensor(k, device=x.device),
177 | pad=((p + 1) // 2, p // 2))
178 | return F.conv2d(x, w, stride=s, padding=0)
179 |
180 |
181 | def _setup_kernel(k):
182 | k = np.asarray(k, dtype=np.float32)
183 | if k.ndim == 1:
184 | k = np.outer(k, k)
185 | k /= np.sum(k)
186 | assert k.ndim == 2
187 | assert k.shape[0] == k.shape[1]
188 | return k
189 |
190 |
191 | def _shape(x, dim):
192 | return x.shape[dim]
193 |
194 |
195 | def upsample_2d(x, k=None, factor=2, gain=1):
196 | r"""Upsample a batch of 2D images with the given filter.
197 |
198 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
199 | and upsamples each image with the given filter. The filter is normalized so
200 | that
201 | if the input pixels are constant, they will be scaled by the specified
202 | `gain`.
203 | Pixels outside the image are assumed to be zero, and the filter is padded
204 | with
205 | zeros so that its shape is a multiple of the upsampling factor.
206 | Args:
207 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
208 | C]`.
209 | k: FIR filter of the shape `[firH, firW]` or `[firN]`
210 | (separable). The default is `[1] * factor`, which corresponds to
211 | nearest-neighbor upsampling.
212 | factor: Integer upsampling factor (default: 2).
213 | gain: Scaling factor for signal magnitude (default: 1.0).
214 |
215 | Returns:
216 | Tensor of the shape `[N, C, H * factor, W * factor]`
217 | """
218 | assert isinstance(factor, int) and factor >= 1
219 | if k is None:
220 | k = [1] * factor
221 | k = _setup_kernel(k) * (gain * (factor ** 2))
222 | p = k.shape[0] - factor
223 | return upfirdn2d(x, torch.tensor(k, device=x.device),
224 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
225 |
226 |
227 | def downsample_2d(x, k=None, factor=2, gain=1):
228 | r"""Downsample a batch of 2D images with the given filter.
229 |
230 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
231 | and downsamples each image with the given filter. The filter is normalized
232 | so that
233 | if the input pixels are constant, they will be scaled by the specified
234 | `gain`.
235 | Pixels outside the image are assumed to be zero, and the filter is padded
236 | with
237 | zeros so that its shape is a multiple of the downsampling factor.
238 | Args:
239 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
240 | C]`.
241 | k: FIR filter of the shape `[firH, firW]` or `[firN]`
242 | (separable). The default is `[1] * factor`, which corresponds to
243 | average pooling.
244 | factor: Integer downsampling factor (default: 2).
245 | gain: Scaling factor for signal magnitude (default: 1.0).
246 |
247 | Returns:
248 | Tensor of the shape `[N, C, H // factor, W // factor]`
249 | """
250 |
251 | assert isinstance(factor, int) and factor >= 1
252 | if k is None:
253 | k = [1] * factor
254 | k = _setup_kernel(k) * gain
255 | p = k.shape[0] - factor
256 | return upfirdn2d(x, torch.tensor(k, device=x.device),
257 | down=factor, pad=((p + 1) // 2, p // 2))
258 |
--------------------------------------------------------------------------------
/main/models/score_fn/song_sde/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """All functions and modules related to model definition.
17 | """
18 |
19 | import torch
20 | import numpy as np
21 |
22 |
23 | _MODELS = {}
24 |
25 |
26 | def register_model(cls=None, *, name=None):
27 | """A decorator for registering model classes."""
28 |
29 | def _register(cls):
30 | if name is None:
31 | local_name = cls.__name__
32 | else:
33 | local_name = name
34 | if local_name in _MODELS:
35 | raise ValueError(f'Already registered model with name: {local_name}')
36 | _MODELS[local_name] = cls
37 | return cls
38 |
39 | if cls is None:
40 | return _register
41 | else:
42 | return _register(cls)
43 |
44 |
45 | def get_model(name):
46 | return _MODELS[name]
47 |
48 |
49 | def get_sigmas(config):
50 | """Get sigmas --- the set of noise levels for SMLD from config files.
51 | Args:
52 | config: A ConfigDict object parsed from the config file
53 | Returns:
54 | sigmas: a jax numpy arrary of noise levels
55 | """
56 | sigmas = np.exp(
57 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))
58 |
59 | return sigmas
60 |
61 |
62 | def get_ddpm_params(config):
63 | """Get betas and alphas --- parameters used in the original DDPM paper."""
64 | num_diffusion_timesteps = 1000
65 | # parameters need to be adapted if number of time steps differs from 1000
66 | beta_start = config.model.beta_min / config.model.num_scales
67 | beta_end = config.model.beta_max / config.model.num_scales
68 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
69 |
70 | alphas = 1. - betas
71 | alphas_cumprod = np.cumprod(alphas, axis=0)
72 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
73 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
74 |
75 | return {
76 | 'betas': betas,
77 | 'alphas': alphas,
78 | 'alphas_cumprod': alphas_cumprod,
79 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
80 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
81 | 'beta_min': beta_start * (num_diffusion_timesteps - 1),
82 | 'beta_max': beta_end * (num_diffusion_timesteps - 1),
83 | 'num_diffusion_timesteps': num_diffusion_timesteps
84 | }
85 |
86 |
87 | def create_model(config):
88 | """Create the score model."""
89 | model_name = config.model.name
90 | score_model = get_model(model_name)(config)
91 | score_model = score_model.to(config.device)
92 | score_model = torch.nn.DataParallel(score_model)
93 | return score_model
94 |
95 |
96 | def get_model_fn(model, train=False):
97 | """Create a function to give the output of the score-based model.
98 |
99 | Args:
100 | model: The score model.
101 | train: `True` for training and `False` for evaluation.
102 |
103 | Returns:
104 | A model function.
105 | """
106 |
107 | def model_fn(x, labels):
108 | """Compute the output of the score-based model.
109 |
110 | Args:
111 | x: A mini-batch of input data.
112 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
113 | for different models.
114 |
115 | Returns:
116 | A tuple of (model output, new mutable states)
117 | """
118 | if not train:
119 | model.eval()
120 | return model(x, labels)
121 | else:
122 | model.train()
123 | return model(x, labels)
124 |
125 | return model_fn
126 |
127 |
128 | # def get_score_fn(sde, model, train=False, continuous=False):
129 | # """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
130 |
131 | # Args:
132 | # sde: An `sde_lib.SDE` object that represents the forward SDE.
133 | # model: A score model.
134 | # train: `True` for training and `False` for evaluation.
135 | # continuous: If `True`, the score-based model is expected to directly take continuous time steps.
136 |
137 | # Returns:
138 | # A score function.
139 | # """
140 | # model_fn = get_model_fn(model, train=train)
141 |
142 | # if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
143 | # def score_fn(x, t):
144 | # # Scale neural network output by standard deviation and flip sign
145 | # if continuous or isinstance(sde, sde_lib.subVPSDE):
146 | # # For VP-trained models, t=0 corresponds to the lowest noise level
147 | # # The maximum value of time embedding is assumed to 999 for
148 | # # continuously-trained models.
149 | # labels = t * 999
150 | # score = model_fn(x, labels)
151 | # std = sde.marginal_prob(torch.zeros_like(x), t)[1]
152 | # else:
153 | # # For VP-trained models, t=0 corresponds to the lowest noise level
154 | # labels = t * (sde.N - 1)
155 | # score = model_fn(x, labels)
156 | # std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
157 |
158 | # score = -score / std[:, None, None, None]
159 | # return score
160 |
161 | # elif isinstance(sde, sde_lib.VESDE):
162 | # def score_fn(x, t):
163 | # if continuous:
164 | # labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
165 | # else:
166 | # # For VE-trained models, t=0 corresponds to the highest noise level
167 | # labels = sde.T - t
168 | # labels *= sde.N - 1
169 | # labels = torch.round(labels).long()
170 |
171 | # score = model_fn(x, labels)
172 | # return score
173 |
174 | # else:
175 | # raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
176 |
177 | # return score_fn
178 |
179 |
180 | def to_flattened_numpy(x):
181 | """Flatten a torch tensor `x` and convert it to numpy."""
182 | return x.detach().cpu().numpy().reshape((-1,))
183 |
184 |
185 | def from_flattened_numpy(x, shape):
186 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
187 | return torch.from_numpy(x.reshape(shape))
--------------------------------------------------------------------------------
/main/models/sde/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mandt-lab/PSLD/24e186793acc63dc98acfbc1142ac1e0356c6aa3/main/models/sde/__init__.py
--------------------------------------------------------------------------------
/main/models/sde/base.py:
--------------------------------------------------------------------------------
1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
2 | import abc
3 |
4 |
5 | class SDE(abc.ABC):
6 | """SDE abstract class. Functions are designed for a mini-batch of inputs."""
7 |
8 | def __init__(self, N):
9 | """Construct an SDE.
10 |
11 | Args:
12 | N: number of discretization time steps.
13 | """
14 | super().__init__()
15 | self.N = N
16 |
17 | @property
18 | @abc.abstractmethod
19 | def T(self):
20 | """End time of the SDE."""
21 | raise NotImplementedError
22 |
23 | @abc.abstractmethod
24 | def get_score(self, eps, t):
25 | """Computes the score from a given epsilon value"""
26 | raise NotImplementedError
27 |
28 | @abc.abstractmethod
29 | def perturb_data(self, x_0, t, noise=None):
30 | """Add noise to a data point"""
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def sde(self, x, t):
35 | """Return the drift and the diffusion coefficients"""
36 | raise NotImplementedError
37 |
38 | @abc.abstractmethod
39 | def reverse_sde(self, x, t, score_fn, probability_flow=False):
40 | """Return the drift and the diffusion coefficients of the reverse-sde"""
41 | raise NotImplementedError
42 |
43 | @abc.abstractmethod
44 | def cond_marginal_prob(self, x, t):
45 | """Parameters to determine the marginal distribution of the SDE, $p_t(x_t|x_0)$."""
46 | raise NotImplementedError
47 |
48 | @abc.abstractmethod
49 | def prior_sampling(self, shape):
50 | """Generate one sample from the prior distribution, $p_T(x)$."""
51 | raise NotImplementedError
52 |
53 | @abc.abstractmethod
54 | def prior_logp(self, z):
55 | """Compute log-density of the prior distribution.
56 |
57 | Useful for computing the log-likelihood via probability flow ODE.
58 |
59 | Args:
60 | z: latent code
61 | Returns:
62 | log probability density
63 | """
64 | pass
65 |
--------------------------------------------------------------------------------
/main/models/sde/vpsde.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from util import register_module, reshape
4 |
5 | from .base import SDE
6 |
7 |
8 | @register_module(category="sde", name="vpsde")
9 | class VPSDE(SDE):
10 | def __init__(self, config):
11 | """Construct a Variance Preserving (VP) SDE."""
12 | super().__init__(config.model.sde.n_timesteps)
13 | self.beta_0 = config.model.sde.beta_min
14 | self.beta_1 = config.model.sde.beta_max
15 |
16 | def beta_t(self, t):
17 | return self.beta_0 + t * (self.beta_1 - self.beta_0)
18 |
19 | @property
20 | def T(self):
21 | return 1.0
22 |
23 | @property
24 | def type(self):
25 | return "vpsde"
26 |
27 | def get_score(self, eps, t):
28 | return -eps / reshape(self._std(t), eps)
29 |
30 | def perturb_data(self, x_0, t, noise=None):
31 | """Add noise to input data point"""
32 | noise = torch.randn_like(x_0) if None else noise
33 | assert noise.shape == x_0.shape
34 |
35 | mu_t, std_t = self.cond_marginal_prob(x_0, t)
36 | assert mu_t.shape == x_0.shape
37 | assert len(std_t.shape) == len(x_0.shape)
38 | return mu_t + noise * std_t
39 |
40 | def sde(self, x_t, t):
41 | """Return the drift and diffusion coefficients"""
42 | beta_t = reshape(self.beta_t(t), x_t)
43 | drift = -0.5 * beta_t * x_t
44 | diffusion = torch.sqrt(beta_t)
45 |
46 | return drift, diffusion
47 |
48 | def reverse_sde(self, x_t, t, score_fn, probability_flow=False):
49 | """Return the drift and diffusion coefficients of the reverse sde"""
50 | # The reverse SDE is defines on the domain (T-t)
51 | t = self.T - t
52 |
53 | # Forward drift and diffusion
54 | f, g = self.sde(x_t, t)
55 |
56 | # scale the score by 0.5 for the prob. flow formulation
57 | eps_pred = score_fn(x_t.type(torch.float32), t.type(torch.float32))
58 | score = self.get_score(eps_pred, t)
59 | score = 0.5 * score if probability_flow else score
60 |
61 | # Reverse drift
62 | f_bar = -f + g**2 * score
63 | assert f_bar.shape == f.shape
64 |
65 | # Reverse diffusion (0 for prob. flow)
66 | g_bar = g if not probability_flow else torch.zeros_like(g)
67 | return f_bar, g_bar
68 |
69 | def cond_marginal_prob(self, x_0, t):
70 | """Generate samples from the perturbation kernel p(x_t|x_0)"""
71 | log_mean_coeff = (
72 | -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
73 | )
74 | mean = torch.exp(log_mean_coeff[:, None, None, None]) * x_0
75 | assert mean.shape == x_0.shape
76 |
77 | std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
78 | std = reshape(std, x_0)
79 | return mean, std
80 |
81 | def _std(self, t):
82 | log_mean_coeff = (
83 | -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
84 | )
85 | return torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
86 |
87 | def prior_sampling(self, shape):
88 | """Generate samples from the prior p(x_T)"""
89 | return torch.randn(*shape)
90 |
91 | def prior_logp(self, z):
92 | shape = z.shape
93 | N = np.prod(shape[1:])
94 | logps = -N / 2.0 * np.log(2 * np.pi) - torch.sum(z**2, dim=(1, 2, 3)) / 2.0
95 | return logps
96 |
97 | def likelihood_weighting(self, t):
98 | # Return g(t)^2 for likelihood training and computation
99 | return self.beta_t(t)
100 |
--------------------------------------------------------------------------------
/main/models/wrapper.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from pytorch_lightning.utilities.seed import seed_everything
6 | from util import get_module, register_module
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | @register_module(category="pl_modules", name="sde_wrapper")
12 | class SDEWrapper(pl.LightningModule):
13 | """This PL module can do the following tasks:
14 | - train: Train an unconditional score model using HSM or DSM
15 | - predict: Generate unconditional samples from a pre-trained score model
16 | """
17 |
18 | def __init__(
19 | self,
20 | config,
21 | sde,
22 | score_fn,
23 | ema_score_fn=None,
24 | criterion=None,
25 | sampler_cls=None,
26 | corrector_fn=None,
27 | ):
28 | super().__init__()
29 | self.config = config
30 | self.score_fn = score_fn
31 | self.ema_score_fn = ema_score_fn
32 | self.sde = sde
33 |
34 | # Training
35 | self.criterion = criterion
36 | self.train_eps = self.config.training.train_eps
37 |
38 | # Evaluation
39 | self.sampler = None
40 | if sampler_cls is not None:
41 | self.sampler = sampler_cls(
42 | self.config,
43 | self.sde,
44 | self.ema_score_fn
45 | if config.evaluation.sample_from == "target"
46 | else self.score_fn,
47 | corrector_fn=corrector_fn,
48 | )
49 | self.eval_eps = self.config.evaluation.eval_eps
50 | self.denoise = self.config.evaluation.denoise
51 | n_discrete_steps = self.config.evaluation.n_discrete_steps
52 | self.n_discrete_steps = (
53 | n_discrete_steps - 1 if self.denoise else n_discrete_steps
54 | )
55 | self.val_eps = self.config.evaluation.eval_eps
56 | self.stride_type = self.config.evaluation.stride_type
57 |
58 | # Disable automatic optimization
59 | self.automatic_optimization = False
60 |
61 | def forward(self):
62 | pass
63 |
64 | def training_step(self, batch, batch_idx):
65 | # Optimizers
66 | optim = self.optimizers()
67 | lr_sched = self.lr_schedulers()
68 |
69 | x_0 = batch
70 |
71 | # Sample timepoints (between [eps, 1])
72 | t_ = torch.rand(x_0.shape[0], device=x_0.device, dtype=torch.float64)
73 | t = t_ * (self.sde.T - self.train_eps) + self.train_eps
74 | assert t.shape[0] == x_0.shape[0]
75 |
76 | # Compute loss and backward
77 | loss = self.criterion(x_0, t, self.score_fn)
78 | optim.zero_grad()
79 | self.manual_backward(loss)
80 |
81 | # Clip gradients (if enabled.)
82 | if self.config.training.optimizer.grad_clip != 0:
83 | torch.nn.utils.clip_grad_norm_(
84 | self.score_fn.parameters(), self.config.training.optimizer.grad_clip
85 | )
86 | optim.step()
87 |
88 | # Scheduler step
89 | lr_sched.step()
90 | self.log("loss", loss, prog_bar=True)
91 | return loss
92 |
93 | def on_predict_start(self):
94 | seed = self.config.evaluation.seed
95 |
96 | # This is done for predictions since setting a common seed
97 | # leads to generating same samples across gpus which affects
98 | # the evaluation metrics like FID negatively.
99 | seed_everything(seed + self.global_rank, workers=True)
100 |
101 | def predict_step(self, batch, batch_idx, dataloader_idx=None):
102 | t_final = self.sde.T - self.eval_eps
103 | ts = torch.linspace(
104 | 0,
105 | t_final,
106 | self.n_discrete_steps + 1,
107 | device=self.device,
108 | dtype=torch.float64,
109 | )
110 |
111 | if self.stride_type == "uniform":
112 | pass
113 | elif self.stride_type == "quadratic":
114 | ts = t_final * torch.flip(1 - (ts / t_final) ** 2.0, dims=[0])
115 |
116 | return self.sampler.sample(
117 | batch,
118 | ts,
119 | self.n_discrete_steps,
120 | denoise=self.denoise,
121 | eps=self.eval_eps,
122 | )
123 |
124 | def on_predict_end(self):
125 | if isinstance(self.sampler, get_module("samplers", "bb_ode")):
126 | print(self.sampler.mean_nfe)
127 |
128 | def configure_optimizers(self):
129 | opt_config = self.config.training.optimizer
130 | opt_name = opt_config.name
131 | if opt_name == "Adam":
132 | optimizer = torch.optim.Adam(
133 | self.score_fn.parameters(),
134 | lr=opt_config.lr,
135 | betas=(opt_config.beta_1, opt_config.beta_2),
136 | eps=opt_config.eps,
137 | weight_decay=opt_config.weight_decay,
138 | )
139 | else:
140 | raise NotImplementedError(f"Optimizer {opt_name} not supported yet!")
141 |
142 | # Define the LR scheduler (As in Ho et al.)
143 | if opt_config.warmup == 0:
144 | lr_lambda = lambda step: 1.0
145 | else:
146 | lr_lambda = lambda step: min(step / opt_config.warmup, 1.0)
147 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
148 | return {
149 | "optimizer": optimizer,
150 | "lr_scheduler": {
151 | "scheduler": scheduler,
152 | "interval": "step",
153 | "strict": False,
154 | },
155 | }
156 |
--------------------------------------------------------------------------------
/main/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .sde import EulerMaruyamaSampler, ClassCondEulerMaruyamaSampler
2 | from .ode import BBODESampler
3 |
--------------------------------------------------------------------------------
/main/samplers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Sampler(abc.ABC):
5 | """The abstract class for a Sampler algorithm."""
6 |
7 | def __init__(self, config, sde, score_fn, corrector_fn=None):
8 | super().__init__()
9 | self.config = config
10 | self.sde = sde
11 | self.score_fn = score_fn
12 | self.corrector_fn = corrector_fn
13 |
14 | @property
15 | def n_steps(self):
16 | return self.config.evaluation.n_discrete_steps
17 |
18 | @abc.abstractmethod
19 | def predictor_update_fn(self):
20 | raise NotImplementedError
21 |
22 | def corrector_update_fn(self, x, t, dt):
23 | if self.corrector_fn is not None:
24 | return self.corrector_fn(x, t, dt)
25 |
26 | # If no corrector is defined, Identity is default
27 | return x, x
28 |
29 | @abc.abstractmethod
30 | def sample(self):
31 | raise NotImplementedError
32 |
--------------------------------------------------------------------------------
/main/samplers/ode.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchdiffeq import odeint
3 | from util import register_module
4 |
5 | from .base import Sampler
6 |
7 |
8 | @register_module(category="samplers", name="bb_ode")
9 | class BBODESampler(Sampler):
10 | """Black-Box ODE sampler for generating samples from the
11 | Probability Flow ODE.
12 | """
13 |
14 | def __init__(self, config, sde, score_fn, corrector_fn=None):
15 | super().__init__(config, sde, score_fn, corrector_fn=corrector_fn)
16 | self.nfe = 0
17 | self.rtol = config.evaluation.sampler.rtol
18 | self.atol = config.evaluation.sampler.atol
19 | self.solver_opts = {"solver": config.evaluation.sampler.solver}
20 |
21 | self._counter = 0
22 |
23 | @property
24 | def n_steps(self):
25 | return self.nfe
26 |
27 | @property
28 | def mean_nfe(self):
29 | if self._counter != 0:
30 | return self.nfe / self._counter
31 | raise ValueError("Run .sample() to compute mean_nfe")
32 |
33 | def predictor_update_fn(self, x, t, dt):
34 | pass
35 |
36 | def denoise_fn(self, x, t, dt):
37 | f, _ = self.sde.reverse_sde(x, t, self.score_fn, probability_flow=True)
38 | x_mean = x + f * dt
39 | return x_mean
40 |
41 | def sample(self, batch, ts, n_discrete_steps, denoise=True, eps=1e-3):
42 | def ode_fn(t, x):
43 | self.nfe += 1
44 | vec_t = torch.ones(x.shape[0], device=x.device, dtype=torch.float64) * t
45 | f, _ = self.sde.reverse_sde(x, vec_t, self.score_fn, probability_flow=True)
46 | return f
47 |
48 | x = batch
49 | self._counter += 1
50 |
51 | time_tensor = torch.tensor(
52 | [0.0, self.sde.T - eps], dtype=torch.float64, device=x.device
53 | )
54 | # BB-ODE solver
55 | solution = odeint(
56 | ode_fn,
57 | x,
58 | time_tensor,
59 | rtol=self.rtol,
60 | atol=self.atol,
61 | method="scipy_solver",
62 | options=self.solver_opts,
63 | )
64 |
65 | x = solution[-1]
66 |
67 | # Denoise
68 | if denoise:
69 | x = self.denoise_fn(
70 | x,
71 | torch.ones(x.shape[0], device=x.device, dtype=torch.float64)
72 | * (self.sde.T - eps),
73 | eps,
74 | )
75 | self.nfe += 1
76 | return x
77 |
--------------------------------------------------------------------------------
/main/train_clf.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import hydra
5 | import pytorch_lightning as pl
6 | from models.clf_wrapper import TClfWrapper
7 | from omegaconf import OmegaConf
8 | from pytorch_lightning.callbacks import ModelCheckpoint
9 | from pytorch_lightning.strategies import DDPStrategy
10 | from pytorch_lightning.utilities.seed import seed_everything
11 | from torch.utils.data import DataLoader
12 | from util import get_dataset, get_module, import_modules_into_registry
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | # Import all modules into registry
18 | import_modules_into_registry()
19 |
20 |
21 | @hydra.main(config_path="configs")
22 | def train_clf(config):
23 | """Helper script for training a noise conditioned classifier for guidance purposes"""
24 | # Get config and setup
25 | config = config.dataset
26 | logger.info(OmegaConf.to_yaml(config))
27 |
28 | # Set seed
29 | seed_everything(config.clf.training.seed, workers=True)
30 |
31 | # Setup dataset
32 | dataset = get_dataset(config.clf)
33 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}")
34 |
35 | # Setup score predictor
36 | clf_fn_cls = get_module(category="clf_fn", name=config.clf.model.clf_fn.name)
37 | clf_fn = clf_fn_cls(config.clf)
38 | logger.info(f"Using Classifier fn: {clf_fn_cls}")
39 |
40 | # Setup Score SDE
41 | sde_cls = get_module(category="sde", name=config.diffusion.model.sde.name)
42 | sde = sde_cls(config.diffusion)
43 | logger.info(f"Using SDE: {sde_cls} with type: {sde.type}")
44 | logger.info(sde)
45 |
46 | # Setup Loss fn
47 | criterion_cls = get_module(category="losses", name=config.clf.training.loss.name)
48 | criterion = criterion_cls(config, sde)
49 | logger.info(f"Using Loss: {criterion_cls}")
50 |
51 | # Setup Lightning Wrapper Module
52 | wrapper = TClfWrapper(config, sde, clf_fn, score_fn=None, criterion=criterion)
53 |
54 | # Setup Trainer
55 | train_kwargs = {}
56 |
57 | # Setup callbacks
58 | results_dir = config.clf.training.results_dir
59 | chkpt_callback = ModelCheckpoint(
60 | dirpath=os.path.join(results_dir, "checkpoints"),
61 | filename=f"{config.clf.model.clf_fn.name}-{config.diffusion.model.sde.name}-{config.clf.training.chkpt_prefix}"
62 | + "-{epoch:02d}-{loss:.4f}",
63 | every_n_epochs=config.clf.training.chkpt_interval,
64 | save_top_k=-1,
65 | )
66 |
67 | train_kwargs["default_root_dir"] = results_dir
68 | train_kwargs["max_epochs"] = config.clf.training.epochs
69 | train_kwargs["log_every_n_steps"] = config.clf.training.log_step
70 | train_kwargs["callbacks"] = [chkpt_callback]
71 |
72 | device_type = config.clf.training.accelerator
73 | train_kwargs["accelerator"] = device_type
74 | loader_kws = {}
75 | if device_type == "gpu":
76 | train_kwargs["devices"] = config.clf.training.devices
77 |
78 | # Disable find_unused_parameters when using DDP training for performance reasons
79 | train_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False)
80 | loader_kws["persistent_workers"] = True
81 | elif device_type == "tpu":
82 | train_kwargs["tpu_cores"] = 8
83 |
84 | # Half precision training
85 | if config.clf.training.fp16:
86 | train_kwargs["precision"] = 16
87 |
88 | # Loader
89 | batch_size = config.clf.training.batch_size
90 | batch_size = min(len(dataset), batch_size)
91 | loader = DataLoader(
92 | dataset,
93 | batch_size,
94 | num_workers=config.clf.training.workers,
95 | pin_memory=True,
96 | shuffle=True,
97 | drop_last=True,
98 | **loader_kws,
99 | )
100 |
101 | # Trainer
102 | logger.info(f"Running Trainer with kwargs: {train_kwargs}")
103 | trainer = pl.Trainer(**train_kwargs)
104 |
105 | # Restore checkpoint
106 | restore_path = config.clf.training.restore_path
107 | if restore_path == "":
108 | restore_path = None
109 | trainer.fit(wrapper, train_dataloaders=loader, ckpt_path=restore_path)
110 |
111 |
112 | if __name__ == "__main__":
113 | train_clf()
114 |
--------------------------------------------------------------------------------
/main/train_sde.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from copy import deepcopy
4 |
5 | import hydra
6 | import pytorch_lightning as pl
7 | from callbacks import EMAWeightUpdate
8 | from omegaconf import OmegaConf
9 | from pytorch_lightning.callbacks import ModelCheckpoint
10 | from pytorch_lightning.utilities.seed import seed_everything
11 | from torch.utils.data import DataLoader
12 | from util import get_dataset, get_module, import_modules_into_registry
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | # Import all modules into registry
18 | import_modules_into_registry()
19 |
20 |
21 | @hydra.main(config_path="configs")
22 | def train(config):
23 | """Helper script for training a score-based generative model"""
24 | # Get config and setup
25 | config = config.dataset.diffusion
26 | logger.info(OmegaConf.to_yaml(config))
27 |
28 | # Set seed
29 | seed_everything(config.training.seed, workers=True)
30 |
31 | # Setup dataset
32 | dataset = get_dataset(config)
33 | logger.info(f"Using Dataset: {dataset} with size: {len(dataset)}")
34 |
35 | # Setup score predictor
36 | score_fn_cls = get_module(category="score_fn", name=config.model.score_fn.name)
37 | score_fn = score_fn_cls(config)
38 | logger.info(f"Using Score fn: {score_fn_cls}")
39 |
40 | # Setup target network for EMA
41 | ema_score_fn = deepcopy(score_fn)
42 | for p in ema_score_fn.parameters():
43 | p.requires_grad = False
44 |
45 | # Setup Score SDE
46 | sde_cls = get_module(category="sde", name=config.model.sde.name)
47 | sde = sde_cls(config)
48 | logger.info(f"Using SDE: {sde_cls} with type: {sde.type}")
49 | logger.info(sde)
50 |
51 | # Setup Loss fn
52 | criterion_cls = get_module(category="losses", name=config.training.loss.name)
53 | criterion = criterion_cls(config, sde)
54 | logger.info(f"Using Loss: {criterion_cls}")
55 |
56 | # Setup Lightning Wrapper Module
57 | wrapper_cls = get_module(category="pl_modules", name=config.model.pl_module)
58 | wrapper = wrapper_cls(
59 | config, sde, score_fn, ema_score_fn=ema_score_fn, criterion=criterion
60 | )
61 |
62 | # Setup Trainer
63 | train_kwargs = {}
64 |
65 | # Setup callbacks
66 | results_dir = config.training.results_dir
67 | chkpt_callback = ModelCheckpoint(
68 | dirpath=os.path.join(results_dir, "checkpoints"),
69 | filename=f"{config.model.sde.name}-{config.training.chkpt_prefix}"
70 | + "-{epoch:02d}-{loss:.4f}",
71 | every_n_epochs=config.training.chkpt_interval,
72 | save_top_k=-1,
73 | )
74 |
75 | train_kwargs["default_root_dir"] = results_dir
76 | train_kwargs["max_epochs"] = config.training.epochs
77 | train_kwargs["log_every_n_steps"] = config.training.log_step
78 | train_kwargs["callbacks"] = [chkpt_callback]
79 | if config.training.use_ema:
80 | ema_callback = EMAWeightUpdate(tau=config.training.ema_decay)
81 | train_kwargs["callbacks"].append(ema_callback)
82 |
83 | device_type = config.training.accelerator
84 | train_kwargs["accelerator"] = device_type
85 | loader_kws = {}
86 | if device_type == "gpu":
87 | train_kwargs["devices"] = config.training.devices
88 |
89 | # Disable find_unused_parameters when using DDP training for performance reasons
90 | # train_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False)
91 | loader_kws["persistent_workers"] = True
92 | elif device_type == "tpu":
93 | train_kwargs["tpu_cores"] = 8
94 |
95 | # Half precision training
96 | if config.training.fp16:
97 | train_kwargs["precision"] = 16
98 |
99 | # Loader
100 | batch_size = config.training.batch_size
101 | batch_size = min(len(dataset), batch_size)
102 | loader = DataLoader(
103 | dataset,
104 | batch_size,
105 | num_workers=config.training.workers,
106 | pin_memory=True,
107 | shuffle=True,
108 | drop_last=True,
109 | **loader_kws,
110 | )
111 |
112 | # Trainer
113 | logger.info(f"Running Trainer with kwargs: {train_kwargs}")
114 | trainer = pl.Trainer(**train_kwargs, strategy="ddp")
115 |
116 | # Restore checkpoint
117 | restore_path = config.training.restore_path
118 | if restore_path == "":
119 | restore_path = None
120 | trainer.fit(wrapper, train_dataloaders=loader, ckpt_path=restore_path)
121 |
122 |
123 | if __name__ == "__main__":
124 | train()
125 |
--------------------------------------------------------------------------------
/main/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import torchvision.transforms as T
6 | from PIL import Image
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | _MODULES = {}
11 |
12 |
13 | def reshape(t, rt):
14 | """Adds additional dimensions corresponding to the size of the
15 | reference tensor rt.
16 | """
17 | if len(rt.shape) == len(t.shape):
18 | return t
19 | ones = [1] * len(rt.shape[1:])
20 | t_ = t.view(-1, *ones)
21 | assert len(t_.shape) == len(rt.shape)
22 | return t_
23 |
24 |
25 | def data_scaler(img, norm=True):
26 | if norm:
27 | img = (np.asarray(img).astype(np.float) / 127.5) - 1.0
28 | else:
29 | img = np.asarray(img).astype(np.float) / 255.0
30 | return img
31 |
32 |
33 | def register_module(category=None, name=None):
34 | """A decorator for registering model classes."""
35 |
36 | def _register(cls):
37 | local_category = category
38 | if local_category is None:
39 | local_category = cls.__name__ if name is None else name
40 |
41 | # Create category (if does not exist)
42 | if local_category not in _MODULES:
43 | _MODULES[local_category] = {}
44 |
45 | # Add module to the category
46 | local_name = cls.__name__ if name is None else name
47 | if name in _MODULES[local_category]:
48 | raise ValueError(
49 | f"Already registered module with name: {local_name} in category: {category}"
50 | )
51 |
52 | _MODULES[local_category][local_name] = cls
53 | return cls
54 |
55 | return _register
56 |
57 |
58 | def get_module(category, name):
59 | module = _MODULES.get(category, dict()).get(name, None)
60 | if module is None:
61 | raise ValueError(f"No module named `{name}` found in category: `{category}`")
62 | return module
63 |
64 |
65 | def configure_device(device):
66 | if device.startswith("gpu"):
67 | if not torch.cuda.is_available():
68 | raise Exception(
69 | "CUDA support is not available on your platform. Re-run using CPU or TPU mode"
70 | )
71 | gpu_id = device.split(":")[-1]
72 | if gpu_id == "":
73 | # Use all GPU's
74 | gpu_id = -1
75 | gpu_id = [int(id) for id in gpu_id.split(",")]
76 | return f"cuda:{gpu_id}", gpu_id
77 | return device
78 |
79 |
80 | def get_dataset(config):
81 | # TODO: Add support for dynamically adding **kwargs directly via config
82 | # Parse config
83 | name = config.data.name
84 | root = config.data.root
85 | image_size = config.data.image_size
86 | norm = config.data.norm
87 | flip = config.data.hflip
88 |
89 | # Checks
90 | assert isinstance(norm, bool)
91 |
92 | if name.lower() == "cifar10":
93 | assert image_size == 32
94 |
95 | # Construct transforms
96 | t_list = [T.Resize((image_size, image_size))]
97 | if flip:
98 | t_list.append(T.RandomHorizontalFlip())
99 | transform = T.Compose(t_list)
100 |
101 | # Get dataset
102 | dataset_cls = get_module(category="datasets", name=name.lower())
103 | if dataset_cls is None:
104 | raise ValueError(
105 | f"Dataset with name: {name} not found in category: `datasets`. Ensure its properly registered"
106 | )
107 |
108 | return dataset_cls(
109 | root,
110 | norm=norm,
111 | transform=transform,
112 | return_target=config.data.return_target,
113 | )
114 |
115 |
116 | def import_modules_into_registry():
117 | logger.info("Importing modules into registry")
118 | import datasets
119 | import losses
120 | import models
121 | import samplers
122 |
123 |
124 | def convert_to_np(obj):
125 | obj = obj.permute(0, 2, 3, 1).contiguous()
126 | obj = obj.detach().cpu().numpy()
127 |
128 | obj_list = []
129 | for _, out in enumerate(obj):
130 | obj_list.append(out)
131 | return obj_list
132 |
133 |
134 | def normalize(obj):
135 | B, C, H, W = obj.shape
136 | for i in range(3):
137 | channel_val = obj[:, i, :, :].view(B, -1)
138 | channel_val -= channel_val.min(1, keepdim=True)[0]
139 | channel_val /= (
140 | channel_val.max(1, keepdim=True)[0] - channel_val.min(1, keepdim=True)[0]
141 | )
142 | channel_val = channel_val.view(B, H, W)
143 | obj[:, i, :, :] = channel_val
144 | return obj
145 |
146 |
147 | def save_as_images(obj, file_name="output", denorm=True):
148 | # Saves predictions as png images (useful for Sample generation)
149 | if denorm:
150 | # obj = normalize(obj)
151 | obj = obj * 0.5 + 0.5
152 | obj_list = convert_to_np(obj)
153 |
154 | for i, out in enumerate(obj_list):
155 | out = (out * 255).clip(0, 255).astype(np.uint8)
156 | img_out = Image.fromarray(out)
157 | current_file_name = file_name + "_%d.png" % i
158 | img_out.save(current_file_name, "png")
159 |
160 |
161 | def save_as_np(obj, file_name="output", denorm=True):
162 | # Saves predictions directly as numpy arrays
163 | if denorm:
164 | obj = normalize(obj)
165 | obj_list = convert_to_np(obj)
166 |
167 | for i, out in enumerate(obj_list):
168 | current_file_name = file_name + "_%d.npy" % i
169 | np.save(current_file_name, out)
170 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/cond/afhqv2/sample_inpaint_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/inpaint.py +dataset=afhqv2/afhqv2128_es3sde \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
3 | dataset.diffusion.data.name='afhqv2' \
4 | +dataset.diffusion.data.mask_path=\'/home/pandeyk1/datasets\' \
5 | dataset.diffusion.data.norm=True \
6 | dataset.diffusion.data.hflip=True \
7 | dataset.diffusion.model.score_fn.in_ch=6 \
8 | dataset.diffusion.model.score_fn.out_ch=6 \
9 | dataset.diffusion.model.score_fn.nf=128 \
10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \
11 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
13 | dataset.diffusion.model.score_fn.dropout=0.2 \
14 | dataset.diffusion.model.sde.beta_min=8.0 \
15 | dataset.diffusion.model.sde.beta_max=8.0 \
16 | dataset.diffusion.model.sde.nu=4.01 \
17 | dataset.diffusion.model.sde.gamma=0.01 \
18 | dataset.diffusion.model.sde.kappa=0.04 \
19 | dataset.diffusion.model.sde.decomp_mode='lower' \
20 | dataset.diffusion.evaluation.seed=0 \
21 | dataset.diffusion.evaluation.sample_prefix='gpu' \
22 | dataset.diffusion.evaluation.path_prefix="1000" \
23 | dataset.diffusion.evaluation.devices=8 \
24 | dataset.diffusion.evaluation.batch_size=1 \
25 | dataset.diffusion.evaluation.stride_type='uniform' \
26 | dataset.diffusion.evaluation.sample_from='target' \
27 | dataset.diffusion.evaluation.workers=1 \
28 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/es3sde-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \
29 | dataset.diffusion.evaluation.sampler.name="ip_em_sde" \
30 | dataset.diffusion.evaluation.n_samples=32 \
31 | dataset.diffusion.evaluation.n_discrete_steps=1000 \
32 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/inpaint/es3sde_hsm_gamma=0.01_nu=4.01_afhqv2_continuous_sfn=ncsnpp_nres=2/dummy_samples/test/wild/\'
33 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/cond/afhqv2/sample_tclf_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/class_cond_sample.py +dataset=afhqv2/afhqv2128_es3sde \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
3 | dataset.diffusion.data.name='afhqv2' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.2 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.nu=4.01 \
16 | dataset.diffusion.model.sde.gamma=0.01 \
17 | dataset.diffusion.model.sde.kappa=0.04 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.evaluation.seed=0 \
20 | dataset.diffusion.evaluation.sample_prefix='gpu' \
21 | dataset.diffusion.evaluation.path_prefix="1000" \
22 | dataset.diffusion.evaluation.devices=8 \
23 | dataset.diffusion.evaluation.batch_size=1 \
24 | dataset.diffusion.evaluation.stride_type='uniform' \
25 | dataset.diffusion.evaluation.sample_from='target' \
26 | dataset.diffusion.evaluation.workers=1 \
27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/es3sde-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \
28 | dataset.diffusion.evaluation.sampler.name="cc_em_sde" \
29 | dataset.diffusion.evaluation.n_samples=32 \
30 | dataset.diffusion.evaluation.n_discrete_steps=1000 \
31 | dataset.clf.model.clf_fn.in_ch=6 \
32 | dataset.clf.model.clf_fn.nf=128 \
33 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \
34 | dataset.clf.model.clf_fn.num_res_blocks=4 \
35 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \
36 | dataset.clf.model.clf_fn.dropout=0.1 \
37 | dataset.clf.model.clf_fn.n_cls=3 \
38 | dataset.clf.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_afhqv2_continuous_clf=ncsnppclf/checkpoints/ncsnpp_clf-es3sde-tclf_gamma=0.01_nu=4.01_afhqv2_Feb27-epoch=299-loss=0.5960.ckpt\' \
39 | dataset.clf.evaluation.clf_temp=10.0 \
40 | dataset.clf.evaluation.label_to_sample=2 \
41 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/cc/es3sde_hsm_gamma=0.01_nu=4.01_afhqv2_continuous_sfn=ncsnpp_nres=2/cc_others/\'
42 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/cond/afhqv2/train_tclf_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_clf.py +dataset=afhqv2/afhqv2128_es3sde \
2 | dataset.clf.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
3 | dataset.clf.data.name='afhqv2' \
4 | dataset.clf.data.norm=True \
5 | dataset.clf.data.hflip=True \
6 | dataset.clf.data.return_target=True \
7 | dataset.clf.model.pl_module='tclf_wrapper' \
8 | dataset.clf.model.clf_fn.in_ch=6 \
9 | dataset.clf.model.clf_fn.nf=128 \
10 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \
11 | dataset.clf.model.clf_fn.num_res_blocks=4 \
12 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \
13 | dataset.clf.model.clf_fn.dropout=0.1 \
14 | dataset.clf.model.clf_fn.n_cls=3 \
15 | dataset.diffusion.model.sde.beta_min=8.0 \
16 | dataset.diffusion.model.sde.beta_max=8.0 \
17 | dataset.diffusion.model.sde.decomp_mode='lower' \
18 | dataset.diffusion.model.sde.nu=4.01 \
19 | dataset.diffusion.model.sde.gamma=0.01 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.clf.training.loss.name='tce_loss' \
22 | dataset.clf.training.seed=0 \
23 | dataset.clf.training.chkpt_interval=100 \
24 | dataset.clf.training.fp16=False \
25 | dataset.clf.training.batch_size=16 \
26 | dataset.clf.training.epochs=2000 \
27 | dataset.clf.training.accelerator='gpu' \
28 | dataset.clf.training.devices=[1,3,5,6] \
29 | dataset.clf.training.results_dir=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_afhqv2_continuous_clf=ncsnppclf/\' \
30 | dataset.clf.training.workers=1 \
31 | dataset.clf.training.chkpt_prefix=\"tclf_gamma=0.01_nu=4.01_afhqv2_Feb27\"
--------------------------------------------------------------------------------
/scripts_psld/ablations/cond/cifar10/sample_tclf_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/class_cond_sample.py +dataset=cifar10/cifar10_es3sde \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.15 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.nu=4.01 \
19 | dataset.diffusion.model.sde.gamma=0.01 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.model.sde.decomp_mode='lower' \
22 | dataset.diffusion.evaluation.seed=0 \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.path_prefix="1000" \
25 | dataset.diffusion.evaluation.devices=8 \
26 | dataset.diffusion.evaluation.batch_size=16 \
27 | dataset.diffusion.evaluation.stride_type='uniform' \
28 | dataset.diffusion.evaluation.sample_from='target' \
29 | dataset.diffusion.evaluation.workers=1 \
30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/sota/uncond/es3sde_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/es3sde-hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_3rdFeb-epoch=2249-loss=0.0148.ckpt\' \
31 | dataset.diffusion.evaluation.sampler.name="cc_em_sde" \
32 | dataset.diffusion.evaluation.n_samples=64 \
33 | dataset.diffusion.evaluation.n_discrete_steps=1000 \
34 | dataset.clf.model.clf_fn.in_ch=6 \
35 | dataset.clf.model.clf_fn.nf=128 \
36 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \
37 | dataset.clf.model.clf_fn.num_res_blocks=4 \
38 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \
39 | dataset.clf.model.clf_fn.dropout=0.1 \
40 | dataset.clf.model.clf_fn.n_cls=10 \
41 | dataset.clf.evaluation.chkpt_path=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0.01_nu=4.01_cifar10_continuous_clf=ncsnppclf/checkpoints/ncsnpp_clf-es3sde-tclf_gamma=0.01_nu=4.01_cifar10_Feb27-epoch=999-loss=1.4715.ckpt\' \
42 | dataset.clf.evaluation.clf_temp=5.0 \
43 | dataset.clf.evaluation.label_to_sample=9 \
44 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/es3sde_results/sota/cond/cc/es3sde_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/dummy_samples/cc_cifar10_9\' \
45 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/cond/cifar10/train_tclf_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_clf.py +dataset=cifar10/cifar10_es3sde \
2 | dataset.clf.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.clf.data.name='cifar10' \
4 | dataset.clf.data.norm=True \
5 | dataset.clf.data.hflip=True \
6 | dataset.clf.data.return_target=True \
7 | dataset.clf.model.pl_module='tclf_wrapper' \
8 | dataset.clf.model.clf_fn.in_ch=6 \
9 | dataset.clf.model.clf_fn.nf=128 \
10 | dataset.clf.model.clf_fn.ch_mult=[1,2,3,4] \
11 | dataset.clf.model.clf_fn.num_res_blocks=4 \
12 | dataset.clf.model.clf_fn.attn_resolutions=[16,8] \
13 | dataset.clf.model.clf_fn.dropout=0.1 \
14 | dataset.clf.model.clf_fn.n_cls=10 \
15 | dataset.diffusion.model.sde.beta_min=8.0 \
16 | dataset.diffusion.model.sde.beta_max=8.0 \
17 | dataset.diffusion.model.sde.decomp_mode='lower' \
18 | dataset.diffusion.model.sde.nu=4.0 \
19 | dataset.diffusion.model.sde.gamma=0 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.clf.training.loss.name='tce_loss' \
22 | dataset.clf.training.seed=0 \
23 | dataset.clf.training.chkpt_interval=100 \
24 | dataset.clf.training.fp16=False \
25 | dataset.clf.training.batch_size=64 \
26 | dataset.clf.training.epochs=2000 \
27 | dataset.clf.training.accelerator='gpu' \
28 | dataset.clf.training.devices=[1,3,5,6] \
29 | dataset.clf.training.results_dir=\'/home/pandeyk1/es3sde_results/ablations/cond/controllable/es3sde_tclf_gamma=0_nu=4.0_cifar10_continuous_clf=ncsnppclf/\' \
30 | dataset.clf.training.workers=1 \
31 | dataset.clf.training.chkpt_prefix=\"tclf_gamma=0_nu=4.0_cifar10_Feb27\"
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/afhqv2/sample_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=afhqv2/afhqv2128_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
3 | dataset.diffusion.data.name='afhqv2' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.2 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.nu=4.01 \
16 | dataset.diffusion.model.sde.gamma=0.01 \
17 | dataset.diffusion.model.sde.kappa=0.04 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.evaluation.seed=0 \
20 | dataset.diffusion.evaluation.sample_prefix='gpu' \
21 | dataset.diffusion.evaluation.devices=[0] \
22 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/psld_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/dummy_samples/\' \
23 | dataset.diffusion.evaluation.batch_size=2 \
24 | dataset.diffusion.evaluation.stride_type='quadratic' \
25 | dataset.diffusion.evaluation.sample_from='target' \
26 | dataset.diffusion.evaluation.workers=1 \
27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/psld_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/checkpoints/psld-hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23-epoch=1749-loss=0.0033.ckpt\' \
28 | dataset.diffusion.evaluation.sampler.name="em_sde" \
29 | dataset.diffusion.evaluation.n_samples=128 \
30 | dataset.diffusion.evaluation.n_discrete_steps=1000
31 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/afhqv2/train_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_sde.py +dataset=afhqv2/afhqv2128_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
3 | dataset.diffusion.data.name='afhqv2' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.pl_module='sde_wrapper' \
7 | dataset.diffusion.model.score_fn.in_ch=6 \
8 | dataset.diffusion.model.score_fn.out_ch=6 \
9 | dataset.diffusion.model.score_fn.nf=128 \
10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2,3] \
11 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
13 | dataset.diffusion.model.score_fn.dropout=0.2 \
14 | dataset.diffusion.model.sde.beta_min=8.0 \
15 | dataset.diffusion.model.sde.beta_max=8.0 \
16 | dataset.diffusion.model.sde.decomp_mode='lower' \
17 | dataset.diffusion.model.sde.nu=4.01 \
18 | dataset.diffusion.model.sde.gamma=0.01 \
19 | dataset.diffusion.model.sde.kappa=0.04 \
20 | dataset.diffusion.model.sde.numerical_eps=1e-9 \
21 | dataset.diffusion.training.loss.name='psld_score_loss' \
22 | dataset.diffusion.training.seed=0 \
23 | dataset.diffusion.training.chkpt_interval=50 \
24 | dataset.diffusion.training.mode='hsm' \
25 | dataset.diffusion.training.fp16=False \
26 | dataset.diffusion.training.use_ema=True \
27 | dataset.diffusion.training.batch_size=8 \
28 | dataset.diffusion.training.epochs=2000 \
29 | dataset.diffusion.training.accelerator='gpu' \
30 | dataset.diffusion.training.devices=8 \
31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ncsnpp/\' \
32 | dataset.diffusion.training.workers=1 \
33 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.01_nu=4.01_afhq128_20thFeb23\"
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/celeba64/sample_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \
3 | dataset.diffusion.data.name='celeba64' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=3 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.nu=4.0 \
16 | dataset.diffusion.model.sde.gamma=0.0 \
17 | dataset.diffusion.model.sde.kappa=0.04 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.evaluation.seed=0 \
20 | dataset.diffusion.evaluation.sample_prefix='gpu' \
21 | dataset.diffusion.evaluation.devices=8 \
22 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0_nu=4.0_celeba64_continuous_sfn=ncsnpp/speedvsquality/sscs_uniform/\' \
23 | dataset.diffusion.evaluation.batch_size=6 \
24 | dataset.diffusion.evaluation.stride_type='uniform' \
25 | dataset.diffusion.evaluation.sample_from='target' \
26 | dataset.diffusion.evaluation.workers=1 \
27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0_nu=4.0_celeba64_continuous_sfn=ncsnpp/checkpoints/psld-hsm_ablation_gamma=0_nu=4.0_celeba64_24thJan23-epoch=199-loss=0.0036.ckpt\' \
28 | dataset.diffusion.evaluation.sampler.name="sscs_sde" \
29 | dataset.diffusion.evaluation.n_samples=10000 \
30 | dataset.diffusion.evaluation.n_discrete_steps=50 \
31 | dataset.diffusion.evaluation.path_prefix='50'
32 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/celeba64/train_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_sde.py +dataset=celeba64/celeba64_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \
3 | dataset.diffusion.data.name='celeba64' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.pl_module='sde_wrapper' \
7 | dataset.diffusion.model.score_fn.in_ch=6 \
8 | dataset.diffusion.model.score_fn.out_ch=6 \
9 | dataset.diffusion.model.score_fn.nf=128 \
10 | dataset.diffusion.model.score_fn.ch_mult=[1,1,2,2,2] \
11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
13 | dataset.diffusion.model.score_fn.dropout=0.1 \
14 | dataset.diffusion.model.sde.beta_min=8.0 \
15 | dataset.diffusion.model.sde.beta_max=8.0 \
16 | dataset.diffusion.model.sde.decomp_mode='lower' \
17 | dataset.diffusion.model.sde.nu=4.25 \
18 | dataset.diffusion.model.sde.gamma=0.25 \
19 | dataset.diffusion.model.sde.kappa=0.04 \
20 | dataset.diffusion.training.loss.name='psld_score_loss' \
21 | dataset.diffusion.training.seed=0 \
22 | dataset.diffusion.training.mode='hsm' \
23 | dataset.diffusion.training.fp16=False \
24 | dataset.diffusion.training.use_ema=True \
25 | dataset.diffusion.training.batch_size=32 \
26 | dataset.diffusion.training.epochs=200 \
27 | dataset.diffusion.training.accelerator='gpu' \
28 | dataset.diffusion.training.devices=4 \
29 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/celeba64/psld_hsm_gamma=0.01_nu=4.01_celeba64_continuous_sfn=ncsnpp/\' \
30 | dataset.diffusion.training.workers=1 \
31 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.01_nu=4.01_celeba64_27thJan23\"
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/sample_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=3 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.nu=4.0 \
16 | dataset.diffusion.model.sde.gamma=0 \
17 | dataset.diffusion.model.sde.kappa=0.04 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.model.sde.numerical_eps=1e-9 \
20 | dataset.diffusion.evaluation.seed=0 \
21 | dataset.diffusion.evaluation.sample_prefix='gpu' \
22 | dataset.diffusion.evaluation.devices=8 \
23 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/speedvsquality/sscs_uniform/\' \
24 | dataset.diffusion.evaluation.batch_size=16 \
25 | dataset.diffusion.evaluation.stride_type='uniform' \
26 | dataset.diffusion.evaluation.sample_from='target' \
27 | dataset.diffusion.evaluation.workers=1 \
28 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/psld-hsm_ablation_scorem_cifar10_5thJan23-epoch=1999-loss=0.0142.ckpt\' \
29 | dataset.diffusion.evaluation.sampler.name="sscs_sde" \
30 | dataset.diffusion.evaluation.n_samples=10000 \
31 | dataset.diffusion.evaluation.n_discrete_steps=1000 \
32 | dataset.diffusion.evaluation.path_prefix="1000"
33 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/sample_uncond_psld_ode.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=3 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.nu=4.0 \
16 | dataset.diffusion.model.sde.gamma=0 \
17 | dataset.diffusion.model.sde.kappa=0.04 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.model.sde.numerical_eps=1e-9 \
20 | dataset.diffusion.evaluation.seed=0 \
21 | dataset.diffusion.evaluation.sample_prefix='gpu' \
22 | dataset.diffusion.evaluation.devices=8 \
23 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/ode_exps/solver=rk45/\' \
24 | dataset.diffusion.evaluation.batch_size=16 \
25 | dataset.diffusion.evaluation.sample_from='target' \
26 | dataset.diffusion.evaluation.workers=1 \
27 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0_nu=4.0_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/psld-hsm_ablation_scorem_cifar10_5thJan23-epoch=1999-loss=0.0142.ckpt\' \
28 | dataset.diffusion.evaluation.sampler.name="bb_ode" \
29 | +dataset.diffusion.evaluation.sampler.solver="RK45" \
30 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \
31 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \
32 | dataset.diffusion.evaluation.n_samples=10000 \
33 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-5\"
34 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/sample_uncond_vpsde.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_vpsde \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=3 \
7 | dataset.diffusion.model.score_fn.out_ch=3 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=0.1 \
14 | dataset.diffusion.model.sde.beta_max=20 \
15 | dataset.diffusion.evaluation.seed=0 \
16 | dataset.diffusion.evaluation.sample_prefix='gpu' \
17 | dataset.diffusion.evaluation.devices=8 \
18 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/speedvsquality/em_quadratic/\' \
19 | dataset.diffusion.evaluation.batch_size=16 \
20 | dataset.diffusion.evaluation.stride_type='quadratic' \
21 | dataset.diffusion.evaluation.sample_from='target' \
22 | dataset.diffusion.evaluation.workers=1 \
23 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/vpsde-dsm_ablation_cifar10_5thJan23-epoch=1999-loss=0.0259.ckpt\' \
24 | dataset.diffusion.evaluation.sampler.name="em_sde" \
25 | dataset.diffusion.evaluation.n_samples=10000 \
26 | dataset.diffusion.evaluation.n_discrete_steps=1000 \
27 | dataset.diffusion.evaluation.path_prefix="1000"
28 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/sample_uncond_vpsde_ode.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_vpsde \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=3 \
7 | dataset.diffusion.model.score_fn.out_ch=3 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=0.1 \
14 | dataset.diffusion.model.sde.beta_max=20 \
15 | dataset.diffusion.evaluation.seed=0 \
16 | dataset.diffusion.evaluation.sample_prefix='gpu' \
17 | dataset.diffusion.evaluation.devices=8 \
18 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/ode_exps/solver=rk45/\' \
19 | dataset.diffusion.evaluation.batch_size=16 \
20 | dataset.diffusion.evaluation.sample_from='target' \
21 | dataset.diffusion.evaluation.workers=1 \
22 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/vpsde_cifar10_continuous_sfn=ncsnpp/checkpoints/cached_chkpts/vpsde-dsm_ablation_cifar10_5thJan23-epoch=1999-loss=0.0259.ckpt\' \
23 | dataset.diffusion.evaluation.sampler.name="bb_ode" \
24 | +dataset.diffusion.evaluation.sampler.solver="RK45" \
25 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \
26 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \
27 | dataset.diffusion.evaluation.n_samples=10000 \
28 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-5\"
29 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/train_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_sde.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.sde.beta_min=8.0 \
14 | dataset.diffusion.model.sde.beta_max=8.0 \
15 | dataset.diffusion.model.sde.decomp_mode='lower' \
16 | dataset.diffusion.model.sde.nu=4.02 \
17 | dataset.diffusion.model.sde.gamma=0.02 \
18 | dataset.diffusion.model.sde.kappa=0.04 \
19 | dataset.diffusion.training.seed=0 \
20 | dataset.diffusion.training.chkpt_interval=50 \
21 | dataset.diffusion.training.mode='hsm' \
22 | dataset.diffusion.training.fp16=False \
23 | dataset.diffusion.training.use_ema=True \
24 | dataset.diffusion.training.batch_size=32 \
25 | dataset.diffusion.training.epochs=2000 \
26 | dataset.diffusion.training.accelerator='gpu' \
27 | dataset.diffusion.training.devices=4 \
28 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/cifar10/psld_hsm_gamma=0.02_nu=4.02_decomp=upper_cifar10_continuous_sfn=ncsnpp/\' \
29 | dataset.diffusion.training.workers=1 \
30 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.02_nu=4.02_decomp=lower_cifar10_14thFeb23\"
31 |
--------------------------------------------------------------------------------
/scripts_psld/ablations/uncond/cifar10/train_uncond_vpsde.sh:
--------------------------------------------------------------------------------
1 | # CIFAR-10
2 | python main/train_sde.py +dataset=cifar10/cifar10_vpsde \
3 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
4 | dataset.diffusion.data.name='cifar10' \
5 | dataset.diffusion.data.norm=True \
6 | dataset.diffusion.data.hflip=True \
7 | dataset.diffusion.model.score_fn.in_ch=3 \
8 | dataset.diffusion.model.score_fn.out_ch=3 \
9 | dataset.diffusion.model.score_fn.nf=128 \
10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
13 | dataset.diffusion.model.score_fn.dropout=0.1 \
14 | dataset.diffusion.training.seed=0 \
15 | dataset.diffusion.training.fp16=False \
16 | dataset.diffusion.training.use_ema=True \
17 | dataset.diffusion.training.batch_size=32 \
18 | dataset.diffusion.training.epochs=2000 \
19 | dataset.diffusion.training.accelerator='gpu' \
20 | dataset.diffusion.training.devices=[0,1,2,3] \
21 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/ablations/uncond/vpsde_cifar10_continuous_sfn=ncsnpp_testnccl/\' \
22 | dataset.diffusion.training.workers=1 \
23 | dataset.diffusion.training.chkpt_prefix="dsm_ablation_cifar10_5thJan23"
--------------------------------------------------------------------------------
/scripts_psld/fid.sh:
--------------------------------------------------------------------------------
1 | fidelity --gpu 0 --fid --input1 /home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/1000/batch/ --input2 /home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0.01_nu=4.01_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/1000/images/
2 |
--------------------------------------------------------------------------------
/scripts_psld/sota/cond/afhqv2/sample_inpaint_psld.sh:
--------------------------------------------------------------------------------
1 | ulimit -n 64000
2 | python main/eval/inpaint.py +dataset=afhqv2/afhqv2128_psld \
3 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/afhqv2/\' \
4 | dataset.diffusion.data.name='afhqv2' \
5 | +dataset.diffusion.data.mask_path=\'/home/pandeyk1/datasets\' \
6 | dataset.diffusion.data.norm=True \
7 | dataset.diffusion.data.hflip=True \
8 | dataset.diffusion.model.score_fn.in_ch=6 \
9 | dataset.diffusion.model.score_fn.out_ch=3 \
10 | dataset.diffusion.model.score_fn.nf=160 \
11 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,3,3] \
12 | dataset.diffusion.model.score_fn.num_res_blocks=2 \
13 | dataset.diffusion.model.score_fn.attn_resolutions=[8,16] \
14 | dataset.diffusion.model.score_fn.dropout=0.2 \
15 | dataset.diffusion.model.sde.beta_min=8.0 \
16 | dataset.diffusion.model.sde.beta_max=8.0 \
17 | dataset.diffusion.model.sde.nu=4.0 \
18 | dataset.diffusion.model.sde.gamma=0 \
19 | dataset.diffusion.model.sde.kappa=0.04 \
20 | dataset.diffusion.model.sde.decomp_mode='lower' \
21 | dataset.diffusion.evaluation.seed=0 \
22 | dataset.diffusion.evaluation.denoise=True \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.path_prefix="1000" \
25 | dataset.diffusion.evaluation.devices=8 \
26 | dataset.diffusion.evaluation.batch_size=16 \
27 | dataset.diffusion.evaluation.stride_type='quadratic' \
28 | dataset.diffusion.evaluation.sample_from='target' \
29 | dataset.diffusion.evaluation.workers=1 \
30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0_nu=4.0_afhq128_continuous_sfn=ddpmpp/checkpoints/cached_chkpts/es3sde-hsm_sota_gamma=0_nu=4.0_afhq128_25May23_nf=160_chmult=12233-epoch=1999-loss=0.0011.ckpt\' \
31 | dataset.diffusion.evaluation.sampler.name="ip_em_sde" \
32 | dataset.diffusion.evaluation.n_samples=50000 \
33 | dataset.diffusion.evaluation.n_discrete_steps=250 \
34 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/afhq128/es3sde_hsm_gamma=0_nu=4.0_afhq128_continuous_sfn=ddpmpp/inpaint_results_val/\'
35 |
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/celeba64/sample_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \
3 | dataset.diffusion.data.name='celeba64' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.nu=4.005 \
19 | dataset.diffusion.model.sde.gamma=0.005 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.model.sde.decomp_mode='lower' \
22 | dataset.diffusion.evaluation.seed=0 \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.devices=7 \
25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/samples_50k/epoch=500_emquad/\' \
26 | dataset.diffusion.evaluation.batch_size=4 \
27 | dataset.diffusion.evaluation.stride_type='quadratic' \
28 | dataset.diffusion.evaluation.sample_from='target' \
29 | dataset.diffusion.evaluation.workers=0 \
30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/checkpoints/psld-hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23-epoch=499-loss=0.0041.ckpt\' \
31 | dataset.diffusion.evaluation.sampler.name="em_sde" \
32 | dataset.diffusion.evaluation.n_samples=50000 \
33 | dataset.diffusion.evaluation.n_discrete_steps=250 \
34 | dataset.diffusion.evaluation.path_prefix="250"
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/celeba64/sample_uncond_psld_ode.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=celeba64/celeba64_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \
3 | dataset.diffusion.data.name='celeba64' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.1 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.nu=4.005 \
19 | dataset.diffusion.model.sde.gamma=0.005 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.model.sde.decomp_mode='lower' \
22 | dataset.diffusion.evaluation.seed=0 \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.devices=7 \
25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/samples_10k/epoch=500_ode/\' \
26 | dataset.diffusion.evaluation.batch_size=4 \
27 | dataset.diffusion.evaluation.sample_from='target' \
28 | dataset.diffusion.evaluation.workers=0 \
29 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/checkpoints/psld-hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23-epoch=499-loss=0.0041.ckpt\' \
30 | dataset.diffusion.evaluation.sampler.name="bb_ode" \
31 | +dataset.diffusion.evaluation.sampler.solver="RK45" \
32 | +dataset.diffusion.evaluation.sampler.rtol=1e-5 \
33 | +dataset.diffusion.evaluation.sampler.atol=1e-5 \
34 | dataset.diffusion.evaluation.n_samples=10000 \
35 | dataset.diffusion.evaluation.path_prefix=\'tol=1e-5\'
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/celeba64/train_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_sde.py +dataset=celeba64/celeba64_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/img_align_celeba\' \
3 | dataset.diffusion.data.name='celeba64' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.pl_module='sde_wrapper' \
7 | dataset.diffusion.model.score_fn.in_ch=6 \
8 | dataset.diffusion.model.score_fn.out_ch=6 \
9 | dataset.diffusion.model.score_fn.nf=128 \
10 | dataset.diffusion.model.score_fn.ch_mult=[1,2,2,2] \
11 | dataset.diffusion.model.score_fn.num_res_blocks=4 \
12 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
13 | dataset.diffusion.model.score_fn.dropout=0.1 \
14 | dataset.diffusion.model.score_fn.progressive_input='residual' \
15 | dataset.diffusion.model.score_fn.fir=True \
16 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
17 | dataset.diffusion.model.sde.decomp_mode='lower' \
18 | dataset.diffusion.model.sde.nu=4.005 \
19 | dataset.diffusion.model.sde.gamma=0.005 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.training.loss.name='psld_score_loss' \
22 | dataset.diffusion.training.seed=0 \
23 | dataset.diffusion.training.chkpt_interval=25 \
24 | dataset.diffusion.training.mode='hsm' \
25 | dataset.diffusion.training.fp16=False \
26 | dataset.diffusion.training.use_ema=True \
27 | dataset.diffusion.training.batch_size=16 \
28 | dataset.diffusion.training.epochs=500 \
29 | dataset.diffusion.training.accelerator='gpu' \
30 | dataset.diffusion.training.devices=8 \
31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/sota/uncond/celeba64/psld_hsm_gamma=0.005_nu=4.005_celeba64_continuous_sfn=ncsnpp_chmult=1222_nres=4/\' \
32 | dataset.diffusion.training.workers=1 \
33 | dataset.diffusion.training.chkpt_prefix=\"hsm_ablation_gamma=0.005_nu=4.005_celeba64_17thFeb23\"
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/cifar10/sample_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.15 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.nu=4.02 \
19 | dataset.diffusion.model.sde.gamma=0.02 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.model.sde.decomp_mode='lower' \
22 | dataset.diffusion.evaluation.seed=0 \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.devices=8 \
25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/speedvsquality_50k/em_uniform/\' \
26 | dataset.diffusion.evaluation.batch_size=16 \
27 | dataset.diffusion.evaluation.stride_type='uniform' \
28 | dataset.diffusion.evaluation.sample_from='target' \
29 | dataset.diffusion.evaluation.workers=1 \
30 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/psld-hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_17thFeb-epoch=2349-loss=0.0201.ckpt\' \
31 | dataset.diffusion.evaluation.sampler.name="em_sde" \
32 | dataset.diffusion.evaluation.n_samples=50000 \
33 | dataset.diffusion.evaluation.n_discrete_steps=50 \
34 | dataset.diffusion.evaluation.path_prefix="50"
35 |
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/cifar10/sample_uncond_psld_ode.sh:
--------------------------------------------------------------------------------
1 | python main/eval/sample.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.15 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.nu=4.02 \
19 | dataset.diffusion.model.sde.gamma=0.02 \
20 | dataset.diffusion.model.sde.kappa=0.04 \
21 | dataset.diffusion.model.sde.decomp_mode='lower' \
22 | dataset.diffusion.evaluation.seed=0 \
23 | dataset.diffusion.evaluation.sample_prefix='gpu' \
24 | dataset.diffusion.evaluation.devices=8 \
25 | dataset.diffusion.evaluation.save_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/samples_50k_ode/epoch=2350/\' \
26 | dataset.diffusion.evaluation.batch_size=16 \
27 | dataset.diffusion.evaluation.sample_from='target' \
28 | dataset.diffusion.evaluation.workers=1 \
29 | dataset.diffusion.evaluation.chkpt_path=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/checkpoints/psld-hsm_gamma=0.02_nu=4.02_cifar10_continuous_sfn=ncsnpp_17thFeb-epoch=2349-loss=0.0201.ckpt\' \
30 | dataset.diffusion.evaluation.sampler.name="bb_ode" \
31 | +dataset.diffusion.evaluation.sampler.solver="RK45" \
32 | +dataset.diffusion.evaluation.sampler.rtol=1e-4 \
33 | +dataset.diffusion.evaluation.sampler.atol=1e-4 \
34 | dataset.diffusion.evaluation.n_samples=50000 \
35 | dataset.diffusion.evaluation.path_prefix=\"tol=1e-4\"
36 |
--------------------------------------------------------------------------------
/scripts_psld/sota/uncond/cifar10/train_uncond_psld.sh:
--------------------------------------------------------------------------------
1 | python main/train_sde.py +dataset=cifar10/cifar10_psld \
2 | dataset.diffusion.data.root=\'/home/pandeyk1/datasets/\' \
3 | dataset.diffusion.data.name='cifar10' \
4 | dataset.diffusion.data.norm=True \
5 | dataset.diffusion.data.hflip=True \
6 | dataset.diffusion.model.score_fn.in_ch=6 \
7 | dataset.diffusion.model.score_fn.out_ch=6 \
8 | dataset.diffusion.model.score_fn.nf=128 \
9 | dataset.diffusion.model.score_fn.ch_mult=[2,2,2] \
10 | dataset.diffusion.model.score_fn.num_res_blocks=8 \
11 | dataset.diffusion.model.score_fn.attn_resolutions=[16] \
12 | dataset.diffusion.model.score_fn.dropout=0.15 \
13 | dataset.diffusion.model.score_fn.progressive_input='residual' \
14 | dataset.diffusion.model.score_fn.fir=True \
15 | dataset.diffusion.model.score_fn.embedding_type='fourier' \
16 | dataset.diffusion.model.sde.beta_min=8.0 \
17 | dataset.diffusion.model.sde.beta_max=8.0 \
18 | dataset.diffusion.model.sde.decomp_mode='lower' \
19 | dataset.diffusion.model.sde.nu=4.01 \
20 | dataset.diffusion.model.sde.gamma=0.01 \
21 | dataset.diffusion.model.sde.kappa=0.04 \
22 | dataset.diffusion.training.seed=0 \
23 | dataset.diffusion.training.chkpt_interval=50 \
24 | dataset.diffusion.training.mode='hsm' \
25 | dataset.diffusion.training.fp16=False \
26 | dataset.diffusion.training.use_ema=True \
27 | dataset.diffusion.training.batch_size=16 \
28 | dataset.diffusion.training.epochs=2500 \
29 | dataset.diffusion.training.accelerator='gpu' \
30 | dataset.diffusion.training.devices=8 \
31 | dataset.diffusion.training.results_dir=\'/home/pandeyk1/psld_results/sota/uncond/psld_hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_nres=8_chmult=222/\' \
32 | dataset.diffusion.training.workers=1 \
33 | dataset.diffusion.training.chkpt_prefix=\"hsm_gamma=0.01_nu=4.01_cifar10_continuous_sfn=ncsnpp_3rdFeb\"
34 |
--------------------------------------------------------------------------------