├── configs ├── metrics │ ├── default.yaml │ ├── reconstruction_minimal.yaml │ ├── reconstruction_all.yaml │ ├── reconstruction_3d.yaml │ └── binaryclass.yaml ├── criterion │ ├── bce.yaml │ ├── biomedclip.yaml │ ├── default.yaml │ └── lpips_with_discriminator.yaml ├── logger │ └── wandb.yaml ├── optimizer │ ├── default.yaml │ └── adam.yaml ├── model │ ├── monai_seresnet152.yaml │ └── default.yaml ├── dataloader │ ├── default.yaml │ ├── mmgs.yaml │ ├── example_dataset.yaml │ └── mri_ct_3d.yaml ├── hydra │ └── default.yaml ├── experiment │ ├── medvae_4x_1c_3d_finetuning.yaml │ ├── medvae_4x_1c_2d_finetuning.yaml │ ├── medvae_4x_3c_2d_finetuning.yaml │ ├── medvae_4x_3c_2d_s2_finetuning.yaml │ ├── medvae_4x_1c_2d_s2_finetuning.yaml │ └── example_cls.yaml ├── paths │ └── default.yaml └── finetuned_vae.yaml ├── medvae ├── utils │ ├── __init__.py │ ├── vae │ │ ├── __init__.py │ │ ├── distributions.py │ │ ├── train_components_stage2.py │ │ ├── train_components.py │ │ ├── loss_components.py │ │ ├── diffusionmodels.py │ │ └── diffusionmodels_3d.py │ ├── cls │ │ ├── __init__.py │ │ └── train_components.py │ ├── transforms.py │ ├── lr_schedulers.py │ ├── extras.py │ ├── factory.py │ └── loaders.py ├── __init__.py ├── losses │ └── __init__.py ├── dataloaders │ ├── __init__.py │ ├── concat_dataset.py │ └── generic_dataset.py ├── models │ ├── __init__.py │ ├── autoencoder_kl.py │ └── autoencoder_kl_3d.py ├── metrics │ ├── __init__.py │ └── reconstruction_metrics.py ├── medvae_inference.py ├── medvae_main.py ├── medvae_cls.py ├── medvae_finetune.py └── medvae_finetune_s2.py ├── documentation ├── assets │ └── overview.png ├── classification.md ├── create_csv.ipynb ├── demo.ipynb ├── finetune.md └── inference.md ├── .github └── workflows │ └── ruff.yml ├── setup.py ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── pyproject.toml └── README.md /configs/metrics/default.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /medvae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /medvae/utils/vae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/criterion/bce.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.BCEWithLogitsLoss 2 | -------------------------------------------------------------------------------- /documentation/assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/MedVAE/HEAD/documentation/assets/overview.png -------------------------------------------------------------------------------- /medvae/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.medvae_main import MVAE 2 | from medvae import losses 3 | 4 | __all__ = ["MVAE", "losses"] -------------------------------------------------------------------------------- /configs/criterion/biomedclip.yaml: -------------------------------------------------------------------------------- 1 | _target_: medvae.losses.BiomedClipLoss 2 | compute_rec_loss: false 3 | compute_lat_loss: true -------------------------------------------------------------------------------- /configs/criterion/default.yaml: -------------------------------------------------------------------------------- 1 | # the criterion used for the loss function 2 | _target_: torch.nn.CosineSimilarity 3 | dim: 1 4 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | entity: null 2 | id: null 3 | group: null 4 | name: ${task_name} 5 | resume: allow 6 | dir: ${paths.output_dir} 7 | -------------------------------------------------------------------------------- /configs/metrics/reconstruction_minimal.yaml: -------------------------------------------------------------------------------- 1 | - - "PSNR" 2 | - _target_: medvae.metrics.PSNR 3 | - - "MSE" 4 | - _target_: medvae.metrics.MSE -------------------------------------------------------------------------------- /configs/criterion/lpips_with_discriminator.yaml: -------------------------------------------------------------------------------- 1 | _target_: medvae.losses.LPIPSWithDiscriminator 2 | disc_start: 10000000 #50001 3 | kl_weight: 0.000001 4 | disc_weight: 0.5 5 | -------------------------------------------------------------------------------- /medvae/utils/cls/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.utils.cls.train_components import training_epoch, validation_epoch 2 | 3 | __all__ = [ 4 | "training_epoch", 5 | "validation_epoch", 6 | ] 7 | -------------------------------------------------------------------------------- /configs/optimizer/default.yaml: -------------------------------------------------------------------------------- 1 | # partially instantiate an optimizer, so that it only requires params 2 | _partial_: true 3 | _target_: torch.optim.SGD 4 | lr: 0.005 5 | weight_decay: 0.0001 6 | momentum: 0.9 -------------------------------------------------------------------------------- /medvae/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.losses.vae_losses import LPIPSWithDiscriminator 2 | from medvae.losses.vae_losses import BiomedClipLoss 3 | 4 | __all__ = ["LPIPSWithDiscriminator", "BiomedClipLoss"] 5 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [ push, pull_request ] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: astral-sh/ruff-action@v3 -------------------------------------------------------------------------------- /configs/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # partially instantiate an optimizer, so that it only requires params 2 | _partial_: true 3 | _target_: torch.optim.AdamW 4 | lr: 0.0001 5 | # weight_decay: 0.000001 6 | weight_decay: 0.05 7 | -------------------------------------------------------------------------------- /medvae/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.dataloaders.generic_dataset import GenericDataset 2 | from medvae.dataloaders.concat_dataset import ConcatDataset 3 | 4 | __all__ = [ 5 | "GenericDataset", 6 | "ConcatDataset", 7 | ] 8 | -------------------------------------------------------------------------------- /medvae/models/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.models.autoencoder_kl_3d import AutoencoderKL as AutoencoderKL_3D 2 | from medvae.models.autoencoder_kl import AutoencoderKL as AutoencoderKL_2D 3 | 4 | __all__ = ["AutoencoderKL_2D", "AutoencoderKL_3D"] -------------------------------------------------------------------------------- /configs/metrics/reconstruction_all.yaml: -------------------------------------------------------------------------------- 1 | - - "PSNR" 2 | - _target_: medvae.metrics.PSNR 3 | - - "MSE" 4 | - _target_: medvae.metrics.MSE 5 | - - "FID-Inception" 6 | - _target_: medvae.metrics.FID_Inception 7 | - - "FID-CLIP" 8 | - _target_: medvae.metrics.FID_CLIP 9 | version: CLIP -------------------------------------------------------------------------------- /configs/metrics/reconstruction_3d.yaml: -------------------------------------------------------------------------------- 1 | - - "PSNR" 2 | - _target_: medvae.metrics.PSNR 3 | - - "MSE" 4 | - _target_: medvae.metrics.MSE 5 | - - "MS-SSIM" 6 | - _target_: medvae.metrics.MS_SSIM 7 | - - "FID-Inception" 8 | - _target_: medvae.metrics.FID_Inception_3D 9 | - - "FID-CLIP" 10 | - _target_: medvae.metrics.FID_CLIP_3D -------------------------------------------------------------------------------- /medvae/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from medvae.metrics.reconstruction_metrics import ( 2 | MS_SSIM, 3 | MS_SSIM_SMALL, 4 | FID_Inception_3D, 5 | MSE, 6 | PSNR, 7 | FID_Inception, 8 | ) 9 | 10 | __all__ = [ 11 | "MSE", 12 | "PSNR", 13 | "MS_SSIM", 14 | "MS_SSIM_SMALL", 15 | "FID_Inception", 16 | "FID_Inception_3D", 17 | ] 18 | -------------------------------------------------------------------------------- /configs/model/monai_seresnet152.yaml: -------------------------------------------------------------------------------- 1 | _target_: monai.networks.nets.SEResNet152 2 | 3 | # Define the number of spatial dimensions 4 | spatial_dims: 3 5 | 6 | # the number of channels for the images in the dataset (=in_chans for timm) 7 | in_channels: ${in_channels} 8 | 9 | # number of out_channels, its binary 10 | num_classes: 1 11 | 12 | # Dropout 13 | dropout_prob: 0.2 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import os 3 | 4 | def create_env(): 5 | """Create a .env file in the package root directory.""" 6 | with open(".env", "w") as f: 7 | f.write(f'PROJECT_DIR = "{os.path.dirname(os.path.abspath(__file__))}"') 8 | print(".env file created!") 9 | 10 | if __name__ == "__main__": 11 | create_env() 12 | setup() -------------------------------------------------------------------------------- /configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: timm.create_model 2 | 3 | # backbone/encoder of the network (=model_name for timm) 4 | model_name: resnet50 5 | 6 | # the number of channels for the images in the dataset (=in_chans for timm) 7 | in_chans: ??? 8 | 9 | # number of classes in the target 10 | num_classes: ??? 11 | 12 | # initialize with ImageNet weights 13 | pretrained: false 14 | -------------------------------------------------------------------------------- /configs/metrics/binaryclass.yaml: -------------------------------------------------------------------------------- 1 | - - "Accuracy" 2 | - _target_: torchmetrics.classification.BinaryAccuracy 3 | - - "F1Score" 4 | - _target_: torchmetrics.classification.BinaryF1Score 5 | - - "Precision" 6 | - _target_: torchmetrics.classification.BinaryPrecision 7 | - - "Recall" 8 | - _target_: torchmetrics.classification.BinaryRecall 9 | - - "AUROC" 10 | - _target_: torchmetrics.classification.BinaryAUROC 11 | -------------------------------------------------------------------------------- /medvae/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ["to_dict"] 4 | 5 | def to_dict(x): 6 | if isinstance(x, dict): 7 | if "group_id" not in x: 8 | x["group_id"] = torch.zeros( 9 | (x["img"].size(0),), dtype=x["img"].dtype, device=x["img"].device 10 | ) 11 | return x 12 | 13 | group = x[2] if len(x) == 3 else None 14 | return {"img": x[0], "lbl": x[1], "group_id": group} 15 | -------------------------------------------------------------------------------- /configs/dataloader/default.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | dataset: ??? 3 | batch_size: ${batch_size} # total amount across gpus and nodes 4 | shuffle: true 5 | num_workers: 10 # total amount across nodes 6 | pin_memory: true 7 | drop_last: true 8 | 9 | # the valid dataloader is not always used, so ??? are replaced by null 10 | valid: 11 | dataset: null 12 | batch_size: ${batch_size} 13 | shuffle: false 14 | num_workers: 10 15 | pin_memory: true 16 | drop_last: false -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.11.2 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi ] 13 | 14 | - repo: https://github.com/executablebooks/mdformat 15 | rev: 0.7.13 # Use the ref you want to point at 16 | hooks: 17 | - id: mdformat 18 | additional_dependencies: 19 | - mdformat-ruff -------------------------------------------------------------------------------- /configs/experiment/medvae_4x_1c_3d_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataloader: mri_ct_3d.yaml 5 | - override /metrics: reconstruction_minimal.yaml 6 | - override /criterion: lpips_with_discriminator.yaml 7 | 8 | # Finetuning parameters to change 9 | task_name: mri_ct_finetuning 10 | model_name: medvae_4_1_3d 11 | stage2: false 12 | device: cuda 13 | 14 | # Training parameters to change 15 | mixed_precision: "no" 16 | gradient_accumulation_steps: 1 17 | max_epoch: 6400 18 | batch_size: 4 19 | log_every_n_steps: 20 20 | base_learning_rate: 4.5e-6 21 | ema_decay: null 22 | fast_dev_run: false 23 | 24 | criterion: 25 | num_channels: 1 26 | disc_start: 3125 -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | project_dir: ${oc.env:PROJECT_DIR} 5 | 6 | # path to data directory 7 | data_dir: ${paths.project_dir}/medvae/data 8 | 9 | # path to logging directory 10 | log_dir: ${paths.project_dir}/medvae/logs 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} -------------------------------------------------------------------------------- /configs/experiment/medvae_4x_1c_2d_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataloader: mmgs.yaml 5 | - override /metrics: reconstruction_minimal.yaml 6 | - override /criterion: lpips_with_discriminator.yaml 7 | 8 | # Finetuning parameters to change 9 | task_name: mmg_finetuning 10 | model_name: medvae_4_1_2d 11 | dataset_name: mmg_data 12 | merge_channels: true 13 | stage2: false 14 | device: cuda 15 | 16 | # Training parameters to change 17 | mixed_precision: "no" 18 | gradient_accumulation_steps: 2 19 | max_epoch: 100 20 | batch_size: 2 21 | log_every_n_steps: 20 22 | base_learning_rate: 4.5e-6 23 | ema_decay: null 24 | fast_dev_run: false 25 | 26 | criterion: 27 | num_channels: 1 28 | disc_start: 3125 -------------------------------------------------------------------------------- /configs/experiment/medvae_4x_3c_2d_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataloader: mmgs.yaml 5 | - override /metrics: reconstruction_minimal.yaml 6 | - override /criterion: lpips_with_discriminator.yaml 7 | 8 | # Finetuning parameters to change 9 | task_name: mmg_finetuning 10 | model_name: medvae_4_3_2d 11 | dataset_name: mmg_data 12 | merge_channels: false 13 | stage2: false 14 | device: cuda 15 | 16 | # Training parameters to change 17 | mixed_precision: "no" 18 | gradient_accumulation_steps: 2 19 | max_epoch: 100 20 | batch_size: 2 21 | log_every_n_steps: 20 22 | base_learning_rate: 4.5e-6 23 | ema_decay: null 24 | fast_dev_run: false 25 | 26 | criterion: 27 | num_channels: 1 28 | disc_start: 3125 -------------------------------------------------------------------------------- /configs/experiment/medvae_4x_3c_2d_s2_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataloader: mmgs.yaml 5 | - override /metrics: reconstruction_minimal.yaml 6 | - override /criterion: biomedclip.yaml 7 | 8 | # Finetuning parameters to change 9 | task_name: mmg_finetuning 10 | model_name: medvae_4_3_2d 11 | dataset_name: mmg_data 12 | merge_channels: false 13 | # Note: Need to replace the stage2_ckpt with the path to the stage 1 checkpoint 14 | stage2: true 15 | stage2_ckpt: medvae/logs/mmg_finetuning/runs/2025-03-30_16-29-57/checkpoints/best.pt/pytorch_model.bin 16 | 17 | device: cuda 18 | 19 | # Training parameters to change 20 | mixed_precision: "no" 21 | gradient_accumulation_steps: 1 22 | max_epoch: 100 23 | batch_size: 4 24 | log_every_n_steps: 20 25 | base_learning_rate: 1.0e-4 26 | ema_decay: null 27 | fast_dev_run: false -------------------------------------------------------------------------------- /configs/experiment/medvae_4x_1c_2d_s2_finetuning.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /dataloader: mmgs.yaml 5 | - override /metrics: reconstruction_minimal.yaml 6 | - override /criterion: biomedclip.yaml 7 | 8 | # Finetuning parameters to change 9 | task_name: mmg_finetuning 10 | model_name: medvae_4_1_2d 11 | dataset_name: mmg_data 12 | merge_channels: true 13 | 14 | # Note: Need to replace the stage2_ckpt with the path to the stage 1 checkpoint 15 | stage2: true 16 | stage2_ckpt: medvae/logs/mmg_finetuning/runs/2025-03-30_14-33-36/checkpoints/best.pt/pytorch_model.bin 17 | 18 | device: cuda 19 | 20 | # Training parameters to change 21 | mixed_precision: "no" 22 | gradient_accumulation_steps: 1 23 | max_epoch: 100 24 | batch_size: 4 25 | log_every_n_steps: 20 26 | base_learning_rate: 1.0e-4 27 | ema_decay: null 28 | fast_dev_run: false -------------------------------------------------------------------------------- /configs/experiment/example_cls.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Note: This config file does not actually run any experiments. It is only a template for the user to follow. 4 | 5 | # All of these defaults are included. Please refer to their documentation for more details. 6 | defaults: 7 | - override /dataloader: example_dataset.yaml 8 | - override /model: monai_seresnet152.yaml 9 | - override /criterion: bce.yaml 10 | - override /metrics: binaryclass.yaml 11 | - override /optimizer: adam.yaml 12 | 13 | optimizer: 14 | lr: 0.001 15 | 16 | # General settings 17 | exp_name: null 18 | in_channels: 1 19 | dataset_id: 6 20 | task_name: cls_${exp_name} 21 | init_proj_name: cls_evals 22 | 23 | # Training settings 24 | gradient_accumulation_steps: 1 25 | mixed_precision: "no" 26 | max_epoch: 100 27 | batch_size: 16 28 | log_every_n_steps: 20 29 | ckpt_every_n_epochs: 5 30 | num_workers: 8 31 | weight_data_loader: True -------------------------------------------------------------------------------- /medvae/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | __all__ = ["CosineScheduler"] 6 | 7 | 8 | class CosineScheduler(_LRScheduler): 9 | def __init__(self, optimizer, max_epochs: int, verbose: bool = False): 10 | self.optimizer = optimizer 11 | self.max_epochs = max_epochs 12 | 13 | # last epoch should always be -1, use load_state_dict to resume 14 | super().__init__(optimizer, -1, verbose) 15 | 16 | def _compute_lr(self, param_group): 17 | init_lr = param_group["initial_lr"] 18 | current_lr = ( 19 | init_lr 20 | * 0.5 21 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_epochs)) 22 | ) 23 | 24 | if "fixed_lr" in param_group and param_group["fixed_lr"]: 25 | return init_lr 26 | else: 27 | return current_lr 28 | 29 | def get_lr(self): 30 | return [ 31 | self._compute_lr(param_group) for param_group in self.optimizer.param_groups 32 | ] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Stanford MIMI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | metrics/ 30 | !configs/metrics/ 31 | !medvae/metrics/ 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Remove data file 58 | /medvae/data/ 59 | 60 | # Environment files 61 | *.env 62 | *.vscode 63 | 64 | # Data files 65 | data/ 66 | medvae/data/ 67 | 68 | # Log files 69 | logs/ 70 | *.log -------------------------------------------------------------------------------- /medvae/dataloaders/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import torch 4 | 5 | from torch.utils.data import ConcatDataset 6 | 7 | __all__ = ["ConcatDataset"] 8 | 9 | 10 | class ConcatDataset(ConcatDataset): 11 | def __init__(self, datasets, dataset_ids: List[int] = None): 12 | for i in range(len(datasets)): 13 | if callable(datasets[i]): 14 | datasets[i] = datasets[i]() 15 | 16 | if dataset_ids is not None: 17 | datasets = [ds for ds in datasets if ds.dataset_id in dataset_ids] 18 | 19 | super().__init__(datasets) 20 | 21 | 22 | def _calculate_weights(self, datasets): 23 | # Calculate weights for each sample in each dataset 24 | dataset_lengths = np.array([len(ds) for ds in datasets]) 25 | weight = 1.0 / dataset_lengths 26 | # Create a list of weights 27 | weights = np.repeat(weight, dataset_lengths) 28 | return weights 29 | 30 | def _calculate_label_weights(self, datasets): 31 | # Concatenate all labels from all datasets 32 | all_labels = torch.cat([torch.tensor(ds.get_labels()) for ds in datasets]) 33 | 34 | # Count the number of occurrences of each class 35 | class_counts = torch.bincount(all_labels) 36 | 37 | # Calculate weights 38 | weights = 1. / class_counts[all_labels] 39 | 40 | return weights 41 | 42 | def get_weights(self): 43 | # Return the calculated weights 44 | return self._calculate_weights(self.datasets) 45 | 46 | def get_label_weights(self): 47 | return self._calculate_label_weights(self.datasets) 48 | -------------------------------------------------------------------------------- /configs/finetuned_vae.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - hydra: default.yaml 8 | - paths: default.yaml 9 | - criterion: default.yaml 10 | - dataloader: default.yaml 11 | - logger: null 12 | 13 | # Choose a dataloader 14 | - dataloader: ??? 15 | 16 | # keep track of metrics during the experiment 17 | - metrics: default.yaml 18 | 19 | # experiment configs allow for version control of specific hyperparameters 20 | # e.g. best hyperparameters for given model and datamodule 21 | - experiment: ??? 22 | 23 | 24 | # log frequency 25 | log_every_n_steps: 10 26 | 27 | # start trainig at epoch n 28 | start_epoch: 0 29 | 30 | # stop training at epoch n 31 | max_epoch: 50 32 | 33 | # mini-batch size 34 | batch_size: 256 35 | 36 | # Set the task name 37 | task_name: null 38 | 39 | # Set the input directory 40 | input: null 41 | 42 | # Set the output directory 43 | output: null 44 | 45 | # Set the model name 46 | model_name: null 47 | 48 | # Set the device 49 | device: null 50 | 51 | # gradient accumulation 52 | gradient_accumulation_steps: 1 53 | 54 | # checkpoint path to the accelerator state (this should be a directory) 55 | resume_from_ckpt: null 56 | 57 | # directory to store checkpoints in, same as the global output dir for this run 58 | ckpt_dir: ${paths.output_dir}/checkpoints 59 | 60 | # checkpoint frequency 61 | ckpt_every_n_steps: 500 62 | 63 | # seed for random number generators in pytorch, numpy and python.random 64 | # sets cudnn to deterministic too 65 | seed: null 66 | 67 | # Run one training step, and if applicable a validation step 68 | fast_dev_run: false 69 | 70 | # For mixed precision training 71 | mixed_precision: "no" 72 | -------------------------------------------------------------------------------- /medvae/utils/vae/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class DiagonalGaussianDistribution: 6 | def __init__(self, parameters, deterministic=False): 7 | self.parameters = parameters 8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 10 | self.deterministic = deterministic 11 | self.std = torch.exp(0.5 * self.logvar) 12 | self.var = torch.exp(self.logvar) 13 | if self.deterministic: 14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 15 | 16 | def sample(self): 17 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 18 | return x 19 | 20 | def kl(self, other=None): 21 | if self.deterministic: 22 | return torch.Tensor([0.0]) 23 | else: 24 | if other is None: 25 | return 0.5 * torch.sum( 26 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3] 27 | ) 28 | else: 29 | return 0.5 * torch.sum( 30 | torch.pow(self.mean - other.mean, 2) / other.var 31 | + self.var / other.var 32 | - 1.0 33 | - self.logvar 34 | + other.logvar, 35 | dim=[1, 2, 3], 36 | ) 37 | 38 | def nll(self, sample, dims=[1, 2, 3]): 39 | if self.deterministic: 40 | return torch.Tensor([0.0]) 41 | logtwopi = np.log(2.0 * np.pi) 42 | return 0.5 * torch.sum( 43 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims 44 | ) 45 | 46 | def mode(self): 47 | return self.mean 48 | -------------------------------------------------------------------------------- /configs/dataloader/mmgs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train: 5 | dataset: 6 | _target_: medvae.dataloaders.ConcatDataset 7 | datasets: 8 | # MMGS 9 | - _target_: medvae.dataloaders.GenericDataset 10 | split_path: ${paths.data_dir}/${dataset_name}.csv 11 | split_column: split 12 | split_name: train 13 | data_dir: ${paths.data_dir}/${dataset_name}/ 14 | dataset_id: 0 15 | img_column: image_uuid 16 | img_suffix: .png 17 | img_transform: 18 | _target_: torchvision.transforms.Compose 19 | transforms: 20 | - _partial_: true 21 | _target_: medvae.utils.loaders.load_2d_finetune 22 | 23 | # If you don't have single channel input 24 | merge_channels: ${merge_channels} 25 | 26 | valid: 27 | dataset: 28 | _target_: medvae.dataloaders.ConcatDataset 29 | datasets: 30 | # MMGS 31 | - _target_: medvae.dataloaders.GenericDataset 32 | split_path: ${paths.data_dir}/${dataset_name}.csv 33 | split_column: split 34 | split_name: val 35 | data_dir: ${paths.data_dir}/${dataset_name}/ 36 | dataset_id: 0 37 | img_column: image_uuid 38 | img_suffix: .png 39 | img_transform: 40 | _target_: torchvision.transforms.Compose 41 | transforms: 42 | - _partial_: true 43 | _target_: medvae.utils.loaders.load_2d_finetune 44 | 45 | # If you don't have single channel input 46 | merge_channels: ${merge_channels} 47 | 48 | test: 49 | dataset: 50 | _target_: medvae.dataloaders.ConcatDataset 51 | datasets: 52 | # MMGS 53 | - _target_: medvae.dataloaders.GenericDataset 54 | split_path: ${paths.data_dir}/${dataset_name}.csv 55 | split_column: split 56 | split_name: test 57 | data_dir: ${paths.data_dir}/${dataset_name}/ 58 | dataset_id: 0 59 | img_column: image_uuid 60 | img_suffix: .png 61 | img_transform: 62 | _target_: torchvision.transforms.Compose 63 | transforms: 64 | - _partial_: true 65 | _target_: medvae.utils.loaders.load_2d_finetune 66 | 67 | # If you don't have single channel input 68 | merge_channels: ${merge_channels} -------------------------------------------------------------------------------- /documentation/classification.md: -------------------------------------------------------------------------------- 1 | # ✨ MedVAE classification documentation 2 | 3 | MedVAE includes a flexible classification framework designed to leverage latent representations for downstream classification tasks. The current setup provides examples that you can adapt to your dataset. Please follow the instructions below to configure and run classification models using MedVAE. 4 | 5 | ## Requirements 6 | 7 | - Python 3.9 (Recommended) 8 | - Clone of the MedVAE GitHub repository 9 | 10 | ## Preparing Your Dataset 11 | 12 | The provided example dataset does not include labels or classification examples. You will need to: 13 | 14 | - Prepare a CSV file that clearly specifies your dataset splits (train, validation, test) and corresponding labels. 15 | - Adjust the [dataloader configuration](../configs/dataloader/example_dataset.yaml) to reference your CSV file and dataset path. 16 | 17 | ## Configuration and Customization 18 | 19 | The classification framework requires modifications in several key configuration areas: 20 | 21 | ### Dataloader 22 | 23 | - Customize the [example dataloader file](../configs/dataloader/example_dataset.yaml) to match your dataset structure and labeling schema. 24 | - Ensure the CSV file accurately includes columns for data splits and class labels. 25 | 26 | ### Criterion 27 | 28 | - The default loss function is set for binary cross-entropy. To use alternative loss functions (e.g., multi-class cross-entropy), you must update the criterion configuration accordingly. 29 | 30 | ### Model 31 | 32 | - The default [model configuration](../configs/model/monai_seresnet152.yaml) provided in [example_cls.yaml](../configs/experiment/example_cls.yaml) is designed for 3D volume classification. 33 | - For 2D image classification tasks, please adapt this configuration using [default.yaml](../configs/model/default.yaml) or another suitable architecture. 34 | - Identify and select model architectures that best suit your specific dataset and task. 35 | 36 | ### Experiment 37 | 38 | - Modify the experiment configuration file to adjust hyperparameters such as batch size, learning rate, and epochs tailored to your dataset and task. 39 | 40 | ## Running Classification 41 | 42 | Use the following command to start the classification process with your configured experiment: 43 | 44 | ```bash 45 | CUDA_VISIBLE_DEVICES=0 medvae_classify experiment=example_cls 46 | ``` 47 | 48 | ### Multi-GPU Training 49 | 50 | The classification framework supports multi-GPU training using HuggingFace Accelerate. Example: 51 | 52 | ```bash 53 | CUDA_VISIBLE_DEVICES=0,1,2 medvae_classify experiment=example_cls 54 | ``` 55 | 56 | ## Support 57 | 58 | For questions or support regarding classification tasks using MedVAE, please submit an issue on our [GitHub issues page](https://github.com/StanfordMIMI/MedVAE/issues). 59 | -------------------------------------------------------------------------------- /configs/dataloader/example_dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | ### NOTE: This config file does not actually run any experiments. It is only a template for the user to follow. 5 | ### It will NOT work. Please refer to this file for clarify on understanding the dataloader format for the img and lbl formats. 6 | ### You may want to add a test set if you want to run the code. 7 | ### Example file format: line number, image_uuid, label (0 or 1), split (train, val, test) 8 | 9 | train: 10 | dataset: 11 | _target_: medvae.dataloaders.ConcatDataset 12 | dataset_ids: 13 | - ${dataset_id} 14 | datasets: 15 | # DATASET 16 | - _target_: medvae.dataloaders.GenericDataset 17 | split_path: ${paths.data_dir}/mmgs.csv 18 | split_column: split 19 | split_name: train 20 | data_dir: ${paths.data_dir}/mmgs_data/ 21 | dataset_id: 5 22 | img_column: image_uuid 23 | # These are the latents 24 | img_suffix: .nii.gz 25 | img_transform: 26 | _target_: torchvision.transforms.Compose 27 | transforms: 28 | # You will need a proper latent / image dataloader 29 | - _partial_: true 30 | _target_: medvae.utils.loaders.load_monai_tensor 31 | - _target_: torchvision.transforms.Normalize 32 | mean: [0] 33 | std: [4.6] 34 | - _target_: medvae.utils.transforms.MergeChannels 35 | merge: ${merge_channels} 36 | lbl_columns: ["target"] 37 | lbl_transform: 38 | _partial_: true 39 | _target_: medvae.utils.loaders.load_labels 40 | dtype: float32 41 | squeeze: -1 42 | 43 | valid: 44 | dataset: 45 | _target_: medvae.dataloaders.ConcatDataset 46 | dataset_ids: 47 | - ${dataset_id} 48 | datasets: 49 | # DATASET 50 | - _target_: medvae.dataloaders.GenericDataset 51 | split_path: ${paths.data_dir}/mmgs.csv 52 | split_column: split 53 | split_name: test 54 | data_dir: ${paths.data_dir}/mmgs_data/ 55 | dataset_id: 5 56 | img_column: image_uuid 57 | # These are the latents 58 | img_suffix: .nii.gz 59 | img_transform: 60 | _target_: torchvision.transforms.Compose 61 | transforms: 62 | - _partial_: true 63 | _target_: medvae.utils.loaders.load_monai_tensor 64 | - _target_: torchvision.transforms.Normalize 65 | mean: [0] 66 | std: [4.6] 67 | - _target_: medvae.utils.transforms.MergeChannels 68 | merge: ${merge_channels} 69 | lbl_columns: ["target"] 70 | lbl_transform: 71 | _partial_: true 72 | _target_: medvae.utils.loaders.load_labels 73 | dtype: float32 74 | squeeze: -1 -------------------------------------------------------------------------------- /medvae/utils/extras.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellaneous utility functions. 3 | """ 4 | 5 | import os 6 | import random 7 | import torch 8 | import numpy as np 9 | import torch.backends.cudnn as cudnn 10 | 11 | 12 | def create_directory(directory: str) -> None: 13 | """ 14 | Create a directory if it does not exist. 15 | 16 | Parameters: 17 | directory (str): The directory to create. 18 | """ 19 | if not os.path.exists(directory): 20 | os.makedirs(directory) 21 | return directory 22 | 23 | 24 | def cite_function(): 25 | """ 26 | Print a message to cite the MedVAE paper. 27 | """ 28 | print( 29 | "\n#######################################################################\nPlease cite the following paper " 30 | "when using MedVAE: Varma, M., Kumar, A., van der Sluijs, R., Ostmeier, S., " 31 | "Blankemeier, L., Chambon, P., Bluethgen, C., Prince, J., Langlotz, C., " 32 | "Chaudhari, A. (2025). MedVAE: Efficient Automated Interpretation of Medical Images with Large-Scale Generalizable Autoencoders. " 33 | "arXiv preprint arXiv:2502.14753.\n" 34 | "#######################################################################\n" 35 | ) 36 | 37 | 38 | """ 39 | Calculate the region of interest (ROI) size for each dimension of the input image shape. 40 | 41 | This function determines the appropriate ROI size based on the target GPU dimension. 42 | If a dimension exceeds the target GPU dimension, it finds the largest power of 2 that 43 | results in a size less than the target dimension. 44 | 45 | @param image_shape: A tuple or list representing the shape of the input image (e.g., (depth, height, width)). 46 | @param target_gpu_dim: The maximum dimension size allowed for processing on the GPU (default is 160). 47 | @return: A list of calculated ROI sizes for each dimension of the input image. 48 | """ 49 | 50 | 51 | def roi_size_calc(image_shape, target_gpu_dim=160): 52 | roi_size = [] 53 | for dim in image_shape: 54 | if dim > target_gpu_dim: 55 | target_shape = target_gpu_dim 56 | # For loop for powers of 2 57 | for power in [2**i for i in range(8)]: 58 | if dim // power < target_shape: 59 | roi_size.append(dim // power) 60 | break 61 | else: 62 | roi_size.append(dim) 63 | 64 | return roi_size 65 | 66 | 67 | """ 68 | This function sanitizes the keyword arguments for a DataLoader by converting the 'num_workers' argument to an integer. 69 | This is necessary when 'num_workers' is retrieved from the OS environment, as it might be stored as a string. 70 | """ 71 | 72 | 73 | def sanitize_dataloader_kwargs(kwargs): 74 | if "num_workers" in kwargs: 75 | kwargs["num_workers"] = int(kwargs["num_workers"]) 76 | 77 | return kwargs 78 | 79 | 80 | def set_seed(seed: int): 81 | """Seed the RNGs.""" 82 | 83 | print(f"=> Setting seed [seed={seed}]") 84 | random.seed(seed) 85 | torch.manual_seed(seed) 86 | np.random.seed(seed) 87 | cudnn.deterministic = True 88 | 89 | print("=> Setting a seed slows down training considerably!") 90 | 91 | 92 | def get_weight_dtype(accelerator): 93 | """Get the weight dtype from the accelerator.""" 94 | 95 | if accelerator.mixed_precision == "fp16": 96 | weight_dtype = torch.float16 97 | else: 98 | weight_dtype = torch.float32 99 | 100 | return weight_dtype 101 | -------------------------------------------------------------------------------- /configs/dataloader/mri_ct_3d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | train: 5 | dataset: 6 | _target_: medvae.dataloaders.ConcatDataset 7 | datasets: 8 | # MRI 9 | - _target_: medvae.dataloaders.GenericDataset 10 | split_path: ${paths.data_dir}/mri_data.csv 11 | split_column: split 12 | split_name: train 13 | data_dir: ${paths.data_dir}/mri_data/ 14 | dataset_id: 0 15 | img_column: image_uuid 16 | img_suffix: .nii.gz 17 | img_transform: 18 | _target_: torchvision.transforms.Compose 19 | transforms: 20 | - _partial_: true 21 | _target_: medvae.utils.loaders.load_mri_3d_finetune 22 | 23 | # CT 24 | - _target_: medvae.dataloaders.GenericDataset 25 | split_path: ${paths.data_dir}/ct_data.csv 26 | split_column: split 27 | split_name: train 28 | data_dir: ${paths.data_dir}/ct_data/ 29 | dataset_id: 1 30 | img_column: image_uuid 31 | img_suffix: .nii.gz 32 | img_transform: 33 | _target_: torchvision.transforms.Compose 34 | transforms: 35 | - _partial_: true 36 | _target_: medvae.utils.loaders.load_ct_3d_finetune 37 | valid: 38 | dataset: 39 | _target_: medvae.dataloaders.ConcatDataset 40 | datasets: 41 | # MRI 42 | - _target_: medvae.dataloaders.GenericDataset 43 | split_path: ${paths.data_dir}/mri_data.csv 44 | split_column: split 45 | split_name: val 46 | data_dir: ${paths.data_dir}/mri_data/ 47 | dataset_id: 0 48 | img_column: image_uuid 49 | img_suffix: .nii.gz 50 | img_transform: 51 | _target_: torchvision.transforms.Compose 52 | transforms: 53 | - _partial_: true 54 | _target_: medvae.utils.loaders.load_mri_3d_finetune 55 | 56 | # CT 57 | - _target_: medvae.dataloaders.GenericDataset 58 | split_path: ${paths.data_dir}/ct_data.csv 59 | split_column: split 60 | split_name: val 61 | data_dir: ${paths.data_dir}/ct_data/ 62 | dataset_id: 1 63 | img_column: image_uuid 64 | img_suffix: .nii.gz 65 | img_transform: 66 | _target_: torchvision.transforms.Compose 67 | transforms: 68 | - _partial_: true 69 | _target_: medvae.utils.loaders.load_ct_3d_finetune 70 | 71 | test: 72 | dataset: 73 | _target_: medvae.dataloaders.ConcatDataset 74 | datasets: 75 | # MRI 76 | - _target_: medvae.dataloaders.GenericDataset 77 | split_path: ${paths.data_dir}/mri_data.csv 78 | split_column: split 79 | split_name: test 80 | data_dir: ${paths.data_dir}/mri_data/ 81 | dataset_id: 0 82 | img_column: image_uuid 83 | img_suffix: .nii.gz 84 | img_transform: 85 | _target_: torchvision.transforms.Compose 86 | transforms: 87 | - _partial_: true 88 | _target_: medvae.utils.loaders.load_mri_3d_finetune 89 | 90 | # CT 91 | - _target_: medvae.dataloaders.GenericDataset 92 | split_path: ${paths.data_dir}/ct_data.csv 93 | split_column: split 94 | split_name: test 95 | data_dir: ${paths.data_dir}/ct_data/ 96 | dataset_id: 1 97 | img_column: image_uuid 98 | img_suffix: .nii.gz 99 | img_transform: 100 | _target_: torchvision.transforms.Compose 101 | transforms: 102 | - _partial_: true 103 | _target_: medvae.utils.loaders.load_ct_3d_finetune -------------------------------------------------------------------------------- /documentation/create_csv.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# create_csv.ipynb\n", 8 | "\n", 9 | "## Description\n", 10 | "This jupyter notebook creates a csv file based on the directory you want to finetune.\n", 11 | "Specifically, this code will take the file and create train, val, and test splits for your directory.\n", 12 | "The csv file will be placed in the preceding directory. It is important to run this for your dataloader. The csv file will be the same name as the directory. The random split probabilities will most accurate as number of samples increase." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 18, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import pandas as pd\n", 22 | "import numpy as np\n", 23 | "import os\n", 24 | "import random\n", 25 | "from tqdm import tqdm" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 19, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "## Note: CODE TO CHANGE HERE ##\n", 35 | "\n", 36 | "# Example data here is mammogram data for training\n", 37 | "DATA_PATH = os.path.abspath(\"../medvae/data/mmg_data\")\n", 38 | "\n", 39 | "train_pct = 0.6\n", 40 | "val_pct = 0.2\n", 41 | "test_pct = 0.2\n", 42 | "\n", 43 | "# Create a list of probabilities\n", 44 | "probabilities = [train_pct, val_pct, test_pct]\n", 45 | "splits = ['train', 'val', 'test']\n", 46 | "\n", 47 | "# Make sure the sum of the percentages is 1\n", 48 | "assert sum(probabilities) == 1\n", 49 | "\n", 50 | "# Make sure splits == pcts\n", 51 | "assert len(splits) == len(probabilities)\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 21, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stderr", 61 | "output_type": "stream", 62 | "text": [ 63 | "100%|██████████| 10/10 [00:00<00:00, 7588.75it/s]\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Iterate \n", 69 | "data_files = os.listdir(DATA_PATH) \n", 70 | "\n", 71 | "# Shuffle the data files\n", 72 | "random.seed(42)\n", 73 | "random.shuffle(data_files)\n", 74 | "\n", 75 | "save_df = []\n", 76 | "\n", 77 | "# Iterate through all the files in the data directory\n", 78 | "for i, data_file in tqdm(enumerate(data_files), total=len(data_files)):\n", 79 | " file_id = data_file.split('.')[0]\n", 80 | " \n", 81 | " save_df.append({\n", 82 | " 'row_nr': i,\n", 83 | " 'image_uuid': file_id,\n", 84 | " # Randomly assign the split, with 60% train, 20% val, 20% test\n", 85 | " 'split': np.random.choice(splits, p=probabilities)\n", 86 | " })\n", 87 | "\n", 88 | "# Create a pandas DataFrame from the save_df list\n", 89 | "save_df = pd.DataFrame(save_df)\n", 90 | "\n", 91 | "# Save the DataFrame as a CSV file\n", 92 | "save_df.to_csv(os.path.join('/'.join(DATA_PATH.split('/')[:-1]), f'{DATA_PATH.split(\"/\")[-1]}.csv'), index=False)" 93 | ] 94 | } 95 | ], 96 | "metadata": { 97 | "kernelspec": { 98 | "display_name": "compress", 99 | "language": "python", 100 | "name": "python3" 101 | }, 102 | "language_info": { 103 | "codemirror_mode": { 104 | "name": "ipython", 105 | "version": 3 106 | }, 107 | "file_extension": ".py", 108 | "mimetype": "text/x-python", 109 | "name": "python", 110 | "nbconvert_exporter": "python", 111 | "pygments_lexer": "ipython3", 112 | "version": "3.10.14" 113 | } 114 | }, 115 | "nbformat": 4, 116 | "nbformat_minor": 2 117 | } 118 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "medvae" 3 | version = "0.1.7" 4 | requires-python = ">=3.9" 5 | description = "MedVAE is a family of six medical image autoencoders that can encode high-dimensional medical images into latent representations." 6 | readme = "README.md" 7 | license = { file = "LICENSE" } 8 | authors = [ 9 | { name = "Maya Varma", email = "mayavarma@cs.stanford.edu"}, 10 | { name = "Ashwin Kumar", email = "akkumar@stanford.edu"}, 11 | { name = "Rogier van der Sluijs", email = "sluijs@stanford.edu"}, 12 | { name = "Stanford Machine Intelligence for Medical Imaging (MIMI)" } 13 | ] 14 | classifiers = [ 15 | "Development Status :: 5 - Production/Stable", 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "Intended Audience :: Healthcare Industry", 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 23 | ] 24 | keywords = [ 25 | 'deep learning', 26 | 'image compression', 27 | 'compression', 28 | 'efficiency', 29 | 'computer aided diagnosis', 30 | 'medical image analysis', 31 | 'autoencoders', 32 | 'representation learning', 33 | 'Med-VAE', 34 | 'medvae' 35 | ] 36 | dependencies = [ 37 | "torch>=2.4.1", # tested using 2.6.0 38 | "accelerate>=0.34.2", # tested using 0.34.2 39 | "wandb==0.14.0; python_version < '3.12'", # tested using 0.14.0 for older python 40 | "wandb>=0.16.0; python_version >= '3.12'", # use newer wandb for Python 3.12 41 | "tqdm", # tested using 4.67.1 42 | "dicom2nifti", # tested using 2.5.1 43 | "scipy", # tested using 1.13.1 44 | "batchgenerators>=0.25", # tested using 0.25.1 45 | "numpy>=1.24", # tested using 1.26.4 46 | "scikit-learn", # tested using 1.6.1 47 | "scikit-image>=0.19.3", # tested using 0.24.0 48 | "SimpleITK>=2.4.0", # tested using 2.4.1 49 | "omegaconf>=2.3.0", # tested using 2.3.0 50 | "pandas", # tested using 2.2.3 51 | 'requests', # tested using 2.32.3 52 | "nibabel", # tested using 5.3.2 53 | "matplotlib", # tested using 3.9.4 54 | "seaborn", # tested using 0.13.2 55 | "imagecodecs", # tested using 2024.12.30 56 | "yacs", # tested using 0.1.8 57 | "batchgeneratorsv2>=0.2", # tested using 0.2.3 58 | "einops>=0.8.0", # tested using 0.8.1 59 | "monai>=1.3.2", # tested using 1.4.0 60 | "torchvision>=0.19.1", # tested using 0.21.0 61 | "gdown", # tested using 5.2.0 62 | "nilearn", # tested using 0.11.1 63 | "pyrootutils", # tested using 1.0.4 64 | "hydra-core", # tested using 1.3.2 65 | "torchmetrics", # tested using 1.7.0 66 | "hydra-colorlog", # tested using 1.2.0 67 | "open_clip_torch==2.24.0", 68 | "polars>=0.19.10", # tested using 0.19.10 69 | "rich", # tested using 13.9.4 70 | "clip@git+https://github.com/openai/CLIP.git" # tested using 1.0 71 | ] 72 | 73 | [project.urls] 74 | homepage = "https://github.com/StanfordMIMI/MedVAE" 75 | repository = "https://github.com/StanfordMIMI/MedVAE" 76 | 77 | [project.scripts] 78 | medvae_inference = "medvae.medvae_inference:main" 79 | medvae_finetune = "medvae.medvae_finetune:main" 80 | medvae_finetune_s2 = "medvae.medvae_finetune_s2:main" 81 | medvae_classify = "medvae.medvae_cls:main" 82 | 83 | [project.optional-dependencies] 84 | dev = [ 85 | "ruff", 86 | "pre-commit", 87 | "mdformat" 88 | ] 89 | 90 | [build-system] 91 | requires = ["setuptools>=67.8.0"] 92 | build-backend = "setuptools.build_meta" 93 | 94 | [tool.setuptools.packages.find] 95 | include = ["medvae", "medvae.utils", "medvae.models", "medvae.utils.vae"] 96 | exclude = ["documentation", "configs"] 97 | 98 | [tool.codespell] 99 | skip = '.git,*.pdf,*.svg, *.png' 100 | -------------------------------------------------------------------------------- /medvae/models/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from medvae.utils.vae.diffusionmodels import Decoder, Encoder 4 | from medvae.utils.vae.distributions import DiagonalGaussianDistribution 5 | 6 | 7 | class AutoencoderKL(torch.nn.Module): 8 | def __init__( 9 | self, 10 | ddconfig, 11 | embed_dim, 12 | ckpt_path=None, 13 | ignore_keys=[], 14 | apply_channel_ds=True, 15 | state_dict=True, 16 | ): 17 | super().__init__() 18 | self.encoder = Encoder(**ddconfig) 19 | self.decoder = Decoder(**ddconfig) 20 | assert ddconfig["double_z"] 21 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 22 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 23 | self.embed_dim = embed_dim 24 | 25 | self.apply_channel_ds = apply_channel_ds 26 | if self.apply_channel_ds: 27 | self.channel_ds = torch.nn.Sequential( 28 | torch.nn.Conv2d(self.embed_dim, 64, 1), 29 | torch.nn.ReLU(), 30 | torch.nn.Conv2d(64, 64, 3, padding="same"), 31 | torch.nn.ReLU(), 32 | torch.nn.Conv2d(64, self.embed_dim, 1), 33 | ) 34 | self.channel_proj = torch.nn.Conv2d(self.embed_dim, self.embed_dim, 1) 35 | 36 | if ckpt_path is not None: 37 | self.init_from_ckpt( 38 | ckpt_path, ignore_keys=ignore_keys, state_dict=state_dict 39 | ) 40 | 41 | def init_from_ckpt(self, path, ignore_keys=list(), state_dict=True): 42 | if not state_dict: 43 | sd = torch.load(path, map_location="cpu") 44 | else: 45 | sd = torch.load(path, map_location="cpu")["state_dict"] 46 | keys = list(sd.keys()) 47 | for k in keys: 48 | for ik in ignore_keys: 49 | if k.startswith(ik): 50 | print(f"Deleting key {k} from state_dict.") 51 | del sd[k] 52 | missing, unexpected = self.load_state_dict(sd, strict=False) 53 | print( 54 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 55 | ) 56 | if len(missing) > 0: 57 | print(f"Missing Keys: {missing}") 58 | print(f"Unexpected Keys: {unexpected}") 59 | 60 | def encode_moments(self, x): 61 | h = self.encoder(x) 62 | moments = self.quant_conv(h) 63 | posterior = DiagonalGaussianDistribution(moments) 64 | return moments, posterior 65 | 66 | def moment_diagonal(self, moments): 67 | return DiagonalGaussianDistribution(moments) 68 | 69 | def encode(self, x): 70 | h = self.encoder(x) 71 | moments = self.quant_conv(h) 72 | posterior = DiagonalGaussianDistribution(moments) 73 | return posterior 74 | 75 | def decode(self, z): 76 | z = self.post_quant_conv(z) 77 | dec = self.decoder(z) 78 | return dec 79 | 80 | def compute_latent_proj(self, x, sample_posterior=True): 81 | posterior = self.encode(x) 82 | if sample_posterior: 83 | z = posterior.sample() 84 | else: 85 | z = posterior.mode() 86 | if self.apply_channel_ds: 87 | return z, posterior, self.channel_proj(self.channel_ds(z) + z) 88 | 89 | return z, posterior, None 90 | 91 | def forward(self, input, sample_posterior=True, decode=True): 92 | posterior = self.encode(input) 93 | if sample_posterior: 94 | z = posterior.sample() 95 | else: 96 | z = posterior.mode() 97 | latent = self.channel_proj(self.channel_ds(z) + z) 98 | if decode: 99 | dec = self.decode(z) 100 | return dec, posterior, latent 101 | else: 102 | return z, posterior, latent 103 | 104 | def get_last_layer(self): 105 | return self.decoder.conv_out.weight 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedVAE: Efficient Automated Interpretation of Medical Images with Large-Scale Generalizable Autoencoders 2 | 3 | [![Hugging Face](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg)](https://huggingface.co/stanfordmimi/MedVAE)    [![pypi](https://img.shields.io/pypi/v/medvae?style=for-the-badge)](https://pypi.org/project/medvae/)    [![arXiv](https://img.shields.io/badge/arXiv-2502.14753-b31b1b.svg?style=for-the-badge)](https://arxiv.org/abs/2502.14753)    [![Watch the Talk on YouTube](https://img.shields.io/badge/YouTube-Talk-red?style=for-the-badge&logo=youtube)](https://www.youtube.com/watch?v=5zoxHz71ZgY)    [![License](https://img.shields.io/github/license/stanfordmimi/medvae?style=for-the-badge)](LICENSE) 4 | 5 | This repository contains the official PyTorch implementation for [MedVAE: Efficient Automated Interpretation of Medical Images with Large-Scale Generalizable Autoencoders](https://arxiv.org/abs/2502.14753) (MIDL 2025; Best Oral Paper Award). 6 | 7 | ![Overview](documentation/assets/overview.png) 8 | 9 | ## 🫁 What is MedVAE? 10 | 11 | MedVAE is a family of six large-scale, generalizable 2D and 3D variational autoencoders (VAEs) designed for medical imaging. It is trained on over one million medical images across multiple anatomical regions and modalities. MedVAE autoencoders encode medical images as downsized latent representations and decode latent representations back to high-resolution images. Across diverse tasks obtained from 20 medical image datasets, we demonstrate that utilizing MedVAE latent representations in place of high-resolution images when training downstream models can lead to efficiency benefits (up to 70x improvement in throughput) while simultaneously preserving clinically-relevant features. 12 | 13 | ## ⚡️ Installation 14 | 15 | To install MedVAE, you can simply run: 16 | 17 | ```python 18 | pip install medvae 19 | ``` 20 | 21 | For an editable installation, use the following commands to clone and install this repository. 22 | 23 | ```python 24 | git clone https://github.com/StanfordMIMI/MedVAE.git 25 | cd MedVAE 26 | pip install -e .[dev] 27 | pre-commit install 28 | pre-commit 29 | ``` 30 | 31 | ## 🚀 Inference Instructions 32 | 33 | ```python 34 | import torch 35 | from medvae import MVAE 36 | 37 | fpath = "documentation/data/mmg_data/isJV8hQ2hhJsvEP5rdQNiy.png" 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | 40 | model = MVAE(model_name="medvae_4_3_2d", modality="xray").to(device) 41 | img = model.apply_transform(fpath).to(device) 42 | 43 | model.requires_grad_(False) 44 | model.eval() 45 | 46 | with torch.no_grad(): 47 | latent = model(img) 48 | ``` 49 | 50 | We also developed an easy-to-use CLI inference tool for compressing your high-dimensional medical images into usable latents: 51 | 52 | ```python 53 | medvae_inference -i INPUT_FOLDER -o OUTPUT_FOLDER -model_name MED_VAE_MODEL -modality MODALITY 54 | ``` 55 | 56 | For more information, please check our [inference documentation](/documentation/inference.md) and [demo](documentation/demo.ipynb). 57 | 58 | ## 🔧 Finetuning Instructions 59 | 60 | Easily finetune MedVAE on **your own dataset**! Follow the instructions below (requires Python 3.9 and cloning the repository). 61 | 62 | Run the following commands depending on your finetuning scenario: 63 | 64 | **Stage 1 (2D) Finetuning** 65 | 66 | ```bash 67 | medvae_finetune experiment=medvae_4x_1c_2d_finetuning 68 | ``` 69 | 70 | **Stage 2 (2D) Finetuning:** 71 | 72 | ```bash 73 | medvae_finetune_s2 experiment=medvae_4x_1c_2d_s2_finetuning 74 | ``` 75 | 76 | **Stage 2 (3D) Finetuning:** 77 | 78 | ```bash 79 | medvae_finetune experiment=medvae_4x_1c_3d_finetuning 80 | ``` 81 | 82 | This setup supports multi-GPU training and includes integration with Weights & Biases for experiment tracking. 83 | 84 | For detailed finetuning guidelines, see the [Finetuning Documentation](documentation/finetune.md). 85 | 86 | To create classification models using downsized latent representations, refer to the [Classification Documentation](documentation/classification.md). 87 | 88 | ## 📎 Citation 89 | 90 | If you find this repository useful for your work, please cite the following paper: 91 | 92 | ```bibtex 93 | @misc{varma2025medvaeefficientautomatedinterpretation, 94 | title={MedVAE: Efficient Automated Interpretation of Medical Images with Large-Scale Generalizable Autoencoders}, 95 | author={Maya Varma and Ashwin Kumar and Rogier van der Sluijs and Sophie Ostmeier and Louis Blankemeier and Pierre Chambon and Christian Bluethgen and Jip Prince and Curtis Langlotz and Akshay Chaudhari}, 96 | year={2025}, 97 | eprint={2502.14753}, 98 | archivePrefix={arXiv}, 99 | primaryClass={eess.IV}, 100 | url={https://arxiv.org/abs/2502.14753}, 101 | } 102 | ``` 103 | 104 | This repository is powered by [Hydra](https://github.com/facebookresearch/hydra) and [HuggingFace Accelerate](https://github.com/huggingface/accelerate). Our implementation of MedVAE is inspired by prior work on diffusion models from [CompVis](https://github.com/CompVis/latent-diffusion) and [Stability AI](https://github.com/Stability-AI/stablediffusion). 105 | -------------------------------------------------------------------------------- /medvae/models/autoencoder_kl_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from medvae.utils.vae.diffusionmodels_3d import Decoder, Encoder 4 | from medvae.utils.vae.distributions import DiagonalGaussianDistribution 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | 8 | class AutoencoderKL(torch.nn.Module): 9 | def __init__( 10 | self, ddconfig, embed_dim, ckpt_path=None, ignore_keys=[], apply_channel_ds=True 11 | ): 12 | super().__init__() 13 | self.encoder = Encoder(**ddconfig) 14 | self.decoder = Decoder(**ddconfig) 15 | assert ddconfig["double_z"] 16 | self.quant_conv = torch.nn.Conv3d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 17 | self.post_quant_conv = torch.nn.Conv3d(embed_dim, ddconfig["z_channels"], 1) 18 | self.embed_dim = embed_dim 19 | 20 | self.apply_channel_ds = apply_channel_ds 21 | if self.apply_channel_ds: 22 | self.channel_ds = torch.nn.Sequential( 23 | torch.nn.Conv3d(self.embed_dim, 64, 1), 24 | torch.nn.ReLU(), 25 | torch.nn.Conv3d(64, 64, 3, padding="same"), 26 | torch.nn.ReLU(), 27 | torch.nn.Conv3d(64, self.embed_dim, 1), 28 | ) 29 | self.channel_proj = torch.nn.Conv3d(self.embed_dim, self.embed_dim, 1) 30 | 31 | if ckpt_path is not None: 32 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, state_dict=True) 33 | 34 | def init_from_ckpt(self, path, ignore_keys=list(), state_dict=False): 35 | if not state_dict: 36 | sd = torch.load(path, map_location="cpu") 37 | else: 38 | sd = torch.load(path, map_location="cpu")["state_dict"] 39 | keys = list(sd.keys()) 40 | for k in keys: 41 | for ik in ignore_keys: 42 | if k.startswith(ik): 43 | print(f"Deleting key {k} from state_dict.") 44 | del sd[k] 45 | 46 | model_state_dict = self.state_dict() 47 | missing = [] 48 | unexpected = [] 49 | for name, param in sd.items(): 50 | if name in model_state_dict: 51 | if param.size() == model_state_dict[name].size(): 52 | model_state_dict[name].copy_(param) 53 | else: 54 | # If the dimensions of the checkpoint param is one less than model param, then copy the checkpoint param across 55 | # middle dimension of model param 56 | if ( 57 | len(param.size()) == len(model_state_dict[name].size()) - 1 58 | and len(model_state_dict[name].size()) == 5 59 | ): 60 | ckpt_param = param.unsqueeze(4) 61 | ckpt_param_repeat = ckpt_param.repeat( 62 | 1, 1, 1, 1, model_state_dict[name].size(4) 63 | ) 64 | # Only save the weight to the middle kernel slice 65 | ckpt_param_repeat = torch.zeros_like(ckpt_param_repeat) 66 | ckpt_param_repeat[ 67 | :, :, :, :, model_state_dict[name].size(4) // 2 68 | ] = param 69 | 70 | model_state_dict[name].copy_(ckpt_param_repeat) 71 | else: 72 | missing.append(name) 73 | else: 74 | unexpected.append(name) 75 | 76 | # Load state 77 | self.load_state_dict(model_state_dict, strict=False) 78 | 79 | # missing, unexpected = self.load_state_dict(sd, strict=False) 80 | print( 81 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 82 | ) 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | print(f"Unexpected Keys: {unexpected}") 86 | elif len(unexpected) > 0: 87 | print(f"Unexpected Keys: {unexpected}") 88 | 89 | def encode(self, x): 90 | h = self.encoder(x) 91 | moments = self.quant_conv(h) 92 | posterior = DiagonalGaussianDistribution(moments) 93 | return posterior 94 | 95 | def decode(self, z): 96 | z = self.post_quant_conv(z) 97 | dec = self.decoder(z) 98 | return dec 99 | 100 | def compute_latent_proj(self, x, sample_posterior=True): 101 | posterior = self.encode(x) 102 | if sample_posterior: 103 | z = posterior.sample() 104 | else: 105 | z = posterior.mode() 106 | if self.apply_channel_ds: 107 | return z, posterior, self.channel_proj(self.channel_ds(z) + z) 108 | 109 | return z, posterior, None 110 | 111 | def forward(self, input, sample_posterior=True, decode=True): 112 | posterior = checkpoint(self.encode, input, use_reentrant=False) 113 | if sample_posterior: 114 | z = posterior.sample() 115 | else: 116 | z = posterior.mode() 117 | latent = checkpoint( 118 | self.channel_proj, self.channel_ds(z) + z, use_reentrant=False 119 | ) 120 | if decode: 121 | dec = checkpoint(self.decode, z, use_reentrant=False) 122 | return dec, posterior, z 123 | else: 124 | return z, posterior, latent 125 | 126 | def get_last_layer(self): 127 | return self.decoder.conv_out.weight 128 | -------------------------------------------------------------------------------- /medvae/utils/vae/train_components_stage2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from torch.optim import Optimizer 8 | from rich import print 9 | from torch.utils.data import DataLoader 10 | from torchmetrics import Metric 11 | from medvae.utils.extras import ( 12 | get_weight_dtype, 13 | ) 14 | from medvae.utils.transforms import to_dict 15 | 16 | __all__ = ["training_epoch", "validation_epoch"] 17 | 18 | def training_epoch( 19 | epoch: int, 20 | global_step: int, 21 | accelerator: Accelerator, 22 | dataloader: DataLoader, 23 | model: torch.nn.Module, 24 | criterion: torch.nn.Module, 25 | default_metrics: List[Metric], 26 | rec_metrics: List[Metric], 27 | opt: Optimizer, 28 | options: Dict[str, Any], 29 | ): 30 | """Train a single epoch of a VAE model.""" 31 | for metric in default_metrics: 32 | metric.reset() 33 | 34 | for _, metric in rec_metrics: 35 | metric.reset() 36 | 37 | ( 38 | metric_bc_loss, 39 | metric_data, 40 | metric_batch, 41 | ) = default_metrics 42 | 43 | model.train() 44 | epoch_start, batch_start = time(), time() 45 | dtype = get_weight_dtype(accelerator) 46 | for i, batch in enumerate(dataloader): 47 | data_time = time() - batch_start 48 | 49 | batch = to_dict(batch) 50 | images = batch["img"].to(dtype) 51 | _, _, latent = model(images, decode=False) 52 | 53 | loss = criterion(images, latent=latent) 54 | loss = loss.sum() / images.shape[0] 55 | 56 | opt.zero_grad() 57 | accelerator.backward(loss) 58 | opt.step() 59 | 60 | # Update metrics 61 | batch_time = time() - batch_start 62 | metric_data.update(data_time) 63 | metric_batch.update(batch_time) 64 | metric_bc_loss.update(loss) 65 | 66 | # Logging values 67 | print( 68 | f"\r[Epoch <{epoch:03}/{options['max_epoch']}>: Step <{i:03}/{len(dataloader)}>] - " 69 | + f"Data(s): {data_time:.3f} ({metric_data.compute():.3f}) - " 70 | + f"Batch(s): {batch_time:.3f} ({metric_batch.compute():.3f}) - " 71 | + f"BC Loss: {loss.item():.3f} ({metric_bc_loss.compute():.3f}) - " 72 | + f"{(((time() - epoch_start) / (i + 1)) * (len(dataloader) - i)) / 60:.2f} m remaining\n" 73 | ) 74 | 75 | if options["is_logging"] and i % options["log_every_n_steps"] == 0: 76 | log_data = { 77 | "epoch": epoch, 78 | "mean_data": metric_data.compute(), 79 | "mean_batch": metric_batch.compute(), 80 | "step": i, 81 | "step_global": global_step, 82 | "step_data": data_time, 83 | "step_batch": batch_time, 84 | } 85 | log_data["mean_bc_loss"] = metric_bc_loss.compute() 86 | 87 | for name, metric in rec_metrics: 88 | log_data[name] = metric.compute() 89 | 90 | accelerator.log(log_data) 91 | global_step += 1 92 | 93 | if global_step % options["ckpt_every_n_steps"] == 0: 94 | try: 95 | accelerator.save_state(os.path.join(options["ckpt_dir"], f"step_{global_step}.pt")) 96 | except Exception as e: 97 | print(e) 98 | 99 | batch_start = time() 100 | 101 | if options["fast_dev_run"]: 102 | break 103 | 104 | return global_step 105 | 106 | 107 | def validation_epoch( 108 | options: Dict[str, Any], 109 | epoch: int, 110 | accelerator: Accelerator, 111 | dataloader: DataLoader, 112 | model: torch.nn.Module, 113 | criterion: torch.nn.Module, 114 | default_metrics: List[Metric], 115 | rec_metrics: List[Metric], 116 | global_step: int = 0, 117 | postfix: str = "", 118 | ): 119 | """Validate one epoch for the VAE model.""" 120 | for metric in default_metrics: 121 | metric.reset() 122 | 123 | for _, metric in rec_metrics: 124 | metric.reset() 125 | 126 | metric_bcloss = default_metrics[0] 127 | metric_bcloss.reset() 128 | 129 | model.eval() 130 | criterion.eval() 131 | epoch_start = time() 132 | dtype = get_weight_dtype(accelerator) 133 | with torch.no_grad(): 134 | for i, batch in enumerate(dataloader): 135 | batch = to_dict(batch) 136 | 137 | images = batch["img"].to(dtype) 138 | _, _, latent = model(images, decode=False) 139 | 140 | loss = criterion(images, latent=latent) 141 | loss = loss.sum() / images.shape[0] 142 | 143 | metric_bcloss.update(loss) 144 | 145 | # Logging values 146 | print( 147 | f"\r Validation{postfix}: " 148 | + f"\r[Epoch <{epoch:03}/{options['max_epoch']}>: Step <{i:03}/{len(dataloader)}>] - " 149 | + f"BC Loss: {loss.item():.3f} ({metric_bcloss.compute():.3f}) - " 150 | + f"{(((time() - epoch_start) / (i + 1)) * (len(dataloader) - i)) / 60:.2f} m remaining\n" 151 | ) 152 | 153 | if options["fast_dev_run"]: 154 | break 155 | 156 | if options["is_logging"]: 157 | log_data = { 158 | f"valid{postfix}/epoch": epoch, 159 | f"valid{postfix}/mean_bcloss": metric_bcloss.compute(), 160 | } 161 | for name, metric in rec_metrics: 162 | log_data[f"valid{postfix}/{name}"] = metric.compute() 163 | 164 | accelerator.log(log_data) 165 | 166 | return metric_bcloss.compute() -------------------------------------------------------------------------------- /medvae/medvae_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from genericpath import isdir 3 | from medvae.utils.extras import create_directory, cite_function 4 | import torch 5 | import nibabel as nib 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | from os.path import join as pjoin 10 | from medvae import MVAE 11 | 12 | def parse_arguments(): 13 | parser = argparse.ArgumentParser(description='Use this to run inference with MedVAE. This function is used when ' 14 | 'you want to manually specify a folder containing an pretrained MedVAE ' 15 | 'model. ', 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | 18 | parser.add_argument('-i', type=str, required=True, 19 | help='input folder. These should contain files that you want the latent to be processed for' 20 | 'Remember *.pngs are for 2D images and *.nii.gz are for 3D images. ' 21 | 'The filename should not have a "." in it apart from suffix') 22 | 23 | parser.add_argument('-o', type=str, required=True, 24 | help='Output folder. If it does not exist it will be created. Predicted latents will ' 25 | 'have the same name as their source images.') 26 | parser.add_argument( 27 | '-model_name', type=str, required=True, 28 | help=( 29 | "There are six MedVAE models that can be used for inference. Choose between:\n" 30 | "(1) medvae_4_1_2d: 2D images with a 4x compression in each dim (16x total) with a 1 channel latent.\n" 31 | "(2) medvae_4_3_2d: 2D images with a 4x compression in each dim (64x total) with a 3 channel latent.\n" 32 | "(3) medvae_8_1_2d: 2D images with an 8x compression in each dim (64x total) with a 1 channel latent.\n" 33 | "(4) medvae_8_4_2d: 2D images with an 8x compression in each dim (64x total) with a 4 channel latent.\n" 34 | "(5) medvae_4_1_3d: 3D images with a 4x compression in each dim (64x total) with a 1 channel latent.\n" 35 | "(6) medvae_8_1_3d: 3D images with an 8x compression in each dim (64x total) with a 1 channel latent.\n" 36 | ) 37 | ) 38 | 39 | parser.add_argument('-modality', type=str, required=True, 40 | help='Modality of the input images. Choose between xray, ct, or mri.') 41 | 42 | parser.add_argument('-ckpt_path', type=str, required=False, 43 | help='Path to the checkpoint file. If provided, the model will be loaded from the weight in this file.' + 44 | 'Note: This should be a ckpt after stage 2 2D and 3D finetuning. If you want stage 1, then modification need to be made') 45 | 46 | parser.add_argument('-roi_size', type=int, default=160, required=False, 47 | help='Region of interest size for 3D models. This is the maximum dimension size allowed for processing on the GPU.') 48 | 49 | parser.add_argument('-device', type=str, default='cuda', required=False, 50 | help="Use this to set the device the inference should run with. Available options are 'cuda' " 51 | "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " 52 | "Use CUDA_VISIBLE_DEVICES=X medvae_inference [...] instead!") 53 | 54 | # Print a message to cite the medvae paper 55 | cite_function() 56 | 57 | args, unknownargs = parser.parse_known_args() 58 | if unknownargs: 59 | print(f"Ignoring arguments: {unknownargs}") 60 | 61 | # Check if input folder exists 62 | assert isdir(args.i), f"Input folder {args.i} does not exist." 63 | 64 | # Create output directory if it does not exist 65 | create_directory(args.o) 66 | 67 | assert args.device in ['cpu', 'cuda', 68 | 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' 69 | 70 | if args.device == 'cpu': 71 | # let's allow torch to use lots of threads 72 | import multiprocessing 73 | torch.set_num_threads(multiprocessing.cpu_count()) 74 | device = torch.device('cpu') 75 | elif args.device == 'cuda': 76 | # multithreading in torch doesn't help medvae if run on GPU 77 | torch.set_num_threads(1) 78 | torch.set_num_interop_threads(1) 79 | device = torch.device('cuda') 80 | else: 81 | device = torch.device('mps') 82 | 83 | return args, device 84 | 85 | def __init__(): 86 | 87 | args, device = parse_arguments() 88 | 89 | # Build the model and transform 90 | model = MVAE(args.model_name, args.modality, args.roi_size).to(device) 91 | 92 | # If a checkpoint path is provided, load the model from the weight in this file 93 | if args.ckpt_path: 94 | model.init_from_ckpt(args.ckpt_path, state_dict=False) 95 | 96 | model.requires_grad_(False) 97 | model.eval() 98 | 99 | # Run inference on the input folder 100 | print("Running inference at {}".format(args.i)) 101 | for fpath in tqdm(os.listdir(args.i), total = len(os.listdir(args.i))): 102 | 103 | img = model.apply_transform(pjoin(args.i, fpath)).to(device) 104 | latent = model(img).detach().cpu().numpy() 105 | 106 | # Save the latent 107 | nib.save(nib.Nifti1Image(latent, np.eye(4)), pjoin(args.o, fpath.split('.')[0] + ".nii.gz")) 108 | 109 | print("Inference complete! Output saved at {}".format(args.o)) 110 | 111 | def main(): 112 | __init__() -------------------------------------------------------------------------------- /documentation/demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from medvae import MVAE\n", 10 | "import torch\n", 11 | "from nilearn.plotting import view_img\n", 12 | "import numpy as np\n", 13 | "import nibabel as nib" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "## Download Sample Data\n", 21 | "\n", 22 | "If you have trouble download from huggingface, you can delete local cache 'rm -rf .cache' and restart your jupter kernel." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from huggingface_hub import snapshot_download\n", 32 | "import shutil\n", 33 | "\n", 34 | "repo_id = \"stanfordmimi/MedVAE\"\n", 35 | "\n", 36 | "# Download the example_data directory\n", 37 | "local_dir = snapshot_download(\n", 38 | " repo_id=repo_id,\n", 39 | " allow_patterns=[\"example_data/*\"], # Only download files in example_data folder\n", 40 | " local_dir=\"./\", # Save to data directory\n", 41 | " max_workers=1,\n", 42 | " etag_timeout=10000,\n", 43 | " force_download=True,\n", 44 | ")\n", 45 | "\n", 46 | "# Rename the example_data directory to data\n", 47 | "shutil.move(\"example_data\", \"data\")\n", 48 | "\n", 49 | "# Remove the .cache directory\n", 50 | "shutil.rmtree(\".cache\")\n", 51 | "\n", 52 | "print(\"Download completed\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Example with 2D MedVAE (f=16; C=3)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 69 | "\n", 70 | "model = MVAE(\n", 71 | " model_name='medvae_4_3_2d',\n", 72 | " modality='xray',\n", 73 | ").to(device)\n", 74 | "model.requires_grad_(False)\n", 75 | "model.eval()\n", 76 | "\n", 77 | "fpath = 'data/mmg_data/TQcBVJediTG8E34ftHnapA.png'\n", 78 | "\n", 79 | "# Getting the transform and applying it\n", 80 | "transform = model.get_transform()\n", 81 | "img = transform(fpath).unsqueeze(0).to(device)\n", 82 | "\n", 83 | "# Getting the latent representation\n", 84 | "with torch.no_grad():\n", 85 | " latent = model(img).cpu().detach().numpy()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "view_img(stat_map_img=nib.Nifti1Image(latent.transpose(1, 2, 0), np.eye(4)), bg_img=False, cmap='gray')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## CT Example with 3D MedVAE (f=64; C=1)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 111 | "\n", 112 | "model = MVAE(\n", 113 | " model_name='medvae_4_1_3d',\n", 114 | " modality='CT',\n", 115 | ").to(device)\n", 116 | "model.requires_grad_(False)\n", 117 | "model.eval()\n", 118 | "\n", 119 | "fpath = 'data/ct_data/sino_7858_0398.nii.gz'\n", 120 | "\n", 121 | "# Apply the model transform -- easiest way\n", 122 | "img = model.apply_transform(fpath).to(device)\n", 123 | "\n", 124 | "# Getting the latent representation\n", 125 | "with torch.no_grad():\n", 126 | " latent = model(img).cpu().detach().numpy()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## MRI Example with 3D MedVAE (f=64; C=1)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 152 | "\n", 153 | "model = MVAE(\n", 154 | " model_name='medvae_4_1_3d',\n", 155 | " modality='MRI',\n", 156 | ").to(device)\n", 157 | "model.requires_grad_(False)\n", 158 | "model.eval()\n", 159 | "\n", 160 | "fpath = 'data/mri_data/t1oasis_case_1286.nii.gz'\n", 161 | "\n", 162 | "# Apply the model transform -- easiest way\n", 163 | "img = model.apply_transform(fpath).to(device)\n", 164 | "\n", 165 | "# Getting the latent representation\n", 166 | "with torch.no_grad():\n", 167 | " latent = model(img).cpu().detach().numpy()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')" 177 | ] 178 | } 179 | ], 180 | "metadata": { 181 | "kernelspec": { 182 | "display_name": "compress", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 3 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython3", 196 | "version": "3.10.14" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 2 201 | } 202 | -------------------------------------------------------------------------------- /medvae/dataloaders/generic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, List, Union 3 | 4 | import polars as pl 5 | from torch.utils.data import Dataset 6 | 7 | __all__ = ["GenericDataset"] 8 | 9 | 10 | class GenericDataset(Dataset): 11 | def __init__( 12 | self, 13 | split_path: Union[os.PathLike, str], 14 | data_dir: Union[os.PathLike, str], # DEPRECATED 15 | dataset_id: int, 16 | *, 17 | split_column: str = None, 18 | split_name: str = None, 19 | img_column: str = None, 20 | img_dir: str = None, 21 | img_suffix: str = None, 22 | img_transform: Callable = None, 23 | lbl_columns: List[str] = None, 24 | lbl_transform: Callable = None, 25 | msk_column: str = None, 26 | msk_dir: str = None, 27 | msk_suffix: str = None, 28 | msk_transform: Callable = None, 29 | txt_column: str = None, 30 | txt_dir: str = None, 31 | txt_suffix: str = None, 32 | txt_transform: Callable = None, 33 | com_transform: Callable = None, 34 | **kwargs, 35 | ): 36 | """A Generic Dataset implementation. 37 | 38 | The generic dataset can be used to create any dataset from a given CSV "split" and a data 39 | directory. The dataset aims to be as flexible as possible, allowing the user to specify 40 | the columns to use for images, labels, and text. The user can also specify the transforms 41 | to apply to each of these components. 42 | 43 | Args: 44 | split_path (Union[os.PathLike, str]): Path to the split file. 45 | data_dir (Union[os.PathLike, str]): Path to the data directory. 46 | img_column (str, optional): Image column. Defaults to "image_uuid". 47 | img_suffix (str, optional): Image suffix. Defaults to ".npy". 48 | img_transform (Callable, optional): Image transform. Defaults to None. 49 | lbl_columns (List[str], optional): Label columns. Defaults to None. 50 | lbl_transform (Callable, optional): Label transform. Defaults to None. 51 | txt_column (str, optional): Text column. Defaults to None. 52 | txt_transform (Callable, optional): Text transform. Defaults to None. 53 | com_transform (Callable, optional): Composite transform. Defaults to None. 54 | dataset_id (int, optional): Dataset ID. Defaults to None. 55 | """ 56 | 57 | # Store arguments 58 | self.split_path = split_path 59 | self.split_column = split_column 60 | self.split_name = split_name 61 | self.dataset_id = dataset_id 62 | self.img_dir = img_dir if img_dir is not None else data_dir 63 | self.img_column = img_column 64 | self.img_suffix = img_suffix 65 | self.img_transform = img_transform 66 | self.lbl_columns = lbl_columns 67 | self.lbl_transform = lbl_transform 68 | self.msk_column = msk_column 69 | self.msk_dir = msk_dir 70 | self.msk_suffix = msk_suffix 71 | self.msk_transform = msk_transform 72 | self.txt_column = txt_column 73 | self.txt_dir = txt_dir 74 | self.txt_suffix = txt_suffix 75 | self.txt_transform = txt_transform 76 | self.com_transform = com_transform 77 | self.kwargs = kwargs 78 | 79 | self.samples = {} 80 | 81 | # Create the samples for the images, labels, and text 82 | if not os.path.exists(split_path): 83 | raise ValueError(f"Split path {split_path} does not exist.") 84 | 85 | self.df = pl.read_csv(split_path) 86 | if isinstance(split_column, str) and isinstance(split_name, str): 87 | self.df = self.df.filter(pl.col(split_column) == split_name) 88 | 89 | # Generate image paths 90 | if img_column is not None: 91 | self.samples["img"] = ( 92 | self.df.get_column(img_column) 93 | .apply(lambda x: os.path.join(self.img_dir, f"{x}{img_suffix or ''}")) 94 | .to_list() 95 | ) 96 | 97 | # Extract the columns with labels 98 | if lbl_columns is not None: 99 | self.samples["lbl"] = self.df.select(lbl_columns) 100 | 101 | # Extract the column with text or a path to a text file 102 | if txt_column is not None: 103 | self.samples["txt"] = self.df.get_column(txt_column).to_list() 104 | 105 | # Extract the column with masks 106 | if msk_column is not None: 107 | self.samples["msk"] = ( 108 | self.df.get_column(msk_column) 109 | .apply(lambda x: os.path.join(self.msk_dir, f"{x}{msk_suffix or ''}")) 110 | .to_list() 111 | ) 112 | 113 | self.print_stats() 114 | 115 | def __getitem__(self, idx: int): 116 | """Return a dictionary with the requested sample.""" 117 | sample = {"group_id": self.dataset_id} 118 | 119 | # Image 120 | if "img" in self.samples: 121 | sample["img"] = self.samples["img"][idx] 122 | if callable(self.img_transform): 123 | sample["img"] = self.img_transform(sample["img"]) 124 | 125 | # Labels 126 | if "lbl" in self.samples: 127 | sample["lbl"] = self.samples["lbl"][idx] 128 | if callable(self.lbl_transform): 129 | sample["lbl"] = self.lbl_transform(sample["lbl"]) 130 | 131 | # Mask 132 | if "msk" in self.samples: 133 | sample["msk"] = self.samples["msk"][idx] 134 | if callable(self.msk_transform): 135 | sample["msk"] = self.msk_transform(sample["msk"]) 136 | 137 | # Text 138 | if "txt" in self.samples: 139 | sample["txt"] = self.samples["txt"][idx] 140 | if callable(self.txt_transform): 141 | sample["txt"] = self.txt_transform(sample["txt"]) 142 | 143 | # Common transform applied on sample level 144 | if callable(self.com_transform): 145 | return self.com_transform(sample) 146 | 147 | return sample 148 | 149 | def __len__(self): 150 | return len(self.df) 151 | 152 | def print_stats(self): 153 | print( 154 | f""" 155 | === Dataset stats for split={self.split_name or "full"} === 156 | CSV file: {self.split_path} 157 | Data directory: {self.img_dir} 158 | Number of samples: {len(self.df)} 159 | """ 160 | ) 161 | 162 | def get_labels(self): 163 | return self.samples["lbl"].to_numpy().flatten() 164 | -------------------------------------------------------------------------------- /medvae/utils/factory.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from medvae.models import AutoencoderKL_2D, AutoencoderKL_3D 3 | from omegaconf import OmegaConf 4 | from medvae.utils.lora import inject_trainable_lora_extended 5 | from medvae.utils.loaders import load_mri_3d, load_ct_3d, load_2d 6 | 7 | HF_REPO_PATH = "stanfordmimi/MedVAE" 8 | 9 | FILE_DICT_ASSOCIATIONS = { 10 | "medvae_4_1_2d": { 11 | "config": "model_weights/medvae_4x1.yaml", 12 | "ckpt": "model_weights/vae_4x_1c_2D.ckpt", 13 | }, 14 | "medvae_4_3_2d": { 15 | "config": "model_weights/medvae_4x3.yaml", 16 | "ckpt": "model_weights/vae_4x_3c_2D.ckpt", 17 | }, 18 | "medvae_4_4_2d": { 19 | "config": "model_weights/medvae_4x4.yaml", 20 | "ckpt": "model_weights/vae_4x_4c_2D.ckpt", 21 | }, 22 | "medvae_8_1_2d": { 23 | "config": "model_weights/medvae_8x1.yaml", 24 | "ckpt": "model_weights/vae_8x_1c_2D.ckpt", 25 | }, 26 | "medvae_8_4_2d": { 27 | "config": "model_weights/medvae_8x4.yaml", 28 | "ckpt": "model_weights/vae_8x_4c_2D.ckpt", 29 | }, 30 | "medvae_4_1_3d": { 31 | "config": "model_weights/medvae_4x1.yaml", 32 | "ckpt": "model_weights/vae_4x_1c_3D.ckpt", 33 | }, 34 | "medvae_8_1_3d": { 35 | "config": "model_weights/medvae_8x1.yaml", 36 | "ckpt": "model_weights/vae_8x_1c_3D.ckpt", 37 | }, 38 | } 39 | 40 | """ 41 | Download model weights from Hugging Face Hub 42 | """ 43 | 44 | 45 | def download_model_weights(hfpath): 46 | fpath = hf_hub_download(repo_id=HF_REPO_PATH, filename=hfpath) 47 | return fpath 48 | 49 | 50 | """ 51 | Build the Med-VAE models for inference using the model weights 52 | """ 53 | 54 | 55 | def build_model( 56 | model_name: str, 57 | config_fpath: str, 58 | ckpt_fpath: str, 59 | training: bool = False, 60 | existing_weight: str = None, 61 | state_dict: bool = True, 62 | ): 63 | if ( 64 | model_name == "medvae_4_1_2d" 65 | or model_name == "medvae_8_1_2d" 66 | or model_name == "medvae_4_4_2d" 67 | ): 68 | conf = OmegaConf.load(config_fpath) 69 | model = AutoencoderKL_2D( 70 | ddconfig=conf.ddconfig, 71 | embed_dim=conf.embed_dim, 72 | ckpt_path=ckpt_fpath if existing_weight is None else existing_weight, 73 | state_dict=state_dict, 74 | ) 75 | elif model_name == "medvae_4_3_2d" or model_name == "medvae_8_4_2d": 76 | conf = OmegaConf.load(config_fpath) 77 | model = AutoencoderKL_2D( 78 | ddconfig=conf.model.params.ddconfig, 79 | embed_dim=conf.model.params.embed_dim, 80 | ) 81 | # If training, freeze the encoder and decoder and inject the lora 82 | if training: 83 | print( 84 | "Trainable Params before LORA:", 85 | sum(p.numel() for p in model.parameters() if p.requires_grad), 86 | ) 87 | model.encoder.requires_grad_(False) 88 | model.decoder.requires_grad_(False) 89 | _, _ = inject_trainable_lora_extended( 90 | model, {"ResnetBlock", "AttnBlock"}, r=4 91 | ) 92 | print( 93 | "Trainable Params after LORA:", 94 | sum(p.numel() for p in model.parameters() if p.requires_grad), 95 | ) 96 | else: 97 | _, _ = inject_trainable_lora_extended( 98 | model, {"ResnetBlock", "AttnBlock"}, r=4 99 | ) 100 | 101 | model.init_from_ckpt( 102 | ckpt_fpath if existing_weight is None else existing_weight, 103 | state_dict=state_dict, 104 | ) 105 | elif model_name == "medvae_4_1_3d" or model_name == "medvae_8_1_3d": 106 | conf = OmegaConf.load(config_fpath) 107 | model = AutoencoderKL_3D( 108 | ddconfig=conf.ddconfig, 109 | embed_dim=conf.embed_dim, 110 | ) 111 | model.init_from_ckpt( 112 | ckpt_fpath if existing_weight is None else existing_weight, 113 | state_dict=state_dict, 114 | ) 115 | 116 | return model 117 | 118 | 119 | """ 120 | Build the transform for the model 121 | """ 122 | 123 | 124 | def build_transform(model_name: str, modality: str): 125 | if "3d" in model_name: 126 | if "ct" in modality.lower(): 127 | transform = load_ct_3d 128 | elif "mri" in modality.lower(): 129 | transform = load_mri_3d 130 | else: 131 | raise ValueError(f"Modality {modality} not supported for 3D models") 132 | elif "2d" in model_name: 133 | transform = load_2d 134 | else: 135 | raise ValueError( 136 | f"Model name {model_name} not supported. Needs to be a 2D or 3D model." 137 | ) 138 | 139 | return transform 140 | 141 | 142 | """ 143 | Create a model and transform from a model name 144 | """ 145 | 146 | 147 | def create_model_and_transform( 148 | model_name: str, 149 | modality: str, 150 | ): 151 | # Check if model_name is in FILE_DICT_ASSOCIATIONS 152 | if model_name not in FILE_DICT_ASSOCIATIONS: 153 | raise ValueError(f"Model name {model_name} not found in FILE_DICT_ASSOCIATIONS") 154 | 155 | # Download the model_weights 156 | config_fpath = download_model_weights(FILE_DICT_ASSOCIATIONS[model_name]["config"]) 157 | ckpt_fpath = download_model_weights(FILE_DICT_ASSOCIATIONS[model_name]["ckpt"]) 158 | 159 | # Build the model 160 | model = build_model(model_name, config_fpath, ckpt_fpath) 161 | 162 | # Get the transform 163 | transform = build_transform(model_name, modality) 164 | 165 | return model, transform 166 | 167 | 168 | """ 169 | Create a model from a model name 170 | """ 171 | 172 | 173 | def create_model( 174 | model_name: str, 175 | existing_weight: str = None, 176 | training: bool = True, 177 | state_dict: bool = True, 178 | ): 179 | # Check if model_name is in FILE_DICT_ASSOCIATIONS 180 | if model_name not in FILE_DICT_ASSOCIATIONS: 181 | raise ValueError(f"Model name {model_name} not found in FILE_DICT_ASSOCIATIONS") 182 | 183 | # Download the model_weights 184 | config_fpath = download_model_weights(FILE_DICT_ASSOCIATIONS[model_name]["config"]) 185 | ckpt_fpath = download_model_weights(FILE_DICT_ASSOCIATIONS[model_name]["ckpt"]) 186 | 187 | # Build the model 188 | model = build_model( 189 | model_name, 190 | config_fpath, 191 | ckpt_fpath, 192 | training=training, 193 | existing_weight=existing_weight, 194 | state_dict=state_dict, 195 | ) 196 | 197 | return model 198 | -------------------------------------------------------------------------------- /documentation/finetune.md: -------------------------------------------------------------------------------- 1 | # 🔧 MedVAE Finetuning Documentation 2 | 3 | Our finetuning framework leverages the flexibility and power of [Hydra](https://github.com/facebookresearch/hydra) and [HuggingFace Accelerate](https://github.com/huggingface/accelerate), providing multi-GPU training support and easy integration with [Weights and Biases (wandb)](https://wandb.ai/) for experiment tracking. We recommend using Python 3.9 and installing packages specified in our `pyproject.toml` file. 4 | 5 | While MedVAE is primarily tested and validated on medical imaging modalities like X-ray, MRI, and CT, our framework can potentially support finetuning on datasets from other imaging modalities. However, please note that performance has not been validated outside these modalities. 6 | 7 | ## Configuration Structure (Hydra) 8 | 9 | Through Hydra, we are able to modify our hyperparameters for finetuning through [config files](../configs/). Our configuration files for finetuning are organized into three main directories: 10 | 11 | - **Criterion:** Contains loss function configurations for various training stages. Notable losses include: 12 | 13 | - `lpips_with_discriminator`: Used for 2D stage 1 and 3D stage 2 finetuning. 14 | - `biomedclip`: Used specifically for 2D stage 2 finetuning. 15 | 16 | - **Dataloader:** Includes preconfigured data loaders for different imaging types: 17 | 18 | - `mmgs.yaml`: Loads 2D Full-Field Digital Mammograms (FFDMs). 19 | - `mri_ct_3d.yaml`: Loads 3D MRI and CT imaging data. 20 | 21 | - **Experiment:** Centralizes and abstracts hyperparameters, allowing easy customization of your finetuning process. You will mainly need to change parameters in your experiment file for different finetuning runs. 22 | 23 | ## Running Finetuning Stages 24 | 25 | We provide example configuration files for both 2D and 3D image finetuning using a 4x downscaled latent representation. Adjust the GPU identifier (`CUDA_VISIBLE_DEVICES`) and batch size according to your hardware capabilities and dataset size. Typically, larger batch sizes are recommended. 26 | 27 | ### Stage 1 (2D Finetuning) 28 | 29 | Finetune the base model using 4x downsizing with either 1-channel or 3-channel latent representations. Key parameters to update in the experiment configuration file include `dataloader`, `dataset_name`, and `task_name`. 30 | 31 | - **1-channel latent:** 32 | 33 | ```python 34 | CUDA_VISIBLE_DEVICES=0 medvae_finetune experiment=medvae_4x_1c_2d_finetuning 35 | ``` 36 | 37 | - **3-channel latent (using LoRA):** 38 | 39 | ```python 40 | CUDA_VISIBLE_DEVICES=0 medvae_finetune experiment=medvae_4x_3c_2d_finetuning 41 | ``` 42 | 43 | ### Stage 2 (2D Finetuning) 44 | 45 | Stage 2 involves training a lightweight projection layer to enrich latent representations for downstream tasks. Ensure the `stage2_ckpt` parameter in the experiment file points to your stage 1 finetuning checkpoint. 46 | 47 | - **1-channel latent:** 48 | 49 | ```python 50 | CUDA_VISIBLE_DEVICES=0 medvae_finetune_s2 experiment=medvae_4x_1c_2d_s2_finetuning 51 | ``` 52 | 53 | - **3-channel latent (using LoRA):** 54 | 55 | ```python 56 | CUDA_VISIBLE_DEVICES=0 medvae_finetune_s2 experiment=medvae_4x_3c_2d_s2_finetuning 57 | ``` 58 | 59 | ### Stage 2 (3D Finetuning) 60 | 61 | Directly finetune 3D latent representations. Similar to stage 1, ensure appropriate parameters (`dataloader`, `dataset_name`, `task_name`) are correctly configured. 62 | 63 | ```python 64 | CUDA_VISIBLE_DEVICES=0 medvae_finetune experiment=medvae_4x_1c_3d_finetuning 65 | ``` 66 | 67 | ## Multi-GPU Training 68 | 69 | Multi-GPU training is seamlessly supported through Accelerate. Configure Accelerate appropriately, then specify multiple GPUs as shown below: 70 | 71 | ```python 72 | CUDA_VISIBLE_DEVICES=1,2,3,4 medvae_finetune experiment=medvae_4x_1c_2d_finetuning 73 | ``` 74 | 75 | ## CSV File Creation 76 | 77 | For your data, you will need to create a CSV-file with the appropriate train, val, and test splits. Please see [create_csv.ipynb](create_csv.ipynb) to assist on this task. 78 | 79 | ## Symbolic Links for Data 80 | 81 | If you prefer not to modify the dataloader configuration, you can symbolically link your dataset to the default data directory: 82 | 83 | ```bash 84 | ln -s /medvae/data 85 | ``` 86 | 87 | ## Creating Custom Data Loaders 88 | 89 | For maximum flexibility, we recommend creating your own data loading method. You will need to change the loader in the dataloader configuration file. Refer to the [`loaders.py`](../medvae/utils/loaders.py) file to see examples: 90 | 91 | - `load_2d_finetune` 92 | - `load_mri_3d_finetune` 93 | - `load_ct_3d_finetune` 94 | 95 | Use these functions as templates for developing loaders tailored to your specific dataset structure and requirements. 96 | 97 | ## Logging with Weights & Biases 98 | 99 | To log training runs with wandb, ensure your wandb API key is set up. Enable logging as follows: 100 | 101 | ```python 102 | CUDA_VISIBLE_DEVICES=0 medvae_finetune experiment=medvae_4x_1c_2d_finetuning logger=wandb 103 | ``` 104 | 105 | Note: We recommend using `wandb` version `0.14.0`. We have noticed logging errors with other versions with our setup. 106 | 107 | ## Inference Post-Finetuning 108 | 109 | Use our built-in inference engine to perform inference on your finetuned models: 110 | 111 | ```python 112 | medvae_inference -i INPUT_FOLDER -o OUTPUT_FOLDER -model_name MED_VAE_MODEL -modality MODALITY -ckpt_path YOUR_CKPT_PATH 113 | ``` 114 | 115 | ## Troubleshooting Tips 116 | 117 | - If you encounter a `state_dict` warning during checkpoint loading, simply wrap your checkpoint weights within a dictionary under the `'state_dict'` key. 118 | - Creating a separate conda environment will help debugging considerably. Especially with hydra / accelerate configurations. 119 | - Input to the VAEs will need to be normalized between \[-1, 1\]. The already provided dataloaders handle this. 120 | - The discriminator currently starts after 3125 steps. If you want it to start earlier, you can adjust it in the main config experiment file. We set it to 3125 for our batch sizes, which was 32. Typically, the discriminator can discriminate pretty quickly, so that is why you set it to train a bit later after the model has finetuned for a bit. The discriminator is randomly initialized based on a small distribution (ln 236 in vae_losses; line 70 in loss components). 121 | - We recommend maintaining gradient accumulation as 1 for numerical stability. 122 | - Do not worry if the L1 loss (reconstruction) and perceptual loss are wildly different. They are on different scales, but this should not affect the backprop, since the gradient directions would stay the same. 123 | 124 | ## Support 125 | 126 | For questions or issues regarding finetuning MedVAE models, please submit a request on our [GitHub issues page](https://github.com/StanfordMIMI/MedVAE/issues). 127 | -------------------------------------------------------------------------------- /medvae/medvae_main.py: -------------------------------------------------------------------------------- 1 | from medvae.utils.factory import create_model_and_transform 2 | from medvae.utils.extras import roi_size_calc 3 | import torch 4 | from monai.inferers import sliding_window_inference 5 | 6 | """ 7 | Large Med-VAE class to abstract all the models. 8 | 9 | This allows for interfacing with Med-VAE as a pytorch model. 10 | Can be used for 2D and 3D inference / finetuning. 11 | 12 | @param model_name: The name of the model to use. Choose between: 13 | (1) medvae_4_1_2d: 2D images with a 4x compression in each dim (16x total) with a 1 channel latent. 14 | (2) medvae_4_3_2d: 2D images with a 4x compression in each dim (64x total) with a 3 channel latent. 15 | (3) medvae_8_1_2d: 2D images with an 8x compression in each dim (64x total) with a 1 channel latent. 16 | (4) medvae_8_4_2d: 2D images with an 8x compression in each dim (64x total) with a 4 channel latent. 17 | (5) medvae_4_1_3d: 3D images with a 4x compression in each dim (64x total) with a 1 channel latent. 18 | (6) medvae_8_1_3d: 3D images with an 8x compression in each dim (64x total) with a 1 channel latent. 19 | 20 | @param modality: Modality of the input images. Choose between xray, ct, or mri. 21 | 22 | @param gpu_dim: The maximum dimension size allowed for processing on the GPU (default is 160). 23 | 24 | @return (forward): The latent representation of the input image (torch.tensor). 25 | """ 26 | 27 | 28 | class MVAE(torch.nn.Module): 29 | def __init__(self, model_name: str, modality: str, gpu_dim=160): 30 | super(MVAE, self).__init__() 31 | 32 | self.model_name = model_name 33 | self.modality = modality 34 | 35 | self.model, self.transform = create_model_and_transform( 36 | self.model_name, self.modality 37 | ) 38 | 39 | self.gpu_dim = gpu_dim 40 | 41 | self.encoded_latent = None 42 | self.decoded_latent = None 43 | 44 | def apply_transform(self, fpath: str): 45 | if "3d" in self.model_name: 46 | return self.transform(fpath).unsqueeze(0) 47 | elif "2d" in self.model_name: 48 | return self.transform( 49 | fpath, merge_channels="1_2d" in self.model_name 50 | ).unsqueeze(0) 51 | else: 52 | raise ValueError( 53 | f"Model name {self.model_name} not supported. Needs to be a 2D or 3D model." 54 | ) 55 | 56 | def get_transform(self): 57 | return self.transform 58 | 59 | def init_from_ckpt(self, ckpt_path: str, state_dict: bool = True): 60 | self.model.init_from_ckpt(ckpt_path, state_dict=state_dict) 61 | 62 | def _process_3d(self, img, decode: bool = False): 63 | """Handle 3D image processing with sliding window.""" 64 | 65 | def predict_latent(patch): 66 | if decode: 67 | dec, _, z = self.model(patch, decode=True) 68 | return dec, z 69 | else: 70 | z, _, _ = self.model(patch, decode=False) 71 | return z 72 | 73 | roi_size = roi_size_calc(img.shape[-3:], target_gpu_dim=self.gpu_dim) 74 | result = sliding_window_inference( 75 | inputs=img, 76 | roi_size=roi_size, 77 | sw_batch_size=1, 78 | mode="gaussian", 79 | predictor=predict_latent, 80 | ) 81 | 82 | if decode: 83 | dec, latent = result 84 | # This is the decoded image and the latent representation of the image 85 | return dec.squeeze().squeeze(), latent.squeeze().squeeze() 86 | else: 87 | # This is the latent representation of the image 88 | return result.squeeze().squeeze() 89 | 90 | def _process_2d(self, img, decode: bool = False): 91 | """Handle 2D image processing.""" 92 | if decode: 93 | dec, _, latent = self.model(img, decode=True) 94 | # This is the decoded image and the latent representation of the image 95 | return dec.squeeze().squeeze(), latent.squeeze().squeeze() 96 | else: 97 | _, _, latent = self.model(img, decode=False) 98 | # This is the latent representation of the image 99 | return latent.squeeze().squeeze() 100 | 101 | def encode(self, img: torch.tensor): 102 | """Encode the image into a latent representation. (S1 for 2D, S2 for 3D)""" 103 | if "3d" in self.model_name: 104 | 105 | def encode_latent(patch): 106 | z, _, _ = self.model(patch, decode=False) 107 | return z 108 | 109 | roi_size = roi_size_calc(img.shape[-3:], target_gpu_dim=self.gpu_dim) 110 | s2_latent = sliding_window_inference( 111 | inputs=img, 112 | roi_size=roi_size, 113 | sw_batch_size=1, 114 | mode="gaussian", 115 | predictor=encode_latent, 116 | ) 117 | 118 | return s2_latent 119 | 120 | if "2d" in self.model_name: 121 | s1_latent = self.model.encode(img).sample() 122 | return s1_latent 123 | 124 | def decode(self, latent: torch.tensor): 125 | """Decode the latent representation into an image. (S1 for 2D, S2 for 3D)""" 126 | if "3d" in self.model_name: 127 | 128 | def decode_latent(patch): 129 | dec = self.model.decode(patch) 130 | return dec 131 | 132 | # Extract compression factor from model name (e.g., "medvae_4_1_3d" -> 4) 133 | compression_factor = int(self.model_name.split("_")[1]) 134 | 135 | # Calculate ROI size for the original dimensions 136 | roi_size = roi_size_calc( 137 | [x * compression_factor for x in latent.shape[-3:]], 138 | target_gpu_dim=self.gpu_dim, 139 | ) 140 | 141 | # Scale down the ROI size to match the latent space 142 | roi_size = [size // compression_factor for size in roi_size] 143 | 144 | dec = sliding_window_inference( 145 | inputs=latent, 146 | roi_size=roi_size, 147 | sw_batch_size=1, 148 | mode="gaussian", 149 | predictor=decode_latent, 150 | ) 151 | return dec 152 | 153 | if "2d" in self.model_name: 154 | dec = self.model.decode(latent) 155 | return dec 156 | 157 | """ 158 | Forward pass for the model. It will return the S2 2D and 3D latent representation. 159 | @param img: The image to run inference on. 160 | @return: The latent representation of the input image. 161 | """ 162 | 163 | def forward(self, img: torch.tensor, decode: bool = False): 164 | if "3d" in self.model_name: 165 | return self._process_3d(img, decode) 166 | 167 | if "2d" in self.model_name: 168 | return self._process_2d(img, decode) 169 | -------------------------------------------------------------------------------- /medvae/utils/cls/train_components.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from time import time 3 | from typing import Any, List, Tuple 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from torch import nn 8 | from torch.optim import Optimizer 9 | from torch.utils.data import DataLoader 10 | from torchmetrics import Metric 11 | from tqdm import tqdm 12 | 13 | from medvae.utils.extras import get_weight_dtype 14 | from medvae.utils.transforms import to_dict 15 | 16 | 17 | __all__ = ["training_epoch", "validation_epoch"] 18 | 19 | 20 | def training_epoch( 21 | cfg: Any, 22 | epoch: int, 23 | accelerator: Accelerator, 24 | dataloader: DataLoader, 25 | model: nn.Module, 26 | criterion: nn.Module, 27 | optimizer: Optimizer, 28 | default_metrics: List[Metric], 29 | classif_metrics: List[Tuple[str, Metric]], 30 | ): 31 | """Train one epoch for a classifier.""" 32 | 33 | # Setup metrics 34 | for metric in default_metrics: 35 | metric.reset() 36 | 37 | for _, metric in classif_metrics: 38 | metric.reset() 39 | 40 | metric_loss, metric_data, metric_batch = default_metrics 41 | model.train() 42 | 43 | batch_start = time() 44 | dtype = get_weight_dtype(accelerator) 45 | 46 | for i, batch in enumerate(dataloader): 47 | data_time = time() - batch_start 48 | 49 | with accelerator.accumulate(model): 50 | batch = to_dict(batch) 51 | images, targets = batch["img"].to(dtype), batch["lbl"] 52 | 53 | # Compute output and loss 54 | output = model(images) 55 | loss = criterion(output, targets) 56 | 57 | optimizer.zero_grad() 58 | accelerator.backward(loss) 59 | optimizer.step() 60 | 61 | batch_time = time() - batch_start 62 | 63 | # Update metrics 64 | metric_loss.update(loss) 65 | metric_data.update(data_time) 66 | metric_batch.update(batch_time) 67 | 68 | output, targets = accelerator.gather_for_metrics((output, targets)) 69 | for _, metric in classif_metrics: 70 | metric.update(output, targets) 71 | 72 | # Logging 73 | print( 74 | f"\r[Epoch <{epoch:03}/{cfg['max_epoch']}>: Step <{i:03}/{len(dataloader)}>] - " 75 | + f"Data(s): {data_time:.3f} ({metric_data.compute():.3f}) - " 76 | + f"Batch(s): {batch_time:.3f} ({metric_batch.compute():.3f}) - " 77 | + f"Loss: {loss.item():.4f} ({metric_loss.compute():.4f}) \n" 78 | ) 79 | 80 | if cfg.get("logger", False) and i % cfg["log_every_n_steps"] == 0: 81 | log_data = { 82 | "epoch": epoch, 83 | "mean_loss": metric_loss.compute(), 84 | "mean_data": metric_data.compute(), 85 | "mean_batch": metric_batch.compute(), 86 | "step": i, 87 | "step_global": len(dataloader) * epoch + i, 88 | "step_loss": loss, 89 | "step_data": data_time, 90 | "step_batch": batch_time, 91 | } 92 | 93 | for name, metric in classif_metrics: 94 | log_data[name] = metric.compute() 95 | 96 | for idx, pg in enumerate(optimizer.param_groups): 97 | name = pg["name"] if "name" in pg else f"param_group_{idx}" 98 | if "lr" in pg: 99 | log_data[f"{name}/lr"] = pg["lr"] 100 | 101 | if "momentum" in pg: 102 | log_data[f"{name}/momentum"] = pg["momentum"] 103 | 104 | if "weight_decay" in pg: 105 | log_data[f"{name}/weight_decay"] = pg["weight_decay"] 106 | 107 | accelerator.log(log_data) 108 | 109 | batch_start = time() 110 | if cfg.get("fast_dev_run", False): 111 | break 112 | 113 | 114 | def validation_epoch( 115 | cfg: Any, 116 | epoch: int, 117 | accelerator: Accelerator, 118 | dataloader: DataLoader, 119 | model: nn.Module, 120 | criterion: nn.Module, 121 | default_metrics: List[Metric], 122 | classif_metrics: List[Tuple[str, Metric]], 123 | slice_metrics: List[Tuple[str, int, Metric]], 124 | ): 125 | """Validate one epoch for a classifier.""" 126 | 127 | # setup metrics 128 | for metric in default_metrics: 129 | metric.reset() 130 | 131 | for _, metric in classif_metrics: 132 | metric.reset() 133 | 134 | slice_metric_counter = defaultdict(int) 135 | for _, _, metric in slice_metrics: 136 | metric.reset() 137 | 138 | metric_loss = default_metrics[0] 139 | metric_loss.reset() 140 | 141 | model.eval() 142 | # if model1 is not None: 143 | # model1.eval() 144 | dtype = get_weight_dtype(accelerator) 145 | with torch.no_grad(): 146 | for i, batch in tqdm(enumerate(dataloader)): 147 | try: 148 | batch = to_dict(batch) 149 | images, targets, group_id = ( 150 | batch["img"].to(dtype), 151 | batch["lbl"], 152 | batch["group_id"], 153 | ) 154 | 155 | output = model(images) 156 | loss = criterion(output, targets) 157 | 158 | # Update metrics 159 | metric_loss.update(loss) 160 | 161 | output, targets, group_id = accelerator.gather_for_metrics( 162 | (output, targets, group_id) 163 | ) 164 | for _, metric in classif_metrics: 165 | metric.update(output, targets) 166 | 167 | for n, g, metric in slice_metrics: 168 | idx = group_id == g 169 | if idx.sum() == 0: 170 | continue 171 | metric.update(output[idx], targets[idx]) 172 | slice_metric_counter[n] += idx.sum() 173 | 174 | except Exception as e: 175 | print("<= Exception in cls_eval =>") 176 | print(e) 177 | 178 | if cfg.get("fast_dev_run", False): 179 | return 0 180 | 181 | # logging 182 | print(f"Validation - loss: {loss.item():.4f} ({metric_loss.compute():.4f}) \n") 183 | if cfg.get("logger", False): 184 | log_data = { 185 | "valid/epoch": epoch, 186 | "valid/mean_loss": metric_loss.compute(), 187 | "valid/step_loss": loss, 188 | } 189 | 190 | for name, metric in classif_metrics: 191 | log_data[f"valid/{name}"] = metric.compute() 192 | 193 | for name, g, metric in slice_metrics: 194 | if slice_metric_counter[name] > 0: 195 | log_data[f"valid_slice/{name}"] = metric.compute() 196 | 197 | accelerator.log(log_data) 198 | 199 | return metric_loss.compute() 200 | -------------------------------------------------------------------------------- /medvae/metrics/reconstruction_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import ( 3 | Metric, 4 | MultiScaleStructuralSimilarityIndexMeasure, 5 | PeakSignalNoiseRatio, 6 | ) 7 | from torchmetrics.image.fid import FrechetInceptionDistance 8 | 9 | from monai.transforms import SpatialPad 10 | 11 | 12 | class MSE(Metric): 13 | def __init__(self): 14 | super().__init__() 15 | self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum") 16 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 17 | 18 | def update(self, images, reconstructions): 19 | assert images.shape == reconstructions.shape 20 | images = images.contiguous() 21 | reconstructions = reconstructions.contiguous() 22 | 23 | images = images.view(images.shape[0], -1) 24 | reconstructions = reconstructions.view(reconstructions.shape[0], -1) 25 | err = ((images - reconstructions) ** 2).mean(-1) 26 | 27 | self.error += err.sum() 28 | self.total += images.shape[0] 29 | 30 | def compute(self): 31 | return self.error.float() / self.total 32 | 33 | 34 | class PSNR(Metric): 35 | def __init__(self): 36 | super().__init__() 37 | self.add_state("psnr", default=torch.tensor(0.0), dist_reduce_fx="sum") 38 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 39 | 40 | self.func = PeakSignalNoiseRatio(data_range=1.0, reduction=None, dim=1) 41 | 42 | def update(self, images, reconstructions): 43 | assert images.shape == reconstructions.shape 44 | images = images.contiguous() 45 | reconstructions = reconstructions.contiguous() 46 | 47 | # Undo normalization and transform images to 0 to 1 range 48 | images = images.view(images.shape[0], -1) 49 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 50 | 51 | # Undo normalization and transform reconstructions to 0 to 1 range 52 | reconstructions = reconstructions.view(reconstructions.shape[0], -1) 53 | reconstructions = torch.clamp((reconstructions + 1.0) / 2.0, min=0.0, max=1.0) 54 | 55 | # Compute PSNR 56 | psnr = self.func(reconstructions, images).sum() 57 | 58 | self.psnr += psnr 59 | self.total += images.shape[0] 60 | 61 | def compute(self): 62 | return self.psnr.float() / self.total 63 | 64 | 65 | class MS_SSIM(Metric): 66 | def __init__(self): 67 | super().__init__() 68 | self.add_state("ms_ssim", default=torch.tensor(0.0), dist_reduce_fx="sum") 69 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 70 | 71 | self.func = MultiScaleStructuralSimilarityIndexMeasure( 72 | reduction="none", data_range=1.0 73 | ) 74 | 75 | def update(self, images, reconstructions): 76 | assert images.shape == reconstructions.shape 77 | 78 | # Undo normalization and transform reconstructions to 0 to 1 range 79 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 80 | 81 | # Undo normalization and transform reconstructions to 0 to 1 range 82 | reconstructions = torch.clamp((reconstructions + 1.0) / 2.0, min=0.0, max=1.0) 83 | 84 | # Compute MS-SSIM 85 | ms_ssim = self.func(reconstructions, images).sum() 86 | 87 | self.ms_ssim += ms_ssim 88 | self.total += images.shape[0] 89 | 90 | def compute(self): 91 | return self.ms_ssim.float() / self.total 92 | 93 | 94 | class MS_SSIM_SMALL(Metric): 95 | def __init__(self): 96 | super().__init__() 97 | self.add_state("ms_ssim", default=torch.tensor(0.0), dist_reduce_fx="sum") 98 | self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") 99 | 100 | self.func = MultiScaleStructuralSimilarityIndexMeasure( 101 | reduction="none", data_range=1.0 102 | ) 103 | 104 | def update(self, images, reconstructions): 105 | assert images.shape == reconstructions.shape 106 | 107 | # Undo normalization and transform reconstructions to 0 to 1 range 108 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) 109 | images = SpatialPad(spatial_size=(1, 192, 192, 192))(images).as_tensor() 110 | 111 | # Undo normalization and transform reconstructions to 0 to 1 range 112 | reconstructions = torch.clamp((reconstructions + 1.0) / 2.0, min=0.0, max=1.0) 113 | reconstructions = SpatialPad(spatial_size=(1, 192, 192, 192))( 114 | reconstructions 115 | ).as_tensor() 116 | 117 | # Compute MS-SSIM 118 | ms_ssim = self.func(reconstructions, images).sum() 119 | 120 | self.ms_ssim += ms_ssim 121 | self.total += images.shape[0] 122 | 123 | def compute(self): 124 | return self.ms_ssim.float() / self.total 125 | 126 | 127 | class FID_Inception(Metric): 128 | def __init__(self): 129 | super().__init__() 130 | self.func = FrechetInceptionDistance(normalize=True) 131 | 132 | def update(self, images, reconstructions): 133 | assert images.shape == reconstructions.shape 134 | 135 | # Undo normalization and transform reconstructions to 0 to 1 range 136 | images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0).expand( 137 | -1, 3, -1, -1 138 | ) 139 | 140 | # Undo normalization and transform reconstructions to 0 to 1 range 141 | reconstructions = torch.clamp( 142 | (reconstructions + 1.0) / 2.0, min=0.0, max=1.0 143 | ).expand(-1, 3, -1, -1) 144 | 145 | # Compute FID 146 | self.func.update(images, real=True) 147 | self.func.update(reconstructions, real=False) 148 | 149 | def compute(self): 150 | return self.func.compute() 151 | 152 | 153 | class FID_Inception_3D(Metric): 154 | def __init__(self): 155 | super().__init__() 156 | self.func = FID_Inception() 157 | 158 | def update(self, images, reconstructions): 159 | assert images.shape == reconstructions.shape 160 | 161 | for dim_idx, dim_name in enumerate(["depth", "height", "width"], start=2): 162 | # Iterate over slices along the current dimension 163 | for j in range(images.size(dim_idx)): 164 | # Select the appropriate slice along each dimension 165 | if dim_name == "depth": 166 | slice_i = images[:, :, j, :, :] 167 | recon_i = reconstructions[:, :, j, :, :] 168 | elif dim_name == "height": 169 | slice_i = images[:, :, :, j, :] 170 | recon_i = reconstructions[:, :, :, j, :] 171 | else: # dim_name == 'width' 172 | slice_i = images[:, :, :, :, j] 173 | recon_i = reconstructions[:, :, :, :, j] 174 | 175 | # Compute FID 176 | self.func.update(slice_i, recon_i) 177 | 178 | def compute(self): 179 | return self.func.compute() 180 | -------------------------------------------------------------------------------- /documentation/inference.md: -------------------------------------------------------------------------------- 1 | # Inference Usage Instruction 2 | 3 | MedVAE can be run using either: 4 | 5 | - A PyTorch model (programmatic use) 6 | - A command-line interface (CLI) (recommended for beginners) 7 | 8 | **Please see the [demo](demo.ipynb) for programmatic examples.** 9 | 10 | If you are new to MedVAE and want to downsize your medical images, the CLI approach is recommended. 11 | 12 | ## **Available MedVAE Models** 13 | 14 | MedVAE provides **six pre-trained models** for **2D and 3D medical images**, each with different compression settings: 15 | 16 | ### **📌 2D Models** 17 | 18 | | Model Name | Compression | Latent Channels | Total Compression | 19 | |------------------|------------|-----------------|-------------------| 20 | | `medvae_4_1_2d` | 4× per dim | 1 | 16× total | 21 | | `medvae_4_3_2d` | 4× per dim | 3 | 16× total | 22 | | `medvae_8_1_2d` | 8× per dim | 1 | 64× total | 23 | | `medvae_8_4_2d` | 8× per dim | 4 | 64× total | 24 | 25 | ### **📌 3D Models** 26 | 27 | | Model Name | Compression | Latent Channels | Total Compression | 28 | |------------------|------------|-----------------|-------------------| 29 | | `medvae_4_1_3d` | 4× per dim | 1 | 64× total | 30 | | `medvae_8_1_3d` | 8× per dim | 1 | 512× total | 31 | 32 | ## 👨‍💻 Programmatic Usage 33 | 34 | If you are integrating MedVAE into an existing PyTorch workflow, using it as a PyTorch model is recommended. The [MVAE](../medvae/medvae.py) class provides an easy way to load and use MedVAE models programmatically. 35 | 36 | #### **Instantiating a MedVAE Model** 37 | 38 | To create an `MVAE` model object, three parameters are needed: 39 | 40 | - **`model_name`** – Specifies which of the six available MedVAE models to use. 41 | - **`modality`** – Defines the medical imaging modality (`"xray"`, `"ct"`, or `"mri"`). 42 | - **`gpu_dim`** (optional) – Sets the largest volumetric dimension the GPU can handle. 43 | - Default: `160`, optimized for a 48GB Nvidia A6000 GPU. 44 | 45 | #### **Applying Tranforms** 46 | 47 | The `MVAE` class provides an `apply_transforms` method, which automatically applies the appropriate transformation based on the input file type and modality. 48 | 49 | - **2D MedVAE models** → Input should be a 2D `.png` file. 50 | - **3D MedVAE models** → Input should be a compressed 3D NIfTI (`*.nii.gz`) file. 51 | 52 | For more details, the transforms file is located [here](../medvae/utils/loaders.py). 53 | 54 | #### **Example Usage:** 55 | 56 | ```python 57 | import torch 58 | from medvae import MVAE 59 | 60 | fpath = "documentation/data/mmg_data/isJV8hQ2hhJsvEP5rdQNiy.png" 61 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | 63 | model = MVAE(model_name="medvae_4_3_2d", modality="xray").to(device) 64 | img = model.apply_transform(fpath).to(device) 65 | 66 | model.requires_grad_(False) 67 | model.eval() 68 | 69 | with torch.no_grad(): 70 | latent = model(img) 71 | ``` 72 | 73 | ## Obtain Decoded Representations 74 | 75 | To obtain the decoded representations, you can set `decode = True` when calling the forward method of the `MVAE` class. Here's an example of how to use it: 76 | 77 | ```python 78 | with torch.no_grad(): 79 | decoded_img, latent = model(img, decode=True) 80 | ``` 81 | 82 | The `decoded_img` variable will contain the decoded image, and the `latent` variable will contain the latent representation of the input image. 83 | 84 | ## Using MedVAE part of LDM pipelines 85 | 86 | We provide two functions, `encode` and `decode`, to work with latent representations produced by the model: Stage 1 latents for 2D images and Stage 2 latents for 3D volumes. These latents can potentially be used in downstream tasks such as medical latent diffusion models (*though we have not tested this*). 87 | 88 | ```python 89 | with torch.no_grad(): 90 | latent = model.encode(img) 91 | dec = model.decode(img) 92 | ``` 93 | 94 | ## 🖥️ CLI Usage 95 | 96 | The CLI script runs inference using MedVAE, processing 2D or 3D medical images to generate latent representations. It allows users to specify a pretrained MedVAE model and input modalities (X-ray, CT, MRI). Given an input directory, it will process all the medical images into latent representations and save them in the specified folder. 97 | 98 | ```python 99 | medvae_inference -i INPUT_FOLDER -o OUTPUT_FOLDER -model_name MED_VAE_MODEL -modality MODALITY 100 | ``` 101 | 102 | ### Arguments 103 | 104 | | Argument | Type | Required | Description | 105 | |--------------|------|----------|-------------------------------------------------------------------------------------------------| 106 | | -i | str | ✅ Yes | Path to the input folder containing images (\*.png for 2D, \*.nii.gz for 3D). The filenames must not contain multiple dots. | 107 | | -o | str | ✅ Yes | Path to the output folder where latent representations will be saved. If the folder does not exist, it will be created. | 108 | | -model_name | str | ✅ Yes | Specifies the Med-VAE model to use. See available options above. | 109 | | -modality | str | ✅ Yes | Specifies the image modality: "xray", "ct", or "mri". | 110 | | -roi_size | int | ❌ No (Default: 160) | Sets the region of interest (ROI) size for 3D models (used to manage GPU memory). | 111 | | -device | str | ❌ No (Default: "cuda") | Specifies the device to run inference on: "cuda" (GPU), "cpu" (CPU), "mps" (Apple M1/M2). Do not specify GPU ID here! Use CUDA_VISIBLE_DEVICES=X instead. | 112 | 113 | ## 🤗 Model Files on Huggingface 114 | 115 | | Total Compression Factor | Channels | Dimensions | Modalities | Anatomies | Config File | Model File | 116 | |----------|----------|----------|----------|----------|----------|----------| 117 | | 16 | 1 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_4x1.yaml ](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_4x1.yaml) | [vae_4x_1c_2D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_4x_1c_2D.ckpt) 118 | | 16 | 3 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_4x3.yaml](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_4x3.yaml) | [vae_4x_3c_2D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_4x_3c_2D.ckpt) 119 | | 64 | 1 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_8x1.yaml](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_8x1.yaml) | [vae_8x_1c_2D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_8x_1c_2D.ckpt) 120 | | 64 | 3 | 2D | X-ray | Chest, Breast (FFDM) | [medvae_8x4.yaml](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_8x4.yaml) | [vae_8x_4c_2D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_8x_4c_2D.ckpt) 121 | | 64 | 1 | 3D | MRI, CT | Whole-Body | [medvae_4x1.yaml ](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_4x1.yaml) | [vae_4x_1c_3D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_4x_1c_3D.ckpt) 122 | | 512 | 1 | 3D | MRI, CT | Whole-Body | [medvae_8x1.yaml](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/medvae_8x1.yaml) | [vae_8x_1c_3D.ckpt](https://huggingface.co/stanfordmimi/MedVAE/blob/main/model_weights/vae_8x_1c_3D.ckpt) 123 | 124 | ## Creating a MedVAE conda environment 125 | 126 | Run the following in your terminal or command prompt: 127 | 128 | ```python 129 | conda create --name medvae python=3.9 130 | ``` 131 | 132 | To activate the environment, enter: 133 | 134 | ```python 135 | conda activate medvae 136 | ``` 137 | 138 | To delete the environment, enter: 139 | 140 | ```python 141 | conda remove --name medvae --all 142 | ``` 143 | 144 | ## Running pre-commit 145 | 146 | To install the project as a development package, run the following command in your terminal or command prompt: 147 | 148 | ```python 149 | pip install -e .[dev] 150 | ``` 151 | 152 | Install pre-commit 153 | 154 | ```python 155 | pre-commit install 156 | ``` 157 | 158 | Run pre-commit 159 | 160 | ```python 161 | pre - commit 162 | ``` 163 | -------------------------------------------------------------------------------- /medvae/utils/loaders.py: -------------------------------------------------------------------------------- 1 | from monai.transforms import ( 2 | EnsureChannelFirst, 3 | Compose, 4 | LoadImage, 5 | Orientation, 6 | ScaleIntensity, 7 | CropForeground, 8 | ScaleIntensityRange, 9 | RandSpatialCrop, 10 | ) 11 | import torch 12 | import torch.nn.functional as F 13 | from monai.transforms import Transform 14 | import torchvision 15 | from PIL import Image 16 | import polars as pl 17 | import numpy as np 18 | 19 | 20 | class MonaiNormalize(Transform): 21 | def __init__(self, mean, std): 22 | self.mean = mean 23 | self.std = std 24 | 25 | def __call__(self, img): 26 | return torchvision.transforms.Normalize(self.mean, self.std)(img) 27 | 28 | 29 | class MonaiPad(Transform): 30 | def __init__(self, size, mode="constant", value=0): 31 | self.size = size 32 | self.mode = mode 33 | self.value = value 34 | 35 | def __call__(self, img): 36 | padding = [] 37 | for i in range(len(img.shape) - 1, 0, -1): # Start from the last dimension 38 | total_pad = max(self.size[i - 1] - img.shape[i], 0) 39 | padding.extend([total_pad // 2, total_pad - total_pad // 2]) 40 | padded_img = F.pad(img, padding, mode=self.mode, value=self.value) 41 | return padded_img 42 | 43 | 44 | class MonaiImageOpen(Transform): 45 | def __init__(self): 46 | pass 47 | 48 | def __call__(self, path): 49 | return Image.open(path) 50 | 51 | 52 | """ 53 | Custom transform to normalize and pad 2D images 54 | @input: path to image (str) 55 | @Output: padding (np.array) 56 | """ 57 | 58 | 59 | def load_2d( 60 | path: str, 61 | merge_channels: bool = False, 62 | dtype: torch.dtype = torch.float32, 63 | **kwargs, 64 | ): 65 | img_transforms = Compose( 66 | transforms=[ 67 | MonaiImageOpen(), 68 | torchvision.transforms.ToTensor(), 69 | ScaleIntensity(channel_wise=True, minv=0, maxv=1), 70 | MonaiNormalize(mean=[0.5], std=[0.5]), 71 | ], 72 | lazy=True, 73 | ) 74 | 75 | try: 76 | img = img_transforms(path).as_tensor() 77 | 78 | if merge_channels: 79 | img = img.mean(0, keepdim=True) 80 | 81 | return img 82 | except Exception as e: 83 | print(f"Error in loading {path} with error: {e}") 84 | return torch.zeros((1, 384, 384)) 85 | 86 | 87 | def load_2d_finetune( 88 | path: str, dtype: torch.dtype = torch.float32, merge_channels: bool = True, **kwargs 89 | ): 90 | img_transforms = Compose( 91 | transforms=[ 92 | MonaiImageOpen(), 93 | torchvision.transforms.ToTensor(), 94 | torchvision.transforms.Resize((384, 384), interpolation=3, antialias=True), 95 | ScaleIntensity(channel_wise=True, minv=0, maxv=1), 96 | MonaiNormalize(mean=[0.5], std=[0.5]), 97 | ], 98 | lazy=True, 99 | ) 100 | 101 | try: 102 | img = img_transforms(path).as_tensor() 103 | if merge_channels: 104 | img = img.mean(0, keepdim=True) 105 | return img 106 | except Exception as e: 107 | print(f"Error in loading {path} with error: {e}") 108 | return torch.zeros((1, 384, 384)) 109 | 110 | 111 | """ 112 | Custom transform to normalize, crop, and pad 3D volumes 113 | @input: path to image (str) 114 | @Output: padding (np.array) 115 | """ 116 | 117 | 118 | def load_mri_3d(path: str, dtype: torch.dtype = torch.float32, **kwargs): 119 | mri_transforms = Compose( 120 | transforms=[ 121 | LoadImage(), 122 | EnsureChannelFirst(), 123 | Orientation(axcodes="RAS"), 124 | ScaleIntensity(channel_wise=True, minv=0, maxv=1), 125 | MonaiNormalize(mean=[0.5], std=[0.5]), 126 | CropForeground(k_divisible=[16, 16, 16]), 127 | ], 128 | lazy=True, 129 | ) 130 | 131 | try: 132 | mr_augmented = mri_transforms(path).as_tensor() 133 | return mr_augmented 134 | except Exception as e: 135 | print(f"Error in loading {path} with error: {e}") 136 | return torch.zeros((1, 128, 128, 128)) 137 | 138 | 139 | """ 140 | Custom transform to normalize, crop, and pad 3D CT volumes 141 | @input: path to image (str) 142 | @Output: padding (np.array) 143 | """ 144 | 145 | 146 | def load_ct_3d(path: str, dtype: torch.dtype = torch.float32, **kwargs): 147 | ct_transforms = Compose( 148 | transforms=[ 149 | LoadImage(), 150 | EnsureChannelFirst(), 151 | Orientation(axcodes="RAS"), 152 | ScaleIntensityRange( 153 | a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True 154 | ), 155 | MonaiNormalize(mean=[0.5], std=[0.5]), 156 | CropForeground(k_divisible=[16, 16, 16]), 157 | ], 158 | lazy=True, 159 | ) 160 | 161 | try: 162 | ct_augmented = ct_transforms(path).as_tensor() 163 | return ct_augmented 164 | except Exception as e: 165 | print(f"Error in loading {path} with error: {e}") 166 | return torch.zeros((1, 256, 256, 256)) 167 | 168 | 169 | """ 170 | Custom transform to normalize, crop, and pad 3D volumes 171 | @input: path to image (str) 172 | @Output: padding (np.array) 173 | """ 174 | 175 | 176 | def load_mri_3d_finetune(path: str, dtype: torch.dtype = torch.float32, **kwargs): 177 | mri_transforms = Compose( 178 | transforms=[ 179 | LoadImage(), 180 | EnsureChannelFirst(), 181 | Orientation(axcodes="RAS"), 182 | ScaleIntensity(channel_wise=True, minv=0, maxv=1), 183 | MonaiNormalize(mean=[0.5], std=[0.5]), 184 | MonaiPad(size=[64, 64, 64], value=-1), 185 | RandSpatialCrop(roi_size=[64, 64, 64]), 186 | ], 187 | lazy=True, 188 | ) 189 | 190 | try: 191 | mr_augmented = mri_transforms(path).as_tensor() 192 | return mr_augmented 193 | except Exception as e: 194 | print(f"Error in loading {path} with error: {e}") 195 | return torch.zeros((1, 64, 64, 64)) 196 | 197 | 198 | """ 199 | Custom transform to normalize, crop, and pad ct 3D volumes 200 | @input: path to image (str) 201 | @Output: padding (np.array) 202 | """ 203 | 204 | 205 | def load_ct_3d_finetune(path: str, dtype: torch.dtype = torch.float32, **kwargs): 206 | ct_transforms = Compose( 207 | transforms=[ 208 | LoadImage(), 209 | EnsureChannelFirst(), 210 | Orientation(axcodes="RAS"), 211 | ScaleIntensityRange( 212 | a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True 213 | ), 214 | MonaiNormalize(mean=[0.5], std=[0.5]), 215 | MonaiPad(size=[64, 64, 64], value=-1), 216 | RandSpatialCrop(roi_size=[64, 64, 64]), 217 | ], 218 | lazy=True, 219 | ) 220 | 221 | try: 222 | ct_augmented = ct_transforms(path).as_tensor() 223 | return ct_augmented 224 | except Exception as e: 225 | print(f"Error in loading {path} with error: {e}") 226 | return torch.zeros((1, 64, 64, 64)) 227 | 228 | 229 | def load_labels( 230 | df: pl.DataFrame, 231 | dtype: np.dtype = None, 232 | # fill_null=None, 233 | fill_nan: float = None, 234 | squeeze: int = None, 235 | ) -> torch.Tensor: 236 | """Load the labels from a dataframe.""" 237 | # BUG: Polars hangs when trying to convert to numpy in a DataLoader 238 | x = df.to_pandas().to_numpy() 239 | if dtype is not None: 240 | x = x.astype(dtype) 241 | 242 | if isinstance(squeeze, int): 243 | out = torch.from_numpy(x).squeeze(dim=squeeze) 244 | else: 245 | out = torch.from_numpy(x).squeeze() 246 | 247 | if isinstance(fill_nan, float): 248 | out = torch.where(out.isnan(), fill_nan, out) 249 | 250 | return out 251 | -------------------------------------------------------------------------------- /medvae/medvae_cls.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import pyrootutils 5 | import torch 6 | from accelerate import Accelerator 7 | from accelerate.utils import GradientAccumulationPlugin 8 | from hydra.utils import instantiate 9 | from omegaconf import DictConfig, OmegaConf 10 | from rich import print 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | from torchmetrics import MeanMetric 14 | 15 | 16 | from medvae.utils.cls import training_epoch, validation_epoch 17 | from medvae.utils.extras import sanitize_dataloader_kwargs, set_seed 18 | from medvae.utils.lr_schedulers import CosineScheduler 19 | 20 | from torch.utils.data import WeightedRandomSampler 21 | 22 | 23 | # Set the project root 24 | root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True) 25 | config_dir = os.path.join(root, "configs") 26 | 27 | # Register configuration resolvers 28 | OmegaConf.register_new_resolver("eval", eval) 29 | 30 | 31 | @hydra.main( 32 | version_base="1.2", config_path=config_dir, config_name="finetuned_vae.yaml" 33 | ) 34 | def main(cfg: DictConfig): 35 | print(f"=> Starting [experiment={cfg['task_name']}]") 36 | print("=> Initializing Hydra configuration") 37 | cfg = instantiate(cfg) 38 | 39 | seed = cfg.get("seed", None) 40 | if seed is not None: 41 | set_seed(seed) 42 | 43 | # Setup accelerator 44 | is_logging = cfg.get("logger", None) is not None 45 | print(f"=> Instantiate accelerator [logging={is_logging}]") 46 | logger_name = "wandb" if is_logging else None 47 | logger_kwargs = {"wandb": cfg.get("logger", None)} 48 | 49 | assert cfg.get("mixed_precision", None) in ["bf16", "fp16", "no"] 50 | gradient_accumulation_steps = cfg.get("gradient_accumulation_steps", 1) 51 | accelerator = Accelerator( 52 | gradient_accumulation_plugin=GradientAccumulationPlugin( 53 | num_steps=gradient_accumulation_steps, 54 | adjust_scheduler=False, 55 | ), 56 | mixed_precision=cfg.get("mixed_precision", None), 57 | log_with=logger_name, 58 | split_batches=True, 59 | ) 60 | 61 | print(f"=> Mixed precision: {accelerator.mixed_precision}") 62 | init_proj_name = cfg.get("init_proj_name", "compress_evals_verse") 63 | accelerator.init_trackers(init_proj_name, config=cfg, init_kwargs=logger_kwargs) 64 | device = accelerator.device 65 | 66 | # instantiate dataloaders 67 | print(f"=> Instantiating train dataloader [device={device}]") 68 | if not cfg["weight_data_loader"]: 69 | train_dataloader = DataLoader( 70 | **sanitize_dataloader_kwargs(cfg["dataloader"]["train"]) 71 | ) 72 | else: 73 | train_weights = cfg["dataloader"]["train"]["dataset"].get_label_weights() 74 | cfg["dataloader"]["train"]["shuffle"] = False 75 | train_dataloader = DataLoader( 76 | **sanitize_dataloader_kwargs(cfg["dataloader"]["train"]), 77 | sampler=WeightedRandomSampler(train_weights, len(train_weights)), 78 | ) 79 | 80 | print(f"=> Instantiating valid dataloader [device={device}]") 81 | if not cfg["weight_data_loader"]: 82 | valid_dataloader = DataLoader( 83 | **sanitize_dataloader_kwargs(cfg["dataloader"]["valid"]) 84 | ) 85 | else: 86 | val_weights = cfg["dataloader"]["valid"]["dataset"].get_label_weights() 87 | cfg["dataloader"]["valid"]["shuffle"] = False 88 | valid_dataloader = DataLoader( 89 | **sanitize_dataloader_kwargs(cfg["dataloader"]["valid"]), 90 | sampler=WeightedRandomSampler(val_weights, len(val_weights)), 91 | ) 92 | 93 | # create the model 94 | print(f"=> Creating model [device={device}]") 95 | model = cfg["model"] 96 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 97 | 98 | # build optimizer from a partial optimizer 99 | print(f"=> Instantiating the optimizer [device={device}]") 100 | params = list(filter(lambda p: p.requires_grad, model.parameters())) 101 | 102 | lr, batch_size = ( 103 | cfg["optimizer"].keywords["lr"], 104 | cfg["dataloader"]["train"]["batch_size"], 105 | ) 106 | init_lr = lr * batch_size * gradient_accumulation_steps / 256 107 | optimizer = cfg["optimizer"](params, lr=init_lr) 108 | 109 | # learning rate scheduler 110 | print(f"=> Instantiating LR scheduler [device={device}]") 111 | scheduler = CosineScheduler(optimizer, cfg["max_epoch"]) 112 | 113 | # loss function 114 | criterion = cfg["criterion"] 115 | 116 | # prepare the components for multi-gpu/mixed precision training 117 | ( 118 | train_dataloader, 119 | valid_dataloader, 120 | model, 121 | optimizer, 122 | scheduler, 123 | criterion, 124 | ) = accelerator.prepare( 125 | train_dataloader, 126 | valid_dataloader, 127 | model, 128 | optimizer, 129 | scheduler, 130 | criterion, 131 | ) 132 | 133 | accelerator.register_for_checkpointing(scheduler) 134 | 135 | # prepare the metrics 136 | default_metrics = accelerator.prepare(*[MeanMetric() for _ in range(3)]) 137 | if len(cfg["metrics"]) > 0: 138 | names, metrics = list(zip(*cfg["metrics"])) 139 | metrics = list(zip(names, accelerator.prepare(*metrics))) 140 | else: 141 | metrics = [] 142 | 143 | if len(cfg["metrics_slice"]) > 0: 144 | names, group_ids, metrics_slice = list(zip(*cfg["metrics_slice"])) 145 | slice_metrics = list( 146 | zip( 147 | names, 148 | [g["group_id"] for g in group_ids], 149 | accelerator.prepare(*metrics_slice), 150 | ) 151 | ) 152 | else: 153 | slice_metrics = [] 154 | 155 | # resume from checkpoint 156 | start_epoch = cfg["start_epoch"] 157 | if cfg["resume_from_ckpt"] is not None: 158 | accelerator.load_state(cfg["resume_from_ckpt"]) 159 | custom_ckpt = torch.load( 160 | os.path.join(cfg["resume_from_ckpt"], "custom_checkpoint_0.pkl") 161 | ) 162 | start_epoch = custom_ckpt["last_epoch"] 163 | 164 | # setup metrics 165 | max_metric = None 166 | 167 | print(f"=> Starting model training [epochs={cfg['max_epoch']}]") 168 | for epoch in range(start_epoch, cfg["max_epoch"]): 169 | # train one epoch 170 | training_epoch( 171 | cfg=cfg, 172 | epoch=epoch, 173 | accelerator=accelerator, 174 | dataloader=train_dataloader, 175 | model=model, 176 | criterion=criterion, 177 | optimizer=optimizer, 178 | default_metrics=default_metrics, 179 | classif_metrics=metrics, 180 | ) 181 | 182 | accelerator.wait_for_everyone() 183 | # adjust the learning rate per epoch 184 | scheduler.step() 185 | 186 | # evaluate the model 187 | metric = validation_epoch( 188 | cfg=cfg, 189 | epoch=epoch, 190 | accelerator=accelerator, 191 | dataloader=valid_dataloader, 192 | model=model, 193 | criterion=criterion, 194 | default_metrics=default_metrics, 195 | classif_metrics=metrics, 196 | slice_metrics=slice_metrics, 197 | ) 198 | 199 | # save the best model 200 | if max_metric is None or metric < max_metric: 201 | max_metric = metric 202 | try: 203 | accelerator.save_state(os.path.join(cfg["ckpt_dir"], "best.pt")) 204 | except Exception as e: 205 | print(e) 206 | 207 | # save checkpoint 208 | if (epoch + 1) % cfg["ckpt_every_n_epochs"] == 0 and not cfg.get( 209 | "disable_ckpts", False 210 | ): 211 | print(f"=> Saving checkpoint [epoch={epoch}]") 212 | try: 213 | accelerator.save_state( 214 | os.path.join(cfg["ckpt_dir"], f"epoch-{epoch:04d}.pt") 215 | ) 216 | except Exception as e: 217 | print(e) 218 | 219 | # save last model 220 | if not cfg.get("disable_ckpts", False): 221 | print(f"=> Saving last checkpoint [epoch={epoch}]") 222 | accelerator.save_state(os.path.join(cfg["ckpt_dir"], "last.pt")) 223 | 224 | print( 225 | f"=> Finished model training [epochs={cfg['max_epoch']}, metric={max_metric}]" 226 | ) 227 | accelerator.end_training() 228 | 229 | 230 | if __name__ == "__main__": 231 | main() 232 | -------------------------------------------------------------------------------- /medvae/medvae_finetune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pyrootutils 3 | from medvae.utils.extras import cite_function 4 | from medvae.utils.factory import create_model 5 | from medvae.utils.extras import sanitize_dataloader_kwargs, set_seed 6 | from omegaconf import DictConfig, OmegaConf 7 | import hydra 8 | from hydra.utils import instantiate 9 | import os 10 | from accelerate import Accelerator, DistributedDataParallelKwargs 11 | from accelerate.utils import GradientAccumulationPlugin 12 | from torch.utils.data import DataLoader 13 | from torchmetrics import MeanMetric 14 | from medvae.utils.vae.train_components import training_epoch, validation_epoch 15 | 16 | """ 17 | Process configuration from Hydra instead of command line arguments. 18 | 19 | Args: 20 | cfg: Hydra configuration object 21 | 22 | Returns: 23 | The processed configuration: hydra config object 24 | """ 25 | 26 | 27 | def parse_arguments(cfg: DictConfig): 28 | # Print a message to cite the med-vae paper 29 | cite_function() 30 | 31 | # Validate model_name 32 | valid_model_names = [ 33 | "medvae_4_1_2d", 34 | "medvae_4_3_2d", 35 | "medvae_4_4_2d", 36 | "medvae_8_1_2d", 37 | "medvae_8_4_2d", 38 | "medvae_4_1_3d", 39 | "medvae_8_1_3d", 40 | ] 41 | assert cfg.model_name in valid_model_names, ( 42 | f"model_name must be one of {valid_model_names}. Got: {cfg.model_name}." 43 | ) 44 | 45 | assert cfg.stage2 is False, ( 46 | f"stage2 must be False. This script is not used for 2D stage 2 finetuning. Got: {cfg.stage2}." 47 | ) 48 | 49 | return cfg 50 | 51 | 52 | # Set the project root 53 | root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True) 54 | config_dir = os.path.join(root, "configs") 55 | 56 | # Register configuration resolvers 57 | OmegaConf.register_new_resolver("eval", eval) 58 | 59 | 60 | @hydra.main( 61 | version_base="1.2", config_path=config_dir, config_name="finetuned_vae.yaml" 62 | ) 63 | def main(cfg: DictConfig): 64 | cfg = parse_arguments(cfg) 65 | 66 | # Instantiating config 67 | print(f"=> Starting [experiment={cfg.get('task_name', 'default')}]") 68 | cfg = instantiate(cfg) 69 | 70 | # Seeding 71 | if cfg.get("seed", None) is not None: 72 | print(f"=> Setting seed [seed={cfg.seed}]") 73 | set_seed(cfg.seed) 74 | 75 | torch.backends.cuda.matmul.allow_tf32 = True 76 | 77 | # Setup accelerator 78 | logger_kwargs = cfg.get("logger", None) 79 | is_logging = bool(logger_kwargs) 80 | print(f"=> Instantiate accelerator [logging={is_logging}]") 81 | 82 | gradient_accumulation_steps = cfg.get("gradient_accumulation_steps", 1) 83 | accelerator = Accelerator( 84 | gradient_accumulation_plugin=GradientAccumulationPlugin( 85 | num_steps=gradient_accumulation_steps, 86 | adjust_scheduler=False, 87 | ), 88 | mixed_precision=cfg.get("mixed_precision", None), 89 | log_with="wandb" if is_logging else None, 90 | split_batches=True, 91 | kwargs_handlers=[ 92 | DistributedDataParallelKwargs( 93 | find_unused_parameters=True, 94 | ) 95 | ], 96 | ) 97 | accelerator.init_trackers( 98 | "medvae", config=cfg, init_kwargs={"wandb": logger_kwargs} 99 | ) 100 | 101 | # Determine the mode 102 | print(f"=> Mixed precision: {accelerator.mixed_precision}") 103 | 104 | inference_mode = cfg.get("inference", False) 105 | print(f"=> Running in inference mode: {inference_mode}") 106 | 107 | print(f"=> Instantiating train dataloader [device={accelerator.device}]") 108 | train_dataloader = DataLoader( 109 | **sanitize_dataloader_kwargs(cfg["dataloader"]["train"]) 110 | ) 111 | 112 | print(f"=> Instantiating valid dataloader [device={accelerator.device}]") 113 | valid_dataloader = DataLoader( 114 | **sanitize_dataloader_kwargs(cfg["dataloader"]["valid"]) 115 | ) 116 | 117 | # Create loss function 118 | criterion = cfg.criterion 119 | 120 | # Only run the discriminator when its needed (stage 1 -- 2D; Stage 2 -- 3D) 121 | # We delay the discriminator by setting a start to avoid potential mode collapse 122 | discriminator_iter_start = criterion.discriminator_iter_start 123 | 124 | # Create model 125 | model = create_model(cfg.model_name) 126 | 127 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 128 | 129 | # Create two optimizers: one for the autoencoder and one for the discriminator 130 | print(f"=> Instantiating the optimizer [device={accelerator.device}]") 131 | 132 | batch_size, lr = cfg.batch_size, cfg.base_learning_rate 133 | lr = gradient_accumulation_steps * batch_size * lr 134 | 135 | # Create autoencoder parameters 136 | ae_params = ( 137 | list(model.encoder.parameters()) 138 | + list(model.decoder.parameters()) 139 | + list(model.quant_conv.parameters()) 140 | + list(model.post_quant_conv.parameters()) 141 | ) 142 | 143 | if criterion.learn_logvar: 144 | ae_params.append(criterion.logvar) 145 | opt_ae = torch.optim.Adam(ae_params, lr=lr, betas=(0.5, 0.9)) 146 | 147 | opt_disc = torch.optim.Adam( 148 | criterion.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 149 | ) 150 | 151 | num_metrics = 5 152 | # This mean uses lora and needs biomedclip loss 153 | if cfg.model_name in ["medvae_4_3_2d", "medvae_8_4_2d"]: 154 | num_metrics += 1 155 | 156 | # Prepare components for multi-gpu/mixed precision training 157 | (train_dataloader, valid_dataloader, model, opt_ae, opt_disc, criterion) = ( 158 | accelerator.prepare( 159 | train_dataloader, 160 | valid_dataloader, 161 | model, 162 | opt_ae, 163 | opt_disc, 164 | criterion, 165 | ) 166 | ) 167 | 168 | # Create metrics: aeloss, discloss, recloss, data(time), batch(time) 169 | default_metrics = accelerator.prepare(*[MeanMetric() for _ in range(num_metrics)]) 170 | if len(cfg["metrics"]) > 0: 171 | names, metrics = list(zip(*cfg["metrics"])) 172 | metrics = list(zip(names, accelerator.prepare(*metrics))) 173 | else: 174 | metrics = [] 175 | 176 | # Resume from checkpoint 177 | start_epoch = cfg.start_epoch 178 | if cfg.resume_from_ckpt is not None: 179 | print("Loading Model from Checkpoint: ", cfg.resume_from_ckpt) 180 | accelerator.load_state(cfg.resume_from_ckpt) 181 | 182 | options = { 183 | "max_epoch": cfg["max_epoch"], 184 | "is_logging": is_logging, 185 | "log_every_n_steps": cfg["log_every_n_steps"], 186 | "ckpt_every_n_steps": cfg["ckpt_every_n_steps"], 187 | "ckpt_dir": cfg["ckpt_dir"], 188 | "fast_dev_run": cfg["fast_dev_run"], 189 | } 190 | 191 | print(f"=> Starting model training [epochs={cfg['max_epoch']}]") 192 | min_loss = None 193 | global_step = cfg.get("global_step", 0) 194 | for epoch in range(start_epoch, cfg["max_epoch"]): 195 | global_step = training_epoch( 196 | options=options, 197 | epoch=epoch, 198 | global_step=global_step, 199 | accelerator=accelerator, 200 | dataloader=train_dataloader, 201 | model=model, 202 | criterion=criterion, 203 | discriminator_iter_start=discriminator_iter_start, 204 | default_metrics=default_metrics, 205 | rec_metrics=metrics, 206 | optimizer_ae=opt_ae, 207 | optimizer_disc=opt_disc, 208 | ) 209 | 210 | accelerator.wait_for_everyone() 211 | 212 | loss = validation_epoch( 213 | options=options, 214 | epoch=epoch, 215 | accelerator=accelerator, 216 | dataloader=valid_dataloader, 217 | model=model, 218 | criterion=criterion, 219 | default_metrics=default_metrics, 220 | rec_metrics=metrics, 221 | global_step=global_step, 222 | ) 223 | 224 | # save the best model 225 | if min_loss is None or loss < min_loss: 226 | try: 227 | accelerator.save_state( 228 | os.path.join(cfg.ckpt_dir, "best.pt"), safe_serialization=False 229 | ) 230 | except Exception as e: 231 | print(e) 232 | min_loss = loss 233 | 234 | # save checkpoint 235 | if (epoch + 1) % cfg.get("ckpt_every_n_epochs", 1) == 0: 236 | print(f"=> Saving checkpoint [epoch={epoch}]") 237 | try: 238 | accelerator.save_state( 239 | os.path.join(cfg.ckpt_dir, f"epoch-{epoch:04d}.pt"), 240 | safe_serialization=False, 241 | ) 242 | except Exception as e: 243 | print(e) 244 | 245 | # save last model 246 | accelerator.save_state( 247 | os.path.join(cfg.ckpt_dir, "last.pt"), safe_serialization=False 248 | ) 249 | 250 | # save model to output directory 251 | accelerator.save( 252 | model, os.path.join(cfg.output, "model.pt"), safe_serialization=False 253 | ) 254 | 255 | print(f"=> Finished model training [epochs={cfg['max_epoch']}, metric={min_loss}]") 256 | accelerator.end_training() 257 | 258 | 259 | if __name__ == "__main__": 260 | main() 261 | -------------------------------------------------------------------------------- /medvae/medvae_finetune_s2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Please note this script is for 2D stage 2 finetuning. If you have a 3D model, please use the medvae_finetune.py script. 3 | """ 4 | 5 | import torch 6 | import pyrootutils 7 | from medvae.utils.extras import cite_function 8 | from medvae.utils.factory import create_model 9 | from medvae.utils.extras import sanitize_dataloader_kwargs, set_seed 10 | from omegaconf import DictConfig, OmegaConf 11 | import hydra 12 | from hydra.utils import instantiate 13 | import os 14 | from accelerate import Accelerator, DistributedDataParallelKwargs 15 | from accelerate.utils import GradientAccumulationPlugin 16 | from torch.utils.data import DataLoader 17 | from torchmetrics import MeanMetric 18 | from medvae.utils.vae.train_components_stage2 import training_epoch, validation_epoch 19 | 20 | """ 21 | Process configuration from Hydra instead of command line arguments. 22 | 23 | Args: 24 | cfg: Hydra configuration object 25 | 26 | Returns: 27 | The processed configuration: hydra config object 28 | """ 29 | 30 | 31 | def parse_arguments(cfg: DictConfig): 32 | # Print a message to cite the med-vae paper 33 | cite_function() 34 | 35 | # Add visual emphasis to important warning message 36 | print("\n" + "=" * 80) 37 | print("⚠️ WARNING: This script is for 2D stage 2 finetuning ONLY ⚠️") 38 | print("If you have a 3D model, please use the medvae_finetune.py script instead.") 39 | print("=" * 80 + "\n") 40 | 41 | # Validate model_name 42 | valid_model_names = [ 43 | "medvae_4_1_2d", 44 | "medvae_4_3_2d", 45 | "medvae_4_4_2d", 46 | "medvae_8_1_2d", 47 | "medvae_8_4_2d", 48 | ] 49 | assert cfg.model_name in valid_model_names, ( 50 | f"model_name must be one of {valid_model_names}. Got: {cfg.model_name}." 51 | ) 52 | 53 | assert cfg.stage2 is True, ( 54 | f"stage2 must be True for stage 2 finetuning. This is used for 2D stage 2 finetuning. Got: {cfg.stage2}." 55 | ) 56 | 57 | cfg.stage2_ckpt = os.path.abspath(cfg.stage2_ckpt) 58 | if not os.path.exists(cfg.stage2_ckpt): 59 | raise FileNotFoundError(f"stage2_ckpt {cfg.stage2_ckpt} does not exist.") 60 | 61 | return cfg 62 | 63 | 64 | # Set the project root 65 | root = pyrootutils.setup_root(__file__, dotenv=True, pythonpath=True) 66 | config_dir = os.path.join(root, "configs") 67 | 68 | # Register configuration resolvers 69 | OmegaConf.register_new_resolver("eval", eval) 70 | 71 | 72 | @hydra.main( 73 | version_base="1.2", config_path=config_dir, config_name="finetuned_vae.yaml" 74 | ) 75 | def main(cfg: DictConfig): 76 | cfg = parse_arguments(cfg) 77 | 78 | # Instantiating config 79 | print(f"=> Starting [experiment={cfg.get('task_name', 'default')}]") 80 | cfg = instantiate(cfg) 81 | 82 | # Seeding 83 | if cfg.get("seed", None) is not None: 84 | print(f"=> Setting seed [seed={cfg.seed}]") 85 | set_seed(cfg.seed) 86 | 87 | torch.backends.cuda.matmul.allow_tf32 = True 88 | 89 | # Setup accelerator 90 | logger_kwargs = cfg.get("logger", None) 91 | is_logging = bool(logger_kwargs) 92 | print(f"=> Instantiate accelerator [logging={is_logging}]") 93 | 94 | gradient_accumulation_steps = cfg.get("gradient_accumulation_steps", 1) 95 | accelerator = Accelerator( 96 | gradient_accumulation_plugin=GradientAccumulationPlugin( 97 | num_steps=gradient_accumulation_steps, 98 | adjust_scheduler=False, 99 | ), 100 | mixed_precision=cfg.get("mixed_precision", None), 101 | log_with="wandb" if is_logging else None, 102 | split_batches=True, 103 | kwargs_handlers=[ 104 | DistributedDataParallelKwargs( 105 | find_unused_parameters=True, 106 | ) 107 | ], 108 | ) 109 | accelerator.init_trackers( 110 | "medvae", config=cfg, init_kwargs={"wandb": logger_kwargs} 111 | ) 112 | 113 | # Determine the mode 114 | print(f"=> Mixed precision: {accelerator.mixed_precision}") 115 | 116 | inference_mode = cfg.get("inference", False) 117 | print(f"=> Running in inference mode: {inference_mode}") 118 | 119 | print(f"=> Instantiating train dataloader [device={accelerator.device}]") 120 | train_dataloader = DataLoader( 121 | **sanitize_dataloader_kwargs(cfg["dataloader"]["train"]) 122 | ) 123 | 124 | print(f"=> Instantiating valid dataloader [device={accelerator.device}]") 125 | valid_dataloader = DataLoader( 126 | **sanitize_dataloader_kwargs(cfg["dataloader"]["valid"]) 127 | ) 128 | 129 | # Create loss function 130 | criterion = cfg.criterion 131 | 132 | # Create model and use prior weight for stage 1 weight for stage 2 finetuning 133 | model = create_model( 134 | cfg.model_name, 135 | existing_weight=cfg.stage2_ckpt, 136 | training=False, 137 | state_dict=False, 138 | ) 139 | 140 | # Freeze the encoder, decoder, quant_conv, and post_quant_conv layers, so that only the projection head is trainable 141 | print( 142 | "Trainable Params before freeze:", 143 | sum(p.numel() for p in model.parameters() if p.requires_grad), 144 | ) 145 | model.encoder.requires_grad_(False) 146 | model.decoder.requires_grad_(False) 147 | model.quant_conv.requires_grad_(False) 148 | model.post_quant_conv.requires_grad_(False) 149 | print( 150 | "Trainable Params after freeze:", 151 | sum(p.numel() for p in model.parameters() if p.requires_grad), 152 | ) 153 | 154 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 155 | 156 | # Create two optimizers: one for the autoencoder and one for the discriminator 157 | print(f"=> Instantiating the optimizer [device={accelerator.device}]") 158 | 159 | batch_size, lr = cfg.batch_size, cfg.base_learning_rate 160 | lr = gradient_accumulation_steps * batch_size * lr 161 | 162 | ae_params = list(model.channel_ds.parameters()) + list( 163 | model.channel_proj.parameters() 164 | ) 165 | opt_ae = torch.optim.Adam(ae_params, lr=lr, betas=(0.5, 0.9)) 166 | 167 | num_metrics = 3 168 | 169 | # Prepare components for multi-gpu/mixed precision training 170 | (train_dataloader, valid_dataloader, model, criterion) = accelerator.prepare( 171 | train_dataloader, 172 | valid_dataloader, 173 | model, 174 | criterion, 175 | ) 176 | 177 | # Create metrics: aeloss, discloss, recloss, data(time), batch(time) 178 | default_metrics = accelerator.prepare(*[MeanMetric() for _ in range(num_metrics)]) 179 | if len(cfg["metrics"]) > 0: 180 | names, metrics = list(zip(*cfg["metrics"])) 181 | metrics = list(zip(names, accelerator.prepare(*metrics))) 182 | else: 183 | metrics = [] 184 | 185 | options = { 186 | "max_epoch": cfg["max_epoch"], 187 | "is_logging": is_logging, 188 | "log_every_n_steps": cfg["log_every_n_steps"], 189 | "ckpt_every_n_steps": cfg["ckpt_every_n_steps"], 190 | "ckpt_dir": cfg["ckpt_dir"], 191 | "fast_dev_run": cfg["fast_dev_run"], 192 | } 193 | 194 | print(f"=> Starting model training [epochs={cfg['max_epoch']}]") 195 | min_loss = None 196 | global_step = cfg.get("global_step", 0) 197 | for epoch in range(cfg["max_epoch"]): 198 | global_step = training_epoch( 199 | options=options, 200 | epoch=epoch, 201 | global_step=global_step, 202 | accelerator=accelerator, 203 | dataloader=train_dataloader, 204 | model=model, 205 | criterion=criterion, 206 | default_metrics=default_metrics, 207 | rec_metrics=metrics, 208 | opt=opt_ae, 209 | ) 210 | 211 | accelerator.wait_for_everyone() 212 | 213 | loss = validation_epoch( 214 | options=options, 215 | epoch=epoch, 216 | accelerator=accelerator, 217 | dataloader=valid_dataloader, 218 | model=model, 219 | criterion=criterion, 220 | default_metrics=default_metrics, 221 | rec_metrics=metrics, 222 | global_step=global_step, 223 | ) 224 | 225 | # save the best model 226 | if min_loss is None or loss < min_loss: 227 | try: 228 | accelerator.save_state( 229 | os.path.join(cfg.ckpt_dir, "best.pt"), safe_serialization=False 230 | ) 231 | except Exception as e: 232 | print(e) 233 | min_loss = loss 234 | 235 | # save checkpoint 236 | if (epoch + 1) % cfg.get("ckpt_every_n_epochs", 1) == 0: 237 | print(f"=> Saving checkpoint [epoch={epoch}]") 238 | try: 239 | accelerator.save_state( 240 | os.path.join(cfg.ckpt_dir, f"epoch-{epoch:04d}.pt"), 241 | safe_serialization=False, 242 | ) 243 | except Exception as e: 244 | print(e) 245 | 246 | # save last model 247 | accelerator.save_state( 248 | os.path.join(cfg.ckpt_dir, "last.pt"), safe_serialization=False 249 | ) 250 | 251 | # save model to output directory 252 | accelerator.save( 253 | model, os.path.join(cfg.output, "model.pt"), safe_serialization=False 254 | ) 255 | 256 | print(f"=> Finished model training [epochs={cfg['max_epoch']}, metric={min_loss}]") 257 | accelerator.end_training() 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /medvae/utils/vae/train_components.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | from accelerate import Accelerator 7 | from rich import print 8 | from torch import nn 9 | from torch.optim import Optimizer 10 | from torch.utils.data import DataLoader 11 | from torchmetrics import Metric 12 | 13 | from medvae.utils.extras import get_weight_dtype 14 | from medvae.utils.transforms import to_dict 15 | 16 | __all__ = ["training_epoch", "validation_epoch"] 17 | 18 | 19 | def training_epoch( 20 | epoch: int, 21 | global_step: int, 22 | accelerator: Accelerator, 23 | dataloader: DataLoader, 24 | model: nn.Module, 25 | criterion: nn.Module, 26 | discriminator_iter_start: int, 27 | default_metrics: List[Metric], 28 | rec_metrics: List[Metric], 29 | optimizer_ae: Optimizer, 30 | optimizer_disc: Optimizer, 31 | options: Dict[str, Any], 32 | ): 33 | """Train a single epoch of a VAE model.""" 34 | for metric in default_metrics: 35 | metric.reset() 36 | 37 | for _, metric in rec_metrics: 38 | metric.reset() 39 | 40 | if len(default_metrics) == 6: 41 | ( 42 | metric_aeloss, 43 | metric_discloss, 44 | metric_ae_recloss, 45 | metric_bc_loss, 46 | metric_data, 47 | metric_batch, 48 | ) = default_metrics 49 | else: 50 | ( 51 | metric_aeloss, 52 | metric_discloss, 53 | metric_ae_recloss, 54 | metric_data, 55 | metric_batch, 56 | ) = default_metrics 57 | 58 | model.train() 59 | epoch_start, batch_start = time(), time() 60 | dtype = get_weight_dtype(accelerator) 61 | discloss = torch.tensor(0.0) 62 | aeloss = torch.tensor(0.0) 63 | compute_disc = False 64 | for i, batch in enumerate(dataloader): 65 | data_time = time() - batch_start 66 | 67 | batch = to_dict(batch) 68 | if global_step >= discriminator_iter_start: 69 | compute_disc = True 70 | images = batch["img"].to(dtype) 71 | 72 | if (compute_disc and i % 2 == 0) or (not compute_disc): 73 | # Train the encoder and decoder 74 | reconstructions, posterior, latent = model(images) 75 | 76 | aeloss, log_dict_ae = criterion( 77 | inputs=images, 78 | reconstructions=reconstructions, 79 | posteriors=posterior, 80 | latent=latent, 81 | optimizer_idx=0, 82 | global_step=global_step, 83 | weight_dtype=dtype, 84 | last_layer=accelerator.unwrap_model(model).get_last_layer(), 85 | split="train", 86 | ) 87 | 88 | optimizer_ae.zero_grad() 89 | accelerator.backward(aeloss) 90 | optimizer_ae.step() 91 | 92 | elif compute_disc and i % 2 == 1: 93 | with torch.no_grad(): 94 | reconstructions, posterior, latent = model(images) 95 | discloss, _log_dict_disc = criterion( 96 | inputs=images, 97 | reconstructions=reconstructions, 98 | posteriors=posterior, 99 | latent=latent, 100 | optimizer_idx=1, 101 | global_step=global_step, 102 | weight_dtype=dtype, 103 | last_layer=None, 104 | split="train", 105 | ) 106 | 107 | optimizer_disc.zero_grad() 108 | accelerator.backward(discloss) 109 | optimizer_disc.step() 110 | 111 | # Update metrics 112 | batch_time = time() - batch_start 113 | metric_aeloss.update(aeloss) 114 | metric_ae_recloss.update(log_dict_ae["train/rec_loss"]) 115 | metric_discloss.update(discloss) 116 | metric_data.update(data_time) 117 | metric_batch.update(batch_time) 118 | if "train/bc_loss" in log_dict_ae: 119 | metric_bc_loss.update(log_dict_ae["train/bc_loss"]) 120 | 121 | images, reconstructions = accelerator.gather_for_metrics((images, reconstructions)) 122 | for _, metric in rec_metrics: 123 | metric.update(images, reconstructions) 124 | 125 | # Logging values 126 | bc_print = "" 127 | if "train/bc_loss" in log_dict_ae: 128 | bc_print = f"BC Loss: {log_dict_ae['train/bc_loss'].item():.3f} ({metric_bc_loss.compute():.3f}) - " 129 | print( 130 | f"\r[Epoch <{epoch:03}/{options['max_epoch']}>: Step <{i:03}/{len(dataloader)}>] - " 131 | + f"Data(s): {data_time:.3f} ({metric_data.compute():.3f}) - " 132 | + f"Batch(s): {batch_time:.3f} ({metric_batch.compute():.3f}) - " 133 | + f"AE Loss: {aeloss.item():.3f} ({metric_aeloss.compute():.3f}) - " 134 | + f"AE Rec Loss: {log_dict_ae['train/rec_loss'].item():.3f} ({metric_ae_recloss.compute():.3f}) - " 135 | + bc_print 136 | + f"Disc Loss: {discloss.item():.3f} ({metric_discloss.compute():.3f}) - " 137 | + f"{(((time() - epoch_start) / (i + 1)) * (len(dataloader) - i)) / 60:.2f} m remaining\n" 138 | ) 139 | 140 | if options["is_logging"] and i % options["log_every_n_steps"] == 0: 141 | log_data = { 142 | "epoch": epoch, 143 | "mean_aeloss": metric_aeloss.compute(), 144 | "mean_ae_recloss": metric_ae_recloss.compute(), 145 | "mean_discloss": metric_discloss.compute(), 146 | "mean_data": metric_data.compute(), 147 | "mean_batch": metric_batch.compute(), 148 | "step": i, 149 | "step_global": global_step, 150 | "step_aeloss": aeloss, 151 | "step_ae_recloss": log_dict_ae["train/rec_loss"], 152 | "step_discloss": discloss, 153 | "step_data": data_time, 154 | "step_batch": batch_time, 155 | } 156 | 157 | if "train/bc_loss" in log_dict_ae: 158 | log_data["mean_bc_loss"] = metric_bc_loss.compute() 159 | 160 | for name, metric in rec_metrics: 161 | log_data[name] = metric.compute() 162 | 163 | accelerator.log(log_data) 164 | global_step += 1 165 | 166 | if global_step % options["ckpt_every_n_steps"] == 0: 167 | try: 168 | accelerator.save_state(os.path.join(options["ckpt_dir"], f"step_{global_step}.pt")) 169 | except Exception as e: 170 | print(e) 171 | 172 | batch_start = time() 173 | 174 | if options["fast_dev_run"]: 175 | break 176 | 177 | return global_step 178 | 179 | 180 | def validation_epoch( 181 | options: Dict[str, Any], 182 | epoch: int, 183 | accelerator: Accelerator, 184 | dataloader: DataLoader, 185 | model: nn.Module, 186 | criterion: nn.Module, 187 | default_metrics: List[Metric], 188 | rec_metrics: List[Metric], 189 | global_step: int = 0, 190 | postfix: str = "", 191 | ): 192 | """Validate one epoch for the VAE model.""" 193 | for metric in default_metrics: 194 | metric.reset() 195 | 196 | for _, metric in rec_metrics: 197 | metric.reset() 198 | 199 | metric_aeloss = default_metrics[0] 200 | metric_discloss = default_metrics[1] 201 | metric_ae_recloss = default_metrics[2] 202 | metric_aeloss.reset() 203 | metric_discloss.reset() 204 | metric_ae_recloss.reset() 205 | 206 | model.eval() 207 | criterion.eval() 208 | epoch_start = time() 209 | dtype = get_weight_dtype(accelerator) 210 | with torch.no_grad(): 211 | for i, batch in enumerate(dataloader): 212 | batch = to_dict(batch) 213 | 214 | images = batch["img"].to(dtype) 215 | reconstructions, posterior, latent = model(images) 216 | 217 | aeloss, log_dict_ae = criterion( 218 | inputs=images, 219 | reconstructions=reconstructions, 220 | posteriors=posterior, 221 | latent=latent, 222 | optimizer_idx=0, 223 | global_step=global_step, 224 | weight_dtype=dtype, 225 | last_layer=accelerator.unwrap_model(model).get_last_layer(), 226 | split="valid", 227 | ) 228 | 229 | discloss, _log_dict_disc = criterion( 230 | inputs=images, 231 | reconstructions=reconstructions, 232 | posteriors=posterior, 233 | latent=latent, 234 | optimizer_idx=1, 235 | global_step=global_step, 236 | weight_dtype=dtype, 237 | last_layer=accelerator.unwrap_model(model).get_last_layer(), 238 | split="valid", 239 | ) 240 | 241 | metric_aeloss.update(aeloss) 242 | metric_ae_recloss.update(log_dict_ae["valid/rec_loss"]) 243 | metric_discloss.update(discloss) 244 | images, reconstructions = accelerator.gather_for_metrics((images, reconstructions)) 245 | for _, metric in rec_metrics: 246 | metric.update(images, reconstructions) 247 | 248 | # Logging values 249 | print( 250 | f"\r Validation{postfix}: " 251 | + f"\r[Epoch <{epoch:03}/{options['max_epoch']}>: Step <{i:03}/{len(dataloader)}>] - " 252 | + f"AE Loss: {aeloss.item():.3f} ({metric_aeloss.compute():.3f}) - " 253 | + f"AE Rec Loss: {log_dict_ae['valid/rec_loss'].item():.3f} ({metric_ae_recloss.compute():.3f}) - " 254 | + f"Disc Loss: {discloss.item():.3f} ({metric_discloss.compute():.3f}) - " 255 | + f"{(((time() - epoch_start) / (i + 1)) * (len(dataloader) - i)) / 60:.2f} m remaining\n" 256 | ) 257 | 258 | if options["fast_dev_run"]: 259 | break 260 | 261 | if options["is_logging"]: 262 | log_data = { 263 | f"valid{postfix}/epoch": epoch, 264 | f"valid{postfix}/mean_aeloss": metric_aeloss.compute(), 265 | f"valid{postfix}/mean_ae_recloss": metric_ae_recloss.compute(), 266 | f"valid{postfix}/mean_discloss": metric_discloss.compute(), 267 | } 268 | for name, metric in rec_metrics: 269 | log_data[f"valid{postfix}/{name}"] = metric.compute() 270 | 271 | accelerator.log(log_data) 272 | 273 | return metric_ae_recloss.compute() 274 | -------------------------------------------------------------------------------- /medvae/utils/vae/loss_components.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from collections import namedtuple 4 | 5 | import requests 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import models 10 | from tqdm import tqdm 11 | 12 | 13 | def hinge_loss(logits_real, logits_fake): 14 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 15 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 16 | d_loss = 0.5 * (loss_real + loss_fake) 17 | return d_loss 18 | 19 | 20 | class LPIPS(nn.Module): 21 | # Learned perceptual metric 22 | def __init__(self, use_dropout=True): 23 | super().__init__() 24 | self.scaling_layer = ScalingLayer() 25 | self.chns = [64, 128, 256, 512, 512] # vg16 features 26 | self.net = vgg16(requires_grad=False) 27 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 28 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 29 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 30 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 31 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 32 | self.load_from_pretrained() 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def load_from_pretrained(self, name="vgg_lpips"): 37 | ckpt = get_ckpt_path(name, ".cache") 38 | self.load_state_dict( 39 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 40 | ) 41 | print(f"loaded pretrained LPIPS loss from {ckpt}") 42 | 43 | @classmethod 44 | def from_pretrained(cls, name="vgg_lpips"): 45 | if name != "vgg_lpips": 46 | raise NotImplementedError 47 | model = cls() 48 | ckpt = get_ckpt_path(name) 49 | model.load_state_dict( 50 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 51 | ) 52 | return model 53 | 54 | def forward(self, input, target): 55 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 56 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 57 | feats0, feats1, diffs = {}, {}, {} 58 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 59 | for kk in range(len(self.chns)): 60 | feats0[kk], feats1[kk] = ( 61 | normalize_tensor(outs0[kk]), 62 | normalize_tensor(outs1[kk]), 63 | ) 64 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 65 | 66 | res = [ 67 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 68 | for kk in range(len(self.chns)) 69 | ] 70 | val = res[0] 71 | for c in range(1, len(self.chns)): 72 | val += res[c] 73 | return val 74 | 75 | 76 | def weights_init(m): 77 | classname = m.__class__.__name__ 78 | if classname.find("Conv") != -1: 79 | nn.init.normal_(m.weight.data, 0.0, 0.02) 80 | elif classname.find("BatchNorm") != -1: 81 | nn.init.normal_(m.weight.data, 1.0, 0.02) 82 | nn.init.constant_(m.bias.data, 0) 83 | 84 | 85 | class ScalingLayer(nn.Module): 86 | def __init__(self): 87 | super().__init__() 88 | self.register_buffer( 89 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 90 | ) 91 | self.register_buffer( 92 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 93 | ) 94 | 95 | def forward(self, inp): 96 | return (inp - self.shift) / self.scale 97 | 98 | 99 | class NetLinLayer(nn.Module): 100 | """A single linear layer which does a 1x1 conv.""" 101 | 102 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 103 | super().__init__() 104 | layers = ( 105 | [ 106 | nn.Dropout(), 107 | ] 108 | if (use_dropout) 109 | else [] 110 | ) 111 | layers += [ 112 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 113 | ] 114 | self.model = nn.Sequential(*layers) 115 | 116 | 117 | class vgg16(torch.nn.Module): 118 | def __init__(self, requires_grad=False): 119 | super().__init__() 120 | vgg_pretrained_features = models.vgg16( 121 | weights="VGG16_Weights.IMAGENET1K_V1" 122 | ).features 123 | self.slice1 = torch.nn.Sequential() 124 | self.slice2 = torch.nn.Sequential() 125 | self.slice3 = torch.nn.Sequential() 126 | self.slice4 = torch.nn.Sequential() 127 | self.slice5 = torch.nn.Sequential() 128 | self.N_slices = 5 129 | for x in range(4): 130 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 131 | for x in range(4, 9): 132 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 133 | for x in range(9, 16): 134 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 135 | for x in range(16, 23): 136 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 137 | for x in range(23, 30): 138 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 139 | if not requires_grad: 140 | for param in self.parameters(): 141 | param.requires_grad = False 142 | 143 | def forward(self, X): 144 | h = self.slice1(X) 145 | h_relu1_2 = h 146 | h = self.slice2(h) 147 | h_relu2_2 = h 148 | h = self.slice3(h) 149 | h_relu3_3 = h 150 | h = self.slice4(h) 151 | h_relu4_3 = h 152 | h = self.slice5(h) 153 | h_relu5_3 = h 154 | vgg_outputs = namedtuple( 155 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 156 | ) 157 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 158 | return out 159 | 160 | 161 | def normalize_tensor(x, eps=1e-10): 162 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 163 | return x / (norm_factor + eps) 164 | 165 | 166 | def spatial_average(x, keepdim=True): 167 | return x.mean([2, 3], keepdim=keepdim) 168 | 169 | 170 | def get_ckpt_path(name, root, check=False): 171 | vgg_url = "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 172 | vgg_ckpt = "vgg.pth" 173 | # vgg_md5 = "d507d7349b931f0638a25a48a722f98a" 174 | path = os.path.join(root, vgg_ckpt) 175 | if not os.path.exists(path): 176 | print(f"Downloading {name} model from {vgg_url} to {path}") 177 | download(vgg_url, path) 178 | return path 179 | 180 | 181 | def download(url, local_path, chunk_size=1024): 182 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 183 | with requests.get(url, stream=True) as r: 184 | total_size = int(r.headers.get("content-length", 0)) 185 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 186 | with open(local_path, "wb") as f: 187 | for data in r.iter_content(chunk_size=chunk_size): 188 | if data: 189 | f.write(data) 190 | pbar.update(chunk_size) 191 | 192 | 193 | class NLayerDiscriminator(nn.Module): 194 | """Defines a PatchGAN discriminator as in Pix2Pix. 195 | 196 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 197 | """ 198 | 199 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 200 | """Construct a PatchGAN discriminator 201 | Parameters: 202 | input_nc (int) -- the number of channels in input images 203 | ndf (int) -- the number of filters in the last conv layer 204 | n_layers (int) -- the number of conv layers in the discriminator 205 | norm_layer -- normalization layer 206 | """ 207 | super().__init__() 208 | if not use_actnorm: 209 | norm_layer = nn.BatchNorm2d 210 | else: 211 | norm_layer = ActNorm 212 | if isinstance( 213 | norm_layer, functools.partial 214 | ): # no need to use bias as BatchNorm2d has affine parameters 215 | use_bias = norm_layer.func != nn.BatchNorm2d 216 | else: 217 | use_bias = norm_layer != nn.BatchNorm2d 218 | 219 | kw = 4 220 | padw = 1 221 | sequence = [ 222 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 223 | nn.LeakyReLU(0.2, True), 224 | ] 225 | nf_mult = 1 226 | nf_mult_prev = 1 227 | for n in range(1, n_layers): # gradually increase the number of filters 228 | nf_mult_prev = nf_mult 229 | nf_mult = min(2**n, 8) 230 | sequence += [ 231 | nn.Conv2d( 232 | ndf * nf_mult_prev, 233 | ndf * nf_mult, 234 | kernel_size=kw, 235 | stride=2, 236 | padding=padw, 237 | bias=use_bias, 238 | ), 239 | norm_layer(ndf * nf_mult), 240 | nn.LeakyReLU(0.2, True), 241 | ] 242 | 243 | nf_mult_prev = nf_mult 244 | nf_mult = min(2**n_layers, 8) 245 | sequence += [ 246 | nn.Conv2d( 247 | ndf * nf_mult_prev, 248 | ndf * nf_mult, 249 | kernel_size=kw, 250 | stride=1, 251 | padding=padw, 252 | bias=use_bias, 253 | ), 254 | norm_layer(ndf * nf_mult), 255 | nn.LeakyReLU(0.2, True), 256 | ] 257 | 258 | sequence += [ 259 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 260 | ] # output 1 channel prediction map 261 | self.main = nn.Sequential(*sequence) 262 | 263 | def forward(self, input): 264 | """Standard forward.""" 265 | return self.main(input) 266 | 267 | 268 | class ActNorm(nn.Module): 269 | def __init__( 270 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 271 | ): 272 | assert affine 273 | super().__init__() 274 | self.logdet = logdet 275 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 276 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 277 | self.allow_reverse_init = allow_reverse_init 278 | 279 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 280 | 281 | def initialize(self, input): 282 | with torch.no_grad(): 283 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 284 | mean = ( 285 | flatten.mean(1) 286 | .unsqueeze(1) 287 | .unsqueeze(2) 288 | .unsqueeze(3) 289 | .permute(1, 0, 2, 3) 290 | ) 291 | std = ( 292 | flatten.std(1) 293 | .unsqueeze(1) 294 | .unsqueeze(2) 295 | .unsqueeze(3) 296 | .permute(1, 0, 2, 3) 297 | ) 298 | 299 | self.loc.data.copy_(-mean) 300 | self.scale.data.copy_(1 / (std + 1e-6)) 301 | 302 | def forward(self, input, reverse=False): 303 | if reverse: 304 | return self.reverse(input) 305 | if len(input.shape) == 2: 306 | input = input[:, :, None, None] 307 | squeeze = True 308 | else: 309 | squeeze = False 310 | 311 | _, _, height, width = input.shape 312 | 313 | if self.training and self.initialized.item() == 0: 314 | self.initialize(input) 315 | self.initialized.fill_(1) 316 | 317 | h = self.scale * (input + self.loc) 318 | 319 | if squeeze: 320 | h = h.squeeze(-1).squeeze(-1) 321 | 322 | if self.logdet: 323 | log_abs = torch.log(torch.abs(self.scale)) 324 | logdet = height * width * torch.sum(log_abs) 325 | logdet = logdet * torch.ones(input.shape[0]).to(input) 326 | return h, logdet 327 | 328 | return h 329 | 330 | def reverse(self, output): 331 | if self.training and self.initialized.item() == 0: 332 | if not self.allow_reverse_init: 333 | raise RuntimeError( 334 | "Initializing ActNorm in reverse direction is " 335 | "disabled by default. Use allow_reverse_init=True to enable." 336 | ) 337 | else: 338 | self.initialize(output) 339 | self.initialized.fill_(1) 340 | 341 | if len(output.shape) == 2: 342 | output = output[:, :, None, None] 343 | squeeze = True 344 | else: 345 | squeeze = False 346 | 347 | h = output / self.scale - self.loc 348 | 349 | if squeeze: 350 | h = h.squeeze(-1).squeeze(-1) 351 | return h 352 | -------------------------------------------------------------------------------- /medvae/utils/vae/diffusionmodels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | __all__ = ["Encoder", "Decoder"] 7 | 8 | # ----------------------ENCODER & DECODER DEFINITIONS------------------------ 9 | 10 | 11 | class Encoder(nn.Module): 12 | def __init__( 13 | self, 14 | *, 15 | ch, 16 | out_ch, 17 | ch_mult=(1, 2, 4, 8), 18 | num_res_blocks, 19 | attn_resolutions, 20 | dropout=0.0, 21 | resamp_with_conv=True, 22 | in_channels, 23 | resolution, 24 | z_channels, 25 | double_z=True, 26 | use_linear_attn=False, 27 | attn_type="vanilla", 28 | **ignore_kwargs, 29 | ): 30 | super().__init__() 31 | if use_linear_attn: 32 | attn_type = "linear" 33 | self.ch = ch 34 | self.temb_ch = 0 35 | self.num_resolutions = len(ch_mult) 36 | self.num_res_blocks = num_res_blocks 37 | self.resolution = resolution 38 | self.in_channels = in_channels 39 | 40 | # downsampling 41 | self.conv_in = torch.nn.Conv2d( 42 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 43 | ) 44 | 45 | curr_res = resolution 46 | in_ch_mult = (1,) + tuple(ch_mult) 47 | self.in_ch_mult = in_ch_mult 48 | self.down = nn.ModuleList() 49 | for i_level in range(self.num_resolutions): 50 | block = nn.ModuleList() 51 | attn = nn.ModuleList() 52 | block_in = ch * in_ch_mult[i_level] 53 | block_out = ch * ch_mult[i_level] 54 | for i_block in range(self.num_res_blocks): 55 | block.append( 56 | ResnetBlock( 57 | in_channels=block_in, 58 | out_channels=block_out, 59 | temb_channels=self.temb_ch, 60 | dropout=dropout, 61 | ) 62 | ) 63 | block_in = block_out 64 | if curr_res in attn_resolutions: 65 | attn.append(make_attn(block_in, attn_type=attn_type)) 66 | down = nn.Module() 67 | down.block = block 68 | down.attn = attn 69 | if i_level != self.num_resolutions - 1: 70 | down.downsample = Downsample(block_in, resamp_with_conv) 71 | curr_res = curr_res // 2 72 | self.down.append(down) 73 | 74 | # middle 75 | self.mid = nn.Module() 76 | self.mid.block_1 = ResnetBlock( 77 | in_channels=block_in, 78 | out_channels=block_in, 79 | temb_channels=self.temb_ch, 80 | dropout=dropout, 81 | ) 82 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 83 | self.mid.block_2 = ResnetBlock( 84 | in_channels=block_in, 85 | out_channels=block_in, 86 | temb_channels=self.temb_ch, 87 | dropout=dropout, 88 | ) 89 | 90 | # end 91 | self.norm_out = Normalize(block_in) 92 | self.conv_out = torch.nn.Conv2d( 93 | block_in, 94 | 2 * z_channels if double_z else z_channels, 95 | kernel_size=3, 96 | stride=1, 97 | padding=1, 98 | ) 99 | 100 | def forward(self, x): 101 | # timestep embedding 102 | temb = None 103 | 104 | # downsampling 105 | hs = [self.conv_in(x)] 106 | for i_level in range(self.num_resolutions): 107 | for i_block in range(self.num_res_blocks): 108 | h = self.down[i_level].block[i_block](hs[-1], temb) 109 | if len(self.down[i_level].attn) > 0: 110 | h = self.down[i_level].attn[i_block](h) 111 | hs.append(h) 112 | if i_level != self.num_resolutions - 1: 113 | hs.append(self.down[i_level].downsample(hs[-1])) 114 | 115 | # middle 116 | h = hs[-1] 117 | h = self.mid.block_1(h, temb) 118 | h = self.mid.attn_1(h) 119 | h = self.mid.block_2(h, temb) 120 | 121 | # end 122 | h = self.norm_out(h) 123 | h = nonlinearity(h) 124 | h = self.conv_out(h) 125 | return h 126 | 127 | 128 | class Decoder(nn.Module): 129 | def __init__( 130 | self, 131 | *, 132 | ch, 133 | out_ch, 134 | ch_mult=(1, 2, 4, 8), 135 | num_res_blocks, 136 | attn_resolutions, 137 | dropout=0.0, 138 | resamp_with_conv=True, 139 | in_channels, 140 | resolution, 141 | z_channels, 142 | give_pre_end=False, 143 | tanh_out=False, 144 | use_linear_attn=False, 145 | attn_type="vanilla", 146 | **ignorekwargs, 147 | ): 148 | super().__init__() 149 | if use_linear_attn: 150 | attn_type = "linear" 151 | self.ch = ch 152 | self.temb_ch = 0 153 | self.num_resolutions = len(ch_mult) 154 | self.num_res_blocks = num_res_blocks 155 | self.resolution = resolution 156 | self.in_channels = in_channels 157 | self.give_pre_end = give_pre_end 158 | self.tanh_out = tanh_out 159 | 160 | # compute in_ch_mult, block_in and curr_res at lowest res 161 | # in_ch_mult = (1,) + tuple(ch_mult) 162 | block_in = ch * ch_mult[self.num_resolutions - 1] 163 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 164 | self.z_shape = (1, z_channels, curr_res, curr_res) 165 | print( 166 | "Working with z of shape {} = {} dimensions.".format( 167 | self.z_shape, np.prod(self.z_shape) 168 | ) 169 | ) 170 | 171 | # z to block_in 172 | self.conv_in = torch.nn.Conv2d( 173 | z_channels, block_in, kernel_size=3, stride=1, padding=1 174 | ) 175 | 176 | # middle 177 | self.mid = nn.Module() 178 | self.mid.block_1 = ResnetBlock( 179 | in_channels=block_in, 180 | out_channels=block_in, 181 | temb_channels=self.temb_ch, 182 | dropout=dropout, 183 | ) 184 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 185 | self.mid.block_2 = ResnetBlock( 186 | in_channels=block_in, 187 | out_channels=block_in, 188 | temb_channels=self.temb_ch, 189 | dropout=dropout, 190 | ) 191 | 192 | # upsampling 193 | self.up = nn.ModuleList() 194 | for i_level in reversed(range(self.num_resolutions)): 195 | block = nn.ModuleList() 196 | attn = nn.ModuleList() 197 | block_out = ch * ch_mult[i_level] 198 | for i_block in range(self.num_res_blocks + 1): 199 | block.append( 200 | ResnetBlock( 201 | in_channels=block_in, 202 | out_channels=block_out, 203 | temb_channels=self.temb_ch, 204 | dropout=dropout, 205 | ) 206 | ) 207 | block_in = block_out 208 | if curr_res in attn_resolutions: 209 | attn.append(make_attn(block_in, attn_type=attn_type)) 210 | up = nn.Module() 211 | up.block = block 212 | up.attn = attn 213 | if i_level != 0: 214 | up.upsample = Upsample(block_in, resamp_with_conv) 215 | curr_res = curr_res * 2 216 | self.up.insert(0, up) # prepend to get consistent order 217 | 218 | # end 219 | self.norm_out = Normalize(block_in) 220 | self.conv_out = torch.nn.Conv2d( 221 | block_in, out_ch, kernel_size=3, stride=1, padding=1 222 | ) 223 | 224 | def forward(self, z): 225 | # assert z.shape[1:] == self.z_shape[1:] 226 | self.last_z_shape = z.shape 227 | 228 | # timestep embedding 229 | temb = None 230 | 231 | # z to block_in 232 | h = self.conv_in(z) 233 | 234 | # middle 235 | h = self.mid.block_1(h, temb) 236 | h = self.mid.attn_1(h) 237 | h = self.mid.block_2(h, temb) 238 | 239 | # upsampling 240 | for i_level in reversed(range(self.num_resolutions)): 241 | for i_block in range(self.num_res_blocks + 1): 242 | h = self.up[i_level].block[i_block](h, temb) 243 | if len(self.up[i_level].attn) > 0: 244 | h = self.up[i_level].attn[i_block](h) 245 | if i_level != 0: 246 | h = self.up[i_level].upsample(h) 247 | 248 | # end 249 | if self.give_pre_end: 250 | return h 251 | 252 | h = self.norm_out(h) 253 | h = nonlinearity(h) 254 | h = self.conv_out(h) 255 | if self.tanh_out: 256 | h = torch.tanh(h) 257 | return h 258 | 259 | 260 | # ----------------------HELPER FUNCTIONS------------------------ 261 | 262 | 263 | class LinearAttention(nn.Module): 264 | def __init__(self, dim, heads=4, dim_head=32): 265 | super().__init__() 266 | self.heads = heads 267 | hidden_dim = dim_head * heads 268 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 269 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 270 | 271 | def forward(self, x): 272 | b, c, h, w = x.shape 273 | qkv = self.to_qkv(x) 274 | q, k, v = rearrange( 275 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 276 | ) 277 | k = k.softmax(dim=-1) 278 | context = torch.einsum("bhdn,bhen->bhde", k, v) 279 | out = torch.einsum("bhde,bhdn->bhen", context, q) 280 | out = rearrange( 281 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 282 | ) 283 | return self.to_out(out) 284 | 285 | 286 | def nonlinearity(x): 287 | return x * torch.sigmoid(x) 288 | 289 | 290 | def make_attn(in_channels, attn_type="vanilla"): 291 | assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" 292 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 293 | if attn_type == "vanilla": 294 | return AttnBlock(in_channels) 295 | elif attn_type == "none": 296 | return nn.Identity(in_channels) 297 | else: 298 | return LinAttnBlock(in_channels) 299 | 300 | 301 | def Normalize(in_channels, num_groups=32): 302 | return torch.nn.GroupNorm( 303 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 304 | ) 305 | 306 | 307 | class Upsample(nn.Module): 308 | def __init__(self, in_channels, with_conv): 309 | super().__init__() 310 | self.with_conv = with_conv 311 | if self.with_conv: 312 | self.conv = torch.nn.Conv2d( 313 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 314 | ) 315 | 316 | def forward(self, x): 317 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 318 | if self.with_conv: 319 | x = self.conv(x) 320 | return x 321 | 322 | 323 | class Downsample(nn.Module): 324 | def __init__(self, in_channels, with_conv): 325 | super().__init__() 326 | self.with_conv = with_conv 327 | if self.with_conv: 328 | # no asymmetric padding in torch conv, must do it ourselves 329 | self.conv = torch.nn.Conv2d( 330 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 331 | ) 332 | 333 | def forward(self, x): 334 | if self.with_conv: 335 | pad = (0, 1, 0, 1) 336 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 337 | x = self.conv(x) 338 | else: 339 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 340 | return x 341 | 342 | 343 | class LinAttnBlock(LinearAttention): 344 | def __init__(self, in_channels): 345 | super().__init__(dim=in_channels, heads=1, dim_head=in_channels) 346 | 347 | 348 | class AttnBlock(nn.Module): 349 | def __init__(self, in_channels): 350 | super().__init__() 351 | self.in_channels = in_channels 352 | 353 | self.norm = Normalize(in_channels) 354 | self.q = torch.nn.Conv2d( 355 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 356 | ) 357 | self.k = torch.nn.Conv2d( 358 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 359 | ) 360 | self.v = torch.nn.Conv2d( 361 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 362 | ) 363 | self.proj_out = torch.nn.Conv2d( 364 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 365 | ) 366 | 367 | def forward(self, x): 368 | h_ = x 369 | h_ = self.norm(h_) 370 | q = self.q(h_) 371 | k = self.k(h_) 372 | v = self.v(h_) 373 | 374 | # compute attention 375 | b, c, h, w = q.shape 376 | q = q.reshape(b, c, h * w) 377 | q = q.permute(0, 2, 1) # b,hw,c 378 | k = k.reshape(b, c, h * w) # b,c,hw 379 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 380 | w_ = w_ * (int(c) ** (-0.5)) 381 | w_ = torch.nn.functional.softmax(w_, dim=2) 382 | 383 | # attend to values 384 | v = v.reshape(b, c, h * w) 385 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 386 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 387 | h_ = h_.reshape(b, c, h, w) 388 | 389 | h_ = self.proj_out(h_) 390 | 391 | return x + h_ 392 | 393 | 394 | class ResnetBlock(nn.Module): 395 | def __init__( 396 | self, 397 | *, 398 | in_channels, 399 | out_channels=None, 400 | conv_shortcut=False, 401 | dropout, 402 | temb_channels=512, 403 | ): 404 | super().__init__() 405 | self.in_channels = in_channels 406 | out_channels = in_channels if out_channels is None else out_channels 407 | self.out_channels = out_channels 408 | self.use_conv_shortcut = conv_shortcut 409 | 410 | self.norm1 = Normalize(in_channels) 411 | self.conv1 = torch.nn.Conv2d( 412 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 413 | ) 414 | if temb_channels > 0: 415 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 416 | self.norm2 = Normalize(out_channels) 417 | self.dropout = torch.nn.Dropout(dropout) 418 | self.conv2 = torch.nn.Conv2d( 419 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 420 | ) 421 | if self.in_channels != self.out_channels: 422 | if self.use_conv_shortcut: 423 | self.conv_shortcut = torch.nn.Conv2d( 424 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 425 | ) 426 | else: 427 | self.nin_shortcut = torch.nn.Conv2d( 428 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 429 | ) 430 | 431 | def forward(self, x, temb): 432 | h = x 433 | h = self.norm1(h) 434 | h = nonlinearity(h) 435 | h = self.conv1(h) 436 | 437 | if temb is not None: 438 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 439 | 440 | h = self.norm2(h) 441 | h = nonlinearity(h) 442 | h = self.dropout(h) 443 | h = self.conv2(h) 444 | 445 | if self.in_channels != self.out_channels: 446 | if self.use_conv_shortcut: 447 | x = self.conv_shortcut(x) 448 | else: 449 | x = self.nin_shortcut(x) 450 | 451 | return x + h 452 | -------------------------------------------------------------------------------- /medvae/utils/vae/diffusionmodels_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | from torch.utils.checkpoint import checkpoint 6 | 7 | __all__ = ["Encoder", "Decoder"] 8 | 9 | # ----------------------ENCODER & DECODER DEFINITIONS------------------------ 10 | 11 | 12 | class Encoder(nn.Module): 13 | def __init__( 14 | self, 15 | *, 16 | ch, 17 | out_ch, 18 | ch_mult=(1, 2, 4, 8), 19 | num_res_blocks, 20 | attn_resolutions, 21 | dropout=0.0, 22 | resamp_with_conv=True, 23 | in_channels, 24 | resolution, 25 | z_channels, 26 | double_z=True, 27 | use_linear_attn=False, 28 | attn_type="vanilla", 29 | **ignore_kwargs, 30 | ): 31 | super().__init__() 32 | if use_linear_attn: 33 | attn_type = "linear" 34 | self.ch = ch 35 | self.temb_ch = 0 36 | self.num_resolutions = len(ch_mult) 37 | self.num_res_blocks = num_res_blocks 38 | self.resolution = resolution 39 | self.in_channels = in_channels 40 | 41 | # downsampling 42 | self.conv_in = torch.nn.Conv3d( 43 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 44 | ) 45 | 46 | curr_res = resolution 47 | in_ch_mult = (1,) + tuple(ch_mult) 48 | self.in_ch_mult = in_ch_mult 49 | self.down = nn.ModuleList() 50 | for i_level in range(self.num_resolutions): 51 | block = nn.ModuleList() 52 | attn = nn.ModuleList() 53 | block_in = ch * in_ch_mult[i_level] 54 | block_out = ch * ch_mult[i_level] 55 | for i_block in range(self.num_res_blocks): 56 | block.append( 57 | ResnetBlock( 58 | in_channels=block_in, 59 | out_channels=block_out, 60 | temb_channels=self.temb_ch, 61 | dropout=dropout, 62 | ) 63 | ) 64 | block_in = block_out 65 | if curr_res in attn_resolutions: 66 | attn.append(make_attn(block_in, attn_type=attn_type)) 67 | down = nn.Module() 68 | down.block = block 69 | down.attn = attn 70 | if i_level != self.num_resolutions - 1: 71 | down.downsample = Downsample(block_in, resamp_with_conv) 72 | curr_res = curr_res // 2 73 | self.down.append(down) 74 | 75 | # middle 76 | self.mid = nn.Module() 77 | self.mid.block_1 = ResnetBlock( 78 | in_channels=block_in, 79 | out_channels=block_in, 80 | temb_channels=self.temb_ch, 81 | dropout=dropout, 82 | ) 83 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 84 | self.mid.block_2 = ResnetBlock( 85 | in_channels=block_in, 86 | out_channels=block_in, 87 | temb_channels=self.temb_ch, 88 | dropout=dropout, 89 | ) 90 | 91 | # end 92 | self.norm_out = Normalize(block_in) 93 | self.conv_out = torch.nn.Conv3d( 94 | block_in, 95 | 2 * z_channels if double_z else z_channels, 96 | kernel_size=3, 97 | stride=1, 98 | padding=1, 99 | ) 100 | 101 | def forward(self, x): 102 | # timestep embedding 103 | temb = None 104 | 105 | # downsampling 106 | hs = [checkpoint(self.conv_in, x, use_reentrant=False)] 107 | for i_level in range(self.num_resolutions): 108 | for i_block in range(self.num_res_blocks): 109 | h = checkpoint( 110 | self.down[i_level].block[i_block], hs[-1], temb, use_reentrant=False 111 | ) 112 | if len(self.down[i_level].attn) > 0: 113 | h = checkpoint( 114 | self.down[i_level].attn[i_block], h, use_reeentrant=False 115 | ) 116 | hs.append(h) 117 | if i_level != self.num_resolutions - 1: 118 | hs.append( 119 | checkpoint( 120 | self.down[i_level].downsample, hs[-1], use_reentrant=False 121 | ) 122 | ) 123 | 124 | # middle 125 | h = hs[-1] 126 | h = checkpoint(self.mid.block_1, h, temb, use_reentrant=False) 127 | h = checkpoint(self.mid.attn_1, h, use_reentrant=False) 128 | h = checkpoint(self.mid.block_2, h, temb, use_reentrant=False) 129 | 130 | # end 131 | h = checkpoint(self.norm_out, h, use_reentrant=False) 132 | h = nonlinearity(h) 133 | h = checkpoint(self.conv_out, h, use_reentrant=False) 134 | return h 135 | 136 | 137 | class Decoder(nn.Module): 138 | def __init__( 139 | self, 140 | *, 141 | ch, 142 | out_ch, 143 | ch_mult=(1, 2, 4, 8), 144 | num_res_blocks, 145 | attn_resolutions, 146 | dropout=0.0, 147 | resamp_with_conv=True, 148 | in_channels, 149 | resolution, 150 | z_channels, 151 | give_pre_end=False, 152 | tanh_out=False, 153 | use_linear_attn=False, 154 | attn_type="vanilla", 155 | **ignorekwargs, 156 | ): 157 | super().__init__() 158 | if use_linear_attn: 159 | attn_type = "linear" 160 | self.ch = ch 161 | self.temb_ch = 0 162 | self.num_resolutions = len(ch_mult) 163 | self.num_res_blocks = num_res_blocks 164 | self.resolution = resolution 165 | self.in_channels = in_channels 166 | self.give_pre_end = give_pre_end 167 | self.tanh_out = tanh_out 168 | 169 | # compute in_ch_mult, block_in and curr_res at lowest res 170 | # in_ch_mult = (1,) + tuple(ch_mult) 171 | block_in = ch * ch_mult[self.num_resolutions - 1] 172 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 173 | self.z_shape = (1, z_channels, curr_res, curr_res) 174 | print( 175 | "Working with z of shape {} = {} dimensions.".format( 176 | self.z_shape, np.prod(self.z_shape) 177 | ) 178 | ) 179 | 180 | # z to block_in 181 | self.conv_in = torch.nn.Conv3d( 182 | z_channels, block_in, kernel_size=3, stride=1, padding=1 183 | ) 184 | 185 | # middle 186 | self.mid = nn.Module() 187 | self.mid.block_1 = ResnetBlock( 188 | in_channels=block_in, 189 | out_channels=block_in, 190 | temb_channels=self.temb_ch, 191 | dropout=dropout, 192 | ) 193 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 194 | self.mid.block_2 = ResnetBlock( 195 | in_channels=block_in, 196 | out_channels=block_in, 197 | temb_channels=self.temb_ch, 198 | dropout=dropout, 199 | ) 200 | 201 | # upsampling 202 | self.up = nn.ModuleList() 203 | for i_level in reversed(range(self.num_resolutions)): 204 | block = nn.ModuleList() 205 | attn = nn.ModuleList() 206 | block_out = ch * ch_mult[i_level] 207 | for i_block in range(self.num_res_blocks + 1): 208 | block.append( 209 | ResnetBlock( 210 | in_channels=block_in, 211 | out_channels=block_out, 212 | temb_channels=self.temb_ch, 213 | dropout=dropout, 214 | ) 215 | ) 216 | block_in = block_out 217 | if curr_res in attn_resolutions: 218 | attn.append(make_attn(block_in, attn_type=attn_type)) 219 | up = nn.Module() 220 | up.block = block 221 | up.attn = attn 222 | if i_level != 0: 223 | up.upsample = Upsample(block_in, resamp_with_conv) 224 | curr_res = curr_res * 2 225 | self.up.insert(0, up) # prepend to get consistent order 226 | 227 | # end 228 | self.norm_out = Normalize(block_in) 229 | self.conv_out = torch.nn.Conv3d( 230 | block_in, out_ch, kernel_size=3, stride=1, padding=1 231 | ) 232 | 233 | def forward(self, z): 234 | # assert z.shape[1:] == self.z_shape[1:] 235 | self.last_z_shape = z.shape 236 | 237 | # timestep embedding 238 | temb = None 239 | 240 | # z to block_in 241 | h = checkpoint(self.conv_in, z, use_reentrant=False) 242 | 243 | # middle 244 | h = checkpoint(self.mid.block_1, h, temb, use_reentrant=False) 245 | h = checkpoint(self.mid.attn_1, h, use_reentrant=False) 246 | h = checkpoint(self.mid.block_2, h, temb, use_reentrant=False) 247 | 248 | # upsampling 249 | for i_level in reversed(range(self.num_resolutions)): 250 | for i_block in range(self.num_res_blocks + 1): 251 | h = checkpoint( 252 | self.up[i_level].block[i_block], h, temb, use_reentrant=False 253 | ) 254 | if len(self.up[i_level].attn) > 0: 255 | h = checkpoint( 256 | self.up[i_level].attn[i_block], h, use_reentrant=False 257 | ) 258 | if i_level != 0: 259 | h = checkpoint(self.up[i_level].upsample, h, use_reentrant=False) 260 | 261 | # end 262 | if self.give_pre_end: 263 | return h 264 | 265 | h = checkpoint(self.norm_out, h, use_reentrant=False) 266 | h = nonlinearity(h) 267 | h = checkpoint(self.conv_out, h, use_reentrant=False) 268 | if self.tanh_out: 269 | h = checkpoint(torch.tanh, h, use_reentrant=False) 270 | return h 271 | 272 | 273 | # ----------------------HELPER FUNCTIONS------------------------ 274 | 275 | 276 | class LinearAttention(nn.Module): 277 | def __init__(self, dim, heads=4, dim_head=32): 278 | super().__init__() 279 | self.heads = heads 280 | hidden_dim = dim_head * heads 281 | self.to_qkv = nn.Conv3d(dim, hidden_dim * 3, 1, bias=False) 282 | self.to_out = nn.Conv3d(hidden_dim, dim, 1) 283 | 284 | def forward(self, x): 285 | b, c, h, w = x.shape 286 | qkv = self.to_qkv(x) 287 | q, k, v = rearrange( 288 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 289 | ) 290 | k = k.softmax(dim=-1) 291 | context = torch.einsum("bhdn,bhen->bhde", k, v) 292 | out = torch.einsum("bhde,bhdn->bhen", context, q) 293 | out = rearrange( 294 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 295 | ) 296 | return self.to_out(out) 297 | 298 | 299 | def nonlinearity(x): 300 | return x * torch.sigmoid(x) 301 | 302 | 303 | def make_attn(in_channels, attn_type="vanilla"): 304 | assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" 305 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 306 | if attn_type == "vanilla": 307 | return AttnBlock(in_channels) 308 | elif attn_type == "none": 309 | return nn.Identity(in_channels) 310 | else: 311 | return LinAttnBlock(in_channels) 312 | 313 | 314 | def Normalize(in_channels, num_groups=32): 315 | return torch.nn.GroupNorm( 316 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 317 | ) 318 | 319 | 320 | class Upsample(nn.Module): 321 | def __init__(self, in_channels, with_conv): 322 | super().__init__() 323 | self.with_conv = with_conv 324 | if self.with_conv: 325 | self.conv = torch.nn.Conv3d( 326 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 327 | ) 328 | 329 | def forward(self, x): 330 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 331 | if self.with_conv: 332 | x = self.conv(x) 333 | return x 334 | 335 | 336 | class Downsample(nn.Module): 337 | def __init__(self, in_channels, with_conv): 338 | super().__init__() 339 | self.with_conv = with_conv 340 | if self.with_conv: 341 | # no asymmetric padding in torch conv, must do it ourselves 342 | self.conv = torch.nn.Conv3d( 343 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 344 | ) 345 | 346 | def forward(self, x): 347 | if self.with_conv: 348 | pad = (0, 1, 0, 1, 0, 1) 349 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 350 | x = self.conv(x) 351 | else: 352 | x = torch.nn.functional.avg_pool3d(x, kernel_size=2, stride=2) 353 | return x 354 | 355 | 356 | class LinAttnBlock(LinearAttention): 357 | def __init__(self, in_channels): 358 | super().__init__(dim=in_channels, heads=1, dim_head=in_channels) 359 | 360 | 361 | class AttnBlock(nn.Module): 362 | def __init__(self, in_channels): 363 | super().__init__() 364 | self.in_channels = in_channels 365 | 366 | self.norm = Normalize(in_channels) 367 | self.q = torch.nn.Conv3d( 368 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 369 | ) 370 | self.k = torch.nn.Conv3d( 371 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 372 | ) 373 | self.v = torch.nn.Conv3d( 374 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 375 | ) 376 | self.proj_out = torch.nn.Conv3d( 377 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 378 | ) 379 | 380 | def forward(self, x): 381 | h_ = x 382 | h_ = self.norm(h_) 383 | q = self.q(h_) 384 | k = self.k(h_) 385 | v = self.v(h_) 386 | 387 | # compute attention 388 | b, c, d, h, w = q.shape 389 | q = q.reshape(b, c, d * h * w) 390 | q = q.permute(0, 2, 1) # b,dhw,c 391 | k = k.reshape(b, c, d * h * w) # b,c,dhw 392 | w_ = torch.bmm(q, k) # b,dhw,dhw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 393 | w_ = w_ * (int(c) ** (-0.5)) 394 | w_ = torch.nn.functional.softmax(w_, dim=2) 395 | 396 | # # attend to values 397 | v = v.reshape(b, c, d * h * w) 398 | w_ = w_.permute(0, 2, 1) # b,dhw,dhw (first hw of k, second of q) 399 | h_ = torch.bmm(v, w_) # b, c,dhw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 400 | h_ = h_.reshape(b, c, d, h, w) 401 | 402 | h_ = self.proj_out(h_) 403 | 404 | return x + h_ 405 | 406 | 407 | class ResnetBlock(nn.Module): 408 | def __init__( 409 | self, 410 | *, 411 | in_channels, 412 | out_channels=None, 413 | conv_shortcut=False, 414 | dropout, 415 | temb_channels=512, 416 | ): 417 | super().__init__() 418 | self.in_channels = in_channels 419 | out_channels = in_channels if out_channels is None else out_channels 420 | self.out_channels = out_channels 421 | self.use_conv_shortcut = conv_shortcut 422 | 423 | self.norm1 = Normalize(in_channels) 424 | self.conv1 = torch.nn.Conv3d( 425 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 426 | ) 427 | if temb_channels > 0: 428 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 429 | self.norm2 = Normalize(out_channels) 430 | self.dropout = torch.nn.Dropout(dropout) 431 | self.conv2 = torch.nn.Conv3d( 432 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 433 | ) 434 | if self.in_channels != self.out_channels: 435 | if self.use_conv_shortcut: 436 | self.conv_shortcut = torch.nn.Conv3d( 437 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 438 | ) 439 | else: 440 | self.nin_shortcut = torch.nn.Conv3d( 441 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 442 | ) 443 | 444 | def forward(self, x, temb): 445 | h = x 446 | h = self.norm1(h) 447 | h = nonlinearity(h) 448 | h = self.conv1(h) 449 | 450 | if temb is not None: 451 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 452 | 453 | h = self.norm2(h) 454 | h = nonlinearity(h) 455 | h = self.dropout(h) 456 | h = self.conv2(h) 457 | 458 | if self.in_channels != self.out_channels: 459 | if self.use_conv_shortcut: 460 | x = self.conv_shortcut(x) 461 | else: 462 | x = self.nin_shortcut(x) 463 | 464 | return x + h 465 | --------------------------------------------------------------------------------