├── requirements.txt ├── model-zoo ├── models │ ├── brain_image_synthesis_latent_diffusion_model │ │ ├── scripts │ │ │ ├── __init__.py │ │ │ ├── saver.py │ │ │ └── sampler.py │ │ ├── docs │ │ │ ├── figure_1.png │ │ │ └── README.md │ │ ├── large_files.yml │ │ └── configs │ │ │ ├── logging.conf │ │ │ ├── metadata.json │ │ │ └── inference.json │ ├── cxr_image_synthesis_latent_diffusion_model │ │ ├── scripts │ │ │ ├── __init__.py │ │ │ ├── saver.py │ │ │ └── sampler.py │ │ ├── docs │ │ │ ├── figure_1.png │ │ │ └── README.md │ │ ├── large_files.yml │ │ └── configs │ │ │ ├── logging.conf │ │ │ ├── metadata.json │ │ │ └── inference.json │ └── mednist_ddpm │ │ └── bundle │ │ ├── configs │ │ ├── logging.conf │ │ ├── train_multigpu.yaml │ │ ├── infer.yaml │ │ ├── common.yaml │ │ ├── metadata.json │ │ └── train.yaml │ │ ├── scripts │ │ └── __init__.py │ │ └── docs │ │ ├── README.md │ │ ├── sub_train.sh │ │ └── sub_train_multigpu.sh └── README.md ├── requirements-min.txt ├── .deepsource.toml ├── generative ├── networks │ ├── __init__.py │ ├── layers │ │ └── __init__.py │ ├── blocks │ │ ├── __init__.py │ │ ├── encoder_modules.py │ │ ├── spade_norm.py │ │ ├── transformerblock.py │ │ └── selfattention.py │ ├── schedulers │ │ └── __init__.py │ └── nets │ │ ├── __init__.py │ │ └── transformer.py ├── version.py ├── __init__.py ├── engines │ ├── __init__.py │ └── prepare_batch.py ├── losses │ ├── __init__.py │ └── spectral_loss.py ├── metrics │ ├── __init__.py │ ├── mmd.py │ ├── fid.py │ └── ms_ssim.py ├── utils │ ├── __init__.py │ ├── misc.py │ ├── enums.py │ └── component_store.py └── inferers │ └── __init__.py ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── tests ├── test_compute_fid_metric.py ├── __init__.py ├── test_transformer.py ├── test_compute_mmd_metric.py ├── test_misc.py ├── test_controlnet.py ├── test_selfattention.py ├── min_tests.py ├── test_component_store.py ├── test_scheduler_ddim.py ├── test_scheduler_pndm.py ├── test_spectral_loss.py ├── test_perceptual_loss.py ├── test_compute_multiscalessim_metric.py ├── test_vector_quantizer.py ├── test_scheduler_ddpm.py ├── test_encoder_modules.py ├── test_adversarial.py ├── test_spade_vaegan.py ├── runner.py ├── test_patch_gan.py └── test_integration_workflows_adversarial.py ├── requirements-dev.txt ├── pyproject.toml ├── .pre-commit-config.yaml ├── README.md ├── CODE_OF_CONDUCT.md ├── setup.cfg ├── .gitignore └── tutorials └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.17 2 | torch>=1.8 3 | monai>=1.2.0rc1 4 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-min.txt: -------------------------------------------------------------------------------- 1 | # Requirements for minimal tests 2 | -r requirements.txt 3 | setuptools>65.5.0,<66.0.0 4 | coverage>=5.5 5 | parameterized 6 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Project-MONAI/GenerativeModels/HEAD/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/figure_1.png -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Project-MONAI/GenerativeModels/HEAD/model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/figure_1.png -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["tests/**"] 4 | 5 | [[analyzers]] 6 | name = "python" 7 | enabled = true 8 | 9 | [analyzers.meta] 10 | runtime_version = "3.x.x" 11 | 12 | [[analyzers]] 13 | name = "test-coverage" 14 | enabled = true 15 | 16 | [[analyzers]] 17 | name = "docker" 18 | enabled = true 19 | 20 | [[analyzers]] 21 | name = "shell" 22 | enabled = true 23 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml: -------------------------------------------------------------------------------- 1 | large_files: 2 | - path: "models/autoencoder.pth" 3 | url: "https://drive.google.com/uc?export=download&id=1paDN1m-Q_Oy8d_BanPkRTi3RlNB_Sv_h" 4 | hash_val: "" 5 | hash_type: "" 6 | - path: "models/diffusion_model.pth" 7 | url: "https://drive.google.com/uc?export=download&id=1CjcmiPu5_QWr-f7wDJsXrCCcVeczneGT" 8 | hash_val: "" 9 | hash_type: "" 10 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/large_files.yml: -------------------------------------------------------------------------------- 1 | large_files: 2 | - path: "models/autoencoder.pth" 3 | url: "https://drive.google.com/uc?export=download&id=1CZHwxHJWybOsDavipD0EorDPOo_mzNeX" 4 | hash_val: "" 5 | hash_type: "" 6 | - path: "models/diffusion_model.pth" 7 | url: "https://drive.google.com/uc?export=download&id=1XO-ak93ZuOcGTCpgRtqgIeZq3dG5ExN6" 8 | hash_val: "" 9 | hash_type: "" 10 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=fullFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=consoleHandler 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=INFO 17 | formatter=fullFormatter 18 | args=(sys.stdout,) 19 | 20 | [formatter_fullFormatter] 21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s 22 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=fullFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=consoleHandler 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=INFO 17 | formatter=fullFormatter 18 | args=(sys.stdout,) 19 | 20 | [formatter_fullFormatter] 21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s 22 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=fullFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=consoleHandler 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=INFO 17 | formatter=fullFormatter 18 | args=(sys.stdout,) 19 | 20 | [formatter_fullFormatter] 21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s 22 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | def inv_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: 5 | """ 6 | This inverts comparison for those metrics which reduce like loss values, such that the lower one is better. 7 | 8 | Args: 9 | current_metric: metric value of current round computation. 10 | prev_best: the best metric value of previous rounds to compare with. 11 | """ 12 | return current_metric < prev_best 13 | -------------------------------------------------------------------------------- /model-zoo/README.md: -------------------------------------------------------------------------------- 1 | # Generative Models - Model Zoo 2 | 3 | In this directory, we include the prototypes of the model zoo for the MONAI Generative Models project. 4 | Different from the official one, we do not include all features from the [official one](https://github.com/Project-MONAI/model-zoo). 5 | For this reason, it is not possible to download the models directly with the `python -m monai.bundle run ...` command. 6 | In order to use our models, please, manually download them with their link specified in the `large_files.yml` files, 7 | and place them inside the folder path specified in the same .yml file. 8 | -------------------------------------------------------------------------------- /generative/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/saver.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | 8 | class JPGSaver: 9 | def __init__(self, output_dir: str) -> None: 10 | super().__init__() 11 | self.output_dir = output_dir 12 | 13 | def save(self, image_data: torch.Tensor, file_name: str) -> None: 14 | image_data = np.clip(image_data.cpu().numpy(), 0, 1) 15 | image_data = (image_data * 255).astype(np.uint8) 16 | im = Image.fromarray(image_data[0, 0]) 17 | im.save(self.output_dir + "/" + file_name + ".jpg") 18 | -------------------------------------------------------------------------------- /generative/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | __version__ = "0.2.3" 15 | -------------------------------------------------------------------------------- /generative/networks/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from .vector_quantizer import EMAQuantizer, VectorQuantizer 13 | -------------------------------------------------------------------------------- /generative/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .version import __version__ 15 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/docs/README.md: -------------------------------------------------------------------------------- 1 | 2 | # MedNIST DDPM Example Bundle 3 | 4 | This implements roughly equivalent code to the "Denoising Diffusion Probabilistic Models with MedNIST Dataset" example notebook. This includes scripts for training with single or multiple GPUs and a visualisation notebook. 5 | 6 | The files included here demonstrate how to use the bundle: 7 | * [2d_ddpm_bundle_tutorial.ipynb](./2d_ddpm_bundle_tutorial.ipynb) - demonstrates command line and in-code invocation of the bundle's training and inference scripts 8 | * [sub_train.sh](sub_train.sh) - SLURM submission script example for training 9 | * [sub_train_multigpu.sh](sub_train_multigpu.sh) - SLURM submission script example for training with multiple GPUs 10 | -------------------------------------------------------------------------------- /generative/engines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch 15 | from .trainer import AdversarialTrainer 16 | -------------------------------------------------------------------------------- /generative/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .adversarial_loss import PatchAdversarialLoss 15 | from .perceptual import PerceptualLoss 16 | from .spectral_loss import JukeboxLoss 17 | -------------------------------------------------------------------------------- /generative/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .fid import FIDMetric 15 | from .mmd import MMDMetric 16 | from .ms_ssim import MultiScaleSSIMMetric 17 | from .ssim import SSIMMetric 18 | -------------------------------------------------------------------------------- /generative/networks/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .encoder_modules import SpatialRescaler 15 | from .selfattention import SABlock 16 | from .transformerblock import TransformerBlock 17 | -------------------------------------------------------------------------------- /generative/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .component_store import ComponentStore 15 | from .enums import AdversarialIterationEvents, AdversarialKeys 16 | from .misc import unsqueeze_left, unsqueeze_right 17 | -------------------------------------------------------------------------------- /generative/networks/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .ddim import DDIMScheduler 15 | from .ddpm import DDPMScheduler 16 | from .pndm import PNDMScheduler 17 | from .scheduler import NoiseSchedules, Scheduler 18 | -------------------------------------------------------------------------------- /generative/inferers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .inferer import ( 15 | ControlNetDiffusionInferer, 16 | ControlNetLatentDiffusionInferer, 17 | DiffusionInferer, 18 | LatentDiffusionInferer, 19 | VQVAETransformerInferer, 20 | ) 21 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/train_multigpu.yaml: -------------------------------------------------------------------------------- 1 | # This can be mixed in with the training script to enable multi-GPU training 2 | 3 | network: 4 | _target_: torch.nn.parallel.DistributedDataParallel 5 | module: $@network_def.to(@device) 6 | device_ids: ['@device'] 7 | find_unused_parameters: true 8 | 9 | tsampler: 10 | _target_: DistributedSampler 11 | dataset: '@train_ds' 12 | even_divisible: true 13 | shuffle: true 14 | train_loader#sampler: '@tsampler' 15 | train_loader#shuffle: false 16 | 17 | vsampler: 18 | _target_: DistributedSampler 19 | dataset: '@val_ds' 20 | even_divisible: false 21 | shuffle: false 22 | val_loader#sampler: '@vsampler' 23 | 24 | training: 25 | - $import torch.distributed as dist 26 | - $dist.init_process_group(backend='nccl') 27 | - $torch.cuda.set_device(@device) 28 | - $monai.utils.set_determinism(seed=123), 29 | - $@trainer.run() 30 | - $dist.destroy_process_group() 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from setuptools import find_packages, setup 15 | 16 | setup( 17 | name="monai-generative", 18 | packages=find_packages(exclude=[]), 19 | version="0.2.3", 20 | description="Installer to help to use the prototypes from MONAI generative models in other projects.", 21 | install_requires=["monai>=1.3.0"], 22 | ) 23 | -------------------------------------------------------------------------------- /generative/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import TypeVar 15 | 16 | T = TypeVar("T") 17 | 18 | 19 | def unsqueeze_right(arr: T, ndim: int) -> T: 20 | """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" 21 | return arr[(...,) + (None,) * (ndim - arr.ndim)] 22 | 23 | 24 | def unsqueeze_left(arr: T, ndim: int) -> T: 25 | """Preppend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" 26 | return arr[(None,) * (ndim - arr.ndim)] 27 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/docs/sub_train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH -J mednist_train 4 | #SBATCH -c 4 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --time=2:00:00 7 | #SBATCH -p small 8 | 9 | set -v 10 | 11 | # change this if run submitted from a different directory 12 | export BUNDLE="$(pwd)/.." 13 | 14 | # have to set PYTHONPATH to find MONAI and GenerativeModels as well as the bundle's script directory 15 | export PYTHONPATH="$HOME/MONAI:$HOME/GenerativeModels:$BUNDLE" 16 | 17 | # change this to load a checkpoint instead of started from scratch 18 | CKPT=none 19 | 20 | CONFIG="'$BUNDLE/configs/common.yaml', '$BUNDLE/configs/train.yaml'" 21 | 22 | # change this to point to where MedNIST is located 23 | DATASET="$(pwd)" 24 | 25 | # it's useful to include the configuration in the log file 26 | cat "$BUNDLE/configs/common.yaml" 27 | cat "$BUNDLE/configs/train.yaml" 28 | 29 | python -m monai.bundle run training \ 30 | --meta_file "$BUNDLE/configs/metadata.json" \ 31 | --config_file "$CONFIG" \ 32 | --logging_file "$BUNDLE/configs/logging.conf" \ 33 | --bundle_root "$BUNDLE" \ 34 | --dataset_dir "$DATASET" 35 | -------------------------------------------------------------------------------- /generative/networks/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .autoencoderkl import AutoencoderKL 15 | from .controlnet import ControlNet 16 | from .diffusion_model_unet import DiffusionModelUNet 17 | from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator 18 | from .spade_autoencoderkl import SPADEAutoencoderKL 19 | from .spade_diffusion_model_unet import SPADEDiffusionModelUNet 20 | from .spade_network import SPADENet 21 | from .transformer import DecoderOnlyTransformer 22 | from .vqvae import VQVAE 23 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import nibabel as nib 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class NiftiSaver: 9 | def __init__(self, output_dir: str) -> None: 10 | super().__init__() 11 | self.output_dir = output_dir 12 | self.affine = np.array( 13 | [ 14 | [-1.0, 0.0, 0.0, 96.48149872], 15 | [0.0, 1.0, 0.0, -141.47715759], 16 | [0.0, 0.0, 1.0, -156.55375671], 17 | [0.0, 0.0, 0.0, 1.0], 18 | ] 19 | ) 20 | 21 | def save(self, image_data: torch.Tensor, file_name: str) -> None: 22 | image_data = image_data.cpu().numpy() 23 | image_data = image_data[0, 0, 5:-5, 5:-5, :-15] 24 | image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min()) 25 | image_data = (image_data * 255).astype(np.uint8) 26 | 27 | empty_header = nib.Nifti1Header() 28 | sample_nii = nib.Nifti1Image(image_data, self.affine, empty_header) 29 | nib.save(sample_nii, f"{str(self.output_dir)}/{file_name}.nii.gz") 30 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/infer.yaml: -------------------------------------------------------------------------------- 1 | # This defines an inference script for generating a random image to a Pytorch file 2 | 3 | batch_size: 1 4 | num_workers: 0 5 | 6 | noise: $torch.rand(1,1,@image_dim,@image_dim) # create a random image every time this program is run 7 | 8 | out_file: "" # where to save the tensor to 9 | 10 | # using a lambda this defines a simple sampling function used below 11 | sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)' 12 | 13 | load_state: '$@network.load_state_dict(torch.load(@ckpt_path))' # command to load the saved model weights 14 | 15 | save_trans: 16 | _target_: Compose 17 | transforms: 18 | - _target_: ScaleIntensity 19 | minv: 0.0 20 | maxv: 255.0 21 | - _target_: ToTensor 22 | track_meta: false 23 | - _target_: SaveImage 24 | output_ext: "jpg" 25 | resample: false 26 | output_dtype: '$torch.uint8' 27 | separate_folder: false 28 | output_postfix: '@out_file' 29 | 30 | # program to load the model weights, run `sample`, and store results to `out_file` 31 | testing: 32 | - '@load_state' 33 | - '$torch.save(@sample(@noise.to(@device)), @out_file)' 34 | 35 | #alternative version which saves to a jpg file 36 | testing_jpg: 37 | - '@load_state' 38 | - '$@save_trans(@sample(@noise.to(@device))[0])' 39 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/docs/sub_train_multigpu.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | #SBATCH --nodes=1 3 | #SBATCH -J mednist_train 4 | #SBATCH -c 4 5 | #SBATCH --gres=gpu:2 6 | #SBATCH --time=2:00:00 7 | #SBATCH -p big 8 | 9 | set -v 10 | 11 | # change this if run submitted from a different directory 12 | export BUNDLE="$(pwd)/.." 13 | 14 | # have to set PYTHONPATH to find MONAI and GenerativeModels as well as the bundle's script directory 15 | export PYTHONPATH="$HOME/MONAI:$HOME/GenerativeModels:$BUNDLE" 16 | 17 | # change this to load a checkpoint instead of started from scratch 18 | CKPT=none 19 | 20 | CONFIG="'$BUNDLE/configs/common.yaml', '$BUNDLE/configs/train.yaml', '$BUNDLE/configs/train_multigpu.yaml'" 21 | 22 | # change this to point to where MedNIST is located 23 | DATASET="$(pwd)" 24 | 25 | # it's useful to include the configuration in the log file 26 | cat "$BUNDLE/configs/common.yaml" 27 | cat "$BUNDLE/configs/train.yaml" 28 | cat "$BUNDLE/configs/train_multigpu.yaml" 29 | 30 | # remember to change arguments to match how many nodes and GPUs you have 31 | torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training \ 32 | --meta_file "$BUNDLE/configs/metadata.json" \ 33 | --config_file "$CONFIG" \ 34 | --logging_file "$BUNDLE/configs/logging.conf" \ 35 | --bundle_root "$BUNDLE" \ 36 | --dataset_dir "$DATASET" 37 | -------------------------------------------------------------------------------- /tests/test_compute_fid_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import numpy as np 17 | import torch 18 | 19 | from generative.metrics import FIDMetric 20 | 21 | 22 | class TestFIDMetric(unittest.TestCase): 23 | def test_results(self): 24 | x = torch.Tensor([[1, 2], [1, 2], [1, 2]]) 25 | y = torch.Tensor([[2, 2], [1, 2], [1, 2]]) 26 | results = FIDMetric()(x, y) 27 | np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4) 28 | 29 | def test_input_dimensions(self): 30 | with self.assertRaises(ValueError): 31 | FIDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) 32 | 33 | 34 | if __name__ == "__main__": 35 | unittest.main() 36 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Full requirements for developments 2 | -r requirements-min.txt 3 | pytorch-ignite==0.4.10 4 | gdown>=4.4.0 5 | scipy 6 | itk>=5.2 7 | nibabel 8 | pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571 9 | tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776 10 | scikit-image>=0.19.0 11 | tqdm>=4.47.0 12 | lmdb 13 | flake8>=3.8.1 14 | flake8-bugbear 15 | flake8-comprehensions 16 | flake8-executable 17 | pylint!=2.13 # https://github.com/PyCQA/pylint/issues/5969 18 | mccabe 19 | pep8-naming 20 | pycodestyle 21 | pyflakes 22 | black 23 | isort 24 | pytype>=2020.6.1; platform_system != "Windows" 25 | types-pkg_resources 26 | mypy>=0.790 27 | ninja 28 | torchvision 29 | psutil 30 | Sphinx==3.5.3 31 | recommonmark==0.6.0 32 | sphinx-autodoc-typehints==1.11.1 33 | sphinx-rtd-theme==0.5.2 34 | cucim==22.8.1; platform_system == "Linux" 35 | openslide-python==1.1.2 36 | imagecodecs; platform_system == "Linux" or platform_system == "Darwin" 37 | tifffile; platform_system == "Linux" or platform_system == "Darwin" 38 | pandas 39 | requests 40 | einops 41 | transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 42 | mlflow 43 | matplotlib!=3.5.0 44 | tensorboardX 45 | types-PyYAML 46 | pyyaml 47 | fire 48 | jsonschema 49 | pynrrd 50 | pre-commit 51 | pydicom 52 | h5py 53 | nni 54 | optuna 55 | git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded 56 | lpips==0.1.4 57 | xformers==0.0.16 58 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn as nn 5 | from monai.utils import optional_import 6 | from torch.cuda.amp import autocast 7 | 8 | tqdm, has_tqdm = optional_import("tqdm", name="tqdm") 9 | 10 | 11 | class Sampler: 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | @torch.no_grad() 16 | def sampling_fn( 17 | self, 18 | noise: torch.Tensor, 19 | autoencoder_model: nn.Module, 20 | diffusion_model: nn.Module, 21 | scheduler: nn.Module, 22 | prompt_embeds: torch.Tensor, 23 | guidance_scale: float = 7.0, 24 | scale_factor: float = 0.3, 25 | ) -> torch.Tensor: 26 | if has_tqdm: 27 | progress_bar = tqdm(scheduler.timesteps) 28 | else: 29 | progress_bar = iter(scheduler.timesteps) 30 | 31 | for t in progress_bar: 32 | noise_input = torch.cat([noise] * 2) 33 | model_output = diffusion_model( 34 | noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds 35 | ) 36 | noise_pred_uncond, noise_pred_text = model_output.chunk(2) 37 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 38 | noise, _ = scheduler.step(noise_pred, t, noise) 39 | 40 | with autocast(): 41 | sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor) 42 | 43 | return sample 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import sys 15 | import unittest 16 | import warnings 17 | 18 | 19 | def _enter_pr_4800(self): 20 | """ 21 | code from https://github.com/python/cpython/pull/4800 22 | """ 23 | # The __warningregistry__'s need to be in a pristine state for tests 24 | # to work properly. 25 | for v in list(sys.modules.values()): 26 | if getattr(v, "__warningregistry__", None): 27 | v.__warningregistry__ = {} 28 | self.warnings_manager = warnings.catch_warnings(record=True) 29 | self.warnings = self.warnings_manager.__enter__() 30 | warnings.simplefilter("always", self.expected) 31 | return self 32 | 33 | 34 | # FIXME: workaround for https://bugs.python.org/issue29620 35 | try: 36 | # Suppression for issue #494: tests/__init__.py:34: error: Cannot assign to a method 37 | unittest.case._AssertWarnsContext.__enter__ = _enter_pr_4800 # type: ignore 38 | except AttributeError: 39 | pass 40 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch.nn as nn 5 | from monai.utils import optional_import 6 | from torch.cuda.amp import autocast 7 | 8 | tqdm, has_tqdm = optional_import("tqdm", name="tqdm") 9 | 10 | 11 | class Sampler: 12 | def __init__(self) -> None: 13 | super().__init__() 14 | 15 | @torch.no_grad() 16 | def sampling_fn( 17 | self, 18 | input_noise: torch.Tensor, 19 | autoencoder_model: nn.Module, 20 | diffusion_model: nn.Module, 21 | scheduler: nn.Module, 22 | conditioning: torch.Tensor, 23 | ) -> torch.Tensor: 24 | if has_tqdm: 25 | progress_bar = tqdm(scheduler.timesteps) 26 | else: 27 | progress_bar = iter(scheduler.timesteps) 28 | 29 | image = input_noise 30 | cond_concat = conditioning.squeeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 31 | cond_concat = cond_concat.expand(list(cond_concat.shape[0:2]) + list(input_noise.shape[2:])) 32 | for t in progress_bar: 33 | with torch.no_grad(): 34 | model_output = diffusion_model( 35 | torch.cat((image, cond_concat), dim=1), 36 | timesteps=torch.Tensor((t,)).to(input_noise.device).long(), 37 | context=conditioning, 38 | ) 39 | image, _ = scheduler.step(model_output, t, image) 40 | 41 | with torch.no_grad(): 42 | with autocast(): 43 | sample = autoencoder_model.decode_stage_2_outputs(image) 44 | 45 | return sample 46 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/common.yaml: -------------------------------------------------------------------------------- 1 | # This file defines common definitions used in training and inference, most importantly the network definition 2 | 3 | imports: 4 | - $import os 5 | - $import datetime 6 | - $import torch 7 | - $import scripts 8 | - $import monai 9 | - $import generative 10 | - $import torch.distributed as dist 11 | 12 | image: $monai.utils.CommonKeys.IMAGE 13 | label: $monai.utils.CommonKeys.LABEL 14 | pred: $monai.utils.CommonKeys.PRED 15 | 16 | is_dist: '$dist.is_initialized()' 17 | rank: '$dist.get_rank() if @is_dist else 0' 18 | is_not_rank0: '$@rank > 0' 19 | device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' 20 | 21 | network_def: 22 | _target_: generative.networks.nets.DiffusionModelUNet 23 | spatial_dims: 2 24 | in_channels: 1 25 | out_channels: 1 26 | num_channels: [64, 128, 128] 27 | attention_levels: [false, true, true] 28 | num_res_blocks: 1 29 | num_head_channels: 128 30 | 31 | network: $@network_def.to(@device) 32 | 33 | bundle_root: . 34 | ckpt_path: $@bundle_root + '/models/model.pt' 35 | use_amp: true 36 | image_dim: 64 37 | image_size: [1, '@image_dim', '@image_dim'] 38 | num_train_timesteps: 1000 39 | 40 | base_transforms: 41 | - _target_: LoadImaged 42 | keys: '@image' 43 | image_only: true 44 | - _target_: EnsureChannelFirstd 45 | keys: '@image' 46 | - _target_: ScaleIntensityRanged 47 | keys: '@image' 48 | a_min: 0.0 49 | a_max: 255.0 50 | b_min: 0.0 51 | b_max: 1.0 52 | clip: true 53 | 54 | scheduler: 55 | _target_: generative.networks.schedulers.DDPMScheduler 56 | num_train_timesteps: '@num_train_timesteps' 57 | 58 | inferer: 59 | _target_: generative.inferers.DiffusionInferer 60 | scheduler: '@scheduler' 61 | -------------------------------------------------------------------------------- /tests/test_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from monai.networks import eval_mode 18 | 19 | from generative.networks.nets import DecoderOnlyTransformer 20 | 21 | 22 | class TestDecoderOnlyTransformer(unittest.TestCase): 23 | def test_unconditioned_models(self): 24 | net = DecoderOnlyTransformer( 25 | num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2 26 | ) 27 | with eval_mode(net): 28 | net.forward(torch.randint(0, 10, (1, 16))) 29 | 30 | def test_conditioned_models(self): 31 | net = DecoderOnlyTransformer( 32 | num_tokens=10, 33 | max_seq_len=16, 34 | attn_layers_dim=8, 35 | attn_layers_depth=2, 36 | attn_layers_heads=2, 37 | with_cross_attention=True, 38 | embedding_dropout_rate=0, 39 | ) 40 | with eval_mode(net): 41 | net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8)) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /tests/test_compute_mmd_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import numpy as np 17 | import torch 18 | from parameterized import parameterized 19 | 20 | from generative.metrics import MMDMetric 21 | 22 | TEST_CASES = [ 23 | [ 24 | {"y_transform": None, "y_pred_transform": None}, 25 | {"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])}, 26 | 0.0, 27 | ], 28 | [ 29 | {"y_transform": None, "y_pred_transform": None}, 30 | {"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])}, 31 | 0.0, 32 | ], 33 | ] 34 | 35 | 36 | class TestMMDMetric(unittest.TestCase): 37 | @parameterized.expand(TEST_CASES) 38 | def test_results(self, input_param, input_data, expected_val): 39 | metric = MMDMetric(**input_param) 40 | results = metric(**input_data) 41 | np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) 42 | 43 | def test_if_inputs_different_shapes(self): 44 | with self.assertRaises(ValueError): 45 | MMDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145])) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37', 'py38', 'py39', 'py310'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | ( 7 | /( 8 | # exclude a few common directories in the root of the project 9 | \.eggs 10 | | \.git 11 | | \.hg 12 | | \.mypy_cache 13 | | \.tox 14 | | \.venv 15 | | venv 16 | | \.pytype 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | # also separately exclude a file named versioneer.py 23 | | generative/_version.py 24 | ) 25 | ''' 26 | 27 | [tool.pycln] 28 | all = true 29 | 30 | [tool.pytype] 31 | # Space-separated list of files or directories to exclude. 32 | exclude = ["versioneer.py", "_version.py", "tutorials/"] 33 | # Space-separated list of files or directories to process. 34 | inputs = ["generative"] 35 | # Keep going past errors to analyze as many files as possible. 36 | keep_going = true 37 | # Run N jobs in parallel. 38 | jobs = 8 39 | # All pytype output goes here. 40 | output = ".pytype" 41 | # Paths to source code directories, separated by ':'. 42 | pythonpath = "." 43 | # Check attribute values against their annotations. 44 | check_attribute_types = true 45 | # Check container mutations against their annotations. 46 | check_container_types = true 47 | # Check parameter defaults and assignments against their annotations. 48 | check_parameter_types = true 49 | # Check variable values against their annotations. 50 | check_variable_types = true 51 | # Comma or space separated list of error names to ignore. 52 | disable = ["pyi-error"] 53 | # Report errors. 54 | report_errors = true 55 | # Experimental: Infer precise return types even for invalid function calls. 56 | precise_return = true 57 | # Experimental: solve unknown types to label with structural types. 58 | protocols = true 59 | # Experimental: Only load submodules that are explicitly imported. 60 | strict_import = false 61 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import numpy as np 17 | import torch 18 | from parameterized import parameterized 19 | 20 | from generative.utils import unsqueeze_left, unsqueeze_right 21 | 22 | RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))] 23 | 24 | LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))] 25 | 26 | ALL_CASES = [ 27 | (np.random.rand(3, 4), 2, (3, 4)), 28 | (np.random.rand(3, 4), 0, (3, 4)), 29 | (np.random.rand(3, 4), -1, (3, 4)), 30 | (np.array(3), 4, (1, 1, 1, 1)), 31 | (np.array(3), 0, ()), 32 | (torch.rand(3, 4), 2, (3, 4)), 33 | (torch.rand(3, 4), 0, (3, 4)), 34 | (torch.rand(3, 4), -1, (3, 4)), 35 | (torch.tensor(3), 4, (1, 1, 1, 1)), 36 | (torch.tensor(3), 0, ()), 37 | ] 38 | 39 | 40 | class TestUnsqueeze(unittest.TestCase): 41 | @parameterized.expand(RIGHT_CASES + ALL_CASES) 42 | def test_unsqueeze_right(self, arr, ndim, shape): 43 | self.assertEqual(unsqueeze_right(arr, ndim).shape, shape) 44 | 45 | @parameterized.expand(LEFT_CASES + ALL_CASES) 46 | def test_unsqueeze_left(self, arr, ndim, shape): 47 | self.assertEqual(unsqueeze_left(arr, ndim).shape, shape) 48 | -------------------------------------------------------------------------------- /tests/test_controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from monai.networks import eval_mode 18 | from parameterized import parameterized 19 | 20 | from generative.networks.nets.controlnet import ControlNet 21 | 22 | TEST_CASES = [ 23 | [ 24 | { 25 | "spatial_dims": 2, 26 | "in_channels": 1, 27 | "num_res_blocks": 1, 28 | "num_channels": (8, 8, 8), 29 | "attention_levels": (False, False, True), 30 | "num_head_channels": 8, 31 | "norm_num_groups": 8, 32 | "conditioning_embedding_in_channels": 1, 33 | "conditioning_embedding_num_channels": (8, 8), 34 | }, 35 | 6, 36 | (1, 8, 4, 4), 37 | ] 38 | ] 39 | 40 | 41 | class TestControlNet(unittest.TestCase): 42 | @parameterized.expand(TEST_CASES) 43 | def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): 44 | net = ControlNet(**input_param) 45 | with eval_mode(net): 46 | result = net.forward( 47 | torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32)) 48 | ) 49 | self.assertEqual(len(result[0]), expected_num_down_blocks_residuals) 50 | self.assertEqual(result[1].shape, expected_shape) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220729.json", 3 | "version": "0.1.0", 4 | "changelog": { 5 | "0.1.0": "Initial version" 6 | }, 7 | "monai_version": "1.0.0", 8 | "pytorch_version": "1.10.2", 9 | "numpy_version": "1.21.2", 10 | "optional_packages_version": { 11 | "generative": "0.1.0" 12 | }, 13 | "task": "MedNIST Hand Generation", 14 | "description": "", 15 | "authors": "Walter Hugo Lopez Pinaya, Mark Graham, and Eric Kerfoot", 16 | "copyright": "Copyright (c) KCL", 17 | "references": [], 18 | "intended_use": "This is suitable for research purposes only", 19 | "image_classes": "Single channel magnitude data", 20 | "data_source": "MedNIST", 21 | "network_data_format": { 22 | "inputs": { 23 | "image": { 24 | "type": "image", 25 | "format": "magnitude", 26 | "modality": "xray", 27 | "num_channels": 1, 28 | "spatial_shape": [ 29 | 1, 30 | 64, 31 | 64 32 | ], 33 | "dtype": "float32", 34 | "value_range": [], 35 | "is_patch_data": false, 36 | "channel_def": { 37 | "0": "image" 38 | } 39 | } 40 | }, 41 | "outputs": { 42 | "pred": { 43 | "type": "image", 44 | "format": "magnitude", 45 | "modality": "xray", 46 | "num_channels": 1, 47 | "spatial_shape": [ 48 | 1, 49 | 64, 50 | 64 51 | ], 52 | "dtype": "float32", 53 | "value_range": [], 54 | "is_patch_data": false, 55 | "channel_def": { 56 | "0": "image" 57 | } 58 | } 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | #default_language_version: 2 | # python: python3.8 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.4.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | exclude: ^tutorials/ 16 | - id: trailing-whitespace 17 | - id: check-yaml 18 | - id: check-docstring-first 19 | - id: check-executables-have-shebangs 20 | - id: check-toml 21 | - id: check-case-conflict 22 | - id: check-added-large-files 23 | args: ['--maxkb=1024'] 24 | - id: detect-private-key 25 | - id: forbid-new-submodules 26 | - id: pretty-format-json 27 | args: ['--autofix', '--no-sort-keys', '--indent=4'] 28 | - id: mixed-line-ending 29 | 30 | - repo: https://github.com/asottile/pyupgrade 31 | rev: v3.3.1 32 | hooks: 33 | - id: pyupgrade 34 | args: [--py37-plus] 35 | name: Upgrade code 36 | exclude: | 37 | (?x)^( 38 | versioneer.py| 39 | monai/_version.py 40 | )$ 41 | 42 | - repo: https://github.com/asottile/yesqa 43 | rev: v1.4.0 44 | hooks: 45 | - id: yesqa 46 | name: Unused noqa 47 | additional_dependencies: 48 | - flake8>=3.8.1 49 | - flake8-bugbear 50 | - flake8-comprehensions 51 | - flake8-executable 52 | - flake8-pyi 53 | - pep8-naming 54 | exclude: | 55 | (?x)^( 56 | generative/__init__.py| 57 | docs/source/conf.py 58 | )$ 59 | 60 | - repo: https://github.com/hadialqattan/pycln 61 | rev: v2.1.2 62 | hooks: 63 | - id: pycln 64 | args: [--config=pyproject.toml] 65 | 66 | # - repo: https://github.com/psf/black 67 | # rev: 22.3.0 68 | # hooks: 69 | # - id: black 70 | # 71 | # - repo: https://github.com/PyCQA/isort 72 | # rev: 5.9.3 73 | # hooks: 74 | # - id: isort 75 | -------------------------------------------------------------------------------- /tests/test_selfattention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from monai.networks import eval_mode 18 | from parameterized import parameterized 19 | 20 | from generative.networks.blocks.selfattention import SABlock 21 | 22 | TEST_CASE_SABLOCK = [ 23 | [ 24 | {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None}, 25 | (2, 4, 16), 26 | (2, 4, 16), 27 | ], 28 | [ 29 | {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4}, 30 | (2, 4, 16), 31 | (2, 4, 16), 32 | ], 33 | ] 34 | 35 | 36 | class TestResBlock(unittest.TestCase): 37 | @parameterized.expand(TEST_CASE_SABLOCK) 38 | def test_shape(self, input_param, input_shape, expected_shape): 39 | net = SABlock(**input_param) 40 | with eval_mode(net): 41 | result = net(torch.randn(input_shape)) 42 | self.assertEqual(result.shape, expected_shape) 43 | 44 | def test_ill_arg(self): 45 | with self.assertRaises(ValueError): 46 | SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0) 47 | 48 | with self.assertRaises(ValueError): 49 | SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0) 50 | 51 | with self.assertRaises(ValueError): 52 | SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4) 53 | 54 | with self.assertRaises(ValueError): 55 | SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None) 56 | 57 | 58 | if __name__ == "__main__": 59 | unittest.main() 60 | -------------------------------------------------------------------------------- /generative/utils/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import TYPE_CHECKING 15 | 16 | from monai.config import IgniteInfo 17 | from monai.utils import StrEnum, min_version, optional_import 18 | 19 | if TYPE_CHECKING: 20 | from ignite.engine import EventEnum 21 | else: 22 | EventEnum, _ = optional_import( 23 | "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" 24 | ) 25 | 26 | 27 | class AdversarialKeys(StrEnum): 28 | REALS = "reals" 29 | REAL_LOGITS = "real_logits" 30 | FAKES = "fakes" 31 | FAKE_LOGITS = "fake_logits" 32 | RECONSTRUCTION_LOSS = "reconstruction_loss" 33 | GENERATOR_LOSS = "generator_loss" 34 | DISCRIMINATOR_LOSS = "discriminator_loss" 35 | 36 | 37 | class AdversarialIterationEvents(EventEnum): 38 | RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" 39 | GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" 40 | GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" 41 | GENERATOR_LOSS_COMPLETED = "generator_loss_completed" 42 | GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed" 43 | GENERATOR_MODEL_COMPLETED = "generator_model_completed" 44 | DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed" 45 | DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed" 46 | DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" 47 | DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" 48 | DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" 49 | 50 | 51 | class OrderingType(StrEnum): 52 | RASTER_SCAN = "raster_scan" 53 | S_CURVE = "s_curve" 54 | RANDOM = "random" 55 | 56 | 57 | class OrderingTransformations(StrEnum): 58 | ROTATE_90 = "rotate_90" 59 | TRANSPOSE = "transpose" 60 | REFLECT = "reflect" 61 | -------------------------------------------------------------------------------- /tests/min_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import glob 15 | import os 16 | import sys 17 | import unittest 18 | 19 | 20 | def run_testsuit(): 21 | """ 22 | Load test cases by excluding those need external dependencies. 23 | The loaded cases should work with "requirements-min.txt":: 24 | 25 | # in the monai repo folder: 26 | pip install -r requirements-min.txt 27 | QUICKTEST=true python -m tests.min_tests 28 | 29 | :return: a test suite 30 | """ 31 | exclude_cases = [ # these cases use external dependencies 32 | "test_autoencoderkl", 33 | "test_diffusion_inferer", 34 | "test_integration_workflows_adversarial", 35 | "test_latent_diffusion_inferer", 36 | "test_perceptual_loss", 37 | "test_transformer", 38 | ] 39 | assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" 40 | 41 | files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py")) 42 | 43 | cases = [] 44 | for case in files: 45 | test_module = os.path.basename(case)[:-3] 46 | if test_module in exclude_cases: 47 | exclude_cases.remove(test_module) 48 | print(f"skipping tests.{test_module}.") 49 | else: 50 | cases.append(f"tests.{test_module}") 51 | assert not exclude_cases, f"items in exclude_cases not used: {exclude_cases}." 52 | test_suite = unittest.TestLoader().loadTestsFromNames(cases) 53 | return test_suite 54 | 55 | 56 | if __name__ == "__main__": 57 | # testing import submodules 58 | from monai.utils.module import load_submodules 59 | 60 | _, err_mod = load_submodules(sys.modules["monai"], True) 61 | assert not err_mod, f"err_mod={err_mod} not empty" 62 | 63 | # testing all modules 64 | test_runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2) 65 | result = test_runner.run(run_testsuit()) 66 | sys.exit(int(not result.wasSuccessful())) 67 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", 3 | "version": "1.0.0", 4 | "changelog": { 5 | "0.2": "Flipped images fixed" 6 | }, 7 | "monai_version": "1.1.0", 8 | "pytorch_version": "1.13.0", 9 | "numpy_version": "1.22.4", 10 | "optional_packages_version": { 11 | "nibabel": "4.0.1", 12 | "generative": "0.1.0", 13 | "transformers": "4.26.1" 14 | }, 15 | "task": "Chest X-ray image synthesis", 16 | "description": "A generative model for creating high-resolution chest X-ray based on MIMIC dataset", 17 | "copyright": "Copyright (c) MONAI Consortium", 18 | "data_source": "https://physionet.org/content/mimic-cxr-jpg/2.0.0/", 19 | "data_type": "image", 20 | "image_classes": "Radiography (X-ray) with 512 x 512 pixels", 21 | "intended_use": "This is a research tool/prototype and not to be used clinically", 22 | "network_data_format": { 23 | "inputs": { 24 | "latent_representation": { 25 | "type": "image", 26 | "format": "magnitude", 27 | "modality": "CXR", 28 | "num_channels": 3, 29 | "spatial_shape": [ 30 | 64, 31 | 64 32 | ], 33 | "dtype": "float32", 34 | "value_range": [], 35 | "is_patch_data": false 36 | }, 37 | "timesteps": { 38 | "type": "vector", 39 | "value_range": [ 40 | 0, 41 | 1000 42 | ], 43 | "dtype": "long" 44 | }, 45 | "context": { 46 | "type": "vector", 47 | "value_range": [], 48 | "dtype": "float32" 49 | } 50 | }, 51 | "outputs": { 52 | "pred": { 53 | "type": "image", 54 | "format": "magnitude", 55 | "modality": "CXR", 56 | "num_channels": 1, 57 | "spatial_shape": [ 58 | 512, 59 | 512 60 | ], 61 | "dtype": "float32", 62 | "value_range": [ 63 | 0, 64 | 1 65 | ], 66 | "is_patch_data": false, 67 | "channel_def": { 68 | "0": "X-ray" 69 | } 70 | } 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /tests/test_component_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | from generative.utils import ComponentStore 17 | 18 | 19 | class TestComponentStore(unittest.TestCase): 20 | def setUp(self): 21 | self.cs = ComponentStore("TestStore", "I am a test store, please ignore") 22 | 23 | def test_empty(self): 24 | self.assertEqual(len(self.cs), 0) 25 | self.assertEqual(list(self.cs), []) 26 | 27 | def test_add(self): 28 | test_obj = object() 29 | 30 | self.assertFalse("test_obj" in self.cs) 31 | 32 | self.cs.add("test_obj", "Test object", test_obj) 33 | 34 | self.assertTrue("test_obj" in self.cs) 35 | 36 | self.assertEqual(len(self.cs), 1) 37 | self.assertEqual(list(self.cs), [("test_obj", test_obj)]) 38 | 39 | self.assertEqual(self.cs.test_obj, test_obj) 40 | self.assertEqual(self.cs["test_obj"], test_obj) 41 | 42 | def test_add2(self): 43 | test_obj1 = object() 44 | test_obj2 = object() 45 | 46 | self.cs.add("test_obj1", "Test object", test_obj1) 47 | self.cs.add("test_obj2", "Test object", test_obj2) 48 | 49 | self.assertEqual(len(self.cs), 2) 50 | self.assertTrue("test_obj1" in self.cs) 51 | self.assertTrue("test_obj2" in self.cs) 52 | 53 | def test_add_def(self): 54 | self.assertFalse("test_func" in self.cs) 55 | 56 | @self.cs.add_def("test_func", "Test function") 57 | def test_func(): 58 | return 123 59 | 60 | self.assertTrue("test_func" in self.cs) 61 | 62 | self.assertEqual(len(self.cs), 1) 63 | self.assertEqual(list(self.cs), [("test_func", test_func)]) 64 | 65 | self.assertEqual(self.cs.test_func, test_func) 66 | self.assertEqual(self.cs["test_func"], test_func) 67 | 68 | # try adding the same function again 69 | self.cs.add_def("test_func", "Test function but with new description")(test_func) 70 | 71 | self.assertEqual(len(self.cs), 1) 72 | self.assertEqual(self.cs.test_func, test_func) 73 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", 3 | "version": "1.0.0", 4 | "changelog": { 5 | "1.0.8": "Initial release" 6 | }, 7 | "monai_version": "1.1.0", 8 | "pytorch_version": "1.13.0", 9 | "numpy_version": "1.22.4", 10 | "optional_packages_version": { 11 | "nibabel": "4.0.1", 12 | "generative": "0.1.0" 13 | }, 14 | "task": "Brain image synthesis", 15 | "description": "A generative model for creating high-resolution 3D brain MRI based on UK Biobank", 16 | "authors": "Walter H. L. Pinaya, Petru-Daniel Tudosiu, Jessica Dafflon, Pedro F Da Costa, Virginia Fernandez, Parashkev Nachev, Sebastien Ourselin, and M. Jorge Cardoso", 17 | "copyright": "Copyright (c) MONAI Consortium", 18 | "data_source": "https://www.ukbiobank.ac.uk/", 19 | "data_type": "nibabel", 20 | "image_classes": "T1w head MRI with 1x1x1 mm voxel size", 21 | "eval_metrics": { 22 | "fid": 0.0076, 23 | "msssim": 0.6555, 24 | "4gmsssim": 0.3883 25 | }, 26 | "intended_use": "This is a research tool/prototype and not to be used clinically", 27 | "references": [ 28 | "Pinaya, Walter HL, et al. \"Brain imaging generation with latent diffusion models.\" MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022." 29 | ], 30 | "network_data_format": { 31 | "inputs": { 32 | "image": { 33 | "type": "tabular", 34 | "num_channels": 1, 35 | "dtype": "float32", 36 | "value_range": [ 37 | 0, 38 | 1 39 | ], 40 | "is_patch_data": false, 41 | "channel_def": { 42 | "0": "Gender", 43 | "1": "Age", 44 | "2": "Ventricular volume", 45 | "3": "Brain volume" 46 | } 47 | } 48 | }, 49 | "outputs": { 50 | "pred": { 51 | "type": "image", 52 | "format": "magnitude", 53 | "modality": "MR", 54 | "num_channels": 1, 55 | "spatial_shape": [ 56 | 160, 57 | 224, 58 | 160 59 | ], 60 | "dtype": "float32", 61 | "value_range": [ 62 | 0, 63 | 1 64 | ], 65 | "is_patch_data": false, 66 | "channel_def": { 67 | "0": "T1w" 68 | } 69 | } 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /tests/test_scheduler_ddim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.networks.schedulers import DDIMScheduler 20 | 21 | TEST_2D_CASE = [] 22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 23 | TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) 24 | 25 | TEST_3D_CASE = [] 26 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 27 | TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) 28 | 29 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE 30 | 31 | 32 | class TestDDPMScheduler(unittest.TestCase): 33 | @parameterized.expand(TEST_CASES) 34 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): 35 | scheduler = DDIMScheduler(**input_param) 36 | scheduler.set_timesteps(num_inference_steps=100) 37 | original_sample = torch.zeros(input_shape) 38 | noise = torch.randn_like(original_sample) 39 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() 40 | 41 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) 42 | self.assertEqual(noisy.shape, expected_shape) 43 | 44 | @parameterized.expand(TEST_CASES) 45 | def test_step_shape(self, input_param, input_shape, expected_shape): 46 | scheduler = DDIMScheduler(**input_param) 47 | scheduler.set_timesteps(num_inference_steps=100) 48 | model_output = torch.randn(input_shape) 49 | sample = torch.randn(input_shape) 50 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) 51 | self.assertEqual(output_step[0].shape, expected_shape) 52 | self.assertEqual(output_step[1].shape, expected_shape) 53 | 54 | def test_set_timesteps(self): 55 | scheduler = DDIMScheduler(num_train_timesteps=1000) 56 | scheduler.set_timesteps(num_inference_steps=100) 57 | self.assertEqual(scheduler.num_inference_steps, 100) 58 | self.assertEqual(len(scheduler.timesteps), 100) 59 | 60 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): 61 | scheduler = DDIMScheduler(num_train_timesteps=1000) 62 | with self.assertRaises(ValueError): 63 | scheduler.set_timesteps(num_inference_steps=2000) 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/README.md: -------------------------------------------------------------------------------- 1 | # Brain Imaging Generation with Latent Diffusion Models 2 | 3 | ### **Authors** 4 | 5 | Walter H. L. Pinaya, Petru-Daniel Tudosiu, Jessica Dafflon, Pedro F Da Costa, Virginia Fernandez, Parashkev Nachev, 6 | Sebastien Ourselin, and M. Jorge Cardoso 7 | 8 | ### **Tags** 9 | Synthetic data, Latent Diffusion Model, Generative model, Brain Imaging 10 | 11 | ## **Model Description** 12 | This model is trained using the Latent Diffusion Model architecture [1] and is used for the synthesis of conditioned 3D 13 | brain MRI data. The model is divided into two parts: an autoencoder with a KL-regularisation model that compresses data 14 | into a latent space and a diffusion model that learns to generate conditioned synthetic latent representations. This 15 | model is conditioned on age, sex, the volume of ventricular cerebrospinal fluid, and brain volume normalised for head size. 16 | 17 | ![](./figure_1.png)
18 |

19 | Figure 1 - Synthetic image from the model.

20 | 21 | 22 | ## **Data** 23 | The model was trained on brain data from 31,740 participants from the UK Biobank [2]. We used high-resolution 3D T1w MRI with voxel size of 1mm3, resulting in volumes with 160 x 224 x 160 voxels 24 | 25 | #### **Preprocessing** 26 | We used UniRes [3] to perform a rigid body registration to a common MNI space for image pre-processing. The voxel intensity was normalised to be between [0, 1]. 27 | 28 | ## **Performance** 29 | This model achieves the following results on UK Biobank: an FID of 0.0076, an MS-SSIM of 0.6555, and a 4-G-R-SSIM of 0.3883. 30 | 31 | Please, check Table 1 of the original paper for more details regarding evaluation results. 32 | 33 | 34 | ## **commands example** 35 | Execute sampling: 36 | ``` 37 | export PYTHONPATH=$PYTHONPATH:"" 38 | $ python -m monai.bundle run save_nii --config_file configs/inference.json --gender 1.0 --age 0.7 --ventricular_vol 0.7 --brain_vol 0.5 39 | ``` 40 | All conditioning are expected to have values between 0 and 1 41 | 42 | ## **Citation Info** 43 | 44 | ``` 45 | @inproceedings{pinaya2022brain, 46 | title={Brain imaging generation with latent diffusion models}, 47 | author={Pinaya, Walter HL and Tudosiu, Petru-Daniel and Dafflon, Jessica and Da Costa, Pedro F and Fernandez, Virginia and Nachev, Parashkev and Ourselin, Sebastien and Cardoso, M Jorge}, 48 | booktitle={MICCAI Workshop on Deep Generative Models}, 49 | pages={117--126}, 50 | year={2022}, 51 | organization={Springer} 52 | } 53 | ``` 54 | 55 | ## **References** 56 | 57 | Example: 58 | 59 | [1] Pinaya, Walter HL, et al. "Brain imaging generation with latent diffusion models." MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022. 60 | 61 | [2] Sudlow, Cathie, et al. "UK biobank: an open access resource for identifying the causes of a wide range of complex diseases of middle and old age." PLoS medicine 12.3 (2015): e1001779. 62 | 63 | [3] Brudfors, Mikael, et al. "MRI super-resolution using multi-channel total variation." Annual Conference on Medical Image Understanding and Analysis. Springer, Cham, 2018. 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | project-monai 3 |

4 | 5 | # MONAI Generative Models 6 | Prototyping repository for generative models to be integrated into MONAI core, MONAI tutorials, and MONAI model zoo. 7 | ## Features 8 | * Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, Autoregressive transformers, (Multi-scale) Patch-GAN discriminator. 9 | * Diffusion Model Noise Schedulers: DDPM, DDIM, and PNDM. 10 | * Losses: Adversarial losses, Spectral losses, and Perceptual losses (for 2D and 3D data using LPIPS, RadImageNet, and 3DMedicalNet pre-trained models). 11 | * Metrics: Multi-Scale Structural Similarity Index Measure (MS-SSIM) and Fréchet inception distance (FID). 12 | * Diffusion Models, Latent Diffusion Models, and VQ-VAE + Transformer Inferers classes (compatible with MONAI style) containing methods to train, sample synthetic images, and obtain the likelihood of inputted data. 13 | * MONAI-compatible trainer engine (based on Ignite) to train models with reconstruction and adversarial components. 14 | * Tutorials including: 15 | * How to train VQ-VAEs, VQ-GANs, VQ-VAE + Transformers, AutoencoderKLs, Diffusion Models, and Latent Diffusion Models on 2D and 3D data. 16 | * Train diffusion model to perform conditional image generation with classifier-free guidance. 17 | * Comparison of different diffusion model schedulers. 18 | * Diffusion models with different parameterizations (e.g., v-prediction and epsilon parameterization). 19 | * Anomaly Detection using VQ-VAE + Transformers and Diffusion Models. 20 | * Inpainting with diffusion model (using Repaint method) 21 | * Super-resolution with Latent Diffusion Models (using Noise Conditioning Augmentation) 22 | 23 | ## Roadmap 24 | Our short-term goals are available in the [Milestones](https://github.com/Project-MONAI/GenerativeModels/milestones) 25 | section of the repository. 26 | 27 | In the longer term, we aim to integrate the generative models into the MONAI core repository (supporting tasks such as, 28 | image synthesis, anomaly detection, MRI reconstruction, domain transfer) 29 | 30 | ## Installation 31 | To install the current release of MONAI Generative Models, you can run: 32 | ``` 33 | pip install monai-generative 34 | ``` 35 | To install the current main branch of the repository, run: 36 | ``` 37 | pip install git+https://github.com/Project-MONAI/GenerativeModels.git 38 | ``` 39 | Requires Python >= 3.8. 40 | 41 | ## Contributing 42 | For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/GenerativeModels/blob/main/CONTRIBUTING.md). 43 | 44 | ## Community 45 | Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9). 46 | 47 | # Citation 48 | 49 | If you use MONAI Generative in your research, please cite us! The citation can be exported from [the paper](https://arxiv.org/abs/2307.15208). 50 | 51 | ## Links 52 | - Website: https://monai.io/ 53 | - Code: https://github.com/Project-MONAI/GenerativeModels 54 | - Project tracker: https://github.com/Project-MONAI/GenerativeModels/projects 55 | - Issue tracker: https://github.com/Project-MONAI/GenerativeModels/issues 56 | -------------------------------------------------------------------------------- /tests/test_scheduler_pndm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.networks.schedulers import PNDMScheduler 20 | 21 | TEST_2D_CASE = [] 22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 23 | TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)]) 24 | 25 | TEST_3D_CASE = [] 26 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 27 | TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]) 28 | 29 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE 30 | 31 | 32 | class TestDDPMScheduler(unittest.TestCase): 33 | @parameterized.expand(TEST_CASES) 34 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): 35 | scheduler = PNDMScheduler(**input_param) 36 | original_sample = torch.zeros(input_shape) 37 | noise = torch.randn_like(original_sample) 38 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() 39 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) 40 | self.assertEqual(noisy.shape, expected_shape) 41 | 42 | @parameterized.expand(TEST_CASES) 43 | def test_step_shape(self, input_param, input_shape, expected_shape): 44 | scheduler = PNDMScheduler(**input_param) 45 | scheduler.set_timesteps(600) 46 | model_output = torch.randn(input_shape) 47 | sample = torch.randn(input_shape) 48 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) 49 | self.assertEqual(output_step[0].shape, expected_shape) 50 | self.assertEqual(output_step[1], None) 51 | 52 | def test_set_timesteps(self): 53 | scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True) 54 | scheduler.set_timesteps(num_inference_steps=100) 55 | self.assertEqual(scheduler.num_inference_steps, 100) 56 | self.assertEqual(len(scheduler.timesteps), 100) 57 | 58 | def test_set_timesteps_prk(self): 59 | scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False) 60 | scheduler.set_timesteps(num_inference_steps=100) 61 | self.assertEqual(scheduler.num_inference_steps, 109) 62 | self.assertEqual(len(scheduler.timesteps), 109) 63 | 64 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): 65 | scheduler = PNDMScheduler(num_train_timesteps=1000) 66 | with self.assertRaises(ValueError): 67 | scheduler.set_timesteps(num_inference_steps=2000) 68 | 69 | 70 | if __name__ == "__main__": 71 | unittest.main() 72 | -------------------------------------------------------------------------------- /tests/test_spectral_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import numpy as np 17 | import torch 18 | from parameterized import parameterized 19 | 20 | from generative.losses import JukeboxLoss 21 | from tests.utils import test_script_save 22 | 23 | TEST_CASES = [ 24 | [ 25 | {"spatial_dims": 2}, 26 | { 27 | "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), 28 | "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), 29 | }, 30 | 0.070648, 31 | ], 32 | [ 33 | {"spatial_dims": 2, "reduction": "sum"}, 34 | { 35 | "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), 36 | "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]), 37 | }, 38 | 0.8478, 39 | ], 40 | [ 41 | {"spatial_dims": 3}, 42 | { 43 | "input": torch.tensor( 44 | [ 45 | [ 46 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], 47 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], 48 | ] 49 | ] 50 | ), 51 | "target": torch.tensor( 52 | [ 53 | [ 54 | [[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], 55 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]], 56 | ] 57 | ] 58 | ), 59 | }, 60 | 0.03838, 61 | ], 62 | ] 63 | 64 | 65 | class TestJukeboxLoss(unittest.TestCase): 66 | @parameterized.expand(TEST_CASES) 67 | def test_results(self, input_param, input_data, expected_val): 68 | results = JukeboxLoss(**input_param).forward(**input_data) 69 | np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4) 70 | 71 | def test_2d_shape(self): 72 | results = JukeboxLoss(spatial_dims=2, reduction="none").forward(**TEST_CASES[0][1]) 73 | self.assertEqual(results.shape, (1, 2, 2, 3)) 74 | 75 | def test_3d_shape(self): 76 | results = JukeboxLoss(spatial_dims=3, reduction="none").forward(**TEST_CASES[2][1]) 77 | self.assertEqual(results.shape, (1, 2, 2, 2, 3)) 78 | 79 | def test_script(self): 80 | loss = JukeboxLoss(spatial_dims=2) 81 | test_input = torch.ones(2, 1, 8, 8) 82 | test_script_save(loss, test_input, test_input) 83 | 84 | 85 | if __name__ == "__main__": 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /tests/test_perceptual_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.losses import PerceptualLoss 20 | 21 | TEST_CASES = [ 22 | [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], 23 | [ 24 | {"spatial_dims": 3, "network_type": "squeeze", "is_fake_3d": True, "fake_3d_ratio": 0.1}, 25 | (2, 1, 64, 64, 64), 26 | (2, 1, 64, 64, 64), 27 | ], 28 | [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)], 29 | [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)], 30 | [ 31 | {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1}, 32 | (2, 1, 64, 64, 64), 33 | (2, 1, 64, 64, 64), 34 | ], 35 | [ 36 | {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False}, 37 | (2, 1, 64, 64, 64), 38 | (2, 1, 64, 64, 64), 39 | ], 40 | [ 41 | {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2}, 42 | (2, 1, 64, 64, 64), 43 | (2, 1, 64, 64, 64), 44 | ], 45 | ] 46 | 47 | 48 | class TestPerceptualLoss(unittest.TestCase): 49 | @parameterized.expand(TEST_CASES) 50 | def test_shape(self, input_param, input_shape, target_shape): 51 | loss = PerceptualLoss(**input_param) 52 | result = loss(torch.randn(input_shape), torch.randn(target_shape)) 53 | self.assertEqual(result.shape, torch.Size([])) 54 | 55 | @parameterized.expand(TEST_CASES) 56 | def test_identical_input(self, input_param, input_shape, target_shape): 57 | loss = PerceptualLoss(**input_param) 58 | tensor = torch.randn(input_shape) 59 | result = loss(tensor, tensor) 60 | self.assertEqual(result, torch.Tensor([0.0])) 61 | 62 | def test_different_shape(self): 63 | loss = PerceptualLoss(spatial_dims=2, network_type="squeeze") 64 | tensor = torch.randn(2, 1, 64, 64) 65 | target = torch.randn(2, 1, 32, 32) 66 | with self.assertRaises(ValueError): 67 | loss(tensor, target) 68 | 69 | def test_1d(self): 70 | with self.assertRaises(NotImplementedError): 71 | PerceptualLoss(spatial_dims=1) 72 | 73 | def test_medicalnet_on_2d_data(self): 74 | with self.assertRaises(ValueError): 75 | PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets") 76 | 77 | with self.assertRaises(ValueError): 78 | PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets") 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /tests/test_compute_multiscalessim_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from monai.utils import set_determinism 18 | 19 | from generative.metrics import MultiScaleSSIMMetric 20 | 21 | 22 | class TestMultiScaleSSIMMetric(unittest.TestCase): 23 | def test2d_gaussian(self): 24 | set_determinism(0) 25 | preds = torch.abs(torch.randn(1, 1, 64, 64)) 26 | target = torch.abs(torch.randn(1, 1, 64, 64)) 27 | preds = preds / preds.max() 28 | target = target / target.max() 29 | 30 | metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) 31 | metric(preds, target) 32 | result = metric.aggregate() 33 | expected_value = 0.023176 34 | self.assertTrue(expected_value - result.item() < 0.000001) 35 | 36 | def test2d_uniform(self): 37 | set_determinism(0) 38 | preds = torch.abs(torch.randn(1, 1, 64, 64)) 39 | target = torch.abs(torch.randn(1, 1, 64, 64)) 40 | preds = preds / preds.max() 41 | target = target / target.max() 42 | 43 | metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform", weights=[0.5, 0.5]) 44 | metric(preds, target) 45 | result = metric.aggregate() 46 | expected_value = 0.022655 47 | self.assertTrue(expected_value - result.item() < 0.000001) 48 | 49 | def test3d_gaussian(self): 50 | set_determinism(0) 51 | preds = torch.abs(torch.randn(1, 1, 64, 64, 64)) 52 | target = torch.abs(torch.randn(1, 1, 64, 64, 64)) 53 | preds = preds / preds.max() 54 | target = target / target.max() 55 | 56 | metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5]) 57 | metric(preds, target) 58 | result = metric.aggregate() 59 | expected_value = 0.061796 60 | self.assertTrue(expected_value - result.item() < 0.000001) 61 | 62 | def input_ill_input_shape2d(self): 63 | metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) 64 | 65 | with self.assertRaises(ValueError): 66 | metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64)) 67 | 68 | def input_ill_input_shape3d(self): 69 | metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5]) 70 | 71 | with self.assertRaises(ValueError): 72 | metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64)) 73 | 74 | def small_inputs(self): 75 | metric = MultiScaleSSIMMetric(spatial_dims=2) 76 | 77 | with self.assertRaises(ValueError): 78 | metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16)) 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /tests/test_vector_quantizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | 18 | from generative.networks.layers import EMAQuantizer, VectorQuantizer 19 | 20 | 21 | class TestEMA(unittest.TestCase): 22 | def test_ema_shape(self): 23 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8) 24 | input_shape = (1, 8, 8, 8) 25 | x = torch.randn(input_shape) 26 | layer = layer.train() 27 | outputs = layer(x) 28 | self.assertEqual(outputs[0].shape, input_shape) 29 | self.assertEqual(outputs[2].shape, (1, 8, 8)) 30 | 31 | layer = layer.eval() 32 | outputs = layer(x) 33 | self.assertEqual(outputs[0].shape, input_shape) 34 | self.assertEqual(outputs[2].shape, (1, 8, 8)) 35 | 36 | def test_ema_quantize(self): 37 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8) 38 | input_shape = (1, 8, 8, 8) 39 | x = torch.randn(input_shape) 40 | outputs = layer.quantize(x) 41 | self.assertEqual(outputs[0].shape, (64, 8)) # (HxW, C) 42 | self.assertEqual(outputs[1].shape, (64, 16)) # (HxW, E) 43 | self.assertEqual(outputs[2].shape, (1, 8, 8)) # (1, H, W) 44 | 45 | def test_ema(self): 46 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0) 47 | original_weight_0 = layer.embedding.weight[0].clone() 48 | original_weight_1 = layer.embedding.weight[1].clone() 49 | x_0 = original_weight_0 50 | x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 51 | x_0 = x_0.repeat(1, 1, 1, 2) + 0.001 52 | 53 | x_1 = original_weight_1 54 | x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 55 | x_1 = x_1.repeat(1, 1, 1, 2) 56 | 57 | x = torch.cat([x_0, x_1], dim=0) 58 | layer = layer.train() 59 | _ = layer(x) 60 | 61 | self.assertTrue(all(layer.embedding.weight[0] != original_weight_0)) 62 | self.assertTrue(all(layer.embedding.weight[1] == original_weight_1)) 63 | 64 | 65 | class TestVectorQuantizer(unittest.TestCase): 66 | def test_vector_quantizer_shape(self): 67 | layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)) 68 | input_shape = (1, 8, 8, 8) 69 | x = torch.randn(input_shape) 70 | outputs = layer(x) 71 | self.assertEqual(outputs[1].shape, input_shape) 72 | 73 | def test_vector_quantizer_quantize(self): 74 | layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)) 75 | input_shape = (1, 8, 8, 8) 76 | x = torch.randn(input_shape) 77 | outputs = layer.quantize(x) 78 | self.assertEqual(outputs.shape, (1, 8, 8)) 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /generative/metrics/mmd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Callable 15 | 16 | import torch 17 | from monai.metrics.metric import Metric 18 | 19 | 20 | class MMDMetric(Metric): 21 | """ 22 | Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two 23 | distributions. It is a non-negative metric where a smaller value indicates a closer match between the two 24 | distributions. 25 | 26 | Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773. 27 | 28 | Args: 29 | y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace 30 | filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a 31 | feature extractor or an Identity function. 32 | y_pred_transform: Callable to transform the y_pred tensor before computing the metric. 33 | """ 34 | 35 | def __init__(self, y_transform: Callable | None = None, y_pred_transform: Callable | None = None) -> None: 36 | super().__init__() 37 | 38 | self.y_transform = y_transform 39 | self.y_pred_transform = y_pred_transform 40 | 41 | def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Args: 44 | y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. 45 | y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. 46 | """ 47 | 48 | # Beta and Gamma are not calculated since torch.mean is used at return 49 | beta = 1.0 50 | gamma = 2.0 51 | 52 | if self.y_transform is not None: 53 | y = self.y_transform(y) 54 | 55 | if self.y_pred_transform is not None: 56 | y_pred = self.y_pred_transform(y_pred) 57 | 58 | if y_pred.shape != y.shape: 59 | raise ValueError( 60 | "y_pred and y shapes dont match after being processed " 61 | f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}" 62 | ) 63 | 64 | for d in range(len(y.shape) - 1, 1, -1): 65 | y = y.squeeze(dim=d) 66 | y_pred = y_pred.squeeze(dim=d) 67 | 68 | y = y.view(y.shape[0], -1) 69 | y_pred = y_pred.view(y_pred.shape[0], -1) 70 | 71 | y_y = torch.mm(y, y.t()) 72 | y_pred_y_pred = torch.mm(y_pred, y_pred.t()) 73 | y_pred_y = torch.mm(y_pred, y.t()) 74 | 75 | y_y = y_y / y.shape[1] 76 | y_pred_y_pred = y_pred_y_pred / y.shape[1] 77 | y_pred_y = y_pred_y / y.shape[1] 78 | 79 | # Ref. 1 Eq. 3 (found under Lemma 6) 80 | return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y) 81 | -------------------------------------------------------------------------------- /generative/networks/blocks/encoder_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Sequence 15 | from functools import partial 16 | 17 | import torch 18 | import torch.nn as nn 19 | from monai.networks.blocks import Convolution 20 | 21 | __all__ = ["SpatialRescaler"] 22 | 23 | 24 | class SpatialRescaler(nn.Module): 25 | """ 26 | SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py 27 | 28 | Args: 29 | spatial_dims: number of spatial dimensions. 30 | n_stages: number of interpolation stages. 31 | size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). 32 | method: algorithm used for sampling. 33 | multiplier: multiplier for spatial size. If `multiplier` is a sequence, 34 | its length has to match the number of spatial dimensions; `input.dim() - 2`. 35 | in_channels: number of input channels. 36 | out_channels: number of output channels. 37 | bias: whether to have a bias term. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | spatial_dims: int = 2, 43 | n_stages: int = 1, 44 | size: Sequence[int] | int | None = None, 45 | method: str = "bilinear", 46 | multiplier: Sequence[float] | float | None = None, 47 | in_channels: int = 3, 48 | out_channels: int = None, 49 | bias: bool = False, 50 | ): 51 | super().__init__() 52 | self.n_stages = n_stages 53 | assert self.n_stages >= 0 54 | assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] 55 | if size is not None and n_stages != 1: 56 | raise ValueError("when size is not None, n_stages should be 1.") 57 | if size is not None and multiplier is not None: 58 | raise ValueError("only one of size or multiplier should be defined.") 59 | self.multiplier = multiplier 60 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size) 61 | self.remap_output = out_channels is not None 62 | if self.remap_output: 63 | print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") 64 | self.channel_mapper = Convolution( 65 | spatial_dims=spatial_dims, 66 | in_channels=in_channels, 67 | out_channels=out_channels, 68 | kernel_size=1, 69 | conv_only=True, 70 | bias=bias, 71 | ) 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | if self.remap_output: 75 | x = self.channel_mapper(x) 76 | 77 | for _ in range(self.n_stages): 78 | x = self.interpolator(x, scale_factor=self.multiplier) 79 | 80 | return x 81 | 82 | def encode(self, x: torch.Tensor) -> torch.Tensor: 83 | return self(x) 84 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at monai.contact@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /generative/losses/spectral_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from monai.utils import LossReduction 17 | from torch.fft import fftn 18 | from torch.nn.modules.loss import _Loss 19 | 20 | 21 | class JukeboxLoss(_Loss): 22 | """ 23 | Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT). 24 | 25 | Based on: 26 | Dhariwal, et al. 'Jukebox: A generative model for music.'https://arxiv.org/abs/2005.00341 27 | 28 | Args: 29 | spatial_dims: number of spatial dimensions. 30 | fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information. 31 | fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See 32 | torch.fft.fftn() for more information. 33 | 34 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 35 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 36 | 37 | - ``"none"``: no reduction will be applied. 38 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 39 | - ``"sum"``: the output will be summed. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | spatial_dims: int, 45 | fft_signal_size: tuple[int] | None = None, 46 | fft_norm: str = "ortho", 47 | reduction: LossReduction | str = LossReduction.MEAN, 48 | ) -> None: 49 | super().__init__(reduction=LossReduction(reduction).value) 50 | 51 | self.spatial_dims = spatial_dims 52 | self.fft_signal_size = fft_signal_size 53 | self.fft_dim = tuple(range(1, spatial_dims + 2)) 54 | self.fft_norm = fft_norm 55 | 56 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 57 | input_amplitude = self._get_fft_amplitude(target) 58 | target_amplitude = self._get_fft_amplitude(input) 59 | 60 | # Compute distance between amplitude of frequency components 61 | # See Section 3.3 from https://arxiv.org/abs/2005.00341 62 | loss = F.mse_loss(target_amplitude, input_amplitude, reduction="none") 63 | 64 | if self.reduction == LossReduction.MEAN.value: 65 | loss = loss.mean() 66 | elif self.reduction == LossReduction.SUM.value: 67 | loss = loss.sum() 68 | elif self.reduction == LossReduction.NONE.value: 69 | pass 70 | 71 | return loss 72 | 73 | def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor: 74 | """ 75 | Calculate the amplitude of the fourier transformations representation of the images 76 | 77 | Args: 78 | images: Images that are to undergo fftn 79 | 80 | Returns: 81 | fourier transformation amplitude 82 | """ 83 | img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm) 84 | 85 | amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2) 86 | 87 | return amplitude 88 | -------------------------------------------------------------------------------- /generative/networks/blocks/spade_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from monai.networks.blocks import ADN, Convolution 18 | 19 | 20 | class SPADE(nn.Module): 21 | """ 22 | SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) 23 | 24 | Args: 25 | label_nc: number of semantic labels 26 | norm_nc: number of output channels 27 | kernel_size: kernel size 28 | spatial_dims: number of spatial dimensions 29 | hidden_channels: number of channels in the intermediate gamma and beta layers 30 | norm: type of base normalisation used before applying the SPADE normalisation 31 | norm_params: parameters for the base normalisation 32 | """ 33 | 34 | def __init__( 35 | self, 36 | label_nc: int, 37 | norm_nc: int, 38 | kernel_size: int = 3, 39 | spatial_dims: int = 2, 40 | hidden_channels: int = 64, 41 | norm: str | tuple = "INSTANCE", 42 | norm_params: dict | None = None, 43 | ) -> None: 44 | super().__init__() 45 | 46 | if norm_params is None: 47 | norm_params = {} 48 | if len(norm_params) != 0: 49 | norm = (norm, norm_params) 50 | self.param_free_norm = ADN( 51 | act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc 52 | ) 53 | self.mlp_shared = Convolution( 54 | spatial_dims=spatial_dims, 55 | in_channels=label_nc, 56 | out_channels=hidden_channels, 57 | kernel_size=kernel_size, 58 | norm=None, 59 | padding=kernel_size // 2, 60 | act="LEAKYRELU", 61 | ) 62 | self.mlp_gamma = Convolution( 63 | spatial_dims=spatial_dims, 64 | in_channels=hidden_channels, 65 | out_channels=norm_nc, 66 | kernel_size=kernel_size, 67 | padding=kernel_size // 2, 68 | act=None, 69 | ) 70 | self.mlp_beta = Convolution( 71 | spatial_dims=spatial_dims, 72 | in_channels=hidden_channels, 73 | out_channels=norm_nc, 74 | kernel_size=kernel_size, 75 | padding=kernel_size // 2, 76 | act=None, 77 | ) 78 | 79 | def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: 80 | """ 81 | Args: 82 | x: input tensor 83 | segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. 84 | The map will be interpolated to the dimension of x internally. 85 | """ 86 | 87 | # Part 1. generate parameter-free normalized activations 88 | normalized = self.param_free_norm(x) 89 | 90 | # Part 2. produce scaling and bias conditioned on semantic map 91 | segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") 92 | actv = self.mlp_shared(segmap) 93 | gamma = self.mlp_gamma(actv) 94 | beta = self.mlp_beta(actv) 95 | out = normalized * (1 + gamma) + beta 96 | return out 97 | -------------------------------------------------------------------------------- /generative/networks/blocks/transformerblock.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | from monai.networks.blocks.mlp import MLPBlock 17 | 18 | from generative.networks.blocks.selfattention import SABlock 19 | 20 | 21 | class TransformerBlock(nn.Module): 22 | """ 23 | A transformer block, based on: "Dosovitskiy et al., 24 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 25 | 26 | Args: 27 | hidden_size: dimension of hidden layer. 28 | mlp_dim: dimension of feedforward layer. 29 | num_heads: number of attention heads. 30 | dropout_rate: faction of the input units to drop. 31 | qkv_bias: apply bias term for the qkv linear layer 32 | causal: whether to use causal attention. 33 | sequence_length: if causal is True, it is necessary to specify the sequence length. 34 | with_cross_attention: Whether to use cross attention for conditioning. 35 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | hidden_size: int, 41 | mlp_dim: int, 42 | num_heads: int, 43 | dropout_rate: float = 0.0, 44 | qkv_bias: bool = False, 45 | causal: bool = False, 46 | sequence_length: int | None = None, 47 | with_cross_attention: bool = False, 48 | use_flash_attention: bool = False, 49 | ) -> None: 50 | self.with_cross_attention = with_cross_attention 51 | super().__init__() 52 | 53 | if not (0 <= dropout_rate <= 1): 54 | raise ValueError("dropout_rate should be between 0 and 1.") 55 | 56 | if hidden_size % num_heads != 0: 57 | raise ValueError("hidden_size should be divisible by num_heads.") 58 | 59 | self.norm1 = nn.LayerNorm(hidden_size) 60 | self.attn = SABlock( 61 | hidden_size=hidden_size, 62 | num_heads=num_heads, 63 | dropout_rate=dropout_rate, 64 | qkv_bias=qkv_bias, 65 | causal=causal, 66 | sequence_length=sequence_length, 67 | use_flash_attention=use_flash_attention, 68 | ) 69 | 70 | self.norm2 = None 71 | self.cross_attn = None 72 | if self.with_cross_attention: 73 | self.norm2 = nn.LayerNorm(hidden_size) 74 | self.cross_attn = SABlock( 75 | hidden_size=hidden_size, 76 | num_heads=num_heads, 77 | dropout_rate=dropout_rate, 78 | qkv_bias=qkv_bias, 79 | with_cross_attention=with_cross_attention, 80 | causal=False, 81 | use_flash_attention=use_flash_attention, 82 | ) 83 | 84 | self.norm3 = nn.LayerNorm(hidden_size) 85 | self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) 86 | 87 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 88 | x = x + self.attn(self.norm1(x)) 89 | if self.with_cross_attention: 90 | x = x + self.cross_attn(self.norm2(x), context=context) 91 | x = x + self.mlp(self.norm3(x)) 92 | return x 93 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/README.md: -------------------------------------------------------------------------------- 1 | # Chest X-ray with Latent Diffusion Models 2 | 3 | ### **Authors** 4 | 5 | MONAI Generative Models 6 | 7 | ### **Tags** 8 | Synthetic data, Latent Diffusion Model, Generative model, Chest X-ray 9 | 10 | ## **Model Description** 11 | This model is trained from scratch using the Latent Diffusion Model architecture [1] and is used for the synthesis of 12 | 2D Chest X-ray conditioned on Radiological reports. The model is divided into two parts: an autoencoder with a 13 | KL-regularisation model that compresses data into a latent space and a diffusion model that learns to generate 14 | conditioned synthetic latent representations. This model is conditioned on Findings and Impressions from radiological 15 | reports. The original repository can be found [here](https://github.com/Warvito/generative_chestxray) 16 | 17 | ![](./figure_1.png)
18 |

19 | Figure 1 - Synthetic images from the model.

20 | 21 | ## **Data** 22 | The model was trained on brain data from 90,000 participants from the MIMIC dataset [2] [3]. We downsampled the 23 | original images to have a format of 512 x 512 pixels. 24 | 25 | #### **Preprocessing** 26 | We resized the original images to make the smallest sides have 512 pixels. When inputting it to the network, we center 27 | cropped the images to 512 x 512. The pixel intensity was normalised to be between [0, 1]. The text data was obtained 28 | from associated radiological reports. We randoomly extracted sentences from the findings and impressions sections of the 29 | reports, having a maximum of 5 sentences and 77 tokens. The text was tokenised using the CLIPTokenizer from 30 | transformers package (https://github.com/huggingface/transformers) (pretrained model 31 | "stabilityai/stable-diffusion-2-1-base") and then encoded using CLIPTextModel from the same package and pretrained 32 | model. 33 | 34 | 35 | ## **Commands Example** 36 | Here we included a few examples of commands to sample images from the model and save them as .jpg files. The available 37 | arguments for this task are: "--prompt" (str) text prompt to condition the model on; "--guidance_scale" (float), the 38 | parameter that controls how much the image generation process follows the text prompt. The higher the value, the more 39 | the image sticks to a given text input (the common range is between 1-21). 40 | 41 | Examples: 42 | 43 | ```shell 44 | export PYTHONPATH=$PYTHONPATH:"" 45 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Big right-sided pleural effusion" --guidance_scale 7.0 46 | ``` 47 | 48 | ```shell 49 | export PYTHONPATH=$PYTHONPATH:"" 50 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Small right-sided pleural effusion" --guidance_scale 7.0 51 | ``` 52 | 53 | ```shell 54 | export PYTHONPATH=$PYTHONPATH:"" 55 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Bilateral pleural effusion" --guidance_scale 7.0 56 | ``` 57 | 58 | ```shell 59 | export PYTHONPATH=$PYTHONPATH:"" 60 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Cardiomegaly" --guidance_scale 7.0 61 | ``` 62 | 63 | 64 | ## **References** 65 | 66 | 67 | [1] Pinaya, Walter HL, et al. "Brain imaging generation with latent diffusion models." MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022. 68 | 69 | [2] Johnson, A., Lungren, M., Peng, Y., Lu, Z., Mark, R., Berkowitz, S., & Horng, S. (2019). MIMIC-CXR-JPG - chest radiographs with structured labels (version 2.0.0). PhysioNet. https://doi.org/10.13026/8360-t248. 70 | 71 | [3] Johnson AE, Pollard TJ, Berkowitz S, Greenbaum NR, Lungren MP, Deng CY, Mark RG, Horng S. MIMIC-CXR: A large publicly available database of labeled chest radiographs. arXiv preprint arXiv:1901.07042. 2019 Jan 21. 72 | -------------------------------------------------------------------------------- /model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json: -------------------------------------------------------------------------------- 1 | { 2 | "imports": [ 3 | "$import torch", 4 | "$from datetime import datetime", 5 | "$from pathlib import Path" 6 | ], 7 | "bundle_root": ".", 8 | "model_dir": "$@bundle_root + '/models'", 9 | "output_dir": "$@bundle_root + '/output'", 10 | "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", 11 | "gender": 0.0, 12 | "age": 0.1, 13 | "ventricular_vol": 0.2, 14 | "brain_vol": 0.4, 15 | "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", 16 | "conditioning": "$torch.tensor([[@gender, @age, @ventricular_vol, @brain_vol]]).to(@device).unsqueeze(1)", 17 | "out_file": "$datetime.now().strftime('sample_%H%M%S_%d%m%Y') + '_' + str(@gender) + '_' + str(@age) + '_' + str(@ventricular_vol) + '_' + str(@brain_vol)", 18 | "autoencoder_def": { 19 | "_target_": "generative.networks.nets.AutoencoderKL", 20 | "spatial_dims": 3, 21 | "in_channels": 1, 22 | "out_channels": 1, 23 | "latent_channels": 3, 24 | "num_channels": [ 25 | 64, 26 | 128, 27 | 128, 28 | 128 29 | ], 30 | "num_res_blocks": 2, 31 | "norm_num_groups": 32, 32 | "norm_eps": 1e-06, 33 | "attention_levels": [ 34 | false, 35 | false, 36 | false, 37 | false 38 | ], 39 | "with_encoder_nonlocal_attn": false, 40 | "with_decoder_nonlocal_attn": false 41 | }, 42 | "load_autoencoder_path": "$@model_dir + '/autoencoder.pth'", 43 | "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", 44 | "autoencoder": "$@autoencoder_def.to(@device)", 45 | "diffusion_def": { 46 | "_target_": "generative.networks.nets.DiffusionModelUNet", 47 | "spatial_dims": 3, 48 | "in_channels": 7, 49 | "out_channels": 3, 50 | "num_channels": [ 51 | 256, 52 | 512, 53 | 768 54 | ], 55 | "num_res_blocks": 2, 56 | "attention_levels": [ 57 | false, 58 | true, 59 | true 60 | ], 61 | "norm_num_groups": 32, 62 | "norm_eps": 1e-06, 63 | "resblock_updown": true, 64 | "num_head_channels": [ 65 | 0, 66 | 512, 67 | 768 68 | ], 69 | "with_conditioning": true, 70 | "transformer_num_layers": 1, 71 | "cross_attention_dim": 4, 72 | "upcast_attention": true, 73 | "use_flash_attention": false 74 | }, 75 | "load_diffusion_path": "$@model_dir + '/diffusion_model.pth'", 76 | "load_diffusion": "$@diffusion_def.load_state_dict(torch.load(@load_diffusion_path))", 77 | "diffusion": "$@diffusion_def.to(@device)", 78 | "scheduler": { 79 | "_target_": "generative.networks.schedulers.DDIMScheduler", 80 | "_requires_": [ 81 | "@load_diffusion", 82 | "@load_autoencoder" 83 | ], 84 | "beta_start": 0.0015, 85 | "beta_end": 0.0205, 86 | "num_train_timesteps": 1000, 87 | "schedule": "scaled_linear_beta", 88 | "clip_sample": false 89 | }, 90 | "noise": "$torch.randn((1, 3, 20, 28, 20)).to(@device)", 91 | "set_timesteps": "$@scheduler.set_timesteps(num_inference_steps=50)", 92 | "sampler": { 93 | "_target_": "scripts.sampler.Sampler", 94 | "_requires_": "@set_timesteps" 95 | }, 96 | "sample": "$@sampler.sampling_fn(@noise, @autoencoder, @diffusion, @scheduler, @conditioning)", 97 | "saver": { 98 | "_target_": "scripts.saver.NiftiSaver", 99 | "_requires_": "@create_output_dir", 100 | "output_dir": "@output_dir" 101 | }, 102 | "save_nii": "$@saver.save(@sample, @out_file)", 103 | "save": "$torch.save(@sample, @output_dir + '/' + @out_file + '.pt')" 104 | } 105 | -------------------------------------------------------------------------------- /tests/test_scheduler_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.networks.schedulers import DDPMScheduler 20 | 21 | TEST_2D_CASE = [] 22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 23 | for variance_type in ["fixed_small", "fixed_large"]: 24 | TEST_2D_CASE.append( 25 | [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)] 26 | ) 27 | 28 | TEST_3D_CASE = [] 29 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]: 30 | for variance_type in ["fixed_small", "fixed_large"]: 31 | TEST_3D_CASE.append( 32 | [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)] 33 | ) 34 | 35 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE 36 | 37 | 38 | class TestDDPMScheduler(unittest.TestCase): 39 | @parameterized.expand(TEST_CASES) 40 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape): 41 | scheduler = DDPMScheduler(**input_param) 42 | original_sample = torch.zeros(input_shape) 43 | noise = torch.randn_like(original_sample) 44 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long() 45 | 46 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps) 47 | self.assertEqual(noisy.shape, expected_shape) 48 | 49 | @parameterized.expand(TEST_CASES) 50 | def test_step_shape(self, input_param, input_shape, expected_shape): 51 | scheduler = DDPMScheduler(**input_param) 52 | model_output = torch.randn(input_shape) 53 | sample = torch.randn(input_shape) 54 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) 55 | self.assertEqual(output_step[0].shape, expected_shape) 56 | self.assertEqual(output_step[1].shape, expected_shape) 57 | 58 | @parameterized.expand(TEST_CASES) 59 | def test_get_velocity_shape(self, input_param, input_shape, expected_shape): 60 | scheduler = DDPMScheduler(**input_param) 61 | sample = torch.randn(input_shape) 62 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long() 63 | velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps) 64 | self.assertEqual(velocity.shape, expected_shape) 65 | 66 | def test_step_learned(self): 67 | for variance_type in ["learned", "learned_range"]: 68 | scheduler = DDPMScheduler(variance_type=variance_type) 69 | model_output = torch.randn(2, 6, 16, 16) 70 | sample = torch.randn(2, 3, 16, 16) 71 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample) 72 | self.assertEqual(output_step[0].shape, sample.shape) 73 | self.assertEqual(output_step[1].shape, sample.shape) 74 | 75 | def test_set_timesteps(self): 76 | scheduler = DDPMScheduler(num_train_timesteps=1000) 77 | scheduler.set_timesteps(num_inference_steps=100) 78 | self.assertEqual(scheduler.num_inference_steps, 100) 79 | self.assertEqual(len(scheduler.timesteps), 100) 80 | 81 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self): 82 | scheduler = DDPMScheduler(num_train_timesteps=1000) 83 | with self.assertRaises(ValueError): 84 | scheduler.set_timesteps(num_inference_steps=2000) 85 | 86 | 87 | if __name__ == "__main__": 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /tests/test_encoder_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.networks.blocks import SpatialRescaler 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | CASES = [ 24 | [ 25 | { 26 | "spatial_dims": 2, 27 | "n_stages": 1, 28 | "method": "bilinear", 29 | "multiplier": 0.5, 30 | "in_channels": None, 31 | "out_channels": None, 32 | }, 33 | (1, 1, 16, 16), 34 | (1, 1, 8, 8), 35 | ], 36 | [ 37 | { 38 | "spatial_dims": 2, 39 | "n_stages": 1, 40 | "method": "bilinear", 41 | "multiplier": 0.5, 42 | "in_channels": 3, 43 | "out_channels": 2, 44 | }, 45 | (1, 3, 16, 16), 46 | (1, 2, 8, 8), 47 | ], 48 | [ 49 | { 50 | "spatial_dims": 3, 51 | "n_stages": 1, 52 | "method": "trilinear", 53 | "multiplier": 0.5, 54 | "in_channels": None, 55 | "out_channels": None, 56 | }, 57 | (1, 1, 16, 16, 16), 58 | (1, 1, 8, 8, 8), 59 | ], 60 | [ 61 | { 62 | "spatial_dims": 3, 63 | "n_stages": 1, 64 | "method": "trilinear", 65 | "multiplier": 0.5, 66 | "in_channels": 3, 67 | "out_channels": 2, 68 | }, 69 | (1, 3, 16, 16, 16), 70 | (1, 2, 8, 8, 8), 71 | ], 72 | [ 73 | { 74 | "spatial_dims": 3, 75 | "n_stages": 1, 76 | "method": "trilinear", 77 | "multiplier": (0.25, 0.5, 0.75), 78 | "in_channels": 3, 79 | "out_channels": 2, 80 | }, 81 | (1, 3, 20, 20, 20), 82 | (1, 2, 5, 10, 15), 83 | ], 84 | [ 85 | {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2}, 86 | (1, 3, 16, 16), 87 | (1, 2, 8, 8), 88 | ], 89 | [ 90 | { 91 | "spatial_dims": 3, 92 | "n_stages": 1, 93 | "size": (8, 8, 8), 94 | "method": "trilinear", 95 | "in_channels": None, 96 | "out_channels": None, 97 | }, 98 | (1, 1, 16, 16, 16), 99 | (1, 1, 8, 8, 8), 100 | ], 101 | ] 102 | 103 | 104 | class TestSpatialRescaler(unittest.TestCase): 105 | @parameterized.expand(CASES) 106 | def test_shape(self, input_param, input_shape, expected_shape): 107 | module = SpatialRescaler(**input_param).to(device) 108 | 109 | result = module(torch.randn(input_shape).to(device)) 110 | self.assertEqual(result.shape, expected_shape) 111 | 112 | def test_method_not_in_available_options(self): 113 | with self.assertRaises(AssertionError): 114 | SpatialRescaler(method="none") 115 | 116 | def test_n_stages_is_negative(self): 117 | with self.assertRaises(AssertionError): 118 | SpatialRescaler(n_stages=-1) 119 | 120 | def test_use_size_but_n_stages_is_not_one(self): 121 | with self.assertRaises(ValueError): 122 | SpatialRescaler(n_stages=2, size=[8, 8, 8]) 123 | 124 | def test_both_size_and_multiplier_defined(self): 125 | with self.assertRaises(ValueError): 126 | SpatialRescaler(size=[1, 2, 3], multiplier=0.5) 127 | 128 | 129 | if __name__ == "__main__": 130 | unittest.main() 131 | -------------------------------------------------------------------------------- /tests/test_adversarial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from parameterized import parameterized 18 | 19 | from generative.losses import PatchAdversarialLoss 20 | 21 | shapes_tensors = {"2d": [4, 1, 64, 64], "3d": [4, 1, 64, 64, 64]} 22 | reductions = ["sum", "mean"] 23 | criterion = ["bce", "least_squares", "hinge"] 24 | 25 | TEST_CASE_CREATION_FAIL = [{"reduction": "sum", "criterion": "invalid"}] 26 | 27 | TEST_CASES_LOSS_LOGIC_2D = [] 28 | TEST_CASES_LOSS_LOGIC_3D = [] 29 | 30 | for c in criterion: 31 | for r in reductions: 32 | TEST_CASES_LOSS_LOGIC_2D.append([{"reduction": r, "criterion": c}, shapes_tensors["2d"]]) 33 | TEST_CASES_LOSS_LOGIC_3D.append([{"reduction": r, "criterion": c}, shapes_tensors["3d"]]) 34 | 35 | TEST_CASES_LOSS_LOGIC_LIST = [] 36 | for c in criterion: 37 | TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["2d"]]) 38 | TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["3d"]]) 39 | 40 | 41 | class TestPatchAdversarialLoss(unittest.TestCase): 42 | def get_input(self, shape, is_positive): 43 | """ 44 | Get tensor for the tests. The tensor is around (-1) or (+1), depending on 45 | is_positive. 46 | """ 47 | if is_positive: 48 | offset = 1 49 | else: 50 | offset = -1 51 | return torch.ones(shape) * (offset) + 0.01 * torch.randn(shape) 52 | 53 | def test_criterion(self): 54 | """ 55 | Make sure that unknown criterion fail. 56 | """ 57 | with self.assertRaises(ValueError): 58 | PatchAdversarialLoss(**TEST_CASE_CREATION_FAIL[0]) 59 | 60 | @parameterized.expand(TEST_CASES_LOSS_LOGIC_2D + TEST_CASES_LOSS_LOGIC_3D) 61 | def test_loss_logic(self, input_param: dict, shape_input: list): 62 | """ 63 | We want to make sure that the adversarial losses do what they should. 64 | If the discriminator takes in a tensor that looks positive, yet the label is fake, 65 | the loss should be bigger than that obtained with a tensor that looks negative. 66 | Same for the real label, and for the generator. 67 | """ 68 | loss = PatchAdversarialLoss(**input_param) 69 | fakes = self.get_input(shape_input, is_positive=False) 70 | reals = self.get_input(shape_input, is_positive=True) 71 | # Discriminator: fake label 72 | loss_disc_f_f = loss(fakes, target_is_real=False, for_discriminator=True) 73 | loss_disc_f_r = loss(reals, target_is_real=False, for_discriminator=True) 74 | assert loss_disc_f_f < loss_disc_f_r 75 | # Discriminator: real label 76 | loss_disc_r_f = loss(fakes, target_is_real=True, for_discriminator=True) 77 | loss_disc_r_r = loss(reals, target_is_real=True, for_discriminator=True) 78 | assert loss_disc_r_f > loss_disc_r_r 79 | # Generator: 80 | loss_gen_f = loss(fakes, target_is_real=True, for_discriminator=False) # target_is_real is overridden 81 | loss_gen_r = loss(reals, target_is_real=True, for_discriminator=False) # target_is_real is overridden 82 | assert loss_gen_f > loss_gen_r 83 | 84 | @parameterized.expand(TEST_CASES_LOSS_LOGIC_LIST) 85 | def test_multiple_discs(self, input_param: dict, shape_input): 86 | shapes = [shape_input] + [shape_input[0:2] + [int(i / j) for i in shape_input[2:]] for j in range(1, 3)] 87 | inputs = [self.get_input(shapes[i], is_positive=True) for i in range(len(shapes))] 88 | loss = PatchAdversarialLoss(**input_param) 89 | assert len(loss(inputs, for_discriminator=True, target_is_real=True)) == 3 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json: -------------------------------------------------------------------------------- 1 | { 2 | "imports": [ 3 | "$import torch", 4 | "$from datetime import datetime", 5 | "$from pathlib import Path", 6 | "$from transformers import CLIPTextModel", 7 | "$from transformers import CLIPTokenizer" 8 | ], 9 | "bundle_root": ".", 10 | "model_dir": "$@bundle_root + '/models'", 11 | "output_dir": "$@bundle_root + '/output'", 12 | "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", 13 | "prompt": "Big right-sided pleural effusion", 14 | "prompt_list": "$['', @prompt]", 15 | "guidance_scale": 7.0, 16 | "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", 17 | "tokenizer": "$CLIPTokenizer.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", subfolder=\"tokenizer\")", 18 | "text_encoder": "$CLIPTextModel.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", subfolder=\"text_encoder\")", 19 | "tokenized_prompt": "$@tokenizer(@prompt_list, padding=\"max_length\", max_length=@tokenizer.model_max_length, truncation=True,return_tensors=\"pt\")", 20 | "prompt_embeds": "$@text_encoder(@tokenized_prompt.input_ids.squeeze(1))[0].to(@device)", 21 | "out_file": "$datetime.now().strftime('sample_%H%M%S_%d%m%Y')", 22 | "autoencoder_def": { 23 | "_target_": "generative.networks.nets.AutoencoderKL", 24 | "spatial_dims": 2, 25 | "in_channels": 1, 26 | "out_channels": 1, 27 | "latent_channels": 3, 28 | "num_channels": [ 29 | 64, 30 | 128, 31 | 128, 32 | 128 33 | ], 34 | "num_res_blocks": 2, 35 | "norm_num_groups": 32, 36 | "norm_eps": 1e-06, 37 | "attention_levels": [ 38 | false, 39 | false, 40 | false, 41 | false 42 | ], 43 | "with_encoder_nonlocal_attn": false, 44 | "with_decoder_nonlocal_attn": false 45 | }, 46 | "load_autoencoder_path": "$@model_dir + '/autoencoder.pth'", 47 | "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", 48 | "autoencoder": "$@autoencoder_def.to(@device)", 49 | "diffusion_def": { 50 | "_target_": "generative.networks.nets.DiffusionModelUNet", 51 | "spatial_dims": 2, 52 | "in_channels": 3, 53 | "out_channels": 3, 54 | "num_channels": [ 55 | 256, 56 | 512, 57 | 768 58 | ], 59 | "num_res_blocks": 2, 60 | "attention_levels": [ 61 | false, 62 | true, 63 | true 64 | ], 65 | "norm_num_groups": 32, 66 | "norm_eps": 1e-06, 67 | "resblock_updown": false, 68 | "num_head_channels": [ 69 | 0, 70 | 512, 71 | 768 72 | ], 73 | "with_conditioning": true, 74 | "transformer_num_layers": 1, 75 | "cross_attention_dim": 1024 76 | }, 77 | "load_diffusion_path": "$@model_dir + '/diffusion_model.pth'", 78 | "load_diffusion": "$@diffusion_def.load_state_dict(torch.load(@load_diffusion_path))", 79 | "diffusion": "$@diffusion_def.to(@device)", 80 | "scheduler": { 81 | "_target_": "generative.networks.schedulers.DDIMScheduler", 82 | "_requires_": [ 83 | "@load_diffusion", 84 | "@load_autoencoder" 85 | ], 86 | "beta_start": 0.0015, 87 | "beta_end": 0.0205, 88 | "num_train_timesteps": 1000, 89 | "schedule": "scaled_linear_beta", 90 | "prediction_type": "v_prediction", 91 | "clip_sample": false 92 | }, 93 | "noise": "$torch.randn((1, 3, 64, 64)).to(@device)", 94 | "set_timesteps": "$@scheduler.set_timesteps(num_inference_steps=50)", 95 | "sampler": { 96 | "_target_": "scripts.sampler.Sampler", 97 | "_requires_": "@set_timesteps" 98 | }, 99 | "sample": "$@sampler.sampling_fn(@noise, @autoencoder, @diffusion, @scheduler, @prompt_embeds)", 100 | "saver": { 101 | "_target_": "scripts.saver.JPGSaver", 102 | "_requires_": "@create_output_dir", 103 | "output_dir": "@output_dir" 104 | }, 105 | "save_jpg": "$@saver.save(@sample, @out_file)", 106 | "save": "$torch.save(@sample, @output_dir + '/' + @out_file + '.pt')" 107 | } 108 | -------------------------------------------------------------------------------- /model-zoo/models/mednist_ddpm/bundle/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # This defines the training script for the network 2 | 3 | # choose a new directory for every run 4 | output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S') 5 | dataset_dir: ./data 6 | 7 | train_data: 8 | _target_ : MedNISTDataset 9 | root_dir: '@dataset_dir' 10 | section: training 11 | download: true 12 | progress: false 13 | seed: 0 14 | 15 | val_data: 16 | _target_ : MedNISTDataset 17 | root_dir: '@dataset_dir' 18 | section: validation 19 | download: true 20 | progress: false 21 | seed: 0 22 | 23 | train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]' 24 | val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]' 25 | 26 | batch_size: 8 27 | num_substeps: 1 28 | num_workers: 4 29 | use_thread_workers: false 30 | 31 | lr: 0.000025 32 | rand_prob: 0.5 33 | num_epochs: 75 34 | val_interval: 5 35 | save_interval: 5 36 | 37 | train_transforms: 38 | - _target_: RandAffined 39 | keys: '@image' 40 | rotate_range: 41 | - ['$-np.pi / 36', '$np.pi / 36'] 42 | - ['$-np.pi / 36', '$np.pi / 36'] 43 | translate_range: 44 | - [-1, 1] 45 | - [-1, 1] 46 | scale_range: 47 | - [-0.05, 0.05] 48 | - [-0.05, 0.05] 49 | spatial_size: [64, 64] 50 | padding_mode: "zeros" 51 | prob: '@rand_prob' 52 | 53 | train_ds: 54 | _target_: Dataset 55 | data: $@train_datalist 56 | transform: 57 | _target_: Compose 58 | transforms: '$@base_transforms + @train_transforms' 59 | 60 | train_loader: 61 | _target_: ThreadDataLoader 62 | dataset: '@train_ds' 63 | batch_size: '@batch_size' 64 | repeats: '@num_substeps' 65 | num_workers: '@num_workers' 66 | use_thread_workers: '@use_thread_workers' 67 | persistent_workers: '$@num_workers > 0' 68 | shuffle: true 69 | 70 | val_ds: 71 | _target_: Dataset 72 | data: $@val_datalist 73 | transform: 74 | _target_: Compose 75 | transforms: '@base_transforms' 76 | 77 | val_loader: 78 | _target_: DataLoader 79 | dataset: '@val_ds' 80 | batch_size: '@batch_size' 81 | num_workers: '@num_workers' 82 | persistent_workers: '$@num_workers > 0' 83 | shuffle: false 84 | 85 | lossfn: 86 | _target_: torch.nn.MSELoss 87 | 88 | optimizer: 89 | _target_: torch.optim.Adam 90 | params: $@network.parameters() 91 | lr: '@lr' 92 | 93 | prepare_batch: 94 | _target_: generative.engines.DiffusionPrepareBatch 95 | num_train_timesteps: '@num_train_timesteps' 96 | 97 | val_handlers: 98 | - _target_: StatsHandler 99 | name: train_log 100 | output_transform: '$lambda x: None' 101 | _disabled_: '@is_not_rank0' 102 | 103 | evaluator: 104 | _target_: SupervisedEvaluator 105 | device: '@device' 106 | val_data_loader: '@val_loader' 107 | network: '@network' 108 | amp: '@use_amp' 109 | inferer: '@inferer' 110 | prepare_batch: '@prepare_batch' 111 | key_val_metric: 112 | val_mean_abs_error: 113 | _target_: MeanAbsoluteError 114 | output_transform: $monai.handlers.from_engine([@pred, @label]) 115 | metric_cmp_fn: '$scripts.inv_metric_cmp_fn' 116 | val_handlers: '$list(filter(bool, @val_handlers))' 117 | 118 | handlers: 119 | - _target_: CheckpointLoader 120 | _disabled_: $not os.path.exists(@ckpt_path) 121 | load_path: '@ckpt_path' 122 | load_dict: 123 | model: '@network' 124 | - _target_: ValidationHandler 125 | validator: '@evaluator' 126 | epoch_level: true 127 | interval: '@val_interval' 128 | - _target_: CheckpointSaver 129 | save_dir: '@output_dir' 130 | save_dict: 131 | model: '@network' 132 | save_interval: '@save_interval' 133 | save_final: true 134 | epoch_level: true 135 | _disabled_: '@is_not_rank0' 136 | 137 | trainer: 138 | _target_: SupervisedTrainer 139 | max_epochs: '@num_epochs' 140 | device: '@device' 141 | train_data_loader: '@train_loader' 142 | network: '@network' 143 | loss_function: '@lossfn' 144 | optimizer: '@optimizer' 145 | inferer: '@inferer' 146 | prepare_batch: '@prepare_batch' 147 | key_train_metric: 148 | train_acc: 149 | _target_: MeanSquaredError 150 | output_transform: $monai.handlers.from_engine([@pred, @label]) 151 | metric_cmp_fn: '$scripts.inv_metric_cmp_fn' 152 | train_handlers: '$list(filter(bool, @handlers))' 153 | amp: '@use_amp' 154 | 155 | training: 156 | - '$monai.utils.set_determinism(0)' 157 | - '$@trainer.run()' 158 | -------------------------------------------------------------------------------- /generative/networks/nets/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from generative.networks.blocks.transformerblock import TransformerBlock 18 | 19 | __all__ = ["DecoderOnlyTransformer"] 20 | 21 | 22 | class AbsolutePositionalEmbedding(nn.Module): 23 | """Absolute positional embedding. 24 | 25 | Args: 26 | max_seq_len: Maximum sequence length. 27 | embedding_dim: Dimensionality of the embedding. 28 | """ 29 | 30 | def __init__(self, max_seq_len: int, embedding_dim: int) -> None: 31 | super().__init__() 32 | self.max_seq_len = max_seq_len 33 | self.embedding_dim = embedding_dim 34 | self.embedding = nn.Embedding(max_seq_len, embedding_dim) 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | batch_size, seq_len = x.size() 38 | positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) 39 | return self.embedding(positions) 40 | 41 | 42 | class DecoderOnlyTransformer(nn.Module): 43 | """Decoder-only (Autoregressive) Transformer model. 44 | 45 | Args: 46 | num_tokens: Number of tokens in the vocabulary. 47 | max_seq_len: Maximum sequence length. 48 | attn_layers_dim: Dimensionality of the attention layers. 49 | attn_layers_depth: Number of attention layers. 50 | attn_layers_heads: Number of attention heads. 51 | with_cross_attention: Whether to use cross attention for conditioning. 52 | embedding_dropout_rate: Dropout rate for the embedding. 53 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | num_tokens: int, 59 | max_seq_len: int, 60 | attn_layers_dim: int, 61 | attn_layers_depth: int, 62 | attn_layers_heads: int, 63 | with_cross_attention: bool = False, 64 | embedding_dropout_rate: float = 0.0, 65 | use_flash_attention: bool = False, 66 | ) -> None: 67 | super().__init__() 68 | self.num_tokens = num_tokens 69 | self.max_seq_len = max_seq_len 70 | self.attn_layers_dim = attn_layers_dim 71 | self.attn_layers_depth = attn_layers_depth 72 | self.attn_layers_heads = attn_layers_heads 73 | self.with_cross_attention = with_cross_attention 74 | 75 | self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) 76 | self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) 77 | self.embedding_dropout = nn.Dropout(embedding_dropout_rate) 78 | 79 | self.blocks = nn.ModuleList( 80 | [ 81 | TransformerBlock( 82 | hidden_size=attn_layers_dim, 83 | mlp_dim=attn_layers_dim * 4, 84 | num_heads=attn_layers_heads, 85 | dropout_rate=0.0, 86 | qkv_bias=False, 87 | causal=True, 88 | sequence_length=max_seq_len, 89 | with_cross_attention=with_cross_attention, 90 | use_flash_attention=use_flash_attention, 91 | ) 92 | for _ in range(attn_layers_depth) 93 | ] 94 | ) 95 | 96 | self.to_logits = nn.Linear(attn_layers_dim, num_tokens) 97 | 98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 99 | tok_emb = self.token_embeddings(x) 100 | pos_emb = self.position_embeddings(x) 101 | x = self.embedding_dropout(tok_emb + pos_emb) 102 | 103 | for block in self.blocks: 104 | x = block(x, context=context) 105 | 106 | return self.to_logits(x) 107 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = monai-generative 3 | author = MONAI Consortium 4 | author_email = monai.contact@gmail.com 5 | url = https://monai.io/ 6 | description = MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications 7 | long_description = file:README.md 8 | long_description_content_type = text/markdown; charset=UTF-8 9 | platforms = OS Independent 10 | license = Apache License 2.0 11 | license_files = 12 | LICENSE 13 | project_urls = 14 | Documentation=https://docs.monai.io/ 15 | Bug Tracker=https://github.com/Project-MONAI/GenerativeModels/issues 16 | Source Code=https://github.com/Project-MONAI/GenerativeModels 17 | classifiers = 18 | Intended Audience :: Developers 19 | Intended Audience :: Education 20 | Intended Audience :: Science/Research 21 | Intended Audience :: Healthcare Industry 22 | Programming Language :: C++ 23 | Programming Language :: Python :: 3 24 | Programming Language :: Python :: 3.8 25 | Programming Language :: Python :: 3.9 26 | Programming Language :: Python :: 3.10 27 | Topic :: Scientific/Engineering 28 | Topic :: Scientific/Engineering :: Artificial Intelligence 29 | Topic :: Scientific/Engineering :: Medical Science Apps. 30 | Topic :: Scientific/Engineering :: Information Analysis 31 | Topic :: Software Development 32 | Topic :: Software Development :: Libraries 33 | Typing :: Typed 34 | 35 | [options] 36 | python_requires = >= 3.8 37 | # for compiling and develop setup only 38 | # no need to specify the versions so that we could 39 | # compile for multiple targeted versions. 40 | setup_requires = 41 | torch 42 | ninja 43 | install_requires = 44 | monai>=1.2.0rc1 45 | torch>=1.9 46 | numpy>=1.20 47 | 48 | [flake8] 49 | select = B,C,E,F,N,P,T4,W,B9 50 | max_line_length = 120 51 | # C408 ignored because we like the dict keyword argument syntax 52 | # E501 is not flexible enough, we're using B950 instead 53 | # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F' 54 | # B023 https://github.com/Project-MONAI/MONAI/issues/4627 55 | # B028 https://github.com/Project-MONAI/MONAI/issues/5855 56 | # B907 https://github.com/Project-MONAI/MONAI/issues/5868 57 | # B908 https://github.com/Project-MONAI/MONAI/issues/6503 58 | ignore = 59 | E203 60 | E501 61 | E741 62 | W503 63 | W504 64 | C408 65 | N812 66 | B023 67 | B905 68 | B028 69 | B907 70 | B908 71 | per_file_ignores = __init__.py: F401, __main__.py: F401 72 | exclude = *.pyi,.git,.eggs,generative/_version.py,versioneer.py,venv,.venv,_version.py,tutorials/ 73 | 74 | [isort] 75 | known_first_party = generative 76 | profile = black 77 | line_length = 120 78 | # generative/networks/layers/ is excluded because it is raising JIT errors 79 | skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py, tutorials/, generative/networks/layers/ 80 | skip_glob = *.pyi 81 | add_imports = from __future__ import annotations 82 | append_only = true 83 | 84 | [mypy] 85 | # Suppresses error messages about imports that cannot be resolved. 86 | ignore_missing_imports = True 87 | # Changes the treatment of arguments with a default value of None by not implicitly making their type Optional. 88 | no_implicit_optional = True 89 | # Warns about casting an expression to its inferred type. 90 | warn_redundant_casts = True 91 | # No error on unneeded # type: ignore comments. 92 | warn_unused_ignores = False 93 | # Shows a warning when returning a value with type Any from a function declared with a non-Any return type. 94 | warn_return_any = True 95 | # Prohibit equality checks, identity checks, and container checks between non-overlapping types. 96 | strict_equality = True 97 | # Shows column numbers in error messages. 98 | show_column_numbers = True 99 | # Shows error codes in error messages. 100 | show_error_codes = True 101 | # Use visually nicer output in error messages: use soft word wrap, show source code snippets, and show error location markers. 102 | pretty = False 103 | # Warns about per-module sections in the config file that do not match any files processed when invoking mypy. 104 | warn_unused_configs = True 105 | # Make arguments prepended via Concatenate be truly positional-only. 106 | strict_concatenate = True 107 | 108 | exclude = venv/ 109 | 110 | [coverage:run] 111 | concurrency = multiprocessing 112 | source = . 113 | data_file = .coverage/.coverage 114 | omit = setup.py 115 | 116 | [coverage:report] 117 | exclude_lines = 118 | pragma: no cover 119 | if TYPE_CHECKING: 120 | # Don't complain if tests don't hit code: 121 | raise NotImplementedError 122 | if __name__ == .__main__.: 123 | show_missing = True 124 | skip_covered = True 125 | 126 | [coverage:xml] 127 | output = coverage.xml 128 | -------------------------------------------------------------------------------- /generative/utils/component_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections import namedtuple 15 | from keyword import iskeyword 16 | from textwrap import dedent, indent 17 | from typing import Any, Callable, Dict, Iterable, TypeVar 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def is_variable(name): 23 | """Returns True if `name` is a valid Python variable name and also not a keyword.""" 24 | return name.isidentifier() and not iskeyword(name) 25 | 26 | 27 | class ComponentStore: 28 | """ 29 | Represents a storage object for other objects (specifically functions) keyed to a name with a description. 30 | 31 | These objects act as global named places for storing components for objects parameterised by component names. 32 | Typically this is functions although other objects can be added. Printing a component store will produce a 33 | list of members along with their docstring information if present. 34 | 35 | Example: 36 | 37 | .. code-block:: python 38 | 39 | TestStore = ComponentStore("Test Store", "A test store for demo purposes") 40 | 41 | @TestStore.add_def("my_func_name", "Some description of your function") 42 | def _my_func(a, b): 43 | '''A description of your function here.''' 44 | return a * b 45 | 46 | print(TestStore) # will print out name, description, and 'my_func_name' with the docstring 47 | 48 | func = TestStore["my_func_name"] 49 | result = func(7, 6) 50 | 51 | """ 52 | 53 | _Component = namedtuple("Component", ("description", "value")) # internal value pair 54 | 55 | def __init__(self, name: str, description: str) -> None: 56 | self.components: Dict[str, self._Component] = {} 57 | self.name: str = name 58 | self.description: str = description 59 | 60 | self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() 61 | 62 | def add(self, name: str, desc: str, value: T) -> T: 63 | """Store the object `value` under the name `name` with description `desc`.""" 64 | if not is_variable(name): 65 | raise ValueError("Name of component must be valid Python identifier") 66 | 67 | self.components[name] = self._Component(desc, value) 68 | return value 69 | 70 | def add_def(self, name: str, desc: str) -> Callable: 71 | """Returns a decorator which stores the decorated function under `name` with description `desc`.""" 72 | 73 | def deco(func): 74 | """Decorator to add a function to a store.""" 75 | return self.add(name, desc, func) 76 | 77 | return deco 78 | 79 | def __contains__(self, name: str) -> bool: 80 | """Returns True if the given name is stored.""" 81 | return name in self.components 82 | 83 | def __len__(self) -> int: 84 | """Returns the number of stored components.""" 85 | return len(self.components) 86 | 87 | def __iter__(self) -> Iterable: 88 | """Yields name/component pairs.""" 89 | for k, v in self.components.items(): 90 | yield k, v.value 91 | 92 | def __str__(self): 93 | result = f"Component Store '{self.name}': {self.description}\nAvailable components:" 94 | for k, v in self.components.items(): 95 | result += f"\n* {k}:" 96 | 97 | if hasattr(v.value, "__doc__"): 98 | doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") 99 | result += f"\n{doc}\n" 100 | else: 101 | result += f" {v.description}" 102 | 103 | return result 104 | 105 | def __getattr__(self, name: str) -> Any: 106 | """Returns the stored object under the given name.""" 107 | if name in self.components: 108 | return self.components[name].value 109 | else: 110 | return self.__getattribute__(name) 111 | 112 | def __getitem__(self, name: str) -> Any: 113 | """Returns the stored object under the given name.""" 114 | if name in self.components: 115 | return self.components[name].value 116 | else: 117 | raise ValueError(f"Component '{name}' not found") 118 | -------------------------------------------------------------------------------- /generative/metrics/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from __future__ import annotations 14 | 15 | import numpy as np 16 | import torch 17 | from monai.metrics.metric import Metric 18 | from scipy import linalg 19 | 20 | 21 | class FIDMetric(Metric): 22 | """ 23 | Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors. 24 | Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." 25 | https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format 26 | (number images, number of features)) extracted from the a pretrained network. 27 | 28 | Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet. 29 | However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and 30 | MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial 31 | average pooling. 32 | """ 33 | 34 | def __init__(self) -> None: 35 | super().__init__() 36 | 37 | def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): 38 | return get_fid_score(y_pred, y) 39 | 40 | 41 | def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 42 | y = y.double() 43 | y_pred = y_pred.double() 44 | 45 | if y.ndimension() > 2: 46 | raise ValueError("Inputs should have (number images, number of features) shape.") 47 | 48 | mu_y_pred = torch.mean(y_pred, dim=0) 49 | sigma_y_pred = _cov(y_pred, rowvar=False) 50 | mu_y = torch.mean(y, dim=0) 51 | sigma_y = _cov(y, rowvar=False) 52 | 53 | return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) 54 | 55 | 56 | def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor: 57 | """ 58 | Estimate a covariance matrix of the variables. 59 | 60 | Args: 61 | input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, 62 | and each column a single observation of all those variables. 63 | rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. 64 | Otherwise, the relationship is transposed: each column represents a variable, while the rows contain 65 | observations. 66 | """ 67 | if input_data.dim() < 2: 68 | input_data = input_data.view(1, -1) 69 | 70 | if not rowvar and input_data.size(0) != 1: 71 | input_data = input_data.t() 72 | 73 | factor = 1.0 / (input_data.size(1) - 1) 74 | input_data = input_data - torch.mean(input_data, dim=1, keepdim=True) 75 | return factor * input_data.matmul(input_data.t()).squeeze() 76 | 77 | 78 | def _sqrtm(input_data: torch.Tensor) -> torch.Tensor: 79 | """Compute the square root of a matrix.""" 80 | scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False) 81 | return torch.from_numpy(scipy_res) 82 | 83 | 84 | def compute_frechet_distance( 85 | mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6 86 | ) -> torch.Tensor: 87 | """The Frechet distance between multivariate normal distributions.""" 88 | diff = mu_x - mu_y 89 | 90 | covmean = _sqrtm(sigma_x.mm(sigma_y)) 91 | 92 | # Product might be almost singular 93 | if not torch.isfinite(covmean).all(): 94 | print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates") 95 | offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon 96 | covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset)) 97 | 98 | # Numerical error might give slight imaginary component 99 | if torch.is_complex(covmean): 100 | if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3): 101 | raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.") 102 | covmean = covmean.real 103 | 104 | tr_covmean = torch.trace(covmean) 105 | return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean 106 | -------------------------------------------------------------------------------- /generative/engines/prepare_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import Dict, Mapping, Optional, Union 15 | 16 | import torch 17 | import torch.nn as nn 18 | from monai.engines import PrepareBatch, default_prepare_batch 19 | 20 | 21 | class DiffusionPrepareBatch(PrepareBatch): 22 | """ 23 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. 24 | 25 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and 26 | return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". 27 | This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. 28 | 29 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition 30 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". 31 | 32 | """ 33 | 34 | def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: 35 | self.condition_name = condition_name 36 | self.num_train_timesteps = num_train_timesteps 37 | 38 | def get_noise(self, images: torch.Tensor) -> torch.Tensor: 39 | """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" 40 | return torch.randn_like(images) 41 | 42 | def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: 43 | """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" 44 | return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() 45 | 46 | def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: 47 | """Return the target for the loss function, this is the `noise` value by default.""" 48 | return noise 49 | 50 | def __call__( 51 | self, 52 | batchdata: Dict[str, torch.Tensor], 53 | device: Optional[Union[str, torch.device]] = None, 54 | non_blocking: bool = False, 55 | **kwargs, 56 | ): 57 | images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) 58 | noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) 59 | timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) 60 | 61 | target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) 62 | infer_kwargs = {"noise": noise, "timesteps": timesteps} 63 | 64 | if self.condition_name is not None and isinstance(batchdata, Mapping): 65 | infer_kwargs["conditioning"] = batchdata[self.condition_name].to( 66 | device, non_blocking=non_blocking, **kwargs 67 | ) 68 | 69 | # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value 70 | return images, target, (), infer_kwargs 71 | 72 | 73 | class VPredictionPrepareBatch(DiffusionPrepareBatch): 74 | """ 75 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. 76 | 77 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and 78 | from this compute the velocity using the provided scheduler. This value is used as the target in place of the 79 | noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer 80 | being used in conjunction with this class expects a "noise" parameter to be provided. 81 | 82 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition 83 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". 84 | 85 | """ 86 | 87 | def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: 88 | super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) 89 | self.scheduler = scheduler 90 | 91 | def get_target(self, images, noise, timesteps): 92 | return self.scheduler.get_velocity(images, noise, timesteps) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # MONAI core 132 | MONAI/ 133 | 134 | # Created by https://www.toptal.com/developers/gitignore/api/jetbrains 135 | # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains 136 | 137 | ### JetBrains ### 138 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 139 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 140 | 141 | # User-specific stuff 142 | .idea/**/workspace.xml 143 | .idea/**/tasks.xml 144 | .idea/**/usage.statistics.xml 145 | .idea/**/dictionaries 146 | .idea/**/shelf 147 | 148 | # AWS User-specific 149 | .idea/**/aws.xml 150 | 151 | # Generated files 152 | .idea/**/contentModel.xml 153 | 154 | # Sensitive or high-churn files 155 | .idea/**/dataSources/ 156 | .idea/**/dataSources.ids 157 | .idea/**/dataSources.local.xml 158 | .idea/**/sqlDataSources.xml 159 | .idea/**/dynamic.xml 160 | .idea/**/uiDesigner.xml 161 | .idea/**/dbnavigator.xml 162 | 163 | # Gradle 164 | .idea/**/gradle.xml 165 | .idea/**/libraries 166 | 167 | # Gradle and Maven with auto-import 168 | # When using Gradle or Maven with auto-import, you should exclude module files, 169 | # since they will be recreated, and may cause churn. Uncomment if using 170 | # auto-import. 171 | # .idea/artifacts 172 | # .idea/compiler.xml 173 | # .idea/jarRepositories.xml 174 | # .idea/modules.xml 175 | # .idea/*.iml 176 | # .idea/modules 177 | # *.iml 178 | # *.ipr 179 | 180 | # CMake 181 | cmake-build-*/ 182 | 183 | # Mongo Explorer plugin 184 | .idea/**/mongoSettings.xml 185 | 186 | # File-based project format 187 | *.iws 188 | 189 | # IntelliJ 190 | out/ 191 | 192 | # mpeltonen/sbt-idea plugin 193 | .idea_modules/ 194 | 195 | # JIRA plugin 196 | atlassian-ide-plugin.xml 197 | 198 | # Cursive Clojure plugin 199 | .idea/replstate.xml 200 | 201 | # SonarLint plugin 202 | .idea/sonarlint/ 203 | 204 | # Crashlytics plugin (for Android Studio and IntelliJ) 205 | com_crashlytics_export_strings.xml 206 | crashlytics.properties 207 | crashlytics-build.properties 208 | fabric.properties 209 | 210 | # Editor-based Rest Client 211 | .idea/httpRequests 212 | 213 | # Android studio 3.1+ serialized cache file 214 | .idea/caches/build_file_checksums.ser 215 | 216 | ### JetBrains Patch ### 217 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 218 | 219 | # *.iml 220 | # modules.xml 221 | # .idea/misc.xml 222 | # *.ipr 223 | 224 | # Sonarlint plugin 225 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 226 | .idea/**/sonarlint/ 227 | 228 | # SonarQube Plugin 229 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 230 | .idea/**/sonarIssues.xml 231 | 232 | # Markdown Navigator plugin 233 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 234 | .idea/**/markdown-navigator.xml 235 | .idea/**/markdown-navigator-enh.xml 236 | .idea/**/markdown-navigator/ 237 | 238 | # Cache file creation bug 239 | # See https://youtrack.jetbrains.com/issue/JBR-2257 240 | .idea/$CACHE_FILE$ 241 | 242 | # CodeStream plugin 243 | # https://plugins.jetbrains.com/plugin/12206-codestream 244 | .idea/codestream.xml 245 | 246 | # Azure Toolkit for IntelliJ plugin 247 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 248 | .idea/**/azureSettings.xml 249 | 250 | # End of https://www.toptal.com/developers/gitignore/api/jetbrains 251 | -------------------------------------------------------------------------------- /tests/test_spade_vaegan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import numpy as np 17 | import torch 18 | from monai.networks import eval_mode 19 | from parameterized import parameterized 20 | 21 | from generative.networks.nets import SPADENet 22 | 23 | CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] 24 | CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] 25 | CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] 26 | 27 | 28 | def create_semantic_data(shape: list, semantic_regions: int): 29 | """ 30 | To create semantic and image mock inputs for the network. 31 | Args: 32 | shape: input shape 33 | semantic_regions: number of semantic regions 34 | Returns: 35 | """ 36 | out_label = torch.zeros(shape) 37 | out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 38 | for i in range(1, semantic_regions): 39 | shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] 40 | start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] 41 | if len(shape) == 2: 42 | out_label[ 43 | start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) 44 | ] = i 45 | base_intensity = torch.ones(shape_square) * np.random.randn() 46 | out_image[ 47 | start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) 48 | ] = (base_intensity + torch.randn(shape_square) * 0.1) 49 | elif len(shape) == 3: 50 | out_label[ 51 | start_point[0] : (start_point[0] + shape_square[0]), 52 | start_point[1] : (start_point[1] + shape_square[1]), 53 | start_point[2] : (start_point[2] + shape_square[2]), 54 | ] = i 55 | base_intensity = torch.ones(shape_square) * np.random.randn() 56 | out_image[ 57 | start_point[0] : (start_point[0] + shape_square[0]), 58 | start_point[1] : (start_point[1] + shape_square[1]), 59 | start_point[2] : (start_point[2] + shape_square[2]), 60 | ] = (base_intensity + torch.randn(shape_square) * 0.1) 61 | else: 62 | ValueError("Supports only 2D and 3D tensors") 63 | 64 | # One hot encode label 65 | out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) 66 | for ch in range(semantic_regions): 67 | out_label_[ch, ...] = out_label == ch 68 | 69 | return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) 70 | 71 | 72 | class TestDiffusionModelUNet2D(unittest.TestCase): 73 | @parameterized.expand(CASE_2D) 74 | def test_forward_2d(self, input_param): 75 | """ 76 | Check that forward method is called correctly and output shape matches. 77 | """ 78 | net = SPADENet(*input_param) 79 | in_label, in_image = create_semantic_data(input_param[4], input_param[3]) 80 | with eval_mode(net): 81 | out, kld = net(in_label, in_image) 82 | self.assertEqual( 83 | False, 84 | True in torch.isnan(out) 85 | or True in torch.isinf(out) 86 | or True in torch.isinf(kld) 87 | or True in torch.isinf(kld), 88 | ) 89 | self.assertEqual(list(out.shape), [1, 1, 64, 64]) 90 | 91 | @parameterized.expand(CASE_2D_BIS) 92 | def test_encoder_decoder(self, input_param): 93 | """ 94 | Check that forward method is called correctly and output shape matches. 95 | """ 96 | net = SPADENet(*input_param) 97 | in_label, in_image = create_semantic_data(input_param[4], input_param[3]) 98 | with eval_mode(net): 99 | out_z = net.encode(in_image) 100 | self.assertEqual(list(out_z.shape), [1, 16]) 101 | out_i = net.decode(in_label, out_z) 102 | self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) 103 | 104 | @parameterized.expand(CASE_3D) 105 | def test_forward_3d(self, input_param): 106 | """ 107 | Check that forward method is called correctly and output shape matches. 108 | """ 109 | net = SPADENet(*input_param) 110 | in_label, in_image = create_semantic_data(input_param[4], input_param[3]) 111 | with eval_mode(net): 112 | out, kld = net(in_label, in_image) 113 | self.assertEqual( 114 | False, 115 | True in torch.isnan(out) 116 | or True in torch.isinf(out) 117 | or True in torch.isinf(kld) 118 | or True in torch.isinf(kld), 119 | ) 120 | self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) 121 | 122 | def test_shape_wrong(self): 123 | """ 124 | We input an input shape that isn't divisible by 2**(n downstream steps) 125 | """ 126 | with self.assertRaises(ValueError): 127 | _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) 128 | 129 | 130 | if __name__ == "__main__": 131 | unittest.main() 132 | -------------------------------------------------------------------------------- /tests/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import argparse 15 | import glob 16 | import inspect 17 | import os 18 | import re 19 | import sys 20 | import time 21 | import unittest 22 | 23 | from monai.utils import PerfContext 24 | 25 | results: dict = {} 26 | 27 | 28 | class TimeLoggingTestResult(unittest.TextTestResult): 29 | """Overload the default results so that we can store the results.""" 30 | 31 | def __init__(self, *args, **kwargs): 32 | super().__init__(*args, **kwargs) 33 | self.timed_tests = {} 34 | 35 | def startTest(self, test): # noqa: N802 36 | """Start timer, print test name, do normal test.""" 37 | self.start_time = time.time() 38 | name = self.getDescription(test) 39 | self.stream.write(f"Starting test: {name}...\n") 40 | super().startTest(test) 41 | 42 | def stopTest(self, test): # noqa: N802 43 | """On test end, get time, print, store and do normal behaviour.""" 44 | elapsed = time.time() - self.start_time 45 | name = self.getDescription(test) 46 | self.stream.write(f"Finished test: {name} ({elapsed:.03}s)\n") 47 | if name in results: 48 | raise AssertionError("expected all keys to be unique") 49 | results[name] = elapsed 50 | super().stopTest(test) 51 | 52 | 53 | def print_results(results, discovery_time, thresh, status): 54 | # only keep results >= threshold 55 | results = dict(filter(lambda x: x[1] > thresh, results.items())) 56 | if len(results) == 0: 57 | return 58 | print(f"\n\n{status}, printing completed times >{thresh}s in ascending order...\n") 59 | timings = dict(sorted(results.items(), key=lambda item: item[1])) 60 | 61 | for r in timings: 62 | if timings[r] >= thresh: 63 | print(f"{r} ({timings[r]:.03}s)") 64 | print(f"test discovery time: {discovery_time:.03}s") 65 | print(f"total testing time: {sum(results.values()):.03}s") 66 | print("Remember to check above times for any errors!") 67 | 68 | 69 | def parse_args(): 70 | parser = argparse.ArgumentParser(description="Runner for MONAI unittests with timing.") 71 | parser.add_argument( 72 | "-s", action="store", dest="path", default=".", help="Directory to start discovery (default: '%(default)s')" 73 | ) 74 | parser.add_argument( 75 | "-p", 76 | action="store", 77 | dest="pattern", 78 | default="test_*.py", 79 | help="Pattern to match tests (default: '%(default)s')", 80 | ) 81 | parser.add_argument( 82 | "-t", 83 | "--thresh", 84 | dest="thresh", 85 | default=10.0, 86 | type=float, 87 | help="Display tests longer than given threshold (default: %(default)d)", 88 | ) 89 | parser.add_argument( 90 | "-v", 91 | "--verbosity", 92 | action="store", 93 | dest="verbosity", 94 | type=int, 95 | default=1, 96 | help="Verbosity level (default: %(default)d)", 97 | ) 98 | parser.add_argument("-q", "--quick", action="store_true", dest="quick", default=False, help="Only do quick tests") 99 | parser.add_argument( 100 | "-f", "--failfast", action="store_true", dest="failfast", default=False, help="Stop testing on first failure" 101 | ) 102 | args = parser.parse_args() 103 | print(f"Running tests in folder: '{args.path}'") 104 | if args.pattern: 105 | print(f"With file pattern: '{args.pattern}'") 106 | 107 | return args 108 | 109 | 110 | def get_default_pattern(loader): 111 | signature = inspect.signature(loader.discover) 112 | params = {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} 113 | return params["pattern"] 114 | 115 | 116 | if __name__ == "__main__": 117 | # Parse input arguments 118 | args = parse_args() 119 | 120 | # If quick is desired, set environment variable 121 | if args.quick: 122 | os.environ["QUICKTEST"] = "True" 123 | 124 | # Get all test names (optionally from some path with some pattern) 125 | with PerfContext() as pc: 126 | # the files are searched from `tests/` folder, starting with `test_` 127 | files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py")) 128 | cases = [] 129 | for test_module in {os.path.basename(f)[:-3] for f in files}: 130 | if re.match(args.pattern, test_module): 131 | cases.append(f"tests.{test_module}") 132 | else: 133 | print(f"monai test runner: excluding tests.{test_module}") 134 | tests = unittest.TestLoader().loadTestsFromNames(cases) 135 | discovery_time = pc.total_time 136 | print(f"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.") 137 | 138 | test_runner = unittest.runner.TextTestRunner( 139 | resultclass=TimeLoggingTestResult, verbosity=args.verbosity, failfast=args.failfast 140 | ) 141 | 142 | # Use try catches to print the current results if encountering exception or keyboard interruption 143 | try: 144 | test_result = test_runner.run(tests) 145 | print_results(results, discovery_time, args.thresh, "tests finished") 146 | sys.exit(not test_result.wasSuccessful()) 147 | except KeyboardInterrupt: 148 | print_results(results, discovery_time, args.thresh, "tests cancelled") 149 | sys.exit(1) 150 | except Exception: 151 | print_results(results, discovery_time, args.thresh, "exception reached") 152 | raise 153 | -------------------------------------------------------------------------------- /tests/test_patch_gan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import unittest 15 | 16 | import torch 17 | from monai.networks import eval_mode 18 | from parameterized import parameterized 19 | 20 | from generative.networks.nets.patchgan_discriminator import MultiScalePatchDiscriminator 21 | from tests.utils import test_script_save 22 | 23 | TEST_2D = [ 24 | { 25 | "num_d": 2, 26 | "num_layers_d": 3, 27 | "spatial_dims": 2, 28 | "num_channels": 8, 29 | "in_channels": 3, 30 | "out_channels": 1, 31 | "kernel_size": 3, 32 | "activation": "LEAKYRELU", 33 | "norm": "instance", 34 | "bias": False, 35 | "dropout": 0.1, 36 | "minimum_size_im": 256, 37 | }, 38 | torch.rand([1, 3, 256, 512]), 39 | [(1, 1, 32, 64), (1, 1, 4, 8)], 40 | [4, 7], 41 | ] 42 | TEST_3D = [ 43 | { 44 | "num_d": 2, 45 | "num_layers_d": 3, 46 | "spatial_dims": 3, 47 | "num_channels": 8, 48 | "in_channels": 3, 49 | "out_channels": 1, 50 | "kernel_size": 3, 51 | "activation": "LEAKYRELU", 52 | "norm": "instance", 53 | "bias": False, 54 | "dropout": 0.1, 55 | "minimum_size_im": 256, 56 | }, 57 | torch.rand([1, 3, 256, 512, 256]), 58 | [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)], 59 | [4, 7], 60 | ] 61 | TEST_3D_POOL = [ 62 | { 63 | "num_d": 2, 64 | "num_layers_d": 3, 65 | "spatial_dims": 3, 66 | "num_channels": 8, 67 | "in_channels": 3, 68 | "out_channels": 1, 69 | "kernel_size": 3, 70 | "pooling_method": "max", 71 | "activation": "LEAKYRELU", 72 | "norm": "instance", 73 | "bias": False, 74 | "dropout": 0.1, 75 | "minimum_size_im": 256, 76 | }, 77 | torch.rand([1, 3, 256, 512, 256]), 78 | [(1, 1, 32, 64, 32), (1, 1, 16, 32, 16)], 79 | [4, 4], 80 | ] 81 | TEST_2D_POOL = [ 82 | { 83 | "num_d": 4, 84 | "num_layers_d": 3, 85 | "spatial_dims": 2, 86 | "num_channels": 8, 87 | "in_channels": 3, 88 | "out_channels": 1, 89 | "kernel_size": 3, 90 | "pooling_method": "avg", 91 | "activation": "LEAKYRELU", 92 | "norm": "instance", 93 | "bias": False, 94 | "dropout": 0.1, 95 | "minimum_size_im": 256, 96 | }, 97 | torch.rand([1, 3, 256, 512]), 98 | [(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16), (1, 1, 4, 8)], 99 | [4, 4, 4, 4], 100 | ] 101 | TEST_LAYER_LIST = [ 102 | { 103 | "num_d": 3, 104 | "num_layers_d": [3,4,5], 105 | "spatial_dims": 2, 106 | "num_channels": 8, 107 | "in_channels": 3, 108 | "out_channels": 1, 109 | "kernel_size": 3, 110 | "activation": "LEAKYRELU", 111 | "norm": "instance", 112 | "bias": False, 113 | "dropout": 0.1, 114 | "minimum_size_im": 256, 115 | }, 116 | torch.rand([1, 3, 256, 512]), 117 | [(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16)], 118 | [4, 5, 6], 119 | ] 120 | TEST_TOO_SMALL_SIZE = [ 121 | { 122 | "num_d": 2, 123 | "num_layers_d": 6, 124 | "spatial_dims": 2, 125 | "num_channels": 8, 126 | "in_channels": 3, 127 | "out_channels": 1, 128 | "kernel_size": 3, 129 | "activation": "LEAKYRELU", 130 | "norm": "instance", 131 | "bias": False, 132 | "dropout": 0.1, 133 | "minimum_size_im": 256, 134 | } 135 | ] 136 | TEST_MISMATCHED_NUM_LAYERS = [ 137 | { 138 | "num_d": 5, 139 | "num_layers_d": [3,4,5], 140 | "spatial_dims": 2, 141 | "num_channels": 8, 142 | "in_channels": 3, 143 | "out_channels": 1, 144 | "kernel_size": 3, 145 | "activation": "LEAKYRELU", 146 | "norm": "instance", 147 | "bias": False, 148 | "dropout": 0.1, 149 | "minimum_size_im": 256, 150 | } 151 | ] 152 | 153 | CASES = [TEST_2D, TEST_3D, TEST_3D_POOL, TEST_2D_POOL, TEST_LAYER_LIST] 154 | 155 | class TestPatchGAN(unittest.TestCase): 156 | @parameterized.expand(CASES) 157 | def test_shape(self, input_param, input_data, expected_shape, features_lengths=None): 158 | net = MultiScalePatchDiscriminator(**input_param) 159 | with eval_mode(net): 160 | result, features = net.forward(input_data) 161 | for r_ind, r in enumerate(result): 162 | self.assertEqual(tuple(r.shape), expected_shape[r_ind]) 163 | for o_d_ind, o_d in enumerate(features): 164 | self.assertEqual(len(o_d), features_lengths[o_d_ind]) 165 | 166 | def test_too_small_shape(self): 167 | with self.assertRaises(AssertionError): 168 | MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0]) 169 | 170 | def test_mismatched_num_layers(self): 171 | with self.assertRaises(AssertionError): 172 | MultiScalePatchDiscriminator(**TEST_MISMATCHED_NUM_LAYERS[0]) 173 | 174 | def test_script(self): 175 | net = MultiScalePatchDiscriminator( 176 | num_d=2, 177 | num_layers_d=3, 178 | spatial_dims=2, 179 | num_channels=8, 180 | in_channels=3, 181 | out_channels=1, 182 | kernel_size=3, 183 | activation="LEAKYRELU", 184 | norm="instance", 185 | bias=False, 186 | dropout=0.1, 187 | minimum_size_im=256, 188 | ) 189 | i = torch.rand([1, 3, 256, 512]) 190 | test_script_save(net, i) 191 | 192 | 193 | if __name__ == "__main__": 194 | unittest.main() 195 | -------------------------------------------------------------------------------- /generative/networks/blocks/selfattention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import importlib.util 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | 21 | if importlib.util.find_spec("xformers") is not None: 22 | import xformers.ops as xops 23 | 24 | has_xformers = True 25 | else: 26 | has_xformers = False 27 | 28 | 29 | class SABlock(nn.Module): 30 | """ 31 | A self-attention block, based on: "Dosovitskiy et al., 32 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 33 | 34 | Args: 35 | hidden_size: dimension of hidden layer. 36 | num_heads: number of attention heads. 37 | dropout_rate: dropout ratio. Defaults to no dropout. 38 | qkv_bias: bias term for the qkv linear layer. 39 | causal: whether to use causal attention. 40 | sequence_length: if causal is True, it is necessary to specify the sequence length. 41 | with_cross_attention: Whether to use cross attention for conditioning. 42 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | hidden_size: int, 48 | num_heads: int, 49 | dropout_rate: float = 0.0, 50 | qkv_bias: bool = False, 51 | causal: bool = False, 52 | sequence_length: int | None = None, 53 | with_cross_attention: bool = False, 54 | use_flash_attention: bool = False, 55 | ) -> None: 56 | super().__init__() 57 | self.hidden_size = hidden_size 58 | self.num_heads = num_heads 59 | self.head_dim = hidden_size // num_heads 60 | self.scale = 1.0 / math.sqrt(self.head_dim) 61 | self.causal = causal 62 | self.sequence_length = sequence_length 63 | self.with_cross_attention = with_cross_attention 64 | self.use_flash_attention = use_flash_attention 65 | 66 | if not (0 <= dropout_rate <= 1): 67 | raise ValueError("dropout_rate should be between 0 and 1.") 68 | self.dropout_rate = dropout_rate 69 | 70 | if hidden_size % num_heads != 0: 71 | raise ValueError("hidden size should be divisible by num_heads.") 72 | 73 | if causal and sequence_length is None: 74 | raise ValueError("sequence_length is necessary for causal attention.") 75 | 76 | if use_flash_attention and not has_xformers: 77 | raise ValueError("use_flash_attention is True but xformers is not installed.") 78 | 79 | # key, query, value projections 80 | self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 81 | self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 82 | self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 83 | 84 | # regularization 85 | self.drop_weights = nn.Dropout(dropout_rate) 86 | self.drop_output = nn.Dropout(dropout_rate) 87 | 88 | # output projection 89 | self.out_proj = nn.Linear(hidden_size, hidden_size) 90 | 91 | if causal and sequence_length is not None: 92 | # causal mask to ensure that attention is only applied to the left in the input sequence 93 | self.register_buffer( 94 | "causal_mask", 95 | torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), 96 | ) 97 | 98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 99 | b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) 100 | 101 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 102 | query = self.to_q(x) 103 | 104 | kv = context if context is not None else x 105 | _, kv_t, _ = kv.size() 106 | key = self.to_k(kv) 107 | value = self.to_v(kv) 108 | 109 | query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) 110 | key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) 111 | value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) 112 | 113 | if self.use_flash_attention: 114 | query = query.contiguous() 115 | key = key.contiguous() 116 | value = value.contiguous() 117 | y = xops.memory_efficient_attention( 118 | query=query, 119 | key=key, 120 | value=value, 121 | scale=self.scale, 122 | p=self.dropout_rate, 123 | attn_bias=xops.LowerTriangularMask() if self.causal else None, 124 | ) 125 | 126 | else: 127 | query = query.transpose(1, 2) # (b, nh, t, hs) 128 | key = key.transpose(1, 2) # (b, nh, kv_t, hs) 129 | value = value.transpose(1, 2) # (b, nh, kv_t, hs) 130 | 131 | # manual implementation of attention 132 | query = query * self.scale 133 | attention_scores = query @ key.transpose(-2, -1) 134 | 135 | if self.causal: 136 | attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) 137 | 138 | attention_probs = F.softmax(attention_scores, dim=-1) 139 | attention_probs = self.drop_weights(attention_probs) 140 | y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) 141 | 142 | y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) 143 | 144 | y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side 145 | 146 | y = self.out_proj(y) 147 | y = self.drop_output(y) 148 | return y 149 | -------------------------------------------------------------------------------- /generative/metrics/ms_ssim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Sequence 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from monai.metrics.regression import RegressionMetric 19 | from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep 20 | 21 | from generative.metrics.ssim import compute_ssim_and_cs 22 | 23 | 24 | class KernelType(StrEnum): 25 | GAUSSIAN = "gaussian" 26 | UNIFORM = "uniform" 27 | 28 | 29 | class MultiScaleSSIMMetric(RegressionMetric): 30 | """ 31 | Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM). 32 | 33 | [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. 34 | Multiscale structural similarity for image quality assessment. 35 | In The Thirty-Seventh Asilomar Conference on Signals, Systems 36 | & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. 37 | 38 | Args: 39 | spatial_dims: number of spatial dimensions of the input images. 40 | data_range: value range of input images. (usually 1.0 or 255) 41 | kernel_type: type of kernel, can be "gaussian" or "uniform". 42 | kernel_size: size of kernel 43 | kernel_sigma: standard deviation for Gaussian kernel. 44 | k1: stability constant used in the luminance denominator 45 | k2: stability constant used in the contrast denominator 46 | weights: parameters for image similarity and contrast sensitivity at different resolution scores. 47 | reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, 48 | available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, 49 | ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction 50 | get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) 51 | """ 52 | 53 | def __init__( 54 | self, 55 | spatial_dims: int, 56 | data_range: float = 1.0, 57 | kernel_type: KernelType | str = KernelType.GAUSSIAN, 58 | kernel_size: int | Sequence[int, ...] = 11, 59 | kernel_sigma: float | Sequence[float, ...] = 1.5, 60 | k1: float = 0.01, 61 | k2: float = 0.03, 62 | weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), 63 | reduction: MetricReduction | str = MetricReduction.MEAN, 64 | get_not_nans: bool = False, 65 | ) -> None: 66 | super().__init__(reduction=reduction, get_not_nans=get_not_nans) 67 | 68 | self.spatial_dims = spatial_dims 69 | self.data_range = data_range 70 | self.kernel_type = kernel_type 71 | 72 | if not isinstance(kernel_size, Sequence): 73 | kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) 74 | self.kernel_size = kernel_size 75 | 76 | if not isinstance(kernel_sigma, Sequence): 77 | kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) 78 | self.kernel_sigma = kernel_sigma 79 | 80 | self.k1 = k1 81 | self.k2 = k2 82 | self.weights = weights 83 | 84 | def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 85 | """ 86 | Args: 87 | y_pred: Predicted image. 88 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 89 | y: Reference image. 90 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 91 | 92 | Raises: 93 | ValueError: when `y_pred` is not a 2D or 3D image. 94 | """ 95 | dims = y_pred.ndimension() 96 | if self.spatial_dims == 2 and dims != 4: 97 | raise ValueError( 98 | f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " 99 | f"spatial dimensions, got {dims}." 100 | ) 101 | 102 | if self.spatial_dims == 3 and dims != 5: 103 | raise ValueError( 104 | f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" 105 | f" spatial dimensions, got {dims}." 106 | ) 107 | 108 | # check if image have enough size for the number of downsamplings and the size of the kernel 109 | weights_div = max(1, (len(self.weights) - 1)) ** 2 110 | y_pred_spatial_dims = y_pred.shape[2:] 111 | for i in range(len(y_pred_spatial_dims)): 112 | if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1: 113 | raise ValueError( 114 | f"For a given number of `weights` parameters {len(self.weights)} and kernel size " 115 | f"{self.kernel_size[i]}, the image height must be larger than " 116 | f"{(self.kernel_size[i] - 1) * weights_div}." 117 | ) 118 | 119 | weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float) 120 | 121 | avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d") 122 | 123 | multiscale_list: list[torch.Tensor] = [] 124 | for _ in range(len(weights)): 125 | ssim, cs = compute_ssim_and_cs( 126 | y_pred=y_pred, 127 | y=y, 128 | spatial_dims=self.spatial_dims, 129 | data_range=self.data_range, 130 | kernel_type=self.kernel_type, 131 | kernel_size=self.kernel_size, 132 | kernel_sigma=self.kernel_sigma, 133 | k1=self.k1, 134 | k2=self.k2, 135 | ) 136 | 137 | cs_per_batch = cs.view(cs.shape[0], -1).mean(1) 138 | 139 | multiscale_list.append(torch.relu(cs_per_batch)) 140 | y_pred = avg_pool(y_pred, kernel_size=2) 141 | y = avg_pool(y, kernel_size=2) 142 | 143 | ssim = ssim.view(ssim.shape[0], -1).mean(1) 144 | multiscale_list[-1] = torch.relu(ssim) 145 | multiscale_list = torch.stack(multiscale_list) 146 | 147 | ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0) 148 | 149 | ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean( 150 | 1, keepdim=True 151 | ) 152 | 153 | return ms_ssim_per_batch 154 | -------------------------------------------------------------------------------- /tests/test_integration_workflows_adversarial.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import os 15 | import shutil 16 | import tempfile 17 | import unittest 18 | from glob import glob 19 | 20 | import monai 21 | import nibabel as nib 22 | import numpy as np 23 | import torch 24 | from monai.data import create_test_image_2d 25 | from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler 26 | from monai.networks.nets import AutoEncoder, Discriminator 27 | from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd 28 | from monai.utils import CommonKeys, set_determinism 29 | 30 | from generative.engines import AdversarialTrainer 31 | from generative.utils import AdversarialKeys as Keys 32 | from tests.utils import DistTestCase, TimedCall, skip_if_quick 33 | 34 | 35 | def run_training_test(root_dir, device="cuda:0"): 36 | learning_rate = 2e-4 37 | real_label = 1 38 | fake_label = 0 39 | 40 | real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) 41 | train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)] 42 | 43 | # prepare real data 44 | train_transforms = Compose( 45 | [ 46 | LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]), 47 | EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2), 48 | ScaleIntensityd(keys=[CommonKeys.IMAGE]), 49 | RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5), 50 | ] 51 | ) 52 | train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) 53 | train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) 54 | 55 | # Create Discriminator 56 | discriminator_net = Discriminator( 57 | in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5 58 | ).to(device) 59 | discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate) 60 | discriminator_loss_criterion = torch.nn.BCELoss() 61 | 62 | def discriminator_loss(real_logits, fake_logits): 63 | real_target = real_logits.new_full((real_logits.shape[0], 1), real_label) 64 | fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label) 65 | real_loss = discriminator_loss_criterion(real_logits, real_target) 66 | fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) 67 | return torch.div(torch.add(real_loss, fake_loss), 2) 68 | 69 | # Create Generator 70 | generator_network = AutoEncoder( 71 | spatial_dims=2, 72 | in_channels=1, 73 | out_channels=1, 74 | channels=(8, 16, 32, 64), 75 | strides=(2, 2, 2, 2), 76 | num_res_units=1, 77 | num_inter_units=1, 78 | ) 79 | generator_network = generator_network.to(device) 80 | generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate) 81 | generator_loss_criterion = torch.nn.MSELoss() 82 | 83 | def reconstruction_loss(recon_images, real_images): 84 | return generator_loss_criterion(recon_images, real_images) 85 | 86 | def generator_loss(fake_logits): 87 | fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label) 88 | recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target) 89 | return recon_loss 90 | 91 | key_train_metric = None 92 | 93 | train_handlers = [ 94 | StatsHandler( 95 | name="training_loss", 96 | output_transform=lambda x: { 97 | Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], 98 | Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], 99 | Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], 100 | }, 101 | ), 102 | TensorBoardStatsHandler( 103 | log_dir=root_dir, 104 | tag_name="training_loss", 105 | output_transform=lambda x: { 106 | Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS], 107 | Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS], 108 | Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS], 109 | }, 110 | ), 111 | CheckpointSaver( 112 | save_dir=root_dir, 113 | save_dict={"g_net": generator_network, "d_net": discriminator_net}, 114 | save_interval=2, 115 | epoch_level=True, 116 | ), 117 | ] 118 | 119 | num_epochs = 5 120 | 121 | trainer = AdversarialTrainer( 122 | device=device, 123 | max_epochs=num_epochs, 124 | train_data_loader=train_loader, 125 | g_network=generator_network, 126 | g_optimizer=generator_optimiser, 127 | g_loss_function=generator_loss, 128 | recon_loss_function=reconstruction_loss, 129 | d_network=discriminator_net, 130 | d_optimizer=discriminator_opt, 131 | d_loss_function=discriminator_loss, 132 | non_blocking=True, 133 | key_train_metric=key_train_metric, 134 | train_handlers=train_handlers, 135 | ) 136 | trainer.run() 137 | 138 | return trainer.state 139 | 140 | 141 | @skip_if_quick 142 | class IntegrationWorkflowsAdversarialTrainer(DistTestCase): 143 | def setUp(self): 144 | set_determinism(seed=0) 145 | 146 | self.data_dir = tempfile.mkdtemp() 147 | for i in range(40): 148 | im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1) 149 | n = nib.Nifti1Image(im, np.eye(4)) 150 | nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz")) 151 | 152 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0") 153 | monai.config.print_config() 154 | 155 | def tearDown(self): 156 | set_determinism(seed=None) 157 | shutil.rmtree(self.data_dir) 158 | 159 | @TimedCall(seconds=200, daemon=False) 160 | def test_training(self): 161 | torch.manual_seed(0) 162 | 163 | finish_state = run_training_test(self.data_dir, device=self.device) 164 | 165 | # Assert AdversarialTrainer training finished 166 | self.assertEqual(finish_state.iteration, 100) 167 | self.assertEqual(finish_state.epoch, 5) 168 | 169 | 170 | if __name__ == "__main__": 171 | unittest.main() 172 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # MONAI Generative Models Tutorials 2 | This directory hosts the MONAI Generative Models tutorials. 3 | 4 | ## Requirements 5 | To run the tutorials, you will need to install the Generative Models package. 6 | Besides that, most of the examples and tutorials require 7 | [matplotlib](https://matplotlib.org/) and [Jupyter Notebook](https://jupyter.org/). 8 | 9 | These can be installed with the following: 10 | 11 | ```bash 12 | python -m pip install -U pip 13 | python -m pip install -U matplotlib 14 | python -m pip install -U notebook 15 | ``` 16 | 17 | Some of the examples may require optional dependencies. In case of any optional import errors, 18 | please install the relevant packages according to MONAI's [installation guide](https://docs.monai.io/en/latest/installation.html). 19 | Or install all optional requirements with the following: 20 | 21 | ```bash 22 | pip install -r requirements-dev.txt 23 | ``` 24 | 25 | ## List of notebooks and examples 26 | 27 | ### Table of Contents 28 | 1. [Diffusion Models](#1-diffusion-models) 29 | 2. [Latent Diffusion Models](#2-latent-diffusion-models) 30 | 3. [VQ-VAE + Transformers](#3-vq-vae--transformers) 31 | 32 | 33 | ### 1. Diffusion Models 34 | 35 | #### Image synthesis with Diffusion Models 36 | 37 | * [Training a 3D Denoising Diffusion Probabilistic Model](./generative/3d_ddpm/3d_ddpm_tutorial.ipynb): This tutorial shows how to easily 38 | train a DDPM on 3D medical data. In this example, we use a downsampled version of the BraTS dataset. We will show how to 39 | make use of the UNet model and the Noise Scheduler necessary to train a diffusion model. Besides that, we show how to 40 | use the DiffusionInferer class to simplify the training and sampling processes. Finally, after training the model, we 41 | show how to use a Noise Scheduler with fewer timesteps to sample synthetic images. 42 | 43 | * [Training a 2D Denoising Diffusion Probabilistic Model](./generative/2d_ddpm/2d_ddpm_tutorial.ipynb): This tutorial shows how to easily 44 | train a DDPM on medical data. In this example, we use the MedNIST dataset, which is very suitable for beginners as a tutorial. 45 | 46 | * [Comparing different noise schedulers](./generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb): In this tutorial, we compare the 47 | performance of different noise schedulers. We will show how to sample a diffusion model using the DDPM, DDIM, and PNDM 48 | schedulers and how different numbers of timesteps affect the quality of the samples. 49 | 50 | * [Training a 2D Denoising Diffusion Probabilistic Model with different parameterisation](./generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb): 51 | In MONAI Generative Models, we support different parameterizations for the diffusion model (epsilon, sample, and 52 | v-prediction). In this tutorial, we show how to train a DDPM using the v-prediction parameterization, which improves the 53 | stability and convergence of the model. 54 | 55 | * [Training a 2D DDPM using Pytorch Ignite](./generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb): Here, we show how to train a DDPM 56 | on medical data using Pytorch Ignite. We will show how to use the DiffusionPrepareBatch to prepare the model inputs and MONAI's SupervisedTrainer and SupervisedEvaluator to train DDPMs. 57 | 58 | * [Using a 2D DDPM to inpaint images](./generative/2d_ddpm/2d_ddpm_inpainting.ipynb): In this tutorial, we show how to use a DDPM to 59 | inpaint of 2D images from the MedNIST dataset using the RePaint method. 60 | 61 | * [Generating conditional samples with a 2D DDPM using classifier-free guidance](./generative/classifier_free_guidance/2d_ddpm_classifier_free_guidance_tutorial.ipynb): 62 | This tutorial shows how easily we can train a Diffusion Model and generate conditional samples using classifier-free guidance in 63 | the MONAI's framework. 64 | 65 | * [Training Diffusion models with Distributed Data Parallel](./generative/distributed_training/ddpm_training_ddp.py): This example shows how to execute distributed training and evaluation based on PyTorch native DistributedDataParallel 66 | module with torch.distributed.launch. 67 | 68 | #### Anomaly Detection with Diffusion Models 69 | 70 | * [Weakly Supervised Anomaly Detection with Implicit Guidance](./generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb): 71 | This tutorial shows how to use a DDPM to perform weakly supervised anomaly detection using classifier-free (implicit) guidance based on the 72 | method proposed by Sanchez et al. [What is Healthy? Generative Counterfactual Diffusion for Lesion Localization](https://arxiv.org/abs/2207.12268). DGM 4 MICCAI 2022 73 | 74 | 75 | ### 2. Latent Diffusion Models 76 | 77 | #### Image synthesis with Latent Diffusion Models 78 | 79 | * [Training a 3D Latent Diffusion Model](./generative/3d_ldm/3d_ldm_tutorial.ipynb): This tutorial shows how to train a LDM on 3D medical 80 | data. In this example, we use the BraTS dataset. We show how to train an AutoencoderKL and connect it to an LDM. We also 81 | comment on the importance of the scaling factor in the LDM used to scale the latent representation of the AEKL to a suitable 82 | range for the diffusion model. Finally, we show how to use the LatentDiffusionInferer class to simplify the training and sampling. 83 | 84 | * [Training a 2D Latent Diffusion Model](./generative/2d_ldm/2d_ldm_tutorial.ipynb): This tutorial shows how to train an LDM on medical 85 | on the MedNIST dataset. We show how to train an AutoencoderKL and connect it to an LDM. 86 | 87 | * Training Autoencoder with KL-regularization: In this section, we focus on training an AutoencoderKL on [2D](./generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb) and [3D](./generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb) medical data, 88 | that can be used as the compression model used in a Latent Diffusion Model. 89 | 90 | #### Super-resolution with Latent Diffusion Models 91 | 92 | * [Super-resolution using Stable Diffusion Upscalers method](./generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb): 93 | In this tutorial, we show how to perform super-resolution on 2D images from the MedNIST dataset using the Stable 94 | Diffusion Upscalers method. In this example, we will show how to condition a latent diffusion model on a low-resolution image 95 | as well as how to use the DiffusionModelUNet's class_labels conditioning to condition the model on the level of noise added to the image 96 | (aka "noise conditioning augmentation") 97 | 98 | 99 | ### 3. VQ-VAE + Transformers 100 | 101 | #### Image synthesis with VQ-VAE + Transformers 102 | 103 | * [Training a 2D VQ-VAE + Autoregressive Transformers](./generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb): This tutorial shows how to train 104 | a Vector-Quantized Variation Autoencoder + Transformers on the MedNIST dataset. 105 | 106 | * Training VQ-VAEs and VQ-GANs: In this section, we show how to train Vector Quantized Variation Autoencoder (on [2D](./generative/2d_vqvae/2d_vqvae_tutorial.ipynb) and [3D](./generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb) data) and 107 | show how to use the PatchDiscriminator class to train a [VQ-GAN](./generative/2d_vqgan/2d_vqgan_tutorial.ipynb) and improve the quality of the generated images. 108 | 109 | #### Anomaly Detection with VQ-VAE + Transformers 110 | 111 | * [Anomaly Detection with 2D VQ-VAE + Autoregressive Transformers](./generative/anomaly_detection/anomaly_detection_with_transformers.ipynb): This tutorial shows how to 112 | train a Vector-Quantized Variation Autoencoder + Transformers on the MedNIST dataset and use it to extract the likelihood of 113 | testing images to be part of the in-distribution class (used during training). 114 | --------------------------------------------------------------------------------