├── .gitignore ├── asset ├── ddpm │ ├── cat.png │ ├── mnist.png │ ├── cifar10.png │ └── fashion_mnist.png ├── rect_flow │ ├── cat.png │ ├── mnist.png │ ├── cifar10.png │ └── fashion_mnist.png └── score-sde │ ├── cat.png │ ├── mnist.png │ ├── cifar10.png │ └── fashion_mnist.png ├── run.sh ├── cfm ├── scripts │ └── run.sh ├── configs │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── fashion_mnist.yaml │ └── huggan │ │ └── AFHQv2.yaml └── src │ ├── train.py │ └── model.py ├── ddpm ├── scripts │ └── run.sh ├── configs │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── fashion_mnist.yaml │ └── huggan │ │ └── AFHQv2.yaml └── src │ ├── schedule.py │ ├── train.py │ └── model.py ├── consistency ├── scripts │ └── run.sh ├── configs │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── fashion_mnist.yaml │ └── huggan │ │ └── AFHQv2.yaml └── src │ ├── train.py │ └── model.py ├── rect_flow ├── scripts │ └── run.sh ├── configs │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── fashion_mnist.yaml │ └── huggan │ │ └── AFHQv2.yaml └── src │ ├── train.py │ └── model.py ├── score-sde ├── scripts │ └── run.sh ├── configs │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── fashion_mnist.yaml │ └── huggan │ │ └── AFHQv2.yaml └── src │ ├── sde.py │ ├── train.py │ └── model.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | **/out 2 | **/__pycache__ 3 | ./run.sh -------------------------------------------------------------------------------- /asset/ddpm/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/ddpm/cat.png -------------------------------------------------------------------------------- /asset/ddpm/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/ddpm/mnist.png -------------------------------------------------------------------------------- /asset/ddpm/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/ddpm/cifar10.png -------------------------------------------------------------------------------- /asset/rect_flow/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/rect_flow/cat.png -------------------------------------------------------------------------------- /asset/score-sde/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/score-sde/cat.png -------------------------------------------------------------------------------- /asset/rect_flow/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/rect_flow/mnist.png -------------------------------------------------------------------------------- /asset/score-sde/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/score-sde/mnist.png -------------------------------------------------------------------------------- /asset/ddpm/fashion_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/ddpm/fashion_mnist.png -------------------------------------------------------------------------------- /asset/rect_flow/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/rect_flow/cifar10.png -------------------------------------------------------------------------------- /asset/score-sde/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/score-sde/cifar10.png -------------------------------------------------------------------------------- /asset/rect_flow/fashion_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/rect_flow/fashion_mnist.png -------------------------------------------------------------------------------- /asset/score-sde/fashion_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/reppy4620/diffusion/HEAD/asset/score-sde/fashion_mnist.png -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cur=$(pwd) 4 | 5 | cd ddpm/scripts && ./run.sh 6 | cd $cur 7 | 8 | cd score-sde/scripts && ./run.sh 9 | cd $cur 10 | 11 | cd rect_flow/scripts && ./run.sh 12 | cd $cur 13 | -------------------------------------------------------------------------------- /cfm/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=(mnist fashion_mnist "huggan/AFHQv2") 4 | 5 | for ds in ${datasets[@]}; 6 | do 7 | config_path=../configs/$ds.yaml 8 | echo $config_path 9 | python ../src/train.py --config $config_path 10 | done 11 | -------------------------------------------------------------------------------- /ddpm/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=(mnist fashion_mnist cifar10 "huggan/AFHQv2") 4 | 5 | for ds in ${datasets[@]}; 6 | do 7 | config_path=../configs/$ds.yaml 8 | echo $config_path 9 | python ../src/train.py --config $config_path 10 | done 11 | -------------------------------------------------------------------------------- /consistency/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=("huggan/AFHQv2" mnist fashion_mnist cifar10) 4 | 5 | for ds in ${datasets[@]}; 6 | do 7 | config_path=../configs/$ds.yaml 8 | echo $config_path 9 | python ../src/train.py --config $config_path 10 | done 11 | -------------------------------------------------------------------------------- /rect_flow/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=(mnist fashion_mnist cifar10 "huggan/AFHQv2") 4 | 5 | for ds in ${datasets[@]}; 6 | do 7 | config_path=../configs/$ds.yaml 8 | echo $config_path 9 | python ../src/train.py --config $config_path 10 | done 11 | -------------------------------------------------------------------------------- /score-sde/scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=(mnist fashion_mnist cifar10 "huggan/AFHQv2") 4 | 5 | for ds in ${datasets[@]}; 6 | do 7 | config_path=../configs/$ds.yaml 8 | echo $config_path 9 | python ../src/train.py --config $config_path 10 | done 11 | -------------------------------------------------------------------------------- /cfm/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 50 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /ddpm/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /rect_flow/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /cfm/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: cifar10 4 | img_key: img 5 | img_convert: RGB 6 | img_size: 32 7 | 8 | epochs: 300 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2, 4] 16 | channels: 3 17 | -------------------------------------------------------------------------------- /cfm/configs/fashion_mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: fashion_mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 50 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /ddpm/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: cifar10 4 | img_key: img 5 | img_convert: RGB 6 | img_size: 32 7 | 8 | epochs: 100 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2, 4] 16 | channels: 3 17 | -------------------------------------------------------------------------------- /ddpm/configs/fashion_mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: fashion_mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /rect_flow/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: cifar10 4 | img_key: img 5 | img_convert: RGB 6 | img_size: 32 7 | 8 | epochs: 100 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2, 4] 16 | channels: 3 17 | -------------------------------------------------------------------------------- /rect_flow/configs/fashion_mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: fashion_mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | model: 14 | dim: 32 15 | dim_mults: [1, 2] 16 | channels: 1 -------------------------------------------------------------------------------- /score-sde/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | sde_type: vp 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2] 18 | channels: 1 -------------------------------------------------------------------------------- /score-sde/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: cifar10 4 | img_key: img 5 | img_convert: RGB 6 | img_size: 32 7 | 8 | epochs: 100 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | sde_type: vp 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2, 4] 18 | channels: 3 19 | -------------------------------------------------------------------------------- /score-sde/configs/fashion_mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: fashion_mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 20 9 | image_interval: 10 10 | ckpt_interval: 25 11 | batch_size: 128 12 | 13 | sde_type: vp 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2] 18 | channels: 1 -------------------------------------------------------------------------------- /cfm/configs/huggan/AFHQv2.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: huggan/AFHQv2 4 | img_key: image 5 | img_convert: RGB 6 | img_size: 128 7 | 8 | filter: cat 9 | 10 | epochs: 300 11 | image_interval: 10 12 | ckpt_interval: 50 13 | batch_size: 32 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2, 4] 18 | channels: 3 19 | -------------------------------------------------------------------------------- /ddpm/configs/huggan/AFHQv2.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: huggan/AFHQv2 4 | img_key: image 5 | img_convert: RGB 6 | img_size: 128 7 | 8 | filter: cat 9 | 10 | epochs: 300 11 | image_interval: 10 12 | ckpt_interval: 50 13 | batch_size: 32 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2, 4] 18 | channels: 3 19 | -------------------------------------------------------------------------------- /rect_flow/configs/huggan/AFHQv2.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: huggan/AFHQv2 4 | img_key: image 5 | img_convert: RGB 6 | img_size: 128 7 | 8 | filter: cat 9 | 10 | epochs: 300 11 | image_interval: 10 12 | ckpt_interval: 50 13 | batch_size: 32 14 | 15 | model: 16 | dim: 32 17 | dim_mults: [1, 2, 4] 18 | channels: 3 19 | -------------------------------------------------------------------------------- /score-sde/configs/huggan/AFHQv2.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: huggan/AFHQv2 4 | img_key: image 5 | img_convert: RGB 6 | img_size: 128 7 | 8 | filter: cat 9 | 10 | epochs: 300 11 | image_interval: 10 12 | ckpt_interval: 50 13 | batch_size: 32 14 | 15 | sde_type: vp 16 | 17 | model: 18 | dim: 32 19 | dim_mults: [1, 2, 4] 20 | channels: 3 21 | -------------------------------------------------------------------------------- /consistency/configs/mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 100 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | ema: 14 | mu: 0.9999 15 | 16 | model: 17 | dim: 32 18 | dim_mults: [1, 2, 4] 19 | channels: 1 20 | s_data: 0.5 21 | eps: 2e-3 22 | 23 | schedule: 24 | s_1: 150 25 | s_0: 2 26 | mu_0: 0.9 27 | 28 | fun: 29 | rho: 7.0 30 | eps: ${model.eps} 31 | T: 80.0 32 | -------------------------------------------------------------------------------- /consistency/configs/cifar10.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: cifar10 4 | img_key: img 5 | img_convert: RGB 6 | img_size: 32 7 | 8 | epochs: 200 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | ema: 14 | mu: 0.9999 15 | 16 | model: 17 | dim: 32 18 | dim_mults: [1, 2, 4] 19 | channels: 3 20 | s_data: 0.5 21 | eps: 2e-3 22 | 23 | schedule: 24 | s_1: 150 25 | s_0: 2 26 | mu_0: 0.9 27 | 28 | fun: 29 | rho: 7.0 30 | eps: ${model.eps} 31 | T: 80.0 32 | -------------------------------------------------------------------------------- /consistency/configs/fashion_mnist.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: fashion_mnist 4 | img_key: image 5 | img_convert: L 6 | img_size: 28 7 | 8 | epochs: 100 9 | image_interval: 10 10 | ckpt_interval: 50 11 | batch_size: 128 12 | 13 | ema: 14 | mu: 0.9999 15 | 16 | model: 17 | dim: 32 18 | dim_mults: [1, 2, 4] 19 | channels: 1 20 | s_data: 0.5 21 | eps: 2e-3 22 | 23 | schedule: 24 | s_1: 150 25 | s_0: 2 26 | mu_0: 0.9 27 | 28 | fun: 29 | rho: 7.0 30 | eps: ${model.eps} 31 | T: 80.0 32 | -------------------------------------------------------------------------------- /consistency/configs/huggan/AFHQv2.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ../out 2 | 3 | dataset: huggan/AFHQv2 4 | img_key: image 5 | img_convert: RGB 6 | img_size: 128 7 | 8 | filter: cat 9 | 10 | epochs: 300 11 | image_interval: 10 12 | ckpt_interval: 50 13 | batch_size: 16 14 | 15 | ema: 16 | mu: 0.9999 17 | 18 | model: 19 | dim: 32 20 | dim_mults: [1, 2, 4] 21 | channels: 3 22 | s_data: 0.5 23 | eps: 2e-3 24 | 25 | schedule: 26 | s_1: 150 27 | s_0: 2 28 | mu_0: 0.9 29 | 30 | fun: 31 | rho: 7.0 32 | eps: ${model.eps} 33 | T: 80.0 34 | -------------------------------------------------------------------------------- /ddpm/src/schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cosine_beta_schedule(timesteps, s=0.008): 5 | """ 6 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672 7 | """ 8 | steps = timesteps + 1 9 | x = torch.linspace(0, timesteps, steps) 10 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 11 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 12 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 13 | return torch.clip(betas, 0.0001, 0.9999) 14 | 15 | def linear_beta_schedule(timesteps): 16 | beta_start = 0.0001 17 | beta_end = 0.02 18 | return torch.linspace(beta_start, beta_end, timesteps) 19 | 20 | def quadratic_beta_schedule(timesteps): 21 | beta_start = 0.0001 22 | beta_end = 0.02 23 | return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 24 | 25 | def sigmoid_beta_schedule(timesteps): 26 | beta_start = 0.0001 27 | beta_end = 0.02 28 | betas = torch.linspace(-6, 6, timesteps) 29 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | diffusion 2 | === 3 | 4 | The Implementation of DIffusion (like) models for fun. 5 | 6 | - DDPM([https://arxiv.org/abs/2006.11239](https://arxiv.org/abs/2006.11239)) 7 | - Score-SDE([https://arxiv.org/abs/2011.13456](https://arxiv.org/abs/2011.13456)) 8 | - Rectified Flow([https://arxiv.org/abs/2209.03003](https://arxiv.org/abs/2209.03003)) 9 | - Conditional Flow Matching(Simplified version)([https://arxiv.org/abs/2302.00482](https://arxiv.org/abs/2302.00482)) 10 | - Consistency Models([https://arxiv.org/abs/2303.01469](https://arxiv.org/abs/2303.01469)) 11 | 12 | # Requirements 13 | 14 | ```sh 15 | pip install torch torchvision einops numpy datasets tqdm scipy omegaconf 16 | ``` 17 | 18 | # Result 19 | 20 | - All the datasets are loaded with huggingface/datasets. 21 | 22 | - MNIST 23 | - 20 epoch 24 | - Unet with 1 up-down block and mid block 25 | - FASHION MNIST 26 | - 20 epoch 27 | - Unet with 1 up-down block and mid block 28 | - CIFAR10 29 | - 100 epoch 30 | - Unet with 2 up-down block and mid block 31 | - AFHQ - CAT 32 | - AFHQ dataset filtered with cat label 33 | - 300 epoch 34 | - Unet with 2 up-down block and mid block 35 | 36 | | Model | MNIST | FASHION MNIST | CIFAR10 | CAT | 37 | | --- | --- | --- | --- | --- | 38 | | DDPM | ![MNIST](./asset/ddpm/mnist.png) | ![FASHION_MNIST](./asset/ddpm/fashion_mnist.png) | ![CIFAR10](./asset/ddpm/cifar10.png) | ![CAT](./asset/ddpm/cat.png) | 39 | | Score-SDE | ![MNIST](./asset/score-sde/mnist.png) | ![FASHION_MNIST](./asset/score-sde/fashion_mnist.png) | ![CIFAR10](./asset/score-sde/cifar10.png) | ![CAT](./asset/score-sde/cat.png) | 40 | | Rectified Flow | ![MNIST](./asset/rect_flow/mnist.png) | ![FASHION_MNIST](./asset/rect_flow/fashion_mnist.png) | ![CIFAR10](./asset/rect_flow/cifar10.png) | ![CAT](./asset/rect_flow/cat.png) | 41 | 42 | 43 | - Score-SDE is not stable with regards to color tones in the CAT dataset. 44 | - Since the previous experiment worked well at times, it may depend on the model size. 45 | - This phenomenon sometimes occurs in DDPM. 46 | -------------------------------------------------------------------------------- /score-sde/src/sde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SDE: 4 | def sde(self, x, t): 5 | pass 6 | 7 | def reverse_sde(self, score, x, t): 8 | drift, diffusion = self.sde(x, t) 9 | drift = drift - (diffusion ** 2)[:, None, None, None] * score 10 | return drift, diffusion 11 | 12 | def probability_flow(self, score, x, t): 13 | drift, diffusion = self.sde(x, t) 14 | drift = drift - 0.5 * (diffusion ** 2)[:, None, None, None] * score 15 | diffusion = torch.zeros_like(diffusion) 16 | return drift, diffusion 17 | 18 | class VESDE(SDE): 19 | def __init__(self, sigma_min=0.01, sigma_max=50.): 20 | self.sigma_min = sigma_min 21 | self.sigma_max = sigma_max 22 | 23 | def sde(self, x, t): 24 | drift = torch.zeros_like(x) 25 | sigma_t = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 26 | diffusion = sigma_t * torch.sqrt(2 * (torch.log(self.sigma_max) - torch.log(self.sigma_min))) 27 | return drift, diffusion 28 | 29 | def marginal_prob(self, x, t): 30 | mean = x 31 | std = self.sigma_min ** 2 * (self.sigma_max / self.sigma_min) ** (2 * t) 32 | return mean, std 33 | 34 | class VPSDE(SDE): 35 | def __init__(self, beta_min=0.1, beta_max=20): 36 | self.beta_0 = beta_min 37 | self.beta_1 = beta_max 38 | 39 | def sde(self, x, t): 40 | beta_t = self.beta_0 + (self.beta_1 - self.beta_0) * t 41 | drift = -0.5 * beta_t[:, None, None, None] * x 42 | diffusion = torch.sqrt(beta_t) 43 | return drift, diffusion 44 | 45 | def marginal_prob(self, x, t): 46 | beta_int = self.beta_0 * t + 0.5 * (self.beta_1 - self.beta_0) * t ** 2 47 | mean = torch.exp(-0.5 * beta_int)[:, None, None, None] * x 48 | std = torch.sqrt(1. - torch.exp(-beta_int)) 49 | return mean, std 50 | 51 | class SubVPSDE(SDE): 52 | def __init__(self, beta_min=0.1, beta_max=20): 53 | self.beta_0 = beta_min 54 | self.beta_1 = beta_max 55 | 56 | def sde(self, x, t): 57 | beta_t = self.beta_0 + (self.beta_1 - self.beta_0) * t 58 | drift = -0.5 * beta_t[:, None, None, None] * x 59 | beta_int = self.beta_0 * t + 0.5 * (self.beta_1 - self.beta_0) * t ** 2 60 | diffusion = torch.sqrt(beta_t * (1. - torch.exp(-2. * beta_int))) 61 | return drift, diffusion 62 | 63 | def marginal_prob(self, x, t): 64 | beta_int = self.beta_0 * t + 0.5 * (self.beta_1 - self.beta_0) * t ** 2 65 | log_mean_coeff = -0.5 * beta_int 66 | mean = torch.exp(log_mean_coeff)[:, None, None, None] * x 67 | std = 1. - torch.exp(2. * log_mean_coeff) 68 | return mean, std 69 | 70 | -------------------------------------------------------------------------------- /rect_flow/src/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from datasets import load_dataset 6 | from torchvision import transforms as T 7 | from torchvision.utils import make_grid 8 | from torch.utils.data import DataLoader 9 | from torch.nn.utils import clip_grad_norm_ 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from scipy import integrate 13 | from argparse import ArgumentParser 14 | from omegaconf import OmegaConf 15 | 16 | from model import Unet 17 | 18 | eps = 1e-3 19 | 20 | @torch.no_grad() 21 | def sample_ode(model, image_size, batch_size=16, channels=1): 22 | shape = (batch_size, channels, image_size, image_size) 23 | device = next(model.parameters()).device 24 | 25 | b = shape[0] 26 | x = torch.randn(shape, device=device) 27 | 28 | def ode_func(t, x): 29 | x = torch.tensor(x, device=device, dtype=torch.float).reshape(shape) 30 | t = torch.full(size=(b,), fill_value=t, device=device, dtype=torch.float).reshape((b,)) 31 | v = model(x, t) 32 | return v.cpu().numpy().reshape((-1,)).astype(np.float64) 33 | 34 | res = integrate.solve_ivp(ode_func, (eps, 1.), x.reshape((-1,)).cpu().numpy(), method='RK45') 35 | x = torch.tensor(res.y[:, -1], device=device).reshape(shape) 36 | return x.clamp(-1, 1) 37 | 38 | def loss_fn(model, x_1, t): 39 | x_0 = torch.randn_like(x_1) 40 | x_t = t[:, None, None, None] * x_1 + (1 - t[:, None, None, None]) * x_0 41 | v = model(x_t, t) 42 | loss = F.mse_loss(x_1 - x_0, v) 43 | return loss 44 | 45 | def main(): 46 | parser = ArgumentParser() 47 | parser.add_argument('--config', type=str, required=True) 48 | args = parser.parse_args() 49 | 50 | config = OmegaConf.load(args.config) 51 | 52 | torch.manual_seed(42) 53 | 54 | output_dir = Path(f'{config.output_dir}/{config.dataset}') 55 | img_dir = output_dir / 'images' 56 | img_dir.mkdir(parents=True, exist_ok=True) 57 | ckpt_dir = output_dir / 'ckpt' 58 | ckpt_dir.mkdir(parents=True, exist_ok=True) 59 | 60 | if config.dataset.startswith('huggan'): 61 | transform = T.Compose([ 62 | T.Resize((config.img_size, config.img_size)), 63 | T.RandomHorizontalFlip(), 64 | T.ToTensor(), 65 | T.Lambda(lambda t: (t * 2) - 1) 66 | ]) 67 | else: 68 | transform = T.Compose([ 69 | T.RandomHorizontalFlip(), 70 | T.ToTensor(), 71 | T.Lambda(lambda t: (t * 2) - 1) 72 | ]) 73 | 74 | def transforms(examples): 75 | examples["pixel_values"] = [transform(image.convert(config.img_convert)) for image in examples[config.img_key]] 76 | del examples[config.img_key] 77 | 78 | return examples 79 | 80 | ds = load_dataset(config.dataset) 81 | if hasattr(config, 'filter'): 82 | if config.filter == 'cat': 83 | ds = ds.filter(lambda x: x["label"] == 0) 84 | transformed_ds = ds.with_transform(transforms).remove_columns("label") 85 | 86 | # create dataloader 87 | dl = DataLoader(transformed_ds["train"], batch_size=config.batch_size, shuffle=True, num_workers=16) 88 | 89 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 90 | model = Unet(**config.model).to(device) 91 | 92 | optimizer = optim.AdamW(model.parameters(), lr=1e-4) 93 | 94 | def handle_batch(batch): 95 | batch_size = batch["pixel_values"].shape[0] 96 | x = batch["pixel_values"].to(device) 97 | 98 | t = torch.empty(size=(batch_size,), device=device).uniform_(eps, 1) 99 | loss = loss_fn(model, x, t) 100 | return loss 101 | 102 | train_losses = list() 103 | for epoch in range(1, config.epochs + 1): 104 | losses = list() 105 | bar = tqdm(dl, total=len(dl), desc=f'Epoch {epoch}: ') 106 | for batch in bar: 107 | optimizer.zero_grad() 108 | loss = handle_batch(batch) 109 | loss.backward() 110 | clip_grad_norm_(model.parameters(), 5.0) 111 | optimizer.step() 112 | losses.append(loss.item()) 113 | bar.set_postfix_str(f'Loss: {np.mean(losses):.6f}') 114 | train_losses.append(np.mean(losses)) 115 | if epoch % config.image_interval == 0: 116 | images = sample_ode(model, config.img_size, channels=config.model.channels) 117 | img = make_grid(images, nrow=4, normalize=True) 118 | img = T.ToPILImage()(img) 119 | img.save(img_dir / f'epoch_{epoch}.png') 120 | 121 | if epoch % config.ckpt_interval == 0: 122 | torch.save({ 123 | 'epoch': epoch, 124 | 'model': model.state_dict(), 125 | 'optimizer': optimizer.state_dict(), 126 | }, ckpt_dir / f'epoch_{epoch:05d}.pth') 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /cfm/src/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from datasets import load_dataset 6 | from torchvision import transforms as T 7 | from torchvision.utils import make_grid 8 | from torch.utils.data import DataLoader 9 | from torch.nn.utils import clip_grad_norm_ 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from scipy import integrate 13 | from argparse import ArgumentParser 14 | from omegaconf import OmegaConf 15 | 16 | from model import Unet 17 | 18 | eps = 1e-3 19 | sigma = 0.1 20 | 21 | @torch.no_grad() 22 | def sample_ode(model, image_size, batch_size=16, channels=1): 23 | shape = (batch_size, channels, image_size, image_size) 24 | device = next(model.parameters()).device 25 | 26 | b = shape[0] 27 | x = torch.randn(shape, device=device) 28 | 29 | def ode_func(t, x): 30 | x = torch.tensor(x, device=device, dtype=torch.float).reshape(shape) 31 | t = torch.full(size=(b,), fill_value=t, device=device, dtype=torch.float).reshape((b,)) 32 | v = model(x, t) 33 | return v.cpu().numpy().reshape((-1,)).astype(np.float64) 34 | 35 | res = integrate.solve_ivp(ode_func, (eps, 1.), x.reshape((-1,)).cpu().numpy(), method='RK45') 36 | x = torch.tensor(res.y[:, -1], device=device).reshape(shape) 37 | return x 38 | 39 | def loss_fn(model, x_1, t, sigma): 40 | x_0 = torch.randn_like(x_1) 41 | mu_t = t[:, None, None, None] * x_1 + (1 - t[:, None, None, None]) * x_0 42 | sigma_t = sigma 43 | x_t = mu_t + sigma_t * torch.randn_like(x_1) 44 | u_t = x_1 - x_0 45 | v_t = model(x_t, t) 46 | loss = F.mse_loss(u_t, v_t) 47 | return loss 48 | 49 | def main(): 50 | parser = ArgumentParser() 51 | parser.add_argument('--config', type=str, required=True) 52 | args = parser.parse_args() 53 | 54 | config = OmegaConf.load(args.config) 55 | 56 | torch.manual_seed(42) 57 | 58 | output_dir = Path(f'{config.output_dir}/{config.dataset}') 59 | img_dir = output_dir / 'images' 60 | img_dir.mkdir(parents=True, exist_ok=True) 61 | ckpt_dir = output_dir / 'ckpt' 62 | ckpt_dir.mkdir(parents=True, exist_ok=True) 63 | 64 | if config.dataset.startswith('huggan'): 65 | transform = T.Compose([ 66 | T.Resize((config.img_size, config.img_size)), 67 | T.RandomHorizontalFlip(), 68 | T.ToTensor(), 69 | T.Lambda(lambda t: (t * 2) - 1) 70 | ]) 71 | else: 72 | transform = T.Compose([ 73 | T.RandomHorizontalFlip(), 74 | T.ToTensor(), 75 | T.Lambda(lambda t: (t * 2) - 1) 76 | ]) 77 | 78 | def transforms(examples): 79 | examples["pixel_values"] = [transform(image.convert(config.img_convert)) for image in examples[config.img_key]] 80 | del examples[config.img_key] 81 | 82 | return examples 83 | 84 | ds = load_dataset(config.dataset) 85 | if hasattr(config, 'filter'): 86 | if config.filter == 'cat': 87 | ds = ds.filter(lambda x: x["label"] == 0) 88 | transformed_ds = ds.with_transform(transforms).remove_columns("label") 89 | 90 | # create dataloader 91 | dl = DataLoader(transformed_ds["train"], batch_size=config.batch_size, shuffle=True, num_workers=16) 92 | 93 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 94 | model = Unet(**config.model).to(device) 95 | 96 | optimizer = optim.AdamW(model.parameters(), lr=1e-4) 97 | 98 | def handle_batch(batch): 99 | batch_size = batch["pixel_values"].shape[0] 100 | x = batch["pixel_values"].to(device) 101 | 102 | t = torch.empty(size=(batch_size,), device=device).uniform_(eps, 1) 103 | loss = loss_fn(model, x, t, sigma) 104 | return loss 105 | 106 | train_losses = list() 107 | for epoch in range(1, config.epochs + 1): 108 | losses = list() 109 | bar = tqdm(dl, total=len(dl), desc=f'Epoch {epoch}: ') 110 | for batch in bar: 111 | optimizer.zero_grad() 112 | loss = handle_batch(batch) 113 | loss.backward() 114 | clip_grad_norm_(model.parameters(), 5.0) 115 | optimizer.step() 116 | losses.append(loss.item()) 117 | bar.set_postfix_str(f'Loss: {np.mean(losses):.6f}') 118 | train_losses.append(np.mean(losses)) 119 | if epoch % config.image_interval == 0: 120 | images = sample_ode(model, config.img_size, channels=config.model.channels) 121 | img = make_grid(images, nrow=4, normalize=True) 122 | img = T.ToPILImage()(img) 123 | img.save(img_dir / f'epoch_{epoch}.png') 124 | 125 | if epoch % config.ckpt_interval == 0: 126 | torch.save({ 127 | 'epoch': epoch, 128 | 'model': model.state_dict(), 129 | 'optimizer': optimizer.state_dict(), 130 | }, ckpt_dir / f'epoch_{epoch:05d}.pth') 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /consistency/src/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from datasets import load_dataset 7 | from torchvision import transforms as T 8 | from torchvision.utils import make_grid 9 | from torch.utils.data import DataLoader 10 | from torch.nn.utils import clip_grad_norm_ 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | from argparse import ArgumentParser 14 | from omegaconf import OmegaConf 15 | 16 | from model import ConsistencyModel 17 | 18 | def t_schedule(rho, eps, N, T): 19 | # paper p.4 20 | return torch.tensor([ 21 | ( eps ** (1 / rho) + (i - 1) / (N - 1) * (T ** (1 / rho) - eps ** (1 / rho)) ) ** rho 22 | for i in range(N) 23 | ]) 24 | 25 | def main(): 26 | parser = ArgumentParser() 27 | parser.add_argument('--config', type=str, required=True) 28 | args = parser.parse_args() 29 | 30 | config = OmegaConf.load(args.config) 31 | 32 | torch.manual_seed(42) 33 | 34 | output_dir = Path(f'{config.output_dir}/{config.dataset}') 35 | img_dir = output_dir / 'images' 36 | img_dir.mkdir(parents=True, exist_ok=True) 37 | ckpt_dir = output_dir / 'ckpt' 38 | ckpt_dir.mkdir(parents=True, exist_ok=True) 39 | 40 | if config.dataset.startswith('huggan'): 41 | transform = T.Compose([ 42 | T.Resize((config.img_size, config.img_size)), 43 | T.RandomHorizontalFlip(), 44 | T.ToTensor(), 45 | T.Lambda(lambda t: (t * 2) - 1) 46 | ]) 47 | else: 48 | transform = T.Compose([ 49 | T.RandomHorizontalFlip(), 50 | T.ToTensor(), 51 | T.Lambda(lambda t: (t * 2) - 1) 52 | ]) 53 | 54 | def transforms(examples): 55 | examples["pixel_values"] = [transform(image.convert(config.img_convert)) for image in examples[config.img_key]] 56 | del examples[config.img_key] 57 | 58 | return examples 59 | 60 | ds = load_dataset(config.dataset) 61 | if hasattr(config, 'filter'): 62 | if config.filter == 'cat': 63 | ds = ds.filter(lambda x: x["label"] == 0) 64 | transformed_ds = ds.with_transform(transforms).remove_columns("label") 65 | 66 | # create dataloader 67 | dl = DataLoader(transformed_ds["train"], batch_size=config.batch_size, shuffle=True, num_workers=16) 68 | 69 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 70 | model = ConsistencyModel(**config.model).to(device) 71 | ema_model = ConsistencyModel(**config.model).to(device) 72 | ema_model.load_state_dict(model.state_dict()) 73 | 74 | optimizer = optim.AdamW(model.parameters(), lr=1e-4) 75 | 76 | train_losses = list() 77 | for epoch in range(1, config.epochs + 1): 78 | losses = list() 79 | bar = tqdm(dl, total=len(dl), desc=f'Epoch {epoch}: ') 80 | model.train() 81 | for i, batch in enumerate(bar): 82 | optimizer.zero_grad() 83 | 84 | batch_size = batch["pixel_values"].shape[0] 85 | x = batch["pixel_values"].to(device) 86 | 87 | # paper p.25 88 | N = math.ceil( 89 | math.sqrt( 90 | (epoch * len(dl) + i) / (config.epochs * len(dl)) * ((config.schedule.s_1 + 1) ** 2 - config.schedule.s_0 ** 2) 91 | + config.schedule.s_0 ** 2 92 | ) 93 | ) + 1 94 | t_boundaries = t_schedule(**config.schedule.fun, N=N).to(device) 95 | t_indices = torch.randint(low=0, high=N-1, size=(batch_size,)) 96 | t1 = t_boundaries[t_indices + 1] 97 | t2 = t_boundaries[t_indices] 98 | z = torch.randn_like(x) 99 | loss = model.loss(x, z, t1, t2, ema_model) 100 | 101 | loss.backward() 102 | clip_grad_norm_(model.parameters(), 1.0) 103 | optimizer.step() 104 | with torch.no_grad(): 105 | mu = math.exp(2 * math.log(config.schedule.mu_0) / N) 106 | for p, ema_p in zip(model.parameters(), ema_model.parameters()): 107 | ema_p.mul_(mu).add_(p, alpha=1 - mu) 108 | 109 | losses.append(loss.item()) 110 | bar.set_postfix_str(f'N: {N}, Loss: {np.mean(losses):.6f}') 111 | train_losses.append(np.mean(losses)) 112 | if epoch % config.image_interval == 0: 113 | model.eval() 114 | x = torch.randn(16, config.model.channels, config.img_size, config.img_size, device=device) * config.schedule.fun.T 115 | images = model.sample( 116 | x, 117 | ts=[80.0, 20.0, 5.0, 1.0] 118 | ) 119 | img = make_grid(images, nrow=4, normalize=True) 120 | img = T.ToPILImage()(img) 121 | img.save(img_dir / f'epoch_{epoch}.png') 122 | 123 | if epoch % config.ckpt_interval == 0: 124 | torch.save({ 125 | 'epoch': epoch, 126 | 'model': model.state_dict(), 127 | 'optimizer': optimizer.state_dict(), 128 | }, ckpt_dir / f'epoch_{epoch:05d}.pth') 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /score-sde/src/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | from datasets import load_dataset 5 | from torchvision import transforms as T 6 | from torchvision.utils import make_grid 7 | from torch.utils.data import DataLoader 8 | from torch.nn.utils import clip_grad_norm_ 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | from scipy import integrate 12 | from argparse import ArgumentParser 13 | from omegaconf import OmegaConf 14 | 15 | from model import Unet 16 | from sde import ( 17 | VESDE, 18 | VPSDE, 19 | SubVPSDE 20 | ) 21 | 22 | eps = 1e-3 23 | 24 | @torch.no_grad() 25 | def sample_ode(sde, model, image_size, batch_size=16, channels=1): 26 | shape = (batch_size, channels, image_size, image_size) 27 | device = next(model.parameters()).device 28 | 29 | b = shape[0] 30 | t_1 = torch.ones((b,), device=device) 31 | _, std = sde.marginal_prob(torch.zeros(shape, device=device), t_1) 32 | x = torch.randn(shape, device=device) * std[:, None, None, None] 33 | 34 | def ode_func(t, x): 35 | x = torch.tensor(x, device=device, dtype=torch.float).reshape(shape) 36 | t = torch.full(size=(b,), fill_value=t, device=device, dtype=torch.float).reshape((b,)) 37 | score = model(x, t) 38 | drift, _ = sde.probability_flow(score, x, t) 39 | return drift.cpu().numpy().reshape((-1,)).astype(np.float64) 40 | 41 | res = integrate.solve_ivp(ode_func, (1., eps), x.reshape((-1,)).cpu().numpy(), rtol=1e-5, atol=1e-5, method='RK45') 42 | x = torch.tensor(res.y[:, -1], device=device).reshape(shape) 43 | return x.clamp(-1, 1) 44 | 45 | # forward diffusion (using the nice property) 46 | def q_sample(sde, x_0, t, noise): 47 | mean, std = sde.marginal_prob(x_0, t) 48 | perturb_x = mean + std[:, None, None, None] * noise 49 | return perturb_x, mean, std 50 | 51 | def p_losses(sde, model, x_0, t): 52 | z = torch.randn_like(x_0) 53 | perturb_x, _, std = q_sample(sde=sde, x_0=x_0, t=t, noise=z) 54 | score = model(perturb_x, t) 55 | loss = torch.mean((score * std[:, None, None, None] + z) ** 2) 56 | return loss 57 | 58 | def main(): 59 | parser = ArgumentParser() 60 | parser.add_argument('--config', type=str, required=True) 61 | args = parser.parse_args() 62 | 63 | config = OmegaConf.load(args.config) 64 | 65 | torch.manual_seed(42) 66 | 67 | output_dir = Path(f'{config.output_dir}/{config.dataset}') 68 | img_dir = output_dir / 'images' 69 | img_dir.mkdir(parents=True, exist_ok=True) 70 | ckpt_dir = output_dir / 'ckpt' 71 | ckpt_dir.mkdir(parents=True, exist_ok=True) 72 | 73 | if config.dataset.startswith('huggan'): 74 | transform = T.Compose([ 75 | T.Resize((config.img_size, config.img_size)), 76 | T.RandomHorizontalFlip(), 77 | T.ToTensor(), 78 | T.Lambda(lambda t: (t * 2) - 1) 79 | ]) 80 | else: 81 | transform = T.Compose([ 82 | T.RandomHorizontalFlip(), 83 | T.ToTensor(), 84 | T.Lambda(lambda t: (t * 2) - 1) 85 | ]) 86 | 87 | def transforms(examples): 88 | examples["pixel_values"] = [transform(image.convert(config.img_convert)) for image in examples[config.img_key]] 89 | del examples[config.img_key] 90 | 91 | return examples 92 | 93 | ds = load_dataset(config.dataset) 94 | if hasattr(config, 'filter'): 95 | if config.filter == 'cat': 96 | ds = ds.filter(lambda x: x["label"] == 0) 97 | transformed_ds = ds.with_transform(transforms).remove_columns("label") 98 | 99 | # create dataloader 100 | dl = DataLoader(transformed_ds["train"], batch_size=config.batch_size, shuffle=True, num_workers=16) 101 | 102 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 103 | model = Unet(**config.model).to(device) 104 | 105 | if config.sde_type == 've': 106 | sde = VESDE() 107 | elif config.sde_type == 'vp': 108 | sde = VPSDE() 109 | elif config.sde_type == 'subvp': 110 | sde = SubVPSDE() 111 | else: 112 | raise NotImplementedError() 113 | 114 | optimizer = optim.Adam(model.parameters(), lr=1e-4) 115 | 116 | def handle_batch(batch): 117 | batch_size = batch["pixel_values"].shape[0] 118 | x = batch["pixel_values"].to(device) 119 | 120 | t = torch.empty(size=(batch_size,), device=device).uniform_(eps, 1) 121 | loss = p_losses(sde, model, x, t) 122 | return loss 123 | 124 | train_losses = list() 125 | for epoch in range(1, config.epochs + 1): 126 | losses = list() 127 | bar = tqdm(dl, total=len(dl), desc=f'Epoch {epoch}: ') 128 | for batch in bar: 129 | optimizer.zero_grad() 130 | loss = handle_batch(batch) 131 | loss.backward() 132 | clip_grad_norm_(model.parameters(), 5.0) 133 | optimizer.step() 134 | losses.append(loss.item()) 135 | bar.set_postfix_str(f'Loss: {np.mean(losses):.6f}') 136 | train_losses.append(np.mean(losses)) 137 | if epoch % config.image_interval == 0: 138 | images = sample_ode(sde, model, config.img_size, channels=config.model.channels) 139 | img = make_grid(images, nrow=4, normalize=True) 140 | img = T.ToPILImage()(img) 141 | img.save(img_dir / f'epoch_{epoch}.png') 142 | 143 | if epoch % config.ckpt_interval == 0: 144 | torch.save({ 145 | 'epoch': epoch, 146 | 'model': model.state_dict(), 147 | 'optimizer': optimizer.state_dict(), 148 | }, ckpt_dir / f'epoch_{epoch:05d}.pth') 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /ddpm/src/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from datasets import load_dataset 6 | from torchvision import transforms as T 7 | from torchvision.utils import make_grid 8 | from torch.utils.data import DataLoader 9 | from torch.nn.utils import clip_grad_norm_ 10 | from tqdm import tqdm 11 | from argparse import ArgumentParser 12 | from pathlib import Path 13 | from omegaconf import OmegaConf 14 | 15 | from schedule import linear_beta_schedule 16 | from model import Unet 17 | 18 | timesteps = 1000 19 | 20 | # define beta schedule 21 | betas = linear_beta_schedule(timesteps=timesteps) 22 | 23 | # define alphas 24 | alphas = 1. - betas 25 | alphas_cumprod = torch.cumprod(alphas, axis=0) 26 | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) 27 | sqrt_recip_alphas = torch.sqrt(1.0 / alphas) 28 | 29 | # calculations for diffusion q(x_t | x_{t-1}) and others 30 | sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 31 | sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) 32 | 33 | # calculations for posterior q(x_{t-1} | x_t, x_0) 34 | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 35 | 36 | def extract(a, t, x_shape): 37 | batch_size = t.shape[0] 38 | out = a.gather(-1, t.cpu()) 39 | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) 40 | 41 | def q_sample(x_start, t, noise=None): 42 | if noise is None: 43 | noise = torch.randn_like(x_start) 44 | 45 | sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) 46 | sqrt_one_minus_alphas_cumprod_t = extract( 47 | sqrt_one_minus_alphas_cumprod, t, x_start.shape 48 | ) 49 | 50 | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise 51 | 52 | @torch.no_grad() 53 | def p_sample(model, x, t, t_index): 54 | betas_t = extract(betas, t, x.shape) 55 | sqrt_one_minus_alphas_cumprod_t = extract( 56 | sqrt_one_minus_alphas_cumprod, t, x.shape 57 | ) 58 | sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) 59 | 60 | model_mean = sqrt_recip_alphas_t * ( 61 | x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t 62 | ) 63 | 64 | if t_index == 0: 65 | return model_mean 66 | else: 67 | posterior_variance_t = extract(posterior_variance, t, x.shape) 68 | noise = torch.randn_like(x) 69 | # Algorithm 2 line 4: 70 | return model_mean + torch.sqrt(posterior_variance_t) * noise 71 | 72 | # Algorithm 2 (including returning all images) 73 | @torch.no_grad() 74 | def p_sample_loop(model, shape): 75 | device = next(model.parameters()).device 76 | 77 | b = shape[0] 78 | # start from pure noise (for each example in the batch) 79 | img = torch.randn(shape, device=device) 80 | 81 | for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps): 82 | t = torch.full((b,), i, device=device, dtype=torch.long) 83 | img = p_sample(model, img, t, i) 84 | return img.clamp(-1, 1) 85 | 86 | @torch.no_grad() 87 | def sample(model, image_size, batch_size=16, channels=3): 88 | return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) 89 | 90 | def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"): 91 | if noise is None: 92 | noise = torch.randn_like(x_start) 93 | 94 | x_noisy = q_sample(x_start=x_start, t=t, noise=noise) 95 | predicted_noise = denoise_model(x_noisy, t) 96 | 97 | if loss_type == 'l1': 98 | loss = F.l1_loss(noise, predicted_noise) 99 | elif loss_type == 'l2': 100 | loss = F.mse_loss(noise, predicted_noise) 101 | elif loss_type == "huber": 102 | loss = F.smooth_l1_loss(noise, predicted_noise) 103 | else: 104 | raise NotImplementedError() 105 | 106 | return loss 107 | 108 | def main(): 109 | parser = ArgumentParser() 110 | parser.add_argument('--config', type=str, required=True) 111 | args = parser.parse_args() 112 | 113 | config = OmegaConf.load(args.config) 114 | 115 | torch.manual_seed(42) 116 | 117 | output_dir = Path(f'{config.output_dir}/{config.dataset}') 118 | img_dir = output_dir / 'images' 119 | img_dir.mkdir(parents=True, exist_ok=True) 120 | ckpt_dir = output_dir / 'ckpt' 121 | ckpt_dir.mkdir(parents=True, exist_ok=True) 122 | 123 | if config.dataset.startswith('huggan'): 124 | transform = T.Compose([ 125 | T.Resize((config.img_size, config.img_size)), 126 | T.RandomHorizontalFlip(), 127 | T.ToTensor(), 128 | T.Lambda(lambda t: (t * 2) - 1) 129 | ]) 130 | else: 131 | transform = T.Compose([ 132 | T.RandomHorizontalFlip(), 133 | T.ToTensor(), 134 | T.Lambda(lambda t: (t * 2) - 1) 135 | ]) 136 | 137 | def transforms(examples): 138 | examples["pixel_values"] = [transform(image.convert(config.img_convert)) for image in examples[config.img_key]] 139 | del examples[config.img_key] 140 | 141 | return examples 142 | 143 | ds = load_dataset(config.dataset) 144 | if hasattr(config, 'filter'): 145 | if config.filter == 'cat': 146 | ds = ds.filter(lambda x: x["label"] == 0) 147 | transformed_ds = ds.with_transform(transforms).remove_columns("label") 148 | 149 | # create dataloader 150 | dl = DataLoader(transformed_ds["train"], batch_size=config.batch_size, shuffle=True, num_workers=16) 151 | 152 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 153 | model = Unet(**config.model).to(device) 154 | 155 | optimizer = optim.AdamW(model.parameters(), lr=1e-3) 156 | 157 | def handle_batch(batch): 158 | batch_size = batch["pixel_values"].shape[0] 159 | batch = batch["pixel_values"].to(device) 160 | 161 | t = torch.randint(0, timesteps, (batch_size,), device=device).long() 162 | loss = p_losses(model, batch, t, loss_type='l2') 163 | return loss 164 | 165 | train_losses = list() 166 | for epoch in range(1, config.epochs + 1): 167 | losses = list() 168 | bar = tqdm(dl, total=len(dl), desc=f'Epoch {epoch}: ') 169 | for batch in bar: 170 | optimizer.zero_grad() 171 | loss = handle_batch(batch) 172 | loss.backward() 173 | clip_grad_norm_(model.parameters(), 5.0) 174 | optimizer.step() 175 | losses.append(loss.item()) 176 | bar.set_postfix_str(f'Loss: {np.mean(losses):.6f}') 177 | train_losses.append(np.mean(losses)) 178 | if epoch % config.image_interval == 0: 179 | images = sample(model, config.img_size, channels=config.model.channels) 180 | img = make_grid(images, nrow=4, normalize=True) 181 | img = T.ToPILImage()(img) 182 | img.save(img_dir / f'epoch_{epoch}.png') 183 | 184 | if epoch % config.ckpt_interval == 0: 185 | torch.save({ 186 | 'epoch': epoch, 187 | 'model': model.state_dict(), 188 | 'optimizer': optimizer.state_dict(), 189 | }, ckpt_dir / f'epoch_{epoch:05d}.pth') 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /ddpm/src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from functools import partial 6 | from inspect import isfunction 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def num_to_groups(num, divisor): 22 | groups = num // divisor 23 | remainder = num % divisor 24 | arr = [divisor] * groups 25 | if remainder > 0: 26 | arr.append(remainder) 27 | return arr 28 | 29 | 30 | class Residual(nn.Module): 31 | def __init__(self, fn): 32 | super().__init__() 33 | self.fn = fn 34 | 35 | def forward(self, x, *args, **kwargs): 36 | return self.fn(x, *args, **kwargs) + x 37 | 38 | 39 | def Upsample(dim, dim_out=None): 40 | return nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode="nearest"), 42 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), 43 | ) 44 | 45 | 46 | def Downsample(dim, dim_out=None): 47 | # No More Strided Convolutions or Pooling 48 | return nn.Sequential( 49 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 50 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 51 | ) 52 | 53 | class SinusoidalPositionEmbeddings(nn.Module): 54 | def __init__(self, dim): 55 | super().__init__() 56 | self.dim = dim 57 | 58 | def forward(self, time): 59 | device = time.device 60 | half_dim = self.dim // 2 61 | embeddings = math.log(10000) / (half_dim - 1) 62 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 63 | embeddings = time[:, None] * embeddings[None, :] 64 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 65 | return embeddings 66 | 67 | class WeightStandardizedConv2d(nn.Conv2d): 68 | """ 69 | https://arxiv.org/abs/1903.10520 70 | weight standardization purportedly works synergistically with group normalization 71 | """ 72 | 73 | def forward(self, x): 74 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 75 | 76 | weight = self.weight 77 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 78 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 79 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 80 | 81 | return F.conv2d( 82 | x, 83 | normalized_weight, 84 | self.bias, 85 | self.stride, 86 | self.padding, 87 | self.dilation, 88 | self.groups, 89 | ) 90 | 91 | 92 | class Block(nn.Module): 93 | def __init__(self, dim, dim_out, groups=8): 94 | super().__init__() 95 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 96 | self.norm = nn.GroupNorm(groups, dim_out) 97 | self.act = nn.SiLU() 98 | 99 | def forward(self, x, scale_shift=None): 100 | x = self.proj(x) 101 | x = self.norm(x) 102 | 103 | if exists(scale_shift): 104 | scale, shift = scale_shift 105 | x = x * (scale + 1) + shift 106 | 107 | x = self.act(x) 108 | return x 109 | 110 | 111 | class ResnetBlock(nn.Module): 112 | """https://arxiv.org/abs/1512.03385""" 113 | 114 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 115 | super().__init__() 116 | self.mlp = ( 117 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 118 | if exists(time_emb_dim) 119 | else None 120 | ) 121 | 122 | self.block1 = Block(dim, dim_out, groups=groups) 123 | self.block2 = Block(dim_out, dim_out, groups=groups) 124 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 125 | 126 | def forward(self, x, time_emb=None): 127 | scale_shift = None 128 | if exists(self.mlp) and exists(time_emb): 129 | time_emb = self.mlp(time_emb) 130 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 131 | scale_shift = time_emb.chunk(2, dim=1) 132 | 133 | h = self.block1(x, scale_shift=scale_shift) 134 | h = self.block2(h) 135 | return h + self.res_conv(x) 136 | 137 | class Attention(nn.Module): 138 | def __init__(self, dim, heads=4, dim_head=32): 139 | super().__init__() 140 | self.scale = dim_head**-0.5 141 | self.heads = heads 142 | hidden_dim = dim_head * heads 143 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 144 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 145 | 146 | def forward(self, x): 147 | b, c, h, w = x.shape 148 | qkv = self.to_qkv(x).chunk(3, dim=1) 149 | q, k, v = map( 150 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 151 | ) 152 | q = q * self.scale 153 | 154 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 155 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 156 | attn = sim.softmax(dim=-1) 157 | 158 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 159 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 160 | return self.to_out(out) 161 | 162 | class LinearAttention(nn.Module): 163 | def __init__(self, dim, heads=4, dim_head=32): 164 | super().__init__() 165 | self.scale = dim_head**-0.5 166 | self.heads = heads 167 | hidden_dim = dim_head * heads 168 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 169 | 170 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 171 | nn.GroupNorm(1, dim)) 172 | 173 | def forward(self, x): 174 | b, c, h, w = x.shape 175 | qkv = self.to_qkv(x).chunk(3, dim=1) 176 | q, k, v = map( 177 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 178 | ) 179 | 180 | q = q.softmax(dim=-2) 181 | k = k.softmax(dim=-1) 182 | 183 | q = q * self.scale 184 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 185 | 186 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 187 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 188 | return self.to_out(out) 189 | 190 | class PreNorm(nn.Module): 191 | def __init__(self, dim, fn): 192 | super().__init__() 193 | self.fn = fn 194 | self.norm = nn.GroupNorm(1, dim) 195 | 196 | def forward(self, x): 197 | x = self.norm(x) 198 | return self.fn(x) 199 | 200 | class Unet(nn.Module): 201 | def __init__( 202 | self, 203 | dim, 204 | init_dim=None, 205 | out_dim=None, 206 | dim_mults=(1, 2, 4, 8), 207 | channels=3, 208 | self_condition=False, 209 | resnet_block_groups=4, 210 | ): 211 | super().__init__() 212 | 213 | # determine dimensions 214 | self.channels = channels 215 | self.self_condition = self_condition 216 | input_channels = channels * (2 if self_condition else 1) 217 | 218 | init_dim = default(init_dim, dim) 219 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 220 | 221 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 222 | in_out = list(zip(dims[:-1], dims[1:])) 223 | 224 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 225 | 226 | # time embeddings 227 | time_dim = dim * 4 228 | 229 | self.time_mlp = nn.Sequential( 230 | SinusoidalPositionEmbeddings(dim), 231 | nn.Linear(dim, time_dim), 232 | nn.GELU(), 233 | nn.Linear(time_dim, time_dim), 234 | ) 235 | 236 | # layers 237 | self.downs = nn.ModuleList([]) 238 | self.ups = nn.ModuleList([]) 239 | num_resolutions = len(in_out) 240 | 241 | for ind, (dim_in, dim_out) in enumerate(in_out): 242 | is_last = ind >= (num_resolutions - 1) 243 | 244 | self.downs.append( 245 | nn.ModuleList( 246 | [ 247 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 248 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 249 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 250 | Downsample(dim_in, dim_out) 251 | if not is_last 252 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 253 | ] 254 | ) 255 | ) 256 | 257 | mid_dim = dims[-1] 258 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 259 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 260 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 261 | 262 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 263 | is_last = ind == (len(in_out) - 1) 264 | 265 | self.ups.append( 266 | nn.ModuleList( 267 | [ 268 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 269 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 270 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 271 | Upsample(dim_out, dim_in) 272 | if not is_last 273 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 274 | ] 275 | ) 276 | ) 277 | 278 | self.out_dim = default(out_dim, channels) 279 | 280 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 281 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 282 | 283 | def forward(self, x, time, x_self_cond=None): 284 | if self.self_condition: 285 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 286 | x = torch.cat((x_self_cond, x), dim=1) 287 | 288 | x = self.init_conv(x) 289 | r = x.clone() 290 | 291 | t = self.time_mlp(time) 292 | 293 | h = [] 294 | 295 | for block1, block2, attn, downsample in self.downs: 296 | x = block1(x, t) 297 | h.append(x) 298 | 299 | x = block2(x, t) 300 | x = attn(x) 301 | h.append(x) 302 | 303 | x = downsample(x) 304 | 305 | x = self.mid_block1(x, t) 306 | x = self.mid_attn(x) 307 | x = self.mid_block2(x, t) 308 | 309 | for block1, block2, attn, upsample in self.ups: 310 | x = torch.cat((x, h.pop()), dim=1) 311 | x = block1(x, t) 312 | 313 | x = torch.cat((x, h.pop()), dim=1) 314 | x = block2(x, t) 315 | x = attn(x) 316 | 317 | x = upsample(x) 318 | 319 | x = torch.cat((x, r), dim=1) 320 | 321 | x = self.final_res_block(x, t) 322 | return self.final_conv(x) 323 | -------------------------------------------------------------------------------- /cfm/src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from functools import partial 6 | from inspect import isfunction 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def num_to_groups(num, divisor): 22 | groups = num // divisor 23 | remainder = num % divisor 24 | arr = [divisor] * groups 25 | if remainder > 0: 26 | arr.append(remainder) 27 | return arr 28 | 29 | 30 | class Residual(nn.Module): 31 | def __init__(self, fn): 32 | super().__init__() 33 | self.fn = fn 34 | 35 | def forward(self, x, *args, **kwargs): 36 | return self.fn(x, *args, **kwargs) + x 37 | 38 | 39 | def Upsample(dim, dim_out=None): 40 | return nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode="nearest"), 42 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), 43 | ) 44 | 45 | 46 | def Downsample(dim, dim_out=None): 47 | # No More Strided Convolutions or Pooling 48 | return nn.Sequential( 49 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 50 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 51 | ) 52 | 53 | class SinusoidalPositionEmbeddings(nn.Module): 54 | def __init__(self, dim, scale=1000): 55 | super().__init__() 56 | self.dim = dim 57 | self.scale = scale 58 | 59 | def forward(self, time): 60 | device = time.device 61 | half_dim = self.dim // 2 62 | embeddings = math.log(10000) / (half_dim - 1) 63 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 64 | embeddings = self.scale * time[:, None] * embeddings[None, :] 65 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 66 | return embeddings 67 | 68 | class WeightStandardizedConv2d(nn.Conv2d): 69 | """ 70 | https://arxiv.org/abs/1903.10520 71 | weight standardization purportedly works synergistically with group normalization 72 | """ 73 | 74 | def forward(self, x): 75 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 76 | 77 | weight = self.weight 78 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 79 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 80 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 81 | 82 | return F.conv2d( 83 | x, 84 | normalized_weight, 85 | self.bias, 86 | self.stride, 87 | self.padding, 88 | self.dilation, 89 | self.groups, 90 | ) 91 | 92 | 93 | class Block(nn.Module): 94 | def __init__(self, dim, dim_out, groups=8): 95 | super().__init__() 96 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 97 | self.norm = nn.GroupNorm(groups, dim_out) 98 | self.act = nn.SiLU() 99 | 100 | def forward(self, x, scale_shift=None): 101 | x = self.proj(x) 102 | x = self.norm(x) 103 | 104 | if exists(scale_shift): 105 | scale, shift = scale_shift 106 | x = x * (scale + 1) + shift 107 | 108 | x = self.act(x) 109 | return x 110 | 111 | 112 | class ResnetBlock(nn.Module): 113 | """https://arxiv.org/abs/1512.03385""" 114 | 115 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 116 | super().__init__() 117 | self.mlp = ( 118 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 119 | if exists(time_emb_dim) 120 | else None 121 | ) 122 | 123 | self.block1 = Block(dim, dim_out, groups=groups) 124 | self.block2 = Block(dim_out, dim_out, groups=groups) 125 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 126 | 127 | def forward(self, x, time_emb=None): 128 | scale_shift = None 129 | if exists(self.mlp) and exists(time_emb): 130 | time_emb = self.mlp(time_emb) 131 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 132 | scale_shift = time_emb.chunk(2, dim=1) 133 | 134 | h = self.block1(x, scale_shift=scale_shift) 135 | h = self.block2(h) 136 | return h + self.res_conv(x) 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, dim, heads=4, dim_head=32): 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 146 | 147 | def forward(self, x): 148 | b, c, h, w = x.shape 149 | qkv = self.to_qkv(x).chunk(3, dim=1) 150 | q, k, v = map( 151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 152 | ) 153 | q = q * self.scale 154 | 155 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 157 | attn = sim.softmax(dim=-1) 158 | 159 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 161 | return self.to_out(out) 162 | 163 | class LinearAttention(nn.Module): 164 | def __init__(self, dim, heads=4, dim_head=32): 165 | super().__init__() 166 | self.scale = dim_head**-0.5 167 | self.heads = heads 168 | hidden_dim = dim_head * heads 169 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 170 | 171 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 172 | nn.GroupNorm(1, dim)) 173 | 174 | def forward(self, x): 175 | b, c, h, w = x.shape 176 | qkv = self.to_qkv(x).chunk(3, dim=1) 177 | q, k, v = map( 178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 179 | ) 180 | 181 | q = q.softmax(dim=-2) 182 | k = k.softmax(dim=-1) 183 | 184 | q = q * self.scale 185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 186 | 187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 189 | return self.to_out(out) 190 | 191 | class PreNorm(nn.Module): 192 | def __init__(self, dim, fn): 193 | super().__init__() 194 | self.fn = fn 195 | self.norm = nn.GroupNorm(1, dim) 196 | 197 | def forward(self, x): 198 | x = self.norm(x) 199 | return self.fn(x) 200 | 201 | class Unet(nn.Module): 202 | def __init__( 203 | self, 204 | dim, 205 | init_dim=None, 206 | out_dim=None, 207 | dim_mults=(1, 2, 4, 8), 208 | channels=3, 209 | self_condition=False, 210 | resnet_block_groups=4, 211 | ): 212 | super().__init__() 213 | 214 | # determine dimensions 215 | self.channels = channels 216 | self.self_condition = self_condition 217 | input_channels = channels * (2 if self_condition else 1) 218 | 219 | init_dim = default(init_dim, dim) 220 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 221 | 222 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 223 | in_out = list(zip(dims[:-1], dims[1:])) 224 | 225 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 226 | 227 | # time embeddings 228 | time_dim = dim * 4 229 | 230 | self.time_mlp = nn.Sequential( 231 | SinusoidalPositionEmbeddings(dim), 232 | nn.Linear(dim, time_dim), 233 | nn.GELU(), 234 | nn.Linear(time_dim, time_dim), 235 | ) 236 | 237 | # layers 238 | self.downs = nn.ModuleList([]) 239 | self.ups = nn.ModuleList([]) 240 | num_resolutions = len(in_out) 241 | 242 | for ind, (dim_in, dim_out) in enumerate(in_out): 243 | is_last = ind >= (num_resolutions - 1) 244 | 245 | self.downs.append( 246 | nn.ModuleList( 247 | [ 248 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 249 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 250 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 251 | Downsample(dim_in, dim_out) 252 | if not is_last 253 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 254 | ] 255 | ) 256 | ) 257 | 258 | mid_dim = dims[-1] 259 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 260 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 261 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 262 | 263 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 264 | is_last = ind == (len(in_out) - 1) 265 | 266 | self.ups.append( 267 | nn.ModuleList( 268 | [ 269 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 270 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 271 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 272 | Upsample(dim_out, dim_in) 273 | if not is_last 274 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 275 | ] 276 | ) 277 | ) 278 | 279 | self.out_dim = default(out_dim, channels) 280 | 281 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 282 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 283 | 284 | def forward(self, x, time, x_self_cond=None): 285 | if self.self_condition: 286 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 287 | x = torch.cat((x_self_cond, x), dim=1) 288 | 289 | x = self.init_conv(x) 290 | r = x.clone() 291 | 292 | t = self.time_mlp(time) 293 | 294 | h = [] 295 | 296 | for block1, block2, attn, downsample in self.downs: 297 | x = block1(x, t) 298 | h.append(x) 299 | 300 | x = block2(x, t) 301 | x = attn(x) 302 | h.append(x) 303 | 304 | x = downsample(x) 305 | 306 | x = self.mid_block1(x, t) 307 | x = self.mid_attn(x) 308 | x = self.mid_block2(x, t) 309 | 310 | for block1, block2, attn, upsample in self.ups: 311 | x = torch.cat((x, h.pop()), dim=1) 312 | x = block1(x, t) 313 | 314 | x = torch.cat((x, h.pop()), dim=1) 315 | x = block2(x, t) 316 | x = attn(x) 317 | 318 | x = upsample(x) 319 | 320 | x = torch.cat((x, r), dim=1) 321 | 322 | x = self.final_res_block(x, t) 323 | return self.final_conv(x) 324 | -------------------------------------------------------------------------------- /rect_flow/src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from functools import partial 6 | from inspect import isfunction 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def num_to_groups(num, divisor): 22 | groups = num // divisor 23 | remainder = num % divisor 24 | arr = [divisor] * groups 25 | if remainder > 0: 26 | arr.append(remainder) 27 | return arr 28 | 29 | 30 | class Residual(nn.Module): 31 | def __init__(self, fn): 32 | super().__init__() 33 | self.fn = fn 34 | 35 | def forward(self, x, *args, **kwargs): 36 | return self.fn(x, *args, **kwargs) + x 37 | 38 | 39 | def Upsample(dim, dim_out=None): 40 | return nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode="nearest"), 42 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), 43 | ) 44 | 45 | 46 | def Downsample(dim, dim_out=None): 47 | # No More Strided Convolutions or Pooling 48 | return nn.Sequential( 49 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 50 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 51 | ) 52 | 53 | class SinusoidalPositionEmbeddings(nn.Module): 54 | def __init__(self, dim, scale=1000): 55 | super().__init__() 56 | self.dim = dim 57 | self.scale = scale 58 | 59 | def forward(self, time): 60 | device = time.device 61 | half_dim = self.dim // 2 62 | embeddings = math.log(10000) / (half_dim - 1) 63 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 64 | embeddings = self.scale * time[:, None] * embeddings[None, :] 65 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 66 | return embeddings 67 | 68 | class WeightStandardizedConv2d(nn.Conv2d): 69 | """ 70 | https://arxiv.org/abs/1903.10520 71 | weight standardization purportedly works synergistically with group normalization 72 | """ 73 | 74 | def forward(self, x): 75 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 76 | 77 | weight = self.weight 78 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 79 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 80 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 81 | 82 | return F.conv2d( 83 | x, 84 | normalized_weight, 85 | self.bias, 86 | self.stride, 87 | self.padding, 88 | self.dilation, 89 | self.groups, 90 | ) 91 | 92 | 93 | class Block(nn.Module): 94 | def __init__(self, dim, dim_out, groups=8): 95 | super().__init__() 96 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 97 | self.norm = nn.GroupNorm(groups, dim_out) 98 | self.act = nn.SiLU() 99 | 100 | def forward(self, x, scale_shift=None): 101 | x = self.proj(x) 102 | x = self.norm(x) 103 | 104 | if exists(scale_shift): 105 | scale, shift = scale_shift 106 | x = x * (scale + 1) + shift 107 | 108 | x = self.act(x) 109 | return x 110 | 111 | 112 | class ResnetBlock(nn.Module): 113 | """https://arxiv.org/abs/1512.03385""" 114 | 115 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 116 | super().__init__() 117 | self.mlp = ( 118 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 119 | if exists(time_emb_dim) 120 | else None 121 | ) 122 | 123 | self.block1 = Block(dim, dim_out, groups=groups) 124 | self.block2 = Block(dim_out, dim_out, groups=groups) 125 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 126 | 127 | def forward(self, x, time_emb=None): 128 | scale_shift = None 129 | if exists(self.mlp) and exists(time_emb): 130 | time_emb = self.mlp(time_emb) 131 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 132 | scale_shift = time_emb.chunk(2, dim=1) 133 | 134 | h = self.block1(x, scale_shift=scale_shift) 135 | h = self.block2(h) 136 | return h + self.res_conv(x) 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, dim, heads=4, dim_head=32): 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 146 | 147 | def forward(self, x): 148 | b, c, h, w = x.shape 149 | qkv = self.to_qkv(x).chunk(3, dim=1) 150 | q, k, v = map( 151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 152 | ) 153 | q = q * self.scale 154 | 155 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 157 | attn = sim.softmax(dim=-1) 158 | 159 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 161 | return self.to_out(out) 162 | 163 | class LinearAttention(nn.Module): 164 | def __init__(self, dim, heads=4, dim_head=32): 165 | super().__init__() 166 | self.scale = dim_head**-0.5 167 | self.heads = heads 168 | hidden_dim = dim_head * heads 169 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 170 | 171 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 172 | nn.GroupNorm(1, dim)) 173 | 174 | def forward(self, x): 175 | b, c, h, w = x.shape 176 | qkv = self.to_qkv(x).chunk(3, dim=1) 177 | q, k, v = map( 178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 179 | ) 180 | 181 | q = q.softmax(dim=-2) 182 | k = k.softmax(dim=-1) 183 | 184 | q = q * self.scale 185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 186 | 187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 189 | return self.to_out(out) 190 | 191 | class PreNorm(nn.Module): 192 | def __init__(self, dim, fn): 193 | super().__init__() 194 | self.fn = fn 195 | self.norm = nn.GroupNorm(1, dim) 196 | 197 | def forward(self, x): 198 | x = self.norm(x) 199 | return self.fn(x) 200 | 201 | class Unet(nn.Module): 202 | def __init__( 203 | self, 204 | dim, 205 | init_dim=None, 206 | out_dim=None, 207 | dim_mults=(1, 2, 4, 8), 208 | channels=3, 209 | self_condition=False, 210 | resnet_block_groups=4, 211 | ): 212 | super().__init__() 213 | 214 | # determine dimensions 215 | self.channels = channels 216 | self.self_condition = self_condition 217 | input_channels = channels * (2 if self_condition else 1) 218 | 219 | init_dim = default(init_dim, dim) 220 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 221 | 222 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 223 | in_out = list(zip(dims[:-1], dims[1:])) 224 | 225 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 226 | 227 | # time embeddings 228 | time_dim = dim * 4 229 | 230 | self.time_mlp = nn.Sequential( 231 | SinusoidalPositionEmbeddings(dim), 232 | nn.Linear(dim, time_dim), 233 | nn.GELU(), 234 | nn.Linear(time_dim, time_dim), 235 | ) 236 | 237 | # layers 238 | self.downs = nn.ModuleList([]) 239 | self.ups = nn.ModuleList([]) 240 | num_resolutions = len(in_out) 241 | 242 | for ind, (dim_in, dim_out) in enumerate(in_out): 243 | is_last = ind >= (num_resolutions - 1) 244 | 245 | self.downs.append( 246 | nn.ModuleList( 247 | [ 248 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 249 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 250 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 251 | Downsample(dim_in, dim_out) 252 | if not is_last 253 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 254 | ] 255 | ) 256 | ) 257 | 258 | mid_dim = dims[-1] 259 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 260 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 261 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 262 | 263 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 264 | is_last = ind == (len(in_out) - 1) 265 | 266 | self.ups.append( 267 | nn.ModuleList( 268 | [ 269 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 270 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 271 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 272 | Upsample(dim_out, dim_in) 273 | if not is_last 274 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 275 | ] 276 | ) 277 | ) 278 | 279 | self.out_dim = default(out_dim, channels) 280 | 281 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 282 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 283 | 284 | def forward(self, x, time, x_self_cond=None): 285 | if self.self_condition: 286 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 287 | x = torch.cat((x_self_cond, x), dim=1) 288 | 289 | x = self.init_conv(x) 290 | r = x.clone() 291 | 292 | t = self.time_mlp(time) 293 | 294 | h = [] 295 | 296 | for block1, block2, attn, downsample in self.downs: 297 | x = block1(x, t) 298 | h.append(x) 299 | 300 | x = block2(x, t) 301 | x = attn(x) 302 | h.append(x) 303 | 304 | x = downsample(x) 305 | 306 | x = self.mid_block1(x, t) 307 | x = self.mid_attn(x) 308 | x = self.mid_block2(x, t) 309 | 310 | for block1, block2, attn, upsample in self.ups: 311 | x = torch.cat((x, h.pop()), dim=1) 312 | x = block1(x, t) 313 | 314 | x = torch.cat((x, h.pop()), dim=1) 315 | x = block2(x, t) 316 | x = attn(x) 317 | 318 | x = upsample(x) 319 | 320 | x = torch.cat((x, r), dim=1) 321 | 322 | x = self.final_res_block(x, t) 323 | return self.final_conv(x) 324 | -------------------------------------------------------------------------------- /score-sde/src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from functools import partial 6 | from inspect import isfunction 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def num_to_groups(num, divisor): 22 | groups = num // divisor 23 | remainder = num % divisor 24 | arr = [divisor] * groups 25 | if remainder > 0: 26 | arr.append(remainder) 27 | return arr 28 | 29 | 30 | class Residual(nn.Module): 31 | def __init__(self, fn): 32 | super().__init__() 33 | self.fn = fn 34 | 35 | def forward(self, x, *args, **kwargs): 36 | return self.fn(x, *args, **kwargs) + x 37 | 38 | 39 | def Upsample(dim, dim_out=None): 40 | return nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode="nearest"), 42 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), 43 | ) 44 | 45 | 46 | def Downsample(dim, dim_out=None): 47 | # No More Strided Convolutions or Pooling 48 | return nn.Sequential( 49 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 50 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 51 | ) 52 | 53 | class SinusoidalPositionEmbeddings(nn.Module): 54 | def __init__(self, dim, scale=1000): 55 | super().__init__() 56 | self.dim = dim 57 | self.scale = scale 58 | 59 | def forward(self, time): 60 | device = time.device 61 | half_dim = self.dim // 2 62 | embeddings = math.log(10000) / (half_dim - 1) 63 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 64 | embeddings = self.scale * time[:, None] * embeddings[None, :] 65 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 66 | return embeddings 67 | 68 | class WeightStandardizedConv2d(nn.Conv2d): 69 | """ 70 | https://arxiv.org/abs/1903.10520 71 | weight standardization purportedly works synergistically with group normalization 72 | """ 73 | 74 | def forward(self, x): 75 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 76 | 77 | weight = self.weight 78 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 79 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 80 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 81 | 82 | return F.conv2d( 83 | x, 84 | normalized_weight, 85 | self.bias, 86 | self.stride, 87 | self.padding, 88 | self.dilation, 89 | self.groups, 90 | ) 91 | 92 | 93 | class Block(nn.Module): 94 | def __init__(self, dim, dim_out, groups=8): 95 | super().__init__() 96 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 97 | self.norm = nn.GroupNorm(groups, dim_out) 98 | self.act = nn.SiLU() 99 | 100 | def forward(self, x, scale_shift=None): 101 | x = self.proj(x) 102 | x = self.norm(x) 103 | 104 | if exists(scale_shift): 105 | scale, shift = scale_shift 106 | x = x * (scale + 1) + shift 107 | 108 | x = self.act(x) 109 | return x 110 | 111 | 112 | class ResnetBlock(nn.Module): 113 | """https://arxiv.org/abs/1512.03385""" 114 | 115 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 116 | super().__init__() 117 | self.mlp = ( 118 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 119 | if exists(time_emb_dim) 120 | else None 121 | ) 122 | 123 | self.block1 = Block(dim, dim_out, groups=groups) 124 | self.block2 = Block(dim_out, dim_out, groups=groups) 125 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 126 | 127 | def forward(self, x, time_emb=None): 128 | scale_shift = None 129 | if exists(self.mlp) and exists(time_emb): 130 | time_emb = self.mlp(time_emb) 131 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 132 | scale_shift = time_emb.chunk(2, dim=1) 133 | 134 | h = self.block1(x, scale_shift=scale_shift) 135 | h = self.block2(h) 136 | return h + self.res_conv(x) 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, dim, heads=4, dim_head=32): 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 146 | 147 | def forward(self, x): 148 | b, c, h, w = x.shape 149 | qkv = self.to_qkv(x).chunk(3, dim=1) 150 | q, k, v = map( 151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 152 | ) 153 | q = q * self.scale 154 | 155 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 157 | attn = sim.softmax(dim=-1) 158 | 159 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 161 | return self.to_out(out) 162 | 163 | class LinearAttention(nn.Module): 164 | def __init__(self, dim, heads=4, dim_head=32): 165 | super().__init__() 166 | self.scale = dim_head**-0.5 167 | self.heads = heads 168 | hidden_dim = dim_head * heads 169 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 170 | 171 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 172 | nn.GroupNorm(1, dim)) 173 | 174 | def forward(self, x): 175 | b, c, h, w = x.shape 176 | qkv = self.to_qkv(x).chunk(3, dim=1) 177 | q, k, v = map( 178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 179 | ) 180 | 181 | q = q.softmax(dim=-2) 182 | k = k.softmax(dim=-1) 183 | 184 | q = q * self.scale 185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 186 | 187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 189 | return self.to_out(out) 190 | 191 | class PreNorm(nn.Module): 192 | def __init__(self, dim, fn): 193 | super().__init__() 194 | self.fn = fn 195 | self.norm = nn.GroupNorm(1, dim) 196 | 197 | def forward(self, x): 198 | x = self.norm(x) 199 | return self.fn(x) 200 | 201 | class Unet(nn.Module): 202 | def __init__( 203 | self, 204 | dim, 205 | init_dim=None, 206 | out_dim=None, 207 | dim_mults=(1, 2, 4, 8), 208 | channels=3, 209 | self_condition=False, 210 | resnet_block_groups=4, 211 | ): 212 | super().__init__() 213 | 214 | # determine dimensions 215 | self.channels = channels 216 | self.self_condition = self_condition 217 | input_channels = channels * (2 if self_condition else 1) 218 | 219 | init_dim = default(init_dim, dim) 220 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 221 | 222 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 223 | in_out = list(zip(dims[:-1], dims[1:])) 224 | 225 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 226 | 227 | # time embeddings 228 | time_dim = dim * 4 229 | 230 | self.time_mlp = nn.Sequential( 231 | SinusoidalPositionEmbeddings(dim), 232 | nn.Linear(dim, time_dim), 233 | nn.GELU(), 234 | nn.Linear(time_dim, time_dim), 235 | ) 236 | 237 | # layers 238 | self.downs = nn.ModuleList([]) 239 | self.ups = nn.ModuleList([]) 240 | num_resolutions = len(in_out) 241 | 242 | for ind, (dim_in, dim_out) in enumerate(in_out): 243 | is_last = ind >= (num_resolutions - 1) 244 | 245 | self.downs.append( 246 | nn.ModuleList( 247 | [ 248 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 249 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 250 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 251 | Downsample(dim_in, dim_out) 252 | if not is_last 253 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 254 | ] 255 | ) 256 | ) 257 | 258 | mid_dim = dims[-1] 259 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 260 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 261 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 262 | 263 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 264 | is_last = ind == (len(in_out) - 1) 265 | 266 | self.ups.append( 267 | nn.ModuleList( 268 | [ 269 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 270 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 271 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 272 | Upsample(dim_out, dim_in) 273 | if not is_last 274 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 275 | ] 276 | ) 277 | ) 278 | 279 | self.out_dim = default(out_dim, channels) 280 | 281 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 282 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 283 | 284 | def forward(self, x, time, x_self_cond=None): 285 | if self.self_condition: 286 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 287 | x = torch.cat((x_self_cond, x), dim=1) 288 | 289 | x = self.init_conv(x) 290 | r = x.clone() 291 | 292 | t = self.time_mlp(time) 293 | 294 | h = [] 295 | 296 | for block1, block2, attn, downsample in self.downs: 297 | x = block1(x, t) 298 | h.append(x) 299 | 300 | x = block2(x, t) 301 | x = attn(x) 302 | h.append(x) 303 | 304 | x = downsample(x) 305 | 306 | x = self.mid_block1(x, t) 307 | x = self.mid_attn(x) 308 | x = self.mid_block2(x, t) 309 | 310 | for block1, block2, attn, upsample in self.ups: 311 | x = torch.cat((x, h.pop()), dim=1) 312 | x = block1(x, t) 313 | 314 | x = torch.cat((x, h.pop()), dim=1) 315 | x = block2(x, t) 316 | x = attn(x) 317 | 318 | x = upsample(x) 319 | 320 | x = torch.cat((x, r), dim=1) 321 | 322 | x = self.final_res_block(x, t) 323 | return self.final_conv(x) 324 | -------------------------------------------------------------------------------- /consistency/src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from functools import partial 6 | from inspect import isfunction 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | 15 | def default(val, d): 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def num_to_groups(num, divisor): 22 | groups = num // divisor 23 | remainder = num % divisor 24 | arr = [divisor] * groups 25 | if remainder > 0: 26 | arr.append(remainder) 27 | return arr 28 | 29 | 30 | class Residual(nn.Module): 31 | def __init__(self, fn): 32 | super().__init__() 33 | self.fn = fn 34 | 35 | def forward(self, x, *args, **kwargs): 36 | return self.fn(x, *args, **kwargs) + x 37 | 38 | 39 | def Upsample(dim, dim_out=None): 40 | return nn.Sequential( 41 | nn.Upsample(scale_factor=2, mode="nearest"), 42 | nn.Conv2d(dim, default(dim_out, dim), 3, padding=1), 43 | ) 44 | 45 | 46 | def Downsample(dim, dim_out=None): 47 | # No More Strided Convolutions or Pooling 48 | return nn.Sequential( 49 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 50 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 51 | ) 52 | 53 | class SinusoidalPositionEmbeddings(nn.Module): 54 | def __init__(self, dim, scale=1000): 55 | super().__init__() 56 | self.dim = dim 57 | self.scale = scale 58 | 59 | def forward(self, time): 60 | device = time.device 61 | half_dim = self.dim // 2 62 | embeddings = math.log(10000) / (half_dim - 1) 63 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 64 | embeddings = self.scale * time[:, None] * embeddings[None, :] 65 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 66 | return embeddings 67 | 68 | class WeightStandardizedConv2d(nn.Conv2d): 69 | """ 70 | https://arxiv.org/abs/1903.10520 71 | weight standardization purportedly works synergistically with group normalization 72 | """ 73 | 74 | def forward(self, x): 75 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 76 | 77 | weight = self.weight 78 | mean = reduce(weight, "o ... -> o 1 1 1", "mean") 79 | var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) 80 | normalized_weight = (weight - mean) * (var + eps).rsqrt() 81 | 82 | return F.conv2d( 83 | x, 84 | normalized_weight, 85 | self.bias, 86 | self.stride, 87 | self.padding, 88 | self.dilation, 89 | self.groups, 90 | ) 91 | 92 | 93 | class Block(nn.Module): 94 | def __init__(self, dim, dim_out, groups=8): 95 | super().__init__() 96 | self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) 97 | self.norm = nn.GroupNorm(groups, dim_out) 98 | self.act = nn.SiLU() 99 | 100 | def forward(self, x, scale_shift=None): 101 | x = self.proj(x) 102 | x = self.norm(x) 103 | 104 | if exists(scale_shift): 105 | scale, shift = scale_shift 106 | x = x * (scale + 1) + shift 107 | 108 | x = self.act(x) 109 | return x 110 | 111 | 112 | class ResnetBlock(nn.Module): 113 | """https://arxiv.org/abs/1512.03385""" 114 | 115 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 116 | super().__init__() 117 | self.mlp = ( 118 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) 119 | if exists(time_emb_dim) 120 | else None 121 | ) 122 | 123 | self.block1 = Block(dim, dim_out, groups=groups) 124 | self.block2 = Block(dim_out, dim_out, groups=groups) 125 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 126 | 127 | def forward(self, x, time_emb=None): 128 | scale_shift = None 129 | if exists(self.mlp) and exists(time_emb): 130 | time_emb = self.mlp(time_emb) 131 | time_emb = rearrange(time_emb, "b c -> b c 1 1") 132 | scale_shift = time_emb.chunk(2, dim=1) 133 | 134 | h = self.block1(x, scale_shift=scale_shift) 135 | h = self.block2(h) 136 | return h + self.res_conv(x) 137 | 138 | class Attention(nn.Module): 139 | def __init__(self, dim, heads=4, dim_head=32): 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 146 | 147 | def forward(self, x): 148 | b, c, h, w = x.shape 149 | qkv = self.to_qkv(x).chunk(3, dim=1) 150 | q, k, v = map( 151 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 152 | ) 153 | q = q * self.scale 154 | 155 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 156 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 157 | attn = sim.softmax(dim=-1) 158 | 159 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 160 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 161 | return self.to_out(out) 162 | 163 | class LinearAttention(nn.Module): 164 | def __init__(self, dim, heads=4, dim_head=32): 165 | super().__init__() 166 | self.scale = dim_head**-0.5 167 | self.heads = heads 168 | hidden_dim = dim_head * heads 169 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 170 | 171 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 172 | nn.GroupNorm(1, dim)) 173 | 174 | def forward(self, x): 175 | b, c, h, w = x.shape 176 | qkv = self.to_qkv(x).chunk(3, dim=1) 177 | q, k, v = map( 178 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 179 | ) 180 | 181 | q = q.softmax(dim=-2) 182 | k = k.softmax(dim=-1) 183 | 184 | q = q * self.scale 185 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 186 | 187 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 188 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 189 | return self.to_out(out) 190 | 191 | class PreNorm(nn.Module): 192 | def __init__(self, dim, fn): 193 | super().__init__() 194 | self.fn = fn 195 | self.norm = nn.GroupNorm(1, dim) 196 | 197 | def forward(self, x): 198 | x = self.norm(x) 199 | return self.fn(x) 200 | 201 | class Unet(nn.Module): 202 | def __init__( 203 | self, 204 | dim, 205 | init_dim=None, 206 | out_dim=None, 207 | dim_mults=(1, 2, 4), 208 | channels=3, 209 | self_condition=False, 210 | resnet_block_groups=4, 211 | ): 212 | super().__init__() 213 | 214 | # determine dimensions 215 | self.channels = channels 216 | self.self_condition = self_condition 217 | input_channels = channels * (2 if self_condition else 1) 218 | 219 | init_dim = default(init_dim, dim) 220 | self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0) # changed to 1 and 0 from 7,3 221 | 222 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 223 | in_out = list(zip(dims[:-1], dims[1:])) 224 | 225 | block_klass = partial(ResnetBlock, groups=resnet_block_groups) 226 | 227 | # time embeddings 228 | time_dim = dim * 4 229 | 230 | self.time_mlp = nn.Sequential( 231 | SinusoidalPositionEmbeddings(dim), 232 | nn.Linear(dim, time_dim), 233 | nn.GELU(), 234 | nn.Linear(time_dim, time_dim), 235 | ) 236 | 237 | # layers 238 | self.downs = nn.ModuleList([]) 239 | self.ups = nn.ModuleList([]) 240 | num_resolutions = len(in_out) 241 | 242 | for ind, (dim_in, dim_out) in enumerate(in_out): 243 | is_last = ind >= (num_resolutions - 1) 244 | 245 | self.downs.append( 246 | nn.ModuleList( 247 | [ 248 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 249 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 250 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 251 | Downsample(dim_in, dim_out) 252 | if not is_last 253 | else nn.Conv2d(dim_in, dim_out, 3, padding=1), 254 | ] 255 | ) 256 | ) 257 | 258 | mid_dim = dims[-1] 259 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 260 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 261 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 262 | 263 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 264 | is_last = ind == (len(in_out) - 1) 265 | 266 | self.ups.append( 267 | nn.ModuleList( 268 | [ 269 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 270 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), 271 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 272 | Upsample(dim_out, dim_in) 273 | if not is_last 274 | else nn.Conv2d(dim_out, dim_in, 3, padding=1), 275 | ] 276 | ) 277 | ) 278 | 279 | self.out_dim = default(out_dim, channels) 280 | 281 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) 282 | self.final_conv = nn.Conv2d(dim, self.out_dim, 1) 283 | 284 | def forward(self, x, time, x_self_cond=None): 285 | if self.self_condition: 286 | x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) 287 | x = torch.cat((x_self_cond, x), dim=1) 288 | 289 | x = self.init_conv(x) 290 | r = x.clone() 291 | 292 | t = self.time_mlp(time) 293 | 294 | h = [] 295 | 296 | for block1, block2, attn, downsample in self.downs: 297 | x = block1(x, t) 298 | h.append(x) 299 | 300 | x = block2(x, t) 301 | x = attn(x) 302 | h.append(x) 303 | 304 | x = downsample(x) 305 | 306 | x = self.mid_block1(x, t) 307 | x = self.mid_attn(x) 308 | x = self.mid_block2(x, t) 309 | 310 | for block1, block2, attn, upsample in self.ups: 311 | x = torch.cat((x, h.pop()), dim=1) 312 | x = block1(x, t) 313 | 314 | x = torch.cat((x, h.pop()), dim=1) 315 | x = block2(x, t) 316 | x = attn(x) 317 | 318 | x = upsample(x) 319 | 320 | x = torch.cat((x, r), dim=1) 321 | 322 | x = self.final_res_block(x, t) 323 | return self.final_conv(x) 324 | 325 | 326 | class ConsistencyModel(nn.Module): 327 | def __init__(self, dim, channels, dim_mults=(1, 2, 4), s_data=0.5, eps=2e-3): 328 | super().__init__() 329 | self.unet = Unet(dim=dim, dim_mults=dim_mults, channels=channels) 330 | self.channels = channels 331 | self.s_data = s_data 332 | self.eps = eps 333 | 334 | def forward(self, x, t, x_self_cond=None): 335 | f = self.unet(x, t, x_self_cond=x_self_cond) 336 | c_skip = (self.s_data ** 2) / ((t - self.eps) ** 2 + self.s_data ** 2) 337 | c_out = self.s_data * (t - self.eps) / torch.sqrt(self.s_data **2 + t ** 2) 338 | return c_skip[:, None, None, None] * x + c_out[:, None, None, None] * f 339 | 340 | def loss(self, x, z, t1, t2, ema_model): 341 | # t1 : t_{n+1} 342 | # t2 : t_n 343 | 344 | x1 = x + t1[:, None, None, None] * z 345 | x1 = self(x1, t1) 346 | 347 | with torch.no_grad(): 348 | x2 = x + t2[:, None, None, None] * z 349 | x2 = ema_model(x2, t2) 350 | 351 | return F.mse_loss(x1, x2) 352 | 353 | def sample(self, x, ts): 354 | for t in ts[1:]: 355 | z = torch.randn_like(x) 356 | x = x + math.sqrt(t ** 2 - self.eps ** 2) * z 357 | t = torch.tensor([t], dtype=torch.float, device=x.device) 358 | x = self(x, t) 359 | return x 360 | --------------------------------------------------------------------------------