├── datasets ├── __init__.py ├── data_module.py └── lidc.py ├── assets └── model.png ├── .gitignore ├── vqgan ├── gan_losses.py ├── perceivers.py ├── model_base.py ├── train.py ├── evaluate.ipynb ├── attention_blocks.py ├── conv_blocks.py └── model.py ├── requirements.txt ├── gptnano ├── utils.py ├── evaluate.ipynb ├── mingpt.py ├── train.py ├── translator.py └── nanogpt.py ├── README.md └── preprocessing └── preprocess.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets.lidc import LIDCDataset -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FirasGit/transformers_ct_reconstruction/HEAD/assets/model.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | runs/ 3 | gpt_results/ 4 | .vscode/ 5 | .env 6 | __pycache__/ 7 | lidc_indices/ 8 | debug/ 9 | paper/ 10 | -------------------------------------------------------------------------------- /vqgan/gan_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def exp_d_loss(logits_real, logits_fake): 5 | loss_real = torch.mean(torch.exp(-logits_real)) 6 | loss_fake = torch.mean(torch.exp(logits_fake)) 7 | d_loss = 0.5 * (loss_real + loss_fake) 8 | return d_loss 9 | 10 | def hinge_d_loss(logits_real, logits_fake): 11 | loss_real = torch.mean(F.relu(1. - logits_real)) 12 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 13 | d_loss = 0.5 * (loss_real + loss_fake) 14 | return d_loss 15 | 16 | def vanilla_d_loss(logits_real, logits_fake): 17 | d_loss = 0.5 * ( 18 | torch.mean(F.softplus(-logits_real)) + 19 | torch.mean(F.softplus(logits_fake))) 20 | return d_loss -------------------------------------------------------------------------------- /vqgan/perceivers.py: -------------------------------------------------------------------------------- 1 | import lpips 2 | import torch 3 | 4 | class LPIPS(torch.nn.Module): 5 | """Learned Perceptual Image Patch Similarity (LPIPS)""" 6 | def __init__(self, linear_calibration=False, normalize=False): 7 | super().__init__() 8 | self.loss_fn = lpips.LPIPS(net='vgg', lpips=linear_calibration) # Note: only 'vgg' valid as loss 9 | self.normalize = normalize # If true, normalize [0, 1] to [-1, 1] 10 | 11 | 12 | def forward(self, pred, target): 13 | # No need to do that because ScalingLayer was introduced in version 0.1 which does this indirectly 14 | # if pred.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB 15 | # pred = torch.concat([pred, pred, pred], dim=1) 16 | # if target.shape[1] == 1: # convert 1-channel gray images to 3-channel RGB 17 | # target = torch.concat([target, target, target], dim=1) 18 | 19 | if pred.ndim == 5: # 3D Image: Just use 2D model and compute average over slices 20 | depth = pred.shape[2] 21 | losses = torch.stack([self.loss_fn(pred[:,:,d], target[:,:,d], normalize=self.normalize) for d in range(depth)], dim=2) 22 | return torch.mean(losses, dim=2, keepdim=True) 23 | else: 24 | return self.loss_fn(pred, target, normalize=self.normalize) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | aiohttp==3.8.4 3 | aiosignal==1.3.1 4 | async-timeout==4.0.2 5 | attrs==22.2.0 6 | brotlipy==0.7.0 7 | cachetools==5.3.0 8 | certifi==2022.12.7 9 | contourpy==1.0.7 10 | cycler==0.11.0 11 | debugpy==1.6.6 12 | einops==0.6.0 13 | fonttools==4.38.0 14 | frozenlist==1.3.3 15 | fsspec==2023.3.0 16 | google-auth==2.16.2 17 | google-auth-oauthlib==0.4.6 18 | grpcio==1.51.3 19 | imageio==2.26.0 20 | importlib-metadata==6.0.0 21 | importlib-resources==5.12.0 22 | Jinja2==3.1.2 23 | kiwisolver==1.4.4 24 | lazy_loader==0.1 25 | lpips==0.1.4 26 | Markdown==3.4.1 27 | MarkupSafe==2.1.2 28 | matplotlib==3.7.1 29 | mkl-fft==1.3.1 30 | mkl-service==2.4.0 31 | monai==1.1.0 32 | mpmath==1.2.1 33 | multidict==6.0.4 34 | oauthlib==3.2.2 35 | Pillow==9.4.0 36 | protobuf==3.20.1 37 | pyasn1==0.4.8 38 | pyasn1-modules==0.2.8 39 | pyDeprecate==0.3.2 40 | pyparsing==3.0.9 41 | pytorch-lightning==1.6.4 42 | pytorch-msssim==0.2.1 43 | PyWavelets==1.4.1 44 | PyYAML==6.0 45 | pyzmq==19.0.2 46 | requests-oauthlib==1.3.1 47 | rsa==4.9 48 | scikit-image==0.20.0 49 | scipy==1.9.1 50 | tensorboard==2.12.0 51 | tensorboard-data-server==0.7.0 52 | tensorboard-plugin-wit==1.8.1 53 | tifffile==2023.2.28 54 | torch==2.1.0.dev20230305 55 | torchaudio==2.0.0.dev20230305 56 | torchmetrics==0.11.3 57 | torchvision==0.15.0.dev20230305 58 | tqdm==4.65.0 59 | triton==2.0.0 60 | Werkzeug==2.2.3 61 | yarl==1.8.2 62 | zipp==3.15.0 63 | -------------------------------------------------------------------------------- /gptnano/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch.nn as nn 4 | from PIL import Image 5 | from torch.utils.data import Dataset, DataLoader 6 | import matplotlib.pyplot as plt 7 | from datasets import LIDCDataset 8 | 9 | 10 | # --------------------------------------------- # 11 | # Data Utils 12 | # --------------------------------------------- # 13 | 14 | def load_data(args): 15 | train_data = LIDCDataset( 16 | root_dir=args.path_to_preprocessed_data, 17 | augmentation=False, 18 | projection=False, 19 | return_indices=True, 20 | indices_path=args.path_to_data_indices, 21 | split='train' 22 | ) 23 | 24 | val_data = LIDCDataset( 25 | root_dir=args.path_to_preprocessed_data, 26 | augmentation=False, 27 | projection=False, 28 | return_indices=True, 29 | indices_path=args.path_to_data_indices, 30 | split='val' 31 | ) 32 | 33 | test_data = LIDCDataset( 34 | root_dir=args.path_to_preprocessed_data, 35 | augmentation=False, 36 | projection=False, 37 | return_indices=True, 38 | indices_path=args.path_to_data_indices, 39 | split='test' 40 | ) 41 | 42 | train_loader = DataLoader( 43 | train_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 44 | 45 | val_loader = DataLoader( 46 | val_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 47 | 48 | test_loader = DataLoader( 49 | test_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 50 | return train_loader, val_loader, test_loader 51 | 52 | 53 | # --------------------------------------------- # 54 | # Module Utils 55 | # for Encoder, Decoder etc. 56 | # --------------------------------------------- # 57 | 58 | def weights_init(m): 59 | classname = m.__class__.__name__ 60 | if classname.find('Conv') != -1: 61 | nn.init.normal_(m.weight.data, 0.0, 0.02) 62 | elif classname.find('BatchNorm') != -1: 63 | nn.init.normal_(m.weight.data, 1.0, 0.02) 64 | nn.init.constant_(m.bias.data, 0) 65 | 66 | 67 | def plot_images(images, slice_idx=60): 68 | x = images["input"] 69 | reconstruction = images["rec"] 70 | half_sample = images["half_sample"] 71 | full_sample = images["full_sample"] 72 | 73 | fig, axarr = plt.subplots(1, 4) 74 | axarr[0].imshow(x.cpu().detach().numpy()[0][0][slice_idx]) 75 | axarr[1].imshow(reconstruction.cpu().detach().numpy() 76 | [0][0][slice_idx]) 77 | axarr[2].imshow(half_sample.cpu().detach().numpy()[0][0][slice_idx]) 78 | axarr[3].imshow(full_sample.cpu().detach().numpy()[0][0][slice_idx]) 79 | plt.show() 80 | -------------------------------------------------------------------------------- /datasets/data_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.utils.data.dataloader import DataLoader 4 | import torch.multiprocessing as mp 5 | from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler 6 | 7 | 8 | 9 | class SimpleDataModule(pl.LightningDataModule): 10 | 11 | def __init__(self, 12 | ds_train: object, 13 | ds_val:object =None, 14 | ds_test:object =None, 15 | batch_size: int = 1, 16 | num_workers: int = mp.cpu_count(), 17 | seed: int = 0, 18 | pin_memory: bool = False, 19 | weights: list = None 20 | ): 21 | super().__init__() 22 | self.hyperparameters = {**locals()} 23 | self.hyperparameters.pop('__class__') 24 | self.hyperparameters.pop('self') 25 | 26 | self.ds_train = ds_train 27 | self.ds_val = ds_val 28 | self.ds_test = ds_test 29 | 30 | self.batch_size = batch_size 31 | self.num_workers = num_workers 32 | self.seed = seed 33 | self.pin_memory = pin_memory 34 | self.weights = weights 35 | 36 | 37 | 38 | def train_dataloader(self): 39 | generator = torch.Generator() 40 | generator.manual_seed(self.seed) 41 | 42 | if self.weights is not None: 43 | sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator) 44 | else: 45 | sampler = RandomSampler(self.ds_train, replacement=False, generator=generator) 46 | return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers, 47 | sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory) 48 | 49 | 50 | def val_dataloader(self): 51 | generator = torch.Generator() 52 | generator.manual_seed(self.seed) 53 | if self.ds_val is not None: 54 | return DataLoader(self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, 55 | generator=generator, drop_last=False, pin_memory=self.pin_memory) 56 | else: 57 | raise AssertionError("A validation set was not initialized.") 58 | 59 | 60 | def test_dataloader(self): 61 | generator = torch.Generator() 62 | generator.manual_seed(self.seed) 63 | if self.ds_test is not None: 64 | return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, 65 | generator = generator, drop_last=False, pin_memory=self.pin_memory) 66 | else: 67 | raise AssertionError("A test set was not initialized.") 68 | 69 | 70 | -------------------------------------------------------------------------------- /datasets/lidc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.dataset import Dataset 4 | import os 5 | from torchvision import transforms 6 | import glob 7 | 8 | 9 | class LIDCDataset(Dataset): 10 | def __init__(self, root_dir='../LIDC', augmentation=False, projection=False, projection_plane=None, return_indices=False, indices_path=None, split='train'): 11 | self.root_dir = root_dir 12 | self.split = split 13 | self.file_names = self.get_split() 14 | self.augmentation = augmentation 15 | self.projection = projection 16 | self.projection_plane = projection_plane 17 | self.return_indices = return_indices 18 | self.indices_path = indices_path 19 | 20 | def get_split(self): 21 | file_names_ = glob.glob(os.path.join( 22 | self.root_dir, './**/*.npy'), recursive=True) 23 | if self.split == 'train': 24 | # take 70% of the data 25 | file_names = file_names_[:int(len(file_names_)*0.7)] 26 | if self.split == 'val': 27 | # take 20% of the data 28 | file_names = file_names_[ 29 | int(len(file_names_)*0.7):int(len(file_names_)*0.9)] 30 | if self.split == 'test': 31 | file_names = file_names_[int(len(file_names_)*0.9):] 32 | return file_names 33 | 34 | def __len__(self): 35 | return len(self.file_names) 36 | 37 | def __getitem__(self, index): 38 | if self.return_indices: 39 | path = self.file_names[index] 40 | indices_path = os.path.join( 41 | self.indices_path, self.split, path.split('/')[-2]) 42 | indices_ct = torch.tensor( 43 | np.load(os.path.join(indices_path, 'CT.npy'))) 44 | indices_ap = torch.tensor( 45 | np.load(os.path.join(indices_path, 'ap.npy'))) 46 | indices_lat = torch.tensor( 47 | np.load(os.path.join(indices_path, 'lat.npy'))) 48 | return {'indices_ct': indices_ct, 'indices_ap': indices_ap, 'indices_lat': indices_lat, 'file_name': path} 49 | 50 | else: 51 | path = self.file_names[index] 52 | img = np.load(path) 53 | 54 | if self.augmentation: 55 | random_n = torch.rand(1) 56 | if random_n[0] > 0.5: 57 | img = np.flip(img, 2) 58 | 59 | if self.projection: 60 | if self.projection_plane is None: 61 | projection_plane = np.random.choice(['ap', 'lat']) 62 | else: 63 | projection_plane = self.projection_plane 64 | 65 | # Flip the image because for some reason the projection is flipped. 66 | # Note that we take the mean projection here for ease of use, since we have seen no difference in image quality when doing so. 67 | if projection_plane == 'ap': 68 | img = np.flip(img.mean(axis=1), 0) 69 | elif projection_plane == 'lat': 70 | img = np.flip(img.mean(axis=2), 0) 71 | 72 | imageout = torch.from_numpy(img.copy()).float() 73 | imageout = imageout.unsqueeze(0) 74 | 75 | return {'source': imageout, 'file_name': path} 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Transformers for CT Reconstruction From Monoplanar and Biplanar Radiographs 2 | ========================= 3 | 4 | This repository contains the code to our corresponding publication "Transformers for CT Reconstruction From Monoplanar and Biplanar Radiographs" 5 | 6 | - [Springer Proceedings](https://link.springer.com/chapter/10.1007/978-3-031-44689-4_1) 7 | - [Arxiv Preprint](https://arxiv.org/abs/2305.06965) 8 | 9 | 10 | ![alt text](assets/model.png) 11 | 12 | # System Requirements 13 | This code has been tested on Ubuntu 20.04 and an NVIDIA Quadro RTX A6000 GPU. Furthermore it was developed using Python v3.8. 14 | 15 | # Setup 16 | 17 | In order to run this model, please download the LIDC-IDRI dataset (https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254). 18 | 19 | Additionally, create a virtual environment (e.g. with conda): 20 | ```` 21 | conda create -n ct_reconstruction python=3.8 22 | ```` 23 | and run 24 | ``` 25 | conda activate ct_reconstruction 26 | ``` 27 | followed by 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | to download and install the required dependencies. 32 | 33 | # Preprocessing 34 | 35 | Next, before we run the code we have to preprocess the CT images and convert them from DICOM to .npy files. To do so, run 36 | ``` 37 | python preprocessing/preprocess.py --input_path --path_output 38 | ``` 39 | 40 | where points to the folder containing all LIDC-IDRI DICOM folders and points to the desired output location. 41 | 42 | # Training the VQ-GAN models 43 | 44 | Once everything is set up, we can start training our models. The first step is to train the VQ-GAN models. Note that this will be done in two steps. First, we train a normal VQ-VAE (i.e., the VQ-GAN without the discriminator) for a few epochs and subsequently train the VQ-GAN. 45 | 46 | To train the 2D model, run the following commands; 47 | To train the VQ-VAE, run: 48 | 49 | ``` 50 | python vqgan/train.py --mode 2D --data-dir 51 | ``` 52 | After training has finished, copy the path to the best checkpoint (e.g., ".../runs/2023_10_16_180438/epoch=XXX-step=XXX.ckpt") 53 | and train the VQ-GAN model using: 54 | ``` 55 | python vqgan/train.py --mode 2D --data-dir --best-vq-vae-ckpt 56 | ``` 57 | 58 | The same steps can be repeated to train the 3D model, i.e.: 59 | ``` 60 | python vqgan/train.py --mode 3D --data-dir 61 | ``` 62 | and 63 | ``` 64 | python vqgan/train.py --mode 3D --data-dir --best-vq-vae-ckpt 65 | ``` 66 | 67 | Before we continue with training the GPT model, and in order to speed up the training process, we first convert all images (the 3D CT images and the 2D digitally reconstructed radiographs) into their discrete latent indices. 68 | This can be done by navigating to ./vqgan/evaluate.ipynb and running the jupyter notebook. 69 | NOTE: The Jupyter-Notebook has to be run three times with different settings. Please check out the second cell detailing the exact settings required in the jupyter-notebook itself. 70 | 71 | # Training the GPT model 72 | To train the GPT model, run the following command: 73 | ``` 74 | python gptnano/train.py --num-codebook-vectors 8192 --checkpoint-path-3d-vqgan --checkpoint-path-2d-vqgan --path-to-preprocessed-data --path-to-data-indices 75 | ``` 76 | where and point to the best checkpoints of the 2D and 3D VQ-GAN models (the same ones used in ./vqgan/evaluate.ipynb) and points to the folder containing the pre-extracted latent codebook indices of the images (this is set in the ./vqgan/evaluate.ipynb notebook as STORAGE_DIR) 77 | 78 | Once the training has finished, we can synthesize the CT images. To do so, navigate to ./gptnano/evaluate.ipynb and set the necessary paths. 79 | 80 | 81 | # Citation 82 | To cite our work, please use 83 | ```bibtex 84 | @inproceedings{khader_transformers_2023, 85 | address = {Cham}, 86 | title = {Transformers for {CT} {Reconstruction} from {Monoplanar} and {Biplanar} {Radiographs}}, 87 | isbn = {978-3-031-44689-4}, 88 | booktitle = {Simulation and {Synthesis} in {Medical} {Imaging}}, 89 | publisher = {Springer Nature Switzerland}, 90 | author = {Khader, Firas and Müller-Franzes, Gustav and Han, Tianyu and Nebelung, Sven and Kuhl, Christiane and Stegmaier, Johannes and Truhn, Daniel}, 91 | year = {2023}, 92 | pages = {1--10}, 93 | } 94 | ``` 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /vqgan/model_base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.utilities.cloud_io import load as pl_load 9 | from pytorch_lightning.utilities.migration import pl_legacy_patch 10 | 11 | class VeryBasicModel(pl.LightningModule): 12 | def __init__(self): 13 | super().__init__() 14 | self.save_hyperparameters() 15 | self._step_train = 0 16 | self._step_val = 0 17 | self._step_test = 0 18 | 19 | 20 | def forward(self, x_in): 21 | raise NotImplementedError 22 | 23 | def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx:int): 24 | raise NotImplementedError 25 | 26 | def training_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0 ): 27 | self._step_train += 1 # =self.global_step 28 | return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx) 29 | 30 | def validation_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): 31 | self._step_val += 1 32 | return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx ) 33 | 34 | def test_step(self, batch: dict, batch_idx: int, optimizer_idx:int = 0): 35 | self._step_test += 1 36 | return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx) 37 | 38 | def _epoch_end(self, outputs: list, state: str): 39 | return 40 | 41 | def training_epoch_end(self, outputs): 42 | self._epoch_end(outputs, "train") 43 | 44 | def validation_epoch_end(self, outputs): 45 | self._epoch_end(outputs, "val") 46 | 47 | def test_epoch_end(self, outputs): 48 | self._epoch_end(outputs, "test") 49 | 50 | @classmethod 51 | def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): 52 | with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: 53 | json.dump({'best_model_epoch': Path(best_model_path).name}, f) 54 | 55 | @classmethod 56 | def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs): 57 | path_version = 'lightning_logs/version_'+str(version) 58 | with open(Path(path_checkpoint_dir) / path_version/ 'best_checkpoint.json', 'r') as f: 59 | path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) 60 | return Path(path_checkpoint_dir)/path_rel_best_checkpoint 61 | 62 | @classmethod 63 | def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs): 64 | path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version) 65 | return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) 66 | 67 | def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): 68 | if checkpoint_path.is_dir(): 69 | checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) 70 | 71 | with pl_legacy_patch(): 72 | if map_location is not None: 73 | checkpoint = pl_load(checkpoint_path, map_location=map_location) 74 | else: 75 | checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) 76 | return self.load_weights(checkpoint["state_dict"], **kwargs) 77 | 78 | def load_weights(self, pretrained_weights, strict=True, **kwargs): 79 | filter = kwargs.get('filter', lambda key:key in pretrained_weights) 80 | init_weights = self.state_dict() 81 | pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)} 82 | init_weights.update(pretrained_weights) 83 | self.load_state_dict(init_weights, strict=strict) 84 | return self 85 | 86 | 87 | 88 | 89 | class BasicModel(VeryBasicModel): 90 | def __init__(self, 91 | optimizer=torch.optim.AdamW, 92 | optimizer_kwargs={'lr':1e-3, 'weight_decay':1e-2}, 93 | lr_scheduler= None, 94 | lr_scheduler_kwargs={}, 95 | ): 96 | super().__init__() 97 | self.save_hyperparameters() 98 | self.optimizer = optimizer 99 | self.optimizer_kwargs = optimizer_kwargs 100 | self.lr_scheduler = lr_scheduler 101 | self.lr_scheduler_kwargs = lr_scheduler_kwargs 102 | 103 | def configure_optimizers(self): 104 | optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) 105 | if self.lr_scheduler is not None: 106 | lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) 107 | return [optimizer], [lr_scheduler] 108 | else: 109 | return [optimizer] 110 | -------------------------------------------------------------------------------- /vqgan/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/mueller-franzes/medfusion 3 | """ 4 | 5 | from pathlib import Path 6 | from datetime import datetime 7 | import argparse 8 | 9 | import torch 10 | from torch.utils.data import ConcatDataset 11 | from pytorch_lightning.trainer import Trainer 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | 14 | 15 | from datasets.data_module import SimpleDataModule 16 | from datasets import LIDCDataset 17 | from vqgan.model import VQVAE, VQGAN, VAE, VAEGAN 18 | 19 | import torch.multiprocessing 20 | torch.multiprocessing.set_sharing_strategy('file_system') 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Training Script") 25 | parser.add_argument("--mode", type=str, choices=["2D", "3D"], default="2D", help="Whether to train 2D or 3D model.") 26 | parser.add_argument("--data-dir", type=str, default="/data/LIDC/preprocessed", help="Directory containing preprocessed LIDC-IDRI data.") 27 | parser.add_arguments("--best-vq-vae-ckpt", type=str, default=None, help="Path to the best checkpoint for the VQ-VAE model.") 28 | return parser.parse_args() 29 | 30 | 31 | if __name__ == "__main__": 32 | args = parse_args() 33 | 34 | # --------------- Settings -------------------- 35 | current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") 36 | path_run_dir = Path.cwd() / 'runs' / str(current_time) 37 | path_run_dir.mkdir(parents=True, exist_ok=True) 38 | gpus = [0] if torch.cuda.is_available() else None 39 | 40 | if args.mode == "2D": 41 | lidc_dataset_train = LIDCDataset( 42 | root_dir=args.data_dir, augmentation=False, projection=True, split='train') 43 | 44 | lidc_dataset_val = LIDCDataset( 45 | root_dir=args.data_dir, augmentation=False, projection=True, split='val') 46 | 47 | lidc_dataset_test = LIDCDataset( 48 | root_dir=args.data_dir, augmentation=False, projection=True, split='test') 49 | 50 | dm = SimpleDataModule( 51 | ds_train=lidc_dataset_train, 52 | ds_val=lidc_dataset_val, 53 | ds_test=lidc_dataset_test, 54 | batch_size=50, 55 | num_workers=30, 56 | pin_memory=True 57 | ) 58 | 59 | if not args.best_vq_vae_ckpt: 60 | model = VQVAE( 61 | in_channels=1, 62 | out_channels=1, 63 | emb_channels=512, 64 | num_embeddings=8192, 65 | spatial_dims=2, 66 | hid_chs=[64, 128, 256, 512], 67 | kernel_sizes=[3, 3, 3, 3], 68 | strides=[1, 2, 2, 2], 69 | embedding_loss_weight=1, 70 | beta=1, 71 | loss=torch.nn.L1Loss, 72 | deep_supervision=1, 73 | use_attention='none', 74 | sample_every_n_steps=50, 75 | ) 76 | else: 77 | model = VQGAN( 78 | in_channels=1, 79 | out_channels=1, 80 | emb_channels=512, 81 | num_embeddings=8192, 82 | spatial_dims=2, 83 | hid_chs=[64, 128, 256, 512], 84 | kernel_sizes=[3, 3, 3, 3], 85 | strides=[1, 2, 2, 2], 86 | embedding_loss_weight=1, 87 | beta=1, 88 | pixel_loss=torch.nn.L1Loss, 89 | deep_supervision=1, 90 | use_attention='none', 91 | sample_every_n_steps=50, 92 | ) 93 | 94 | model.vqvae.load_pretrained(args.best_vq_vae_ckpt) 95 | 96 | elif args.mode == "3D": 97 | lidc_dataset_train = LIDCDataset( 98 | root_dir=args.data_dir, augmentation=False, split='train') 99 | 100 | lidc_dataset_val = LIDCDataset( 101 | root_dir=args.data_dir, augmentation=False, split='val') 102 | 103 | lidc_dataset_test = LIDCDataset( 104 | root_dir=args.data_dir, augmentation=False, split='test') 105 | 106 | dm = SimpleDataModule( 107 | ds_train=lidc_dataset_train, 108 | ds_val=lidc_dataset_val, 109 | ds_test=lidc_dataset_test, 110 | batch_size=2, 111 | num_workers=30, 112 | pin_memory=True 113 | ) 114 | 115 | if not args.best_vq_vae_ckpt: 116 | model = VQVAE( 117 | in_channels=1, 118 | out_channels=1, 119 | emb_channels=256, 120 | num_embeddings=8192, 121 | spatial_dims=3, 122 | hid_chs=[32, 64, 128, 256], 123 | kernel_sizes=[3, 3, 3, 3], 124 | strides=[1, 2, 2, 2], 125 | embedding_loss_weight=1, 126 | beta=1, 127 | loss=torch.nn.L1Loss, 128 | deep_supervision=0, 129 | use_attention='none', 130 | norm_name=("GROUP", {'num_groups': 4, "affine": True}), 131 | sample_every_n_steps=200, 132 | ) 133 | else: 134 | model = VQGAN( 135 | in_channels=1, 136 | out_channels=1, 137 | emb_channels=256, 138 | num_embeddings=8192, 139 | spatial_dims=3, 140 | hid_chs=[32, 64, 128, 256], 141 | kernel_sizes=[3, 3, 3, 3], 142 | strides=[1, 2, 2, 2], 143 | embedding_loss_weight=1, 144 | beta=1, 145 | pixel_loss=torch.nn.L1Loss, 146 | deep_supervision=0, 147 | use_attention='none', 148 | norm_name=("GROUP", {'num_groups': 4, "affine": True}), 149 | sample_every_n_steps=200, 150 | ) 151 | 152 | model.vqvae.load_pretrained(args.best_vq_vae_ckpt) 153 | 154 | ############################################## 155 | 156 | # -------------- Training Initialization --------------- 157 | to_monitor = "val/ssim_epoch" # "train/L1" # "val/loss" 158 | min_max = "max" 159 | save_and_sample_every = 50 160 | 161 | early_stopping = EarlyStopping( 162 | monitor=to_monitor, 163 | min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement 164 | patience=30, # number of checks with no improvement 165 | mode=min_max 166 | ) 167 | checkpointing = ModelCheckpoint( 168 | dirpath=str(path_run_dir), # dirpath 169 | monitor=to_monitor, 170 | every_n_train_steps=save_and_sample_every, 171 | save_last=True, 172 | save_top_k=1, 173 | mode=min_max, 174 | ) 175 | trainer = Trainer( 176 | accelerator='gpu', 177 | devices=[1], 178 | # precision=16, 179 | # amp_backend='apex', 180 | # amp_level='O2', 181 | # gradient_clip_val=0.5, 182 | default_root_dir=str(path_run_dir), 183 | callbacks=[checkpointing], 184 | # callbacks=[checkpointing, early_stopping], 185 | enable_checkpointing=True, 186 | check_val_every_n_epoch=1, 187 | log_every_n_steps=save_and_sample_every, 188 | auto_lr_find=False, 189 | # limit_train_batches=1000, 190 | # limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available 191 | min_epochs=100, 192 | max_epochs=1001, 193 | num_sanity_val_steps=2, 194 | ) 195 | 196 | # ---------------- Execute Training ---------------- 197 | trainer.fit(model, datamodule=dm) 198 | 199 | # ------------- Save path to best model ------------- 200 | model.save_best_checkpoint( 201 | trainer.logger.log_dir, checkpointing.best_model_path) 202 | -------------------------------------------------------------------------------- /preprocessing/preprocess.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/peterhan91/cycleGAN/blob/db8f1d958c0879c29cf3932cae74a166317be812/prepro.py#L39""" 2 | 3 | import os 4 | import numpy as np 5 | from glob import glob 6 | import pydicom 7 | import scipy.ndimage 8 | from pathlib import Path 9 | import argparse 10 | from multiprocessing import Pool, cpu_count 11 | from tqdm import tqdm 12 | 13 | 14 | 15 | class CTExtractor: 16 | def __init__(self, input_path, out_path): 17 | super(CTExtractor, self).__init__() 18 | 19 | self.MIN_BOUND = -1000.0 20 | self.MAX_BOUND = 400.0 21 | self.PIXEL_MEAN = 0.25 22 | self.roi = 320 23 | self.size = 128 24 | 25 | self.path = input_path 26 | self.outpath = out_path 27 | self.slices = [] 28 | self.fname = '' 29 | 30 | # Load the scans in given folder path 31 | def load_scan(self): 32 | slices_ = [pydicom.read_file(s) for s in glob( 33 | os.path.join(self.path, self.fname, '*/*/*.dcm'))] 34 | 35 | # Problem when CXR is available. This fixes it. 36 | num_subfolders = len(os.listdir(os.path.join(self.path, self.fname))) 37 | if num_subfolders > 1: 38 | print(f"Filename: {self.fname}, No. Subfolders: {num_subfolders}") 39 | slices = [] 40 | for s in slices_: 41 | if s.Modality == 'CT': 42 | slices.append(s) 43 | else: 44 | print(s.Modality) 45 | else: 46 | slices = slices_ 47 | 48 | slices.sort(key=lambda x: float(x.ImagePositionPatient[2])) 49 | try: 50 | slice_thickness = np.abs( 51 | slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2]) 52 | except: 53 | slice_thickness = np.abs( 54 | slices[0].SliceLocation - slices[1].SliceLocation) 55 | 56 | for s in slices: 57 | s.SliceThickness = slice_thickness 58 | if s.Modality != 'CT': 59 | print(f"NOT A CT. This is a {s.Modality}") 60 | 61 | return slices 62 | 63 | def get_pixels_hu(self, slices): 64 | image = np.stack([s.pixel_array for s in slices]) 65 | # Convert to int16 (from sometimes int16), 66 | # should be possible as values should always be low enough (<32k) 67 | image = image.astype(np.int16) 68 | 69 | # Set outside-of-scan pixels to 0 70 | # The intercept is usually -1024, so air is approximately 0 71 | image[image == -2000] = 0 72 | 73 | # Convert to Hounsfield units (HU) 74 | for slice_number in range(len(slices)): 75 | 76 | intercept = slices[slice_number].RescaleIntercept 77 | slope = slices[slice_number].RescaleSlope 78 | 79 | if slope != 1: 80 | image[slice_number] = slope * \ 81 | image[slice_number].astype(np.float64) 82 | image[slice_number] = image[slice_number].astype(np.int16) 83 | 84 | image[slice_number] += np.int16(intercept) 85 | 86 | return np.array(image, dtype=np.int16) 87 | 88 | def resample(self, image, scan, new_spacing=[1.0, 1.0, 1.0]): 89 | # Determine current pixel spacing 90 | # print(scan[0].SliceThickness) 91 | # print(scan[0].PixelSpacing) 92 | spacing = np.array([scan[0].SliceThickness] + 93 | list(scan[0].PixelSpacing), dtype=np.float32) 94 | 95 | resize_factor = spacing / new_spacing 96 | new_real_shape = image.shape * resize_factor 97 | new_shape = np.round(new_real_shape) 98 | real_resize_factor = new_shape / image.shape 99 | new_spacing = spacing / real_resize_factor 100 | 101 | image = scipy.ndimage.interpolation.zoom( 102 | image, real_resize_factor, mode='nearest') 103 | 104 | return image, new_spacing 105 | 106 | def normalize(self, image): 107 | image = (image - self.MIN_BOUND) / (self.MAX_BOUND - self.MIN_BOUND) 108 | image[image > 1] = 1. 109 | image[image < 0] = 0. 110 | return image*2-1. 111 | 112 | def zero_center(self, image): 113 | image = image - self.PIXEL_MEAN 114 | return image 115 | 116 | def pad_center(self, pix_resampled): 117 | pad_z = max(self.roi - pix_resampled.shape[0], 0) 118 | pad_x = max(self.roi - pix_resampled.shape[1], 0) 119 | pad_y = max(self.roi - pix_resampled.shape[2], 0) 120 | try: 121 | pad = np.pad(pix_resampled, 122 | [(pad_z//2, pad_z-pad_z//2), (pad_x//2, 123 | pad_x-pad_x//2), (pad_y//2, pad_y-pad_y//2)], 124 | mode='constant', 125 | constant_values=pix_resampled[0][10][10]) 126 | except ValueError: 127 | print(pix_resampled.shape) 128 | except IndexError: 129 | print(pix_resampled.shape) 130 | pass 131 | return pad 132 | 133 | def crop_center(self, vol, cropz, cropy, cropx): 134 | z, y, x = vol.shape 135 | startx = x//2-(cropx//2) 136 | starty = y//2-(cropy//2) 137 | startz = z//2-(cropz//2) 138 | return vol[startz:startz+cropz, starty:starty+cropy, startx:startx+cropx] 139 | 140 | def save(self): 141 | path = os.path.join(self.outpath, self.fname, '128.npy') 142 | Path(os.path.join(self.outpath, self.fname)).mkdir( 143 | parents=True, exist_ok=True) 144 | np.save(path, self.vol) 145 | 146 | def run(self, fname): 147 | self.fname = fname 148 | self.patient = self.load_scan() 149 | self.vol = self.get_pixels_hu(self.patient) 150 | self.vol, _ = self.resample(self.vol, self.patient) 151 | if self.vol.shape[0] >= self.roi and self.vol.shape[1] >= self.roi and self.vol.shape[2] >= self.roi: 152 | self.vol = self.crop_center(self.vol, self.roi, self.roi, self.roi) 153 | else: 154 | self.vol = self.pad_center(self.vol) 155 | self.vol = self.crop_center(self.vol, self.roi, self.roi, self.roi) 156 | assert self.vol.shape == (self.roi, self.roi, self.roi) 157 | self.vol = scipy.ndimage.zoom(self.vol, 158 | [self.size/self.roi, self.size / 159 | self.roi, self.size/self.roi], 160 | mode='nearest') 161 | assert self.vol.shape == (self.size, self.size, self.size) 162 | self.vol = self.normalize(self.vol) 163 | self.save() 164 | 165 | 166 | def worker(fname, extractor): 167 | try: 168 | extractor.run(fname) 169 | except: 170 | print('Error extracting the lung CT') 171 | print(fname) 172 | 173 | 174 | if __name__ == "__main__": 175 | # Argument Parsing 176 | parser = argparse.ArgumentParser(description='CTExtractor for processing CT scans.') 177 | parser.add_argument('--input_path', type=str, required=True, help='Path to the input CT scans directory') 178 | parser.add_argument('--path_output', type=str, required=True, help='Path to the directory to save processed CT scans') 179 | args = parser.parse_args() 180 | 181 | input_path = args.input_path 182 | path_output = args.path_output 183 | 184 | 185 | extractor = CTExtractor(input_path, path_output) 186 | 187 | def worker_partial(fname): 188 | return worker(fname, extractor) 189 | 190 | fnames = os.listdir(input_path) 191 | print('total # of scans', len(fnames)) 192 | 193 | with Pool(processes=4) as pool: 194 | res = list(tqdm(pool.imap( 195 | worker_partial, iter(fnames)), total=len(fnames))) 196 | -------------------------------------------------------------------------------- /gptnano/evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import numpy as np\n", 11 | "from tqdm import tqdm\n", 12 | "import argparse\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.nn.functional as F\n", 16 | "from torchvision import utils as vutils\n", 17 | "from gptnano.translator import VQGANTransformer\n", 18 | "from utils import load_data, plot_images\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import imageio\n", 21 | "import itertools\n", 22 | "from contextlib import nullcontext\n", 23 | "\n", 24 | "from metrics import Structural_Similarity\n", 25 | "from scipy.ndimage.filters import gaussian_filter" 26 | ] 27 | }, 28 | { 29 | "attachments": {}, 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Load Models" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "class Args:\n", 43 | "\tnum_codebook_vectors = 8192\n", 44 | "\tcheckpoint_path_3d_vqgan = '' # Path to the best 3D VQ-GAN model checkpoint\n", 45 | "\tcheckpoint_path_2d_vqgan = '' # Path to the best 2D VQ-GAN model checkpoint\n", 46 | "\tcheckpoint_path_gpt = '' # Path to the best GPT model checkpoint (located under .../gpt_results/run_X/checkpoint/transformer_X_X.pt)\n", 47 | "\tpkeep = 0.5\n", 48 | "\tsos_token = 0\n", 49 | "\tblock_size = 4096 + 256 * 2\n", 50 | "\tn_unmasked = 256 * 2 + 1\n", 51 | "\tdevice = \"cuda:0\"\n", 52 | "\tbatch_size = 1\n", 53 | "\tepochs = 100\n", 54 | "\tlearning_rate = 2.25e-05\n", 55 | "\tnum_workers = 1\n", 56 | " \n", 57 | "args = Args()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "model = VQGANTransformer(args).to(device=args.device)\n", 67 | "\n", 68 | "model.load_gpt(args, strict=True)" 69 | ] 70 | }, 71 | { 72 | "attachments": {}, 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "## Sample Images" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "train_dataloader, val_dataloader, test_dataloader = load_data(args)\n", 86 | "\n", 87 | "index= 192\n", 88 | "data = next(itertools.islice(iter(test_dataloader), index, None))\n", 89 | "\n", 90 | "imgs_ct = data['indices_ct']\n", 91 | "imgs_ap = data['indices_ap']\n", 92 | "imgs_lat = data['indices_lat']\n", 93 | "\n", 94 | "imgs_ct = imgs_ct.to(device=args.device)\n", 95 | "imgs_ap = imgs_ap.to(device=args.device)\n", 96 | "imgs_lat = imgs_lat.to(device=args.device)\n", 97 | "\n", 98 | "orig_ct = np.load(data['file_name'][0])" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "device=args.device\n", 108 | "dtype = 'bfloat16'\n", 109 | "\n", 110 | "# for later use in torch.autocast\n", 111 | "device_type = 'cuda' if 'cuda' in device else 'cpu'\n", 112 | "# note: float16 data type will automatically use a GradScaler\n", 113 | "ptdtype = {'float32': torch.float32,\n", 114 | " 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n", 115 | "ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(\n", 116 | " device_type=device_type, dtype=ptdtype)\n", 117 | "\n", 118 | "with ctx:\n", 119 | " log, sampled_imgs_ct, sampled_imgs_ap, sampled_imgs_lat = model.log_images(\n", 120 | " (imgs_ct[0][None], imgs_ap[0][None], imgs_lat[0][None]), temperature=1.0, top_k=100)\n", 121 | "\n", 122 | "sampled_imgs_ct = sampled_imgs_ct.float()\n", 123 | "sampled_imgs_ap = sampled_imgs_ap.float()\n", 124 | "sampled_imgs_lat = sampled_imgs_lat.float()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "\n", 134 | "for slide in range(0, len(orig_ct), 10):\n", 135 | "\t# Plot the results side by side\n", 136 | "\timages = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0][slide], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0][slide], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0][slide], 'Original CT': orig_ct[slide]}\n", 137 | "\tfig, ax = plt.subplots(1, len(images), figsize=(20, 20))\n", 138 | "\tfor i, (title, image) in enumerate(images.items()):\n", 139 | "\t\tax[i].imshow(image, cmap='gray')\n", 140 | "\t\tax[i].axis('off')\n", 141 | "\t\tax[i].set_title(title)\n", 142 | "\tplt.show()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# Plot the results side by side\n", 152 | "\n", 153 | "images = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0], 'Original CT': orig_ct}\n", 154 | "fig, ax = plt.subplots(1, 2 * len(images), figsize=(20, 20))\n", 155 | "i = 0\n", 156 | "for title, image in images.items():\n", 157 | "\tax[i].imshow(np.flip(np.mean(np.array(image), axis=1), 0), cmap='gray')\n", 158 | "\tax[i].axis('off')\n", 159 | "\tax[i].set_title(title)\n", 160 | "\n", 161 | "\tax[i + 1].imshow(np.flip(np.mean(np.array(image), axis=2), 0), cmap='gray')\n", 162 | "\tax[i + 1].axis('off')\n", 163 | "\tax[i + 1].set_title(title)\n", 164 | "\ti+=2\n", 165 | "plt.show()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "# Plot the results side by side\n", 175 | "SLICE_IDX = 60\n", 176 | "\n", 177 | "images = {'CT Full reconstruction': sampled_imgs_ct.detach().cpu()[2][0], 'CT Half reconstruction': sampled_imgs_ct.detach().cpu()[1][0], 'Reconstructed (no GPT)': sampled_imgs_ct.detach().cpu()[0][0], 'Original CT': orig_ct}\n", 178 | "fig, ax = plt.subplots(len(images), 3, figsize=(20, 20))\n", 179 | "i = 0\n", 180 | "for title, image in images.items():\n", 181 | "\tax[i][0].imshow(np.flip(np.array(image[:, SLICE_IDX, :]), 0), cmap='gray')\n", 182 | "\tax[i][0].axis('off')\n", 183 | "\tax[i][0].set_title(title)\n", 184 | "\n", 185 | "\tax[i][1].imshow(np.flip(np.array(image[:, :, SLICE_IDX]), 0), cmap='gray')\n", 186 | "\tax[i][1].axis('off')\n", 187 | "\tax[i][1].set_title(title)\n", 188 | "\n", 189 | "\tax[i][2].imshow(np.flip(np.array(image[SLICE_IDX, :, :]), 0), cmap='gray')\n", 190 | "\tax[i][2].axis('off')\n", 191 | "\tax[i][2].set_title(title)\n", 192 | "\ti += 1\n", 193 | "plt.show()" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "x2ct_trans", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.9.16" 214 | }, 215 | "orig_nbformat": 4, 216 | "vscode": { 217 | "interpreter": { 218 | "hash": "e57f99a62812bf754688c50e8ec7c45df4e600ea8ca1cb4c958cb0a42792f43b" 219 | } 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 2 224 | } 225 | -------------------------------------------------------------------------------- /gptnano/mingpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | taken from: https://github.com/karpathy/minGPT/ 3 | GPT model: 4 | - the initial stem consists of a combination of token encoding and a positional encoding 5 | - the meat of it is a uniform sequence of Transformer blocks 6 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 7 | - all blocks feed into a central residual pathway similar to resnets 8 | - the final decoder is a linear projection into a vanilla Softmax classifier 9 | """ 10 | 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import functional as F 15 | from performer_pytorch import SelfAttention 16 | from xformers.components import MultiHeadDispatch 17 | from xformers.components.attention import FavorAttention 18 | from flash_attn.flash_attention import FlashMHA 19 | 20 | 21 | class GPTConfig: 22 | """ base GPT config, params common to all GPT versions """ 23 | embd_pdrop = 0.1 24 | resid_pdrop = 0.1 25 | attn_pdrop = 0.1 26 | 27 | def __init__(self, vocab_size, block_size, **kwargs): 28 | self.vocab_size = vocab_size 29 | self.block_size = block_size 30 | for k, v in kwargs.items(): 31 | setattr(self, k, v) 32 | 33 | 34 | class CausalSelfAttention(nn.Module): 35 | """ 36 | A vanilla multi-head masked self-attention layer with a projection at the end. 37 | It is possible to use torch.nn.MultiheadAttention here but I am including an 38 | explicit implementation here to show that there is nothing too scary here. 39 | """ 40 | 41 | def __init__(self, config): 42 | super().__init__() 43 | assert config.n_embd % config.n_head == 0 44 | # key, query, value projections for all heads 45 | self.key = nn.Linear(config.n_embd, config.n_embd) 46 | self.query = nn.Linear(config.n_embd, config.n_embd) 47 | self.value = nn.Linear(config.n_embd, config.n_embd) 48 | # regularization 49 | self.attn_drop = nn.Dropout(config.attn_pdrop) 50 | self.resid_drop = nn.Dropout(config.resid_pdrop) 51 | # output projection 52 | self.proj = nn.Linear(config.n_embd, config.n_embd) 53 | # causal mask to ensure that attention is only applied to the left in the input sequence 54 | mask = torch.tril(torch.ones(config.block_size, 55 | config.block_size)) 56 | if hasattr(config, "n_unmasked"): 57 | mask[:config.n_unmasked, :config.n_unmasked] = 1 58 | self.register_buffer("mask", mask.view( 59 | 1, 1, config.block_size, config.block_size)) 60 | self.n_head = config.n_head 61 | 62 | def forward(self, x, layer_past=None): 63 | B, T, C = x.size() 64 | 65 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 66 | k = self.key(x).view(B, T, self.n_head, C // 67 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | q = self.query(x).view(B, T, self.n_head, C // 69 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 70 | v = self.value(x).view(B, T, self.n_head, C // 71 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 72 | 73 | present = torch.stack((k, v)) 74 | if layer_past is not None: 75 | past_key, past_value = layer_past 76 | k = torch.cat((past_key, k), dim=-2) 77 | v = torch.cat((past_value, v), dim=-2) 78 | 79 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 80 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 81 | if layer_past is None: 82 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) 83 | 84 | att = F.softmax(att, dim=-1) 85 | att = self.attn_drop(att) 86 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 87 | # re-assemble all head outputs side by side 88 | y = y.transpose(1, 2).contiguous().view(B, T, C) 89 | 90 | # output projection 91 | y = self.resid_drop(self.proj(y)) 92 | return y, present # TODO: check that this does not break anything 93 | 94 | 95 | class CausalSelfAttentionPytorchPerfomer(nn.Module): 96 | def __init__(self, config): 97 | super().__init__() 98 | assert config.n_embd % config.n_head == 0 99 | # key, query, value projections for all heads 100 | self.key = nn.Linear(config.n_embd, config.n_embd) 101 | self.query = nn.Linear(config.n_embd, config.n_embd) 102 | self.value = nn.Linear(config.n_embd, config.n_embd) 103 | # output projection 104 | self.proj = nn.Linear(config.n_embd, config.n_embd) 105 | 106 | dim_features = config.n_embd / config.n_head 107 | self.multi_head_attn = MultiHeadDispatch(attention=FavorAttention( 108 | dim_features=dim_features, dropout=config.attn_pdrop, causal=True, iter_before_redraw=100), num_heads=config.n_head, dim_model=config.n_embd, residual_dropout=config.resid_pdrop) 109 | 110 | def forward(self, x, layer_past=None): 111 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 112 | k = self.key(x) 113 | q = self.query(x) 114 | v = self.value(x) 115 | 116 | y = self.multi_head_attn(q, k, v) 117 | 118 | return y 119 | 120 | 121 | class Block(nn.Module): 122 | """ an unassuming Transformer block """ 123 | 124 | def __init__(self, config): 125 | super().__init__() 126 | self.config = config 127 | self.ln1 = nn.LayerNorm(config.n_embd) 128 | self.ln2 = nn.LayerNorm(config.n_embd) 129 | if config.use_normal_attention: 130 | self.attn = CausalSelfAttention(config) 131 | else: 132 | # self.attn = SelfAttention(dim=config.n_embd, heads=config.n_head, 133 | # dim_head=None, causal=True, dropout=config.attn_pdrop, nb_features=32, feature_redraw_interval=50) 134 | #self.attn = CausalSelfAttentionPytorchPerfomer(config) 135 | self.attn = FlashMHA(embed_dim=config.n_embd, num_heads=config.n_head, 136 | attention_dropout=config.attn_pdrop, causal=True, dtype=torch.float16) 137 | self.mlp = nn.Sequential( 138 | nn.Linear(config.n_embd, 4 * config.n_embd), 139 | nn.GELU(), # nice 140 | nn.Linear(4 * config.n_embd, config.n_embd), 141 | nn.Dropout(config.resid_pdrop), 142 | ) 143 | 144 | def forward(self, x, layer_past=None, return_present=False): 145 | # TODO: check that training still works 146 | if return_present: 147 | assert not self.training 148 | # layer past: tuple of length two with B, nh, T, hs 149 | if self.config.use_normal_attention: 150 | attn, present = self.attn(self.ln1(x), layer_past=layer_past) 151 | else: 152 | attn, present = self.attn(self.ln1(x)), None 153 | 154 | x = x + attn 155 | x = x + self.mlp(self.ln2(x)) 156 | if layer_past is not None or return_present: 157 | return x, present 158 | return x 159 | 160 | 161 | class GPT(nn.Module): 162 | """ the full GPT language model, with a context size of block_size """ 163 | 164 | def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, 165 | embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0, use_normal_attention=False): 166 | super().__init__() 167 | config = GPTConfig(vocab_size=vocab_size, block_size=block_size, 168 | embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, 169 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, 170 | n_unmasked=n_unmasked, use_normal_attention=use_normal_attention) 171 | # input embedding stem 172 | self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) 173 | self.pos_emb = nn.Parameter(torch.zeros( 174 | 1, config.block_size, config.n_embd)) # 512 x 1024 175 | self.drop = nn.Dropout(config.embd_pdrop) 176 | # transformer 177 | self.blocks = nn.Sequential(*[Block(config) 178 | for _ in range(config.n_layer)]) 179 | # decoder head 180 | self.ln_f = nn.LayerNorm(config.n_embd) 181 | self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 182 | 183 | self.block_size = config.block_size 184 | self.apply(self._init_weights) 185 | self.config = config 186 | 187 | def get_block_size(self): 188 | return self.block_size 189 | 190 | def _init_weights(self, module): 191 | if isinstance(module, (nn.Linear, nn.Embedding)): 192 | module.weight.data.normal_(mean=0.0, std=0.02) 193 | if isinstance(module, nn.Linear) and module.bias is not None: 194 | module.bias.data.zero_() 195 | elif isinstance(module, nn.LayerNorm): 196 | module.bias.data.zero_() 197 | module.weight.data.fill_(1.0) 198 | 199 | def forward(self, idx, embeddings=None): 200 | # each index maps to a (learnable) vector 201 | token_embeddings = self.tok_emb(idx) 202 | 203 | if embeddings is not None: # prepend explicit embeddings 204 | token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) 205 | 206 | t = token_embeddings.shape[1] 207 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 208 | # each position maps to a (learnable) vector 209 | position_embeddings = self.pos_emb[:, :t, :] 210 | x = self.drop(token_embeddings + position_embeddings) 211 | x = self.blocks(x) 212 | x = self.ln_f(x) 213 | logits = self.head(x) 214 | 215 | return logits, None 216 | -------------------------------------------------------------------------------- /vqgan/evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Set environment as current working directory\n", 10 | "import sys\n", 11 | "sys.path.append('..')\n", 12 | "\n", 13 | "from pathlib import Path\n", 14 | "from datetime import datetime\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "\n", 17 | "import torch\n", 18 | "from torch.utils.data import ConcatDataset\n", 19 | "from pytorch_lightning.trainer import Trainer\n", 20 | "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", 21 | "\n", 22 | "\n", 23 | "from datasets.data_module import SimpleDataModule\n", 24 | "from datasets import LIDCDataset\n", 25 | "from vqgan.model import VQVAE, VQGAN, VAE, VAEGAN\n", 26 | "\n", 27 | "import torch.multiprocessing\n", 28 | "torch.multiprocessing.set_sharing_strategy('file_system')\n", 29 | "import os\n", 30 | "import numpy as np\n", 31 | "from tqdm import tqdm" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# NOTE: This Notebook has to be run three times using different settings for the parameters below.\n", 41 | "# First Setting: Choose USE_2D = False\n", 42 | "# Second Setting: Choose USE_2D = True and PROJECTION_PLANE = 'lat'\n", 43 | "# Third Setting: Choose USE_2D = True and PROJECTION_PLANE = 'ap'\n", 44 | "\n", 45 | "USE_2D = False # True or False\n", 46 | "PROJECTION_PLANE = 'lat' # 'ap' or 'lat'\n", 47 | "PATH_TO_PREPROCESSED_DATA = '' # Replace this with the folder containing the preprocessed LIDC-IDRI dataset (i.e., )\n", 48 | "BEST_VQ_GAN_CKPT_2D = '' # Replace this with the best VQ-GAN checkpoint for the 2D model\n", 49 | "BEST_VQ_GAN_CKPT_3D = '' # Replace this with the best VQ-GAN checkpoint for the 3D model\n", 50 | "STORAGE_DIR = '' # Replace this with the desired path for storing the indices (e.g. /data/lidc_indices/)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "def create_dir(dir_path):\n", 60 | "\tif not os.path.exists(dir_path):\n", 61 | "\t\tos.makedirs(dir_path)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "gpus = [0] if torch.cuda.is_available() else None" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "if USE_2D == True:\n", 80 | " lidc_dataset_train = LIDCDataset(\n", 81 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='train', projection_plane=PROJECTION_PLANE)\n", 82 | "\n", 83 | " lidc_dataset_val = LIDCDataset(\n", 84 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='val', projection_plane=PROJECTION_PLANE)\n", 85 | "\n", 86 | " lidc_dataset_test = LIDCDataset(\n", 87 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, projection=True, split='test', projection_plane=PROJECTION_PLANE)\n", 88 | "\n", 89 | " dm = SimpleDataModule(\n", 90 | " ds_train=lidc_dataset_train,\n", 91 | " ds_val=lidc_dataset_val,\n", 92 | " ds_test=lidc_dataset_test,\n", 93 | " batch_size=1,\n", 94 | " num_workers=1,\n", 95 | " pin_memory=True\n", 96 | " )\n", 97 | "else:\n", 98 | " lidc_dataset_train = LIDCDataset(\n", 99 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='train')\n", 100 | "\n", 101 | " lidc_dataset_val = LIDCDataset(\n", 102 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='val')\n", 103 | "\n", 104 | " lidc_dataset_test = LIDCDataset(\n", 105 | " root_dir=PATH_TO_PREPROCESSED_DATA, augmentation=False, split='test')\n", 106 | "\n", 107 | " dm = SimpleDataModule(\n", 108 | " ds_train=lidc_dataset_train,\n", 109 | " ds_val=lidc_dataset_val,\n", 110 | " ds_test=lidc_dataset_test,\n", 111 | " batch_size=1,\n", 112 | " num_workers=30,\n", 113 | " pin_memory=True\n", 114 | " )" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "if USE_2D:\n", 124 | " model = VQGAN(\n", 125 | " in_channels=1,\n", 126 | " out_channels=1,\n", 127 | " emb_channels=512,\n", 128 | " num_embeddings=8192,\n", 129 | " spatial_dims=2,\n", 130 | " hid_chs=[64, 128, 256, 512],\n", 131 | " kernel_sizes=[3, 3, 3, 3],\n", 132 | " strides=[1, 2, 2, 2],\n", 133 | " embedding_loss_weight=1,\n", 134 | " beta=1,\n", 135 | " pixel_loss=torch.nn.L1Loss,\n", 136 | " deep_supervision=1,\n", 137 | " use_attention='none',\n", 138 | " sample_every_n_steps=50,\n", 139 | " )\n", 140 | "\n", 141 | " model.load_pretrained(BEST_VQ_GAN_CKPT_2D)\n", 142 | "else:\n", 143 | " model = VQGAN(\n", 144 | " in_channels=1,\n", 145 | " out_channels=1,\n", 146 | " emb_channels=256,\n", 147 | " num_embeddings=8192,\n", 148 | " spatial_dims=3,\n", 149 | " hid_chs=[32, 64, 128, 256],\n", 150 | " kernel_sizes=[3, 3, 3, 3],\n", 151 | " strides=[1, 2, 2, 2],\n", 152 | " embedding_loss_weight=1,\n", 153 | " beta=1,\n", 154 | " pixel_loss=torch.nn.L1Loss,\n", 155 | " deep_supervision=0,\n", 156 | " use_attention='none',\n", 157 | " norm_name=(\"GROUP\", {'num_groups': 4, \"affine\": True}),\n", 158 | " sample_every_n_steps=200,\n", 159 | " )\n", 160 | "\n", 161 | " model.load_pretrained(BEST_VQ_GAN_CKPT_3D)\n", 162 | "\n", 163 | "model.eval()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# get next element of dataloader\n", 173 | "test_sample = next(iter(dm.test_dataloader()))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "SLICE_NUM = 60\n", 183 | "\n", 184 | "if USE_2D:\n", 185 | "\tplt.imshow(test_sample['source'][0][0], cmap='gray')\n", 186 | "\tplt.axis('off')\n", 187 | "else:\n", 188 | "\tplt.imshow(test_sample['source'][0][0][SLICE_NUM], cmap='gray')\n", 189 | "\tplt.axis('off')" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "out_sample = model(test_sample['source'])" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "if USE_2D:\n", 208 | "\tplt.imshow(out_sample[0][0][0].detach().cpu(), cmap='gray')\n", 209 | "\tplt.axis('off')\n", 210 | "else:\n", 211 | "\tplt.imshow(out_sample[0][0][0][SLICE_NUM].detach().cpu(), cmap='gray')\n", 212 | "\tplt.axis('off')\n", 213 | " " 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "indices, embedding_shape = model.vqvae.encode_to_indices(test_sample['source'])" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "print(indices.shape)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "print(embedding_shape)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "out_sample_2 = model.vqvae.decode_from_indices(indices, embedding_shape)\n", 250 | "\n", 251 | "if USE_2D:\n", 252 | "\tplt.imshow(out_sample_2[0][0].detach().cpu(), cmap='gray')\n", 253 | "\tplt.axis('off')\n", 254 | "else:\n", 255 | "\tplt.imshow(out_sample_2[0][0][SLICE_NUM].detach().cpu(), cmap='gray')\n", 256 | "\tplt.axis('off')" 257 | ] 258 | }, 259 | { 260 | "attachments": {}, 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "# Convert all images to indices" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "# get next element of dataloader\n", 274 | "storage_dir = STORAGE_DIR \n", 275 | "train_path = os.path.join(storage_dir, 'train') \n", 276 | "val_path = os.path.join(storage_dir, 'val')\n", 277 | "test_path = os.path.join(storage_dir, 'test')\n", 278 | "create_dir(train_path)\n", 279 | "create_dir(val_path)\n", 280 | " \n", 281 | "for split in [[train_path, dm.train_dataloader()], [val_path, dm.val_dataloader()], [test_path, dm.test_dataloader()]]:\n", 282 | "\tfor sample in tqdm(split[1]):\n", 283 | "\t\tindices, embedding_shape = model.vqvae.encode_to_indices(sample['source'])\n", 284 | "\t\tfile_name = sample['file_name'][0].split('/')[-2] \n", 285 | "\t\tindices_np = indices.detach().cpu().numpy()\n", 286 | "\t\tfolder_path = os.path.join(split[0], file_name)\n", 287 | "\t\tcreate_dir(folder_path)\n", 288 | "\t\tif USE_2D:\n", 289 | "\t\t\tnp.save(os.path.join(folder_path, f'{PROJECTION_PLANE}.npy'), indices_np)\n", 290 | "\t\telse:\n", 291 | "\t\t\tnp.save(os.path.join(folder_path, 'CT.npy'), indices_np)" 292 | ] 293 | } 294 | ], 295 | "metadata": { 296 | "kernelspec": { 297 | "display_name": "medicaldiffusion", 298 | "language": "python", 299 | "name": "python3" 300 | }, 301 | "language_info": { 302 | "codemirror_mode": { 303 | "name": "ipython", 304 | "version": 3 305 | }, 306 | "file_extension": ".py", 307 | "mimetype": "text/x-python", 308 | "name": "python", 309 | "nbconvert_exporter": "python", 310 | "pygments_lexer": "ipython3", 311 | "version": "3.8.13" 312 | }, 313 | "orig_nbformat": 4, 314 | "vscode": { 315 | "interpreter": { 316 | "hash": "489df2b2e73de1ddceac97fbee82a53bd2a027d2efa7299f0a23dfd27bb8968f" 317 | } 318 | } 319 | }, 320 | "nbformat": 4, 321 | "nbformat_minor": 2 322 | } 323 | -------------------------------------------------------------------------------- /gptnano/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import utils as vutils 10 | from gptnano.translator import VQGANTransformer 11 | from utils import load_data, plot_images 12 | import matplotlib.pyplot as plt 13 | from gptnano.nanogpt import LayerNorm 14 | from contextlib import nullcontext 15 | import inspect 16 | 17 | 18 | # given a numpy array of images, plot them in a grid with n rows and m columns 19 | def save_images(images, rows, cols, path): 20 | fig, axes = plt.subplots( 21 | rows, cols, figsize=(20, 20)) 22 | for i, ax in enumerate(axes.flatten()): 23 | ax.imshow(images[i], cmap='gray') 24 | ax.axis('off') 25 | plt.savefig(path) 26 | 27 | 28 | def create_dir(path): 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | 32 | 33 | class TrainTransformer: 34 | def __init__(self, args): 35 | self.model = VQGANTransformer(args).to(device=args.device) 36 | self.optim = self.configure_optimizers() 37 | 38 | self.train(args) 39 | 40 | def configure_optimizers(self): 41 | device_type = 'cuda' 42 | weight_decay = 0.01 43 | self.learning_rate = 6e-4 # 4.5e-06 44 | betas = (0.9, 0.95) 45 | 46 | optimizer = self.model.transformer.configure_optimizers( 47 | weight_decay=weight_decay, 48 | learning_rate=self.learning_rate, 49 | betas=betas, 50 | device_type=device_type, 51 | ) 52 | 53 | return optimizer 54 | 55 | def setup_training(self, args): 56 | seed_offset = 0 57 | device = 'cuda' 58 | # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 59 | dtype = 'bfloat16' 60 | compile = True 61 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 62 | 63 | torch.manual_seed(1337 + seed_offset) 64 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 65 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 66 | # for later use in torch.autocast 67 | device_type = 'cuda' if 'cuda' in device else 'cpu' 68 | # note: float16 data type will automatically use a GradScaler 69 | ptdtype = {'float32': torch.float32, 70 | 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 71 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( 72 | device_type=device_type, dtype=ptdtype) 73 | 74 | # initialize a GradScaler. If enabled=False scaler is a no-op 75 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 76 | 77 | # compile the model 78 | if compile: 79 | print("compiling the model... (takes a ~minute)") 80 | unoptimized_model = self.model 81 | self.model = torch.compile( 82 | unoptimized_model) # requires PyTorch 2.0 83 | 84 | return ctx, scaler, grad_clip 85 | 86 | # learning rate decay scheduler (cosine with warmup) 87 | def get_lr(self, it): 88 | warmup_iters = 2000 89 | learning_rate = self.learning_rate 90 | lr_decay_iters = 600000 91 | min_lr = 6e-5 92 | 93 | # 1) linear warmup for warmup_iters steps 94 | if it < warmup_iters: 95 | return learning_rate * it / warmup_iters 96 | # 2) if it > lr_decay_iters, return min learning rate 97 | if it > lr_decay_iters: 98 | return min_lr 99 | # 3) in between, use cosine decay down to min learning rate 100 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 101 | assert 0 <= decay_ratio <= 1 102 | # coeff ranges 0..1 103 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 104 | return min_lr + coeff * (learning_rate - min_lr) 105 | 106 | def train(self, args): 107 | decay_lr = True 108 | ctx, scaler, grad_clip = self.setup_training(args) 109 | 110 | train_dataloader, val_dataloader, test_dataloader = load_data(args) 111 | current_best_val_epoch_loss = torch.tensor(float('inf')) 112 | current_best_epoch = 0 113 | run_idx = 0 114 | while os.path.exists(os.path.join("gpt_results", f"run_{run_idx}")): 115 | run_idx += 1 116 | 117 | iter_num = 0 118 | for epoch in range(args.epochs): 119 | with tqdm(range(len(train_dataloader))) as pbar: 120 | train_epoch_loss = [] 121 | for i, data in zip(pbar, train_dataloader): 122 | # determine and set the learning rate for this iteration 123 | lr = self.get_lr( 124 | iter_num) if decay_lr else self.learning_rate 125 | for param_group in self.optim.param_groups: 126 | param_group['lr'] = lr 127 | 128 | imgs_ct = data['indices_ct'] 129 | imgs_ap = data['indices_ap'] 130 | imgs_lat = data['indices_lat'] 131 | 132 | self.optim.zero_grad() 133 | 134 | imgs_ct = imgs_ct.to(device=args.device) 135 | imgs_ap = imgs_ap.to(device=args.device) 136 | imgs_lat = imgs_lat.to(device=args.device) 137 | 138 | with ctx: 139 | logits, targets = self.model( 140 | (imgs_ct, imgs_ap, imgs_lat)) 141 | 142 | loss = F.cross_entropy( 143 | logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) 144 | 145 | # loss.backward() 146 | # backward pass, with gradient scaling if training in fp16 147 | scaler.scale(loss).backward() 148 | 149 | # clip the gradient 150 | if grad_clip != 0.0: 151 | scaler.unscale_(self.optim) 152 | torch.nn.utils.clip_grad_norm_( 153 | self.model.parameters(), grad_clip) 154 | # step the optimizer and scaler if training in fp16 155 | scaler.step(self.optim) 156 | scaler.update() 157 | 158 | # self.optim.step() 159 | pbar.set_postfix(Transformer_Loss=np.round( 160 | loss.item(), 4), lr=np.round(lr, 6)) 161 | pbar.update(0) 162 | train_epoch_loss.append(loss) 163 | 164 | iter_num += 1 165 | print("Epoch: ", epoch, "Loss (mean): ", 166 | torch.mean(torch.tensor(train_epoch_loss))) 167 | 168 | with tqdm(range(len(val_dataloader))) as pbar: 169 | val_epoch_loss = [] 170 | self.model.eval() 171 | with torch.no_grad(): 172 | for i, data in zip(pbar, val_dataloader): 173 | imgs_ct = data['indices_ct'] 174 | imgs_ap = data['indices_ap'] 175 | imgs_lat = data['indices_lat'] 176 | 177 | imgs_ct = imgs_ct.to(device=args.device) 178 | imgs_ap = imgs_ap.to(device=args.device) 179 | imgs_lat = imgs_lat.to(device=args.device) 180 | 181 | with ctx: 182 | logits, targets = self.model( 183 | (imgs_ct, imgs_ap, imgs_lat)) 184 | 185 | loss = F.cross_entropy( 186 | logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) 187 | pbar.set_postfix(Transformer_Loss=np.round( 188 | loss.item(), 4)) 189 | pbar.update(0) 190 | val_epoch_loss.append(loss) 191 | mean_val_epoch_loss = torch.mean(torch.tensor(val_epoch_loss)) 192 | print("Epoch: ", epoch, "Loss (mean): ", mean_val_epoch_loss) 193 | self.model.train() 194 | 195 | 196 | checkpoint_dir = os.path.join( 197 | "gpt_results", f"run_{run_idx}", "checkpoints") 198 | create_dir(checkpoint_dir) 199 | 200 | # Store last checkpoint 201 | torch.save(self.model.state_dict(), os.path.join( 202 | checkpoint_dir, f"transformer_last.pt")) 203 | 204 | # Store best checkpoint 205 | if mean_val_epoch_loss < current_best_val_epoch_loss: 206 | torch.save(self.model.state_dict(), os.path.join( 207 | checkpoint_dir, f"transformer_{epoch}_{np.round(mean_val_epoch_loss.float(), 4)}.pt")) 208 | 209 | # check if file exists 210 | path = os.path.join( 211 | checkpoint_dir, f"transformer_{current_best_epoch}_{np.round(current_best_val_epoch_loss.float(), 4)}.pt") 212 | if os.path.isfile(path): 213 | os.remove(path) 214 | 215 | current_best_val_epoch_loss = mean_val_epoch_loss 216 | current_best_epoch = epoch 217 | 218 | if epoch % 10 == 0: 219 | self.model.eval() 220 | with ctx: 221 | log, sampled_imgs_ct, sampled_imgs_ap, sampled_imgs_lat = self.model.log_images( 222 | (imgs_ct[0][None], imgs_ap[0][None], imgs_lat[0][None])) 223 | SLICE_IDX = 60 224 | 225 | image_dir = os.path.join( 226 | f"gpt_results", f"run_{run_idx}", "results") 227 | 228 | create_dir(image_dir) 229 | 230 | np_imgs = torch.cat((sampled_imgs_ct[:, :, :, SLICE_IDX].flip( 231 | dims=(1, 2)), sampled_imgs_ap, sampled_imgs_lat)).detach().cpu().to(torch.float).numpy()[:, 0] 232 | save_images(np_imgs, rows=3, cols=3, path=os.path.join( 233 | image_dir, f"transformer_{epoch}.jpg")) 234 | 235 | # vutils.save_image(torch.cat((sampled_imgs_ct[:, :, :, SLICE_IDX].flip(dims=(1, 2)), sampled_imgs_ap, sampled_imgs_lat)), os.path.join( 236 | # image_dir, f"transformer_{epoch}.jpg"), nrow=3) 237 | # plot_images(log) 238 | 239 | self.model.train() 240 | 241 | 242 | if __name__ == '__main__': 243 | parser = argparse.ArgumentParser(description="VQGAN") 244 | # VQGAN 245 | parser.add_argument('--num-codebook-vectors', type=int, 246 | default=8192, help='Number of codebook vectors.') 247 | parser.add_argument('--checkpoint-path-3d-vqgan', type=str, 248 | default=None, help='Path to checkpoint.') 249 | 250 | parser.add_argument('--checkpoint-path-2d-vqgan', type=str, 251 | default=None, help='Path to checkpoint.') 252 | 253 | parser.add_argument('path-to-preprocessed-data', type=str, default=None, help='Path to preprocessed LIDC-IDRI data') 254 | parser.add_arguments('path-to-data-indices', type=str, default=None, help='Path to pre-extracted indices') 255 | 256 | # Transformer 257 | parser.add_argument('--pkeep', type=float, default=0.5, 258 | help='Percentage for how much latent codes to keep.') 259 | parser.add_argument('--sos-token', type=int, default=0, 260 | help='Start of Sentence token.') 261 | # TODO: Provide better explanation and maybe also why 4096 + 256 * 2 (because of ct, lat and ap) 262 | parser.add_argument('--block_size', type=int, default=4096 + 256 * 2, 263 | help='Block size of GPT') 264 | # TODO: Provide better explanation and maybe also why 256 * 2 (because of lat and ap) and +1 to include the sos token 265 | parser.add_argument('--n_unmasked', type=int, default=256 * 2 + 1, 266 | help='Number of unmasked tokens (needed for the 2D images)') 267 | parser.add_argument('--use_normal_attention', 268 | action='store_true', help='Use normal attention instead of performer attention.') 269 | 270 | # Training 271 | parser.add_argument('--device', type=str, default="cuda", 272 | help='Which device the training is on') 273 | parser.add_argument('--batch-size', type=int, default=12, 274 | help='Input batch size for training.') 275 | parser.add_argument('--epochs', type=int, default=300, 276 | help='Number of epochs to train.') 277 | parser.add_argument('--learning-rate', type=float, 278 | default=2.25e-05, help='Learning rate.') 279 | parser.add_argument('--num_workers', type=int, 280 | default=20, help='Number of workers for the dataloader.') 281 | 282 | args = parser.parse_args() 283 | # create dir if not exists 284 | train_transformer = TrainTransformer(args) 285 | 286 | 287 | # TODO: Add Performer 288 | # TODO: Add GAN in latent space 289 | -------------------------------------------------------------------------------- /gptnano/translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from nanogpt import GPT 5 | from vqgan.model import VQGAN 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | from dataclasses import dataclass 9 | from collections import OrderedDict 10 | 11 | 12 | def rename_state_dict_keys(state_dict, key_transformation): 13 | """ 14 | Rename the keys of a state dict. 15 | state_dict -> The saved state dict. 16 | key_transformation -> Function that accepts the old key names of the state 17 | dict as the only argument and returns the new key name. 18 | """ 19 | new_state_dict = OrderedDict() 20 | 21 | for key, value in state_dict.items(): 22 | new_key = key_transformation(key) 23 | new_state_dict[new_key] = value 24 | 25 | return new_state_dict 26 | 27 | 28 | class VQGANTransformer(nn.Module): 29 | def __init__(self, args): 30 | super(VQGANTransformer, self).__init__() 31 | 32 | self.sos_token = args.sos_token 33 | 34 | self.vqgan_3d = self.load_vqgan_3d(args) 35 | self.vqgan_2d = self.load_vqgan_2d(args) 36 | 37 | @dataclass 38 | class TransformerConfig: 39 | block_size: int = args.block_size 40 | vocab_size: int = args.num_codebook_vectors 41 | n_layer: int = 8 42 | n_head: int = 8 43 | n_embd: int = 512 44 | dropout: float = 0.1 45 | bias: bool = False 46 | 47 | self.transformer = GPT(TransformerConfig) 48 | 49 | self.pkeep = args.pkeep 50 | 51 | def load_gpt(self, args, strict=True): 52 | checkpoint_path = Path(args.checkpoint_path_gpt) 53 | checkpoint = torch.load(checkpoint_path) 54 | checkpoint = rename_state_dict_keys(checkpoint, lambda x: '.'.join(x.split('.')[1:]) if x.startswith('_orig_mod') else x) 55 | self.load_state_dict(checkpoint, strict=strict) 56 | self.transformer = self.transformer.eval() 57 | 58 | @staticmethod 59 | def load_vqgan_3d(args): 60 | # TODO: This shouldn't be hardcoded 61 | model = VQGAN( 62 | in_channels=1, 63 | out_channels=1, 64 | emb_channels=256, 65 | num_embeddings=8192, 66 | spatial_dims=3, 67 | hid_chs=[32, 64, 128, 256], 68 | kernel_sizes=[3, 3, 3, 3], 69 | strides=[1, 2, 2, 2], 70 | embedding_loss_weight=1, 71 | beta=1, 72 | pixel_loss=torch.nn.L1Loss, 73 | deep_supervision=0, 74 | use_attention='none', 75 | norm_name=("GROUP", {'num_groups': 4, "affine": True}), 76 | sample_every_n_steps=200, 77 | ) 78 | checkpoint_path = Path(args.checkpoint_path_3d_vqgan) 79 | model.load_pretrained(checkpoint_path) 80 | model = model.eval() 81 | return model 82 | 83 | @staticmethod 84 | def load_vqgan_2d(args): 85 | # TODO: This shouldn't be hardcoded 86 | model = VQGAN( 87 | in_channels=1, 88 | out_channels=1, 89 | emb_channels=512, 90 | num_embeddings=8192, 91 | spatial_dims=2, 92 | hid_chs=[64, 128, 256, 512], 93 | kernel_sizes=[3, 3, 3, 3], 94 | strides=[1, 2, 2, 2], 95 | embedding_loss_weight=1, 96 | beta=1, 97 | pixel_loss=torch.nn.L1Loss, 98 | deep_supervision=1, 99 | use_attention='none', 100 | sample_every_n_steps=50, 101 | ) 102 | checkpoint_path = Path(args.checkpoint_path_2d_vqgan) 103 | model.load_pretrained(checkpoint_path) 104 | model = model.eval() 105 | return model 106 | 107 | @torch.no_grad() 108 | def encode_to_z(self, x, use_3d=True): 109 | if use_3d: 110 | indices, _ = self.vqgan_3d.vqvae.encode_to_indices(x) 111 | else: 112 | indices, _ = self.vqgan_2d.vqvae.encode_to_indices(x) 113 | return indices 114 | 115 | @torch.no_grad() 116 | def z_to_image(self, indices, ch=8, p1=16, p2=16, p3=16, use_3d=True): 117 | if use_3d: 118 | image = self.vqgan_3d.vqvae.decode_from_indices( 119 | indices, (1, p1, p2, p3, ch)) 120 | else: 121 | image = self.vqgan_2d.vqvae.decode_from_indices( 122 | indices, (1, p1, p2, ch)) 123 | return image 124 | 125 | def forward(self, x, use_indices=True): 126 | if use_indices: 127 | indices_ct, indices_ap, indices_lat = x 128 | else: 129 | # TODO: Encode all images to z (3d and 2d) 130 | #indices = self.encode_to_z(x) 131 | raise NotImplementedError("Encoding to z is not implemented yet") 132 | 133 | sos_tokens = torch.ones(indices_ct.shape[0], 1) * self.sos_token 134 | sos_tokens = sos_tokens.long().to("cuda") 135 | 136 | mask = torch.bernoulli( 137 | self.pkeep * torch.ones(indices_ct.shape, device=indices_ct.device)) 138 | mask = mask.round().to(dtype=torch.int64) 139 | random_indices = torch.randint_like( 140 | indices_ct, self.transformer.config.vocab_size) 141 | new_indices = mask * indices_ct + (1 - mask) * random_indices 142 | 143 | new_indices = torch.cat( 144 | (sos_tokens, indices_ap, indices_lat, new_indices), dim=1) 145 | 146 | target = torch.cat((indices_ap, indices_lat, indices_ct), dim=1) 147 | 148 | logits, _ = self.transformer(new_indices[:, :-1]) 149 | 150 | return logits, target 151 | 152 | def top_k_logits(self, logits, k): 153 | v, ix = torch.topk(logits, k) 154 | out = logits.clone() 155 | out[out < v[..., [-1]]] = -float("inf") 156 | return out 157 | 158 | @torch.no_grad() 159 | def sample(self, x, c, steps, temperature=1.0, top_k=100): 160 | self.transformer.eval() 161 | x = torch.cat((c, x), dim=1) 162 | for k in tqdm(range(steps)): 163 | logits, _ = self.transformer(x) 164 | logits = logits[:, -1, :] / temperature 165 | 166 | if top_k is not None: 167 | logits = self.top_k_logits(logits, top_k) 168 | 169 | probs = F.softmax(logits, dim=-1) 170 | 171 | ix = torch.multinomial(probs, num_samples=1) 172 | 173 | x = torch.cat((x, ix), dim=1) 174 | 175 | x = x[:, c.shape[1]:] 176 | self.transformer.train() 177 | return x 178 | 179 | @torch.no_grad() 180 | def log_images(self, x, use_indices=True, temperature=1.0, top_k=100): 181 | log = dict() 182 | 183 | if use_indices: 184 | indices_ct, indices_ap, indices_lat = x 185 | else: 186 | # TODO: Encode all images to z (3d and 2d) 187 | #indices = self.encode_to_z(x) 188 | raise NotImplementedError("Encoding to z is not implemented yet") 189 | 190 | sos_tokens = torch.ones(indices_ct.shape[0], 1) * self.sos_token 191 | sos_tokens = sos_tokens.long().to("cuda") 192 | 193 | start_indices_ = indices_ct[:, :indices_ct.shape[1] // 2] 194 | start_indices = torch.cat( 195 | (indices_ap, indices_lat, start_indices_), dim=1) 196 | sample_indices = self.sample( 197 | start_indices, sos_tokens, steps=indices_ct.shape[1] - start_indices_.shape[1], temperature=temperature, top_k=top_k) 198 | 199 | sample_indices_ap = sample_indices[:, :indices_ap.shape[1]] 200 | sample_indices_lat = sample_indices[:, indices_ap.shape[1] 201 | :indices_ap.shape[1]+indices_lat.shape[1]] 202 | sample_indices_ct = sample_indices[:, 203 | indices_ap.shape[1]+indices_lat.shape[1]:] 204 | half_sample_ct = self.z_to_image( 205 | sample_indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 206 | half_sample_ap = self.z_to_image( 207 | sample_indices_ap, use_3d=False, ch=512, p1=16, p2=16) 208 | half_sample_lat = self.z_to_image( 209 | sample_indices_lat, use_3d=False, ch=512, p1=16, p2=16) 210 | 211 | start_indices_ = indices_ct[:, :0] 212 | start_indices = torch.cat( 213 | (indices_ap, indices_lat, start_indices_), dim=1) 214 | sample_indices = self.sample( 215 | start_indices, sos_tokens, steps=indices_ct.shape[1], temperature=temperature, top_k=top_k) 216 | 217 | sample_indices_ap = sample_indices[:, :indices_ap.shape[1]] 218 | sample_indices_lat = sample_indices[:, indices_ap.shape[1] 219 | :indices_ap.shape[1]+indices_lat.shape[1]] 220 | sample_indices_ct = sample_indices[:, 221 | indices_ap.shape[1]+indices_lat.shape[1]:] 222 | 223 | full_sample_ct = self.z_to_image( 224 | sample_indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 225 | full_sample_ap = self.z_to_image( 226 | sample_indices_ap, use_3d=False, ch=512, p1=16, p2=16) 227 | full_sample_lat = self.z_to_image( 228 | sample_indices_lat, use_3d=False, ch=512, p1=16, p2=16) 229 | 230 | x_rec_ct = self.z_to_image( 231 | indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 232 | x_rec_ap = self.z_to_image( 233 | indices_ap, use_3d=False, ch=512, p1=16, p2=16) 234 | x_rec_lat = self.z_to_image( 235 | indices_lat, use_3d=False, ch=512, p1=16, p2=16) 236 | 237 | log["input"] = x 238 | log["rec_ct"] = x_rec_ct 239 | log["rec_ap"] = x_rec_ap 240 | log["rec_lat"] = x_rec_lat 241 | log["half_sample_ct"] = half_sample_ct 242 | log["half_sample_ap"] = half_sample_ap 243 | log["half_sample_lat"] = half_sample_lat 244 | log["full_sample_ct"] = full_sample_ct 245 | log["full_sample_ap"] = full_sample_ap 246 | log["full_sample_lat"] = full_sample_lat 247 | 248 | return log, torch.concat((x_rec_ct, half_sample_ct, full_sample_ct)), torch.concat((x_rec_ap, half_sample_ap, full_sample_ap)), torch.concat((x_rec_lat, half_sample_lat, full_sample_lat)) 249 | 250 | @torch.no_grad() 251 | def log_images_monoplanar(self, x, use_indices=True, temperature=1.0, top_k=100): 252 | log = dict() 253 | 254 | if use_indices: 255 | indices_ct, indices_ap, indices_lat = x 256 | else: 257 | # TODO: Encode all images to z (3d and 2d) 258 | #indices = self.encode_to_z(x) 259 | raise NotImplementedError("Encoding to z is not implemented yet") 260 | 261 | sos_tokens = torch.ones(indices_ct.shape[0], 1) * self.sos_token 262 | sos_tokens = sos_tokens.long().to("cuda") 263 | 264 | # Monoplanar 265 | start_indices_ = indices_ct[:, :0] 266 | start_indices = torch.cat( 267 | (indices_ap, start_indices_), dim=1) 268 | sample_indices = self.sample( 269 | start_indices, sos_tokens, steps=indices_lat.shape[1] + indices_ct.shape[1], temperature=temperature, top_k=top_k) 270 | 271 | sample_indices_ap = sample_indices[:, :indices_ap.shape[1]] 272 | sample_indices_lat = sample_indices[:, indices_ap.shape[1] 273 | :indices_ap.shape[1]+indices_lat.shape[1]] 274 | sample_indices_ct = sample_indices[:, 275 | indices_ap.shape[1]+indices_lat.shape[1]:] 276 | half_sample_ct = self.z_to_image( 277 | sample_indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 278 | half_sample_ap = self.z_to_image( 279 | sample_indices_ap, use_3d=False, ch=512, p1=16, p2=16) 280 | half_sample_lat = self.z_to_image( 281 | sample_indices_lat, use_3d=False, ch=512, p1=16, p2=16) 282 | 283 | # Biplanar 284 | start_indices_ = indices_ct[:, :0] 285 | start_indices = torch.cat( 286 | (indices_ap, indices_lat, start_indices_), dim=1) 287 | sample_indices = self.sample( 288 | start_indices, sos_tokens, steps=indices_ct.shape[1], temperature=temperature, top_k=top_k) 289 | 290 | sample_indices_ap = sample_indices[:, :indices_ap.shape[1]] 291 | sample_indices_lat = sample_indices[:, indices_ap.shape[1] 292 | :indices_ap.shape[1]+indices_lat.shape[1]] 293 | sample_indices_ct = sample_indices[:, 294 | indices_ap.shape[1]+indices_lat.shape[1]:] 295 | 296 | full_sample_ct = self.z_to_image( 297 | sample_indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 298 | full_sample_ap = self.z_to_image( 299 | sample_indices_ap, use_3d=False, ch=512, p1=16, p2=16) 300 | full_sample_lat = self.z_to_image( 301 | sample_indices_lat, use_3d=False, ch=512, p1=16, p2=16) 302 | 303 | x_rec_ct = self.z_to_image( 304 | indices_ct, use_3d=True, ch=256, p1=16, p2=16, p3=16) 305 | x_rec_ap = self.z_to_image( 306 | indices_ap, use_3d=False, ch=512, p1=16, p2=16) 307 | x_rec_lat = self.z_to_image( 308 | indices_lat, use_3d=False, ch=512, p1=16, p2=16) 309 | 310 | log["input"] = x 311 | log["rec_ct"] = x_rec_ct 312 | log["rec_ap"] = x_rec_ap 313 | log["rec_lat"] = x_rec_lat 314 | log["half_sample_ct"] = half_sample_ct 315 | log["half_sample_ap"] = half_sample_ap 316 | log["half_sample_lat"] = half_sample_lat 317 | log["full_sample_ct"] = full_sample_ct 318 | log["full_sample_ap"] = full_sample_ap 319 | log["full_sample_lat"] = full_sample_lat 320 | 321 | return log, torch.concat((x_rec_ct, half_sample_ct, full_sample_ct)), torch.concat((x_rec_ap, half_sample_ap, full_sample_ap)), torch.concat((x_rec_lat, half_sample_lat, full_sample_lat)) 322 | 323 | -------------------------------------------------------------------------------- /vqgan/attention_blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | 5 | from monai.networks.blocks import TransformerBlock 6 | from monai.networks.layers.utils import get_norm_layer, get_dropout_layer 7 | from monai.networks.layers.factories import Conv 8 | from einops import rearrange 9 | 10 | 11 | class GEGLU(nn.Module): 12 | def __init__(self, in_channels, out_channels): 13 | super().__init__() 14 | self.norm = nn.LayerNorm(in_channels) 15 | self.proj = nn.Linear(in_channels, out_channels*2, bias=True) 16 | 17 | def forward(self, x): 18 | # x expected to be [B, C, *] 19 | # Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465 20 | b, c, *spatial = x.shape 21 | x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C] 22 | x = self.norm(x) 23 | x, gate = self.proj(x).chunk(2, dim=-1) 24 | x = x * F.gelu(gate) 25 | return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *] 26 | 27 | def zero_module(module): 28 | """ 29 | Zero out the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().zero_() 33 | return module 34 | 35 | def compute_attention(q,k,v , num_heads, scale): 36 | q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N] 37 | 38 | attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N'] 39 | 40 | out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N] 41 | out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N] 42 | 43 | return out 44 | 45 | 46 | class LinearTransformerNd(nn.Module): 47 | """ Combines multi-head self-attention and multi-head cross-attention. 48 | Multi-Head Self-Attention: 49 | Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock) 50 | Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 51 | Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278 52 | Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80 53 | Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209 54 | Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150 55 | CrossAttention: 56 | Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152 57 | 58 | """ 59 | def __init__( 60 | self, 61 | spatial_dims, 62 | in_channels, 63 | out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled 64 | num_heads=8, 65 | ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) 66 | norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1) 67 | dropout=None, 68 | emb_dim=None, 69 | ): 70 | super().__init__() 71 | hid_channels = num_heads*ch_per_head 72 | self.num_heads = num_heads 73 | self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale 74 | 75 | self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) 76 | emb_dim = in_channels if emb_dim is None else emb_dim 77 | 78 | Convolution = Conv["conv", spatial_dims] 79 | self.to_q = Convolution(in_channels, hid_channels, 1) 80 | self.to_k = Convolution(emb_dim, hid_channels, 1) 81 | self.to_v = Convolution(emb_dim, hid_channels, 1) 82 | 83 | self.to_out = nn.Sequential( 84 | zero_module(Convolution(hid_channels, out_channels, 1)), 85 | nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) 86 | ) 87 | 88 | def forward(self, x, embedding=None): 89 | # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] 90 | # if no embedding is given, cross-attention defaults to self-attention 91 | 92 | # Normalize 93 | b, c, *spatial = x.shape 94 | x_n = self.norm_x(x) 95 | 96 | # Attention: embedding (cross-attention) or x (self-attention) 97 | if embedding is None: 98 | embedding = x_n # WARNING: This assumes that emb_dim==in_channels 99 | else: 100 | if embedding.ndim == 2: 101 | embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] 102 | # Why no normalization for embedding here? 103 | 104 | # Convolution 105 | q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *] 106 | k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *] 107 | v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *] 108 | 109 | # Flatten 110 | q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N] 111 | k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] 112 | v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] 113 | 114 | # Apply attention 115 | out = compute_attention(q, k, v, self.num_heads, self.scale) 116 | 117 | out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *] 118 | out = self.to_out(out) # -> [B, C', *] 119 | 120 | 121 | if x.shape == out.shape: 122 | out = x + out 123 | return out # [B, C', *] 124 | 125 | 126 | class LinearTransformer(nn.Module): 127 | """ See LinearTransformer, however this implementation is fixed to Conv1d/Linear""" 128 | def __init__( 129 | self, 130 | spatial_dims, 131 | in_channels, 132 | out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled 133 | num_heads, 134 | ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) 135 | norm_name=("GROUP", {'num_groups':32, "affine": True}), 136 | dropout=None, 137 | emb_dim=None 138 | ): 139 | super().__init__() 140 | hid_channels = num_heads*ch_per_head 141 | self.num_heads = num_heads 142 | self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale 143 | 144 | self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) 145 | emb_dim = in_channels if emb_dim is None else emb_dim 146 | 147 | # Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C] 148 | self.to_q = nn.Conv1d(in_channels, hid_channels, 1) 149 | self.to_k = nn.Conv1d(emb_dim, hid_channels, 1) 150 | self.to_v = nn.Conv1d(emb_dim, hid_channels, 1) 151 | # self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1) 152 | 153 | self.to_out = nn.Sequential( 154 | zero_module(nn.Conv1d(hid_channels, out_channels, 1)), 155 | nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) 156 | ) 157 | 158 | def forward(self, x, embedding=None): 159 | # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] 160 | # if no embedding is given, cross-attention defaults to self-attention 161 | 162 | # Normalize 163 | b, c, *spatial = x.shape 164 | x_n = self.norm_x(x) 165 | 166 | # Attention: embedding (cross-attention) or x (self-attention) 167 | if embedding is None: 168 | embedding = x_n # WARNING: This assumes that emb_dim==in_channels 169 | else: 170 | if embedding.ndim == 2: 171 | embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] 172 | # Why no normalization for embedding here? 173 | 174 | # Flatten 175 | x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N] 176 | embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N'] 177 | 178 | # Convolution 179 | q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N] 180 | k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N'] 181 | v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N'] 182 | # qkv = self.to_qkv(x_n) 183 | # q,k,v = qkv.split(qkv.shape[1]//3, dim=1) 184 | 185 | # Apply attention 186 | out = compute_attention(q, k, v, self.num_heads, self.scale) 187 | 188 | out = self.to_out(out) # -> [B, C', N] 189 | out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *] 190 | 191 | if x.shape == out.shape: 192 | out = x + out 193 | return out # [B, C', *] 194 | 195 | 196 | 197 | 198 | class BasicTransformerBlock(nn.Module): 199 | def __init__( 200 | self, 201 | spatial_dims, 202 | in_channels, 203 | out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled 204 | num_heads, 205 | ch_per_head=32, 206 | norm_name=("GROUP", {'num_groups':32, "affine": True}), 207 | dropout=None, 208 | emb_dim=None 209 | ): 210 | super().__init__() 211 | self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None) 212 | if emb_dim is not None: 213 | self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim) 214 | self.proj_out = nn.Sequential( 215 | GEGLU(in_channels, in_channels*4), 216 | nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims), 217 | Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True) 218 | ) 219 | 220 | 221 | def forward(self, x, embedding=None): 222 | # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] 223 | x = self.self_atn(x) 224 | if embedding is not None: 225 | x = self.cros_atn(x, embedding=embedding) 226 | out = self.proj_out(x) 227 | if out.shape[1] == x.shape[1]: 228 | return out + x 229 | return x 230 | 231 | class SpatialTransformer(nn.Module): 232 | """ Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218 233 | Unrelated to: https://arxiv.org/abs/1506.02025 234 | """ 235 | def __init__( 236 | self, 237 | spatial_dims, 238 | in_channels, 239 | out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled 240 | num_heads, 241 | ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) 242 | norm_name = ("GROUP", {'num_groups':32, "affine": True}), 243 | dropout=None, 244 | emb_dim=None, 245 | depth=1 246 | ): 247 | super().__init__() 248 | self.in_channels = in_channels 249 | self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) 250 | conv_class = Conv["conv", spatial_dims] 251 | hid_channels = num_heads*ch_per_head 252 | 253 | self.proj_in = conv_class( 254 | in_channels, 255 | hid_channels, 256 | kernel_size=1, 257 | stride=1, 258 | padding=0, 259 | ) 260 | 261 | self.transformer_blocks = nn.ModuleList([ 262 | BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim) 263 | for _ in range(depth)] 264 | ) 265 | 266 | self.proj_out = conv_class( # Note: zero_module is used in original code 267 | hid_channels, 268 | out_channels, 269 | kernel_size=1, 270 | stride=1, 271 | padding=0, 272 | ) 273 | 274 | def forward(self, x, embedding=None): 275 | # x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] 276 | # Note: if no embedding is given, cross-attention is disabled 277 | h = self.norm(x) 278 | h = self.proj_in(h) 279 | 280 | for block in self.transformer_blocks: 281 | h = block(h, embedding=embedding) 282 | 283 | h = self.proj_out(h) # -> [B, C'', *] 284 | if h.shape == x.shape: 285 | return h + x 286 | return h 287 | 288 | 289 | class Attention(nn.Module): 290 | def __init__( 291 | self, 292 | spatial_dims, 293 | in_channels, 294 | out_channels, 295 | num_heads=8, 296 | ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) 297 | norm_name = ("GROUP", {'num_groups':32, "affine": True}), 298 | dropout=0, 299 | emb_dim=None, 300 | depth=1, 301 | attention_type='linear' 302 | ) -> None: 303 | super().__init__() 304 | if attention_type == 'spatial': 305 | self.attention = SpatialTransformer( 306 | spatial_dims=spatial_dims, 307 | in_channels=in_channels, 308 | out_channels=out_channels, 309 | num_heads=num_heads, 310 | ch_per_head=ch_per_head, 311 | depth=depth, 312 | norm_name=norm_name, 313 | dropout=dropout, 314 | emb_dim=emb_dim 315 | ) 316 | elif attention_type == 'linear': 317 | self.attention = LinearTransformer( 318 | spatial_dims=spatial_dims, 319 | in_channels=in_channels, 320 | out_channels=out_channels, 321 | num_heads=num_heads, 322 | ch_per_head=ch_per_head, 323 | norm_name=norm_name, 324 | dropout=dropout, 325 | emb_dim=emb_dim 326 | ) 327 | 328 | 329 | def forward(self, x, emb=None): 330 | if hasattr(self, 'attention'): 331 | return self.attention(x, emb) 332 | else: 333 | return x -------------------------------------------------------------------------------- /vqgan/conv_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple, Union, Type 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | from monai.networks.blocks.dynunet_block import get_padding, get_output_padding 10 | from monai.networks.layers import Pool, Conv 11 | from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_dropout_layer 12 | from monai.utils.misc import ensure_tuple_rep 13 | 14 | from vqgan.attention_blocks import Attention, zero_module 15 | 16 | def save_add(*args): 17 | args = [arg for arg in args if arg is not None] 18 | return sum(args) if len(args)>0 else None 19 | 20 | 21 | class SequentialEmb(nn.Sequential): 22 | def forward(self, input, emb): 23 | for module in self: 24 | input = module(input, emb) 25 | return input 26 | 27 | 28 | class BasicDown(nn.Module): 29 | def __init__( 30 | self, 31 | spatial_dims, 32 | in_channels, 33 | out_channels, 34 | kernel_size=3, 35 | stride=2, 36 | learnable_interpolation=True, 37 | use_res=False 38 | ) -> None: 39 | super().__init__() 40 | 41 | if learnable_interpolation: 42 | Convolution = Conv[Conv.CONV, spatial_dims] 43 | self.down_op = Convolution( 44 | in_channels, 45 | out_channels, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=get_padding(kernel_size, stride), 49 | dilation=1, 50 | groups=1, 51 | bias=True, 52 | ) 53 | 54 | if use_res: 55 | self.down_skip = nn.PixelUnshuffle(2) # WARNING: Only supports 2D, , out_channels == 4*in_channels 56 | 57 | else: 58 | Pooling = Pool['avg', spatial_dims] 59 | self.down_op = Pooling( 60 | kernel_size=kernel_size, 61 | stride=stride, 62 | padding=get_padding(kernel_size, stride) 63 | ) 64 | 65 | 66 | def forward(self, x, emb=None): 67 | y = self.down_op(x) 68 | if hasattr(self, 'down_skip'): 69 | y = y+self.down_skip(x) 70 | return y 71 | 72 | class BasicUp(nn.Module): 73 | def __init__( 74 | self, 75 | spatial_dims, 76 | in_channels, 77 | out_channels, 78 | kernel_size=2, 79 | stride=2, 80 | learnable_interpolation=True, 81 | use_res=False, 82 | ) -> None: 83 | super().__init__() 84 | self.learnable_interpolation = learnable_interpolation 85 | if learnable_interpolation: 86 | # TransConvolution = Conv[Conv.CONVTRANS, spatial_dims] 87 | # padding = get_padding(kernel_size, stride) 88 | # output_padding = get_output_padding(kernel_size, stride, padding) 89 | # self.up_op = TransConvolution( 90 | # in_channels, 91 | # out_channels, 92 | # kernel_size=kernel_size, 93 | # stride=stride, 94 | # padding=padding, 95 | # output_padding=output_padding, 96 | # groups=1, 97 | # bias=True, 98 | # dilation=1 99 | # ) 100 | 101 | self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) 102 | -2*np.atleast_1d(get_padding(kernel_size, stride))) 103 | Convolution = Conv[Conv.CONV, spatial_dims] 104 | self.up_op = Convolution( 105 | in_channels, 106 | out_channels, 107 | kernel_size=3, 108 | stride=1, 109 | padding=1, 110 | dilation=1, 111 | groups=1, 112 | bias=True, 113 | ) 114 | 115 | if use_res: 116 | self.up_skip = nn.PixelShuffle(2) # WARNING: Only supports 2D, out_channels == in_channels/4 117 | else: 118 | self.calc_shape = lambda x: tuple((np.asarray(x)-1)*np.atleast_1d(stride)+np.atleast_1d(kernel_size) 119 | -2*np.atleast_1d(get_padding(kernel_size, stride))) 120 | 121 | def forward(self, x, emb=None): 122 | if self.learnable_interpolation: 123 | new_size = self.calc_shape(x.shape[2:]) 124 | x_res = F.interpolate(x, size=new_size, mode='nearest-exact') 125 | y = self.up_op(x_res) 126 | if hasattr(self, 'up_skip'): 127 | y = y+self.up_skip(x) 128 | return y 129 | else: 130 | new_size = self.calc_shape(x.shape[2:]) 131 | return F.interpolate(x, size=new_size, mode='nearest-exact') 132 | 133 | 134 | class BasicBlock(nn.Module): 135 | """ 136 | A block that consists of Conv-Norm-Drop-Act, similar to blocks.Convolution. 137 | 138 | Args: 139 | spatial_dims: number of spatial dimensions. 140 | in_channels: number of input channels. 141 | out_channels: number of output channels. 142 | kernel_size: convolution kernel size. 143 | stride: convolution stride. 144 | norm_name: feature normalization type and arguments. 145 | act_name: activation layer type and arguments. 146 | dropout: dropout probability. 147 | zero_conv: zero out the parameters of the convolution. 148 | """ 149 | 150 | def __init__( 151 | self, 152 | spatial_dims: int, 153 | in_channels: int, 154 | out_channels: int, 155 | kernel_size: Union[Sequence[int], int], 156 | stride: Union[Sequence[int], int]=1, 157 | norm_name: Union[Tuple, str, None]=None, 158 | act_name: Union[Tuple, str, None] = None, 159 | dropout: Optional[Union[Tuple, str, float]] = None, 160 | zero_conv: bool = False, 161 | ): 162 | super().__init__() 163 | Convolution = Conv[Conv.CONV, spatial_dims] 164 | conv = Convolution( 165 | in_channels, 166 | out_channels, 167 | kernel_size=kernel_size, 168 | stride=stride, 169 | padding=get_padding(kernel_size, stride), 170 | dilation=1, 171 | groups=1, 172 | bias=True, 173 | ) 174 | self.conv = zero_module(conv) if zero_conv else conv 175 | 176 | if norm_name is not None: 177 | self.norm = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) 178 | if dropout is not None: 179 | self.drop = get_dropout_layer(name=dropout, dropout_dim=spatial_dims) 180 | if act_name is not None: 181 | self.act = get_act_layer(name=act_name) 182 | 183 | 184 | def forward(self, inp): 185 | out = self.conv(inp) 186 | if hasattr(self, "norm"): 187 | out = self.norm(out) 188 | if hasattr(self, 'drop'): 189 | out = self.drop(out) 190 | if hasattr(self, "act"): 191 | out = self.act(out) 192 | return out 193 | 194 | class BasicResBlock(nn.Module): 195 | """ 196 | A block that consists of Conv-Act-Norm + skip. 197 | 198 | Args: 199 | spatial_dims: number of spatial dimensions. 200 | in_channels: number of input channels. 201 | out_channels: number of output channels. 202 | kernel_size: convolution kernel size. 203 | stride: convolution stride. 204 | norm_name: feature normalization type and arguments. 205 | act_name: activation layer type and arguments. 206 | dropout: dropout probability. 207 | zero_conv: zero out the parameters of the convolution. 208 | """ 209 | def __init__( 210 | self, 211 | spatial_dims: int, 212 | in_channels: int, 213 | out_channels: int, 214 | kernel_size: Union[Sequence[int], int], 215 | stride: Union[Sequence[int], int]=1, 216 | norm_name: Union[Tuple, str, None]=None, 217 | act_name: Union[Tuple, str, None] = None, 218 | dropout: Optional[Union[Tuple, str, float]] = None, 219 | zero_conv: bool = False 220 | ): 221 | super().__init__() 222 | self.basic_block = BasicBlock(spatial_dims, in_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, zero_conv) 223 | Convolution = Conv[Conv.CONV, spatial_dims] 224 | self.conv_res = Convolution( 225 | in_channels, 226 | out_channels, 227 | kernel_size=1, 228 | stride=stride, 229 | padding=get_padding(1, stride), 230 | dilation=1, 231 | groups=1, 232 | bias=True, 233 | ) if in_channels != out_channels else nn.Identity() 234 | 235 | 236 | def forward(self, inp): 237 | out = self.basic_block(inp) 238 | residual = self.conv_res(inp) 239 | out = out+residual 240 | return out 241 | 242 | 243 | 244 | class UnetBasicBlock(nn.Module): 245 | """ 246 | A modified version of monai.networks.blocks.UnetBasicBlock with additional embedding 247 | Args: 248 | spatial_dims: number of spatial dimensions. 249 | in_channels: number of input channels. 250 | out_channels: number of output channels. 251 | kernel_size: convolution kernel size. 252 | stride: convolution stride. 253 | norm_name: feature normalization type and arguments. 254 | act_name: activation layer type and arguments. 255 | dropout: dropout probability. 256 | emb_channels: Number of embedding channels 257 | """ 258 | 259 | def __init__( 260 | self, 261 | spatial_dims: int, 262 | in_channels: int, 263 | out_channels: int, 264 | kernel_size: Union[Sequence[int], int], 265 | stride: Union[Sequence[int], int]=1, 266 | norm_name: Union[Tuple, str]=None, 267 | act_name: Union[Tuple, str]=None, 268 | dropout: Optional[Union[Tuple, str, float]] = None, 269 | emb_channels: int = None, 270 | blocks = 2 271 | ): 272 | super().__init__() 273 | self.block_seq = nn.ModuleList([ 274 | BasicBlock(spatial_dims, in_channels if i==0 else out_channels, out_channels, kernel_size, stride, norm_name, act_name, dropout, i==blocks-1) 275 | for i in range(blocks) 276 | ]) 277 | 278 | if emb_channels is not None: 279 | self.local_embedder = nn.Sequential( 280 | get_act_layer(name=act_name), 281 | nn.Linear(emb_channels, out_channels), 282 | ) 283 | 284 | def forward(self, x, emb=None): 285 | # ------------ Embedding ---------- 286 | if emb is not None: 287 | emb = self.local_embedder(emb) 288 | b,c, *_ = emb.shape 289 | sp_dim = x.ndim-2 290 | emb = emb.reshape(b, c, *((1,)*sp_dim) ) 291 | # scale, shift = emb.chunk(2, dim = 1) 292 | # x = x * (scale + 1) + shift 293 | # x = x+emb 294 | 295 | # ----------- Convolution --------- 296 | n_blocks = len(self.block_seq) 297 | for i, block in enumerate(self.block_seq): 298 | x = block(x) 299 | if (emb is not None) and i= 2.0 60 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 61 | if not self.flash: 62 | print("WARNING: using slow attention. Flash Attention atm needs PyTorch >= 2.0") 63 | # causal mask to ensure that attention is only applied to the left in the input sequence 64 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 65 | .view(1, 1, config.block_size, config.block_size)) 66 | 67 | def forward(self, x): 68 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 69 | 70 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 71 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 72 | k = k.view(B, T, self.n_head, C // 73 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 74 | q = q.view(B, T, self.n_head, C // 75 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 76 | v = v.view(B, T, self.n_head, C // 77 | self.n_head).transpose(1, 2) # (B, nh, T, hs) 78 | 79 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 80 | if self.flash: 81 | # efficient attention using Flash Attention CUDA kernels 82 | y = torch.nn.functional.scaled_dot_product_attention( 83 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) 84 | else: 85 | # manual implementation of attention 86 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 87 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) 88 | att = F.softmax(att, dim=-1) 89 | att = self.attn_dropout(att) 90 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 91 | # re-assemble all head outputs side by side 92 | y = y.transpose(1, 2).contiguous().view(B, T, C) 93 | 94 | # output projection 95 | y = self.resid_dropout(self.c_proj(y)) 96 | return y 97 | 98 | 99 | class MLP(nn.Module): 100 | 101 | def __init__(self, config): 102 | super().__init__() 103 | self.c_fc = nn.Linear( 104 | config.n_embd, 4 * config.n_embd, bias=config.bias) 105 | self.c_proj = nn.Linear( 106 | 4 * config.n_embd, config.n_embd, bias=config.bias) 107 | self.dropout = nn.Dropout(config.dropout) 108 | 109 | def forward(self, x): 110 | x = self.c_fc(x) 111 | x = new_gelu(x) 112 | x = self.c_proj(x) 113 | x = self.dropout(x) 114 | return x 115 | 116 | 117 | class Block(nn.Module): 118 | 119 | def __init__(self, config): 120 | super().__init__() 121 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 122 | self.attn = CausalSelfAttention(config) 123 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 124 | self.mlp = MLP(config) 125 | 126 | def forward(self, x): 127 | x = x + self.attn(self.ln_1(x)) 128 | x = x + self.mlp(self.ln_2(x)) 129 | return x 130 | 131 | 132 | @dataclass 133 | class GPTConfig: 134 | block_size: int = 1024 135 | # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 136 | vocab_size: int = 50304 137 | n_layer: int = 12 138 | n_head: int = 12 139 | n_embd: int = 768 140 | dropout: float = 0.0 141 | # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 142 | bias: bool = True 143 | 144 | 145 | class GPT(nn.Module): 146 | 147 | def __init__(self, config): 148 | super().__init__() 149 | assert config.vocab_size is not None 150 | assert config.block_size is not None 151 | self.config = config 152 | 153 | self.transformer = nn.ModuleDict(dict( 154 | wte=nn.Embedding(config.vocab_size, config.n_embd), 155 | wpe=nn.Embedding(config.block_size, config.n_embd), 156 | drop=nn.Dropout(config.dropout), 157 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 158 | ln_f=LayerNorm(config.n_embd, bias=config.bias), 159 | )) 160 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 161 | # with weight tying when using torch.compile() some warnings get generated: 162 | # "UserWarning: functional_call was passed multiple values for tied weights. 163 | # This behavior is deprecated and will be an error in future versions" 164 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 165 | # https://paperswithcode.com/method/weight-tying 166 | self.transformer.wte.weight = self.lm_head.weight 167 | 168 | # init all weights 169 | self.apply(self._init_weights) 170 | # apply special scaled init to the residual projections, per GPT-2 paper 171 | for pn, p in self.named_parameters(): 172 | if pn.endswith('c_proj.weight'): 173 | torch.nn.init.normal_( 174 | p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 175 | 176 | # report number of parameters 177 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 178 | 179 | def get_num_params(self, non_embedding=True): 180 | """ 181 | Return the number of parameters in the model. 182 | For non-embedding count (default), the position embeddings get subtracted. 183 | The token embeddings would too, except due to the parameter sharing these 184 | params are actually used as weights in the final layer, so we include them. 185 | """ 186 | n_params = sum(p.numel() for p in self.parameters()) 187 | if non_embedding: 188 | n_params -= self.transformer.wpe.weight.numel() 189 | return n_params 190 | 191 | def _init_weights(self, module): 192 | if isinstance(module, nn.Linear): 193 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 194 | if module.bias is not None: 195 | torch.nn.init.zeros_(module.bias) 196 | elif isinstance(module, nn.Embedding): 197 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 198 | 199 | def forward(self, idx, targets=None): 200 | device = idx.device 201 | b, t = idx.size() 202 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 203 | pos = torch.arange(0, t, dtype=torch.long, 204 | device=device).unsqueeze(0) # shape (1, t) 205 | 206 | # forward the GPT model itself 207 | # token embeddings of shape (b, t, n_embd) 208 | tok_emb = self.transformer.wte(idx) 209 | # position embeddings of shape (1, t, n_embd) 210 | pos_emb = self.transformer.wpe(pos) 211 | x = self.transformer.drop(tok_emb + pos_emb) 212 | for block in self.transformer.h: 213 | x = block(x) 214 | x = self.transformer.ln_f(x) 215 | 216 | if targets is not None: 217 | # if we are given some desired targets also calculate the loss 218 | logits = self.lm_head(x) 219 | loss = F.cross_entropy( 220 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 221 | 222 | else: 223 | logits = self.lm_head(x) 224 | loss = None 225 | 226 | return logits, None 227 | 228 | # else: 229 | # inference-time mini-optimization: only forward the lm_head on the very last position 230 | # note: using list [-1] to preserve the time dim 231 | #logits = self.lm_head(x[:, [-1], :]) 232 | #loss = None 233 | 234 | # return logits, loss 235 | 236 | def crop_block_size(self, block_size): 237 | # model surgery to decrease the block size if necessary 238 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 239 | # but want to use a smaller block size for some smaller, simpler model 240 | assert block_size <= self.config.block_size 241 | self.config.block_size = block_size 242 | self.transformer.wpe.weight = nn.Parameter( 243 | self.transformer.wpe.weight[:block_size]) 244 | for block in self.transformer.h: 245 | block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] 246 | 247 | @classmethod 248 | def from_pretrained(cls, model_type, override_args=None): 249 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 250 | override_args = override_args or {} # default to empty dict 251 | # only dropout can be overridden see more notes below 252 | assert all(k == 'dropout' for k in override_args) 253 | from transformers import GPT2LMHeadModel 254 | print("loading weights from pretrained gpt: %s" % model_type) 255 | 256 | # n_layer, n_head and n_embd are determined from model_type 257 | config_args = { 258 | # 124M params 259 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), 260 | # 350M params 261 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), 262 | # 774M params 263 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), 264 | # 1558M params 265 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), 266 | }[model_type] 267 | print("forcing vocab_size=50257, block_size=1024, bias=True") 268 | # always 50257 for GPT model checkpoints 269 | config_args['vocab_size'] = 50257 270 | # always 1024 for GPT model checkpoints 271 | config_args['block_size'] = 1024 272 | config_args['bias'] = True # always True for GPT model checkpoints 273 | # we can override the dropout rate, if desired 274 | if 'dropout' in override_args: 275 | print(f"overriding dropout rate to {override_args['dropout']}") 276 | config_args['dropout'] = override_args['dropout'] 277 | # create a from-scratch initialized minGPT model 278 | config = GPTConfig(**config_args) 279 | model = GPT(config) 280 | sd = model.state_dict() 281 | sd_keys = sd.keys() 282 | # discard this mask / buffer, not a param 283 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] 284 | 285 | # init a huggingface/transformers model 286 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 287 | sd_hf = model_hf.state_dict() 288 | 289 | # copy while ensuring all of the parameters are aligned and match in names and shapes 290 | sd_keys_hf = sd_hf.keys() 291 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith( 292 | '.attn.masked_bias')] # ignore these, just a buffer 293 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith( 294 | '.attn.bias')] # same, just the mask (buffer) 295 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 296 | 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 297 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 298 | # this means that we have to transpose these weights when we import them 299 | assert len(sd_keys_hf) == len( 300 | sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 301 | for k in sd_keys_hf: 302 | if any(k.endswith(w) for w in transposed): 303 | # special treatment for the Conv1D weights we need to transpose 304 | assert sd_hf[k].shape[::-1] == sd[k].shape 305 | with torch.no_grad(): 306 | sd[k].copy_(sd_hf[k].t()) 307 | else: 308 | # vanilla copy over the other parameters 309 | assert sd_hf[k].shape == sd[k].shape 310 | with torch.no_grad(): 311 | sd[k].copy_(sd_hf[k]) 312 | 313 | return model 314 | 315 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 316 | """ 317 | This long function is unfortunately doing something very simple and is being very defensive: 318 | We are separating out all parameters of the model into two buckets: those that will experience 319 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 320 | We are then returning the PyTorch optimizer object. 321 | """ 322 | 323 | # separate out all parameters to those that will and won't experience regularizing weight decay 324 | decay = set() 325 | no_decay = set() 326 | whitelist_weight_modules = (torch.nn.Linear, ) 327 | blacklist_weight_modules = ( 328 | torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) 329 | for mn, m in self.named_modules(): 330 | for pn, p in m.named_parameters(): 331 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 332 | # random note: because named_modules and named_parameters are recursive 333 | # we will see the same tensors p many many times. but doing it this way 334 | # allows us to know which parent module any tensor p belongs to... 335 | if pn.endswith('bias'): 336 | # all biases will not be decayed 337 | no_decay.add(fpn) 338 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 339 | # weights of whitelist modules will be weight decayed 340 | decay.add(fpn) 341 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 342 | # weights of blacklist modules will NOT be weight decayed 343 | no_decay.add(fpn) 344 | 345 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 346 | # will appear in the no_decay and decay sets respectively after the above. 347 | # In addition, because named_parameters() doesn't return duplicates, it 348 | # will only return the first occurence, key'd by 'transformer.wte.weight', below. 349 | # so let's manually remove 'lm_head.weight' from decay set. This will include 350 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 351 | decay.remove('lm_head.weight') 352 | 353 | # validate that we considered every parameter 354 | param_dict = {pn: p for pn, p in self.named_parameters()} 355 | inter_params = decay & no_decay 356 | union_params = decay | no_decay 357 | assert len( 358 | inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 359 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 360 | % (str(param_dict.keys() - union_params), ) 361 | 362 | # create the pytorch optimizer object 363 | optim_groups = [ 364 | {"params": [param_dict[pn] for pn in sorted( 365 | list(decay))], "weight_decay": weight_decay}, 366 | {"params": [param_dict[pn] 367 | for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 368 | ] 369 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster 370 | use_fused = (device_type == 'cuda') and ( 371 | 'fused' in inspect.signature(torch.optim.AdamW).parameters) 372 | print(f"using fused AdamW: {use_fused}") 373 | extra_args = dict(fused=True) if use_fused else dict() 374 | optimizer = torch.optim.AdamW( 375 | optim_groups, lr=learning_rate, betas=betas, **extra_args) 376 | 377 | return optimizer 378 | 379 | def estimate_mfu(self, fwdbwd_per_iter, dt): 380 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 381 | # first estimate the number of flops we do per iteration. 382 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 383 | N = self.get_num_params() 384 | cfg = self.config 385 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 386 | flops_per_token = 6*N + 12*L*H*Q*T 387 | flops_per_fwdbwd = flops_per_token * T 388 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 389 | # express our flops throughput as ratio of A100 bfloat16 peak flops 390 | flops_achieved = flops_per_iter * (1.0/dt) # per second 391 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 392 | mfu = flops_achieved / flops_promised 393 | return mfu 394 | 395 | @torch.no_grad() 396 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 397 | """ 398 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 399 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 400 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 401 | """ 402 | for _ in range(max_new_tokens): 403 | # if the sequence context is growing too long we must crop it at block_size 404 | idx_cond = idx if idx.size( 405 | 1) <= self.config.block_size else idx[:, -self.config.block_size:] 406 | # forward the model to get the logits for the index in the sequence 407 | logits, _ = self(idx_cond) 408 | # pluck the logits at the final step and scale by desired temperature 409 | logits = logits[:, -1, :] / temperature 410 | # optionally crop the logits to only the top k options 411 | if top_k is not None: 412 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 413 | logits[logits < v[:, [-1]]] = -float('Inf') 414 | # apply softmax to convert logits to (normalized) probabilities 415 | probs = F.softmax(logits, dim=-1) 416 | # sample from the distribution 417 | idx_next = torch.multinomial(probs, num_samples=1) 418 | # append sampled index to the running sequence and continue 419 | idx = torch.cat((idx, idx_next), dim=1) 420 | 421 | return idx 422 | -------------------------------------------------------------------------------- /vqgan/model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | from monai.networks.blocks import UnetOutBlock 8 | 9 | 10 | from vqgan.conv_blocks import DownBlock, UpBlock, BasicBlock, BasicResBlock, UnetResBlock, UnetBasicBlock 11 | from vqgan.gan_losses import hinge_d_loss 12 | from vqgan.perceivers import LPIPS 13 | from vqgan.model_base import BasicModel, VeryBasicModel 14 | 15 | 16 | from pytorch_msssim import SSIM, ssim 17 | 18 | 19 | class DiagonalGaussianDistribution(nn.Module): 20 | 21 | def forward(self, x): 22 | mean, logvar = torch.chunk(x, 2, dim=1) 23 | logvar = torch.clamp(logvar, -30.0, 20.0) 24 | std = torch.exp(0.5 * logvar) 25 | sample = torch.randn(mean.shape, generator=None, device=x.device) 26 | z = mean + std * sample 27 | 28 | batch_size = x.shape[0] 29 | var = torch.exp(logvar) 30 | kl = 0.5 * torch.sum(torch.pow(mean, 2) + var - 31 | 1.0 - logvar)/batch_size 32 | 33 | return z, kl 34 | 35 | 36 | class VectorQuantizer(nn.Module): 37 | def __init__(self, num_embeddings, emb_channels, beta=0.25): 38 | super().__init__() 39 | self.num_embeddings = num_embeddings 40 | self.emb_channels = emb_channels 41 | self.beta = beta 42 | 43 | self.embedder = nn.Embedding(num_embeddings, emb_channels) 44 | self.embedder.weight.data.uniform_(-1.0 / 45 | self.num_embeddings, 1.0 / self.num_embeddings) 46 | 47 | def forward(self, z): 48 | assert z.shape[1] == self.emb_channels, "Channels of z and codebook don't match" 49 | z_ch = torch.moveaxis(z, 1, -1) # [B, C, *] -> [B, *, C] 50 | # [B, *, C] -> [Bx*, C], Note: or use contiguous() and view() 51 | z_flattened = z_ch.reshape(-1, self.emb_channels) 52 | 53 | # distances from z to embeddings e: (z - e)^2 = z^2 + e^2 - 2 e * z 54 | dist = (torch.sum(z_flattened**2, dim=1, keepdim=True) 55 | + torch.sum(self.embedder.weight**2, dim=1) 56 | - 2 * torch.einsum("bd,dn->bn", z_flattened, 57 | self.embedder.weight.t()) 58 | ) # [Bx*, num_embeddings] 59 | 60 | min_encoding_indices = torch.argmin(dist, dim=1) # [Bx*] 61 | z_q = self.embedder(min_encoding_indices) # [Bx*, C] 62 | z_q = z_q.view(z_ch.shape) # [Bx*, C] -> [B, *, C] 63 | z_q = torch.moveaxis(z_q, -1, 1) # [B, *, C] -> [B, C, *] 64 | 65 | # Compute Embedding Loss 66 | loss = self.beta * \ 67 | torch.mean((z_q.detach() - z) ** 2) + \ 68 | torch.mean((z_q - z.detach()) ** 2) 69 | 70 | # preserve gradients 71 | z_q = z + (z_q - z).detach() 72 | 73 | return z_q, loss 74 | 75 | def image_to_indices(self, z): 76 | assert z.shape[1] == self.emb_channels, "Channels of z and codebook don't match" 77 | z_ch = torch.moveaxis(z, 1, -1) # [B, C, *] -> [B, *, C] 78 | # [B, *, C] -> [Bx*, C], Note: or use contiguous() and view() 79 | z_flattened = z_ch.reshape(-1, self.emb_channels) 80 | 81 | # distances from z to embeddings e: (z - e)^2 = z^2 + e^2 - 2 e * z 82 | dist = (torch.sum(z_flattened**2, dim=1, keepdim=True) 83 | + torch.sum(self.embedder.weight**2, dim=1) 84 | - 2 * torch.einsum("bd,dn->bn", z_flattened, 85 | self.embedder.weight.t()) 86 | ) # [Bx*, num_embeddings] 87 | 88 | min_encoding_indices = torch.argmin(dist, dim=1) # [Bx*] 89 | embedding_shape = z_ch.shape 90 | return min_encoding_indices, embedding_shape 91 | 92 | def indices_to_image(self, indices, embedding_shape): 93 | z_q = self.embedder(indices) # [Bx*, C] 94 | z_q = z_q.view(embedding_shape) # [Bx*, C] -> [B, *, C] 95 | z_q = torch.moveaxis(z_q, -1, 1) # [B, *, C] -> [B, C, *] 96 | return z_q 97 | 98 | 99 | class Discriminator(nn.Module): 100 | def __init__(self, 101 | in_channels=1, 102 | spatial_dims=3, 103 | hid_chs=[32, 64, 128, 256, 512], 104 | kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], 105 | strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], 106 | act_name=("Swish", {}), 107 | norm_name=("GROUP", {'num_groups': 32, "affine": True}), 108 | dropout=None 109 | ): 110 | super().__init__() 111 | 112 | self.inc = BasicBlock( 113 | spatial_dims=spatial_dims, 114 | in_channels=in_channels, 115 | out_channels=hid_chs[0], 116 | # 2*pad = kernel-stride -> kernel = 2*pad + stride => 1 = 2*0+1, 3, =2*1+1, 2 = 2*0+2, 4 = 2*1+2 117 | kernel_size=kernel_sizes[0], 118 | stride=strides[0], 119 | norm_name=norm_name, 120 | act_name=act_name, 121 | dropout=dropout, 122 | ) 123 | 124 | self.encoder = nn.Sequential(*[ 125 | BasicBlock( 126 | spatial_dims=spatial_dims, 127 | in_channels=hid_chs[i-1], 128 | out_channels=hid_chs[i], 129 | kernel_size=kernel_sizes[i], 130 | stride=strides[i], 131 | act_name=act_name, 132 | norm_name=norm_name, 133 | dropout=dropout) 134 | for i in range(1, len(hid_chs)) 135 | ]) 136 | 137 | self.outc = BasicBlock( 138 | spatial_dims=spatial_dims, 139 | in_channels=hid_chs[-1], 140 | out_channels=1, 141 | kernel_size=3, 142 | stride=1, 143 | act_name=None, 144 | norm_name=None, 145 | dropout=None, 146 | zero_conv=True 147 | ) 148 | 149 | def forward(self, x): 150 | x = self.inc(x) 151 | x = self.encoder(x) 152 | return self.outc(x) 153 | 154 | 155 | class NLayerDiscriminator(nn.Module): 156 | def __init__(self, 157 | in_channels=1, 158 | spatial_dims=3, 159 | hid_chs=[64, 128, 256, 512, 512], 160 | kernel_sizes=[4, 4, 4, 4, 4], 161 | strides=[2, 2, 2, 1, 1], 162 | act_name=("LeakyReLU", {'negative_slope': 0.2}), 163 | norm_name=("BATCH", {}), 164 | dropout=None 165 | ): 166 | super().__init__() 167 | 168 | self.inc = BasicBlock( 169 | spatial_dims=spatial_dims, 170 | in_channels=in_channels, 171 | out_channels=hid_chs[0], 172 | kernel_size=kernel_sizes[0], 173 | stride=strides[0], 174 | norm_name=None, 175 | act_name=act_name, 176 | dropout=dropout, 177 | ) 178 | 179 | self.encoder = nn.Sequential(*[ 180 | BasicBlock( 181 | spatial_dims=spatial_dims, 182 | in_channels=hid_chs[i-1], 183 | out_channels=hid_chs[i], 184 | kernel_size=kernel_sizes[i], 185 | stride=strides[i], 186 | act_name=act_name, 187 | norm_name=norm_name, 188 | dropout=dropout) 189 | for i in range(1, len(strides)) 190 | ]) 191 | 192 | self.outc = BasicBlock( 193 | spatial_dims=spatial_dims, 194 | in_channels=hid_chs[-1], 195 | out_channels=1, 196 | kernel_size=4, 197 | stride=1, 198 | norm_name=None, 199 | act_name=None, 200 | dropout=False, 201 | ) 202 | 203 | def forward(self, x): 204 | x = self.inc(x) 205 | x = self.encoder(x) 206 | return self.outc(x) 207 | 208 | 209 | class VQVAE(BasicModel): 210 | def __init__( 211 | self, 212 | in_channels=3, 213 | out_channels=3, 214 | spatial_dims=2, 215 | emb_channels=4, 216 | num_embeddings=8192, 217 | hid_chs=[32, 64, 128, 256], 218 | kernel_sizes=[3, 3, 3, 3], 219 | strides=[1, 2, 2, 2], 220 | norm_name=("GROUP", {'num_groups': 32, "affine": True}), 221 | act_name=("Swish", {}), 222 | dropout=0.0, 223 | use_res_block=True, 224 | deep_supervision=False, 225 | learnable_interpolation=True, 226 | use_attention='none', 227 | beta=0.25, 228 | embedding_loss_weight=1.0, 229 | perceiver=LPIPS, 230 | perceiver_kwargs={}, 231 | perceptual_loss_weight=1.0, 232 | 233 | 234 | optimizer=torch.optim.Adam, 235 | optimizer_kwargs={'lr': 1e-4}, 236 | lr_scheduler=None, 237 | lr_scheduler_kwargs={}, 238 | loss=torch.nn.L1Loss, 239 | loss_kwargs={'reduction': 'none'}, 240 | 241 | sample_every_n_steps=1000 242 | 243 | ): 244 | super().__init__( 245 | optimizer=optimizer, 246 | optimizer_kwargs=optimizer_kwargs, 247 | lr_scheduler=lr_scheduler, 248 | lr_scheduler_kwargs=lr_scheduler_kwargs 249 | ) 250 | self.sample_every_n_steps = sample_every_n_steps 251 | self.loss_fct = loss(**loss_kwargs) 252 | self.embedding_loss_weight = embedding_loss_weight 253 | self.perceiver = perceiver( 254 | **perceiver_kwargs).eval() if perceiver is not None else None 255 | self.perceptual_loss_weight = perceptual_loss_weight 256 | use_attention = use_attention if isinstance(use_attention, list) else [ 257 | use_attention]*len(strides) 258 | self.depth = len(strides) 259 | self.deep_supervision = deep_supervision 260 | 261 | # ----------- In-Convolution ------------ 262 | ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock 263 | self.inc = ConvBlock(spatial_dims, in_channels, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], 264 | act_name=act_name, norm_name=norm_name) 265 | 266 | # ----------- Encoder ---------------- 267 | self.encoders = nn.ModuleList([ 268 | DownBlock( 269 | spatial_dims, 270 | hid_chs[i-1], 271 | hid_chs[i], 272 | kernel_sizes[i], 273 | strides[i], 274 | kernel_sizes[i], 275 | norm_name, 276 | act_name, 277 | dropout, 278 | use_res_block, 279 | learnable_interpolation, 280 | use_attention[i]) 281 | for i in range(1, self.depth) 282 | ]) 283 | 284 | # ----------- Out-Encoder ------------ 285 | self.out_enc = BasicBlock(spatial_dims, hid_chs[-1], emb_channels, 1) 286 | 287 | # ----------- Quantizer -------------- 288 | self.quantizer = VectorQuantizer( 289 | num_embeddings=num_embeddings, 290 | emb_channels=emb_channels, 291 | beta=beta 292 | ) 293 | 294 | # ----------- In-Decoder ------------ 295 | self.inc_dec = ConvBlock( 296 | spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name) 297 | 298 | # ------------ Decoder ---------- 299 | self.decoders = nn.ModuleList([ 300 | UpBlock( 301 | spatial_dims, 302 | hid_chs[i+1], 303 | hid_chs[i], 304 | kernel_size=kernel_sizes[i+1], 305 | stride=strides[i+1], 306 | upsample_kernel_size=strides[i+1], 307 | norm_name=norm_name, 308 | act_name=act_name, 309 | dropout=dropout, 310 | use_res_block=use_res_block, 311 | learnable_interpolation=learnable_interpolation, 312 | use_attention=use_attention[i], 313 | skip_channels=0) 314 | for i in range(self.depth-1) 315 | ]) 316 | 317 | # --------------- Out-Convolution ---------------- 318 | self.outc = BasicBlock( 319 | spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True) 320 | if isinstance(deep_supervision, bool): 321 | deep_supervision = self.depth-1 if deep_supervision else 0 322 | self.outc_ver = nn.ModuleList([ 323 | BasicBlock(spatial_dims, hid_chs[i], 324 | out_channels, 1, zero_conv=True) 325 | for i in range(1, deep_supervision+1) 326 | ]) 327 | 328 | def encode(self, x): 329 | h = self.inc(x) 330 | for i in range(len(self.encoders)): 331 | h = self.encoders[i](h) 332 | z = self.out_enc(h) 333 | return z 334 | 335 | def encode_to_indices(self, x): 336 | h = self.inc(x) 337 | for i in range(len(self.encoders)): 338 | h = self.encoders[i](h) 339 | z = self.out_enc(h) 340 | indices, embedding_shape = self.quantizer.image_to_indices(z) 341 | return indices, embedding_shape 342 | 343 | def decode_from_indices(self, indices, embedding_shape): 344 | z = self.quantizer.indices_to_image(indices, embedding_shape) 345 | h = self.inc_dec(z) 346 | for i in range(len(self.decoders), 0, -1): 347 | h = self.decoders[i-1](h) 348 | x = self.outc(h) 349 | return x 350 | 351 | def decode(self, z): 352 | z, _ = self.quantizer(z) 353 | h = self.inc_dec(z) 354 | for i in range(len(self.decoders), 0, -1): 355 | h = self.decoders[i-1](h) 356 | x = self.outc(h) 357 | return x 358 | 359 | def forward(self, x_in): 360 | # --------- Encoder -------------- 361 | h = self.inc(x_in) 362 | for i in range(len(self.encoders)): 363 | h = self.encoders[i](h) 364 | z = self.out_enc(h) 365 | 366 | # --------- Quantizer -------------- 367 | z_q, emb_loss = self.quantizer(z) 368 | 369 | # -------- Decoder ----------- 370 | out_hor = [] 371 | h = self.inc_dec(z_q) 372 | for i in range(len(self.decoders)-1, -1, -1): 373 | out_hor.append(self.outc_ver[i](h)) if i < len( 374 | self.outc_ver) else None 375 | h = self.decoders[i](h) 376 | out = self.outc(h) 377 | 378 | return out, out_hor[::-1], emb_loss 379 | 380 | def perception_loss(self, pred, target, depth=0): 381 | if (self.perceiver is not None) and (depth < 2): 382 | self.perceiver.eval() 383 | return self.perceiver(pred, target)*self.perceptual_loss_weight 384 | else: 385 | return 0 386 | 387 | def ssim_loss(self, pred, target): 388 | return 1-ssim(((pred+1)/2).clamp(0, 1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False, 389 | nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1)) 390 | 391 | def rec_loss(self, pred, pred_vertical, target): 392 | interpolation_mode = 'nearest-exact' 393 | # horizontal (equal) + vertical (reducing with every step down) 394 | weights = [1/2**i for i in range(1+len(pred_vertical))] 395 | tot_weight = sum(weights) 396 | weights = [w/tot_weight for w in weights] 397 | 398 | # Loss 399 | loss = 0 400 | loss += torch.mean(self.loss_fct(pred, target)+self.perception_loss( 401 | pred, target)+self.ssim_loss(pred, target))*weights[0] 402 | 403 | for i, pred_i in enumerate(pred_vertical): 404 | target_i = F.interpolate( 405 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 406 | loss += torch.mean(self.loss_fct(pred_i, target_i)+self.perception_loss( 407 | pred_i, target_i)+self.ssim_loss(pred_i, target_i))*weights[i+1] 408 | 409 | return loss 410 | 411 | def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): 412 | # ------------------------- Get Source/Target --------------------------- 413 | x = batch['source'] 414 | target = x 415 | 416 | # ------------------------- Run Model --------------------------- 417 | pred, pred_vertical, emb_loss = self(x) 418 | 419 | # ------------------------- Compute Loss --------------------------- 420 | loss = self.rec_loss(pred, pred_vertical, target) 421 | loss += emb_loss*self.embedding_loss_weight 422 | 423 | # --------------------- Compute Metrics ------------------------------- 424 | with torch.no_grad(): 425 | logging_dict = {'loss': loss, 'emb_loss': emb_loss} 426 | logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target) 427 | logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target) 428 | logging_dict['ssim'] = ssim( 429 | (pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1) 430 | 431 | # ----------------- Log Scalars ---------------------- 432 | for metric_name, metric_val in logging_dict.items(): 433 | self.log(f"{state}/{metric_name}", metric_val, 434 | batch_size=x.shape[0], on_step=True, on_epoch=True) 435 | 436 | # ----------------- Save Image ------------------------------ 437 | if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: 438 | log_step = self.global_step // self.sample_every_n_steps 439 | path_out = Path(self.logger.log_dir)/'images' 440 | path_out.mkdir(parents=True, exist_ok=True) 441 | # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images 442 | 443 | def depth2batch(tensor, batch=0): 444 | return (tensor if tensor.ndim < 5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:, None]) 445 | images = torch.cat([depth2batch(img)[:16] for img in (x, pred)]) 446 | save_image( 447 | images, path_out/f'sample_{log_step}.png', nrow=images.shape[0] // 2, normalize=True) 448 | 449 | return loss 450 | 451 | 452 | class VQGAN(VeryBasicModel): 453 | def __init__( 454 | self, 455 | in_channels=3, 456 | out_channels=3, 457 | spatial_dims=2, 458 | emb_channels=4, 459 | num_embeddings=8192, 460 | hid_chs=[64, 128, 256, 512], 461 | kernel_sizes=[3, 3, 3, 3], 462 | strides=[1, 2, 2, 2], 463 | norm_name=("GROUP", {'num_groups': 32, "affine": True}), 464 | act_name=("Swish", {}), 465 | dropout=0.0, 466 | use_res_block=True, 467 | deep_supervision=False, 468 | learnable_interpolation=True, 469 | use_attention='none', 470 | beta=0.25, 471 | embedding_loss_weight=1.0, 472 | perceiver=LPIPS, 473 | perceiver_kwargs={}, 474 | perceptual_loss_weight: float = 1.0, 475 | 476 | 477 | start_gan_train_step=50000, # NOTE step increase with each optimizer 478 | gan_loss_weight: float = 1.0, # = discriminator 479 | 480 | optimizer_vqvae=torch.optim.Adam, 481 | optimizer_gan=torch.optim.Adam, 482 | optimizer_vqvae_kwargs={'lr': 1e-6}, 483 | optimizer_gan_kwargs={'lr': 1e-6}, 484 | lr_scheduler_vqvae=None, 485 | lr_scheduler_vqvae_kwargs={}, 486 | lr_scheduler_gan=None, 487 | lr_scheduler_gan_kwargs={}, 488 | 489 | pixel_loss=torch.nn.L1Loss, 490 | pixel_loss_kwargs={'reduction': 'none'}, 491 | gan_loss_fct=hinge_d_loss, 492 | 493 | sample_every_n_steps=1000 494 | 495 | ): 496 | super().__init__() 497 | self.sample_every_n_steps = sample_every_n_steps 498 | self.start_gan_train_step = start_gan_train_step 499 | self.gan_loss_weight = gan_loss_weight 500 | self.embedding_loss_weight = embedding_loss_weight 501 | 502 | self.optimizer_vqvae = optimizer_vqvae 503 | self.optimizer_gan = optimizer_gan 504 | self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs 505 | self.optimizer_gan_kwargs = optimizer_gan_kwargs 506 | self.lr_scheduler_vqvae = lr_scheduler_vqvae 507 | self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs 508 | self.lr_scheduler_gan = lr_scheduler_gan 509 | self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs 510 | 511 | self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs) 512 | self.gan_loss_fct = gan_loss_fct 513 | 514 | self.vqvae = VQVAE(in_channels, out_channels, spatial_dims, emb_channels, num_embeddings, hid_chs, kernel_sizes, 515 | strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention, 516 | beta, embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight) 517 | 518 | self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides, 519 | act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)]) 520 | 521 | # self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims) 522 | # for _ in range(len(self.vqvae.decoder.outc_ver)+1)]) 523 | 524 | def encode(self, x): 525 | return self.vqvae.encode(x) 526 | 527 | def decode(self, z): 528 | return self.vqvae.decode(z) 529 | 530 | def forward(self, x): 531 | return self.vqvae.forward(x) 532 | 533 | def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0): 534 | # ------ VQVAE ------- 535 | rec_loss = self.vqvae.rec_loss(pred, [], target) 536 | 537 | # ------- GAN ----- 538 | if step > self.start_gan_train_step: 539 | gan_loss = -torch.mean(discriminator[depth](pred)) 540 | lambda_weight = self.compute_lambda( 541 | rec_loss, gan_loss, dec_out_layer) 542 | gan_loss = gan_loss*lambda_weight 543 | 544 | with torch.no_grad(): 545 | self.log(f"train/gan_loss_{depth}", 546 | gan_loss, on_step=True, on_epoch=True) 547 | self.log(f"train/lambda_{depth}", 548 | lambda_weight, on_step=True, on_epoch=True) 549 | else: 550 | # torch.tensor([0.0], requires_grad=True, device=target.device) 551 | gan_loss = 0 552 | 553 | return self.gan_loss_weight*gan_loss+rec_loss 554 | 555 | def gan_img_loss(self, pred, target, step, discriminators, depth): 556 | if (step > self.start_gan_train_step) and (depth < len(discriminators)): 557 | logits_real = discriminators[depth](target.detach()) 558 | logits_fake = discriminators[depth](pred.detach()) 559 | loss = self.gan_loss_fct(logits_real, logits_fake) 560 | else: 561 | loss = torch.tensor(0.0, requires_grad=True, device=target.device) 562 | 563 | with torch.no_grad(): 564 | self.log(f"train/loss_1_{depth}", loss, 565 | on_step=True, on_epoch=True) 566 | return loss 567 | 568 | def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): 569 | # ------------------------- Get Source/Target --------------------------- 570 | x = batch['source'] 571 | target = x 572 | 573 | # ------------------------- Run Model --------------------------- 574 | pred, pred_vertical, emb_loss = self(x) 575 | 576 | # ------------------------- Compute Loss --------------------------- 577 | interpolation_mode = 'area' 578 | # horizontal + vertical (reducing with every step down) 579 | weights = [1/2**i for i in range(1+len(pred_vertical))] 580 | tot_weight = sum(weights) 581 | weights = [w/tot_weight for w in weights] 582 | logging_dict = {} 583 | 584 | if optimizer_idx == 0: 585 | # Horizontal/Top Layer 586 | img_loss = self.vae_img_loss( 587 | pred, target, self.vqvae.outc.conv, step, self.discriminator, 0)*weights[0] 588 | 589 | # Vertical/Deep Layer 590 | for i, pred_i in enumerate(pred_vertical): 591 | target_i = F.interpolate( 592 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 593 | img_loss += self.vae_img_loss( 594 | pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1)*weights[i+1] 595 | loss = img_loss+self.embedding_loss_weight*emb_loss 596 | 597 | with torch.no_grad(): 598 | logging_dict[f'img_loss'] = img_loss 599 | logging_dict[f'emb_loss'] = emb_loss 600 | logging_dict['loss_0'] = loss 601 | 602 | elif optimizer_idx == 1: 603 | # Horizontal/Top Layer 604 | loss = self.gan_img_loss( 605 | pred, target, step, self.discriminator, 0)*weights[0] 606 | 607 | # Vertical/Deep Layer 608 | for i, pred_i in enumerate(pred_vertical): 609 | target_i = F.interpolate( 610 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 611 | loss += self.gan_img_loss(pred_i, target_i, 612 | step, self.discriminator, i+1)*weights[i+1] 613 | 614 | with torch.no_grad(): 615 | logging_dict['loss_1'] = loss 616 | 617 | # --------------------- Compute Metrics ------------------------------- 618 | with torch.no_grad(): 619 | logging_dict['loss'] = loss 620 | logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x) 621 | logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x) 622 | logging_dict['ssim'] = ssim( 623 | (pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1) 624 | 625 | # ----------------- Log Scalars ---------------------- 626 | for metric_name, metric_val in logging_dict.items(): 627 | self.log(f"{state}/{metric_name}", metric_val, 628 | batch_size=x.shape[0], on_step=True, on_epoch=True) 629 | 630 | # ----------------- Save Image ------------------------------ 631 | # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ... 632 | if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: 633 | 634 | log_step = self.global_step // self.sample_every_n_steps 635 | path_out = Path(self.logger.log_dir)/'images' 636 | path_out.mkdir(parents=True, exist_ok=True) 637 | # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images 638 | 639 | def depth2batch(tensor, batch=0): 640 | return (tensor if tensor.ndim < 5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:, None]) 641 | images = torch.cat([depth2batch(img)[:16] for img in (x, pred)]) 642 | save_image( 643 | images, path_out/f'sample_{log_step}.png', nrow=images.shape[0] // 2, normalize=True) 644 | 645 | return loss 646 | 647 | def configure_optimizers(self): 648 | opt_vqvae = self.optimizer_vqvae( 649 | self.vqvae.parameters(), **self.optimizer_vqvae_kwargs) 650 | opt_gan = self.optimizer_gan( 651 | self.discriminator.parameters(), **self.optimizer_gan_kwargs) 652 | schedulers = [] 653 | if self.lr_scheduler_vqvae is not None: 654 | schedulers.append({ 655 | 'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs), 656 | 'interval': 'step', 657 | 'frequency': 1 658 | }) 659 | if self.lr_scheduler_gan is not None: 660 | schedulers.append({ 661 | 'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs), 662 | 'interval': 'step', 663 | 'frequency': 1 664 | }) 665 | return [opt_vqvae, opt_gan], schedulers 666 | 667 | def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4): 668 | """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841""" 669 | rec_grads = torch.autograd.grad( 670 | rec_loss, dec_out_layer.weight, retain_graph=True)[0] 671 | gan_grads = torch.autograd.grad( 672 | gan_loss, dec_out_layer.weight, retain_graph=True)[0] 673 | d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps) 674 | d_weight = torch.clamp(d_weight, 0.0, 1e4) 675 | return d_weight.detach() 676 | 677 | 678 | class VAE(BasicModel): 679 | def __init__( 680 | self, 681 | in_channels=3, 682 | out_channels=3, 683 | spatial_dims=2, 684 | emb_channels=4, 685 | hid_chs=[64, 128, 256, 512], 686 | kernel_sizes=[3, 3, 3, 3], 687 | strides=[1, 2, 2, 2], 688 | norm_name=("GROUP", {'num_groups': 8, "affine": True}), 689 | act_name=("Swish", {}), 690 | dropout=None, 691 | use_res_block=True, 692 | deep_supervision=False, 693 | learnable_interpolation=True, 694 | use_attention='none', 695 | embedding_loss_weight=1e-6, 696 | perceiver=LPIPS, 697 | perceiver_kwargs={}, 698 | perceptual_loss_weight=1.0, 699 | 700 | 701 | optimizer=torch.optim.Adam, 702 | optimizer_kwargs={'lr': 1e-4}, 703 | lr_scheduler=None, 704 | lr_scheduler_kwargs={}, 705 | loss=torch.nn.L1Loss, 706 | loss_kwargs={'reduction': 'none'}, 707 | 708 | sample_every_n_steps=1000 709 | 710 | ): 711 | super().__init__( 712 | optimizer=optimizer, 713 | optimizer_kwargs=optimizer_kwargs, 714 | lr_scheduler=lr_scheduler, 715 | lr_scheduler_kwargs=lr_scheduler_kwargs 716 | ) 717 | self.sample_every_n_steps = sample_every_n_steps 718 | self.loss_fct = loss(**loss_kwargs) 719 | # self.ssim_fct = SSIM(data_range=1, size_average=False, channel=out_channels, spatial_dims=spatial_dims, nonnegative_ssim=True) 720 | self.embedding_loss_weight = embedding_loss_weight 721 | self.perceiver = perceiver( 722 | **perceiver_kwargs).eval() if perceiver is not None else None 723 | self.perceptual_loss_weight = perceptual_loss_weight 724 | use_attention = use_attention if isinstance(use_attention, list) else [ 725 | use_attention]*len(strides) 726 | self.depth = len(strides) 727 | self.deep_supervision = deep_supervision 728 | downsample_kernel_sizes = kernel_sizes 729 | upsample_kernel_sizes = strides 730 | 731 | # -------- Loss-Reg--------- 732 | # self.logvar = nn.Parameter(torch.zeros(size=()) ) 733 | 734 | # ----------- In-Convolution ------------ 735 | ConvBlock = UnetResBlock if use_res_block else UnetBasicBlock 736 | self.inc = ConvBlock( 737 | spatial_dims, 738 | in_channels, 739 | hid_chs[0], 740 | kernel_size=kernel_sizes[0], 741 | stride=strides[0], 742 | act_name=act_name, 743 | norm_name=norm_name, 744 | emb_channels=None 745 | ) 746 | 747 | # ----------- Encoder ---------------- 748 | self.encoders = nn.ModuleList([ 749 | DownBlock( 750 | spatial_dims=spatial_dims, 751 | in_channels=hid_chs[i-1], 752 | out_channels=hid_chs[i], 753 | kernel_size=kernel_sizes[i], 754 | stride=strides[i], 755 | downsample_kernel_size=downsample_kernel_sizes[i], 756 | norm_name=norm_name, 757 | act_name=act_name, 758 | dropout=dropout, 759 | use_res_block=use_res_block, 760 | learnable_interpolation=learnable_interpolation, 761 | use_attention=use_attention[i], 762 | emb_channels=None 763 | ) 764 | for i in range(1, self.depth) 765 | ]) 766 | 767 | # ----------- Out-Encoder ------------ 768 | self.out_enc = nn.Sequential( 769 | BasicBlock(spatial_dims, hid_chs[-1], 2*emb_channels, 3), 770 | BasicBlock(spatial_dims, 2*emb_channels, 2*emb_channels, 1) 771 | ) 772 | 773 | # ----------- Reparameterization -------------- 774 | self.quantizer = DiagonalGaussianDistribution() 775 | 776 | # ----------- In-Decoder ------------ 777 | self.inc_dec = ConvBlock( 778 | spatial_dims, emb_channels, hid_chs[-1], 3, act_name=act_name, norm_name=norm_name) 779 | 780 | # ------------ Decoder ---------- 781 | self.decoders = nn.ModuleList([ 782 | UpBlock( 783 | spatial_dims=spatial_dims, 784 | in_channels=hid_chs[i+1], 785 | out_channels=hid_chs[i], 786 | kernel_size=kernel_sizes[i+1], 787 | stride=strides[i+1], 788 | upsample_kernel_size=upsample_kernel_sizes[i+1], 789 | norm_name=norm_name, 790 | act_name=act_name, 791 | dropout=dropout, 792 | use_res_block=use_res_block, 793 | learnable_interpolation=learnable_interpolation, 794 | use_attention=use_attention[i], 795 | emb_channels=None, 796 | skip_channels=0 797 | ) 798 | for i in range(self.depth-1) 799 | ]) 800 | 801 | # --------------- Out-Convolution ---------------- 802 | self.outc = BasicBlock( 803 | spatial_dims, hid_chs[0], out_channels, 1, zero_conv=True) 804 | if isinstance(deep_supervision, bool): 805 | deep_supervision = self.depth-1 if deep_supervision else 0 806 | self.outc_ver = nn.ModuleList([ 807 | BasicBlock(spatial_dims, hid_chs[i], 808 | out_channels, 1, zero_conv=True) 809 | for i in range(1, deep_supervision+1) 810 | ]) 811 | # self.logvar_ver = nn.ParameterList([ 812 | # nn.Parameter(torch.zeros(size=()) ) 813 | # for _ in range(1, deep_supervision+1) 814 | # ]) 815 | 816 | def encode(self, x): 817 | h = self.inc(x) 818 | for i in range(len(self.encoders)): 819 | h = self.encoders[i](h) 820 | z = self.out_enc(h) 821 | z, _ = self.quantizer(z) 822 | return z 823 | 824 | def decode(self, z): 825 | h = self.inc_dec(z) 826 | for i in range(len(self.decoders), 0, -1): 827 | h = self.decoders[i-1](h) 828 | x = self.outc(h) 829 | return x 830 | 831 | def forward(self, x_in): 832 | # --------- Encoder -------------- 833 | h = self.inc(x_in) 834 | for i in range(len(self.encoders)): 835 | h = self.encoders[i](h) 836 | z = self.out_enc(h) 837 | 838 | # --------- Quantizer -------------- 839 | z_q, emb_loss = self.quantizer(z) 840 | 841 | # -------- Decoder ----------- 842 | out_hor = [] 843 | h = self.inc_dec(z_q) 844 | for i in range(len(self.decoders)-1, -1, -1): 845 | out_hor.append(self.outc_ver[i](h)) if i < len( 846 | self.outc_ver) else None 847 | h = self.decoders[i](h) 848 | out = self.outc(h) 849 | 850 | return out, out_hor[::-1], emb_loss 851 | 852 | def perception_loss(self, pred, target, depth=0): 853 | if (self.perceiver is not None) and (depth < 2): 854 | self.perceiver.eval() 855 | return self.perceiver(pred, target)*self.perceptual_loss_weight 856 | else: 857 | return 0 858 | 859 | def ssim_loss(self, pred, target): 860 | return 1-ssim(((pred+1)/2).clamp(0, 1), (target.type(pred.dtype)+1)/2, data_range=1, size_average=False, 861 | nonnegative_ssim=True).reshape(-1, *[1]*(pred.ndim-1)) 862 | 863 | def rec_loss(self, pred, pred_vertical, target): 864 | interpolation_mode = 'nearest-exact' 865 | 866 | # Loss 867 | loss = 0 868 | rec_loss = self.loss_fct( 869 | pred, target)+self.perception_loss(pred, target)+self.ssim_loss(pred, target) 870 | # rec_loss = rec_loss/ torch.exp(self.logvar) + self.logvar # Note this is include in Stable-Diffusion but logvar is not used in optimizer 871 | loss += torch.sum(rec_loss)/pred.shape[0] 872 | 873 | for i, pred_i in enumerate(pred_vertical): 874 | target_i = F.interpolate( 875 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 876 | rec_loss_i = self.loss_fct( 877 | pred_i, target_i)+self.perception_loss(pred_i, target_i)+self.ssim_loss(pred_i, target_i) 878 | # rec_loss_i = rec_loss_i/ torch.exp(self.logvar_ver[i]) + self.logvar_ver[i] 879 | loss += torch.sum(rec_loss_i)/pred.shape[0] 880 | 881 | return loss 882 | 883 | def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): 884 | # ------------------------- Get Source/Target --------------------------- 885 | x = batch['source'] 886 | target = x 887 | 888 | # ------------------------- Run Model --------------------------- 889 | pred, pred_vertical, emb_loss = self(x) 890 | 891 | # ------------------------- Compute Loss --------------------------- 892 | loss = self.rec_loss(pred, pred_vertical, target) 893 | loss += emb_loss*self.embedding_loss_weight 894 | 895 | # --------------------- Compute Metrics ------------------------------- 896 | with torch.no_grad(): 897 | logging_dict = {'loss': loss, 'emb_loss': emb_loss} 898 | logging_dict['L2'] = torch.nn.functional.mse_loss(pred, target) 899 | logging_dict['L1'] = torch.nn.functional.l1_loss(pred, target) 900 | logging_dict['ssim'] = ssim( 901 | (pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1) 902 | # logging_dict['logvar'] = self.logvar 903 | 904 | # ----------------- Log Scalars ---------------------- 905 | for metric_name, metric_val in logging_dict.items(): 906 | self.log(f"{state}/{metric_name}", metric_val, 907 | batch_size=x.shape[0], on_step=True, on_epoch=True) 908 | 909 | # ----------------- Save Image ------------------------------ 910 | if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: 911 | log_step = self.global_step // self.sample_every_n_steps 912 | path_out = Path(self.logger.log_dir)/'images' 913 | path_out.mkdir(parents=True, exist_ok=True) 914 | # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images 915 | 916 | def depth2batch(image): 917 | return (image if image.ndim < 5 else torch.swapaxes(image[0], 0, 1)) 918 | images = torch.cat([depth2batch(img)[:16] for img in (x, pred)]) 919 | save_image( 920 | images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True) 921 | 922 | return loss 923 | 924 | 925 | class VAEGAN(VeryBasicModel): 926 | def __init__( 927 | self, 928 | in_channels=3, 929 | out_channels=3, 930 | spatial_dims=2, 931 | emb_channels=4, 932 | hid_chs=[64, 128, 256, 512], 933 | kernel_sizes=[3, 3, 3, 3], 934 | strides=[1, 2, 2, 2], 935 | norm_name=("GROUP", {'num_groups': 8, "affine": True}), 936 | act_name=("Swish", {}), 937 | dropout=0.0, 938 | use_res_block=True, 939 | deep_supervision=False, 940 | learnable_interpolation=True, 941 | use_attention='none', 942 | embedding_loss_weight=1e-6, 943 | perceiver=LPIPS, 944 | perceiver_kwargs={}, 945 | perceptual_loss_weight=1.0, 946 | 947 | 948 | start_gan_train_step=50000, # NOTE step increase with each optimizer 949 | gan_loss_weight: float = 1.0, # = discriminator 950 | 951 | optimizer_vqvae=torch.optim.Adam, 952 | optimizer_gan=torch.optim.Adam, 953 | # 'weight_decay':1e-2, {'lr':1e-6, 'betas':(0.5, 0.9)} 954 | optimizer_vqvae_kwargs={'lr': 1e-6}, 955 | optimizer_gan_kwargs={'lr': 1e-6}, # 'weight_decay':1e-2, 956 | lr_scheduler_vqvae=None, 957 | lr_scheduler_vqvae_kwargs={}, 958 | lr_scheduler_gan=None, 959 | lr_scheduler_gan_kwargs={}, 960 | 961 | pixel_loss=torch.nn.L1Loss, 962 | pixel_loss_kwargs={'reduction': 'none'}, 963 | gan_loss_fct=hinge_d_loss, 964 | 965 | sample_every_n_steps=1000 966 | 967 | ): 968 | super().__init__() 969 | self.sample_every_n_steps = sample_every_n_steps 970 | self.start_gan_train_step = start_gan_train_step 971 | self.gan_loss_weight = gan_loss_weight 972 | self.embedding_loss_weight = embedding_loss_weight 973 | 974 | self.optimizer_vqvae = optimizer_vqvae 975 | self.optimizer_gan = optimizer_gan 976 | self.optimizer_vqvae_kwargs = optimizer_vqvae_kwargs 977 | self.optimizer_gan_kwargs = optimizer_gan_kwargs 978 | self.lr_scheduler_vqvae = lr_scheduler_vqvae 979 | self.lr_scheduler_vqvae_kwargs = lr_scheduler_vqvae_kwargs 980 | self.lr_scheduler_gan = lr_scheduler_gan 981 | self.lr_scheduler_gan_kwargs = lr_scheduler_gan_kwargs 982 | 983 | self.pixel_loss_fct = pixel_loss(**pixel_loss_kwargs) 984 | self.gan_loss_fct = gan_loss_fct 985 | 986 | self.vqvae = VAE(in_channels, out_channels, spatial_dims, emb_channels, hid_chs, kernel_sizes, 987 | strides, norm_name, act_name, dropout, use_res_block, deep_supervision, learnable_interpolation, use_attention, 988 | embedding_loss_weight, perceiver, perceiver_kwargs, perceptual_loss_weight) 989 | 990 | self.discriminator = nn.ModuleList([Discriminator(in_channels, spatial_dims, hid_chs, kernel_sizes, strides, 991 | act_name, norm_name, dropout) for i in range(len(self.vqvae.outc_ver)+1)]) 992 | 993 | # self.discriminator = nn.ModuleList([NLayerDiscriminator(in_channels, spatial_dims) 994 | # for _ in range(len(self.vqvae.outc_ver)+1)]) 995 | 996 | def encode(self, x): 997 | return self.vqvae.encode(x) 998 | 999 | def decode(self, z): 1000 | return self.vqvae.decode(z) 1001 | 1002 | def forward(self, x): 1003 | return self.vqvae.forward(x) 1004 | 1005 | def vae_img_loss(self, pred, target, dec_out_layer, step, discriminator, depth=0): 1006 | # ------ VQVAE ------- 1007 | rec_loss = self.vqvae.rec_loss(pred, [], target) 1008 | 1009 | # ------- GAN ----- 1010 | if (step > self.start_gan_train_step) and (depth < 2): 1011 | # clamp(..., None, 0) => only punish areas that were rated as fake (<0) by discriminator => ensures loss >0 and +- don't cannel out in sum 1012 | gan_loss = -torch.sum(discriminator[depth](pred)) 1013 | lambda_weight = self.compute_lambda( 1014 | rec_loss, gan_loss, dec_out_layer) 1015 | gan_loss = gan_loss*lambda_weight 1016 | 1017 | with torch.no_grad(): 1018 | self.log(f"train/gan_loss_{depth}", 1019 | gan_loss, on_step=True, on_epoch=True) 1020 | self.log(f"train/lambda_{depth}", 1021 | lambda_weight, on_step=True, on_epoch=True) 1022 | else: 1023 | # torch.tensor([0.0], requires_grad=True, device=target.device) 1024 | gan_loss = 0 1025 | 1026 | return self.gan_loss_weight*gan_loss+rec_loss 1027 | 1028 | def gan_img_loss(self, pred, target, step, discriminators, depth): 1029 | if (step > self.start_gan_train_step) and (depth < len(discriminators)): 1030 | logits_real = discriminators[depth](target.detach()) 1031 | logits_fake = discriminators[depth](pred.detach()) 1032 | loss = self.gan_loss_fct(logits_real, logits_fake) 1033 | else: 1034 | loss = torch.tensor(0.0, requires_grad=True, device=target.device) 1035 | 1036 | with torch.no_grad(): 1037 | self.log(f"train/loss_1_{depth}", loss, 1038 | on_step=True, on_epoch=True) 1039 | return loss 1040 | 1041 | def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): 1042 | # ------------------------- Get Source/Target --------------------------- 1043 | x = batch['source'] 1044 | target = x 1045 | 1046 | # ------------------------- Run Model --------------------------- 1047 | pred, pred_vertical, emb_loss = self(x) 1048 | 1049 | # ------------------------- Compute Loss --------------------------- 1050 | interpolation_mode = 'area' 1051 | logging_dict = {} 1052 | 1053 | if optimizer_idx == 0: 1054 | # Horizontal/Top Layer 1055 | img_loss = self.vae_img_loss( 1056 | pred, target, self.vqvae.outc.conv, step, self.discriminator, 0) 1057 | 1058 | # Vertical/Deep Layer 1059 | for i, pred_i in enumerate(pred_vertical): 1060 | target_i = F.interpolate( 1061 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 1062 | img_loss += self.vae_img_loss( 1063 | pred_i, target_i, self.vqvae.outc_ver[i].conv, step, self.discriminator, i+1) 1064 | loss = img_loss+self.embedding_loss_weight*emb_loss 1065 | 1066 | with torch.no_grad(): 1067 | logging_dict[f'img_loss'] = img_loss 1068 | logging_dict[f'emb_loss'] = emb_loss 1069 | logging_dict['loss_0'] = loss 1070 | 1071 | elif optimizer_idx == 1: 1072 | # Horizontal/Top Layer 1073 | loss = self.gan_img_loss(pred, target, step, self.discriminator, 0) 1074 | 1075 | # Vertical/Deep Layer 1076 | for i, pred_i in enumerate(pred_vertical): 1077 | target_i = F.interpolate( 1078 | target, size=pred_i.shape[2:], mode=interpolation_mode, align_corners=None) 1079 | loss += self.gan_img_loss(pred_i, target_i, 1080 | step, self.discriminator, i+1) 1081 | 1082 | with torch.no_grad(): 1083 | logging_dict['loss_1'] = loss 1084 | 1085 | # --------------------- Compute Metrics ------------------------------- 1086 | with torch.no_grad(): 1087 | logging_dict['loss'] = loss 1088 | logging_dict[f'L2'] = torch.nn.functional.mse_loss(pred, x) 1089 | logging_dict[f'L1'] = torch.nn.functional.l1_loss(pred, x) 1090 | logging_dict['ssim'] = ssim( 1091 | (pred+1)/2, (target.type(pred.dtype)+1)/2, data_range=1) 1092 | # logging_dict['logvar'] = self.vqvae.logvar 1093 | 1094 | # ----------------- Log Scalars ---------------------- 1095 | for metric_name, metric_val in logging_dict.items(): 1096 | self.log(f"{state}/{metric_name}", metric_val, 1097 | batch_size=x.shape[0], on_step=True, on_epoch=True) 1098 | 1099 | # ----------------- Save Image ------------------------------ 1100 | # NOTE: step 1 (opt1) , step=2 (opt2), step=3 (opt1), ... 1101 | if self.global_step != 0 and self.global_step % self.sample_every_n_steps == 0: 1102 | 1103 | log_step = self.global_step // self.sample_every_n_steps 1104 | path_out = Path(self.logger.log_dir)/'images' 1105 | path_out.mkdir(parents=True, exist_ok=True) 1106 | # for 3D images use depth as batch :[D, C, H, W], never show more than 16+16 =32 images 1107 | 1108 | def depth2batch(image): 1109 | return (image if image.ndim < 5 else torch.swapaxes(image[0], 0, 1)) 1110 | images = torch.cat([depth2batch(img)[:16] for img in (x, pred)]) 1111 | save_image( 1112 | images, path_out/f'sample_{log_step}.png', nrow=x.shape[0], normalize=True) 1113 | 1114 | return loss 1115 | 1116 | def configure_optimizers(self): 1117 | opt_vqvae = self.optimizer_vqvae( 1118 | self.vqvae.parameters(), **self.optimizer_vqvae_kwargs) 1119 | opt_gan = self.optimizer_gan( 1120 | self.discriminator.parameters(), **self.optimizer_gan_kwargs) 1121 | schedulers = [] 1122 | if self.lr_scheduler_vqvae is not None: 1123 | schedulers.append({ 1124 | 'scheduler': self.lr_scheduler_vqvae(opt_vqvae, **self.lr_scheduler_vqvae_kwargs), 1125 | 'interval': 'step', 1126 | 'frequency': 1 1127 | }) 1128 | if self.lr_scheduler_gan is not None: 1129 | schedulers.append({ 1130 | 'scheduler': self.lr_scheduler_gan(opt_gan, **self.lr_scheduler_gan_kwargs), 1131 | 'interval': 'step', 1132 | 'frequency': 1 1133 | }) 1134 | return [opt_vqvae, opt_gan], schedulers 1135 | 1136 | def compute_lambda(self, rec_loss, gan_loss, dec_out_layer, eps=1e-4): 1137 | """Computes adaptive weight as proposed in eq. 7 of https://arxiv.org/abs/2012.09841""" 1138 | rec_grads = torch.autograd.grad( 1139 | rec_loss, dec_out_layer.weight, retain_graph=True)[0] 1140 | gan_grads = torch.autograd.grad( 1141 | gan_loss, dec_out_layer.weight, retain_graph=True)[0] 1142 | d_weight = torch.norm(rec_grads) / (torch.norm(gan_grads) + eps) 1143 | d_weight = torch.clamp(d_weight, 0.0, 1e4) 1144 | return d_weight.detach() 1145 | --------------------------------------------------------------------------------