├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── models ├── utils.py ├── guided_diffusion │ ├── nn.py │ └── unet.py ├── pix2pix.py ├── attention_unet.py ├── wrapper.py ├── trans_unet.py ├── res_unet.py └── palette.py ├── callbacks └── ema.py ├── .gitignore ├── README.md ├── dataset.py ├── main.py └── report.py /requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | pytorch_lightning==2.0.2 3 | PyYAML==6.0 4 | torch==2.0.0 5 | torch_ema==0.3 6 | torchmetrics==0.11.4 7 | torchvision==0.15.1 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report 4 | title: "[BUG]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to `...` 16 | 2. Run `...` 17 | 3. See error 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Additional context** 26 | Add any other context about the problem here. 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEAT]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | from torchmetrics.functional import ( 5 | structural_similarity_index_measure, 6 | peak_signal_noise_ratio, 7 | mean_squared_error, 8 | ) 9 | 10 | 11 | denormalize = transforms.Lambda(lambda x: torch.clamp(x * 0.5 + 0.5, 0, 1)) 12 | to_int = transforms.ConvertImageDtype(torch.uint8) 13 | 14 | 15 | def init_weights(module: nn.Module): 16 | if isinstance( 17 | module, 18 | (nn.Conv1d, nn.Conv2d, nn.ConvTranspose2d, nn.Linear), 19 | ): 20 | nn.init.normal_(module.weight, 0.0, 0.02) 21 | 22 | if isinstance( 23 | module, 24 | (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm), 25 | ): 26 | nn.init.constant_(module.weight, 1.0) 27 | 28 | nn.init.constant_(module.bias, 0.0) 29 | 30 | 31 | def get_parameter_count(model: nn.Module): 32 | if isinstance(model, nn.Module): 33 | return sum(p.numel() for p in model.parameters()) 34 | 35 | return 0 36 | 37 | 38 | def ssim(pred: torch.Tensor, target: torch.Tensor): 39 | return structural_similarity_index_measure(pred, target, data_range=1.0) 40 | 41 | 42 | def psnr(pred: torch.Tensor, target: torch.Tensor): 43 | return peak_signal_noise_ratio(pred, target, data_range=1.0) 44 | 45 | 46 | def rmse(pred: torch.Tensor, target: torch.Tensor): 47 | return mean_squared_error(pred, target, squared=False) 48 | -------------------------------------------------------------------------------- /callbacks/ema.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch_ema import ExponentialMovingAverage 3 | 4 | 5 | class EMACallback(pl.callbacks.Callback): 6 | """ 7 | Exponential Moving Average callback to be used with any pytorch lightning 8 | module. 9 | 10 | """ 11 | 12 | def __init__(self, decay=0.9999): 13 | self.decay = decay 14 | self.ema = None 15 | 16 | def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 17 | """Initialize EMA.""" 18 | 19 | self.ema = ExponentialMovingAverage( 20 | pl_module.parameters(), 21 | decay=self.decay, 22 | ) 23 | 24 | def on_train_batch_end( 25 | self, 26 | trainer: pl.Trainer, 27 | pl_module: pl.LightningModule, 28 | *args, 29 | **kwargs, 30 | ): 31 | """Update the stored parameters using a moving average.""" 32 | 33 | self.ema.update() 34 | 35 | def on_validation_start( 36 | self, 37 | trainer: pl.Trainer, 38 | pl_module: pl.LightningModule, 39 | ): 40 | """Do validation using the stored parameters.""" 41 | 42 | self.ema.store() 43 | self.ema.copy_to() 44 | 45 | def on_validation_end( 46 | self, 47 | trainer: pl.Trainer, 48 | pl_module: pl.LightningModule, 49 | ): 50 | """Restore original parameters to resume training later.""" 51 | 52 | self.ema.restore() 53 | 54 | def on_save_checkpoint( 55 | self, 56 | trainer: pl.Trainer, 57 | pl_module: pl.LightningModule, 58 | checkpoint: dict[str, any], 59 | ): 60 | """Save state dict on checkpoint.""" 61 | 62 | return self.ema.state_dict() 63 | 64 | def on_load_checkpoint( 65 | self, 66 | trainer: pl.Trainer, 67 | pl_module: pl.LightningModule, 68 | callback_state: dict[str, any], 69 | ): 70 | """Load state dict on checkpoint.""" 71 | 72 | self.ema.load_state_dict(callback_state) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | data/ 132 | .DS_Store 133 | 134 | thesis_data.zip 135 | lightning_logs/ 136 | *.zip 137 | 138 | checkpoints/ 139 | logs/ 140 | output/ 141 | 142 | slurm* 143 | test* 144 | wandb/ 145 | reports/ 146 | reports-experimental/ 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning for Photoacoustic Imaging Reconstruction 2 | 3 | Implementation of models for the work presented in the following articles. 4 | 5 | Cristian P. Jensen, Kalloor Joseph Francis, and Navchetan Awasthi "Image depth improvement in photoacoustic imaging using transformer based generative adversarial networks", Proc. SPIE PC12842, Photons Plus Ultrasound: Imaging and Sensing 2024, PC128421V (13 March 2024); https://doi.org/10.1117/12.3001537 6 | 7 | Please cite this if you find this useful in your work. 8 | 9 | ## Models 10 | 11 | The following image-to-image translation models are implemented using PyTorch 12 | and PyTorch Lightning: 13 | - Pix2Pix [(Isola et al. 2018)](https://arxiv.org/abs/1611.07004); 14 | - Attention U-net [(Oktay et al. 2018)](https://arxiv.org/abs/1804.03999); 15 | - Residual U-net with the following basic blocks: 16 | - Res18 [(He et al. 2015)](https://arxiv.org/abs/1512.03385); 17 | - Res50 [(He et al. 2015)](https://arxiv.org/abs/1512.03385); 18 | - ResV2 [(He et al. 2016)](https://arxiv.org/abs/1603.05027); 19 | - ResNeXt [(Xie et al. 2017)](https://arxiv.org/abs/1611.05431). 20 | - Trans U-net [(Chen et al. 2021)](https://arxiv.org/abs/2102.04306); 21 | - Palette [(Saharia et al. 2022)](https://arxiv.org/abs/2111.05826). 22 | 23 | More models can easily be added by using the `UnetWrapper` class. 24 | 25 | ## Loss functions 26 | 27 | The following loss functions are implemented: 28 | - GAN loss, using Pix2Pix' discriminator (to change the used adversarial 29 | network, you must change the `Discriminator` class in `models/wrapper.py`); 30 | - MSE loss; 31 | - SSIM loss; 32 | - PSNR loss. 33 | - Combination of SSIM and PSNR loss. 34 | 35 | ## Data organisation 36 | 37 | The organisation of your data does not matter. The only important thing is 38 | the data file, a [YAML](https://yaml.org/) file containing a list of 39 | input-ground truth entries. The input and ground truth files must be relative 40 | to the directory of the data file. For example: 41 | ```yaml 42 | - input: input/00001.png 43 | ground_truth: ground_truth/00001.png 44 | - input: input/00002.png 45 | ground_truth: ground_truth/00002.png 46 | - input: input/00003.png 47 | ground_truth: ground_truth/00003.png 48 | ``` 49 | 50 | ## Training a model 51 | 52 | To train a model, run the following: 53 | ```bash 54 | python main.py 55 | ``` 56 | 57 | When training, the model with the highest SSIM on the validation dataset will 58 | be selected as the "best" checkpoint. 59 | 60 | ## Testing a model 61 | 62 | To test a trained model, run the following: 63 | ```bash 64 | python report.py 65 | ``` 66 | It essentially takes a model checkpoint and test data file as input and outputs 67 | metrics and information about the model. The following metrics are reported: 68 | - SSIM per image; 69 | - PSNR per image; 70 | - Mean SSIM; 71 | - Mean PSNR; 72 | - Mean RMSE; 73 | - FLOPs; 74 | - Parameter count; 75 | - SSIM over depth (vertically) of the image (this is only relevant for PAI 76 | reconstruction). 77 | - Outputs of the model. 78 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision.io import read_image, ImageReadMode 4 | from torchvision import transforms 5 | import pytorch_lightning as pl 6 | import yaml 7 | import os 8 | from typing import Optional 9 | 10 | 11 | class ImageDataModule(pl.LightningDataModule): 12 | def __init__( 13 | self, 14 | data_list_file: str, 15 | val_list_file: Optional[str] = None, 16 | batch_size: int = 1, 17 | normalize: bool = True, 18 | ): 19 | super().__init__() 20 | 21 | # Load data 22 | with open(data_list_file, "r") as f: 23 | data_list = yaml.safe_load(f) 24 | 25 | data_dir = os.path.dirname(data_list_file) 26 | self.data_tuples: list[(str, str, int)] = list(map( 27 | lambda x: ( 28 | os.path.join(data_dir, x["input"]), 29 | os.path.join(data_dir, x["ground_truth"]), 30 | ), 31 | data_list, 32 | )) 33 | 34 | # Load validation data 35 | if val_list_file is not None: 36 | with open(val_list_file, "r") as f: 37 | val_list = yaml.safe_load(f) 38 | 39 | val_dir = os.path.dirname(val_list_file) 40 | self.val_tuples: list[(str, str, int)] = list(map( 41 | lambda x: ( 42 | os.path.join(val_dir, x["input"]), 43 | os.path.join(val_dir, x["ground_truth"]), 44 | ), 45 | val_list, 46 | )) 47 | 48 | self.batch_size = batch_size 49 | self.normalize = normalize 50 | 51 | trans = [ 52 | transforms.Resize((256, 256), antialias=True), 53 | transforms.ConvertImageDtype(torch.float32), 54 | ] 55 | 56 | if normalize: 57 | trans.append( 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 59 | ) 60 | 61 | self.transform = transforms.Compose(trans) 62 | 63 | def setup(self, stage: str): 64 | if stage == "fit": 65 | self.train_split = self.data_tuples 66 | self.val_split = self.val_tuples 67 | 68 | if stage == "validate": 69 | self.val_split = self.data_tuples 70 | 71 | if stage == "test": 72 | self.test_split = self.data_tuples 73 | 74 | if stage == "predict": 75 | self.pred_split = self.data_tuples 76 | 77 | def train_dataloader(self): 78 | return DataLoader( 79 | ImageDataset(self.train_split, transform=self.transform), 80 | batch_size=self.batch_size, 81 | shuffle=True, 82 | drop_last=False, 83 | ) 84 | 85 | def val_dataloader(self): 86 | return DataLoader( 87 | ImageDataset(self.val_split, transform=self.transform), 88 | batch_size=self.batch_size, 89 | shuffle=False, 90 | drop_last=False, 91 | ) 92 | 93 | def test_dataloader(self): 94 | return DataLoader( 95 | ImageDataset(self.test_split, transform=self.transform), 96 | batch_size=self.batch_size, 97 | shuffle=False, 98 | drop_last=False, 99 | ) 100 | 101 | def predict_dataloader(self): 102 | return DataLoader( 103 | ImageDataset(self.pred_split, transform=self.transform), 104 | batch_size=self.batch_size, 105 | shuffle=False, 106 | drop_last=False, 107 | ) 108 | 109 | 110 | class ImageDataset(Dataset): 111 | """Dataset class used for image data that has input and targets in separate 112 | directories in the same order. Supports batches.""" 113 | 114 | def __init__( 115 | self, 116 | data_tuples: list[(str, str, int)], 117 | transform=None, 118 | ): 119 | super().__init__() 120 | self.data_tuples = data_tuples 121 | self.transform = transform 122 | 123 | def __len__(self): 124 | return len(self.data_tuples) 125 | 126 | def __getitem__(self, idx): 127 | (input_, gt) = self.data_tuples[idx] 128 | 129 | input_tensor = read_image(input_, mode=ImageReadMode.GRAY) 130 | input_tensor = self.transform(input_tensor) 131 | gt_tensor = read_image(gt, mode=ImageReadMode.GRAY) 132 | gt_tensor = self.transform(gt_tensor) 133 | 134 | return input_tensor, gt_tensor 135 | -------------------------------------------------------------------------------- /models/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class GroupNorm(nn.GroupNorm): 12 | def forward(self, x): 13 | return super().forward(x.float()).type(x.dtype) 14 | 15 | 16 | class BatchNorm2d(nn.BatchNorm2d): 17 | def forward(self, x): 18 | return super().forward(x.float()).type(x.dtype) 19 | 20 | 21 | class BatchNorm1d(nn.BatchNorm1d): 22 | def forward(self, x): 23 | return super().forward(x.float()).type(x.dtype) 24 | 25 | 26 | def zero_module(module): 27 | """ 28 | Zero out the parameters of a module and return it. 29 | """ 30 | for p in module.parameters(): 31 | p.detach().zero_() 32 | return module 33 | 34 | 35 | def scale_module(module, scale): 36 | """ 37 | Scale the parameters of a module and return it. 38 | """ 39 | for p in module.parameters(): 40 | p.detach().mul_(scale) 41 | return module 42 | 43 | 44 | def mean_flat(tensor): 45 | """ 46 | Take the mean over all non-batch dimensions. 47 | """ 48 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 49 | 50 | 51 | def normalization1d(channels): 52 | """ 53 | Make a standard normalization layer. 54 | 55 | :param channels: number of input channels. 56 | :return: an nn.Module for normalization. 57 | """ 58 | return BatchNorm1d(channels) 59 | 60 | 61 | def normalization2d(channels): 62 | """ 63 | Make a standard normalization layer. 64 | 65 | :param channels: number of input channels. 66 | :return: an nn.Module for normalization. 67 | """ 68 | return BatchNorm2d(channels) 69 | 70 | 71 | def checkpoint(func, inputs, params, flag): 72 | """ 73 | Evaluate a function without caching intermediate activations, allowing for 74 | reduced memory at the expense of extra compute in the backward pass. 75 | 76 | :param func: the function to evaluate. 77 | :param inputs: the argument sequence to pass to `func`. 78 | :param params: a sequence of parameters `func` depends on but does not 79 | explicitly take as arguments. 80 | :param flag: if False, disable gradient checkpointing. 81 | """ 82 | if flag: 83 | args = tuple(inputs) + tuple(params) 84 | return CheckpointFunction.apply(func, len(inputs), *args) 85 | else: 86 | return func(*inputs) 87 | 88 | 89 | class CheckpointFunction(torch.autograd.Function): 90 | @staticmethod 91 | def forward(ctx, run_function, length, *args): 92 | ctx.run_function = run_function 93 | ctx.input_tensors = list(args[:length]) 94 | ctx.input_params = list(args[length:]) 95 | with torch.no_grad(): 96 | output_tensors = ctx.run_function(*ctx.input_tensors) 97 | return output_tensors 98 | 99 | @staticmethod 100 | def backward(ctx, *output_grads): 101 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 102 | with torch.enable_grad(): 103 | # Fixes a bug where the first op in run_function modifies the 104 | # Tensor storage in place, which is not allowed for detach()'d 105 | # Tensors. 106 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 107 | output_tensors = ctx.run_function(*shallow_copies) 108 | input_grads = torch.autograd.grad( 109 | output_tensors, 110 | ctx.input_tensors + ctx.input_params, 111 | output_grads, 112 | allow_unused=True, 113 | ) 114 | del ctx.input_tensors 115 | del ctx.input_params 116 | del output_tensors 117 | return (None, None) + input_grads 118 | 119 | 120 | def count_flops_attn(model, _x, y): 121 | """ 122 | A counter for the `thop` package to count the operations in an 123 | attention operation. 124 | Meant to be used like: 125 | macs, params = thop.profile( 126 | model, 127 | inputs=(inputs, timestamps), 128 | custom_ops={QKVAttention: QKVAttention.count_flops}, 129 | ) 130 | """ 131 | b, c, *spatial = y[0].shape 132 | num_spatial = int(np.prod(spatial)) 133 | # We perform two matmuls with the same number of ops. 134 | # The first computes the weight matrix, the second computes 135 | # the combination of the value vectors. 136 | matmul_ops = 2 * b * (num_spatial ** 2) * c 137 | model.total_ops += torch.DoubleTensor([matmul_ops]) 138 | 139 | 140 | def gamma_embedding(gammas, dim, max_period=10000): 141 | """ 142 | Create sinusoidal timestep embeddings. 143 | :param gammas: a 1-D Tensor of N indices, one per batch element. 144 | These may be fractional. 145 | :param dim: the dimension of the output. 146 | :param max_period: controls the minimum frequency of the embeddings. 147 | :return: an [N x dim] Tensor of positional embeddings. 148 | """ 149 | half = dim // 2 150 | freqs = torch.exp( 151 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 152 | ).to(device=gammas.device) 153 | args = gammas[:, None].float() * freqs[None] 154 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 155 | if dim % 2: 156 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 157 | return embedding 158 | -------------------------------------------------------------------------------- /models/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Literal 4 | from .wrapper import UnetWrapper 5 | 6 | 7 | class Pix2Pix(UnetWrapper): 8 | """Implementation of pix2pix (Isola et al. 2018). 9 | 10 | :param in_channels: Input channels that can vary if the images are 11 | grayscale or color. 12 | :param out_channels: Input channels that can vary if the images are 13 | grayscale or color. 14 | :param channel_mults: Channel multiples that define the depth and width of 15 | the U-net architecture. 16 | :param dropout: Dropout percentage used in some of the decoder blocks. 17 | :param loss_type: Loss type. One of "gan", "ssim", "psnr", "mse", 18 | "ssim+psnr". 19 | 20 | :input: [N x in_channels x H x W] 21 | :output: [N x out_channels x H x W] 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | in_channels: int = 3, 28 | out_channels: int = 3, 29 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 30 | dropout: float = 0.5, 31 | loss_type: Literal["gan", "ssim", "psnr", "ssim+psnr" "mse"] = "gan", 32 | ): 33 | unet = Unet( 34 | in_channels, 35 | out_channels, 36 | channel_mults=channel_mults, 37 | dropout=dropout, 38 | ) 39 | 40 | super().__init__(unet, loss_type=loss_type) 41 | 42 | self.example_input_array = torch.Tensor(2, in_channels, 256, 256) 43 | self.save_hyperparameters() 44 | 45 | 46 | class EncoderBlock(nn.Module): 47 | """Encoder block that downsamples the input by 2. 48 | 49 | :param in_channels: Input channels. 50 | :param out_channels: Output channels. 51 | :param norm: Whether to use batch normalization or not. 52 | 53 | :input: [N x in_channels x H x W] 54 | :output: [N x out_channels x (H / 2) x (W / 2)] 55 | 56 | """ 57 | 58 | def __init__(self, in_channels: int, out_channels: int, norm: bool = True): 59 | super().__init__() 60 | 61 | self.encode = nn.Sequential( 62 | nn.LeakyReLU(0.2), 63 | nn.Conv2d( 64 | in_channels, 65 | out_channels, 66 | kernel_size=4, 67 | stride=2, 68 | padding=1 69 | ), 70 | nn.BatchNorm2d(out_channels) if norm else nn.Identity(), 71 | ) 72 | 73 | def forward(self, x): 74 | return self.encode(x) 75 | 76 | 77 | class DecoderBlock(nn.Module): 78 | """Decoder block that upsamples the input by 2. 79 | 80 | :param in_channels: Input channels. 81 | :param out_channels: Output channels. 82 | :param dropout: Dropout percentage. 83 | 84 | :input: [N x in_channels x H x W] 85 | :output: [N x out_channels x (H * 2) x (W * 2)] 86 | 87 | """ 88 | 89 | def __init__( 90 | self, 91 | in_channels: int, 92 | out_channels: int, 93 | dropout: float = 0.5, 94 | ): 95 | super().__init__() 96 | 97 | self.decode = nn.Sequential( 98 | nn.ReLU(), 99 | nn.ConvTranspose2d( 100 | in_channels, 101 | out_channels, 102 | kernel_size=4, 103 | stride=2, 104 | padding=1 105 | ), 106 | nn.BatchNorm2d(out_channels), 107 | nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), 108 | ) 109 | 110 | def forward(self, x): 111 | return self.decode(x) 112 | 113 | 114 | class Unet(nn.Module): 115 | """U-net used as the generator in pix2pix GAN. 116 | 117 | :param in_channels: Input channels that can vary if the images are 118 | grayscale or color. 119 | :param out_channels: Input channels that can vary if the images are 120 | grayscale or color. 121 | :param channel_mults: Channel multiples that define the depth and width of 122 | the U-net architecture. 123 | :param dropout: Dropout percentage used in some of the decoder blocks. 124 | 125 | :input: [N x in_channels x H x W] 126 | :output: [N x out_channels x H x W] 127 | 128 | """ 129 | 130 | def __init__( 131 | self, 132 | in_channels: int = 3, 133 | out_channels: int = 3, 134 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 135 | dropout: float = 0.5, 136 | ): 137 | super().__init__() 138 | 139 | # Encoder blocks 140 | encoders = [ 141 | nn.Conv2d( 142 | in_channels, 143 | channel_mults[0] * 64, 144 | kernel_size=4, 145 | stride=2, 146 | padding=1 147 | ), 148 | ] 149 | in_channels = channel_mults[0] * 64 150 | for level, mult in enumerate(channel_mults[1:], 1): 151 | channels = mult * 64 152 | 153 | encoders.append( 154 | EncoderBlock( 155 | in_channels, 156 | channels, 157 | norm=level != len(channel_mults) - 1, 158 | ) 159 | ) 160 | 161 | in_channels = channels 162 | 163 | self.encoders = nn.ModuleList(encoders) 164 | 165 | # Decoder blocks 166 | decoders = [] 167 | for level, mult in reversed(list(enumerate(channel_mults[:-1]))): 168 | channels = mult * 64 169 | 170 | decoders.append( 171 | DecoderBlock( 172 | in_channels, 173 | channels, 174 | # Only dropout in the lowest three decoder blocks that are 175 | # at the widest part 176 | dropout=dropout if ( 177 | mult == max(channel_mults) and 178 | level > len(channel_mults) - 5 179 | ) else 0, 180 | ) 181 | ) 182 | 183 | in_channels = channels * 2 184 | 185 | decoders.append( 186 | nn.ConvTranspose2d( 187 | in_channels, 188 | out_channels, 189 | kernel_size=4, 190 | stride=2, 191 | padding=1, 192 | ) 193 | ) 194 | 195 | self.decoders = nn.ModuleList(decoders) 196 | self.out = nn.Tanh() 197 | 198 | def forward(self, x): 199 | h = x.type(torch.float32) 200 | 201 | feats = [] 202 | for encoder in self.encoders: 203 | h = encoder(h) 204 | feats.append(h) 205 | 206 | # Remove last feature map, since that should not be used in 207 | # skip-connection 208 | feats.pop() 209 | 210 | for index, decoder in enumerate(self.decoders): 211 | if index != 0: 212 | h = torch.cat([h, feats.pop()], dim=1) 213 | 214 | h = decoder(h) 215 | 216 | return self.out(h) 217 | -------------------------------------------------------------------------------- /models/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Literal 4 | from .wrapper import UnetWrapper 5 | from .pix2pix import EncoderBlock, DecoderBlock 6 | 7 | 8 | class AttentionUnetGAN(UnetWrapper): 9 | """The same model as pix2pix modified to use attention in the skip 10 | connections (Oktay et al. 2018). 11 | 12 | :param in_channels: Input channels that can vary if the images are 13 | grayscale or color. 14 | :param out_channels: Input channels that can vary if the images are 15 | grayscale or color. 16 | :param channel_mults: Channel multiples that define the depth and width of 17 | the U-net architecture. 18 | :param dropout: Dropout percentage used in some of the decoder blocks. 19 | :param loss_type: Loss type. One of "gan", "ssim", "psnr", "mse", 20 | "ssim+psnr". 21 | 22 | :input: [N x in_channels x H x W] 23 | :output: [N x out_channels x H x W] 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int = 3, 30 | out_channels: int = 3, 31 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 32 | dropout: float = 0.5, 33 | loss_type: Literal["gan", "ssim", "psnr", "ssim+psnr", "mse"] = "gan", 34 | ): 35 | unet = AttentionUnet( 36 | in_channels, 37 | out_channels, 38 | channel_mults=channel_mults, 39 | dropout=dropout, 40 | ) 41 | 42 | super().__init__(unet, loss_type=loss_type) 43 | 44 | self.example_input_array = torch.Tensor(2, in_channels, 256, 256) 45 | self.save_hyperparameters() 46 | 47 | 48 | class AttentionBlock(nn.Module): 49 | """Attention block used in the skip connections of the attention U-net. 50 | 51 | :param input_channels: Amount of channels that the encoder layer that is 52 | being skipped has. 53 | :param signal_channels: Amount of channels that the signal has, which is 54 | the output of the previous decoder layer. 55 | :param attention_channels: Amount of channels that the input and signal are 56 | mapped to. 57 | 58 | :input x: [N x input_channels x H x W] 59 | :input signal: [N x signal_channels x H x W] 60 | :output: [N x input_channels x H x W] 61 | 62 | """ 63 | 64 | def __init__( 65 | self, 66 | input_channels: int, 67 | signal_channels: int, 68 | attention_channels: int, 69 | ): 70 | super().__init__() 71 | 72 | self.input_gate = nn.Sequential( 73 | nn.Conv2d(input_channels, attention_channels, kernel_size=1), 74 | nn.BatchNorm2d(attention_channels), 75 | ) 76 | 77 | self.signal_gate = nn.Sequential( 78 | nn.Conv2d(signal_channels, attention_channels, kernel_size=1), 79 | nn.BatchNorm2d(attention_channels), 80 | ) 81 | 82 | self.attention = nn.Sequential( 83 | nn.Conv2d(attention_channels, 1, kernel_size=1), 84 | nn.BatchNorm2d(1), 85 | nn.Sigmoid(), 86 | ) 87 | 88 | self.relu = nn.ReLU() 89 | 90 | def forward(self, x, signal): 91 | h_input = self.input_gate(x) 92 | h_signal = self.signal_gate(signal) 93 | h = self.relu(h_signal + h_input) 94 | attention = self.attention(h) 95 | 96 | return x * attention 97 | 98 | 99 | class AttentionUnet(nn.Module): 100 | """U-net with attention used in the skip-connections. 101 | 102 | :param in_channels: Input channels that can vary if the images are 103 | grayscale or color. 104 | :param out_channels: Input channels that can vary if the images are 105 | grayscale or color. 106 | :param channel_mults: Channel multiples that define the depth and width of 107 | the U-net architecture. 108 | :param dropout: Dropout percentage used in some of the decoder blocks. 109 | 110 | :input: [N x in_channels x H x W] 111 | :output: [N x out_channels x H x W] 112 | 113 | """ 114 | 115 | def __init__( 116 | self, 117 | in_channels: int = 3, 118 | out_channels: int = 3, 119 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 120 | dropout: float = 0.5, 121 | ): 122 | super().__init__() 123 | 124 | # Encoder blocks 125 | encoders = [ 126 | nn.Conv2d( 127 | in_channels, 128 | channel_mults[0] * 64, 129 | kernel_size=4, 130 | stride=2, 131 | padding=1 132 | ), 133 | ] 134 | in_channels = channel_mults[0] * 64 135 | for level, mult in enumerate(channel_mults[1:], 1): 136 | channels = mult * 64 137 | 138 | encoders.append( 139 | EncoderBlock( 140 | in_channels, 141 | channels, 142 | norm=level != len(channel_mults) - 1, 143 | ) 144 | ) 145 | 146 | in_channels = channels 147 | 148 | self.encoders = nn.ModuleList(encoders) 149 | 150 | # Decoder and attention blocks 151 | decoders = [] 152 | attention_blocks = [] 153 | for level, mult in reversed(list(enumerate(channel_mults[:-1]))): 154 | channels = mult * 64 155 | 156 | decoders.append( 157 | DecoderBlock( 158 | in_channels, 159 | channels, 160 | # Only dropout in the lowest three decoder blocks that are 161 | # at the widest part 162 | dropout=dropout if ( 163 | mult == max(channel_mults) and 164 | level > len(channel_mults) - 5 165 | ) else 0, 166 | ) 167 | ) 168 | attention_blocks.append( 169 | AttentionBlock(channels, channels, channels // 2) 170 | ) 171 | 172 | in_channels = channels * 2 173 | 174 | decoders.append( 175 | nn.ConvTranspose2d( 176 | in_channels, 177 | out_channels, 178 | kernel_size=4, 179 | stride=2, 180 | padding=1, 181 | ) 182 | ) 183 | 184 | self.decoders = nn.ModuleList(decoders) 185 | self.attention_blocks = nn.ModuleList(attention_blocks) 186 | self.out = nn.Tanh() 187 | 188 | def forward(self, x): 189 | h = x.type(torch.float32) 190 | 191 | feats = [] 192 | for encoder in self.encoders: 193 | h = encoder(h) 194 | feats.append(h) 195 | 196 | # Remove last feature map, since that should not be used in 197 | # skip-connection 198 | feats.pop() 199 | 200 | for index, decoder in enumerate(self.decoders): 201 | if index != 0: 202 | attention = self.attention_blocks[index - 1] 203 | s = attention(feats.pop(), h) 204 | h = torch.cat([h, s], dim=1) 205 | 206 | h = decoder(h) 207 | 208 | return self.out(h) 209 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import argparse 4 | from argparse import ArgumentParser 5 | import pathlib 6 | from models.pix2pix import Pix2Pix 7 | from models.palette import Palette 8 | from models.attention_unet import AttentionUnetGAN 9 | from models.res_unet import ResUnetGAN 10 | from models.trans_unet import TransUnetGAN 11 | from dataset import ImageDataModule 12 | from callbacks.ema import EMACallback 13 | 14 | 15 | torch.set_float32_matmul_precision("medium") 16 | 17 | 18 | def main(hparams): 19 | channel_mults = [int(x) for x in hparams.channel_mults.split(",")] 20 | att_res = [int(x) for x in hparams.attention_res.split(",")] 21 | 22 | model = None 23 | match hparams.model: 24 | case "pix2pix": 25 | model = Pix2Pix( 26 | in_channels=1, 27 | out_channels=1, 28 | channel_mults=channel_mults, 29 | dropout=hparams.dropout, 30 | loss_type=hparams.loss_type, 31 | ) 32 | 33 | case "attention_unet": 34 | model = AttentionUnetGAN( 35 | in_channels=1, 36 | out_channels=1, 37 | channel_mults=channel_mults, 38 | dropout=hparams.dropout, 39 | loss_type=hparams.loss_type, 40 | ) 41 | 42 | case "palette": 43 | model = Palette( 44 | in_channels=1, 45 | out_channels=1, 46 | channel_mults=channel_mults, 47 | attention_res=att_res, 48 | dropout=hparams.dropout, 49 | schedule_type=hparams.schedule_type, 50 | learn_var=hparams.learn_variance, 51 | ) 52 | 53 | case "res18_unet": 54 | model = ResUnetGAN( 55 | in_channels=1, 56 | out_channels=1, 57 | res_type="18", 58 | channel_mults=channel_mults, 59 | dropout=hparams.dropout, 60 | loss_type=hparams.loss_type, 61 | ) 62 | 63 | case "res50_unet": 64 | model = ResUnetGAN( 65 | in_channels=1, 66 | out_channels=1, 67 | res_type="50", 68 | channel_mults=channel_mults, 69 | dropout=hparams.dropout, 70 | loss_type=hparams.loss_type, 71 | ) 72 | 73 | case "resv2_unet": 74 | model = ResUnetGAN( 75 | in_channels=1, 76 | out_channels=1, 77 | res_type="v2", 78 | channel_mults=channel_mults, 79 | dropout=hparams.dropout, 80 | loss_type=hparams.loss_type, 81 | ) 82 | 83 | case "resnext_unet": 84 | model = ResUnetGAN( 85 | in_channels=1, 86 | out_channels=1, 87 | res_type="next", 88 | channel_mults=channel_mults, 89 | dropout=hparams.dropout, 90 | loss_type=hparams.loss_type, 91 | ) 92 | 93 | case "trans_unet": 94 | model = TransUnetGAN( 95 | in_channels=1, 96 | out_channels=1, 97 | patch_size=4, 98 | channel_mults=channel_mults, 99 | dropout=hparams.dropout, 100 | loss_type=hparams.loss_type, 101 | ) 102 | 103 | case _: 104 | raise ValueError(f"Incorrect model name ({hparams.model})") 105 | 106 | data_module = ImageDataModule( 107 | hparams.data, 108 | hparams.val_data, 109 | batch_size=hparams.batch_size, 110 | normalize=True, 111 | ) 112 | 113 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 114 | save_top_k=1, 115 | monitor="val_ssim", 116 | mode="max", 117 | filename="best", 118 | save_last=model == "palette", 119 | ) 120 | 121 | csv_logger = pl.loggers.CSVLogger("logs", name=hparams.name) 122 | 123 | trainer = pl.Trainer( 124 | max_epochs=hparams.epochs, 125 | max_steps=hparams.steps, 126 | log_every_n_steps=10, 127 | check_val_every_n_epoch=hparams.val_epochs, 128 | logger=[csv_logger], 129 | precision=hparams.precision, 130 | callbacks=[ 131 | EMACallback(0.9999), 132 | checkpoint_callback, 133 | ] if hparams.ema else [checkpoint_callback], 134 | benchmark=True, 135 | ) 136 | trainer.fit(model, data_module) 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = ArgumentParser() 141 | parser.add_argument("name") 142 | parser.add_argument( 143 | "-d", 144 | "--data", 145 | type=pathlib.Path, 146 | help=""" 147 | YAML file containing filenames of images that make up the training 148 | data. 149 | """, 150 | ) 151 | parser.add_argument( 152 | "-vd", 153 | "--val-data", 154 | type=pathlib.Path, 155 | help=""" 156 | YAML file containing filenames of images that make up the 157 | validation data. 158 | """, 159 | ) 160 | parser.add_argument("-e", "--epochs", default=200, type=int) 161 | parser.add_argument("-s", "--steps", default=-1, type=int) 162 | parser.add_argument("--batch-size", default=8, type=int) 163 | parser.add_argument( 164 | "--val-epochs", 165 | default=10, 166 | help="Validation run every n epochs.", 167 | type=int 168 | ) 169 | parser.add_argument( 170 | "--precision", 171 | default="32", 172 | help="Floating-point precision" 173 | ) 174 | parser.add_argument( 175 | "--ema", 176 | default=False, 177 | action=argparse.BooleanOptionalAction, 178 | help="Whether to use EMA weight updating." 179 | ) 180 | parser.add_argument( 181 | "--channel-mults", 182 | default="1,2,4,8,8,8,8,8", 183 | help=""" 184 | Defines the U-net architecture's depth and width. Should be 185 | comma-separated powers of 2. 186 | """, 187 | ) 188 | parser.add_argument( 189 | "--attention-res", 190 | default="8,4,2", 191 | help=""" 192 | At what downsample multiples attention should be used, if the model 193 | supports it. Should be comma-separated powers of 2. 194 | """, 195 | ) 196 | parser.add_argument( 197 | "--dropout", 198 | default=0.0, 199 | type=float, 200 | ) 201 | parser.add_argument( 202 | "--loss-type", 203 | default="gan", 204 | choices=["gan", "ssim", "psnr", "ssim+psnr", "mse"], 205 | ) 206 | parser.add_argument( 207 | "--schedule-type", 208 | default="linear", 209 | choices=["linear", "cosine"], 210 | ) 211 | parser.add_argument( 212 | "--learn-variance", 213 | default=False, 214 | action=argparse.BooleanOptionalAction, 215 | ) 216 | parser.add_argument( 217 | "-m", 218 | "--model", 219 | default="pix2pix", 220 | choices=[ 221 | "pix2pix", 222 | "attention_unet", 223 | "res18_unet", 224 | "res50_unet", 225 | "resv2_unet", 226 | "resnext_unet", 227 | "trans_unet", 228 | "palette", 229 | ], 230 | ) 231 | args = parser.parse_args() 232 | 233 | main(args) 234 | -------------------------------------------------------------------------------- /models/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pytorch_lightning as pl 5 | from typing import Literal 6 | from .utils import denormalize, init_weights, ssim, psnr, rmse 7 | 8 | 9 | class UnetWrapper(pl.LightningModule): 10 | """U-net wrapper with different loss functions. 11 | 12 | :param unet: U-net model. 13 | :param loss_type: Loss function for the U-net. One of "gan", "ssim", 14 | "psnr", "mse", "ssim+psnr". 15 | 16 | :input: [N x C x H x W] 17 | :output: [N x C x H x W] 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | unet: nn.Module, 24 | loss_type: Literal["gan", "ssim", "psnr", "ssim+psnr" "mse"] = "gan", 25 | ): 26 | super().__init__() 27 | self.automatic_optimization = False 28 | 29 | self.unet = unet 30 | self.loss_type = loss_type 31 | 32 | self.discriminator = None 33 | if loss_type == "gan": 34 | self.discriminator = Discriminator() 35 | self.discriminator.apply(init_weights) 36 | 37 | self.unet.apply(init_weights) 38 | 39 | def forward(self, x): 40 | return self.unet(x) 41 | 42 | def loss(self, x, pred, target): 43 | if self.loss_type == "gan": 44 | pred_label = self.discriminator(x, pred) 45 | bce_loss = F.binary_cross_entropy_with_logits( 46 | pred_label, 47 | torch.ones_like(pred_label), 48 | ) 49 | l1_loss = F.l1_loss(pred, target) 50 | 51 | return bce_loss + 50 * l1_loss 52 | 53 | if self.loss_type == "ssim": 54 | return -ssim(denormalize(pred), denormalize(target)) 55 | 56 | if self.loss_type == "psnr": 57 | return -psnr(denormalize(pred), denormalize(target)) 58 | 59 | if self.loss_type == "ssim+psnr": 60 | return -( 61 | 30 * ssim(denormalize(pred), denormalize(target)) + 62 | psnr(denormalize(pred), denormalize(target)) 63 | ) 64 | 65 | if self.loss_type == "mse": 66 | return F.mse_loss(pred, target) 67 | 68 | def discriminator_loss( 69 | self, 70 | pred_label: torch.Tensor, 71 | target_label: torch.Tensor, 72 | ) -> torch.Tensor: 73 | """ 74 | Loss function for discriminator. 75 | 76 | :param pred_label: predicted label of generated image by discriminator. 77 | :param target_label: predicted label of real target image by 78 | discriminator. 79 | :returns: Loss. 80 | 81 | """ 82 | 83 | # The discriminator should predict all zeros for "fake" images 84 | pred_loss = F.binary_cross_entropy_with_logits( 85 | pred_label, 86 | torch.zeros_like(pred_label), 87 | ) 88 | 89 | # The discriminator should predict all ones for "real" images 90 | target_loss = F.binary_cross_entropy_with_logits( 91 | target_label, 92 | torch.ones_like(pred_label), 93 | ) 94 | 95 | return pred_loss + target_loss 96 | 97 | def configure_optimizers(self): 98 | opt_g = torch.optim.Adam( 99 | self.unet.parameters(), 100 | lr=2e-4, 101 | betas=(0.5, 0.999), 102 | eps=1e-7, 103 | ) 104 | 105 | if self.discriminator is not None: 106 | opt_d = torch.optim.Adam( 107 | self.discriminator.parameters(), 108 | lr=2e-4, 109 | betas=(0.5, 0.999), 110 | eps=1e-7, 111 | ) 112 | 113 | return opt_g, opt_d 114 | 115 | return opt_g 116 | 117 | def training_step(self, batch, batch_idx): 118 | x, target = batch 119 | 120 | if self.loss_type == "gan": 121 | opt_d = self.optimizers()[1] 122 | 123 | # Train discriminator. 124 | self.toggle_optimizer(opt_d) 125 | 126 | pred = self.unet(x) 127 | 128 | target_label = self.discriminator(x, target) 129 | pred_label = self.discriminator(x, pred) 130 | d_loss = self.discriminator_loss(pred_label, target_label) 131 | 132 | self.log("d_loss", d_loss, prog_bar=True) 133 | 134 | self.discriminator.zero_grad(set_to_none=True) 135 | self.manual_backward(d_loss) 136 | opt_d.step() 137 | 138 | self.untoggle_optimizer(opt_d) 139 | 140 | opt_g = self.optimizers() 141 | if isinstance(opt_g, list): 142 | opt_g = opt_g[0] 143 | 144 | # Train U-net 145 | self.toggle_optimizer(opt_g) 146 | 147 | pred = self.unet(x) 148 | loss = self.loss(x, pred, target) 149 | 150 | den_pred = denormalize(pred) 151 | den_target = denormalize(target) 152 | 153 | self.log("loss", loss, prog_bar=True) 154 | self.log("train_ssim", ssim(den_pred, den_target), prog_bar=True) 155 | self.log("train_psnr", psnr(den_pred, den_target), prog_bar=True) 156 | self.log("train_rmse", rmse(den_pred, den_target), prog_bar=True) 157 | 158 | self.unet.zero_grad(set_to_none=True) 159 | self.manual_backward(loss) 160 | opt_g.step() 161 | 162 | self.untoggle_optimizer(opt_g) 163 | 164 | def validation_step(self, batch, batch_idx): 165 | x, target = batch 166 | pred = self.forward(x) 167 | 168 | den_pred = denormalize(pred) 169 | den_target = denormalize(target) 170 | 171 | self.log("val_ssim", ssim(den_pred, den_target), prog_bar=True) 172 | self.log("val_psnr", psnr(den_pred, den_target), prog_bar=True) 173 | self.log("val_rmse", rmse(den_pred, den_target), prog_bar=True) 174 | 175 | 176 | class DiscriminatorBlock(nn.Module): 177 | """An encoder block that is used in the discriminator. 178 | 179 | :param in_channels: Input channels. 180 | :param out_channels: Output channels. 181 | :param norm: Whether to use normalization or not. 182 | 183 | :input: [N x in_channels x H x W] 184 | :output: [N x out_channels x (H / 2) x (W / 2)] 185 | 186 | """ 187 | 188 | def __init__( 189 | self, 190 | in_channels: int, 191 | out_channels: int, 192 | norm: bool = False, 193 | ): 194 | super().__init__() 195 | 196 | self.block = nn.Sequential( 197 | nn.Conv2d( 198 | in_channels, 199 | out_channels, 200 | kernel_size=4, 201 | stride=2, 202 | padding=1 203 | ), 204 | nn.InstanceNorm2d(out_channels) if norm else nn.Identity(), 205 | nn.LeakyReLU(0.2), 206 | ) 207 | 208 | def forward(self, x): 209 | return self.block(x) 210 | 211 | 212 | class Discriminator(nn.Module): 213 | """Discriminator that distinguishes between real and fake images. This 214 | particular discriminator is used in all GANs in this repository. 215 | 216 | :param in_channels: Input channels for both the input and conditional 217 | image. 218 | 219 | :input x: [N x in_channels x H x W] 220 | :input y: [N x in_channels x H x W] 221 | :output: [1 x OUT x OUT] 222 | 223 | """ 224 | 225 | def __init__(self, in_channels: int = 3): 226 | super().__init__() 227 | 228 | self.discriminator = nn.Sequential( 229 | DiscriminatorBlock(in_channels * 2, 64, norm=False), 230 | DiscriminatorBlock(64, 128), 231 | DiscriminatorBlock(128, 256), 232 | DiscriminatorBlock(256, 512), 233 | nn.Conv2d(512, 1, kernel_size=4, padding=1, bias=False), 234 | ) 235 | 236 | def forward(self, x, y): 237 | h = torch.cat([x, y], dim=1) 238 | return self.discriminator(h) 239 | -------------------------------------------------------------------------------- /models/trans_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | import math 5 | from typing import Literal 6 | from .wrapper import UnetWrapper 7 | 8 | 9 | class TransUnetGAN(UnetWrapper): 10 | def __init__( 11 | self, 12 | in_channels: int = 3, 13 | out_channels: int = 3, 14 | channel_mults: tuple[int] = (1, 2, 2, 4, 4), 15 | patch_size: int = 2, 16 | dropout: float = 0.5, 17 | loss_type: Literal["gan", "ssim", "psnr", "ssim+psnr", "mse"] = "gan", 18 | ): 19 | unet = TransUnet( 20 | in_channels, 21 | out_channels, 22 | image_size=256, 23 | channel_mults=channel_mults, 24 | patch_size=patch_size, 25 | num_heads=8, 26 | dropout=dropout, 27 | ) 28 | 29 | super().__init__(unet, loss_type=loss_type) 30 | 31 | self.example_input_array = torch.Tensor(2, in_channels, 256, 256) 32 | self.save_hyperparameters() 33 | 34 | 35 | class TransUnet(nn.Module): 36 | """Trans U-net. 37 | 38 | :param in_channels: Input channels that can vary if the images are 39 | grayscale or color. 40 | :param out_channels: Input channels that can vary if the images are 41 | grayscale or color. 42 | :param image_size: Input image size. 43 | :param channel_mults: Define how deep and wide the U-net is. 44 | 45 | :input: [N x in_channels x image_size x image_size] 46 | :output: [N x out_channels x image_size x image_size] 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | in_channels: int = 3, 53 | out_channels: int = 3, 54 | image_size: int = 256, 55 | channel_mults: tuple[int] = (1, 2, 4, 8), 56 | patch_size: int = 16, 57 | num_heads: int = 8, 58 | dropout: float = 0.5, 59 | ): 60 | super().__init__() 61 | 62 | self.in_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1) 63 | in_channels = 64 64 | 65 | # Encoder blocks 66 | encoders = [] 67 | for mult in channel_mults: 68 | channels = mult * 64 69 | encoders.append(EncoderBlock(in_channels, channels)) 70 | in_channels = channels 71 | 72 | self.encoders = nn.ModuleList(encoders) 73 | 74 | # Vision transformer bottleneck 75 | self.vit_bottleneck = VisionTransformer( 76 | channels=channel_mults[-1] * 64, 77 | input_size=image_size // (2 ** len(channel_mults)), 78 | patch_size=patch_size, 79 | num_heads=num_heads, 80 | dropout=dropout, 81 | transformer_layers=12, 82 | ) 83 | 84 | # Decoder blocks 85 | decoders = [] 86 | for mult in reversed(list(channel_mults[:-1])): 87 | channels = mult * 64 88 | decoders.append(DecoderBlock(in_channels, channels)) 89 | in_channels = channels * 2 90 | 91 | decoders.append(DecoderBlock(in_channels, 64)) 92 | self.decoders = nn.ModuleList(decoders) 93 | 94 | self.out = nn.Sequential( 95 | nn.Conv2d(64, out_channels, kernel_size=3, padding=1), 96 | nn.Tanh(), 97 | ) 98 | 99 | def forward(self, x): 100 | h = self.in_conv(x.type(torch.float32)) 101 | 102 | skips = [] 103 | for encoder in self.encoders: 104 | h = encoder(h) 105 | skips.append(h) 106 | 107 | skips.pop() 108 | 109 | h = self.vit_bottleneck(h) 110 | 111 | for index, decoder in enumerate(self.decoders): 112 | if index != 0: 113 | h = torch.cat([h, skips.pop()], dim=1) 114 | 115 | h = decoder(h) 116 | 117 | return self.out(h) 118 | 119 | 120 | class VisionTransformer(nn.Module): 121 | def __init__( 122 | self, 123 | channels: int, 124 | input_size: int, 125 | patch_size: int = 16, 126 | num_heads: int = 8, 127 | dropout: float = 0.5, 128 | transformer_layers: int = 12, 129 | ): 130 | super().__init__() 131 | 132 | patch_dim = channels * patch_size * patch_size 133 | num_patches = (input_size ** 2) // (patch_size ** 2) 134 | 135 | self.to_patch_embedding = nn.Sequential( 136 | # Get flattened patches 137 | Rearrange( 138 | "n c (h p1) (w p2) -> n (h w) (p1 p2 c)", 139 | p1=patch_size, 140 | p2=patch_size, 141 | ), 142 | nn.LayerNorm(patch_dim), 143 | nn.Linear(patch_dim, patch_dim), 144 | nn.LayerNorm(patch_dim), 145 | ) 146 | 147 | self.pos_embedding = nn.Parameter( 148 | torch.randn(1, num_patches, patch_dim) 149 | ) 150 | 151 | trans_enc_layer = nn.TransformerEncoderLayer( 152 | patch_dim, 153 | num_heads, 154 | dropout=dropout, 155 | activation="gelu", 156 | ) 157 | 158 | self.transformer = nn.TransformerEncoder( 159 | trans_enc_layer, 160 | transformer_layers, 161 | ) 162 | 163 | self.to_image = Rearrange( 164 | "n (h w) (p1 p2 c) -> n c (h p1) (w p2)", 165 | h=int(math.sqrt(num_patches)), 166 | w=int(math.sqrt(num_patches)), 167 | p1=patch_size, 168 | p2=patch_size, 169 | ) 170 | 171 | def forward(self, x): 172 | patch_emb = self.to_patch_embedding(x) 173 | patch_emb += self.pos_embedding 174 | patch_emb = self.transformer(patch_emb) 175 | return self.to_image(patch_emb) 176 | 177 | 178 | class EncoderBlock(nn.Module): 179 | """Encoder block that downsamples the input by 2. Basically just a residual 180 | block as used in ResNet-50, ResNet-101, and ResNet-152 with a downsample. 181 | 182 | :param in_channels: Input channels. 183 | :param out_channels: Output channels. 184 | 185 | :input: [N x in_channels x image_size x image_size] 186 | :output: [N x out_channels x image_size x image_size] 187 | 188 | """ 189 | 190 | def __init__(self, in_channels: int, out_channels: int): 191 | super().__init__() 192 | 193 | bottleneck = in_channels // 4 194 | 195 | self.decode = nn.Sequential( 196 | nn.Conv2d(in_channels, bottleneck, kernel_size=1, bias=False), 197 | nn.BatchNorm2d(bottleneck), 198 | nn.ReLU(), 199 | nn.Conv2d( 200 | bottleneck, 201 | bottleneck, 202 | kernel_size=3, 203 | stride=2, 204 | padding=1, 205 | bias=False, 206 | ), 207 | nn.BatchNorm2d(bottleneck), 208 | nn.ReLU(), 209 | nn.Conv2d(bottleneck, out_channels, kernel_size=1, bias=False), 210 | nn.BatchNorm2d(out_channels), 211 | ) 212 | 213 | self.skip = nn.Sequential( 214 | nn.Conv2d( 215 | in_channels, 216 | out_channels, 217 | kernel_size=1, 218 | stride=2, 219 | bias=False, 220 | ), 221 | nn.BatchNorm2d(out_channels), 222 | ) 223 | 224 | self.out = nn.ReLU() 225 | 226 | def forward(self, x): 227 | return self.out(self.decode(x) + self.skip(x)) 228 | 229 | 230 | class DecoderBlock(nn.Module): 231 | """Decoder block that upsamples the input by 2. 232 | 233 | :param in_channels: Input channels. 234 | :param out_channels: Output channels. 235 | 236 | :input: [N x in_channels x H x W] 237 | :output: [N x out_channels x H x W] 238 | 239 | """ 240 | 241 | def __init__(self, in_channels: int, out_channels: int): 242 | super().__init__() 243 | 244 | self.decode = nn.Sequential( 245 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 246 | nn.BatchNorm2d(out_channels), 247 | nn.ReLU(), 248 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 249 | nn.BatchNorm2d(out_channels), 250 | nn.ReLU(), 251 | nn.Upsample(scale_factor=2), 252 | ) 253 | 254 | def forward(self, x): 255 | return self.decode(x) 256 | -------------------------------------------------------------------------------- /report.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchmetrics.functional import ( 4 | peak_signal_noise_ratio as psnr, 5 | structural_similarity_index_measure as ssim, 6 | mean_squared_error as mse, 7 | ) 8 | from matplotlib import colormaps 9 | from torchvision.io import write_png 10 | from argparse import ArgumentParser 11 | import pathlib 12 | import os 13 | from fvcore.nn import FlopCountAnalysis 14 | from models.pix2pix import Pix2Pix 15 | from models.palette import Palette 16 | from models.attention_unet import AttentionUnetGAN 17 | from models.res_unet import ResUnetGAN 18 | from models.trans_unet import TransUnetGAN 19 | from models.utils import denormalize, to_int, get_parameter_count 20 | from dataset import ImageDataModule 21 | 22 | 23 | def main(hparams): 24 | match hparams.model: 25 | case "pix2pix": 26 | model = Pix2Pix.load_from_checkpoint(hparams.checkpoint) 27 | model.freeze() 28 | 29 | case "palette": 30 | model = Palette.load_from_checkpoint(hparams.checkpoint) 31 | model.freeze() 32 | 33 | case "attention_unet": 34 | model = AttentionUnetGAN.load_from_checkpoint(hparams.checkpoint) 35 | model.freeze() 36 | 37 | case "res18_unet" | "res50_unet" | "resv2_unet" | "resnext_unet": 38 | model = ResUnetGAN.load_from_checkpoint(hparams.checkpoint) 39 | model.freeze() 40 | 41 | case "trans_unet": 42 | model = TransUnetGAN.load_from_checkpoint(hparams.checkpoint) 43 | model.freeze() 44 | 45 | case "identity": 46 | def model(x): return x 47 | 48 | case _: 49 | raise ValueError(f"Incorrect model name ({hparams.model})") 50 | 51 | if isinstance(model, nn.Module): 52 | device = model.device 53 | else: 54 | device = "cpu" 55 | 56 | data_module = ImageDataModule( 57 | hparams.data, 58 | batch_size=hparams.batch_size, 59 | ) 60 | data_module.setup("predict") 61 | dataloader = data_module.predict_dataloader() 62 | 63 | preds = [denormalize(model(batch[0].to(device))) for batch in dataloader] 64 | preds = torch.cat(preds, axis=0) 65 | preds = preds.cpu() 66 | # preds = denormalize(preds).cpu() 67 | 68 | targets = [denormalize(batch[1]) for batch in dataloader] 69 | targets = torch.cat(targets, axis=0) 70 | targets = targets.cpu() 71 | 72 | # Compute SSIM, PSNR, and MSE per image 73 | psnrs = [] 74 | ssims = [] 75 | mses = [] 76 | ssim_images = [] 77 | for pred, target in zip(preds.split(64), targets.split(64)): 78 | current_ssim, current_ssim_images = ssim( 79 | pred, 80 | target, 81 | data_range=1.0, 82 | return_full_image=True, 83 | reduction="none", 84 | ) 85 | ssims.append(current_ssim) 86 | ssim_images.append(current_ssim_images) 87 | 88 | current_psnr = torch.tensor([ 89 | psnr(p, t, data_range=1.0) for p, t in zip(pred, target) 90 | ]) 91 | psnrs.append(current_psnr) 92 | 93 | current_mse = torch.tensor([ 94 | mse(p, t) for p, t in zip(pred, target) 95 | ]) 96 | mses.append(current_mse) 97 | 98 | ssims = torch.cat(ssims) 99 | ssim_images = torch.cat(ssim_images) 100 | psnrs = torch.cat(psnrs) 101 | mses = torch.cat(mses) 102 | 103 | # Output average SSIM over depth and standard deviation 104 | ssim_over_depth = depth_ssim(preds, targets) 105 | ssim_over_depth_string = "depth,mean,std\n" 106 | for depth, (mean, std) in enumerate(ssim_over_depth, 1): 107 | ssim_over_depth_string += f"{depth},{mean},{std}\n" 108 | 109 | report_dir = os.path.join("reports", hparams.name) 110 | 111 | if not os.path.isdir(report_dir): 112 | os.mkdir(report_dir) 113 | 114 | with open(os.path.join(report_dir, "depth_ssim.csv"), "w") as f: 115 | f.write(ssim_over_depth_string) 116 | 117 | # Output prediction images 118 | outputs_dir = os.path.join(report_dir, "outputs") 119 | if not os.path.isdir(outputs_dir): 120 | os.mkdir(outputs_dir) 121 | 122 | for index, pred in enumerate(preds): 123 | output_hot_image( 124 | pred, 125 | os.path.join(outputs_dir, f"{str(index).zfill(5)}.png"), 126 | ) 127 | 128 | # Output SSIM maps 129 | ssim_images_dir = os.path.join(report_dir, "ssim_images") 130 | if not os.path.isdir(ssim_images_dir): 131 | os.mkdir(ssim_images_dir) 132 | 133 | for index, ssim_image in enumerate(ssim_images): 134 | write_png( 135 | to_int(ssim_image), 136 | os.path.join( 137 | report_dir, 138 | "ssim_images", 139 | f"{str(index).zfill(5)}.png", 140 | ) 141 | ) 142 | 143 | # Output mean statistics over entire test dataset 144 | ssim_stat = ssims.mean() 145 | psnr_stat = psnrs.mean() 146 | rmse_stat = mse(preds, targets, squared=False) 147 | parameter_count = get_parameter_count(model) 148 | 149 | # Count FLOPs 150 | flops = 0 151 | if isinstance(model, nn.Module): 152 | input_ = torch.randn(1, 3, 256, 256).to(device) 153 | flops = FlopCountAnalysis(model, input_) 154 | flops = flops.total() 155 | 156 | with open(os.path.join(report_dir, "stats.txt"), "w") as f: 157 | f.write(f"SSIM: {ssim_stat}\n") 158 | f.write(f"PSNR: {psnr_stat}\n") 159 | f.write(f"RMSE: {rmse_stat}\n") 160 | f.write(f"FLOPs: {flops}\n") 161 | f.write(f"Parameter count: {parameter_count}\n") 162 | 163 | # Output SSIM per image 164 | ssim_per_image_string = "image,ssim\n" 165 | for index, image_ssim in enumerate(ssims): 166 | ssim_per_image_string += f"{str(index).zfill(5)},{image_ssim}\n" 167 | 168 | with open(os.path.join(report_dir, "ssim_per_image.csv"), "w") as f: 169 | f.write(ssim_per_image_string) 170 | 171 | # Output PSNR per image 172 | psnr_per_image_string = "image,psnr\n" 173 | for index, image_psnr in enumerate(psnrs): 174 | psnr_per_image_string += f"{str(index).zfill(5)},{image_psnr}\n" 175 | 176 | with open(os.path.join(report_dir, "psnr_per_image.csv"), "w") as f: 177 | f.write(psnr_per_image_string) 178 | 179 | # Output RMSE per image 180 | mse_per_image_string = "image,mse\n" 181 | for index, image_mse in enumerate(mses): 182 | mse_per_image_string += f"{str(index).zfill(5)},{image_mse}\n" 183 | 184 | with open(os.path.join(report_dir, "mse_per_image.csv"), "w") as f: 185 | f.write(mse_per_image_string) 186 | 187 | 188 | def depth_ssim( 189 | preds: torch.Tensor, 190 | targets: torch.Tensor, 191 | num_depths: int = 16 192 | ) -> torch.Tensor: 193 | """Compute mean and standard deviation of SSIM over depth of images. The 194 | depth goes in the y-axis of the image. 195 | 196 | :param preds: [N x C x H x W] 197 | :param targets: [N x C x H x W] 198 | :returns: [num_depths] 199 | 200 | """ 201 | 202 | x_depths = preds.chunk(num_depths, dim=2) 203 | y_depths = targets.chunk(num_depths, dim=2) 204 | 205 | ssims = [] 206 | for depth in range(num_depths): 207 | depth_ssim = ssim( 208 | x_depths[depth], 209 | y_depths[depth], 210 | data_range=1.0, 211 | reduction="none", 212 | ) 213 | mean = depth_ssim.mean() 214 | std = depth_ssim.std() 215 | ssims.append((mean, std)) 216 | 217 | return torch.tensor(ssims) 218 | 219 | 220 | def output_hot_image(img: torch.Tensor, filename: str): 221 | """Outputs a hot-encoded image using the matplotlib hot colormap. 222 | 223 | :arg img: [1 x H x W] 224 | :arg filename: File location to save output. 225 | 226 | """ 227 | 228 | colormap = colormaps["afmhot"] 229 | img = colormap(img) 230 | img = img[0, :, :, :3] 231 | img = torch.Tensor(img) 232 | img = torch.permute(img, (2, 0, 1)) 233 | write_png(to_int(img), filename) 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = ArgumentParser() 238 | parser.add_argument("name") 239 | parser.add_argument( 240 | "-c", 241 | "--checkpoint", 242 | type=pathlib.Path, 243 | help="Path to checkpoint", 244 | ) 245 | parser.add_argument( 246 | "-d", 247 | "--data", 248 | type=pathlib.Path, 249 | help="YAML file of all data points", 250 | ) 251 | parser.add_argument("-bs", "--batch-size", default=2, type=int) 252 | parser.add_argument( 253 | "-m", 254 | "--model", 255 | default="pix2pix", 256 | choices=[ 257 | "pix2pix", 258 | "attention_unet", 259 | "res18_unet", 260 | "res50_unet", 261 | "resv2_unet", 262 | "resnext_unet", 263 | "trans_unet", 264 | "palette", 265 | "identity", 266 | ], 267 | ) 268 | args = parser.parse_args() 269 | 270 | main(args) 271 | -------------------------------------------------------------------------------- /models/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Literal 4 | from .wrapper import UnetWrapper 5 | 6 | 7 | ResType = Literal["18", "50", "v2", "next"] 8 | 9 | 10 | class ResUnetGAN(UnetWrapper): 11 | """Implementation of residual U-net. 12 | 13 | :param in_channels: Input channels that can vary if the images are 14 | grayscale or color. 15 | :param out_channels: Input channels that can vary if the images are 16 | grayscale or color. 17 | :param res_type: Which residual block to use. 18 | :param channel_mults: Channel multiples that define the depth and width of 19 | the U-net architecture. 20 | :param dropout: Dropout percentage used in some of the decoder blocks. 21 | :param loss_type: Loss type. One of "gan", "ssim", "psnr", "mse", 22 | "ssim+psnr". 23 | 24 | :input: [N x in_channels x H x W] 25 | :output: [N x out_channels x H x W] 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | in_channels: int = 3, 32 | out_channels: int = 3, 33 | res_type: ResType = "18", 34 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 35 | dropout: float = 0.5, 36 | loss_type: Literal["gan", "ssim", "psnr", "ssim+psnr", "mse"] = "gan", 37 | ): 38 | unet = ResUnet( 39 | in_channels, 40 | out_channels, 41 | res_type, 42 | channel_mults=channel_mults, 43 | dropout=dropout, 44 | ) 45 | 46 | super().__init__(unet, loss_type=loss_type) 47 | 48 | self.example_input_array = torch.Tensor(2, in_channels, 256, 256) 49 | self.save_hyperparameters() 50 | 51 | 52 | class ResidualBlock18(nn.Module): 53 | """Residual block as used in ResNet-18 and ResNet-34 (He et al. 2015).""" 54 | 55 | def __init__(self, in_channels: int, out_channels: int): 56 | super().__init__() 57 | 58 | self.conv_block = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(), 62 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(out_channels), 64 | ) 65 | 66 | self.conv_skip = nn.Sequential( 67 | nn.Conv2d(in_channels, out_channels, kernel_size=1), 68 | nn.BatchNorm2d(out_channels), 69 | ) if in_channels != out_channels else nn.Identity() 70 | 71 | self.out = nn.ReLU() 72 | 73 | def forward(self, x): 74 | return self.out(self.conv_block(x) + self.conv_skip(x)) 75 | 76 | 77 | class ResidualBlock50(nn.Module): 78 | """Residual block as used in ResNet-50, ResNet-101, and ResNet-152 79 | (He et al. 2015).""" 80 | 81 | def __init__(self, in_channels: int, out_channels: int): 82 | super().__init__() 83 | 84 | bottleneck = in_channels // 4 85 | 86 | self.conv_block = nn.Sequential( 87 | nn.Conv2d(in_channels, bottleneck, kernel_size=1), 88 | nn.BatchNorm2d(bottleneck), 89 | nn.ReLU(), 90 | nn.Conv2d(bottleneck, bottleneck, kernel_size=3, padding=1), 91 | nn.BatchNorm2d(bottleneck), 92 | nn.ReLU(), 93 | nn.Conv2d(bottleneck, out_channels, kernel_size=1), 94 | nn.BatchNorm2d(out_channels), 95 | ) 96 | 97 | self.conv_skip = nn.Sequential( 98 | nn.Conv2d(in_channels, out_channels, kernel_size=1), 99 | nn.BatchNorm2d(out_channels), 100 | ) if in_channels != out_channels else nn.Identity() 101 | 102 | self.out = nn.ReLU() 103 | 104 | def forward(self, x): 105 | return self.out(self.conv_block(x) + self.conv_skip(x)) 106 | 107 | 108 | class ResidualBlockV2(nn.Module): 109 | """Residual block as used in ResNet V2 (He et al. 2016).""" 110 | 111 | def __init__(self, in_channels: int, out_channels: int): 112 | super().__init__() 113 | 114 | self.conv_block = nn.Sequential( 115 | nn.BatchNorm2d(in_channels), 116 | nn.ReLU(), 117 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 118 | nn.BatchNorm2d(out_channels), 119 | nn.ReLU(), 120 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 121 | ) 122 | 123 | self.conv_skip = nn.Sequential( 124 | nn.BatchNorm2d(in_channels), 125 | nn.ReLU(), 126 | nn.Conv2d(in_channels, out_channels, kernel_size=1), 127 | ) if in_channels != out_channels else nn.Identity() 128 | 129 | def forward(self, x): 130 | return self.conv_block(x) + self.conv_skip(x) 131 | 132 | 133 | class ResidualBlockNeXt(nn.Module): 134 | """Residual block as used in ResNeXt (Xie et al. 2017).""" 135 | 136 | def __init__( 137 | self, 138 | in_channels: int, 139 | out_channels: int, 140 | cardinality: int = 32, 141 | bottleneck: int = 4, 142 | ): 143 | super().__init__() 144 | 145 | inner_width = bottleneck * cardinality 146 | 147 | self.conv_block = nn.Sequential( 148 | nn.Conv2d(in_channels, inner_width, kernel_size=1), 149 | nn.BatchNorm2d(inner_width), 150 | nn.ReLU(), 151 | nn.Conv2d( 152 | inner_width, 153 | inner_width, 154 | kernel_size=3, 155 | padding=1, 156 | groups=cardinality, 157 | ), 158 | nn.BatchNorm2d(inner_width), 159 | nn.ReLU(), 160 | nn.Conv2d(inner_width, out_channels, kernel_size=1), 161 | nn.BatchNorm2d(out_channels), 162 | nn.ReLU() 163 | ) 164 | 165 | self.conv_skip = nn.Sequential( 166 | nn.Conv2d(in_channels, out_channels, kernel_size=1), 167 | nn.BatchNorm2d(out_channels), 168 | ) if in_channels != out_channels else nn.Identity() 169 | 170 | def forward(self, x): 171 | return self.conv_block(x) + self.conv_skip(x) 172 | 173 | 174 | res_blocks: dict[ResType, nn.Module] = { 175 | "18": ResidualBlock18, 176 | "50": ResidualBlock50, 177 | "v2": ResidualBlockV2, 178 | "next": ResidualBlockNeXt, 179 | } 180 | 181 | 182 | class EncoderBlock(nn.Module): 183 | """Encoder block that downsamples the input by 2. 184 | 185 | :param in_channels: Input channels. 186 | :param out_channels: Output channels. 187 | :param res_type: Which residual block to use. 188 | 189 | :input: [N x in_channels x H x W] 190 | :output: [N x out_channels x (H / 2) x (W / 2)] 191 | 192 | """ 193 | 194 | def __init__(self, in_channels: int, out_channels: int, res_type: ResType): 195 | super().__init__() 196 | 197 | self.encode = nn.Sequential( 198 | res_blocks[res_type](in_channels, out_channels), 199 | nn.MaxPool2d(2), 200 | ) 201 | 202 | def forward(self, x): 203 | return self.encode(x) 204 | 205 | 206 | class DecoderBlock(nn.Module): 207 | """Decoder block that upsamples the input by 2. 208 | 209 | :param in_channels: Input channels. 210 | :param out_channels: Output channels. 211 | :param dropout: Dropout percentage. 212 | :param res_type: Which residual block to use. 213 | 214 | :input: [N x in_channels x H x W] 215 | :output: [N x out_channels x (H * 2) x (W * 2)] 216 | 217 | """ 218 | 219 | def __init__( 220 | self, 221 | in_channels: int, 222 | out_channels: int, 223 | res_type: ResType, 224 | dropout: float = 0.0, 225 | ): 226 | super().__init__() 227 | 228 | self.decode = nn.Sequential( 229 | res_blocks[res_type](in_channels, out_channels), 230 | nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(), 231 | nn.Upsample(scale_factor=2), 232 | ) 233 | 234 | def forward(self, x): 235 | return self.decode(x) 236 | 237 | 238 | class ResUnet(nn.Module): 239 | """U-net used as the generator in pix2pix GAN. 240 | 241 | :param in_channels: Input channels that can vary if the images are 242 | grayscale or color. 243 | :param out_channels: Input channels that can vary if the images are 244 | grayscale or color. 245 | :param res_type: Which residual block to use. 246 | :param channel_mults: Channel multiples that define the depth and width of 247 | the U-net architecture. 248 | :param dropout: Dropout percentage used in some of the decoder blocks. 249 | 250 | :input: [N x in_channels x H x W] 251 | :output: [N x out_channels x H x W] 252 | 253 | """ 254 | 255 | def __init__( 256 | self, 257 | in_channels: int = 3, 258 | out_channels: int = 3, 259 | res_type: ResType = "18", 260 | channel_mults: tuple[int] = (1, 2, 4, 8, 8, 8, 8, 8), 261 | dropout: float = 0.5, 262 | ): 263 | super().__init__() 264 | 265 | self.in_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1) 266 | in_channels = 64 267 | 268 | # Encoder blocks 269 | encoders = [] 270 | for level, mult in enumerate(channel_mults): 271 | channels = mult * 64 272 | encoders.append(EncoderBlock(in_channels, channels, res_type)) 273 | in_channels = channels 274 | 275 | self.encoders = nn.ModuleList(encoders) 276 | 277 | # Decoder blocks 278 | decoders = [] 279 | for level, mult in reversed(list(enumerate(channel_mults[:-1]))): 280 | channels = mult * 64 281 | 282 | decoders.append( 283 | DecoderBlock( 284 | in_channels, 285 | channels, 286 | res_type, 287 | # Only dropout in the lowest three decoder blocks that are 288 | # at the widest part 289 | dropout=dropout if ( 290 | mult == max(channel_mults) and 291 | level > len(channel_mults) - 5 292 | ) else 0, 293 | ) 294 | ) 295 | 296 | in_channels = channels * 2 297 | 298 | decoders.append( 299 | DecoderBlock( 300 | in_channels, 301 | channel_mults[0] * 64, 302 | res_type, 303 | ) 304 | ) 305 | 306 | self.decoders = nn.ModuleList(decoders) 307 | self.out = nn.Sequential( 308 | nn.Conv2d( 309 | channel_mults[0] * 64, 310 | out_channels, 311 | kernel_size=3, 312 | padding=1, 313 | ), 314 | nn.Tanh(), 315 | ) 316 | 317 | def forward(self, x): 318 | h = self.in_conv(x.type(torch.float32)) 319 | 320 | skips = [] 321 | for encoder in self.encoders: 322 | h = encoder(h) 323 | skips.append(h) 324 | 325 | # Remove last feature map, since that should not be used in 326 | # skip-connection 327 | skips.pop() 328 | 329 | for index, decoder in enumerate(self.decoders): 330 | if index != 0: 331 | h = torch.cat([h, skips.pop()], dim=1) 332 | 333 | h = decoder(h) 334 | 335 | return self.out(h) 336 | -------------------------------------------------------------------------------- /models/palette.py: -------------------------------------------------------------------------------- 1 | """Implementation of Palette: Image-to-Image Diffusion Models (Saharia et al., 2 | 2022).""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.io import write_png 8 | import pytorch_lightning as pl 9 | from tqdm import tqdm 10 | import os 11 | import math 12 | from typing import Literal 13 | from .guided_diffusion.unet import UNet 14 | from .utils import denormalize, to_int, ssim, psnr, rmse 15 | 16 | 17 | class Palette(pl.LightningModule): 18 | """ 19 | Palette image-to-image diffusion model. 20 | 21 | :param in_channels: Input channels. 22 | :param out_channels: Output channels. 23 | :param channel_mults: Channel multipliers for each level of the U-net. 24 | :param attention_res: Downsample rates at which attention blocks should be 25 | added after the residual blocks. 26 | :param num_heads: Number of heads used by all attention layers. 27 | :param dropout: Dropout percentage. 28 | :param schedule_type: Noise schedule type. Either cosine or linear. 29 | :param learn_var: Learn the variance aswell. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | in_channels: int = 3, 36 | out_channels: int = 3, 37 | channel_mults: tuple[int] = (1, 1, 2, 2, 4, 4), 38 | attention_res: tuple[int] = (16, 8), 39 | dropout: float = 0.1, 40 | schedule_type: Literal["linear", "cosine"] = "linear", 41 | learn_var: bool = False, 42 | ): 43 | super().__init__() 44 | self.save_hyperparameters() 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | self.learn_var = learn_var 49 | 50 | self.unet = UNet( 51 | in_channel=in_channels * 2, 52 | out_channel=out_channels * 2 if learn_var else out_channels, 53 | res_blocks=2, 54 | inner_channel=128, 55 | channel_mults=channel_mults, 56 | attn_res=attention_res, 57 | num_heads=4, 58 | dropout=dropout, 59 | conv_resample=True, 60 | image_size=256, 61 | ) 62 | 63 | # Noise schedules 64 | self.diffusion = DiffusionModel( 65 | schedule_type, 66 | 2000, 67 | 1e-6, 68 | 0.01, 69 | learn_var=learn_var, 70 | device=self.device, 71 | ) 72 | self.diffusion_inf = DiffusionModel( 73 | "cosine", 74 | 100, 75 | learn_var=learn_var, 76 | device=self.device, 77 | ) 78 | 79 | def forward(self, x, output_process=False): 80 | batch_size = x.shape[0] 81 | 82 | y_t = torch.randn_like(x) 83 | process_array = torch.unsqueeze(y_t, dim=1) 84 | for i in tqdm(reversed(range(self.diffusion_inf.timesteps))): 85 | t = torch.full((batch_size,), i, device=x.device) 86 | y_t = self.diffusion_inf.backward(x, y_t, t, self.unet) 87 | 88 | if ( 89 | output_process 90 | and i % (self.diffusion_inf.timesteps // 7) == 0 91 | ): 92 | process_array = torch.cat( 93 | [process_array, y_t.unsqueeze(1)], 94 | dim=1, 95 | ) 96 | 97 | if output_process: 98 | return y_t, process_array 99 | 100 | return y_t 101 | 102 | def configure_optimizers(self): 103 | optimizer = torch.optim.Adam(self.unet.parameters(), lr=1e-4) 104 | scheduler = torch.optim.lr_scheduler.LinearLR( 105 | optimizer, 106 | total_iters=10000, 107 | ) 108 | return [optimizer], [scheduler] 109 | 110 | def training_step(self, batch): 111 | x, y_0 = batch 112 | 113 | # Sample from p(gamma) 114 | t = torch.randint( 115 | 0, 116 | self.diffusion.timesteps, 117 | size=(y_0.shape[0],), 118 | device=y_0.device, 119 | ) 120 | y_t, noise, gamma = self.diffusion.forward(y_0, t) 121 | 122 | # Predict the added noise (and optionally variance) and compute loss 123 | model_output = self.unet(x, y_t, gamma) 124 | 125 | noise_pred = model_output 126 | if self.learn_var: 127 | noise_pred, _ = model_output.split(x.shape[1], dim=1) 128 | 129 | loss = F.mse_loss(noise_pred, noise) 130 | vlb_loss = self.diffusion.vlb_term(model_output, y_0, y_t, t).mean() 131 | 132 | self.log("mse_loss", loss, prog_bar=True) 133 | self.log("vlb_loss", vlb_loss, prog_bar=True) 134 | 135 | if self.learn_var: 136 | loss += 0.001 * vlb_loss 137 | 138 | self.log("loss", loss, prog_bar=True) 139 | 140 | return loss 141 | 142 | def on_validation_start(self): 143 | # Make dirs to save log video and output to 144 | epoch_dir = os.path.join( 145 | self.logger.log_dir, 146 | str(self.current_epoch + 1), 147 | ) 148 | 149 | if not os.path.exists(epoch_dir): 150 | os.mkdir(epoch_dir) 151 | 152 | def validation_step(self, batch, batch_idx): 153 | x, y_0 = batch 154 | batch_size = x.shape[0] 155 | y_pred = self.forward(x) 156 | 157 | # Write outputs of the model 158 | for ind, y_tx in enumerate(y_pred): 159 | write_png( 160 | to_int(denormalize(y_tx)).cpu(), 161 | os.path.join( 162 | self.logger.log_dir, 163 | str(self.current_epoch + 1), 164 | f"output_{batch_size * batch_idx + ind}.png", 165 | ), 166 | compression_level=0, 167 | ) 168 | 169 | den_y_pred = denormalize(y_pred) 170 | den_y_0 = denormalize(y_0) 171 | 172 | self.log("val_ssim", ssim(den_y_pred, den_y_0), prog_bar=True) 173 | self.log("val_psnr", psnr(den_y_pred, den_y_0), prog_bar=True) 174 | self.log("val_rmse", rmse(den_y_pred, den_y_0), prog_bar=True) 175 | 176 | 177 | class DiffusionModel(nn.Module): 178 | def __init__( 179 | self, 180 | schedule_type: Literal["linear", "cosine"], 181 | timesteps: int, 182 | start: float = 1e-6, 183 | end: float = 0.01, 184 | learn_var: bool = False, 185 | device="cpu", 186 | ): 187 | super().__init__() 188 | 189 | self.timesteps = timesteps 190 | self.learn_var = learn_var 191 | 192 | match schedule_type: 193 | case "linear": 194 | betas = linear_beta_schedule(timesteps, start, end) 195 | 196 | case "cosine": 197 | betas = cosine_beta_schedule(timesteps) 198 | 199 | case _: 200 | raise ValueError(f"{schedule_type} is not supported.") 201 | 202 | betas = betas.to(device) 203 | 204 | self.register_buffer("alphas", 1 - betas) 205 | self.register_buffer("gammas", torch.cumprod(self.alphas, axis=0)) 206 | self.register_buffer( 207 | "gammas_prev", 208 | torch.cat([ 209 | torch.ones((1,), device=self.gammas.device), 210 | self.gammas[:-1], 211 | ]), 212 | ) 213 | 214 | def forward(self, y_0, t): 215 | """ 216 | :param y_0: [N x C x H x W] 217 | :param t: [N] 218 | :returns: y_noised [N x C x H x W], noise [N x C x H x W], gamma [N] 219 | 220 | """ 221 | 222 | noise = torch.randn_like(y_0) * (t > 0).view(-1, 1, 1, 1) 223 | gamma_prev = self.get_value(self.gammas_prev, t) 224 | gamma_cur = self.get_value(self.gammas, t) 225 | gamma = (gamma_cur-gamma_prev) * \ 226 | torch.rand_like(gamma_cur) + gamma_prev 227 | 228 | mean = torch.sqrt(gamma) * y_0 229 | variance = torch.sqrt(1 - gamma) * noise 230 | 231 | return mean + variance, noise, gamma.view(-1) 232 | 233 | def backward(self, x, y_t, t, model): 234 | """ 235 | :param x: [N x C x H x W] 236 | :param y_t: [N x C x H x W] 237 | :param t: [N] 238 | :param model: U-net model that predicts the noise and optionally the 239 | variance. 240 | :returns: [N x C X H x W] 241 | 242 | """ 243 | 244 | gamma = self.gammas[t] 245 | model_output = model(x, y_t, gamma) 246 | 247 | mean, log_variance = self.p_mean_variance(model_output, y_t, t) 248 | sqrt_variance = torch.exp(0.5 * log_variance) 249 | 250 | noise = torch.randn_like(y_t) * (t > 1).view(-1, 1, 1, 1) 251 | 252 | return mean + sqrt_variance * noise 253 | 254 | def q_mean_variance(self, y_0, y_t, t): 255 | """Compute q(y_{t-1} | y_t, y_0) parameters.""" 256 | 257 | alpha = self.get_value(self.alphas, t) 258 | gamma = self.get_value(self.gammas, t) 259 | gamma_prev = self.get_value(self.gammas_prev, t) 260 | 261 | mean = ( 262 | (torch.sqrt(gamma_prev) * (1 - alpha) / (1 - gamma)) * y_0 + 263 | (torch.sqrt(alpha) * (1 - gamma_prev) / (1 - gamma)) * y_t 264 | ) 265 | var_lower_bound = (1 - alpha) * (1 - gamma_prev) / (1 - gamma) 266 | var_lower_bound = torch.clamp(var_lower_bound, min=1e-20) 267 | log_variance = torch.log(var_lower_bound) 268 | 269 | return mean, log_variance 270 | 271 | def p_mean_variance(self, model_output, y_t, t): 272 | """Compute p(y_{t-1} | y_t) parameters.""" 273 | 274 | alpha = self.get_value(self.alphas, t) 275 | gamma = self.get_value(self.gammas, t) 276 | gamma_prev = self.get_value(self.gammas_prev, t) 277 | 278 | # If the variance is not learned, we want the lower bound of the 279 | # variance, so fix var_interp to 0 280 | var_interp = 0 281 | noise_pred = model_output 282 | if self.learn_var: 283 | noise_pred, var_interp = model_output.split(y_t.shape[1], dim=1) 284 | # The range of the U-net is [-1, 1] 285 | var_interp = (var_interp + 1) / 2 286 | 287 | var_lower_bound = (1 - alpha) * (1 - gamma_prev) / (1 - gamma) 288 | var_lower_bound = torch.clamp(var_lower_bound, min=1e-20) 289 | var_upper_bound = 1 - alpha 290 | 291 | log_variance = ( 292 | var_interp * torch.log(var_upper_bound) + 293 | (1-var_interp) * torch.log(var_lower_bound) 294 | ) 295 | 296 | y_0_hat = (1 / torch.sqrt(gamma)) * ( 297 | y_t - 298 | torch.sqrt(1 - gamma) * noise_pred 299 | ) 300 | y_0_hat = torch.clamp(y_0_hat, -1, 1) 301 | 302 | mean = ( 303 | (torch.sqrt(gamma_prev) * (1 - alpha) / (1 - gamma)) * y_0_hat + 304 | (torch.sqrt(alpha) * (1 - gamma_prev) / (1 - gamma)) * y_t 305 | ) 306 | return mean, log_variance 307 | 308 | def vlb_term(self, model_output, y_0, y_t, t): 309 | """Compute a term for the variational lower-bound.""" 310 | 311 | # Learn the variance using the variational bound, but do not let it 312 | # affect the mean prediction 313 | if self.learn_var: 314 | noise_pred, var_interp = model_output.split(y_t.shape[1], dim=1) 315 | model_output = torch.cat([noise_pred.detach(), var_interp], dim=1) 316 | 317 | true_mean, true_log_variance = self.q_mean_variance(y_0, y_t, t) 318 | pred_mean, pred_log_variance = self.p_mean_variance( 319 | model_output, y_t, t) 320 | 321 | kl = normal_kl(true_mean, true_log_variance, 322 | pred_mean, pred_log_variance) 323 | # Take mean for each item in batch 324 | kl = kl.mean(dim=[1, 2, 3]) / math.log(2.0) 325 | 326 | nll = -discretized_gaussian_log_likelihood( 327 | y_0, 328 | means=pred_mean, 329 | log_scales=0.5 * pred_log_variance, 330 | ) 331 | nll = nll.mean(dim=[1, 2, 3]) / math.log(2.0) 332 | 333 | return torch.where(t == 0, nll, kl) 334 | 335 | def get_value(self, values, t): 336 | """ 337 | Reshapes the value to be multiplied with a batch of images. 338 | 339 | :param values: [N x C x H x W] 340 | :param t: [N] 341 | :returns: [N x 1 x 1 x 1] 342 | 343 | """ 344 | 345 | return values[t].view(-1, 1, 1, 1) 346 | 347 | 348 | def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: 349 | """Cosine schedule as proposed in (Nichol and Dhariwal, 2021).""" 350 | 351 | steps = timesteps + 1 352 | x = torch.linspace(0, timesteps, steps) 353 | gammas = torch.cos((torch.pi / 2) * ((x / timesteps) + s) / (1 + s)) 354 | gammas = gammas / gammas[0] 355 | betas = 1 - (gammas[1:] / gammas[:-1]) 356 | 357 | return torch.clamp(betas, 0.0001, 0.9999) 358 | 359 | 360 | def linear_beta_schedule( 361 | timesteps: int, 362 | start: float = 1e-6, 363 | end: float = 0.01, 364 | ) -> torch.Tensor: 365 | return torch.linspace(start, end, timesteps) 366 | 367 | 368 | def normal_kl(mean1, log_var1, mean2, log_var2): 369 | """Compute the KL divergence between two Gaussians.""" 370 | 371 | # Set variances to be tensors 372 | if not isinstance(log_var1, torch.Tensor): 373 | log_var1 = torch.tensor(log_var1).to(mean1.device) 374 | 375 | if not isinstance(log_var2, torch.Tensor): 376 | log_var2 = torch.tensor(log_var2).to(mean2.device) 377 | 378 | return 0.5 * ( 379 | -1.0 + 380 | (log_var2 - log_var1) + 381 | torch.exp(log_var1 - log_var2) + 382 | ((mean1 - mean2) ** 2) * torch.exp(-log_var2) 383 | ) 384 | 385 | 386 | def approx_standard_normal_cdf(x): 387 | """A fast approximation of the cumulative distribution function of the 388 | standard normal.""" 389 | 390 | return 0.5 * ( 391 | 1.0 + 392 | torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3))) 393 | ) 394 | 395 | 396 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 397 | """Compute the log-likelihood of a Gaussian distribution discretizing to a 398 | given image. 399 | 400 | :param x: The target images. It is assumed that this was uint8 values, 401 | rescaled to the range [-1, 1]. 402 | :param means: The Gaussian mean Tensor. 403 | :param log_scales: The Gaussian log stddev Tensor. 404 | :returns: A tensor like x of log probabilities (in nats). 405 | 406 | """ 407 | 408 | centered_x = x - means 409 | inv_stdv = torch.exp(-log_scales) 410 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 411 | cdf_plus = approx_standard_normal_cdf(plus_in) 412 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 413 | cdf_min = approx_standard_normal_cdf(min_in) 414 | log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) 415 | log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) 416 | cdf_delta = cdf_plus - cdf_min 417 | log_probs = torch.where( 418 | x < -0.999, 419 | log_cdf_plus, 420 | torch.where( 421 | x > 0.999, 422 | log_one_minus_cdf_min, 423 | torch.log(cdf_delta.clamp(min=1e-12)), 424 | ), 425 | ) 426 | 427 | return log_probs 428 | -------------------------------------------------------------------------------- /models/guided_diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .nn import ( 9 | checkpoint, 10 | zero_module, 11 | normalization1d, 12 | normalization2d, 13 | count_flops_attn, 14 | gamma_embedding 15 | ) 16 | 17 | 18 | class SiLU(nn.Module): 19 | def forward(self, x): 20 | return x * torch.sigmoid(x) 21 | 22 | 23 | class EmbedBlock(nn.Module): 24 | """ 25 | Any module where forward() takes embeddings as a second argument. 26 | """ 27 | 28 | @abstractmethod 29 | def forward(self, x, emb): 30 | """ 31 | Apply the module to `x` given `emb` embeddings. 32 | """ 33 | 34 | 35 | class EmbedSequential(nn.Sequential, EmbedBlock): 36 | """ 37 | A sequential module that passes embeddings to the children that 38 | support it as an extra input. 39 | """ 40 | 41 | def forward(self, x, emb): 42 | for layer in self: 43 | if isinstance(layer, EmbedBlock): 44 | x = layer(x, emb) 45 | else: 46 | x = layer(x) 47 | return x 48 | 49 | 50 | class Upsample(nn.Module): 51 | """ 52 | An upsampling layer with an optional convolution. 53 | :param channels: channels in the inputs and outputs. 54 | :param use_conv: a bool determining if a convolution is applied. 55 | 56 | """ 57 | 58 | def __init__(self, channels, use_conv, out_channel=None): 59 | super().__init__() 60 | self.channels = channels 61 | self.out_channel = out_channel or channels 62 | self.use_conv = use_conv 63 | if use_conv: 64 | self.conv = nn.Conv2d( 65 | self.channels, 66 | self.out_channel, 67 | kernel_size=3, 68 | padding=1, 69 | ) 70 | 71 | def forward(self, x): 72 | assert x.shape[1] == self.channels 73 | x = F.interpolate(x, scale_factor=2, mode="nearest") 74 | if self.use_conv: 75 | x = self.conv(x) 76 | return x 77 | 78 | 79 | class Downsample(nn.Module): 80 | """ 81 | A downsampling layer with an optional convolution. 82 | :param channels: channels in the inputs and outputs. 83 | :param use_conv: a bool determining if a convolution is applied. 84 | """ 85 | 86 | def __init__(self, channels, use_conv, out_channel=None): 87 | super().__init__() 88 | self.channels = channels 89 | self.out_channel = out_channel or channels 90 | self.use_conv = use_conv 91 | stride = 2 92 | if use_conv: 93 | self.op = nn.Conv2d( 94 | self.channels, self.out_channel, 3, stride=stride, padding=1 95 | ) 96 | else: 97 | assert self.channels == self.out_channel 98 | self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) 99 | 100 | def forward(self, x): 101 | assert x.shape[1] == self.channels 102 | return self.op(x) 103 | 104 | 105 | class ResBlock(EmbedBlock): 106 | """ 107 | A residual block that can optionally change the number of channels. 108 | :param channels: the number of input channels. 109 | :param emb_channels: the number of embedding channels. 110 | :param dropout: the rate of dropout. 111 | :param out_channel: if specified, the number of out channels. 112 | :param use_conv: if True and out_channel is specified, use a spatial 113 | convolution instead of a smaller 1x1 convolution to change the 114 | channels in the skip connection. 115 | :param use_checkpoint: if True, use gradient checkpointing on this module. 116 | :param up: if True, use this block for upsampling. 117 | :param down: if True, use this block for downsampling. 118 | """ 119 | 120 | def __init__( 121 | self, 122 | channels, 123 | emb_channels, 124 | dropout, 125 | out_channel=None, 126 | use_conv=False, 127 | use_scale_shift_norm=False, 128 | use_checkpoint=False, 129 | up=False, 130 | down=False, 131 | ): 132 | super().__init__() 133 | self.channels = channels 134 | self.emb_channels = emb_channels 135 | self.dropout = dropout 136 | self.out_channel = out_channel or channels 137 | self.use_conv = use_conv 138 | self.use_checkpoint = use_checkpoint 139 | self.use_scale_shift_norm = use_scale_shift_norm 140 | 141 | self.in_layers = nn.Sequential( 142 | normalization2d(channels), 143 | SiLU(), 144 | nn.Conv2d(channels, self.out_channel, 3, padding=1), 145 | ) 146 | 147 | self.updown = up or down 148 | 149 | if up: 150 | self.h_upd = Upsample(channels, False) 151 | self.x_upd = Upsample(channels, False) 152 | elif down: 153 | self.h_upd = Downsample(channels, False) 154 | self.x_upd = Downsample(channels, False) 155 | else: 156 | self.h_upd = self.x_upd = nn.Identity() 157 | 158 | self.emb_layers = nn.Sequential( 159 | SiLU(), 160 | nn.Linear( 161 | emb_channels, 162 | 2 * self.out_channel if use_scale_shift_norm else self.out_channel, 163 | ), 164 | ) 165 | self.out_layers = nn.Sequential( 166 | normalization2d(self.out_channel), 167 | SiLU(), 168 | nn.Dropout(p=dropout), 169 | zero_module( 170 | nn.Conv2d(self.out_channel, self.out_channel, 3, padding=1) 171 | ), 172 | ) 173 | 174 | if self.out_channel == channels: 175 | self.skip_connection = nn.Identity() 176 | elif use_conv: 177 | self.skip_connection = nn.Conv2d( 178 | channels, self.out_channel, 3, padding=1 179 | ) 180 | else: 181 | self.skip_connection = nn.Conv2d(channels, self.out_channel, 1) 182 | 183 | def forward(self, x, emb): 184 | """ 185 | Apply the block to a Tensor, conditioned on a embedding. 186 | :param x: an [N x C x ...] Tensor of features. 187 | :param emb: an [N x emb_channels] Tensor of embeddings. 188 | :return: an [N x C x ...] Tensor of outputs. 189 | """ 190 | return checkpoint( 191 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 192 | ) 193 | 194 | def _forward(self, x, emb): 195 | if self.updown: 196 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 197 | h = in_rest(x) 198 | h = self.h_upd(h) 199 | x = self.x_upd(x) 200 | h = in_conv(h) 201 | else: 202 | h = self.in_layers(x) 203 | emb_out = self.emb_layers(emb).type(h.dtype) 204 | while len(emb_out.shape) < len(h.shape): 205 | emb_out = emb_out[..., None] 206 | if self.use_scale_shift_norm: 207 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 208 | scale, shift = torch.chunk(emb_out, 2, dim=1) 209 | h = out_norm(h) * (1 + scale) + shift 210 | h = out_rest(h) 211 | else: 212 | h = h + emb_out 213 | h = self.out_layers(h) 214 | return self.skip_connection(x) + h 215 | 216 | 217 | class AttentionBlock(nn.Module): 218 | """ 219 | An attention block that allows spatial positions to attend to each other. 220 | Originally ported from here, but adapted to the N-d case. 221 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 222 | """ 223 | 224 | def __init__( 225 | self, 226 | channels, 227 | num_heads=1, 228 | num_head_channels=-1, 229 | use_checkpoint=False, 230 | use_new_attention_order=False, 231 | ): 232 | super().__init__() 233 | self.channels = channels 234 | if num_head_channels == -1: 235 | self.num_heads = num_heads 236 | else: 237 | assert ( 238 | channels % num_head_channels == 0 239 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 240 | self.num_heads = channels // num_head_channels 241 | self.use_checkpoint = use_checkpoint 242 | self.norm = normalization1d(channels) 243 | self.qkv = nn.Conv1d(channels, channels * 3, 1) 244 | if use_new_attention_order: 245 | # split qkv before split heads 246 | self.attention = QKVAttention(self.num_heads) 247 | else: 248 | # split heads before split qkv 249 | self.attention = QKVAttentionLegacy(self.num_heads) 250 | 251 | self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) 252 | 253 | def forward(self, x): 254 | return checkpoint(self._forward, (x,), self.parameters(), True) 255 | 256 | def _forward(self, x): 257 | b, c, *spatial = x.shape 258 | x = x.reshape(b, c, -1) 259 | qkv = self.qkv(self.norm(x)) 260 | h = self.attention(qkv) 261 | h = self.proj_out(h) 262 | return (x + h).reshape(b, c, *spatial) 263 | 264 | 265 | class QKVAttentionLegacy(nn.Module): 266 | """ 267 | A module which performs QKV attention. Matches legacy QKVAttention + 268 | input/ouput heads shaping 269 | 270 | """ 271 | 272 | def __init__(self, n_heads): 273 | super().__init__() 274 | self.n_heads = n_heads 275 | 276 | def forward(self, qkv): 277 | """ 278 | Apply QKV attention. 279 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 280 | :return: an [N x (H * C) x T] tensor after attention. 281 | """ 282 | 283 | bs, width, length = qkv.shape 284 | assert width % (3 * self.n_heads) == 0 285 | ch = width // (3 * self.n_heads) 286 | q, k, v = qkv.reshape( 287 | bs * self.n_heads, 288 | ch * 3, 289 | length 290 | ).split(ch, dim=1) 291 | scale = 1 / math.sqrt(math.sqrt(ch)) 292 | weight = torch.einsum( 293 | "bct,bcs->bts", q * scale, k * scale 294 | ) # More stable with f16 than dividing afterwards 295 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 296 | a = torch.einsum("bts,bcs->bct", weight, v) 297 | return a.reshape(bs, -1, length) 298 | 299 | @staticmethod 300 | def count_flops(model, _x, y): 301 | return count_flops_attn(model, _x, y) 302 | 303 | 304 | class QKVAttention(nn.Module): 305 | """ 306 | A module which performs QKV attention and splits in a different order. 307 | """ 308 | 309 | def __init__(self, n_heads): 310 | super().__init__() 311 | self.n_heads = n_heads 312 | 313 | def forward(self, qkv): 314 | """ 315 | Apply QKV attention. 316 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 317 | :return: an [N x (H * C) x T] tensor after attention. 318 | """ 319 | bs, width, length = qkv.shape 320 | assert width % (3 * self.n_heads) == 0 321 | ch = width // (3 * self.n_heads) 322 | q, k, v = qkv.chunk(3, dim=1) 323 | scale = 1 / math.sqrt(math.sqrt(ch)) 324 | weight = torch.einsum( 325 | "bct,bcs->bts", 326 | (q * scale).view(bs * self.n_heads, ch, length), 327 | (k * scale).view(bs * self.n_heads, ch, length), 328 | ) # More stable with f16 than dividing afterwards 329 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 330 | a = torch.einsum( 331 | "bts,bcs->bct", 332 | weight, 333 | v.reshape(bs * self.n_heads, ch, length), 334 | ) 335 | return a.reshape(bs, -1, length) 336 | 337 | @staticmethod 338 | def count_flops(model, _x, y): 339 | return count_flops_attn(model, _x, y) 340 | 341 | 342 | class UNet(nn.Module): 343 | """ 344 | The full UNet model with attention and embedding. 345 | :param in_channel: channels in the input Tensor, for image colorization: 346 | Y_channels + X_channels. 347 | :param inner_channel: base channel count for the model. 348 | :param out_channel: channels in the output Tensor. 349 | :param res_blocks: number of residual blocks per downsample. 350 | :param attn_res: a collection of downsample rates at which 351 | attention will take place. May be a set, list, or tuple. 352 | For example, if this contains 4, then at 4x downsampling, attention 353 | will be used. 354 | :param dropout: the dropout probability. 355 | :param channel_mults: channel multiplier for each level of the UNet. 356 | :param conv_resample: if True, use learned convolutions for upsampling and 357 | downsampling. 358 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 359 | :param num_heads: the number of attention heads in each attention layer. 360 | :param num_heads_channels: if specified, ignore num_heads and instead use 361 | a fixed channel width per attention head. 362 | :param num_heads_upsample: works with num_heads to set a different number 363 | of heads for upsampling. Deprecated. 364 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 365 | :param resblock_updown: use residual blocks for up/downsampling. 366 | :param use_new_attention_order: use a different attention pattern for 367 | potentially increased efficiency. 368 | """ 369 | 370 | def __init__( 371 | self, 372 | image_size, 373 | in_channel, 374 | inner_channel, 375 | out_channel, 376 | res_blocks, 377 | attn_res, 378 | dropout=0, 379 | channel_mults=(1, 2, 4, 8), 380 | conv_resample=True, 381 | use_checkpoint=False, 382 | use_fp16=False, 383 | num_heads=1, 384 | num_head_channels=-1, 385 | num_heads_upsample=-1, 386 | use_scale_shift_norm=True, 387 | resblock_updown=True, 388 | use_new_attention_order=False, 389 | ): 390 | 391 | super().__init__() 392 | 393 | if num_heads_upsample == -1: 394 | num_heads_upsample = num_heads 395 | 396 | self.image_size = image_size 397 | self.in_channel = in_channel 398 | self.inner_channel = inner_channel 399 | self.out_channel = out_channel 400 | self.res_blocks = res_blocks 401 | self.attn_res = attn_res 402 | self.dropout = dropout 403 | self.channel_mults = channel_mults 404 | self.conv_resample = conv_resample 405 | self.use_checkpoint = use_checkpoint 406 | self.dtype = torch.float16 if use_fp16 else torch.float32 407 | self.num_heads = num_heads 408 | self.num_head_channels = num_head_channels 409 | self.num_heads_upsample = num_heads_upsample 410 | 411 | cond_embed_dim = inner_channel * 4 412 | self.cond_embed = nn.Sequential( 413 | nn.Linear(inner_channel, cond_embed_dim), 414 | SiLU(), 415 | nn.Linear(cond_embed_dim, cond_embed_dim), 416 | ) 417 | 418 | ch = input_ch = int(channel_mults[0] * inner_channel) 419 | self.input_blocks = nn.ModuleList( 420 | [EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))] 421 | ) 422 | self._feature_size = ch 423 | input_block_chans = [ch] 424 | ds = 1 425 | for level, mult in enumerate(channel_mults): 426 | for _ in range(res_blocks): 427 | layers = [ 428 | ResBlock( 429 | ch, 430 | cond_embed_dim, 431 | dropout, 432 | out_channel=int(mult * inner_channel), 433 | use_checkpoint=use_checkpoint, 434 | use_scale_shift_norm=use_scale_shift_norm, 435 | ) 436 | ] 437 | ch = int(mult * inner_channel) 438 | if ds in attn_res: 439 | layers.append( 440 | AttentionBlock( 441 | ch, 442 | use_checkpoint=use_checkpoint, 443 | num_heads=num_heads, 444 | num_head_channels=num_head_channels, 445 | use_new_attention_order=use_new_attention_order, 446 | ) 447 | ) 448 | self.input_blocks.append(EmbedSequential(*layers)) 449 | self._feature_size += ch 450 | input_block_chans.append(ch) 451 | if level != len(channel_mults) - 1: 452 | out_ch = ch 453 | self.input_blocks.append( 454 | EmbedSequential( 455 | ResBlock( 456 | ch, 457 | cond_embed_dim, 458 | dropout, 459 | out_channel=out_ch, 460 | use_checkpoint=use_checkpoint, 461 | use_scale_shift_norm=use_scale_shift_norm, 462 | down=True, 463 | ) 464 | if resblock_updown 465 | else Downsample( 466 | ch, conv_resample, out_channel=out_ch 467 | ) 468 | ) 469 | ) 470 | ch = out_ch 471 | input_block_chans.append(ch) 472 | ds *= 2 473 | self._feature_size += ch 474 | 475 | self.middle_block = EmbedSequential( 476 | ResBlock( 477 | ch, 478 | cond_embed_dim, 479 | dropout, 480 | use_checkpoint=use_checkpoint, 481 | use_scale_shift_norm=use_scale_shift_norm, 482 | ), 483 | AttentionBlock( 484 | ch, 485 | use_checkpoint=use_checkpoint, 486 | num_heads=num_heads, 487 | num_head_channels=num_head_channels, 488 | use_new_attention_order=use_new_attention_order, 489 | ), 490 | ResBlock( 491 | ch, 492 | cond_embed_dim, 493 | dropout, 494 | use_checkpoint=use_checkpoint, 495 | use_scale_shift_norm=use_scale_shift_norm, 496 | ), 497 | ) 498 | self._feature_size += ch 499 | 500 | self.output_blocks = nn.ModuleList([]) 501 | for level, mult in list(enumerate(channel_mults))[::-1]: 502 | for i in range(res_blocks + 1): 503 | ich = input_block_chans.pop() 504 | layers = [ 505 | ResBlock( 506 | ch + ich, 507 | cond_embed_dim, 508 | dropout, 509 | out_channel=int(inner_channel * mult), 510 | use_checkpoint=use_checkpoint, 511 | use_scale_shift_norm=use_scale_shift_norm, 512 | ) 513 | ] 514 | ch = int(inner_channel * mult) 515 | if ds in attn_res: 516 | layers.append( 517 | AttentionBlock( 518 | ch, 519 | use_checkpoint=use_checkpoint, 520 | num_heads=num_heads_upsample, 521 | num_head_channels=num_head_channels, 522 | use_new_attention_order=use_new_attention_order, 523 | ) 524 | ) 525 | if level and i == res_blocks: 526 | out_ch = ch 527 | layers.append( 528 | ResBlock( 529 | ch, 530 | cond_embed_dim, 531 | dropout, 532 | out_channel=out_ch, 533 | use_checkpoint=use_checkpoint, 534 | use_scale_shift_norm=use_scale_shift_norm, 535 | up=True, 536 | ) 537 | if resblock_updown 538 | else Upsample(ch, conv_resample, out_channel=out_ch) 539 | ) 540 | ds //= 2 541 | self.output_blocks.append(EmbedSequential(*layers)) 542 | self._feature_size += ch 543 | 544 | self.out = nn.Sequential( 545 | normalization2d(ch), 546 | SiLU(), 547 | zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), 548 | ) 549 | 550 | def forward(self, x, y, gammas): 551 | """ 552 | Apply the model to an input batch. 553 | :param x: [N x C x ...] 554 | :param y: [N x C x ...] 555 | :param gammas: a 1-D batch of gammas. 556 | :return: an [N x C x ...] Tensor of outputs. 557 | """ 558 | 559 | hs = [] 560 | gammas = gammas.view(-1, ) 561 | emb = self.cond_embed(gamma_embedding(gammas, self.inner_channel)) 562 | 563 | h = torch.cat([x, y], dim=1) 564 | h = h.type(torch.float32) 565 | for module in self.input_blocks: 566 | h = module(h, emb) 567 | hs.append(h) 568 | h = self.middle_block(h, emb) 569 | for module in self.output_blocks: 570 | h = torch.cat([h, hs.pop()], dim=1) 571 | h = module(h, emb) 572 | h = h.type(x.dtype) 573 | return self.out(h) 574 | --------------------------------------------------------------------------------