├── .gitignore ├── DSBM-Gaussian.py ├── README.md ├── bridge.def ├── bridge ├── __init__.py ├── data │ ├── __init__.py │ ├── afhq.py │ ├── cacheloader.py │ ├── downscaler.py │ ├── emnist.py │ ├── metrics.py │ └── utils.py ├── models │ ├── __init__.py │ ├── basic │ │ ├── __init__.py │ │ ├── basic_cond.py │ │ ├── layers.py │ │ └── time_embedding.py │ ├── ddpmpp │ │ ├── layers.py │ │ ├── layerspp.py │ │ ├── ncsnpp.py │ │ ├── normalization.py │ │ ├── op │ │ │ ├── __init__.py │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── up_or_down_sampling.py │ │ └── utils.py │ └── unet │ │ ├── __init__.py │ │ ├── fp16_util.py │ │ ├── layers.py │ │ └── unet.py ├── runners │ ├── __init__.py │ ├── config_getters.py │ ├── ema.py │ ├── logger.py │ ├── plotters.py │ └── repeater.py ├── sde │ ├── __init__.py │ ├── diffusion_bridge.py │ ├── discrete_langevin.py │ └── optimal_transport.py ├── trainer_dbdsb.py ├── trainer_dsb.py └── trainer_rf.py ├── conf ├── config.yaml ├── dataset │ ├── afhq_transfer.yaml │ ├── cifar10.yaml │ ├── downscaler_transfer.yaml │ └── mnist_transfer.yaml ├── gaussian.yaml ├── job.yaml ├── launcher │ ├── slurm_cpu.yaml │ └── slurm_gpu.yaml ├── method │ ├── bm.yaml │ ├── dbdsb.yaml │ ├── dbdsb_vp.yaml │ ├── dsb.yaml │ ├── otcfm.yaml │ └── rf.yaml ├── model │ ├── DDPMpp_32.yaml │ ├── DDPMpp_RF.yaml │ ├── DownscalerUNET.yaml │ └── UNET.yaml ├── test_config.yaml └── test_job.yaml ├── main.py ├── run_dbdsb.py ├── run_dsb.py ├── run_rf.py ├── test.py ├── test_dbdsb.py └── test_rf.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | experiments/ 132 | 133 | .idea/ 134 | 135 | /data/ 136 | 137 | /slurm/ 138 | 139 | slurm*.sh 140 | 141 | bridge.sif 142 | 143 | sylabs-token 144 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Schrödinger Bridge Matching 2 | 3 | This repository contains the `PyTorch` implementation for the submission Diffusion Schrödinger Bridge Matching. 4 | 5 | 6 | ## Introduction to Schrödinger Bridges and links with diffusion models 7 | 8 | The goal of learning Schrödinger Bridges is to build a bridge between two distributions $\pi_ 0$ and $\pi_ T$ such that the bridge is optimal in some sense. 9 | This transport setting covers many applications: 10 | * Generative modeling: Gaussian $\rightarrow$ Data distribution. 11 | * Data translation: Data distribution 1 $\rightarrow$ Data distribution 2. 12 | 13 | The bridge is represented by a (stochastic) process $(\mathbf{X}_ t)_ {t \in [0,t]}$ such that $\mathbf{X}_ 0 \sim \pi_ 0$ and $\mathbf{X}_ T \sim \pi_ T$. 14 | 15 | Schrödinger bridges not only impose extremal constraints that the bridge must have the right distributions at time $0$ and $T$ but also imposes that the *energy* of the displacement is minimized in some sense. 16 | As a result, Schrödinger Bridges corresponds to solutions of (regularized) Optimal Transport problems. 17 | 18 | Minimizing the energy of the path can also be interpreted at minimizing the Kullback-Leibler divergence between the measure of the bridge $\mathbb{P}$ and the measure of a *reference* process $\mathbb{Q}$ usually associated with a Brownian motion $(\mathbf{B}_ t)_ {t \in [0,T]}$. The Schrödinger bridge $\mathbb{P}$ is solution to the following minimization problem. 19 | 20 | $$ 21 | \mathbb{P}^\star = \arg\min \{ \mathrm{KL}(\mathbb{P}|\mathbb{Q}), \ \mathbb{P}_ 0 = \pi_0, \ \mathbb{P}_ T = \pi_T \} . 22 | $$ 23 | 24 | The solution $\mathbb{P}^\star$ has the following properties: 25 | 1. $\mathbb{P}^\star_0 = \pi_0$. 26 | 2. $\mathbb{P}^\star_1 = \pi_T$. 27 | 3. $\mathbb{P}^\star$ is Markov. 28 | 4. $\mathbb{P}^\star$ is in the reciprocal class of $\mathbb{Q}$, i.e. $\mathbb{P}^\star_ {|0,T} = \mathbb{Q}_ {|0,T}$ (the measures $\mathbb{P}^\star$ and $\mathbb{Q}$ are the same when conditioned on the initial and terminal conditions). 29 | 30 | The Iterative Proportional Fitting (IPF) procedure proceeds by alternatively projecting the measure on the conditions 1 and 2. The conditions 3 and 4 are satisfied for all the iterates. The new **Iterative Markovian Fitting** (IMF) procedure we propose alternatively projects on the condition 3 and 4, while preserving the conditions 1 and 2. 31 | We denote $\mathrm{proj}_ {\mathcal{M}}$ the projection on Markov processes and $\mathrm{proj}_ {\mathcal{R}(\mathbb{Q})}$ the projection on the reciprocal class of $\mathbb{Q}$. 32 | The IMF procedure defines a sequence $(\mathbb{P}^n)_ {n \in \mathbb{N}}$ given by 33 | 34 | $$ 35 | \mathbb{P}^{2n+1} = \mathrm{proj}_ {\mathcal{M}}(\mathbb{P}^{2n}) , \\ 36 | \mathbb{P}^{2n+2} = \mathrm{proj}_ {\mathcal{R}(\mathbb{Q})}(\mathbb{P}^{2n+1}). 37 | $$ 38 | 39 | We refer to our paper for details on the implementation of these projections. The practical algorithm associated with IMF leverages Flow and Bridge Matching. We call this practical algorithm Diffusion Schrödinger Bridge Matching (DSBM). 40 | 41 | ## Reproducing experiments 42 | ### Setting up 43 | We provide a singularity container recipe in `bridge.def` which can be used to set up a singularity container. Alternatively, a conda environment can be set up manually using the conda installation commands in `bridge.def`. 44 | 45 | ### Gaussian experiment 46 | A self-contained Gaussian experiment benchmark is provided in `DSBM-Gaussian.py`. 47 | 48 | DSB: `python DSBM-Gaussian.py dim=5,20,50 model_name=dsb seed=1,2,3,4,5 inner_iters=10000 -m` 49 | 50 | IMF-b: `python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm first_coupling=ind seed=1,2,3,4,5 inner_iters=10000 fb_sequence=['b'] -m` 51 | 52 | DSBM-IPF: `python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm seed=1,2,3,4,5 inner_iters=10000 -m` 53 | 54 | DSBM-IMF: `python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm first_coupling=ind seed=1,2,3,4,5 inner_iters=10000 -m` 55 | 56 | Rectified Flow: `python DSBM-Gaussian.py dim=5,20,50 model_name=rectifiedflow seed=1,2,3,4,5 inner_iters=10000 fb_sequence=[b] -m` 57 | 58 | SB-CFM: `python DSBM-Gaussian.py dim=5,20,50 model_name=sbcfm seed=1,2,3,4,5 inner_iters=10000 -m` 59 | 60 | 61 | ### MNIST experiment 62 | DSBM-IPF: `python main.py num_steps=30 num_iter=5000 method=dbdsb gamma_min=0.034 gamma_max=0.034` 63 | 64 | DSBM-IMF: `python main.py num_steps=30 num_iter=5000 method=dbdsb first_num_iter=100000 gamma_min=0.034 gamma_max=0.034 first_coupling=ind` 65 | 66 | ### Geophysical downscaling experiment 67 | For the dataset, it can be downloaded and processed using the script `https://github.com/CliMA/diffusion-bridge-downscaling/blob/main/CliMAgen/examples/utils_data.jl`, then save as numpy arrays in `./data/downscaler`. 68 | 69 | DSBM-IPF: `python main.py dataset=downscaler_transfer num_steps=30 num_iter=5000 gamma_min=0.01 gamma_max=0.01 model=DownscalerUNET` 70 | 71 | DSBM-IMF: `python main.py dataset=downscaler_transfer num_steps=30 num_iter=5000 gamma_min=0.01 gamma_max=0.01 model=DownscalerUNET first_coupling=ind` 72 | -------------------------------------------------------------------------------- /bridge.def: -------------------------------------------------------------------------------- 1 | ################# Header: Define the base system you want to use ################ 2 | # Reference of the kind of base you want to use (e.g., docker, debootstrap, shub). 3 | Bootstrap: docker 4 | # Select the docker image you want to use 5 | From: nvidia/cuda:11.7.0-cudnn8-devel-ubuntu18.04 6 | 7 | ################# Section: Defining the system ################################# 8 | # Commands in the %post section are executed within the container. 9 | %post 10 | # Linux packages 11 | apt -y update && \ 12 | apt install -y \ 13 | wget \ 14 | git \ 15 | unzip && \ 16 | apt clean && \ 17 | rm -rf /var/lib/apt/lists/* 18 | 19 | # Install conda 20 | wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 21 | bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/miniconda3 && \ 22 | rm Miniconda3-latest-Linux-x86_64.sh 23 | . /opt/miniconda3/etc/profile.d/conda.sh 24 | 25 | # Install packages 26 | conda create --name bridge python=3.9 27 | conda activate bridge 28 | 29 | conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 30 | conda install -c conda-forge cudatoolkit-dev 31 | 32 | conda install scipy pandas scikit-learn tqdm matplotlib seaborn python-lmdb h5py 33 | conda install -c conda-forge pytorch-lightning hydra-core wandb gdown p7zip pot ninja 34 | conda install -c conda-forge accelerate 35 | conda install -c conda-forge torch-fidelity lpips torchdiffeq 36 | conda install -c conda-forge jupyterlab 37 | pip install slurm-gpustat hydra-submitit-launcher 38 | 39 | echo ". /opt/miniconda3/etc/profile.d/conda.sh" >> $SINGULARITY_ENVIRONMENT 40 | echo "conda activate /opt/miniconda3/envs/bridge" >> $SINGULARITY_ENVIRONMENT 41 | -------------------------------------------------------------------------------- /bridge/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuyang-shi/dsbm-pytorch/0f3ca754cb4dbdfa2eb1c095450676602f481d54/bridge/__init__.py -------------------------------------------------------------------------------- /bridge/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cacheloader import * -------------------------------------------------------------------------------- /bridge/data/afhq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from glob import glob 3 | from pathlib import Path 4 | from PIL import Image 5 | import numpy as np 6 | 7 | 8 | class AFHQ(torch.utils.data.Dataset): 9 | """ 10 | root_dir is to train set that has the cat, dog, wild folders in 11 | animal_type is either cat, dog, or wild 12 | """ 13 | def __init__(self, root_dir, animal_type): 14 | self.root_dir = root_dir 15 | self.animal_type = animal_type 16 | assert animal_type in ['cat', 'dog', 'wild'] 17 | self.all_image_paths = list(sorted(Path(self.root_dir).joinpath(animal_type).glob('*.png'))) 18 | 19 | def __len__(self): 20 | return len(self.all_image_paths) 21 | 22 | def __getitem__(self, index): 23 | path = self.all_image_paths[index] 24 | 25 | pil_image = Image.open(path) 26 | np_image = np.array(pil_image) # 0 to 255 integer 27 | 28 | # scale floats between -1 and 1 29 | tensor_image = (torch.tensor(np_image, dtype=torch.float32) / 255.0) * 2 - 1 30 | 31 | # current shape is (H, W, C) 32 | # transpose to (C, H, W) 33 | tensor_image = tensor_image.permute(2, 0, 1) 34 | 35 | return tensor_image, torch.zeros((1,)) 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /bridge/data/cacheloader.py: -------------------------------------------------------------------------------- 1 | import os, time, shutil 2 | import glob 3 | import numpy as np 4 | from numpy.lib.format import open_memmap 5 | import torch 6 | from torch.utils.data import Dataset, TensorDataset 7 | from bridge.data.utils import save_image 8 | 9 | 10 | class MemMapTensorDataset(Dataset): 11 | def __init__(self, npy_file_list) -> None: 12 | self.npy_file_list = npy_file_list 13 | self.data_file_list = [np.load(npy_file, mmap_mode='r') for npy_file in self.npy_file_list] 14 | 15 | def __getitem__(self, index): 16 | out = [] 17 | for data_file in self.data_file_list: 18 | data = torch.from_numpy(data_file[index]) 19 | out = out + [d for d in data] 20 | return out 21 | 22 | def __len__(self): 23 | return len(self.data_file_list[0]) 24 | 25 | 26 | def CacheLoader(fb, sample_net, init_dl, final_dl, num_batches, langevin, ipf, n, device='cpu'): 27 | start = time.time() 28 | all_x = [] 29 | # all_y = [] 30 | all_out = [] 31 | all_steps = [] 32 | 33 | sample_direction = 'f' if fb == 'b' else 'b' 34 | 35 | for b in range(num_batches): 36 | init_batch_x, batch_y, final_batch_x, mean_final, var_final = ipf.sample_batch(init_dl, final_dl) 37 | if sample_direction == "f": 38 | batch_x = init_batch_x 39 | else: 40 | batch_x = final_batch_x 41 | 42 | with torch.no_grad(): 43 | if (n == 1) & (fb == 'b'): 44 | x, y, out, steps_expanded = langevin.record_init_langevin(batch_x, batch_y, 45 | mean_final=mean_final, 46 | var_final=var_final) 47 | else: 48 | x, y, out, steps_expanded = langevin.record_langevin_seq(sample_net, batch_x, batch_y, sample_direction, var_final=var_final) 49 | 50 | # store x, y, out, steps 51 | x = x.flatten(start_dim=0, end_dim=1).to(device) 52 | # y = y.flatten(start_dim=0, end_dim=1).to(device) 53 | out = out.flatten(start_dim=0, end_dim=1).to(device) 54 | steps_expanded = steps_expanded.flatten(start_dim=0, end_dim=1).to(device) 55 | 56 | all_x.append(x) 57 | # all_y.append(y) 58 | all_out.append(out) 59 | all_steps.append(steps_expanded) 60 | 61 | all_x = torch.cat(all_x, dim=0) 62 | # all_y = torch.cat(all_y, dim=0) 63 | all_out = torch.cat(all_out, dim=0) 64 | all_steps = torch.cat(all_steps, dim=0) 65 | 66 | stop = time.time() 67 | ipf.accelerator.print('Cache size: {0}'.format(all_x.shape)) 68 | ipf.accelerator.print("Load time: {0}".format(stop-start)) 69 | ipf.accelerator.print("Out mean: {0}".format(all_out.mean().item())) 70 | ipf.accelerator.print("Out std: {0}".format(all_out.std().item())) 71 | 72 | return TensorDataset(all_x, all_out, all_steps) 73 | 74 | 75 | def DBDSB_CacheLoader(sample_direction, sample_fn, init_dl, final_dl, num_batches, langevin, ipf, n, refresh_idx=0, refresh_tot=1, device='cpu'): 76 | start = time.time() 77 | 78 | # New method, saving as npy 79 | cache_filename_npy = f'cache_{sample_direction}_{n:03}.npy' 80 | cache_filepath_npy = os.path.join(ipf.cache_dir, cache_filename_npy) 81 | 82 | cache_filename_txt = f'cache_{sample_direction}_{n:03}.txt' 83 | cache_filepath_txt = os.path.join(ipf.cache_dir, cache_filename_txt) 84 | 85 | if ipf.cdsb: 86 | cache_y_filename_npy = f'cache_y_{sample_direction}_{n:03}.npy' 87 | cache_y_filepath_npy = os.path.join(ipf.cache_dir, cache_y_filename_npy) 88 | 89 | # Temporary cache of each batch 90 | temp_cache_dir = os.path.join(ipf.cache_dir, f"temp_{sample_direction}_{n:03}_{refresh_idx:03}") 91 | os.makedirs(temp_cache_dir, exist_ok=True) 92 | 93 | npar = num_batches * ipf.cache_batch_size 94 | num_batches_dist = num_batches * ipf.accelerator.num_processes # In distributed mode 95 | cache_batch_size_dist = ipf.cache_batch_size // ipf.accelerator.num_processes # In distributed mode 96 | 97 | use_existing_cache = False 98 | if os.path.isfile(cache_filepath_txt): 99 | f = open(cache_filepath_txt, 'r') 100 | input = f.readline() 101 | f.close() 102 | input_list = input.split("/") 103 | if int(input_list[0]) == refresh_idx and int(input_list[1]) == refresh_tot: 104 | use_existing_cache = True 105 | 106 | if not use_existing_cache: 107 | sample = ((sample_direction == 'b') or ipf.transfer) 108 | normalize_x1 = ((not sample) and ipf.normalize_x1) 109 | 110 | x1_mean_list, x1_mse_list = [], [] 111 | 112 | for b in range(num_batches): 113 | b_dist = b * ipf.accelerator.num_processes + ipf.accelerator.process_index 114 | 115 | try: 116 | batch_x0, batch_x1 = torch.load(os.path.join(temp_cache_dir, f"{b_dist}.pt")) 117 | if ipf.cdsb: 118 | batch_y = torch.load(os.path.join(temp_cache_dir, f"{b_dist}_y.pt"))[0] 119 | assert len(batch_x0) == len(batch_x1) == cache_batch_size_dist 120 | batch_x0, batch_x1 = batch_x0.to(ipf.device), batch_x1.to(ipf.device) 121 | except: 122 | ipf.set_seed(seed=ipf.compute_current_step(0, n+1)*num_batches_dist*refresh_tot + num_batches_dist*refresh_idx + b_dist) 123 | 124 | init_batch_x, init_batch_y, final_batch_x, _, _ = ipf.sample_batch(init_dl, final_dl) 125 | 126 | with torch.no_grad(): 127 | batch_x0, batch_y, batch_x1 = langevin.generate_new_dataset(init_batch_x, init_batch_y, final_batch_x, sample_fn, sample_direction, sample=sample, num_steps=ipf.cache_num_steps) 128 | batch_x0, batch_x1 = batch_x0.contiguous(), batch_x1.contiguous() 129 | torch.save([batch_x0, batch_x1], os.path.join(temp_cache_dir, f"{b_dist}.pt")) 130 | if ipf.cdsb: 131 | torch.save([batch_y], os.path.join(temp_cache_dir, f"{b_dist}_y.pt")) 132 | 133 | if normalize_x1: 134 | x1_mean_list.append(batch_x1.mean(0)) 135 | x1_mse_list.append(batch_x1.square().mean(0)) 136 | 137 | if normalize_x1: 138 | x1_mean = torch.stack(x1_mean_list).mean(0) 139 | x1_mse = torch.stack(x1_mse_list).mean(0) 140 | reduced_x1_mean = ipf.accelerator.reduce(x1_mean, reduction='mean') 141 | reduced_x1_mse = ipf.accelerator.reduce(x1_mse, reduction='mean') 142 | reduced_x1_std = (reduced_x1_mse - reduced_x1_mean.square()).sqrt() 143 | 144 | ipf.accelerator.wait_for_everyone() 145 | 146 | stop = time.time() 147 | ipf.accelerator.print("Load time: {0}".format(stop-start)) 148 | 149 | # Aggregate temporary caches into central cache file 150 | if ipf.accelerator.is_main_process: 151 | fp = open_memmap(cache_filepath_npy, dtype='float32', mode='w+', shape=(npar, 2, *batch_x0.shape[1:])) 152 | if ipf.cdsb: 153 | fp_y = open_memmap(cache_y_filepath_npy, dtype='float32', mode='w+', shape=(npar, 1, *batch_y.shape[1:])) 154 | for b_dist in range(num_batches_dist): 155 | temp_cache_filepath_b_dist = os.path.join(temp_cache_dir, f"{b_dist}.pt") 156 | loaded = False 157 | while not loaded: 158 | if not os.path.isfile(temp_cache_filepath_b_dist): 159 | print(f"Index {ipf.accelerator.process_index} did not find temp cache file {b_dist}, retrying in 5 seconds") 160 | time.sleep(5) 161 | else: 162 | try: 163 | batch_x0, batch_x1 = torch.load(temp_cache_filepath_b_dist) 164 | batch_x0, batch_x1 = batch_x0.to(ipf.device), batch_x1.to(ipf.device) 165 | loaded = True 166 | except: 167 | print(f"Index {ipf.accelerator.process_index} failed to load cache file {b_dist}, retrying in 5 seconds") 168 | time.sleep(5) 169 | 170 | assert len(batch_x0) == len(batch_x1) == cache_batch_size_dist 171 | 172 | if ipf.cdsb: 173 | temp_cache_y_filepath_b_dist = os.path.join(temp_cache_dir, f"{b_dist}_y.pt") 174 | loaded = False 175 | while not loaded: 176 | if not os.path.isfile(temp_cache_y_filepath_b_dist): 177 | print(f"Index {ipf.accelerator.process_index} did not find temp cache file {b_dist}_y, retrying in 5 seconds") 178 | time.sleep(5) 179 | else: 180 | try: 181 | batch_y = torch.load(temp_cache_y_filepath_b_dist)[0] 182 | loaded = True 183 | except: 184 | print(f"Index {ipf.accelerator.process_index} failed to load cache file {b_dist}_y, retrying in 5 seconds") 185 | time.sleep(5) 186 | assert len(batch_y) == cache_batch_size_dist 187 | 188 | if normalize_x1: 189 | batch_x1 = (batch_x1 - reduced_x1_mean) / reduced_x1_std 190 | 191 | batch = torch.stack([batch_x0, batch_x1], dim=1).float().cpu().numpy() 192 | fp[b_dist*cache_batch_size_dist:(b_dist+1)*cache_batch_size_dist] = batch 193 | fp.flush() 194 | 195 | if ipf.cdsb: 196 | batch_y = batch_y.unsqueeze(1).float().cpu().numpy() 197 | fp_y[b_dist*cache_batch_size_dist:(b_dist+1)*cache_batch_size_dist] = batch_y 198 | fp_y.flush() 199 | 200 | del fp 201 | if ipf.cdsb: 202 | del fp_y 203 | 204 | f = open(cache_filepath_txt, 'w') # w : writing mode / r : reading mode / a : appending mode 205 | f.write(f'{refresh_idx}/{refresh_tot}') 206 | f.close() 207 | 208 | shutil.rmtree(temp_cache_dir) 209 | 210 | ipf.accelerator.wait_for_everyone() 211 | 212 | # All processes check that the cache is accessible 213 | loaded = False 214 | while not loaded: 215 | if not os.path.isfile(cache_filepath_npy): 216 | print("Index", ipf.accelerator.process_index, "did not find cache file, retrying in 5 seconds") 217 | time.sleep(5) 218 | else: 219 | try: 220 | fp = np.load(cache_filepath_npy, mmap_mode='r') 221 | loaded = True 222 | except: 223 | print("Index", ipf.accelerator.process_index, "failed to load cache file, retrying in 5 seconds") 224 | time.sleep(5) 225 | 226 | if ipf.cdsb: 227 | loaded = False 228 | while not loaded: 229 | if not os.path.isfile(cache_y_filepath_npy): 230 | print("Index", ipf.accelerator.process_index, "did not find cache_y file, retrying in 5 seconds") 231 | time.sleep(5) 232 | else: 233 | try: 234 | fp_y = np.load(cache_y_filepath_npy, mmap_mode='r') 235 | loaded = True 236 | except: 237 | print("Index", ipf.accelerator.process_index, "failed to load cache_y file, retrying in 5 seconds") 238 | time.sleep(5) 239 | 240 | ipf.accelerator.wait_for_everyone() 241 | ipf.accelerator.print(f'Cache size: {fp.shape}') 242 | 243 | if ipf.accelerator.is_main_process: 244 | # Visualize first entries 245 | num_plots_grid = 100 246 | ipf.plotter.save_image(torch.from_numpy(fp[:num_plots_grid, 0]), f'cache_{sample_direction}_{n:03}_x0', "./", domain=0) 247 | ipf.plotter.save_image(torch.from_numpy(fp[:num_plots_grid, 1]), f'cache_{sample_direction}_{n:03}_x1', "./", domain=1) 248 | 249 | # Automatically delete old cache files 250 | for fb in ['f', 'b']: 251 | existing_cache_files = sorted(glob.glob(os.path.join(ipf.cache_dir, f"cache_{fb}_**.npy"))) 252 | for ckpt_i in range(max(len(existing_cache_files)-1, 0)): 253 | if not os.path.samefile(existing_cache_files[ckpt_i], cache_filepath_npy): 254 | os.remove(existing_cache_files[ckpt_i]) 255 | 256 | if ipf.cdsb: 257 | existing_cache_files = sorted(glob.glob(os.path.join(ipf.cache_dir, f"cache_y_{fb}_**.npy"))) 258 | for ckpt_i in range(max(len(existing_cache_files)-1, 0)): 259 | if not os.path.samefile(existing_cache_files[ckpt_i], cache_filepath_npy): 260 | os.remove(existing_cache_files[ckpt_i]) 261 | 262 | del fp 263 | 264 | if ipf.cdsb: 265 | del fp_y 266 | return MemMapTensorDataset([cache_filepath_npy, cache_y_filepath_npy]) 267 | 268 | return MemMapTensorDataset([cache_filepath_npy]) -------------------------------------------------------------------------------- /bridge/data/downscaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import h5py 5 | 6 | 7 | class DownscalerDataset(torch.utils.data.Dataset): 8 | def __init__(self, root, resolution=512, wavenumber=0, split="train", transform=None, target_transform=None): 9 | assert resolution in [64, 512] 10 | assert split in ["train", "test"] 11 | self.root = root 12 | # self.data = h5py.File(os.path.join(self.root, f"x{split}_{resolution}.jld2"), 'r')['single_stored_object'] 13 | self.data = np.load(os.path.join(self.root, f"x{split}_{resolution}.npy"), mmap_mode='r') 14 | if resolution == 64: 15 | self.indices = np.arange(len(self.data)) 16 | elif resolution == 512: 17 | if wavenumber == 0: 18 | self.indices = np.arange(len(self.data)) 19 | else: 20 | assert wavenumber in [1, 2, 4, 8, 16] 21 | seg = int(np.log2(wavenumber)) 22 | ndata_seg = 802 if split == 'train' else 200 23 | self.indices = np.arange(seg*ndata_seg, (seg+1)*ndata_seg) 24 | else: 25 | raise ValueError 26 | self.randperm = np.random.default_rng(42).permutation(len(self.indices)) 27 | 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | 31 | self.scaling = np.load(os.path.join(self.root, f"scaling_{resolution}.npz")) 32 | mintrain_mean, Delta_, mintrain_p, Delta_p = self.scaling['mintrain_mean'], self.scaling['Delta_'], self.scaling['mintrain_p'], self.scaling['Deltap'] 33 | self.mintrain_mean, self.Delta_, self.mintrain_p, self.Delta_p = \ 34 | mintrain_mean.transpose()[0], Delta_.transpose()[0], mintrain_p.transpose()[0], Delta_p.transpose()[0] 35 | 36 | def __len__(self): 37 | return len(self.indices) 38 | 39 | def __getitem__(self, index): 40 | rindex = self.indices[self.randperm[index]] 41 | data = self.data[rindex] 42 | img, targets = torch.from_numpy(data[..., :2, :, :]), torch.from_numpy(data[..., 2:, :, :]) 43 | 44 | if self.transform is not None: 45 | img = self.transform(img) 46 | 47 | if self.target_transform is not None: 48 | target = self.target_transform(target) 49 | 50 | return img, targets 51 | 52 | def invert_preprocessing(self, x_tilde, y_tilde=None): 53 | # x_tilde: (B,) C, H, W 54 | if y_tilde is None: 55 | y_tilde = torch.zeros_like(x_tilde)[..., 0:1, :, :] 56 | x_tilde = torch.cat([x_tilde, y_tilde], dim=-3) 57 | assert x_tilde.shape[-3] == 3 58 | tmp = (x_tilde + 2) / 2 * self.Delta_p + self.mintrain_p 59 | xp = tmp - tmp.mean((-2, -1), keepdims=True) 60 | x_bar = (tmp - xp) / self.Delta_p * self.Delta_ + self.mintrain_mean 61 | out = xp + x_bar 62 | return out[..., :2, :, :], out[..., 2:, :, :] 63 | 64 | def apply_preprocessing(self, x, y=None): 65 | # x: (B,) C, H, W 66 | if y is None: 67 | y = torch.zeros_like(x)[..., 0:1, :, :] 68 | x = torch.cat([x, y], dim=-3) 69 | assert x.shape[-3] == 3 70 | x_bar = x.mean((-2, -1), keepdims=True) 71 | xp = x - x_bar 72 | x_tilde = 2 * (x_bar - self.mintrain_mean) / self.Delta_ - 1 73 | x_tilde_p = 2 * (xp - self.mintrain_p) / self.Delta_p - 1 74 | out = x_tilde + x_tilde_p 75 | return out[..., :2, :, :], out[..., 2:, :, :] 76 | 77 | -------------------------------------------------------------------------------- /bridge/data/emnist.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import urllib 3 | import torch 4 | import torchvision.datasets 5 | import torchvision.transforms as transforms 6 | import torchvision.utils as vutils 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from tqdm import tqdm 10 | from torchvision.utils import save_image 11 | 12 | 13 | class FiveClassEMNIST(torchvision.datasets.EMNIST): 14 | def __init__(self, root="./data/emnist", train=True, download=True, transform=None, target_transform=None): 15 | super().__init__(root=root, split="letters", train=train, download=download, transform=transform, target_transform=target_transform) 16 | self.custom_indices = (self.targets<=5).nonzero(as_tuple=True)[0] 17 | self.data, self.targets = self.data[self.custom_indices].transpose(1, 2), self.targets[self.custom_indices] 18 | -------------------------------------------------------------------------------- /bridge/data/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import PeakSignalNoiseRatio as _PSNR, StructuralSimilarityIndexMeasure as _SSIM 3 | from torchmetrics.image.fid import FrechetInceptionDistance as _FID 4 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as _LPIPS 5 | from .utils import from_uint8_tensor 6 | 7 | class PSNR(_PSNR): 8 | def update(self, preds, target): 9 | super().update(preds.float(), target.float()) 10 | 11 | 12 | class SSIM(_SSIM): 13 | def update(self, preds, target): 14 | super().update(preds.float(), target.float()) 15 | 16 | 17 | class FID(_FID): 18 | def update(self, preds, target): 19 | if self.reset_real_features: 20 | super().update(target.expand(-1, 3, -1, -1), real=True) 21 | super().update(preds.expand(-1, 3, -1, -1), real=False) 22 | 23 | 24 | class LPIPS(_LPIPS): 25 | def update(self, preds, target): 26 | preds = from_uint8_tensor(preds) 27 | target = from_uint8_tensor(target) 28 | super().update(target.expand(-1, 3, -1, -1), preds.expand(-1, 3, -1, -1)) 29 | 30 | 31 | if __name__ == "__main__": 32 | PSNR() 33 | SSIM() 34 | FID() -------------------------------------------------------------------------------- /bridge/data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.utils as vutils 4 | 5 | 6 | def save_image(tensor, fp, format=None, **kwargs): 7 | normalized = normalize_tensor(tensor) 8 | vutils.save_image(normalized, fp, format=format, **kwargs) 9 | 10 | 11 | def to_uint8_tensor(tensor): 12 | normalized = normalize_tensor(tensor) 13 | return normalized.mul(255).add_(0.5).clamp_(0, 255).to(torch.uint8) 14 | 15 | 16 | def from_uint8_tensor(tensor): 17 | normalized = tensor.float() / 255 18 | return unnormalize_tensor(normalized) 19 | 20 | 21 | def normalize_tensor(tensor): 22 | normalized = tensor / 2 + 0.5 23 | return normalized.clamp_(0, 1) 24 | 25 | 26 | def unnormalize_tensor(tensor): 27 | unnormalized = (tensor - 0.5) * 2 28 | return unnormalized.clamp_(-1, 1) 29 | 30 | 31 | def _list_image_files_recursively(data_dir): 32 | results = [] 33 | for entry in sorted(os.listdir(data_dir)): 34 | full_path = os.path.join(data_dir, entry) 35 | ext = entry.split(".")[-1] 36 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 37 | results.append(full_path) 38 | elif os.path.isdir(full_path): 39 | results.extend(_list_image_files_recursively(full_path)) 40 | return results -------------------------------------------------------------------------------- /bridge/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic.basic_cond import BasicNetworkCond, ScoreNetworkCond 2 | from .unet import * 3 | from .ddpmpp.ncsnpp import NCSNpp -------------------------------------------------------------------------------- /bridge/models/basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuyang-shi/dsbm-pytorch/0f3ca754cb4dbdfa2eb1c095450676602f481d54/bridge/models/basic/__init__.py -------------------------------------------------------------------------------- /bridge/models/basic/basic_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .layers import MLP 3 | from .time_embedding import get_timestep_embedding 4 | 5 | 6 | class BasicNetworkCond(torch.nn.Module): 7 | def __init__(self, encoder_layers=[16], temb_dim=16, decoder_layers=[128,128], x_dim=1, y_dim=1): 8 | super().__init__() 9 | self.temb_dim = temb_dim 10 | t_enc_dim = temb_dim * 2 11 | self.locals = [encoder_layers, temb_dim, decoder_layers, x_dim, y_dim] 12 | 13 | self.net = MLP(t_enc_dim, 14 | layer_widths=decoder_layers + [x_dim], 15 | activate_final = False, 16 | activation_fn=torch.nn.LeakyReLU()) 17 | 18 | self.y_encoder = MLP(y_dim, 19 | layer_widths=encoder_layers + [t_enc_dim], 20 | activate_final = True, 21 | activation_fn=torch.nn.LeakyReLU()) 22 | 23 | 24 | def forward(self, y): 25 | if len(y.shape) == 1: 26 | y = y.unsqueeze(0) 27 | h = self.y_encoder(y) 28 | out = self.net(h) 29 | return out 30 | 31 | 32 | class ScoreNetworkCond(torch.nn.Module): 33 | def __init__(self, encoder_layers=[16], temb_dim=16, decoder_layers=[128,128], x_dim=1, y_dim=1, temb_max_period=10000): 34 | super().__init__() 35 | self.temb_dim = temb_dim 36 | t_enc_dim = temb_dim * 2 37 | self.locals = [encoder_layers, temb_dim, decoder_layers, x_dim, y_dim, temb_max_period] 38 | 39 | self.net = MLP(3 * t_enc_dim, 40 | layer_widths=decoder_layers + [x_dim], 41 | activate_final = False, 42 | activation_fn=torch.nn.LeakyReLU()) 43 | 44 | self.t_encoder = MLP(temb_dim, 45 | layer_widths=encoder_layers + [t_enc_dim], 46 | activate_final = True, 47 | activation_fn=torch.nn.LeakyReLU()) 48 | 49 | self.xy_encoder = MLP(x_dim + y_dim, 50 | layer_widths=[enc_dim*2 for enc_dim in encoder_layers] + [2*t_enc_dim], 51 | activate_final = True, 52 | activation_fn=torch.nn.LeakyReLU()) 53 | 54 | self.temb_max_period = temb_max_period 55 | 56 | def forward(self, x, y, t): 57 | if len(x.shape) == 1: 58 | x = x.unsqueeze(0) 59 | if len(y.shape) == 1: 60 | y = y.unsqueeze(0) 61 | 62 | t_emb = get_timestep_embedding(t, self.temb_dim, self.temb_max_period) 63 | t_emb = self.t_encoder(t_emb) 64 | xy_emb = self.xy_encoder(torch.cat([x, y], -1)) 65 | h = torch.cat([xy_emb, t_emb], -1) 66 | out = self.net(h) 67 | return out 68 | -------------------------------------------------------------------------------- /bridge/models/basic/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | from functools import partial 6 | 7 | 8 | class MLP(torch.nn.Module): 9 | def __init__(self, input_dim, layer_widths, activate_final = False, activation_fn=F.relu): 10 | super(MLP, self).__init__() 11 | layers = [] 12 | prev_width = input_dim 13 | for layer_width in layer_widths: 14 | layers.append(torch.nn.Linear(prev_width, layer_width)) 15 | # # same init for everyone 16 | # torch.nn.init.constant_(layers[-1].weight, 0) 17 | prev_width = layer_width 18 | self.input_dim = input_dim 19 | self.layer_widths = layer_widths 20 | self.layers = torch.nn.ModuleList(layers) 21 | self.activate_final = activate_final 22 | self.activation_fn = activation_fn 23 | 24 | def forward(self, x): 25 | for i, layer in enumerate(self.layers[:-1]): 26 | x = self.activation_fn(layer(x)) 27 | x = self.layers[-1](x) 28 | if self.activate_final: 29 | x = self.activation_fn(x) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /bridge/models/basic/time_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | def get_timestep_embedding(timesteps, embedding_dim=128, max_period=10000): 7 | """ 8 | From Fairseq. 9 | Build sinusoidal embeddings. 10 | This matches the implementation in tensor2tensor, but differs slightly 11 | from the description in Section 3.5 of "Attention Is All You Need". 12 | https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py 13 | """ 14 | half_dim = embedding_dim // 2 15 | emb = math.log(max_period) / (half_dim - 1) 16 | emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb) 17 | 18 | emb = timesteps * emb.unsqueeze(0) 19 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 20 | if embedding_dim % 2 == 1: # zero pad 21 | emb = F.pad(emb, [0,1]) 22 | 23 | return emb 24 | 25 | 26 | if __name__ == "__main__": 27 | import torch 28 | import matplotlib.pyplot as plt 29 | test = get_timestep_embedding(torch.linspace(0, 0.1, 10).reshape(-1, 1)) 30 | plt.subplot(1, 2, 1) 31 | plt.plot(test.T[:test.shape[1]//2]) 32 | plt.subplot(1, 2, 2) 33 | plt.plot(test.T[test.shape[1]//2:]) 34 | plt.show() -------------------------------------------------------------------------------- /bridge/models/ddpmpp/layerspp.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yang-song/score_sde_pytorch 2 | 3 | # coding=utf-8 4 | # Copyright 2020 The Google Research Authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # pylint: skip-file 19 | """Layers for defining NCSN++. 20 | """ 21 | from . import layers 22 | from . import up_or_down_sampling 23 | import torch.nn as nn 24 | import torch 25 | import torch.nn.functional as F 26 | import numpy as np 27 | 28 | conv1x1 = layers.ddpm_conv1x1 29 | conv3x3 = layers.ddpm_conv3x3 30 | NIN = layers.NIN 31 | default_init = layers.default_init 32 | 33 | 34 | class GaussianFourierProjection(nn.Module): 35 | """Gaussian Fourier embeddings for noise levels.""" 36 | 37 | def __init__(self, embedding_size=256, scale=1.0): 38 | super().__init__() 39 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 40 | 41 | def forward(self, x): 42 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 43 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 44 | 45 | 46 | class Combine(nn.Module): 47 | """Combine information from skip connections.""" 48 | 49 | def __init__(self, dim1, dim2, method='cat'): 50 | super().__init__() 51 | self.Conv_0 = conv1x1(dim1, dim2) 52 | self.method = method 53 | 54 | def forward(self, x, y): 55 | h = self.Conv_0(x) 56 | if self.method == 'cat': 57 | return torch.cat([h, y], dim=1) 58 | elif self.method == 'sum': 59 | return h + y 60 | else: 61 | raise ValueError(f'Method {self.method} not recognized.') 62 | 63 | 64 | class AttnBlockpp(nn.Module): 65 | """Channel-wise self-attention block. Modified from DDPM.""" 66 | 67 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 68 | super().__init__() 69 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 70 | eps=1e-6) 71 | self.NIN_0 = NIN(channels, channels) 72 | self.NIN_1 = NIN(channels, channels) 73 | self.NIN_2 = NIN(channels, channels) 74 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 75 | self.skip_rescale = skip_rescale 76 | 77 | def forward(self, x): 78 | B, C, H, W = x.shape 79 | h = self.GroupNorm_0(x) 80 | q = self.NIN_0(h) 81 | k = self.NIN_1(h) 82 | v = self.NIN_2(h) 83 | 84 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 85 | w = torch.reshape(w, (B, H, W, H * W)) 86 | w = F.softmax(w, dim=-1) 87 | w = torch.reshape(w, (B, H, W, H, W)) 88 | h = torch.einsum('bhwij,bcij->bchw', w, v) 89 | h = self.NIN_3(h) 90 | if not self.skip_rescale: 91 | return x + h 92 | else: 93 | return (x + h) / np.sqrt(2.) 94 | 95 | 96 | class Upsample(nn.Module): 97 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 98 | fir_kernel=(1, 3, 3, 1)): 99 | super().__init__() 100 | out_ch = out_ch if out_ch else in_ch 101 | if not fir: 102 | if with_conv: 103 | self.Conv_0 = conv3x3(in_ch, out_ch) 104 | else: 105 | if with_conv: 106 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 107 | kernel=3, up=True, 108 | resample_kernel=fir_kernel, 109 | use_bias=True, 110 | kernel_init=default_init()) 111 | self.fir = fir 112 | self.with_conv = with_conv 113 | self.fir_kernel = fir_kernel 114 | self.out_ch = out_ch 115 | 116 | def forward(self, x): 117 | B, C, H, W = x.shape 118 | if not self.fir: 119 | h = F.interpolate(x, (H * 2, W * 2), mode='nearest') 120 | if self.with_conv: 121 | h = self.Conv_0(h) 122 | else: 123 | if not self.with_conv: 124 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 125 | else: 126 | h = self.Conv2d_0(x) 127 | 128 | return h 129 | 130 | 131 | class Downsample(nn.Module): 132 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 133 | fir_kernel=(1, 3, 3, 1)): 134 | super().__init__() 135 | out_ch = out_ch if out_ch else in_ch 136 | if not fir: 137 | if with_conv: 138 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 139 | else: 140 | if with_conv: 141 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 142 | kernel=3, down=True, 143 | resample_kernel=fir_kernel, 144 | use_bias=True, 145 | kernel_init=default_init()) 146 | self.fir = fir 147 | self.fir_kernel = fir_kernel 148 | self.with_conv = with_conv 149 | self.out_ch = out_ch 150 | 151 | def forward(self, x): 152 | B, C, H, W = x.shape 153 | if not self.fir: 154 | if self.with_conv: 155 | x = F.pad(x, (0, 1, 0, 1)) 156 | x = self.Conv_0(x) 157 | else: 158 | x = F.avg_pool2d(x, 2, stride=2) 159 | else: 160 | if not self.with_conv: 161 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 162 | else: 163 | x = self.Conv2d_0(x) 164 | 165 | return x 166 | 167 | 168 | class ResnetBlockDDPMpp(nn.Module): 169 | """ResBlock adapted from DDPM.""" 170 | 171 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 172 | dropout=0.1, skip_rescale=False, init_scale=0.): 173 | super().__init__() 174 | out_ch = out_ch if out_ch else in_ch 175 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 176 | self.Conv_0 = conv3x3(in_ch, out_ch) 177 | if temb_dim is not None: 178 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 179 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 180 | nn.init.zeros_(self.Dense_0.bias) 181 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 182 | self.Dropout_0 = nn.Dropout(dropout) 183 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 184 | if in_ch != out_ch: 185 | if conv_shortcut: 186 | self.Conv_2 = conv3x3(in_ch, out_ch) 187 | else: 188 | self.NIN_0 = NIN(in_ch, out_ch) 189 | 190 | self.skip_rescale = skip_rescale 191 | self.act = act 192 | self.out_ch = out_ch 193 | self.conv_shortcut = conv_shortcut 194 | 195 | def forward(self, x, temb=None): 196 | h = self.act(self.GroupNorm_0(x)) 197 | h = self.Conv_0(h) 198 | if temb is not None: 199 | h += self.Dense_0(self.act(temb))[:, :, None, None] 200 | h = self.act(self.GroupNorm_1(h)) 201 | h = self.Dropout_0(h) 202 | h = self.Conv_1(h) 203 | if x.shape[1] != self.out_ch: 204 | if self.conv_shortcut: 205 | x = self.Conv_2(x) 206 | else: 207 | x = self.NIN_0(x) 208 | if not self.skip_rescale: 209 | return x + h 210 | else: 211 | return (x + h) / np.sqrt(2.) 212 | 213 | 214 | class ResnetBlockBigGANpp(nn.Module): 215 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 216 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 217 | skip_rescale=True, init_scale=0.): 218 | super().__init__() 219 | 220 | out_ch = out_ch if out_ch else in_ch 221 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 222 | self.up = up 223 | self.down = down 224 | self.fir = fir 225 | self.fir_kernel = fir_kernel 226 | 227 | self.Conv_0 = conv3x3(in_ch, out_ch) 228 | if temb_dim is not None: 229 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 230 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 231 | nn.init.zeros_(self.Dense_0.bias) 232 | 233 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 234 | self.Dropout_0 = nn.Dropout(dropout) 235 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 236 | if in_ch != out_ch or up or down: 237 | self.Conv_2 = conv1x1(in_ch, out_ch) 238 | 239 | self.skip_rescale = skip_rescale 240 | self.act = act 241 | self.in_ch = in_ch 242 | self.out_ch = out_ch 243 | 244 | def forward(self, x, temb=None): 245 | h = self.act(self.GroupNorm_0(x)) 246 | 247 | if self.up: 248 | if self.fir: 249 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 250 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 251 | else: 252 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 253 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 254 | elif self.down: 255 | if self.fir: 256 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 257 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 258 | else: 259 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 260 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 261 | 262 | h = self.Conv_0(h) 263 | # Add bias to each feature map conditioned on the time embedding 264 | if temb is not None: 265 | h += self.Dense_0(self.act(temb))[:, :, None, None] 266 | h = self.act(self.GroupNorm_1(h)) 267 | h = self.Dropout_0(h) 268 | h = self.Conv_1(h) 269 | 270 | if self.in_ch != self.out_ch or self.up or self.down: 271 | x = self.Conv_2(x) 272 | 273 | if not self.skip_rescale: 274 | return x + h 275 | else: 276 | return (x + h) / np.sqrt(2.) 277 | 278 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yang-song/score_sde_pytorch 2 | 3 | 4 | # coding=utf-8 5 | # Copyright 2020 The Google Research Authors. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | # pylint: skip-file 20 | 21 | from . import utils, layers, layerspp, normalization 22 | import torch.nn as nn 23 | import functools 24 | import torch 25 | import numpy as np 26 | 27 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 28 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 29 | Combine = layerspp.Combine 30 | conv3x3 = layerspp.conv3x3 31 | conv1x1 = layerspp.conv1x1 32 | get_act = layers.get_act 33 | get_normalization = normalization.get_normalization 34 | default_initializer = layers.default_init 35 | 36 | 37 | # even though this is called NCSNpp, with certain config settings, it is known 38 | # as DDPMpp 39 | class NCSNpp(nn.Module): 40 | """NCSN++ model""" 41 | 42 | def __init__(self, config): 43 | super().__init__() 44 | 45 | self.locals = [config] 46 | 47 | self.config = config 48 | self.act = act = get_act(config) 49 | # self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) 50 | 51 | self.nf = nf = config.model.nf 52 | ch_mult = config.model.ch_mult 53 | self.num_res_blocks = num_res_blocks = config.model.num_res_blocks 54 | self.attn_resolutions = attn_resolutions = config.model.attn_resolutions 55 | dropout = config.model.dropout 56 | resamp_with_conv = config.model.resamp_with_conv 57 | self.num_resolutions = num_resolutions = len(ch_mult) 58 | self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] 59 | 60 | # self.conditional = conditional = config.model.conditional # noise-conditional 61 | self.conditional = True # we definitely want to condition on the time step (this doesn't mean conditioning on some other data) 62 | 63 | 64 | fir = config.model.fir 65 | fir_kernel = config.model.fir_kernel 66 | self.skip_rescale = skip_rescale = config.model.skip_rescale 67 | self.resblock_type = resblock_type = config.model.resblock_type.lower() 68 | self.progressive = progressive = config.model.progressive.lower() 69 | self.progressive_input = progressive_input = config.model.progressive_input.lower() 70 | self.embedding_type = embedding_type = config.model.embedding_type.lower() 71 | init_scale = config.model.init_scale 72 | assert progressive in ['none', 'output_skip', 'residual'] 73 | assert progressive_input in ['none', 'input_skip', 'residual'] 74 | assert embedding_type in ['fourier', 'positional'] 75 | combine_method = config.model.progressive_combine.lower() 76 | combiner = functools.partial(Combine, method=combine_method) 77 | 78 | modules = [] 79 | # timestep/noise_level embedding; only for continuous training 80 | if embedding_type == 'fourier': 81 | # Gaussian Fourier features embeddings. 82 | assert config.training.continuous, "Fourier features are only used for continuous training." 83 | 84 | modules.append(layerspp.GaussianFourierProjection( 85 | embedding_size=nf, scale=config.model.fourier_scale 86 | )) 87 | embed_dim = 2 * nf 88 | 89 | elif embedding_type == 'positional': 90 | embed_dim = nf 91 | 92 | else: 93 | raise ValueError(f'embedding type {embedding_type} unknown.') 94 | 95 | if self.conditional: 96 | modules.append(nn.Linear(embed_dim, nf * 4)) 97 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 98 | nn.init.zeros_(modules[-1].bias) 99 | modules.append(nn.Linear(nf * 4, nf * 4)) 100 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 101 | nn.init.zeros_(modules[-1].bias) 102 | 103 | AttnBlock = functools.partial(layerspp.AttnBlockpp, 104 | init_scale=init_scale, 105 | skip_rescale=skip_rescale) 106 | 107 | Upsample = functools.partial(layerspp.Upsample, 108 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 109 | 110 | if progressive == 'output_skip': 111 | self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 112 | elif progressive == 'residual': 113 | pyramid_upsample = functools.partial(layerspp.Upsample, 114 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 115 | 116 | Downsample = functools.partial(layerspp.Downsample, 117 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 118 | 119 | if progressive_input == 'input_skip': 120 | self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 121 | elif progressive_input == 'residual': 122 | pyramid_downsample = functools.partial(layerspp.Downsample, 123 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 124 | 125 | if resblock_type == 'ddpm': 126 | ResnetBlock = functools.partial(ResnetBlockDDPM, 127 | act=act, 128 | dropout=dropout, 129 | init_scale=init_scale, 130 | skip_rescale=skip_rescale, 131 | temb_dim=nf * 4) 132 | 133 | elif resblock_type == 'biggan': 134 | ResnetBlock = functools.partial(ResnetBlockBigGAN, 135 | act=act, 136 | dropout=dropout, 137 | fir=fir, 138 | fir_kernel=fir_kernel, 139 | init_scale=init_scale, 140 | skip_rescale=skip_rescale, 141 | temb_dim=nf * 4) 142 | 143 | else: 144 | raise ValueError(f'resblock type {resblock_type} unrecognized.') 145 | 146 | # Downsampling block 147 | 148 | channels = config.data.num_channels 149 | if progressive_input != 'none': 150 | input_pyramid_ch = channels 151 | 152 | modules.append(conv3x3(channels, nf)) 153 | hs_c = [nf] 154 | 155 | in_ch = nf 156 | for i_level in range(num_resolutions): 157 | # Residual blocks for this resolution 158 | for i_block in range(num_res_blocks): 159 | out_ch = nf * ch_mult[i_level] 160 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 161 | in_ch = out_ch 162 | 163 | if all_resolutions[i_level] in attn_resolutions: 164 | modules.append(AttnBlock(channels=in_ch)) 165 | hs_c.append(in_ch) 166 | 167 | if i_level != num_resolutions - 1: 168 | if resblock_type == 'ddpm': 169 | modules.append(Downsample(in_ch=in_ch)) 170 | else: 171 | modules.append(ResnetBlock(down=True, in_ch=in_ch)) 172 | 173 | if progressive_input == 'input_skip': 174 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) 175 | if combine_method == 'cat': 176 | in_ch *= 2 177 | 178 | elif progressive_input == 'residual': 179 | modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) 180 | input_pyramid_ch = in_ch 181 | 182 | hs_c.append(in_ch) 183 | 184 | in_ch = hs_c[-1] 185 | modules.append(ResnetBlock(in_ch=in_ch)) 186 | modules.append(AttnBlock(channels=in_ch)) 187 | modules.append(ResnetBlock(in_ch=in_ch)) 188 | 189 | pyramid_ch = 0 190 | # Upsampling block 191 | for i_level in reversed(range(num_resolutions)): 192 | for i_block in range(num_res_blocks + 1): 193 | out_ch = nf * ch_mult[i_level] 194 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), 195 | out_ch=out_ch)) 196 | in_ch = out_ch 197 | 198 | if all_resolutions[i_level] in attn_resolutions: 199 | modules.append(AttnBlock(channels=in_ch)) 200 | 201 | if progressive != 'none': 202 | if i_level == num_resolutions - 1: 203 | if progressive == 'output_skip': 204 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 205 | num_channels=in_ch, eps=1e-6)) 206 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 207 | pyramid_ch = channels 208 | elif progressive == 'residual': 209 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 210 | num_channels=in_ch, eps=1e-6)) 211 | modules.append(conv3x3(in_ch, in_ch, bias=True)) 212 | pyramid_ch = in_ch 213 | else: 214 | raise ValueError(f'{progressive} is not a valid name.') 215 | else: 216 | if progressive == 'output_skip': 217 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 218 | num_channels=in_ch, eps=1e-6)) 219 | modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) 220 | pyramid_ch = channels 221 | elif progressive == 'residual': 222 | modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) 223 | pyramid_ch = in_ch 224 | else: 225 | raise ValueError(f'{progressive} is not a valid name') 226 | 227 | if i_level != 0: 228 | if resblock_type == 'ddpm': 229 | modules.append(Upsample(in_ch=in_ch)) 230 | else: 231 | modules.append(ResnetBlock(in_ch=in_ch, up=True)) 232 | 233 | assert not hs_c 234 | 235 | if progressive != 'output_skip': 236 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 237 | num_channels=in_ch, eps=1e-6)) 238 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 239 | 240 | self.all_modules = nn.ModuleList(modules) 241 | 242 | def forward(self, x, y, time_cond): 243 | # make sure y is not used 244 | y = None 245 | 246 | # flatten the time 247 | time_cond = time_cond.view(-1) 248 | 249 | # timestep/noise_level embedding; only for continuous training 250 | modules = self.all_modules 251 | m_idx = 0 252 | if self.embedding_type == 'fourier': 253 | raise NotImplementedError 254 | # Gaussian Fourier features embeddings. 255 | # used_sigmas = time_cond 256 | # temb = modules[m_idx](torch.log(used_sigmas)) 257 | m_idx += 1 258 | 259 | elif self.embedding_type == 'positional': 260 | # Sinusoidal positional embeddings. 261 | timesteps = time_cond 262 | # used_sigmas = self.sigmas[time_cond.long()] 263 | 264 | # assume timesteps is between [0, 1], because then we need to multiply by 1000 265 | # to get a proper spread of time step embeddings 266 | assert timesteps.max() <= 1.1 267 | assert timesteps.min() >= -0.1 268 | temb = layers.get_timestep_embedding(timesteps * 1000, self.nf) 269 | 270 | else: 271 | raise ValueError(f'embedding type {self.embedding_type} unknown.') 272 | 273 | if self.conditional: 274 | temb = modules[m_idx](temb) 275 | m_idx += 1 276 | temb = modules[m_idx](self.act(temb)) 277 | m_idx += 1 278 | else: 279 | temb = None 280 | 281 | if not self.config.data.centered: 282 | # If input data is in [0, 1] 283 | x = 2 * x - 1. 284 | 285 | # Downsampling block 286 | input_pyramid = None 287 | if self.progressive_input != 'none': 288 | input_pyramid = x 289 | 290 | hs = [modules[m_idx](x)] 291 | m_idx += 1 292 | for i_level in range(self.num_resolutions): 293 | # Residual blocks for this resolution 294 | for i_block in range(self.num_res_blocks): 295 | h = modules[m_idx](hs[-1], temb) 296 | m_idx += 1 297 | if h.shape[-1] in self.attn_resolutions: 298 | h = modules[m_idx](h) 299 | m_idx += 1 300 | 301 | hs.append(h) 302 | 303 | if i_level != self.num_resolutions - 1: 304 | if self.resblock_type == 'ddpm': 305 | h = modules[m_idx](hs[-1]) 306 | m_idx += 1 307 | else: 308 | h = modules[m_idx](hs[-1], temb) 309 | m_idx += 1 310 | 311 | if self.progressive_input == 'input_skip': 312 | input_pyramid = self.pyramid_downsample(input_pyramid) 313 | h = modules[m_idx](input_pyramid, h) 314 | m_idx += 1 315 | 316 | elif self.progressive_input == 'residual': 317 | input_pyramid = modules[m_idx](input_pyramid) 318 | m_idx += 1 319 | if self.skip_rescale: 320 | input_pyramid = (input_pyramid + h) / np.sqrt(2.) 321 | else: 322 | input_pyramid = input_pyramid + h 323 | h = input_pyramid 324 | 325 | hs.append(h) 326 | 327 | h = hs[-1] 328 | h = modules[m_idx](h, temb) 329 | m_idx += 1 330 | h = modules[m_idx](h) 331 | m_idx += 1 332 | h = modules[m_idx](h, temb) 333 | m_idx += 1 334 | 335 | pyramid = None 336 | 337 | # Upsampling block 338 | for i_level in reversed(range(self.num_resolutions)): 339 | for i_block in range(self.num_res_blocks + 1): 340 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 341 | m_idx += 1 342 | 343 | if h.shape[-1] in self.attn_resolutions: 344 | h = modules[m_idx](h) 345 | m_idx += 1 346 | 347 | if self.progressive != 'none': 348 | if i_level == self.num_resolutions - 1: 349 | if self.progressive == 'output_skip': 350 | pyramid = self.act(modules[m_idx](h)) 351 | m_idx += 1 352 | pyramid = modules[m_idx](pyramid) 353 | m_idx += 1 354 | elif self.progressive == 'residual': 355 | pyramid = self.act(modules[m_idx](h)) 356 | m_idx += 1 357 | pyramid = modules[m_idx](pyramid) 358 | m_idx += 1 359 | else: 360 | raise ValueError(f'{self.progressive} is not a valid name.') 361 | else: 362 | if self.progressive == 'output_skip': 363 | pyramid = self.pyramid_upsample(pyramid) 364 | pyramid_h = self.act(modules[m_idx](h)) 365 | m_idx += 1 366 | pyramid_h = modules[m_idx](pyramid_h) 367 | m_idx += 1 368 | pyramid = pyramid + pyramid_h 369 | elif self.progressive == 'residual': 370 | pyramid = modules[m_idx](pyramid) 371 | m_idx += 1 372 | if self.skip_rescale: 373 | pyramid = (pyramid + h) / np.sqrt(2.) 374 | else: 375 | pyramid = pyramid + h 376 | h = pyramid 377 | else: 378 | raise ValueError(f'{self.progressive} is not a valid name') 379 | 380 | if i_level != 0: 381 | if self.resblock_type == 'ddpm': 382 | h = modules[m_idx](h) 383 | m_idx += 1 384 | else: 385 | h = modules[m_idx](h, temb) 386 | m_idx += 1 387 | 388 | assert not hs 389 | 390 | if self.progressive == 'output_skip': 391 | h = pyramid 392 | else: 393 | h = self.act(modules[m_idx](h)) 394 | m_idx += 1 395 | h = modules[m_idx](h) 396 | m_idx += 1 397 | 398 | assert m_idx == len(modules) 399 | if self.config.model.scale_by_sigma: 400 | raise NotImplementedError 401 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 402 | h = h / used_sigmas 403 | 404 | return h 405 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/normalization.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yang-song/score_sde_pytorch 2 | 3 | # coding=utf-8 4 | # Copyright 2020 The Google Research Authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Normalization layers.""" 19 | import torch.nn as nn 20 | import torch 21 | import functools 22 | 23 | 24 | def get_normalization(config, conditional=False): 25 | """Obtain normalization modules from the config file.""" 26 | norm = config.model.normalization 27 | if conditional: 28 | if norm == 'InstanceNorm++': 29 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 30 | else: 31 | raise NotImplementedError(f'{norm} not implemented yet.') 32 | else: 33 | if norm == 'InstanceNorm': 34 | return nn.InstanceNorm2d 35 | elif norm == 'InstanceNorm++': 36 | return InstanceNorm2dPlus 37 | elif norm == 'VarianceNorm': 38 | return VarianceNorm2d 39 | elif norm == 'GroupNorm': 40 | return nn.GroupNorm 41 | else: 42 | raise ValueError('Unknown normalization: %s' % norm) 43 | 44 | 45 | class ConditionalBatchNorm2d(nn.Module): 46 | def __init__(self, num_features, num_classes, bias=True): 47 | super().__init__() 48 | self.num_features = num_features 49 | self.bias = bias 50 | self.bn = nn.BatchNorm2d(num_features, affine=False) 51 | if self.bias: 52 | self.embed = nn.Embedding(num_classes, num_features * 2) 53 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 54 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 55 | else: 56 | self.embed = nn.Embedding(num_classes, num_features) 57 | self.embed.weight.data.uniform_() 58 | 59 | def forward(self, x, y): 60 | out = self.bn(x) 61 | if self.bias: 62 | gamma, beta = self.embed(y).chunk(2, dim=1) 63 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 64 | else: 65 | gamma = self.embed(y) 66 | out = gamma.view(-1, self.num_features, 1, 1) * out 67 | return out 68 | 69 | 70 | class ConditionalInstanceNorm2d(nn.Module): 71 | def __init__(self, num_features, num_classes, bias=True): 72 | super().__init__() 73 | self.num_features = num_features 74 | self.bias = bias 75 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 76 | if bias: 77 | self.embed = nn.Embedding(num_classes, num_features * 2) 78 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 79 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 80 | else: 81 | self.embed = nn.Embedding(num_classes, num_features) 82 | self.embed.weight.data.uniform_() 83 | 84 | def forward(self, x, y): 85 | h = self.instance_norm(x) 86 | if self.bias: 87 | gamma, beta = self.embed(y).chunk(2, dim=-1) 88 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 89 | else: 90 | gamma = self.embed(y) 91 | out = gamma.view(-1, self.num_features, 1, 1) * h 92 | return out 93 | 94 | 95 | class ConditionalVarianceNorm2d(nn.Module): 96 | def __init__(self, num_features, num_classes, bias=False): 97 | super().__init__() 98 | self.num_features = num_features 99 | self.bias = bias 100 | self.embed = nn.Embedding(num_classes, num_features) 101 | self.embed.weight.data.normal_(1, 0.02) 102 | 103 | def forward(self, x, y): 104 | vars = torch.var(x, dim=(2, 3), keepdim=True) 105 | h = x / torch.sqrt(vars + 1e-5) 106 | 107 | gamma = self.embed(y) 108 | out = gamma.view(-1, self.num_features, 1, 1) * h 109 | return out 110 | 111 | 112 | class VarianceNorm2d(nn.Module): 113 | def __init__(self, num_features, bias=False): 114 | super().__init__() 115 | self.num_features = num_features 116 | self.bias = bias 117 | self.alpha = nn.Parameter(torch.zeros(num_features)) 118 | self.alpha.data.normal_(1, 0.02) 119 | 120 | def forward(self, x): 121 | vars = torch.var(x, dim=(2, 3), keepdim=True) 122 | h = x / torch.sqrt(vars + 1e-5) 123 | 124 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 125 | return out 126 | 127 | 128 | class ConditionalNoneNorm2d(nn.Module): 129 | def __init__(self, num_features, num_classes, bias=True): 130 | super().__init__() 131 | self.num_features = num_features 132 | self.bias = bias 133 | if bias: 134 | self.embed = nn.Embedding(num_classes, num_features * 2) 135 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 136 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 137 | else: 138 | self.embed = nn.Embedding(num_classes, num_features) 139 | self.embed.weight.data.uniform_() 140 | 141 | def forward(self, x, y): 142 | if self.bias: 143 | gamma, beta = self.embed(y).chunk(2, dim=-1) 144 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 145 | else: 146 | gamma = self.embed(y) 147 | out = gamma.view(-1, self.num_features, 1, 1) * x 148 | return out 149 | 150 | 151 | class NoneNorm2d(nn.Module): 152 | def __init__(self, num_features, bias=True): 153 | super().__init__() 154 | 155 | def forward(self, x): 156 | return x 157 | 158 | 159 | class InstanceNorm2dPlus(nn.Module): 160 | def __init__(self, num_features, bias=True): 161 | super().__init__() 162 | self.num_features = num_features 163 | self.bias = bias 164 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 165 | self.alpha = nn.Parameter(torch.zeros(num_features)) 166 | self.gamma = nn.Parameter(torch.zeros(num_features)) 167 | self.alpha.data.normal_(1, 0.02) 168 | self.gamma.data.normal_(1, 0.02) 169 | if bias: 170 | self.beta = nn.Parameter(torch.zeros(num_features)) 171 | 172 | def forward(self, x): 173 | means = torch.mean(x, dim=(2, 3)) 174 | m = torch.mean(means, dim=-1, keepdim=True) 175 | v = torch.var(means, dim=-1, keepdim=True) 176 | means = (means - m) / (torch.sqrt(v + 1e-5)) 177 | h = self.instance_norm(x) 178 | 179 | if self.bias: 180 | h = h + means[..., None, None] * self.alpha[..., None, None] 181 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 182 | else: 183 | h = h + means[..., None, None] * self.alpha[..., None, None] 184 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 185 | return out 186 | 187 | 188 | class ConditionalInstanceNorm2dPlus(nn.Module): 189 | def __init__(self, num_features, num_classes, bias=True): 190 | super().__init__() 191 | self.num_features = num_features 192 | self.bias = bias 193 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 194 | if bias: 195 | self.embed = nn.Embedding(num_classes, num_features * 3) 196 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 197 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 198 | else: 199 | self.embed = nn.Embedding(num_classes, 2 * num_features) 200 | self.embed.weight.data.normal_(1, 0.02) 201 | 202 | def forward(self, x, y): 203 | means = torch.mean(x, dim=(2, 3)) 204 | m = torch.mean(means, dim=-1, keepdim=True) 205 | v = torch.var(means, dim=-1, keepdim=True) 206 | means = (means - m) / (torch.sqrt(v + 1e-5)) 207 | h = self.instance_norm(x) 208 | 209 | if self.bias: 210 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 211 | h = h + means[..., None, None] * alpha[..., None, None] 212 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 213 | else: 214 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 215 | h = h + means[..., None, None] * alpha[..., None, None] 216 | out = gamma.view(-1, self.num_features, 1, 1) * h 217 | return out 218 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/op/__init__.py: -------------------------------------------------------------------------------- 1 | # from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /bridge/models/ddpmpp/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /bridge/models/ddpmpp/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | """Layers used for up-sampling or down-sampling images. 2 | 3 | Many functions are ported from https://github.com/NVlabs/stylegan2. 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | import numpy as np 10 | try: 11 | from .op import upfirdn2d 12 | except: 13 | print('Failed to import upfirdn2d') 14 | 15 | 16 | # Function ported from StyleGAN2 17 | def get_weight(module, 18 | shape, 19 | weight_var='weight', 20 | kernel_init=None): 21 | """Get/create weight tensor for a convolution or fully-connected layer.""" 22 | 23 | return module.param(weight_var, kernel_init, shape) 24 | 25 | 26 | class Conv2d(nn.Module): 27 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 28 | 29 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 30 | resample_kernel=(1, 3, 3, 1), 31 | use_bias=True, 32 | kernel_init=None): 33 | super().__init__() 34 | assert not (up and down) 35 | assert kernel >= 1 and kernel % 2 == 1 36 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 37 | if kernel_init is not None: 38 | self.weight.data = kernel_init(self.weight.data.shape) 39 | if use_bias: 40 | self.bias = nn.Parameter(torch.zeros(out_ch)) 41 | 42 | self.up = up 43 | self.down = down 44 | self.resample_kernel = resample_kernel 45 | self.kernel = kernel 46 | self.use_bias = use_bias 47 | 48 | def forward(self, x): 49 | if self.up: 50 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 51 | elif self.down: 52 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 53 | else: 54 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 55 | 56 | if self.use_bias: 57 | x = x + self.bias.reshape(1, -1, 1, 1) 58 | 59 | return x 60 | 61 | 62 | def naive_upsample_2d(x, factor=2): 63 | _N, C, H, W = x.shape 64 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 65 | x = x.repeat(1, 1, 1, factor, 1, factor) 66 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 67 | 68 | 69 | def naive_downsample_2d(x, factor=2): 70 | _N, C, H, W = x.shape 71 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 72 | return torch.mean(x, dim=(3, 5)) 73 | 74 | 75 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 76 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 77 | 78 | Padding is performed only once at the beginning, not between the 79 | operations. 80 | The fused op is considerably more efficient than performing the same 81 | calculation 82 | using standard TensorFlow ops. It supports gradients of arbitrary order. 83 | Args: 84 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 85 | C]`. 86 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 87 | outChannels]`. Grouped convolution can be performed by `inChannels = 88 | x.shape[0] // numGroups`. 89 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 90 | (separable). The default is `[1] * factor`, which corresponds to 91 | nearest-neighbor upsampling. 92 | factor: Integer upsampling factor (default: 2). 93 | gain: Scaling factor for signal magnitude (default: 1.0). 94 | 95 | Returns: 96 | Tensor of the shape `[N, C, H * factor, W * factor]` or 97 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 98 | """ 99 | 100 | assert isinstance(factor, int) and factor >= 1 101 | 102 | # Check weight shape. 103 | assert len(w.shape) == 4 104 | convH = w.shape[2] 105 | convW = w.shape[3] 106 | inC = w.shape[1] 107 | outC = w.shape[0] 108 | 109 | assert convW == convH 110 | 111 | # Setup filter kernel. 112 | if k is None: 113 | k = [1] * factor 114 | k = _setup_kernel(k) * (gain * (factor ** 2)) 115 | p = (k.shape[0] - factor) - (convW - 1) 116 | 117 | stride = (factor, factor) 118 | 119 | # Determine data dimensions. 120 | stride = [1, 1, factor, factor] 121 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 122 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 123 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 124 | assert output_padding[0] >= 0 and output_padding[1] >= 0 125 | num_groups = _shape(x, 1) // inC 126 | 127 | # Transpose weights. 128 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 129 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 130 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 131 | 132 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 133 | ## Original TF code. 134 | # x = tf.nn.conv2d_transpose( 135 | # x, 136 | # w, 137 | # output_shape=output_shape, 138 | # strides=stride, 139 | # padding='VALID', 140 | # data_format=data_format) 141 | ## JAX equivalent 142 | 143 | return upfirdn2d(x, torch.tensor(k, device=x.device), 144 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 145 | 146 | 147 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 148 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 149 | 150 | Padding is performed only once at the beginning, not between the operations. 151 | The fused op is considerably more efficient than performing the same 152 | calculation 153 | using standard TensorFlow ops. It supports gradients of arbitrary order. 154 | Args: 155 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 156 | C]`. 157 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 158 | outChannels]`. Grouped convolution can be performed by `inChannels = 159 | x.shape[0] // numGroups`. 160 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 161 | (separable). The default is `[1] * factor`, which corresponds to 162 | average pooling. 163 | factor: Integer downsampling factor (default: 2). 164 | gain: Scaling factor for signal magnitude (default: 1.0). 165 | 166 | Returns: 167 | Tensor of the shape `[N, C, H // factor, W // factor]` or 168 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 169 | """ 170 | 171 | assert isinstance(factor, int) and factor >= 1 172 | _outC, _inC, convH, convW = w.shape 173 | assert convW == convH 174 | if k is None: 175 | k = [1] * factor 176 | k = _setup_kernel(k) * gain 177 | p = (k.shape[0] - factor) + (convW - 1) 178 | s = [factor, factor] 179 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 180 | pad=((p + 1) // 2, p // 2)) 181 | return F.conv2d(x, w, stride=s, padding=0) 182 | 183 | 184 | def _setup_kernel(k): 185 | k = np.asarray(k, dtype=np.float32) 186 | if k.ndim == 1: 187 | k = np.outer(k, k) 188 | k /= np.sum(k) 189 | assert k.ndim == 2 190 | assert k.shape[0] == k.shape[1] 191 | return k 192 | 193 | 194 | def _shape(x, dim): 195 | return x.shape[dim] 196 | 197 | 198 | def upsample_2d(x, k=None, factor=2, gain=1): 199 | r"""Upsample a batch of 2D images with the given filter. 200 | 201 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 202 | and upsamples each image with the given filter. The filter is normalized so 203 | that 204 | if the input pixels are constant, they will be scaled by the specified 205 | `gain`. 206 | Pixels outside the image are assumed to be zero, and the filter is padded 207 | with 208 | zeros so that its shape is a multiple of the upsampling factor. 209 | Args: 210 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 211 | C]`. 212 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 213 | (separable). The default is `[1] * factor`, which corresponds to 214 | nearest-neighbor upsampling. 215 | factor: Integer upsampling factor (default: 2). 216 | gain: Scaling factor for signal magnitude (default: 1.0). 217 | 218 | Returns: 219 | Tensor of the shape `[N, C, H * factor, W * factor]` 220 | """ 221 | assert isinstance(factor, int) and factor >= 1 222 | if k is None: 223 | k = [1] * factor 224 | k = _setup_kernel(k) * (gain * (factor ** 2)) 225 | p = k.shape[0] - factor 226 | return upfirdn2d(x, torch.tensor(k, device=x.device), 227 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 228 | 229 | 230 | def downsample_2d(x, k=None, factor=2, gain=1): 231 | r"""Downsample a batch of 2D images with the given filter. 232 | 233 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 234 | and downsamples each image with the given filter. The filter is normalized 235 | so that 236 | if the input pixels are constant, they will be scaled by the specified 237 | `gain`. 238 | Pixels outside the image are assumed to be zero, and the filter is padded 239 | with 240 | zeros so that its shape is a multiple of the downsampling factor. 241 | Args: 242 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 243 | C]`. 244 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 245 | (separable). The default is `[1] * factor`, which corresponds to 246 | average pooling. 247 | factor: Integer downsampling factor (default: 2). 248 | gain: Scaling factor for signal magnitude (default: 1.0). 249 | 250 | Returns: 251 | Tensor of the shape `[N, C, H // factor, W // factor]` 252 | """ 253 | 254 | assert isinstance(factor, int) and factor >= 1 255 | if k is None: 256 | k = [1] * factor 257 | k = _setup_kernel(k) * gain 258 | p = k.shape[0] - factor 259 | return upfirdn2d(x, torch.tensor(k, device=x.device), 260 | down=factor, pad=((p + 1) // 2, p // 2)) 261 | -------------------------------------------------------------------------------- /bridge/models/ddpmpp/utils.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yang-song/score_sde_pytorch 2 | 3 | 4 | # coding=utf-8 5 | # Copyright 2020 The Google Research Authors. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | """All functions and modules related to model definition. 20 | """ 21 | 22 | import torch 23 | # import sde_lib 24 | import numpy as np 25 | 26 | 27 | _MODELS = {} 28 | 29 | 30 | def register_model(cls=None, *, name=None): 31 | """A decorator for registering model classes.""" 32 | 33 | def _register(cls): 34 | if name is None: 35 | local_name = cls.__name__ 36 | else: 37 | local_name = name 38 | if local_name in _MODELS: 39 | raise ValueError(f'Already registered model with name: {local_name}') 40 | _MODELS[local_name] = cls 41 | return cls 42 | 43 | if cls is None: 44 | return _register 45 | else: 46 | return _register(cls) 47 | 48 | 49 | def get_model(name): 50 | return _MODELS[name] 51 | 52 | 53 | def get_sigmas(config): 54 | """Get sigmas --- the set of noise levels for SMLD from config files. 55 | Args: 56 | config: A ConfigDict object parsed from the config file 57 | Returns: 58 | sigmas: a jax numpy arrary of noise levels 59 | """ 60 | sigmas = np.exp( 61 | np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) 62 | 63 | return sigmas 64 | 65 | 66 | def get_ddpm_params(config): 67 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 68 | num_diffusion_timesteps = 1000 69 | # parameters need to be adapted if number of time steps differs from 1000 70 | beta_start = config.model.beta_min / config.model.num_scales 71 | beta_end = config.model.beta_max / config.model.num_scales 72 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 73 | 74 | alphas = 1. - betas 75 | alphas_cumprod = np.cumprod(alphas, axis=0) 76 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 77 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 78 | 79 | return { 80 | 'betas': betas, 81 | 'alphas': alphas, 82 | 'alphas_cumprod': alphas_cumprod, 83 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 84 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 85 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 86 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 87 | 'num_diffusion_timesteps': num_diffusion_timesteps 88 | } 89 | 90 | 91 | def create_model(config): 92 | """Create the score model.""" 93 | model_name = config.model.name 94 | score_model = get_model(model_name)(config) 95 | score_model = score_model.to(config.device) 96 | score_model = torch.nn.DataParallel(score_model) 97 | return score_model 98 | 99 | 100 | def get_model_fn(model, train=False): 101 | """Create a function to give the output of the score-based model. 102 | 103 | Args: 104 | model: The score model. 105 | train: `True` for training and `False` for evaluation. 106 | 107 | Returns: 108 | A model function. 109 | """ 110 | 111 | def model_fn(x, labels): 112 | """Compute the output of the score-based model. 113 | 114 | Args: 115 | x: A mini-batch of input data. 116 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 117 | for different models. 118 | 119 | Returns: 120 | A tuple of (model output, new mutable states) 121 | """ 122 | if not train: 123 | model.eval() 124 | return model(x, labels) 125 | else: 126 | model.train() 127 | return model(x, labels) 128 | 129 | return model_fn 130 | 131 | 132 | def get_score_fn(sde, model, train=False, continuous=False): 133 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 134 | 135 | Args: 136 | sde: An `sde_lib.SDE` object that represents the forward SDE. 137 | model: A score model. 138 | train: `True` for training and `False` for evaluation. 139 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 140 | 141 | Returns: 142 | A score function. 143 | """ 144 | raise NotImplementedError 145 | model_fn = get_model_fn(model, train=train) 146 | 147 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 148 | def score_fn(x, t): 149 | # Scale neural network output by standard deviation and flip sign 150 | if continuous or isinstance(sde, sde_lib.subVPSDE): 151 | # For VP-trained models, t=0 corresponds to the lowest noise level 152 | # The maximum value of time embedding is assumed to 999 for 153 | # continuously-trained models. 154 | labels = t * 999 155 | score = model_fn(x, labels) 156 | std = sde.marginal_prob(torch.zeros_like(x), t)[1] 157 | else: 158 | # For VP-trained models, t=0 corresponds to the lowest noise level 159 | labels = t * (sde.N - 1) 160 | score = model_fn(x, labels) 161 | std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] 162 | 163 | score = -score / std[:, None, None, None] 164 | return score 165 | 166 | elif isinstance(sde, sde_lib.VESDE): 167 | def score_fn(x, t): 168 | if continuous: 169 | labels = sde.marginal_prob(torch.zeros_like(x), t)[1] 170 | else: 171 | # For VE-trained models, t=0 corresponds to the highest noise level 172 | labels = sde.T - t 173 | labels *= sde.N - 1 174 | labels = torch.round(labels).long() 175 | 176 | score = model_fn(x, labels) 177 | return score 178 | 179 | else: 180 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 181 | 182 | return score_fn 183 | 184 | 185 | def to_flattened_numpy(x): 186 | """Flatten a torch tensor `x` and convert it to numpy.""" 187 | return x.detach().cpu().numpy().reshape((-1,)) 188 | 189 | 190 | def from_flattened_numpy(x, shape): 191 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 192 | return torch.from_numpy(x.reshape(shape)) -------------------------------------------------------------------------------- /bridge/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import * 2 | 3 | -------------------------------------------------------------------------------- /bridge/models/unet/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() -------------------------------------------------------------------------------- /bridge/models/unet/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from abc import abstractmethod 4 | import torch as th 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | def forward(self, x): 12 | return super().forward(x.float()).type(x.dtype) 13 | 14 | 15 | def conv_nd(dims, *args, **kwargs): 16 | """ 17 | Create a 1D, 2D, or 3D convolution module. 18 | """ 19 | if dims == 1: 20 | return nn.Conv1d(*args, **kwargs) 21 | elif dims == 2: 22 | return nn.Conv2d(*args, **kwargs) 23 | elif dims == 3: 24 | return nn.Conv3d(*args, **kwargs) 25 | raise ValueError(f"unsupported dimensions: {dims}") 26 | 27 | 28 | def linear(*args, **kwargs): 29 | """ 30 | Create a linear module. 31 | """ 32 | return nn.Linear(*args, **kwargs) 33 | 34 | 35 | def avg_pool_nd(dims, *args, **kwargs): 36 | """ 37 | Create a 1D, 2D, or 3D average pooling module. 38 | """ 39 | if dims == 1: 40 | return nn.AvgPool1d(*args, **kwargs) 41 | elif dims == 2: 42 | return nn.AvgPool2d(*args, **kwargs) 43 | elif dims == 3: 44 | return nn.AvgPool3d(*args, **kwargs) 45 | raise ValueError(f"unsupported dimensions: {dims}") 46 | 47 | 48 | def update_ema(target_params, source_params, rate=0.99): 49 | """ 50 | Update target parameters to be closer to those of source parameters using 51 | an exponential moving average. 52 | :param target_params: the target parameter sequence. 53 | :param source_params: the source parameter sequence. 54 | :param rate: the EMA rate (closer to 1 means slower). 55 | """ 56 | for targ, src in zip(target_params, source_params): 57 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 58 | 59 | 60 | def zero_module(module, active=True): 61 | """ 62 | Zero out the parameters of a module and return it. 63 | """ 64 | if active: 65 | for p in module.parameters(): 66 | p.detach().zero_() 67 | return module 68 | 69 | 70 | def scale_module(module, scale): 71 | """ 72 | Scale the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().mul_(scale) 76 | return module 77 | 78 | 79 | def mean_flat(tensor): 80 | """ 81 | Take the mean over all non-batch dimensions. 82 | """ 83 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 84 | 85 | 86 | def normalization(channels, num_groups=32): 87 | """ 88 | Make a standard normalization layer. 89 | :param channels: number of input channels. 90 | :return: an nn.Module for normalization. 91 | """ 92 | return GroupNorm32(num_groups, channels) 93 | # return nn.GroupNorm(32, channels) 94 | 95 | def normalization_act(channels, num_groups=32): 96 | return nn.Sequential(GroupNorm32(num_groups, channels), nn.SiLU(inplace=True)) 97 | 98 | def timestep_embedding(timesteps, dim, max_period=10000): 99 | """ 100 | Create sinusoidal timestep embeddings. 101 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an [N x dim] Tensor of positional embeddings. 106 | """ 107 | half = dim // 2 108 | freqs = th.exp( 109 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / (half - 1) 110 | ).to(device=timesteps.device) 111 | args = timesteps[:, None].float() * freqs[None] 112 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 113 | if dim % 2: 114 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 115 | return embedding 116 | 117 | 118 | def checkpoint(func, inputs, params, flag): 119 | """ 120 | Evaluate a function without caching intermediate activations, allowing for 121 | reduced memory at the expense of extra compute in the backward pass. 122 | :param func: the function to evaluate. 123 | :param inputs: the argument sequence to pass to `func`. 124 | :param params: a sequence of parameters `func` depends on but does not 125 | explicitly take as arguments. 126 | :param flag: if False, disable gradient checkpointing. 127 | """ 128 | if flag: 129 | args = tuple(inputs) + tuple(params) 130 | return CheckpointFunction.apply(func, len(inputs), *args) 131 | else: 132 | return func(*inputs) 133 | 134 | 135 | class CheckpointFunction(th.autograd.Function): 136 | @staticmethod 137 | def forward(ctx, run_function, length, *args): 138 | ctx.run_function = run_function 139 | ctx.input_tensors = list(args[:length]) 140 | ctx.input_params = list(args[length:]) 141 | with th.no_grad(): 142 | output_tensors = ctx.run_function(*ctx.input_tensors) 143 | return output_tensors 144 | 145 | @staticmethod 146 | def backward(ctx, *output_grads): 147 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 148 | with th.enable_grad(): 149 | # Fixes a bug where the first op in run_function modifies the 150 | # Tensor storage in place, which is not allowed for detach()'d 151 | # Tensors. 152 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 153 | output_tensors = ctx.run_function(*shallow_copies) 154 | input_grads = th.autograd.grad( 155 | output_tensors, 156 | ctx.input_tensors + ctx.input_params, 157 | output_grads, 158 | allow_unused=True, 159 | ) 160 | del ctx.input_tensors 161 | del ctx.input_params 162 | del output_tensors 163 | return (None, None) + input_grads 164 | 165 | 166 | class TimestepBlock(nn.Module): 167 | """ 168 | Any module where forward() takes timestep embeddings as a second argument. 169 | """ 170 | 171 | @abstractmethod 172 | def forward(self, x, emb): 173 | """ 174 | Apply the module to `x` given `emb` timestep embeddings. 175 | """ 176 | 177 | 178 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 179 | """ 180 | A sequential module that passes timestep embeddings to the children that 181 | support it as an extra input. 182 | """ 183 | 184 | def forward(self, x, emb): 185 | for layer in self: 186 | if isinstance(layer, TimestepBlock): 187 | x = layer(x, emb) 188 | else: 189 | x = layer(x) 190 | return x 191 | 192 | 193 | class Upsample(nn.Module): 194 | """ 195 | An upsampling layer with an optional convolution. 196 | :param channels: channels in the inputs and outputs. 197 | :param use_conv: a bool determining if a convolution is applied. 198 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 199 | upsampling occurs in the inner-two dimensions. 200 | """ 201 | 202 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 203 | super().__init__() 204 | self.channels = channels 205 | self.out_channels = out_channels or channels 206 | self.use_conv = use_conv 207 | self.dims = dims 208 | if use_conv: 209 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) 210 | 211 | def forward(self, x): 212 | assert x.shape[1] == self.channels 213 | if self.dims == 3: 214 | x = F.interpolate( 215 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 216 | ) 217 | else: 218 | x = F.interpolate(x, scale_factor=2, mode="nearest") 219 | if self.use_conv: 220 | x = self.conv(x) 221 | return x 222 | 223 | 224 | class Downsample(nn.Module): 225 | """ 226 | A downsampling layer with an optional convolution. 227 | :param channels: channels in the inputs and outputs. 228 | :param use_conv: a bool determining if a convolution is applied. 229 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 230 | downsampling occurs in the inner-two dimensions. 231 | """ 232 | 233 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 234 | super().__init__() 235 | self.channels = channels 236 | self.out_channels = out_channels or channels 237 | self.use_conv = use_conv 238 | self.dims = dims 239 | stride = 2 if dims != 3 else (1, 2, 2) 240 | if use_conv: 241 | self.op = conv_nd( 242 | dims, self.channels, self.out_channels, 3, stride=stride, padding=1 243 | ) 244 | else: 245 | assert self.channels == self.out_channels 246 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 247 | 248 | def forward(self, x): 249 | assert x.shape[1] == self.channels 250 | return self.op(x) 251 | 252 | 253 | class ResBlock(TimestepBlock): 254 | """ 255 | A residual block that can optionally change the number of channels. 256 | :param channels: the number of input channels. 257 | :param emb_channels: the number of timestep embedding channels. 258 | :param dropout: the rate of dropout. 259 | :param out_channels: if specified, the number of out channels. 260 | :param use_conv: if True and out_channels is specified, use a spatial 261 | convolution instead of a smaller 1x1 convolution to change the 262 | channels in the skip connection. 263 | :param dims: determines if the signal is 1D, 2D, or 3D. 264 | :param use_checkpoint: if True, use gradient checkpointing on this module. 265 | :param up: if True, use this block for upsampling. 266 | :param down: if True, use this block for downsampling. 267 | """ 268 | 269 | def __init__( 270 | self, 271 | channels, 272 | emb_channels, 273 | dropout, 274 | out_channels=None, 275 | use_conv=False, 276 | use_scale_shift_norm=False, 277 | dims=2, 278 | use_checkpoint=False, 279 | up=False, 280 | down=False, 281 | num_groups=32 282 | ): 283 | super().__init__() 284 | self.channels = channels 285 | self.emb_channels = emb_channels 286 | self.dropout = dropout 287 | self.out_channels = out_channels or channels 288 | self.use_conv = use_conv 289 | self.use_checkpoint = use_checkpoint 290 | self.use_scale_shift_norm = use_scale_shift_norm 291 | 292 | self.in_layers = nn.Sequential( 293 | normalization(channels, num_groups), 294 | nn.SiLU(inplace=True), 295 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 296 | ) 297 | 298 | self.updown = up or down 299 | 300 | if up: 301 | self.h_upd = Upsample(channels, False, dims) 302 | self.x_upd = Upsample(channels, False, dims) 303 | elif down: 304 | self.h_upd = Downsample(channels, False, dims) 305 | self.x_upd = Downsample(channels, False, dims) 306 | else: 307 | self.h_upd = self.x_upd = nn.Identity() 308 | 309 | self.emb_layers = nn.Sequential( 310 | nn.SiLU(), 311 | linear( 312 | emb_channels, 313 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 314 | ), 315 | ) 316 | self.out_layers = nn.Sequential( 317 | normalization(self.out_channels, num_groups), 318 | nn.SiLU(inplace=True), 319 | nn.Dropout(p=dropout, inplace=True), 320 | zero_module( 321 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 322 | ), 323 | ) 324 | 325 | if self.out_channels == channels: 326 | self.skip_connection = nn.Identity() 327 | elif use_conv: 328 | self.skip_connection = conv_nd( 329 | dims, channels, self.out_channels, 3, padding=1 330 | ) 331 | else: 332 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 333 | 334 | def forward(self, x, emb): 335 | """ 336 | Apply the block to a Tensor, conditioned on a timestep embedding. 337 | :param x: an [N x C x ...] Tensor of features. 338 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 339 | :return: an [N x C x ...] Tensor of outputs. 340 | """ 341 | return checkpoint( 342 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 343 | ) 344 | 345 | def _forward(self, x, emb): 346 | if self.updown: 347 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 348 | h = in_rest(x) 349 | h = self.h_upd(h) 350 | x = self.x_upd(x) 351 | h = in_conv(h) 352 | else: 353 | h = self.in_layers(x) 354 | emb_out = self.emb_layers(emb).type(h.dtype) 355 | while len(emb_out.shape) < len(h.shape): 356 | emb_out = emb_out[..., None] 357 | if self.use_scale_shift_norm: 358 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 359 | scale, shift = th.chunk(emb_out, 2, dim=1) 360 | h = out_norm(h) * (1 + scale) + shift 361 | h = out_rest(h) 362 | else: 363 | h = h + emb_out 364 | h = self.out_layers(h) 365 | return self.skip_connection(x) + h 366 | 367 | 368 | class AttentionBlock(nn.Module): 369 | """ 370 | An attention block that allows spatial positions to attend to each other. 371 | Originally ported from here, but adapted to the N-d case. 372 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 373 | """ 374 | 375 | def __init__(self, channels, num_heads=1, use_checkpoint=False, num_groups=32): 376 | super().__init__() 377 | self.channels = channels 378 | self.num_heads = num_heads 379 | self.use_checkpoint = use_checkpoint 380 | 381 | self.norm = normalization(channels, num_groups) 382 | self.qkv = conv_nd(1, channels, channels * 3, 1) 383 | self.attention = QKVAttention() 384 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 385 | 386 | def forward(self, x): 387 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 388 | 389 | def _forward(self, x): 390 | b, c, *spatial = x.shape 391 | x = x.reshape(b, c, -1) 392 | qkv = self.qkv(self.norm(x)) 393 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 394 | h = self.attention(qkv) 395 | h = h.reshape(b, -1, h.shape[-1]) 396 | h = self.proj_out(h) 397 | return (x + h).reshape(b, c, *spatial) 398 | 399 | 400 | class QKVAttention(nn.Module): 401 | """ 402 | A module which performs QKV attention. 403 | """ 404 | 405 | def forward(self, qkv): 406 | """ 407 | Apply QKV attention. 408 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 409 | :return: an [N x C x T] tensor after attention. 410 | """ 411 | ch = qkv.shape[1] // 3 412 | q, k, v = th.split(qkv, ch, dim=1) 413 | scale = 1 / math.sqrt(math.sqrt(ch)) 414 | weight = th.einsum( 415 | "bct,bcs->bts", q * scale, k * scale 416 | ) # More stable with f16 than dividing afterwards 417 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 418 | return th.einsum("bts,bcs->bct", weight, v) 419 | 420 | @staticmethod 421 | def count_flops(model, _x, y): 422 | """ 423 | A counter for the `thop` package to count the operations in an 424 | attention operation. 425 | Meant to be used like: 426 | macs, params = thop.profile( 427 | model, 428 | inputs=(inputs, timestamps), 429 | custom_ops={QKVAttention: QKVAttention.count_flops}, 430 | ) 431 | """ 432 | b, c, *spatial = y[0].shape 433 | num_spatial = int(np.prod(spatial)) 434 | # We perform two matmuls with the same number of ops. 435 | # The first computes the weight matrix, the second computes 436 | # the combination of the value vectors. 437 | matmul_ops = 2 * b * (num_spatial ** 2) * c 438 | model.total_ops += th.DoubleTensor([matmul_ops]) 439 | 440 | 441 | class BasicResBlock(nn.Module): 442 | """ 443 | A residual block that can optionally change the number of channels. 444 | :param channels: the number of input channels. 445 | :param dropout: the rate of dropout. 446 | :param out_channels: if specified, the number of out channels. 447 | :param use_conv: if True and out_channels is specified, use a spatial 448 | convolution instead of a smaller 1x1 convolution to change the 449 | channels in the skip connection. 450 | :param dims: determines if the signal is 1D, 2D, or 3D. 451 | :param use_checkpoint: if True, use gradient checkpointing on this module. 452 | :param up: if True, use this block for upsampling. 453 | :param down: if True, use this block for downsampling. 454 | """ 455 | 456 | def __init__( 457 | self, 458 | channels, 459 | dropout, 460 | out_channels=None, 461 | use_conv=False, 462 | dims=2, 463 | use_checkpoint=False, 464 | up=False, 465 | down=False, 466 | ): 467 | super().__init__() 468 | self.channels = channels 469 | self.dropout = dropout 470 | self.out_channels = out_channels or channels 471 | self.use_conv = use_conv 472 | self.use_checkpoint = use_checkpoint 473 | 474 | self.in_layers = nn.Sequential( 475 | normalization(channels), 476 | nn.SiLU(inplace=True), 477 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 478 | ) 479 | 480 | self.updown = up or down 481 | 482 | if up: 483 | self.h_upd = Upsample(channels, False, dims) 484 | self.x_upd = Upsample(channels, False, dims) 485 | elif down: 486 | self.h_upd = Downsample(channels, False, dims) 487 | self.x_upd = Downsample(channels, False, dims) 488 | else: 489 | self.h_upd = self.x_upd = nn.Identity() 490 | 491 | self.out_layers = nn.Sequential( 492 | normalization(self.out_channels), 493 | nn.SiLU(inplace=True), 494 | nn.Dropout(p=dropout, inplace=True), 495 | zero_module( 496 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 497 | ), 498 | ) 499 | 500 | if self.out_channels == channels: 501 | self.skip_connection = nn.Identity() 502 | elif use_conv: 503 | self.skip_connection = conv_nd( 504 | dims, channels, self.out_channels, 3, padding=1 505 | ) 506 | else: 507 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 508 | 509 | def forward(self, x): 510 | """ 511 | Apply the block to a Tensor. 512 | :param x: an [N x C x ...] Tensor of features. 513 | :return: an [N x C x ...] Tensor of outputs. 514 | """ 515 | return checkpoint( 516 | self._forward, (x, ), self.parameters(), self.use_checkpoint 517 | ) 518 | 519 | def _forward(self, x): 520 | if self.updown: 521 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 522 | h = in_rest(x) 523 | h = self.h_upd(h) 524 | x = self.x_upd(x) 525 | h = in_conv(h) 526 | else: 527 | h = self.in_layers(x) 528 | 529 | h = self.out_layers(h) 530 | return self.skip_connection(x) + h 531 | 532 | def expand_dims(t, target_len): 533 | assert target_len >= len(t.shape) 534 | out = t[(..., ) + (None, ) * (target_len - len(t.shape))] 535 | return out -------------------------------------------------------------------------------- /bridge/models/unet/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .fp16_util import convert_module_to_f16, convert_module_to_f32 11 | from .layers import * 12 | 13 | 14 | class UNetModel(nn.Module): 15 | """ 16 | The full UNet model with attention and timestep embedding. 17 | :param in_channels: channels in the input Tensor. 18 | :param model_channels: base channel count for the model. 19 | :param out_channels: channels in the output Tensor. 20 | :param num_res_blocks: number of residual blocks per downsample. 21 | :param attention_resolutions: a collection of downsample rates at which 22 | attention will take place. May be a set, list, or tuple. 23 | For example, if this contains 4, then at 4x downsampling, attention 24 | will be used. 25 | :param dropout: the dropout probability. 26 | :param channel_mult: channel multiplier for each level of the UNet. 27 | :param conv_resample: if True, use learned convolutions for upsampling and 28 | downsampling. 29 | :param dims: determines if the signal is 1D, 2D, or 3D. 30 | :param num_classes: if specified (as an int), then this model will be 31 | class-conditional with `num_classes` classes. 32 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 33 | :param num_heads: the number of attention heads in each attention layer. 34 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 35 | :param resblock_updown: use residual blocks for up/downsampling. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | in_channels, 41 | model_channels, 42 | out_channels, 43 | num_res_blocks, 44 | attention_resolutions, 45 | dropout=0, 46 | channel_mult=(1, 2, 4, 8), 47 | conv_resample=True, 48 | dims=2, 49 | num_classes=None, 50 | use_checkpoint=False, 51 | num_heads=1, 52 | use_scale_shift_norm=False, 53 | resblock_updown=False, 54 | temb_scale=1 55 | ): 56 | super().__init__() 57 | 58 | self.locals = [ in_channels, 59 | model_channels, 60 | out_channels, 61 | num_res_blocks, 62 | attention_resolutions, 63 | dropout, 64 | channel_mult, 65 | conv_resample, 66 | dims, 67 | num_classes, 68 | use_checkpoint, 69 | num_heads, 70 | use_scale_shift_norm, 71 | resblock_updown, 72 | temb_scale 73 | ] 74 | self.in_channels = in_channels 75 | self.model_channels = model_channels 76 | self.out_channels = out_channels 77 | self.num_res_blocks = num_res_blocks 78 | self.attention_resolutions = attention_resolutions 79 | self.dropout = dropout 80 | self.channel_mult = channel_mult 81 | self.conv_resample = conv_resample 82 | self.num_classes = num_classes 83 | self.use_checkpoint = use_checkpoint 84 | self.num_heads = num_heads 85 | self.temb_scale = temb_scale 86 | 87 | # some hacky logic to allow small unets 88 | if self.model_channels <= 32: 89 | self.num_groups = 8 90 | else: 91 | self.num_groups = 32 92 | 93 | self.input_ch = int(channel_mult[0] * model_channels) 94 | ch = self.input_ch 95 | 96 | time_embed_dim = self.input_ch * 4 97 | self.time_embed = nn.Sequential( 98 | linear(self.input_ch, time_embed_dim), 99 | nn.SiLU(inplace=True), 100 | linear(time_embed_dim, time_embed_dim), 101 | ) 102 | 103 | if self.num_classes is not None: 104 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 105 | 106 | self.input_blocks = nn.ModuleList( 107 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] 108 | ) 109 | self._feature_size = ch 110 | input_block_chans = [ch] 111 | ds = 1 112 | for level, mult in enumerate(channel_mult): 113 | for _ in range(num_res_blocks): 114 | layers = [ 115 | ResBlock( 116 | ch, 117 | time_embed_dim, 118 | dropout, 119 | out_channels=int(mult * model_channels), 120 | dims=dims, 121 | use_checkpoint=use_checkpoint, 122 | use_scale_shift_norm=use_scale_shift_norm, 123 | num_groups=self.num_groups 124 | ) 125 | ] 126 | ch = int(mult * model_channels) 127 | if ds in attention_resolutions: 128 | layers.append( 129 | AttentionBlock( 130 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_groups=self.num_groups 131 | ) 132 | ) 133 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 134 | self._feature_size += ch 135 | input_block_chans.append(ch) 136 | if level != len(channel_mult) - 1: 137 | out_ch = ch 138 | self.input_blocks.append( 139 | TimestepEmbedSequential( 140 | ResBlock( 141 | ch, 142 | time_embed_dim, 143 | dropout, 144 | out_channels=out_ch, 145 | dims=dims, 146 | use_checkpoint=use_checkpoint, 147 | use_scale_shift_norm=use_scale_shift_norm, 148 | down=True, 149 | num_groups=self.num_groups 150 | ) 151 | if resblock_updown 152 | else Downsample( 153 | ch, conv_resample, dims=dims, out_channels=out_ch 154 | ) 155 | ) 156 | ) 157 | input_block_chans.append(ch) 158 | ds *= 2 159 | self._feature_size += ch 160 | 161 | self.middle_block = TimestepEmbedSequential( 162 | ResBlock( 163 | ch, 164 | time_embed_dim, 165 | dropout, 166 | dims=dims, 167 | use_checkpoint=use_checkpoint, 168 | use_scale_shift_norm=use_scale_shift_norm, 169 | num_groups=self.num_groups 170 | ), 171 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_groups=self.num_groups), 172 | ResBlock( 173 | ch, 174 | time_embed_dim, 175 | dropout, 176 | dims=dims, 177 | use_checkpoint=use_checkpoint, 178 | use_scale_shift_norm=use_scale_shift_norm, 179 | num_groups=self.num_groups 180 | ), 181 | ) 182 | self._feature_size += ch 183 | 184 | self.output_blocks = nn.ModuleList([]) 185 | for level, mult in list(enumerate(channel_mult))[::-1]: 186 | for i in range(num_res_blocks + 1): 187 | ich = input_block_chans.pop() 188 | layers = [ 189 | ResBlock( 190 | ch + ich, 191 | time_embed_dim, 192 | dropout, 193 | out_channels=int(model_channels * mult), 194 | dims=dims, 195 | use_checkpoint=use_checkpoint, 196 | use_scale_shift_norm=use_scale_shift_norm, 197 | num_groups=self.num_groups 198 | ) 199 | ] 200 | ch = int(model_channels * mult) 201 | if ds in attention_resolutions: 202 | layers.append( 203 | AttentionBlock( 204 | ch, 205 | use_checkpoint=use_checkpoint, 206 | num_heads=num_heads, 207 | num_groups=self.num_groups 208 | ) 209 | ) 210 | if level and i == num_res_blocks: 211 | out_ch = ch 212 | layers.append( 213 | ResBlock( 214 | ch, 215 | time_embed_dim, 216 | dropout, 217 | out_channels=out_ch, 218 | dims=dims, 219 | use_checkpoint=use_checkpoint, 220 | use_scale_shift_norm=use_scale_shift_norm, 221 | up=True, 222 | num_groups=self.num_groups 223 | ) 224 | if resblock_updown 225 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 226 | ) 227 | ds //= 2 228 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 229 | self._feature_size += ch 230 | 231 | self.out = nn.Sequential( 232 | normalization(ch, self.num_groups), 233 | nn.SiLU(inplace=True), 234 | zero_module(conv_nd(dims, self.input_ch, out_channels, 3, padding=1)), 235 | ) 236 | 237 | 238 | def convert_to_fp16(self): 239 | """ 240 | Convert the torso of the model to float16. 241 | """ 242 | self.input_blocks.apply(convert_module_to_f16) 243 | self.middle_block.apply(convert_module_to_f16) 244 | self.output_blocks.apply(convert_module_to_f16) 245 | 246 | def convert_to_fp32(self): 247 | """ 248 | Convert the torso of the model to float32. 249 | """ 250 | self.input_blocks.apply(convert_module_to_f32) 251 | self.middle_block.apply(convert_module_to_f32) 252 | self.output_blocks.apply(convert_module_to_f32) 253 | 254 | 255 | def forward(self, x, y, timesteps): 256 | 257 | """ 258 | Apply the model to an input batch. 259 | :param x: an [N x C x ...] Tensor of inputs. 260 | :param timesteps: a 1-D batch of timesteps. 261 | :param y: an [N] Tensor of labels, if class-conditional. 262 | :return: an [N x C x ...] Tensor of outputs. 263 | """ 264 | timesteps = timesteps.squeeze() 265 | assert (y is not None) == ( 266 | self.num_classes is not None 267 | ), "must specify y if and only if the model is class-conditional" 268 | 269 | hs = [] 270 | emb = self.time_embed(timestep_embedding(timesteps * self.temb_scale + 1, self.input_ch)) 271 | 272 | if self.num_classes is not None: 273 | assert y.shape == (x.shape[0],) 274 | emb = emb + self.label_emb(y) 275 | 276 | h = x # .type(self.dtype) 277 | for module in self.input_blocks: 278 | h = module(h, emb) 279 | hs.append(h) 280 | h = self.middle_block(h, emb) 281 | for module in self.output_blocks: 282 | h = th.cat([h, hs.pop()], dim=1) 283 | h = module(h, emb) 284 | h = h.type(x.dtype) 285 | return self.out(h) 286 | 287 | 288 | class SuperResModel(UNetModel): 289 | """ 290 | A UNetModel that performs super-resolution. 291 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 292 | """ 293 | 294 | def __init__(self, in_channels, cond_channels, *args, **kwargs): 295 | super().__init__(in_channels + cond_channels, *args, **kwargs) 296 | self.locals[0] = in_channels 297 | self.locals.insert(1, cond_channels) 298 | 299 | def forward(self, x, low_res, timesteps, **kwargs): 300 | _, _, new_height, new_width = x.shape 301 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 302 | x = th.cat([x, upsampled], dim=1) 303 | return super().forward(x, None, timesteps, **kwargs) 304 | 305 | 306 | class DownscalerUNetModel(nn.Module): 307 | def __init__( 308 | self, 309 | in_channels, 310 | cond_channels, 311 | model_channels, 312 | out_channels, 313 | num_res_blocks, 314 | dropout=0, 315 | channel_mult=(1, 2, 4, 8), 316 | dims=2, 317 | temb_scale=1, 318 | mean_bypass=False, 319 | scale_mean_bypass=False, 320 | shift_input=False, 321 | shift_output=False, 322 | **kwargs 323 | ): 324 | super().__init__() 325 | 326 | self.locals = [ in_channels, 327 | cond_channels, 328 | model_channels, 329 | out_channels, 330 | num_res_blocks, 331 | dropout, 332 | channel_mult, 333 | dims, 334 | temb_scale, 335 | mean_bypass, 336 | scale_mean_bypass, 337 | shift_input, 338 | shift_output 339 | ] 340 | 341 | in_channels = in_channels + cond_channels 342 | self.in_channels = in_channels 343 | self.model_channels = model_channels 344 | self.out_channels = out_channels 345 | self.num_res_blocks = num_res_blocks 346 | self.dropout = dropout 347 | self.channel_mult = channel_mult 348 | self.temb_scale = temb_scale 349 | 350 | self.mean_bypass = mean_bypass 351 | self.scale_mean_bypass = scale_mean_bypass 352 | self.shift_input = shift_input 353 | self.shift_output = shift_output 354 | 355 | assert len(channel_mult) == 4 356 | self.input_ch = int(channel_mult[0] * model_channels) 357 | ch = self.input_ch 358 | 359 | embed_dim = time_embed_dim = int(channel_mult[-1] * model_channels) 360 | self.time_embed_dim = time_embed_dim 361 | self.time_embed = nn.Sequential( 362 | linear(time_embed_dim, time_embed_dim), 363 | nn.SiLU(inplace=True), 364 | # linear(time_embed_dim, time_embed_dim), 365 | ) 366 | 367 | if self.mean_bypass: 368 | self.mean_skip_1 = conv_nd(dims, in_channels, embed_dim, 1) # Conv((1, 1), inchannels => embed_dim) 369 | self.mean_skip_2 = conv_nd(dims, embed_dim, embed_dim, 1) # Conv((1, 1), embed_dim => embed_dim) 370 | self.mean_skip_3 = conv_nd(dims, embed_dim, out_channels, 1) # Conv((1, 1), embed_dim => outchannels) 371 | self.mean_dense_1 = linear(embed_dim, embed_dim) 372 | self.mean_dense_2 = linear(embed_dim, embed_dim) 373 | self.mean_gnorm_1 = normalization_act(embed_dim, 32) # GroupNorm(embed_dim, 32, swish) 374 | self.mean_gnorm_2 = normalization_act(embed_dim, 32) # GroupNorm(embed_dim, 32, swish) 375 | 376 | self.conv1 = conv_nd(dims, in_channels, ch, 3, padding=1) # 3 -> 32 # Conv((3, 3), inchannels => channels[1], stride=1, pad=SamePad()) 377 | self.dense1 = linear(time_embed_dim, ch) # Dense(embed_dim, channels[1]), 378 | self.gnorm1 = normalization_act(ch, 4) # GroupNorm(channels[1], 4, swish), 379 | 380 | # Encoding 381 | out_ch = int(channel_mult[1] * model_channels) 382 | self.conv2 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 32 -> 64 383 | self.dense2 = linear(time_embed_dim, out_ch) 384 | self.gnorm2 = normalization_act(out_ch, 32) 385 | 386 | ch = out_ch 387 | out_ch = int(channel_mult[2] * model_channels) 388 | self.conv3 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 64 -> 128 389 | self.dense3 = linear(time_embed_dim, out_ch) 390 | self.gnorm3 = normalization_act(out_ch, 32) 391 | 392 | ch = out_ch 393 | out_ch = int(channel_mult[3] * model_channels) 394 | self.conv4 = Downsample(ch, use_conv=True, dims=dims, out_channels=out_ch) # 128 -> 256 395 | self.dense4 = linear(time_embed_dim, out_ch) 396 | 397 | self.middle_block = TimestepEmbedSequential( 398 | *[ 399 | ResBlock( 400 | out_ch, 401 | time_embed_dim, 402 | dropout, 403 | dims=dims, 404 | num_groups=min(out_ch//4, 32) 405 | ) for _ in range(num_res_blocks) 406 | ] 407 | ) 408 | 409 | # Decoding 410 | self.gnorm4 = normalization_act(out_ch, 32) 411 | self.tconv4 = Upsample(out_ch, use_conv=True, dims=dims, out_channels=ch) # 256 -> 128 412 | self.denset4 = linear(time_embed_dim, ch) 413 | self.tgnorm4 = normalization_act(ch, 32) 414 | 415 | out_ch = ch 416 | ch = int(channel_mult[1] * model_channels) 417 | self.tconv3 = Upsample(out_ch*2, use_conv=True, dims=dims, out_channels=ch) # 128 + 128 -> 64 418 | self.denset3 = linear(time_embed_dim, ch) 419 | self.tgnorm3 = normalization_act(ch, 32) 420 | 421 | out_ch = ch 422 | ch = int(channel_mult[0] * model_channels) 423 | self.tconv2 = Upsample(out_ch*2, use_conv=True, dims=dims, out_channels=ch) # 64 + 64 -> 32 424 | self.denset2 = linear(time_embed_dim, ch) 425 | self.tgnorm2 = normalization_act(ch, 32) 426 | 427 | self.tconv1 = zero_module(conv_nd(dims, self.input_ch*2, out_channels, 3, padding=1)) 428 | 429 | 430 | def forward(self, x, y, timesteps): 431 | timesteps = timesteps.squeeze() 432 | embed = self.time_embed(timestep_embedding(timesteps * self.temb_scale + 1, self.time_embed_dim)) 433 | 434 | # Encoder 435 | if self.shift_input: 436 | h1 = x - th.mean(x, dim=(-1,-2), keepdim=True) # remove mean of noised variables before input 437 | else: 438 | h1 = x 439 | 440 | h1 = th.cat([x, y], dim=1) 441 | h1 = self.conv1(h1) 442 | h1 = h1 + expand_dims(self.dense1(embed), len(h1.shape)) 443 | h1 = self.gnorm1(h1) 444 | h2 = self.conv2(h1) 445 | h2 = h2 + expand_dims(self.dense2(embed), len(h2.shape)) 446 | h2 = self.gnorm2(h2) 447 | h3 = self.conv3(h2) 448 | h3 = h3 + expand_dims(self.dense3(embed), len(h3.shape)) 449 | h3 = self.gnorm3(h3) 450 | h4 = self.conv4(h3) 451 | h4 = h4 + expand_dims(self.dense4(embed), len(h4.shape)) 452 | 453 | # middle 454 | h = h4 455 | h = self.middle_block(h, embed) 456 | 457 | # Decoder 458 | h = self.gnorm4(h) 459 | h = self.tconv4(h) 460 | h = h + expand_dims(self.denset4(embed), len(h.shape)) 461 | h = self.tgnorm4(h) 462 | h = self.tconv3(th.cat([h, h3], dim=1)) 463 | h = h + expand_dims(self.denset3(embed), len(h.shape)) 464 | h = self.tgnorm3(h) 465 | h = self.tconv2(th.cat([h, h2], dim=1)) 466 | h = h + expand_dims(self.denset2(embed), len(h.shape)) 467 | h = self.tgnorm2(h) 468 | h = self.tconv1(th.cat([h, h1], dim=1)) 469 | 470 | if self.shift_output: 471 | h = h - th.mean(h, dim=(-1,-2), keepdim=True) # remove mean after output 472 | 473 | # Mean processing of noised variable channels 474 | if self.mean_bypass: 475 | hm = self.mean_skip_1(th.mean(th.cat([x, y], dim=1), dim=(-1,-2), keepdim=True)) 476 | hm = hm + expand_dims(self.mean_dense_1(embed), len(hm.shape)) 477 | hm = self.mean_gnorm_1(hm) 478 | hm = self.mean_skip_2(hm) 479 | hm = hm + expand_dims(self.mean_dense_2(embed), len(hm.shape)) 480 | hm = self.mean_gnorm_2(hm) 481 | hm = self.mean_skip_3(hm) 482 | if self.scale_mean_bypass: 483 | scale = np.sqrt(np.prod(x.shape[2:])) 484 | hm = hm / scale 485 | # Add back in noised channel mean to noised channel spatial variatons 486 | return h + hm 487 | else: 488 | return h 489 | -------------------------------------------------------------------------------- /bridge/runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .repeater import repeater 2 | -------------------------------------------------------------------------------- /bridge/runners/config_getters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import OmegaConf 3 | import hydra 4 | from ..models import * 5 | from .plotters import * 6 | import torchvision.datasets 7 | import torchvision.transforms as transforms 8 | import os 9 | from functools import partial 10 | from .logger import CSVLogger, WandbLogger, Logger 11 | from torch.utils.data import DataLoader 12 | from bridge.data.afhq import AFHQ 13 | from bridge.data.downscaler import DownscalerDataset 14 | 15 | cmp = lambda x: transforms.Compose([*x]) 16 | 17 | def worker_init_fn(worker_id): 18 | np.random.seed(worker_id) 19 | torch.manual_seed(worker_id) 20 | torch.cuda.manual_seed_all(worker_id) 21 | 22 | 23 | def get_plotter(runner, args): 24 | dataset_tag = getattr(args, DATASET) 25 | if dataset_tag in [DATASET_MNIST, DATASET_EMNIST, DATASET_CIFAR10] or dataset_tag.startswith(DATASET_AFHQ): 26 | return ImPlotter(runner, args) 27 | elif dataset_tag in [DATASET_DOWNSCALER_LOW, DATASET_DOWNSCALER_HIGH]: 28 | return DownscalerPlotter(runner, args) 29 | else: 30 | return Plotter(runner, args) 31 | 32 | 33 | # Model 34 | # -------------------------------------------------------------------------------- 35 | 36 | MODEL = 'Model' 37 | BASIC_MODEL = 'Basic' 38 | UNET_MODEL = 'UNET' 39 | DOWNSCALER_UNET_MODEL = 'DownscalerUNET' 40 | DDPMPP_MODEL = 'DDPMpp' 41 | 42 | NAPPROX = 2000 43 | 44 | 45 | def get_model(args): 46 | model_tag = getattr(args, MODEL) 47 | 48 | if model_tag == UNET_MODEL: 49 | image_size = args.data.image_size 50 | 51 | if args.model.channel_mult is not None: 52 | channel_mult = args.model.channel_mult 53 | else: 54 | if image_size == 256: 55 | channel_mult = (1, 1, 2, 2, 4, 4) 56 | elif image_size == 160: 57 | channel_mult = (1, 2, 2, 4) 58 | elif image_size == 64: 59 | channel_mult = (1, 2, 2, 2) 60 | elif image_size == 32: 61 | channel_mult = (1, 2, 2, 2) 62 | elif image_size == 28: 63 | channel_mult = (0.5, 1, 1) 64 | else: 65 | raise ValueError(f"unsupported image size: {image_size}") 66 | 67 | attention_ds = [] 68 | for res in args.model.attention_resolutions.split(","): 69 | if image_size % int(res) == 0: 70 | attention_ds.append(image_size // int(res)) 71 | 72 | kwargs = { 73 | "in_channels": args.data.channels, 74 | "model_channels": args.model.num_channels, 75 | "out_channels": args.data.channels, 76 | "num_res_blocks": args.model.num_res_blocks, 77 | "attention_resolutions": tuple(attention_ds), 78 | "dropout": args.model.dropout, 79 | "channel_mult": channel_mult, 80 | "num_classes": None, 81 | "use_checkpoint": args.model.use_checkpoint, 82 | "num_heads": args.model.num_heads, 83 | "use_scale_shift_norm": args.model.use_scale_shift_norm, 84 | "resblock_updown": args.model.resblock_updown, 85 | "temb_scale": args.model.temb_scale 86 | } 87 | 88 | net = UNetModel(**kwargs) 89 | 90 | elif model_tag == DOWNSCALER_UNET_MODEL: 91 | image_size = args.data.image_size 92 | channel_mult = args.model.channel_mult 93 | 94 | kwargs = { 95 | "in_channels": args.data.channels, 96 | "cond_channels": args.data.cond_channels, 97 | "model_channels": args.model.num_channels, 98 | "out_channels": args.data.channels, 99 | "num_res_blocks": args.model.num_res_blocks, 100 | "dropout": args.model.dropout, 101 | "channel_mult": channel_mult, 102 | "temb_scale": args.model.temb_scale, 103 | "mean_bypass": args.model.mean_bypass, 104 | "scale_mean_bypass": args.model.scale_mean_bypass, 105 | "shift_input": args.model.shift_input, 106 | "shift_output": args.model.shift_output, 107 | } 108 | 109 | net = DownscalerUNetModel(**kwargs) 110 | 111 | elif model_tag == DDPMPP_MODEL: 112 | # assert args.data.image_size == 512 113 | class Config(): 114 | pass 115 | config = Config() 116 | config.model = Config() 117 | config.data = Config() 118 | config.model.scale_by_sigma = args.model.scale_by_sigma 119 | config.model.normalization = args.model.normalization 120 | config.model.nonlinearity = args.model.nonlinearity 121 | config.model.nf = args.model.nf 122 | config.model.ch_mult = args.model.ch_mult 123 | config.model.num_res_blocks = args.model.num_res_blocks 124 | config.model.attn_resolutions = args.model.attn_resolutions 125 | config.model.dropout = args.model.dropout 126 | config.model.resamp_with_conv = args.model.resamp_with_conv 127 | config.model.conditional = args.model.conditional 128 | config.model.fir = args.model.fir 129 | config.model.fir_kernel = args.model.fir_kernel 130 | config.model.skip_rescale = args.model.skip_rescale 131 | config.model.resblock_type = args.model.resblock_type 132 | config.model.progressive = args.model.progressive 133 | config.model.progressive_input = args.model.progressive_input 134 | config.model.progressive_combine = args.model.progressive_combine 135 | config.model.attention_type = args.model.attention_type 136 | config.model.init_scale = args.model.init_scale 137 | config.model.fourier_scale = args.model.fourier_scale 138 | config.model.conv_size = args.model.conv_size 139 | config.model.embedding_type = args.model.embedding_type 140 | 141 | config.data.image_size = args.data.image_size 142 | config.data.num_channels = args.data.channels 143 | config.data.centered = True # assumes data is within -1, 1 and so the model will do no adjustments to it 144 | 145 | 146 | net = NCSNpp(config) 147 | 148 | 149 | return net 150 | 151 | # Optimizer 152 | # -------------------------------------------------------------------------------- 153 | 154 | def get_optimizer(net, args): 155 | lr = args.lr 156 | optimizer = args.optimizer 157 | if optimizer == 'Adam': 158 | return torch.optim.Adam(net.parameters(), lr=lr) 159 | elif optimizer == 'AdamW': 160 | return torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=args.weight_decay) 161 | 162 | 163 | # Dataset 164 | # -------------------------------------------------------------------------------- 165 | 166 | DATASET = 'Dataset' 167 | DATASET_TRANSFER = 'Dataset_transfer' 168 | DATASET_MNIST = 'mnist' 169 | DATASET_EMNIST = 'emnist' 170 | DATASET_CIFAR10 = 'cifar10' 171 | DATASET_AFHQ = 'afhq' 172 | DATASET_DOWNSCALER_LOW = 'downscaler_low' 173 | DATASET_DOWNSCALER_HIGH = 'downscaler_high' 174 | 175 | def get_datasets(args): 176 | dataset_tag = getattr(args, DATASET) 177 | 178 | # INITIAL (DATA) DATASET 179 | 180 | data_dir = hydra.utils.to_absolute_path(args.paths.data_dir_name) 181 | 182 | # MNIST DATASET 183 | if dataset_tag == DATASET_MNIST: 184 | # data_tag = args.data.dataset 185 | root = os.path.join(data_dir, 'mnist') 186 | load = args.load 187 | assert args.data.channels == 1 188 | assert args.data.image_size == 28 189 | train_transform = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 190 | init_ds = torchvision.datasets.MNIST(root=root, train=True, transform=cmp(train_transform), download=True) 191 | 192 | # CIFAR10 DATASET 193 | if dataset_tag == DATASET_CIFAR10: 194 | # data_tag = args.data.dataset 195 | root = os.path.join(data_dir, 'cifar10') 196 | load = args.load 197 | assert args.data.channels == 3 198 | assert args.data.image_size == 32 199 | train_transform = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 200 | if args.data.random_flip: 201 | train_transform.insert(0, transforms.RandomHorizontalFlip()) 202 | 203 | init_ds = torchvision.datasets.CIFAR10(root=root, train=True, transform=cmp(train_transform), download=True) 204 | 205 | # AFHQ DATASET 206 | if dataset_tag.startswith(DATASET_AFHQ): 207 | assert args.data.image_size == 512 208 | animal_type = dataset_tag.split('_')[1] 209 | init_ds = AFHQ(root_dir=os.path.join(args.paths.afhq_path, 'train'), animal_type=animal_type) 210 | 211 | # Downscaler dataset 212 | if dataset_tag == DATASET_DOWNSCALER_HIGH: 213 | root = os.path.join(data_dir, 'downscaler') 214 | train_transform = [transforms.Normalize((0.,), (1.,))] 215 | assert not args.data.random_flip 216 | # if args.data.random_flip: 217 | # train_transform = train_transform + [ 218 | # transforms.RandomHorizontalFlip(p=0.5), 219 | # transforms.RandomVerticalFlip(p=0.5), 220 | # transforms.RandomApply([transforms.RandomRotation((90, 90))], p=0.5), 221 | # ] 222 | wavenumber = args.data.get('wavenumber', 0) 223 | split = args.data.get('split', "train") 224 | 225 | init_ds = DownscalerDataset(root=root, resolution=512, wavenumber=wavenumber, split=split, transform=cmp(train_transform)) 226 | 227 | # FINAL DATASET 228 | 229 | final_ds, mean_final, var_final = get_final_dataset(args, init_ds) 230 | return init_ds, final_ds, mean_final, var_final 231 | 232 | 233 | def get_final_dataset(args, init_ds): 234 | if args.transfer: 235 | data_dir = hydra.utils.to_absolute_path(args.paths.data_dir_name) 236 | dataset_transfer_tag = getattr(args, DATASET_TRANSFER) 237 | mean_final = torch.tensor(0.) 238 | var_final = torch.tensor(1.*10**3) # infty like 239 | 240 | if dataset_transfer_tag == DATASET_EMNIST: 241 | from ..data.emnist import FiveClassEMNIST 242 | # data_tag = args.data.dataset 243 | root = os.path.join(data_dir, 'emnist') 244 | load = args.load 245 | assert args.data.channels == 1 246 | assert args.data.image_size == 28 247 | train_transform = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 248 | final_ds = FiveClassEMNIST(root=root, train=True, download=True, transform=cmp(train_transform)) 249 | 250 | # AFHQ DATASET 251 | if dataset_transfer_tag.startswith(DATASET_AFHQ): 252 | assert args.data.image_size == 512 253 | animal_type = dataset_transfer_tag.split('_')[1] 254 | final_ds = AFHQ(root_dir=os.path.join(args.paths.afhq_path, 'train'), animal_type=animal_type) 255 | 256 | if dataset_transfer_tag == DATASET_DOWNSCALER_LOW: 257 | root = os.path.join(data_dir, 'downscaler') 258 | train_transform = [transforms.Normalize((0.,), (1.,))] 259 | if args.data.random_flip: 260 | train_transform = train_transform + [ 261 | transforms.RandomHorizontalFlip(p=0.5), 262 | transforms.RandomVerticalFlip(p=0.5), 263 | transforms.RandomApply([transforms.RandomRotation((90, 90))], p=0.5), 264 | ] 265 | 266 | split = args.data.get('split', "train") 267 | 268 | final_ds = DownscalerDataset(root=root, resolution=64, split=split, transform=cmp(train_transform)) 269 | 270 | else: 271 | if args.adaptive_mean: 272 | vec = next(iter(DataLoader(init_ds, batch_size=NAPPROX, num_workers=args.num_workers, worker_init_fn=worker_init_fn)))[0] 273 | mean_final = vec.mean(axis=0) 274 | var_final = eval(args.var_final) if isinstance(args.var_final, str) else torch.tensor([args.var_final]) 275 | elif args.final_adaptive: 276 | vec = next(iter(DataLoader(init_ds, batch_size=NAPPROX, num_workers=args.num_workers, worker_init_fn=worker_init_fn)))[0] 277 | mean_final = vec.mean(axis=0) 278 | var_final = vec.var(axis=0) 279 | else: 280 | mean_final = eval(args.mean_final) if isinstance(args.mean_final, str) else torch.tensor([args.mean_final]) 281 | var_final = eval(args.var_final) if isinstance(args.var_final, str) else torch.tensor([args.var_final]) 282 | final_ds = None 283 | 284 | return final_ds, mean_final, var_final 285 | 286 | 287 | def get_valid_test_datasets(args): 288 | valid_ds, test_ds = None, None 289 | 290 | dataset_tag = getattr(args, DATASET) 291 | data_dir = hydra.utils.to_absolute_path(args.paths.data_dir_name) 292 | 293 | # MNIST DATASET 294 | if dataset_tag == DATASET_MNIST: 295 | # data_tag = args.data.dataset 296 | root = os.path.join(data_dir, 'mnist') 297 | load = args.load 298 | assert args.data.channels == 1 299 | assert args.data.image_size == 28 300 | test_transform = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 301 | valid_ds = None 302 | test_ds = torchvision.datasets.MNIST(root=root, train=False, transform=cmp(test_transform), download=True) 303 | 304 | # # CIFAR10 DATASET 305 | # if dataset_tag == DATASET_CIFAR10: 306 | # # data_tag = args.data.dataset 307 | # root = os.path.join(data_dir, 'cifar10') 308 | # load = args.load 309 | # assert args.data.channels == 3 310 | # assert args.data.image_size == 32 311 | # test_transform = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 312 | # valid_ds = None 313 | # test_ds = torchvision.datasets.CIFAR10(root=root, train=False, transform=cmp(test_transform), download=True) 314 | 315 | return valid_ds, test_ds 316 | 317 | 318 | # Logger 319 | # -------------------------------------------------------------------------------- 320 | 321 | LOGGER = 'LOGGER' 322 | CSV_TAG = 'CSV' 323 | WANDB_TAG = 'Wandb' 324 | NOLOG_TAG = 'NONE' 325 | 326 | 327 | def get_logger(args, name): 328 | logger_tag = getattr(args, LOGGER) 329 | 330 | if logger_tag == CSV_TAG: 331 | kwargs = {'save_dir': args.CSV_log_dir, 'name': name, 'flush_logs_every_n_steps': 1} 332 | return CSVLogger(**kwargs) 333 | 334 | if logger_tag == WANDB_TAG: 335 | log_dir = os.getcwd() 336 | if not args.use_default_wandb_name: 337 | run_name = os.path.normpath(os.path.relpath(log_dir, os.path.join( 338 | hydra.utils.to_absolute_path(args.paths.experiments_dir_name), args.name))).replace("\\", "/") 339 | else: 340 | run_name = None 341 | data_tag = args.data.dataset 342 | config = OmegaConf.to_container(args, resolve=True) 343 | 344 | wandb_entity = os.environ['WANDB_ENTITY'] 345 | assert len(wandb_entity) > 0, "WANDB_ENTITY not set" 346 | 347 | kwargs = {'name': run_name, 'project': 'dsbm_' + args.name, 'prefix': name, 'entity': wandb_entity, 348 | 'tags': [data_tag], 'config': config, 'id': str(args.wandb_id) if args.wandb_id is not None else None} 349 | return WandbLogger(**kwargs) 350 | 351 | if logger_tag == NOLOG_TAG: 352 | return Logger() 353 | -------------------------------------------------------------------------------- /bridge/runners/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999, device="cpu"): 6 | self.mu = mu 7 | self.shadow = {} 8 | self.device = device 9 | 10 | def register(self, module): 11 | if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel): 12 | module = module.module 13 | for name, param in module.named_parameters(): 14 | if param.requires_grad: 15 | self.shadow[name] = param.data.clone() 16 | 17 | def update(self, module): 18 | if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel): 19 | module = module.module 20 | for name, param in module.named_parameters(): 21 | if param.requires_grad: 22 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.distributed.DistributedDataParallel): 33 | inner_module = module.module 34 | locs = inner_module.locals 35 | module_copy = type(inner_module)(*locs).to(self.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | if isinstance(module, nn.DataParallel): 38 | module_copy = nn.DataParallel(module_copy) 39 | else: 40 | locs = module.locals 41 | module_copy = type(module)(*locs).to(self.device) 42 | module_copy.load_state_dict(module.state_dict()) 43 | self.ema(module_copy) 44 | return module_copy 45 | 46 | def state_dict(self): 47 | return self.shadow 48 | 49 | def load_state_dict(self, state_dict): 50 | self.shadow = state_dict 51 | -------------------------------------------------------------------------------- /bridge/runners/logger.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.loggers import CSVLogger as _CSVLogger, WandbLogger as _WandbLogger 2 | import wandb 3 | 4 | class Logger: 5 | def log_metrics(self, metric_dict, step=None): 6 | pass 7 | 8 | def log_hyperparams(self, params): 9 | pass 10 | 11 | def log_image(self, key, images, **kwargs): 12 | pass 13 | 14 | 15 | class CSVLogger(_CSVLogger): 16 | def log_image(self, key, images, **kwargs): 17 | pass 18 | 19 | 20 | class WandbLogger(_WandbLogger): 21 | LOGGER_JOIN_CHAR = '/' 22 | 23 | def log_metrics(self, metrics, step=None, fb=None): 24 | if fb is not None: 25 | metrics.pop('fb', None) 26 | else: 27 | fb = metrics.pop('fb', None) 28 | if fb is not None: 29 | metrics = {fb + '/' + k: v for k, v in metrics.items()} 30 | super().log_metrics(metrics, step=step) 31 | 32 | def log_image(self, key, images, **kwargs): 33 | if not isinstance(images, list): 34 | raise TypeError(f'Expected a list as "images", found {type(images)}') 35 | step = kwargs.pop("step", None) 36 | fb = kwargs.pop("fb", None) 37 | n = len(images) 38 | for k, v in kwargs.items(): 39 | if len(v) != n: 40 | raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") 41 | kwarg_list = [{k: kwargs[k][i] for k in kwargs.keys()} for i in range(n)] 42 | if n == 1: 43 | metrics = {key: wandb.Image(images[0], **kwarg_list[0])} 44 | else: 45 | metrics = {key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)]} 46 | self.log_metrics(metrics, step=step, fb=fb) -------------------------------------------------------------------------------- /bridge/runners/repeater.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | 3 | def repeater(data_loader): 4 | for loader in repeat(data_loader): 5 | for data in loader: 6 | yield data -------------------------------------------------------------------------------- /bridge/sde/__init__.py: -------------------------------------------------------------------------------- 1 | from .discrete_langevin import Langevin 2 | from .diffusion_bridge import DBDSB_VE, DBDSB_VP -------------------------------------------------------------------------------- /bridge/sde/diffusion_bridge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .optimal_transport import OTPlanSampler 5 | 6 | class DBDSB_VE: 7 | def __init__(self, sig, num_steps, timesteps, shape_x, shape_y, first_coupling, mean_match=False, ot_sampler=None, eps=1e-4, **kwargs): 8 | self.device = timesteps.device 9 | 10 | self.sig = sig # total sigma from time 0 and T=1 11 | self.num_steps = num_steps # num diffusion steps 12 | self.timesteps = timesteps # schedule of timesteps 13 | assert len(self.timesteps) == self.num_steps 14 | assert torch.allclose(self.timesteps.sum(), torch.tensor(self.T)) # sum of timesteps is T=1 15 | assert (self.timesteps > 0).all() 16 | self.gammas = self.timesteps * self.sig**2 # schedule of variance steps 17 | 18 | self.d_x = shape_x # dimension of object to diffuse 19 | self.d_y = shape_y # dimension of conditioning 20 | 21 | self.first_coupling = first_coupling 22 | self.eps = eps 23 | 24 | self.ot_sampler = None 25 | if ot_sampler is not None: 26 | self.ot_sampler = OTPlanSampler(ot_sampler, reg=2*self.sig**2) 27 | self.mean_match = mean_match 28 | 29 | @property 30 | def T(self): 31 | return 1. 32 | 33 | @property 34 | def alpha(self): 35 | return 0. 36 | 37 | @torch.no_grad() 38 | def marginal_prob(self, x, t, fb): 39 | if fb == "f": 40 | std = self.sig * torch.sqrt(t) 41 | else: 42 | std = self.sig * torch.sqrt(self.T - t) 43 | mean = x 44 | return mean, std 45 | 46 | @torch.no_grad() 47 | def record_langevin_seq(self, net, samples_x, init_samples_y, fb, sample=False, num_steps=None, **kwargs): 48 | if fb == 'b': 49 | gammas = torch.flip(self.gammas, (0,)) 50 | timesteps = torch.flip(self.timesteps, (0,)) 51 | t = torch.ones((samples_x.shape[0], 1), device=self.device) 52 | sign = -1. 53 | elif fb == 'f': 54 | gammas = self.gammas 55 | timesteps = self.timesteps 56 | t = torch.zeros((samples_x.shape[0], 1), device=self.device) 57 | sign = 1. 58 | 59 | x = samples_x 60 | N = x.shape[0] 61 | 62 | if num_steps is None: 63 | num_steps = self.num_steps 64 | else: 65 | timesteps = np.interp(np.arange(1, num_steps+1)/num_steps, np.arange(self.num_steps+1)/self.num_steps, [0, *np.cumsum(timesteps.cpu())]) 66 | timesteps = torch.from_numpy(np.diff(timesteps, prepend=[0])).to(self.device) 67 | gammas = timesteps * self.sig**2 68 | 69 | x_tot = torch.Tensor(N, num_steps, *self.d_x).to(x.device) 70 | y_tot = None 71 | steps_expanded = torch.Tensor(N, num_steps, 1).to(x.device) 72 | 73 | drift_fn = self.get_drift_fn_pred(fb) 74 | 75 | for k in range(num_steps): 76 | gamma = gammas[k] 77 | timestep = timesteps[k] 78 | 79 | pred = net(x, init_samples_y, t) # Raw prediction of the network 80 | 81 | if sample and (k==num_steps-1) and self.mean_match: 82 | x = pred 83 | else: 84 | drift = drift_fn(t, x, pred) 85 | x = x + drift * timestep 86 | if not (sample and (k==num_steps-1)): 87 | x = x + torch.sqrt(gamma) * torch.randn_like(x) 88 | 89 | x_tot[:, k, :] = x 90 | # y_tot[:, k, :] = y 91 | steps_expanded[:, k, :] = t 92 | t = t + sign * timestep 93 | 94 | if fb == 'b': 95 | assert torch.allclose(t, torch.zeros(1, device=self.device), atol=1e-4, rtol=1e-4), f"{t} != 0" 96 | else: 97 | assert torch.allclose(t, torch.ones(1, device=self.device) * self.T, atol=1e-4, rtol=1e-4), f"{t} != 1" 98 | 99 | return x_tot, y_tot, None, steps_expanded 100 | 101 | @torch.no_grad() 102 | def generate_new_dataset(self, x0, y0, x1, sample_fn, sample_direction, sample=False, num_steps=None): 103 | if sample_direction == 'f': 104 | zstart = x0 105 | else: 106 | zstart = x1 107 | zend = self.record_langevin_seq(sample_fn, zstart, y0, sample_direction, sample=sample, num_steps=num_steps)[0][:, -1] 108 | if sample_direction == 'f': 109 | z0, z1 = zstart, zend 110 | else: 111 | z0, z1 = zend, zstart 112 | return z0, y0, z1 113 | 114 | @torch.no_grad() 115 | def probability_flow_ode(self, net_f=None, net_b=None, y=None): 116 | get_drift_fn_net = self.get_drift_fn_net 117 | 118 | class ODEfunc(nn.Module): 119 | def __init__(self, net_f=None, net_b=None): 120 | super().__init__() 121 | self.net_f = net_f 122 | self.net_b = net_b 123 | self.nfe = 0 124 | if self.net_f is not None: 125 | self.drift_fn_f = get_drift_fn_net(self.net_f, 'f', y=y) 126 | self.drift_fn_b = get_drift_fn_net(self.net_b, 'b', y=y) 127 | 128 | def forward(self, t, x): 129 | self.nfe += 1 130 | t = torch.ones((x.shape[0], 1), device=x.device) * t.item() 131 | if self.net_f is None: 132 | return - self.drift_fn_b(t, x) 133 | return (self.drift_fn_f(t, x) - self.drift_fn_b(t, x)) / 2 134 | 135 | return ODEfunc(net_f=net_f, net_b=net_b) 136 | 137 | @torch.no_grad() 138 | def get_train_tuple(self, x0, x1, fb='', first_it=False): 139 | if first_it and fb == 'b': 140 | z0 = x0 141 | if self.first_coupling == "ref": 142 | # First coupling is x_0, x_0 perturbed 143 | z1 = z0 + torch.randn_like(z0) * self.sig 144 | elif self.first_coupling == "ind": 145 | z1 = x1 146 | else: 147 | raise NotImplementedError 148 | elif first_it and fb == 'f': 149 | assert self.first_coupling == "ind" 150 | z0, z1 = x0, x1 151 | else: 152 | z0, z1 = x0, x1 153 | 154 | if self.ot_sampler is not None: 155 | assert z0.shape == z1.shape 156 | original_shape = z0.shape 157 | z0, z1 = self.ot_sampler.sample_plan(z0.flatten(start_dim=1), z1.flatten(start_dim=1)) 158 | z0, z1 = z0.view(original_shape), z1.view(original_shape) 159 | 160 | t = torch.rand(z1.shape[0], device=self.device) * (1-2*self.eps) + self.eps 161 | t = t[:, None, None, None] 162 | z_t = t * z1 + (1.-t) * z0 163 | z = torch.randn_like(z_t) 164 | z_t = z_t + self.sig * torch.sqrt(t*(1.-t)) * z 165 | if self.mean_match: 166 | if fb == 'f': 167 | target = z1 168 | else: 169 | target = z0 170 | else: 171 | if fb == 'f': 172 | # (z1 - z_t) / (1-t) 173 | # target = z1 - z0 174 | # target = target - self.sig * torch.sqrt(t/(1.-t)) * z 175 | # target = self.A_f(t) * z_t + self.M_f(t) * z1 176 | drift_f = self.drift_f(t, z_t, z0, z1) 177 | target = drift_f + self.alpha * z_t 178 | else: 179 | # (z0 - z_t) / t 180 | # target = - (z1 - z0) 181 | # target = target - self.sig * torch.sqrt((1.-t)/t) * z 182 | drift_b = self.drift_b(t, z_t, z0, z1) 183 | target = drift_b - self.alpha * z_t 184 | return z_t, t, target 185 | 186 | def A_f(self, t): 187 | return -1./(self.T-t) 188 | 189 | def M_f(self, t): 190 | return 1./(self.T-t) 191 | 192 | def A_b(self, t): 193 | return -1./t 194 | 195 | def M_b(self, t): 196 | return 1./t 197 | 198 | def drift_f(self, t, x, init, final): 199 | t = t.view(t.shape[0], 1, 1, 1) 200 | return self.A_f(t) * x + self.M_f(t) * final 201 | 202 | def drift_b(self, t, x, init, final): 203 | t = t.view(t.shape[0], 1, 1, 1) 204 | return self.A_b(t) * x + self.M_b(t) * init 205 | 206 | def get_drift_fn_net(self, net, fb, y=None): 207 | drift_fn_pred = self.get_drift_fn_pred(fb) 208 | def drift_fn(t, x): 209 | pred = net(x, y, t) # Raw prediction of the network 210 | return drift_fn_pred(t, x, pred) 211 | return drift_fn 212 | 213 | def get_drift_fn_pred(self, fb): 214 | def drift_fn(t, x, pred): 215 | if self.mean_match: 216 | if fb == 'f': 217 | drift = self.drift_f(t, x, None, pred) 218 | else: 219 | drift = self.drift_b(t, x, pred, None) 220 | else: 221 | if fb == 'f': 222 | drift = pred - self.alpha * x 223 | else: 224 | drift = pred + self.alpha * x 225 | return drift 226 | return drift_fn 227 | 228 | 229 | class DBDSB_VP(DBDSB_VE): 230 | def __init__(self, sig, num_steps, timesteps, shape_x, shape_y, first_coupling, mean_match=False, ot_sampler=None, eps=1e-4, **kwargs): 231 | assert ot_sampler is None 232 | super().__init__(sig, num_steps, timesteps, shape_x, shape_y, first_coupling, mean_match=mean_match, ot_sampler=ot_sampler, eps=eps, **kwargs) 233 | 234 | @property 235 | def alpha(self): 236 | return 0.5 237 | 238 | @torch.no_grad() 239 | def marginal_prob(self, x, t, fb): 240 | if fb == "f": 241 | mean = torch.exp(-0.5 * t) * x 242 | std = self.sig * torch.sqrt(1 - torch.exp(-t)) 243 | else: 244 | raise NotImplementedError 245 | return mean, std 246 | 247 | def A_f(self, t: float) -> float: 248 | return -self.alpha / torch.tanh(self.alpha * (self.T - t)) 249 | 250 | def M_f(self, t: float) -> float: 251 | return self.alpha / torch.sinh(self.alpha * (self.T - t)) 252 | 253 | def A_b(self, t: float) -> float: 254 | return -self.alpha / torch.tanh(self.alpha * t) 255 | 256 | def M_b(self, t: float) -> float: 257 | return self.alpha / torch.sinh(self.alpha * t) 258 | -------------------------------------------------------------------------------- /bridge/sde/discrete_langevin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def grad_gauss(x, m, var): 4 | xout = (m - x) / var 5 | return xout 6 | 7 | # def ornstein_ulhenbeck(x, gradx, gamma): 8 | # xout = x + gamma * gradx + torch.sqrt(2 * gamma) * torch.randn(x.shape, device=x.device) 9 | # return xout 10 | 11 | class Langevin: 12 | 13 | def __init__(self, num_steps, shape_x, shape_y, gammas, time_sampler, 14 | mean_final=torch.tensor([0., 0.]), var_final=torch.tensor([.5, .5]), 15 | mean_match=True, out_scale=1, var_final_gamma_scale=False): 16 | self.device = gammas.device 17 | 18 | self.mean_match = mean_match 19 | self.mean_final = mean_final.to(self.device) if mean_final is not None else None 20 | self.var_final = var_final.to(self.device) if var_final is not None else None 21 | 22 | self.num_steps = num_steps # num diffusion steps 23 | self.d_x = shape_x # dimension of object to diffuse 24 | self.d_y = shape_y # dimension of conditioning 25 | self.gammas = gammas # schedule 26 | 27 | self.steps = torch.arange(self.num_steps).to(self.device) 28 | self.time = torch.cumsum(self.gammas, 0).to(self.device) 29 | # self.time_sampler = time_sampler 30 | self.out_scale = out_scale 31 | self.var_final_gamma_scale = var_final_gamma_scale 32 | 33 | 34 | def record_init_langevin(self, init_samples_x, init_samples_y, mean_final=None, var_final=None): 35 | if mean_final is None: 36 | mean_final = self.mean_final 37 | if var_final is None: 38 | var_final = self.var_final 39 | 40 | x = init_samples_x 41 | # y = init_samples_y 42 | N = x.shape[0] 43 | steps = self.steps.reshape((1,self.num_steps,1)).repeat((N,1,1)) 44 | 45 | 46 | x_tot = torch.Tensor(N, self.num_steps, *self.d_x).to(x.device) 47 | # y_tot = torch.Tensor(N, self.num_steps, *self.d_y).to(x.device) 48 | y_tot = None 49 | out = torch.Tensor(N, self.num_steps, *self.d_x).to(x.device) 50 | num_iter = self.num_steps 51 | steps_expanded = steps 52 | 53 | for k in range(num_iter): 54 | gamma = self.gammas[k] 55 | 56 | if self.var_final_gamma_scale: 57 | var_gamma_ratio = 1 / gamma 58 | scaled_gamma = gamma * var_final 59 | else: 60 | var_gamma_ratio = var_final / gamma 61 | scaled_gamma = gamma 62 | 63 | gradx = grad_gauss(x, mean_final, var_gamma_ratio) 64 | t_old = x + gradx / 2 65 | z = torch.randn(x.shape, device=x.device) 66 | x = t_old + torch.sqrt(scaled_gamma)*z 67 | gradx = grad_gauss(x, mean_final, var_gamma_ratio) 68 | t_new = x + gradx / 2 69 | x_tot[:, k, :] = x 70 | # y_tot[:, k, :] = y 71 | if self.mean_match: 72 | out[:, k, :] = (t_old - t_new) #/ (2 * gamma) 73 | else: 74 | out_scale = eval(self.out_scale).to(self.device) if isinstance(self.out_scale, str) else self.out_scale 75 | out[:, k, :] = (t_old - t_new) / out_scale 76 | 77 | return x_tot, y_tot, out, steps_expanded 78 | 79 | def record_langevin_seq(self, net, samples_x, init_samples_y, fb, sample=False, var_final=None): 80 | if var_final is None: 81 | var_final = self.var_final 82 | if fb == 'b': 83 | gammas = torch.flip(self.gammas, (0,)) 84 | elif fb == 'f': 85 | gammas = self.gammas 86 | 87 | x = samples_x 88 | # y = init_samples_y 89 | N = x.shape[0] 90 | steps = self.steps.reshape((1,self.num_steps,1)).repeat((N,1,1)) 91 | 92 | 93 | x_tot = torch.Tensor(N, self.num_steps, *self.d_x).to(x.device) 94 | # y_tot = torch.Tensor(N, self.num_steps, *self.d_y).to(x.device) 95 | y_tot = None 96 | out = torch.Tensor(N, self.num_steps, *self.d_x).to(x.device) 97 | steps_expanded = steps 98 | num_iter = self.num_steps 99 | 100 | if self.mean_match: 101 | for k in range(num_iter): 102 | gamma = gammas[k] 103 | 104 | scaled_gamma = gamma 105 | if self.var_final_gamma_scale: 106 | scaled_gamma = scaled_gamma * var_final 107 | 108 | t_old = net(x, None, steps[:, k, :]) 109 | 110 | if sample & (k==num_iter-1): 111 | x = t_old 112 | else: 113 | z = torch.randn(x.shape, device=x.device) 114 | x = t_old + torch.sqrt(scaled_gamma) * z 115 | 116 | t_new = net(x, None, steps[:, k, :]) 117 | x_tot[:, k, :] = x 118 | # y_tot[:, k, :] = y 119 | out[:, k, :] = (t_old - t_new) 120 | else: 121 | for k in range(num_iter): 122 | gamma = gammas[k] 123 | 124 | scaled_gamma = gamma 125 | if self.var_final_gamma_scale: 126 | scaled_gamma = scaled_gamma * var_final 127 | out_scale = eval(self.out_scale).to(self.device) if isinstance(self.out_scale, str) else self.out_scale 128 | 129 | t_old = x + out_scale * net(x, None, steps[:, k, :]) 130 | 131 | if sample & (k==num_iter-1): 132 | x = t_old 133 | else: 134 | z = torch.randn(x.shape, device=x.device) 135 | x = t_old + torch.sqrt(scaled_gamma) * z 136 | t_new = x + out_scale * net(x, None, steps[:, k, :]) 137 | 138 | x_tot[:, k, :] = x 139 | # y_tot[:, k, :] = y 140 | out[:, k, :] = (t_old - t_new) / out_scale 141 | 142 | 143 | return x_tot, y_tot, out, steps_expanded 144 | -------------------------------------------------------------------------------- /bridge/sde/optimal_transport.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import ot as pot 7 | import torch 8 | 9 | 10 | class OTPlanSampler: 11 | """OTPlanSampler implements sampling coordinates according to an squared L2 OT plan with 12 | different implementations of the plan calculation.""" 13 | 14 | def __init__( 15 | self, 16 | method: str, 17 | reg: float = 0.05, 18 | reg_m: float = 1.0, 19 | normalize_cost=False, 20 | **kwargs, 21 | ): 22 | # ot_fn should take (a, b, M) as arguments where a, b are marginals and 23 | # M is a cost matrix 24 | if method == "exact": 25 | self.ot_fn = pot.emd 26 | elif method == "sinkhorn": 27 | self.ot_fn = partial(pot.sinkhorn, reg=reg) 28 | # elif method == "unbalanced": 29 | # self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m) 30 | # elif method == "partial": 31 | # self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg) 32 | else: 33 | raise ValueError(f"Unknown method: {method}") 34 | self.reg = reg 35 | self.reg_m = reg_m 36 | self.normalize_cost = normalize_cost 37 | self.kwargs = kwargs 38 | 39 | def get_map(self, x0, x1): 40 | a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) 41 | M = torch.cdist(x0, x1) ** 2 42 | if self.normalize_cost: 43 | M = M / M.max() 44 | p = self.ot_fn(a, b, M.detach().cpu().numpy()) 45 | if not np.all(np.isfinite(p)): 46 | print("ERROR: p is not finite") 47 | print(p) 48 | print("Cost mean, max", M.mean(), M.max()) 49 | print(x0, x1) 50 | return p 51 | 52 | def sample_map(self, pi, batch_size): 53 | p = pi.flatten() 54 | p = p / p.sum() 55 | choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size) 56 | return np.divmod(choices, pi.shape[1]) 57 | 58 | def sample_plan(self, x0, x1): 59 | pi = self.get_map(x0, x1) 60 | i, j = self.sample_map(pi, x0.shape[0]) 61 | return x0[i], x1[j] 62 | 63 | def sample_trajectory(self, X): 64 | # Assume X is [batch, times, dim] 65 | times = X.shape[1] 66 | pis = [] 67 | for t in range(times - 1): 68 | pis.append(self.get_map(X[:, t], X[:, t + 1])) 69 | 70 | indices = [np.arange(X.shape[0])] 71 | for pi in pis: 72 | j = [] 73 | for i in indices[-1]: 74 | j.append(np.random.choice(pi.shape[1], p=pi[i] / pi[i].sum())) 75 | indices.append(np.array(j)) 76 | 77 | to_return = [] 78 | for t in range(times): 79 | to_return.append(X[:, t][indices[t]]) 80 | to_return = np.stack(to_return, axis=1) 81 | return to_return -------------------------------------------------------------------------------- /bridge/trainer_rf.py: -------------------------------------------------------------------------------- 1 | import os, sys, warnings, time 2 | import re 3 | from collections import OrderedDict 4 | from functools import partial 5 | 6 | import hydra 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | import glob 13 | 14 | from .data import DBDSB_CacheLoader 15 | from .sde import * 16 | from .runners import * 17 | from .runners.config_getters import get_model, get_optimizer, get_plotter, get_logger 18 | from .runners.ema import EMAHelper 19 | from .trainer_dbdsb import IPF_DBDSB 20 | 21 | # from torchdyn.core import NeuralODE 22 | from torchdiffeq import odeint 23 | 24 | 25 | class IPF_RF(IPF_DBDSB): 26 | def __init__(self, init_ds, final_ds, mean_final, var_final, args, accelerator=None, final_cond_model=None, 27 | valid_ds=None, test_ds=None): 28 | super().__init__(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, final_cond_model=final_cond_model, 29 | valid_ds=valid_ds, test_ds=test_ds) 30 | self.langevin = DBDSB_VE(0., self.num_steps, self.timesteps, self.shape_x, self.shape_y, first_coupling="ind", ot_sampler=self.args.ot_sampler) 31 | 32 | def build_checkpoints(self): 33 | self.first_pass = True # Load and use checkpointed networks during first pass 34 | self.ckpt_dir = './checkpoints/' 35 | self.ckpt_prefixes = ["net_b", "sample_net_b", "optimizer_b"] 36 | self.cache_dir='./cache/' 37 | if self.accelerator.is_main_process: 38 | os.makedirs(self.ckpt_dir, exist_ok=True) 39 | os.makedirs(self.cache_dir, exist_ok=True) 40 | 41 | if self.args.get('checkpoint_run', False): 42 | self.resume, self.checkpoint_it, self.checkpoint_pass, self.step = \ 43 | True, self.args.checkpoint_it, self.args.checkpoint_pass, self.args.checkpoint_iter 44 | print(f"Resuming training at iter {self.checkpoint_it} {self.checkpoint_pass} step {self.step}") 45 | 46 | self.checkpoint_b = hydra.utils.to_absolute_path(self.args.checkpoint_b) 47 | self.sample_checkpoint_b = hydra.utils.to_absolute_path(self.args.sample_checkpoint_b) 48 | self.optimizer_checkpoint_b = hydra.utils.to_absolute_path(self.args.optimizer_checkpoint_b) 49 | 50 | else: 51 | self.ckpt_dir_load = os.path.abspath(self.ckpt_dir) 52 | ckpt_dir_load_list = os.path.normpath(self.ckpt_dir_load).split(os.sep) 53 | if 'test' in ckpt_dir_load_list: 54 | self.ckpt_dir_load = os.path.join(*ckpt_dir_load_list[:ckpt_dir_load_list.index('test')], "checkpoints/") 55 | self.resume, self.checkpoint_it, self.checkpoint_pass, self.step, ckpt_b_suffix = self.find_last_ckpt() 56 | 57 | if self.resume: 58 | if not self.args.autostart_next_it and self.step == 1 and not (self.checkpoint_it == 1 and self.checkpoint_pass == 'b'): 59 | self.checkpoint_pass, self.checkpoint_it = self.compute_prev_it(self.checkpoint_pass, self.checkpoint_it) 60 | self.step = self.compute_max_iter(self.checkpoint_pass, self.checkpoint_it) + 1 61 | 62 | print(f"Resuming training at iter {self.checkpoint_it} {self.checkpoint_pass} step {self.step}") 63 | self.checkpoint_b, self.sample_checkpoint_b, self.optimizer_checkpoint_b = [os.path.join(self.ckpt_dir_load, f"{ckpt_prefix}_{ckpt_b_suffix}.ckpt") for ckpt_prefix in self.ckpt_prefixes[:3]] 64 | 65 | def build_models(self, forward_or_backward=None): 66 | # running network 67 | net_b = get_model(self.args) 68 | 69 | if self.first_pass and self.resume: 70 | if self.resume: 71 | try: 72 | net_b.load_state_dict(torch.load(self.checkpoint_b)) 73 | except: 74 | state_dict = torch.load(self.checkpoint_b) 75 | new_state_dict = OrderedDict() 76 | for k, v in state_dict.items(): 77 | name = k.replace("module.", "") # remove "module." 78 | new_state_dict[name] = v 79 | net_b.load_state_dict(new_state_dict) 80 | 81 | if forward_or_backward is None: 82 | net_b = self.accelerator.prepare(net_b) 83 | self.net = torch.nn.ModuleDict({'b': net_b}) 84 | if forward_or_backward == 'b': 85 | net_b = self.accelerator.prepare(net_b) 86 | self.net.update({'b': net_b}) 87 | 88 | def build_ema(self): 89 | if self.args.ema: 90 | self.ema_helpers = {} 91 | 92 | if self.first_pass and self.resume: 93 | # sample network 94 | sample_net_b = get_model(self.args) 95 | 96 | if self.resume: 97 | sample_net_b.load_state_dict( 98 | torch.load(self.sample_checkpoint_b)) 99 | sample_net_b = sample_net_b.to(self.device) 100 | self.update_ema('b') 101 | self.ema_helpers['b'].register(sample_net_b) 102 | 103 | def train(self): 104 | for n in range(self.checkpoint_it, self.n_ipf + 1): 105 | self.accelerator.print('RF iteration: ' + str(n) + '/' + str(self.n_ipf)) 106 | # BACKWARD OPTIMISATION 107 | self.ipf_iter('b', n) 108 | 109 | def build_optimizers(self, forward_or_backward=None): 110 | optimizer_b = get_optimizer(self.net['b'], self.args) 111 | 112 | if self.first_pass and self.resume: 113 | if self.resume: 114 | optimizer_b.load_state_dict(torch.load(self.optimizer_checkpoint_b)) 115 | 116 | if forward_or_backward is None: 117 | self.optimizer = {'b': optimizer_b} 118 | if forward_or_backward == 'b': 119 | self.optimizer.update({'b': optimizer_b}) 120 | 121 | def find_last_ckpt(self): 122 | existing_ckpts_dict = {} 123 | for ckpt_prefix in self.ckpt_prefixes: 124 | existing_ckpts = sorted(glob.glob(os.path.join(self.ckpt_dir_load, f"{ckpt_prefix}_**.ckpt"))) 125 | existing_ckpts_dict[ckpt_prefix] = set([os.path.basename(existing_ckpt)[len(ckpt_prefix)+1:-5] for existing_ckpt in existing_ckpts]) 126 | 127 | existing_ckpts_b = sorted(list(existing_ckpts_dict["net_b"].intersection(existing_ckpts_dict["sample_net_b"], existing_ckpts_dict["optimizer_b"])), reverse=True) 128 | 129 | if len(existing_ckpts_b) == 0: 130 | return False, 1, 'b', 1, None 131 | 132 | def return_valid_ckpt_combi(b_i, b_n): 133 | # Return is_valid, checkpoint_it, checkpoint_pass, checkpoint_step 134 | if (b_n == 1 and b_i != self.first_num_iter) or (b_n > 1 and b_i != self.num_iter): # during b pass 135 | return True, b_n, 'b', b_i + 1 136 | else: 137 | return True, b_n + 1, 'b', 1 138 | 139 | for existing_ckpt_b in existing_ckpts_b: 140 | ckpt_b_n, ckpt_b_i = existing_ckpt_b.split("_") 141 | ckpt_b_n, ckpt_b_i = int(ckpt_b_n), int(ckpt_b_i) 142 | 143 | is_valid, checkpoint_it, checkpoint_pass, checkpoint_step = return_valid_ckpt_combi(ckpt_b_i, ckpt_b_n) 144 | if is_valid: 145 | break 146 | 147 | if not is_valid: 148 | return False, 1, 'b', 1, None 149 | else: 150 | return True, checkpoint_it, checkpoint_pass, checkpoint_step, existing_ckpt_b 151 | 152 | def apply_net(self, x, y, t, net, fb, return_scale=False): 153 | out = net.forward(x, y, t) 154 | 155 | if return_scale: 156 | return out, 1 157 | else: 158 | return out 159 | 160 | def compute_prev_it(self, forward_or_backward, n): 161 | assert forward_or_backward == 'b' 162 | prev_direction = 'b' 163 | prev_n = n - 1 164 | return prev_direction, prev_n 165 | 166 | def compute_next_it(self, forward_or_backward, n): 167 | assert forward_or_backward == 'b' 168 | next_direction = 'b' 169 | next_n = n+1 170 | return next_direction, next_n -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | # - launcher: slurm_gpu 6 | - job 7 | - model: UNET 8 | - method: dbdsb #dsb, dbdsb 9 | - dataset: mnist_transfer 10 | # - override hydra/launcher: submitit_slurm 11 | 12 | name: ${Dataset}_${data.dataset} 13 | run: 0 14 | seed: 42 15 | use_default_wandb_name: False 16 | wandb_id: null # leave as null for wandb to assign this 17 | 18 | # logging 19 | LOGGER: Wandb # CSV 20 | CSV_log_dir: ./ 21 | 22 | # training 23 | optimizer: Adam 24 | test_batch_size: 1000 25 | plot_level: 2 26 | cache_refresh_stride: ${num_iter} 27 | cache_num_steps: ${num_steps} 28 | test_num_steps: ${num_steps} 29 | normalize_x1: False 30 | 31 | paths: 32 | experiments_dir_name: experiments 33 | data_dir_name: data 34 | 35 | # checkpoint 36 | autostart_next_it: False 37 | 38 | checkpoint_run: False 39 | checkpoint_it: 1 40 | checkpoint_pass: b # b or f (skip b ipf run) 41 | checkpoint_iter: 0 42 | checkpoint_dir: null 43 | sample_checkpoint_f: null 44 | sample_checkpoint_b: ${checkpoint_dir}/ 45 | checkpoint_f: null 46 | checkpoint_b: ${checkpoint_dir}/ 47 | optimizer_checkpoint_f: null 48 | optimizer_checkpoint_b: ${checkpoint_dir}/ 49 | -------------------------------------------------------------------------------- /conf/dataset/afhq_transfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | cdsb: False 4 | 5 | # data 6 | Dataset: afhq_wild 7 | data: 8 | dataset: afhq_wild_cat 9 | image_size: 512 10 | channels: 3 11 | 12 | # transfer 13 | transfer: True 14 | Dataset_transfer: afhq_cat 15 | 16 | 17 | final_adaptive: False 18 | adaptive_mean: False 19 | mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}]) 20 | var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}]) 21 | 22 | # device 23 | device: cuda 24 | num_workers: 4 25 | pin_memory: True 26 | 27 | # logging 28 | log_stride: 100 29 | gif_stride: 5000 30 | plot_npar: 4 31 | test_npar: 100 32 | test_batch_size: 4 33 | cache_npar: 400 34 | cache_batch_size: 10 35 | num_repeat_data: 1 36 | cache_refresh_stride: 1000 37 | 38 | # training 39 | use_prev_net: True 40 | ema: True 41 | ema_rate: 0.999 42 | grad_clipping: True 43 | grad_clip: 1.0 44 | batch_size: 4 45 | num_iter: 25000 46 | n_ipf: 30 47 | lr: 0.0001 48 | 49 | num_steps: 100 50 | -------------------------------------------------------------------------------- /conf/dataset/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | cdsb: False 4 | 5 | # data 6 | Dataset: cifar10 7 | data: 8 | dataset: "CIFAR10" 9 | image_size: 32 10 | channels: 3 11 | random_flip: true 12 | 13 | # transfer 14 | transfer: False 15 | Dataset_transfer: mnist 16 | 17 | 18 | final_adaptive: False 19 | adaptive_mean: False 20 | mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}]) 21 | var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}]) 22 | load: True 23 | 24 | # device 25 | device: cuda 26 | num_workers: 2 27 | pin_memory: True 28 | 29 | # logging 30 | log_stride: 100 31 | gif_stride: 100000 32 | plot_npar: 100 33 | test_npar: 50000 34 | test_batch_size: 500 35 | cache_npar: 250000 36 | cache_batch_size: 1250 37 | num_repeat_data: 1 # 4 38 | cache_refresh_stride: ${num_iter} 39 | 40 | # training 41 | optimizer: AdamW 42 | use_prev_net: True 43 | ema: True 44 | ema_rate: 0.9999 45 | grad_clipping: True 46 | grad_clip: 1.0 47 | batch_size: 128 48 | num_iter: 500000 49 | n_ipf: 100 50 | lr: 0.0001 51 | weight_decay: 0.01 52 | 53 | num_steps: 200 54 | -------------------------------------------------------------------------------- /conf/dataset/downscaler_transfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | cdsb: True 4 | 5 | # data 6 | Dataset: downscaler_high 7 | data: 8 | dataset: "downscaler_transfer" 9 | image_size: 512 10 | channels: 2 11 | cond_channels: 1 12 | random_flip: False 13 | wavenumber: 0 14 | split: train 15 | 16 | # transfer 17 | transfer: True 18 | Dataset_transfer: downscaler_low 19 | 20 | 21 | final_adaptive: False 22 | adaptive_mean: False 23 | mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}]) 24 | var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}]) 25 | 26 | # device 27 | device: cuda 28 | num_workers: 2 29 | pin_memory: True 30 | 31 | # logging 32 | log_stride: 100 33 | gif_stride: 10000 34 | plot_npar: 16 35 | test_npar: 16 36 | test_batch_size: 8 37 | cache_npar: 1250 38 | cache_batch_size: 10 39 | num_repeat_data: 1 40 | cache_refresh_stride: 2500 41 | 42 | # training 43 | use_prev_net: True 44 | ema: True 45 | ema_rate: 0.999 46 | grad_clipping: True 47 | grad_clip: 1.0 48 | batch_size: 4 49 | num_iter: 10000 50 | n_ipf: 30 51 | lr: 0.0002 52 | 53 | num_steps: 100 54 | -------------------------------------------------------------------------------- /conf/dataset/mnist_transfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | cdsb: False 4 | 5 | # data 6 | Dataset: mnist 7 | data: 8 | dataset: "MNIST_EMNIST" 9 | image_size: 28 10 | channels: 1 11 | 12 | # transfer 13 | transfer: True 14 | Dataset_transfer: emnist 15 | 16 | 17 | final_adaptive: False 18 | adaptive_mean: False 19 | mean_final: torch.zeros([${data.channels}, ${data.image_size}, ${data.image_size}]) 20 | var_final: 1 * torch.ones([${data.channels}, ${data.image_size}, ${data.image_size}]) 21 | load: True 22 | 23 | # device 24 | device: cuda 25 | num_workers: 2 26 | pin_memory: True 27 | 28 | # logging 29 | log_stride: 100 30 | gif_stride: 5000 31 | plot_npar: 100 32 | test_npar: 10000 33 | test_batch_size: 250 34 | cache_npar: 10000 35 | cache_batch_size: 1250 36 | num_repeat_data: 1 37 | cache_refresh_stride: 1000 38 | 39 | # training 40 | use_prev_net: True 41 | ema: True 42 | ema_rate: 0.999 43 | grad_clipping: True 44 | grad_clip: 1.0 45 | batch_size: 128 46 | num_iter: 500000 47 | n_ipf: 50 48 | lr: 0.0001 49 | 50 | num_steps: 30 51 | -------------------------------------------------------------------------------- /conf/gaussian.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - job 6 | 7 | name: gaussian 8 | seed: null 9 | a: 0.1 10 | dim: 5 11 | sigma: 1 12 | num_steps: 20 13 | net_name: mlp_large 14 | activation_fn: torch.nn.SiLU 15 | model_name: dsb 16 | first_coupling: ref 17 | inner_iters: 10000 18 | outer_iters: 40 19 | fb_sequence: ['b','f'] 20 | 21 | paths: 22 | experiments_dir_name: experiments -------------------------------------------------------------------------------- /conf/job.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra 2 | 3 | job: 4 | config: 5 | # configuration for the ${hydra.job.override_dirname} runtime variable 6 | override_dirname: 7 | exclude_keys: [name, launcher, run, training, device, data, data_dir, dataset, load, Dataset, data.dataset, test_batch_size, y_cond, x_cond_true, LOGGER, plot_npar, paths.data_dir_name, seed, autostart_next_it, checkpoint_run] 8 | 9 | run: 10 | # Output directory for normal runs 11 | dir: ./${paths.experiments_dir_name}/${name}/${hydra.job.override_dirname}/${seed} 12 | 13 | sweep: 14 | # Output directory for sweep runs 15 | dir: ./${paths.experiments_dir_name}/${name}/${hydra.job.override_dirname} 16 | subdir: ${seed} 17 | 18 | job_logging: 19 | formatters: 20 | simple: 21 | format: '[%(levelname)s] - %(message)s' 22 | handlers: 23 | file: 24 | filename: run.log 25 | root: 26 | handlers: [console, file] -------------------------------------------------------------------------------- /conf/launcher/slurm_cpu.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.launcher 2 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 3 | timeout_min: 10000 4 | cpus_per_task: 1 5 | tasks_per_node: 1 6 | mem_gb: 8 7 | name: ${hydra.job.name} 8 | partition: partition_name 9 | max_num_timeout: 0 10 | array_parallelism: 30 11 | additional_parameters: { 12 | } -------------------------------------------------------------------------------- /conf/launcher/slurm_gpu.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.launcher 2 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 3 | timeout_min: 10000 4 | cpus_per_task: 1 5 | tasks_per_node: 1 6 | mem_gb: 14 7 | name: ${hydra.job.name} 8 | partition: gpu_partition_name 9 | max_num_timeout: 0 10 | array_parallelism: 5 11 | additional_parameters: { 12 | "gres": "gpu:1", 13 | } -------------------------------------------------------------------------------- /conf/method/bm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: DBDSB 4 | 5 | first_num_iter: 500000 6 | n_ipf: 1 7 | 8 | # schedule 9 | sde: ve 10 | gamma_max: 0.2 11 | gamma_min: 0.001 12 | gamma_space: linspace 13 | 14 | symmetric_gamma: False 15 | 16 | first_coupling: ind 17 | 18 | mean_match: False 19 | loss_scale: True 20 | std_trick: False -------------------------------------------------------------------------------- /conf/method/dbdsb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: DBDSB 4 | 5 | first_num_iter: ${num_iter} 6 | 7 | # schedule 8 | sde: ve 9 | gamma_max: 0.2 10 | gamma_min: 0.001 11 | gamma_space: linspace 12 | 13 | symmetric_gamma: False 14 | 15 | first_coupling: ref 16 | 17 | mean_match: False 18 | loss_scale: True 19 | std_trick: False -------------------------------------------------------------------------------- /conf/method/dbdsb_vp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: DBDSB 4 | 5 | first_num_iter: ${num_iter} 6 | 7 | # schedule 8 | sde: vp 9 | gamma_max: 0.2 10 | gamma_min: 0.001 11 | gamma_space: linspace 12 | 13 | symmetric_gamma: False 14 | 15 | first_coupling: ref 16 | 17 | mean_match: True 18 | loss_scale: False 19 | std_trick: False -------------------------------------------------------------------------------- /conf/method/dsb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: DSB 4 | 5 | mean_match: False 6 | 7 | # schedule 8 | gamma_max: 0.2 9 | gamma_min: 0.001 10 | gamma_space: linspace 11 | 12 | symmetric_gamma: True 13 | 14 | var_final_gamma_scale: False 15 | langevin_scale: 2*torch.sqrt(gamma) 16 | loss_scale: 1. 17 | 18 | -------------------------------------------------------------------------------- /conf/method/otcfm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: RF 4 | 5 | first_num_iter: 500000 6 | n_ipf: 1 7 | 8 | # schedule 9 | sde: ve 10 | gamma_max: 0.01 11 | gamma_min: 0.01 12 | gamma_space: linspace 13 | 14 | symmetric_gamma: True 15 | 16 | first_coupling: ind 17 | ot_sampler: exact 18 | 19 | # Not used by needed for __init__ in the IPF_DBDSB parent class 20 | mean_match: False 21 | loss_scale: False 22 | std_trick: False -------------------------------------------------------------------------------- /conf/method/rf.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Method: RF 4 | 5 | first_num_iter: 500000 6 | 7 | # schedule 8 | sde: ve 9 | gamma_max: 0.01 10 | gamma_min: 0.01 11 | gamma_space: linspace 12 | 13 | symmetric_gamma: True 14 | 15 | first_coupling: ind 16 | ot_sampler: null 17 | 18 | # Not used by needed for __init__ in the IPF_DBDSB parent class 19 | mean_match: False 20 | loss_scale: False 21 | std_trick: False -------------------------------------------------------------------------------- /conf/model/DDPMpp_32.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Model: DDPMpp 4 | model: 5 | # config from https://github.com/gnobitab/RectifiedFlow/blob/main/ImageGeneration/configs/rectified_flow/afhq_cat_pytorch_rf_gaussian.py 6 | scale_by_sigma: False 7 | normalization: 'GroupNorm' 8 | nonlinearity: 'swish' 9 | nf: 128 10 | ch_mult: [1, 2, 2, 2] 11 | num_res_blocks: 4 12 | attn_resolutions: [16] 13 | dropout: 0.15 14 | resamp_with_conv: True 15 | conditional: True 16 | fir: False 17 | fir_kernel: [1, 3, 3, 1] 18 | skip_rescale: True 19 | resblock_type: 'biggan' 20 | progressive: 'none' 21 | progressive_input: 'none' 22 | progressive_combine: 'sum' 23 | attention_type: 'ddpm' 24 | init_scale: 0. 25 | embedding_type: 'positional' 26 | fourier_scale: 16 27 | conv_size: 3 -------------------------------------------------------------------------------- /conf/model/DDPMpp_RF.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Model: DDPMpp 4 | model: 5 | # config from https://github.com/gnobitab/RectifiedFlow/blob/main/ImageGeneration/configs/rectified_flow/afhq_cat_pytorch_rf_gaussian.py 6 | scale_by_sigma: False 7 | normalization: 'GroupNorm' 8 | nonlinearity: 'swish' 9 | nf: 128 10 | ch_mult: [1, 1, 2, 2, 2, 2, 2] 11 | num_res_blocks: 2 12 | attn_resolutions: [16] 13 | dropout: 0. 14 | resamp_with_conv: True 15 | conditional: True 16 | fir: True 17 | fir_kernel: [1, 3, 3, 1] 18 | skip_rescale: True 19 | resblock_type: 'biggan' 20 | progressive: 'output_skip' 21 | progressive_input: 'input_skip' 22 | progressive_combine: 'sum' 23 | attention_type: 'ddpm' 24 | init_scale: 0. 25 | fourier_scale: 16 26 | conv_size: 3 27 | embedding_type: 'positional' -------------------------------------------------------------------------------- /conf/model/DownscalerUNET.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Model: DownscalerUNET 4 | model: 5 | num_channels: 32 6 | channel_mult: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | num_res_blocks: 8 12 | dropout: 0.5 13 | temb_scale: 30 14 | mean_bypass: True 15 | scale_mean_bypass: True 16 | shift_input: True 17 | shift_output: True -------------------------------------------------------------------------------- /conf/model/UNET.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | Model: UNET 4 | model: 5 | num_channels: 128 6 | channel_mult: null 7 | num_res_blocks: 2 8 | num_heads: 4 9 | attention_resolutions: "16,8" 10 | dropout: 0.1 # 0.0 # 11 | use_checkpoint: False 12 | use_scale_shift_norm: True 13 | resblock_updown: False 14 | temb_scale: 1000 -------------------------------------------------------------------------------- /conf/test_config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | # - launcher: slurm_gpu 6 | - test_job 7 | - model: UNET 8 | - method: dbdsb #dsb, dbdsb 9 | - dataset: mnist_transfer 10 | # - override hydra/launcher: submitit_slurm 11 | 12 | name: ${Dataset}_${data.dataset} 13 | run: 0 14 | seed: 42 15 | 16 | # logging 17 | LOGGER: CSV 18 | CSV_log_dir: ./ 19 | 20 | # training 21 | optimizer: Adam 22 | test_batch_size: 1000 23 | plot_level: 2 24 | cache_refresh_stride: ${num_iter} 25 | cache_num_steps: ${num_steps} 26 | test_num_steps: ${num_steps} 27 | normalize_x1: False 28 | 29 | test_ode_sampler: true 30 | ode_sampler: dopri5 31 | ode_tol: 1e-5 32 | ode_euler_step_size: 1e-2 33 | 34 | paths: 35 | experiments_dir_name: experiments 36 | data_dir_name: data 37 | 38 | # checkpoint 39 | autostart_next_it: False 40 | 41 | checkpoint_run: False 42 | checkpoint_it: 1 43 | checkpoint_pass: b # b or f (skip b ipf run) 44 | checkpoint_iter: 0 45 | checkpoint_dir: null 46 | sample_checkpoint_f: null 47 | sample_checkpoint_b: ${checkpoint_dir}/ 48 | checkpoint_f: null 49 | checkpoint_b: ${checkpoint_dir}/ 50 | optimizer_checkpoint_f: null 51 | optimizer_checkpoint_b: ${checkpoint_dir}/ 52 | -------------------------------------------------------------------------------- /conf/test_job.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra 2 | 3 | job: 4 | config: 5 | # configuration for the ${hydra.job.override_dirname} runtime variable 6 | override_dirname: 7 | exclude_keys: [name, launcher, run, training, device, data, data_dir, dataset, load, Dataset, data.dataset, test_batch_size, y_cond, x_cond_true, LOGGER, plot_npar, test_npar, paths.data_dir_name, seed, autostart_next_it, checkpoint_run, test_ode_sampler, plot_level, data.wavenumber, data.split] 8 | 9 | run: 10 | # Output directory for normal runs 11 | dir: ./${paths.experiments_dir_name}/${name}/${hydra.job.override_dirname}/${seed}/test/${now:%H-%M-%S} 12 | 13 | sweep: 14 | # Output directory for sweep runs 15 | dir: ./${paths.experiments_dir_name}/${name}/${hydra.job.override_dirname}/${seed}/test/${now:%H-%M-%S} 16 | 17 | job_logging: 18 | formatters: 19 | simple: 20 | format: '[%(levelname)s] - %(message)s' 21 | handlers: 22 | file: 23 | filename: run.log 24 | root: 25 | handlers: [console, file] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | 4 | 5 | @hydra.main(config_path="conf", config_name="config") 6 | def main(cfg): 7 | if cfg.Method == "DSB": 8 | from run_dsb import run 9 | return run(cfg) 10 | elif cfg.Method == "DBDSB": 11 | from run_dbdsb import run 12 | return run(cfg) 13 | elif cfg.Method == "RF": 14 | from run_rf import run 15 | return run(cfg) 16 | else: 17 | raise NotImplementedError 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /run_dbdsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import os 4 | 5 | from bridge.trainer_dbdsb import IPF_DBDSB 6 | from bridge.runners.config_getters import get_datasets, get_valid_test_datasets 7 | from accelerate import Accelerator 8 | 9 | 10 | def run(args): 11 | accelerator = Accelerator(cpu=args.device == 'cpu', split_batches=True) 12 | accelerator.print('Directory: ' + os.getcwd()) 13 | 14 | init_ds, final_ds, mean_final, var_final = get_datasets(args) 15 | valid_ds, test_ds = get_valid_test_datasets(args) 16 | 17 | final_cond_model = None 18 | ipf = IPF_DBDSB(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, 19 | final_cond_model=final_cond_model, valid_ds=valid_ds, test_ds=test_ds) 20 | accelerator.print(accelerator.state) 21 | accelerator.print(ipf.net['b']) 22 | accelerator.print('Number of parameters:', sum(p.numel() for p in ipf.net['b'].parameters() if p.requires_grad)) 23 | ipf.train() 24 | 25 | -------------------------------------------------------------------------------- /run_dsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import os 4 | 5 | from bridge.trainer_dsb import IPF_DSB 6 | from bridge.runners.config_getters import get_datasets, get_valid_test_datasets 7 | from accelerate import Accelerator 8 | 9 | 10 | def run(args): 11 | accelerator = Accelerator(cpu=args.device == 'cpu', split_batches=True) 12 | accelerator.print('Directory: ' + os.getcwd()) 13 | 14 | init_ds, final_ds, mean_final, var_final = get_datasets(args) 15 | valid_ds, test_ds = get_valid_test_datasets(args) 16 | 17 | final_cond_model = None 18 | ipf = IPF_DSB(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, 19 | final_cond_model=final_cond_model, valid_ds=valid_ds, test_ds=test_ds) 20 | accelerator.print(accelerator.state) 21 | accelerator.print(ipf.net['b']) 22 | accelerator.print('Number of parameters:', sum(p.numel() for p in ipf.net['b'].parameters() if p.requires_grad)) 23 | ipf.train() 24 | 25 | -------------------------------------------------------------------------------- /run_rf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import os 4 | 5 | from bridge.trainer_rf import IPF_RF 6 | from bridge.runners.config_getters import get_datasets, get_valid_test_datasets 7 | from accelerate import Accelerator 8 | 9 | 10 | def run(args): 11 | accelerator = Accelerator(cpu=args.device == 'cpu', split_batches=True) 12 | accelerator.print('Directory: ' + os.getcwd()) 13 | 14 | init_ds, final_ds, mean_final, var_final = get_datasets(args) 15 | valid_ds, test_ds = get_valid_test_datasets(args) 16 | 17 | final_cond_model = None 18 | ipf = IPF_RF(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, 19 | final_cond_model=final_cond_model, valid_ds=valid_ds, test_ds=test_ds) 20 | accelerator.print(accelerator.state) 21 | accelerator.print(ipf.net['b']) 22 | accelerator.print('Number of parameters:', sum(p.numel() for p in ipf.net['b'].parameters() if p.requires_grad)) 23 | ipf.train() 24 | 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | 4 | 5 | @hydra.main(config_path="conf", config_name="test_config") 6 | def main(cfg): 7 | if cfg.Method == "DSB": 8 | from test_dsb import test 9 | return test(cfg) 10 | elif cfg.Method == "DBDSB": 11 | from test_dbdsb import test 12 | return test(cfg) 13 | elif cfg.Method == "RF": 14 | from test_rf import test 15 | return test(cfg) 16 | else: 17 | raise NotImplementedError 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /test_dbdsb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import os 4 | 5 | from bridge.trainer_dbdsb import IPF_DBDSB 6 | from bridge.runners.config_getters import get_datasets, get_valid_test_datasets 7 | from accelerate import Accelerator 8 | 9 | 10 | def test(args): 11 | accelerator = Accelerator(cpu=args.device == 'cpu', split_batches=True) 12 | accelerator.print('Directory: ' + os.getcwd()) 13 | 14 | init_ds, final_ds, mean_final, var_final = get_datasets(args) 15 | valid_ds, test_ds = get_valid_test_datasets(args) 16 | 17 | final_cond_model = None 18 | ipf = IPF_DBDSB(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, 19 | final_cond_model=final_cond_model, valid_ds=valid_ds, test_ds=test_ds) 20 | accelerator.print(accelerator.state) 21 | accelerator.print(ipf.net['b']) 22 | accelerator.print('Number of parameters:', sum(p.numel() for p in ipf.net['b'].parameters() if p.requires_grad)) 23 | test_metrics = ipf.plot_and_test_step(ipf.step, ipf.checkpoint_it, "b", sampler='sde') 24 | accelerator.print("SDE: ", test_metrics) 25 | 26 | if args.test_ode_sampler: 27 | test_metrics = ipf.plot_and_test_step(ipf.step, ipf.checkpoint_it, "b", sampler='ode') 28 | accelerator.print("ODE: ", test_metrics) 29 | 30 | -------------------------------------------------------------------------------- /test_rf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hydra 3 | import os 4 | 5 | from bridge.trainer_rf import IPF_RF 6 | from bridge.runners.config_getters import get_datasets, get_valid_test_datasets 7 | from accelerate import Accelerator 8 | 9 | 10 | def test(args): 11 | accelerator = Accelerator(cpu=args.device == 'cpu', split_batches=True) 12 | accelerator.print('Directory: ' + os.getcwd()) 13 | 14 | init_ds, final_ds, mean_final, var_final = get_datasets(args) 15 | valid_ds, test_ds = get_valid_test_datasets(args) 16 | 17 | final_cond_model = None 18 | ipf = IPF_RF(init_ds, final_ds, mean_final, var_final, args, accelerator=accelerator, 19 | final_cond_model=final_cond_model, valid_ds=valid_ds, test_ds=test_ds) 20 | accelerator.print(accelerator.state) 21 | accelerator.print(ipf.net['b']) 22 | accelerator.print('Number of parameters:', sum(p.numel() for p in ipf.net['b'].parameters() if p.requires_grad)) 23 | test_metrics = ipf.plot_and_test_step(ipf.step, ipf.checkpoint_it, "b", sampler='sde') 24 | accelerator.print("SDE: ", test_metrics) 25 | 26 | if args.test_ode_sampler: 27 | test_metrics = ipf.plot_and_test_step(ipf.step, ipf.checkpoint_it, "b", sampler='ode') 28 | accelerator.print("ODE: ", test_metrics) 29 | 30 | --------------------------------------------------------------------------------