├── requirements.txt
├── model-zoo
├── models
│ ├── brain_image_synthesis_latent_diffusion_model
│ │ ├── scripts
│ │ │ ├── __init__.py
│ │ │ ├── saver.py
│ │ │ └── sampler.py
│ │ ├── docs
│ │ │ ├── figure_1.png
│ │ │ └── README.md
│ │ ├── large_files.yml
│ │ └── configs
│ │ │ ├── logging.conf
│ │ │ ├── metadata.json
│ │ │ └── inference.json
│ ├── cxr_image_synthesis_latent_diffusion_model
│ │ ├── scripts
│ │ │ ├── __init__.py
│ │ │ ├── saver.py
│ │ │ └── sampler.py
│ │ ├── docs
│ │ │ ├── figure_1.png
│ │ │ └── README.md
│ │ ├── large_files.yml
│ │ └── configs
│ │ │ ├── logging.conf
│ │ │ ├── metadata.json
│ │ │ └── inference.json
│ └── mednist_ddpm
│ │ └── bundle
│ │ ├── configs
│ │ ├── logging.conf
│ │ ├── train_multigpu.yaml
│ │ ├── infer.yaml
│ │ ├── common.yaml
│ │ ├── metadata.json
│ │ └── train.yaml
│ │ ├── scripts
│ │ └── __init__.py
│ │ └── docs
│ │ ├── README.md
│ │ ├── sub_train.sh
│ │ └── sub_train_multigpu.sh
└── README.md
├── requirements-min.txt
├── .deepsource.toml
├── generative
├── networks
│ ├── __init__.py
│ ├── layers
│ │ └── __init__.py
│ ├── blocks
│ │ ├── __init__.py
│ │ ├── encoder_modules.py
│ │ ├── spade_norm.py
│ │ ├── transformerblock.py
│ │ └── selfattention.py
│ ├── schedulers
│ │ └── __init__.py
│ └── nets
│ │ ├── __init__.py
│ │ └── transformer.py
├── version.py
├── __init__.py
├── engines
│ ├── __init__.py
│ └── prepare_batch.py
├── losses
│ ├── __init__.py
│ └── spectral_loss.py
├── metrics
│ ├── __init__.py
│ ├── mmd.py
│ ├── fid.py
│ └── ms_ssim.py
├── utils
│ ├── __init__.py
│ ├── misc.py
│ ├── enums.py
│ └── component_store.py
└── inferers
│ └── __init__.py
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── tests
├── test_compute_fid_metric.py
├── __init__.py
├── test_transformer.py
├── test_compute_mmd_metric.py
├── test_misc.py
├── test_controlnet.py
├── test_selfattention.py
├── min_tests.py
├── test_component_store.py
├── test_scheduler_ddim.py
├── test_scheduler_pndm.py
├── test_spectral_loss.py
├── test_perceptual_loss.py
├── test_compute_multiscalessim_metric.py
├── test_vector_quantizer.py
├── test_scheduler_ddpm.py
├── test_encoder_modules.py
├── test_adversarial.py
├── test_spade_vaegan.py
├── runner.py
├── test_patch_gan.py
└── test_integration_workflows_adversarial.py
├── requirements-dev.txt
├── pyproject.toml
├── .pre-commit-config.yaml
├── README.md
├── CODE_OF_CONDUCT.md
├── setup.cfg
├── .gitignore
└── tutorials
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.17
2 | torch>=1.8
3 | monai>=1.2.0rc1
4 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements-min.txt:
--------------------------------------------------------------------------------
1 | # Requirements for minimal tests
2 | -r requirements.txt
3 | setuptools>65.5.0,<66.0.0
4 | coverage>=5.5
5 | parameterized
6 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Project-MONAI/GenerativeModels/HEAD/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/figure_1.png
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Project-MONAI/GenerativeModels/HEAD/model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/figure_1.png
--------------------------------------------------------------------------------
/.deepsource.toml:
--------------------------------------------------------------------------------
1 | version = 1
2 |
3 | test_patterns = ["tests/**"]
4 |
5 | [[analyzers]]
6 | name = "python"
7 | enabled = true
8 |
9 | [analyzers.meta]
10 | runtime_version = "3.x.x"
11 |
12 | [[analyzers]]
13 | name = "test-coverage"
14 | enabled = true
15 |
16 | [[analyzers]]
17 | name = "docker"
18 | enabled = true
19 |
20 | [[analyzers]]
21 | name = "shell"
22 | enabled = true
23 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml:
--------------------------------------------------------------------------------
1 | large_files:
2 | - path: "models/autoencoder.pth"
3 | url: "https://drive.google.com/uc?export=download&id=1paDN1m-Q_Oy8d_BanPkRTi3RlNB_Sv_h"
4 | hash_val: ""
5 | hash_type: ""
6 | - path: "models/diffusion_model.pth"
7 | url: "https://drive.google.com/uc?export=download&id=1CjcmiPu5_QWr-f7wDJsXrCCcVeczneGT"
8 | hash_val: ""
9 | hash_type: ""
10 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/large_files.yml:
--------------------------------------------------------------------------------
1 | large_files:
2 | - path: "models/autoencoder.pth"
3 | url: "https://drive.google.com/uc?export=download&id=1CZHwxHJWybOsDavipD0EorDPOo_mzNeX"
4 | hash_val: ""
5 | hash_type: ""
6 | - path: "models/diffusion_model.pth"
7 | url: "https://drive.google.com/uc?export=download&id=1XO-ak93ZuOcGTCpgRtqgIeZq3dG5ExN6"
8 | hash_val: ""
9 | hash_type: ""
10 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/logging.conf:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root
3 |
4 | [handlers]
5 | keys=consoleHandler
6 |
7 | [formatters]
8 | keys=fullFormatter
9 |
10 | [logger_root]
11 | level=INFO
12 | handlers=consoleHandler
13 |
14 | [handler_consoleHandler]
15 | class=StreamHandler
16 | level=INFO
17 | formatter=fullFormatter
18 | args=(sys.stdout,)
19 |
20 | [formatter_fullFormatter]
21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
22 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/logging.conf:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root
3 |
4 | [handlers]
5 | keys=consoleHandler
6 |
7 | [formatters]
8 | keys=fullFormatter
9 |
10 | [logger_root]
11 | level=INFO
12 | handlers=consoleHandler
13 |
14 | [handler_consoleHandler]
15 | class=StreamHandler
16 | level=INFO
17 | formatter=fullFormatter
18 | args=(sys.stdout,)
19 |
20 | [formatter_fullFormatter]
21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
22 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/logging.conf:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root
3 |
4 | [handlers]
5 | keys=consoleHandler
6 |
7 | [formatters]
8 | keys=fullFormatter
9 |
10 | [logger_root]
11 | level=INFO
12 | handlers=consoleHandler
13 |
14 | [handler_consoleHandler]
15 | class=StreamHandler
16 | level=INFO
17 | formatter=fullFormatter
18 | args=(sys.stdout,)
19 |
20 | [formatter_fullFormatter]
21 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
22 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 |
4 | def inv_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
5 | """
6 | This inverts comparison for those metrics which reduce like loss values, such that the lower one is better.
7 |
8 | Args:
9 | current_metric: metric value of current round computation.
10 | prev_best: the best metric value of previous rounds to compare with.
11 | """
12 | return current_metric < prev_best
13 |
--------------------------------------------------------------------------------
/model-zoo/README.md:
--------------------------------------------------------------------------------
1 | # Generative Models - Model Zoo
2 |
3 | In this directory, we include the prototypes of the model zoo for the MONAI Generative Models project.
4 | Different from the official one, we do not include all features from the [official one](https://github.com/Project-MONAI/model-zoo).
5 | For this reason, it is not possible to download the models directly with the `python -m monai.bundle run ...` command.
6 | In order to use our models, please, manually download them with their link specified in the `large_files.yml` files,
7 | and place them inside the folder path specified in the same .yml file.
8 |
--------------------------------------------------------------------------------
/generative/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/saver.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 |
7 |
8 | class JPGSaver:
9 | def __init__(self, output_dir: str) -> None:
10 | super().__init__()
11 | self.output_dir = output_dir
12 |
13 | def save(self, image_data: torch.Tensor, file_name: str) -> None:
14 | image_data = np.clip(image_data.cpu().numpy(), 0, 1)
15 | image_data = (image_data * 255).astype(np.uint8)
16 | im = Image.fromarray(image_data[0, 0])
17 | im.save(self.output_dir + "/" + file_name + ".jpg")
18 |
--------------------------------------------------------------------------------
/generative/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | __version__ = "0.2.3"
15 |
--------------------------------------------------------------------------------
/generative/networks/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from .vector_quantizer import EMAQuantizer, VectorQuantizer
13 |
--------------------------------------------------------------------------------
/generative/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .version import __version__
15 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/docs/README.md:
--------------------------------------------------------------------------------
1 |
2 | # MedNIST DDPM Example Bundle
3 |
4 | This implements roughly equivalent code to the "Denoising Diffusion Probabilistic Models with MedNIST Dataset" example notebook. This includes scripts for training with single or multiple GPUs and a visualisation notebook.
5 |
6 | The files included here demonstrate how to use the bundle:
7 | * [2d_ddpm_bundle_tutorial.ipynb](./2d_ddpm_bundle_tutorial.ipynb) - demonstrates command line and in-code invocation of the bundle's training and inference scripts
8 | * [sub_train.sh](sub_train.sh) - SLURM submission script example for training
9 | * [sub_train_multigpu.sh](sub_train_multigpu.sh) - SLURM submission script example for training with multiple GPUs
10 |
--------------------------------------------------------------------------------
/generative/engines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch
15 | from .trainer import AdversarialTrainer
16 |
--------------------------------------------------------------------------------
/generative/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .adversarial_loss import PatchAdversarialLoss
15 | from .perceptual import PerceptualLoss
16 | from .spectral_loss import JukeboxLoss
17 |
--------------------------------------------------------------------------------
/generative/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .fid import FIDMetric
15 | from .mmd import MMDMetric
16 | from .ms_ssim import MultiScaleSSIMMetric
17 | from .ssim import SSIMMetric
18 |
--------------------------------------------------------------------------------
/generative/networks/blocks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .encoder_modules import SpatialRescaler
15 | from .selfattention import SABlock
16 | from .transformerblock import TransformerBlock
17 |
--------------------------------------------------------------------------------
/generative/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .component_store import ComponentStore
15 | from .enums import AdversarialIterationEvents, AdversarialKeys
16 | from .misc import unsqueeze_left, unsqueeze_right
17 |
--------------------------------------------------------------------------------
/generative/networks/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .ddim import DDIMScheduler
15 | from .ddpm import DDPMScheduler
16 | from .pndm import PNDMScheduler
17 | from .scheduler import NoiseSchedules, Scheduler
18 |
--------------------------------------------------------------------------------
/generative/inferers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .inferer import (
15 | ControlNetDiffusionInferer,
16 | ControlNetLatentDiffusionInferer,
17 | DiffusionInferer,
18 | LatentDiffusionInferer,
19 | VQVAETransformerInferer,
20 | )
21 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/train_multigpu.yaml:
--------------------------------------------------------------------------------
1 | # This can be mixed in with the training script to enable multi-GPU training
2 |
3 | network:
4 | _target_: torch.nn.parallel.DistributedDataParallel
5 | module: $@network_def.to(@device)
6 | device_ids: ['@device']
7 | find_unused_parameters: true
8 |
9 | tsampler:
10 | _target_: DistributedSampler
11 | dataset: '@train_ds'
12 | even_divisible: true
13 | shuffle: true
14 | train_loader#sampler: '@tsampler'
15 | train_loader#shuffle: false
16 |
17 | vsampler:
18 | _target_: DistributedSampler
19 | dataset: '@val_ds'
20 | even_divisible: false
21 | shuffle: false
22 | val_loader#sampler: '@vsampler'
23 |
24 | training:
25 | - $import torch.distributed as dist
26 | - $dist.init_process_group(backend='nccl')
27 | - $torch.cuda.set_device(@device)
28 | - $monai.utils.set_determinism(seed=123),
29 | - $@trainer.run()
30 | - $dist.destroy_process_group()
31 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from setuptools import find_packages, setup
15 |
16 | setup(
17 | name="monai-generative",
18 | packages=find_packages(exclude=[]),
19 | version="0.2.3",
20 | description="Installer to help to use the prototypes from MONAI generative models in other projects.",
21 | install_requires=["monai>=1.3.0"],
22 | )
23 |
--------------------------------------------------------------------------------
/generative/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from typing import TypeVar
15 |
16 | T = TypeVar("T")
17 |
18 |
19 | def unsqueeze_right(arr: T, ndim: int) -> T:
20 | """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
21 | return arr[(...,) + (None,) * (ndim - arr.ndim)]
22 |
23 |
24 | def unsqueeze_left(arr: T, ndim: int) -> T:
25 | """Preppend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
26 | return arr[(None,) * (ndim - arr.ndim)]
27 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/docs/sub_train.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | #SBATCH --nodes=1
3 | #SBATCH -J mednist_train
4 | #SBATCH -c 4
5 | #SBATCH --gres=gpu:1
6 | #SBATCH --time=2:00:00
7 | #SBATCH -p small
8 |
9 | set -v
10 |
11 | # change this if run submitted from a different directory
12 | export BUNDLE="$(pwd)/.."
13 |
14 | # have to set PYTHONPATH to find MONAI and GenerativeModels as well as the bundle's script directory
15 | export PYTHONPATH="$HOME/MONAI:$HOME/GenerativeModels:$BUNDLE"
16 |
17 | # change this to load a checkpoint instead of started from scratch
18 | CKPT=none
19 |
20 | CONFIG="'$BUNDLE/configs/common.yaml', '$BUNDLE/configs/train.yaml'"
21 |
22 | # change this to point to where MedNIST is located
23 | DATASET="$(pwd)"
24 |
25 | # it's useful to include the configuration in the log file
26 | cat "$BUNDLE/configs/common.yaml"
27 | cat "$BUNDLE/configs/train.yaml"
28 |
29 | python -m monai.bundle run training \
30 | --meta_file "$BUNDLE/configs/metadata.json" \
31 | --config_file "$CONFIG" \
32 | --logging_file "$BUNDLE/configs/logging.conf" \
33 | --bundle_root "$BUNDLE" \
34 | --dataset_dir "$DATASET"
35 |
--------------------------------------------------------------------------------
/generative/networks/nets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from .autoencoderkl import AutoencoderKL
15 | from .controlnet import ControlNet
16 | from .diffusion_model_unet import DiffusionModelUNet
17 | from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
18 | from .spade_autoencoderkl import SPADEAutoencoderKL
19 | from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
20 | from .spade_network import SPADENet
21 | from .transformer import DecoderOnlyTransformer
22 | from .vqvae import VQVAE
23 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/saver.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import nibabel as nib
4 | import numpy as np
5 | import torch
6 |
7 |
8 | class NiftiSaver:
9 | def __init__(self, output_dir: str) -> None:
10 | super().__init__()
11 | self.output_dir = output_dir
12 | self.affine = np.array(
13 | [
14 | [-1.0, 0.0, 0.0, 96.48149872],
15 | [0.0, 1.0, 0.0, -141.47715759],
16 | [0.0, 0.0, 1.0, -156.55375671],
17 | [0.0, 0.0, 0.0, 1.0],
18 | ]
19 | )
20 |
21 | def save(self, image_data: torch.Tensor, file_name: str) -> None:
22 | image_data = image_data.cpu().numpy()
23 | image_data = image_data[0, 0, 5:-5, 5:-5, :-15]
24 | image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min())
25 | image_data = (image_data * 255).astype(np.uint8)
26 |
27 | empty_header = nib.Nifti1Header()
28 | sample_nii = nib.Nifti1Image(image_data, self.affine, empty_header)
29 | nib.save(sample_nii, f"{str(self.output_dir)}/{file_name}.nii.gz")
30 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/infer.yaml:
--------------------------------------------------------------------------------
1 | # This defines an inference script for generating a random image to a Pytorch file
2 |
3 | batch_size: 1
4 | num_workers: 0
5 |
6 | noise: $torch.rand(1,1,@image_dim,@image_dim) # create a random image every time this program is run
7 |
8 | out_file: "" # where to save the tensor to
9 |
10 | # using a lambda this defines a simple sampling function used below
11 | sample: '$lambda x: @inferer.sample(input_noise=x, diffusion_model=@network, scheduler=@scheduler)'
12 |
13 | load_state: '$@network.load_state_dict(torch.load(@ckpt_path))' # command to load the saved model weights
14 |
15 | save_trans:
16 | _target_: Compose
17 | transforms:
18 | - _target_: ScaleIntensity
19 | minv: 0.0
20 | maxv: 255.0
21 | - _target_: ToTensor
22 | track_meta: false
23 | - _target_: SaveImage
24 | output_ext: "jpg"
25 | resample: false
26 | output_dtype: '$torch.uint8'
27 | separate_folder: false
28 | output_postfix: '@out_file'
29 |
30 | # program to load the model weights, run `sample`, and store results to `out_file`
31 | testing:
32 | - '@load_state'
33 | - '$torch.save(@sample(@noise.to(@device)), @out_file)'
34 |
35 | #alternative version which saves to a jpg file
36 | testing_jpg:
37 | - '@load_state'
38 | - '$@save_trans(@sample(@noise.to(@device))[0])'
39 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/docs/sub_train_multigpu.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | #SBATCH --nodes=1
3 | #SBATCH -J mednist_train
4 | #SBATCH -c 4
5 | #SBATCH --gres=gpu:2
6 | #SBATCH --time=2:00:00
7 | #SBATCH -p big
8 |
9 | set -v
10 |
11 | # change this if run submitted from a different directory
12 | export BUNDLE="$(pwd)/.."
13 |
14 | # have to set PYTHONPATH to find MONAI and GenerativeModels as well as the bundle's script directory
15 | export PYTHONPATH="$HOME/MONAI:$HOME/GenerativeModels:$BUNDLE"
16 |
17 | # change this to load a checkpoint instead of started from scratch
18 | CKPT=none
19 |
20 | CONFIG="'$BUNDLE/configs/common.yaml', '$BUNDLE/configs/train.yaml', '$BUNDLE/configs/train_multigpu.yaml'"
21 |
22 | # change this to point to where MedNIST is located
23 | DATASET="$(pwd)"
24 |
25 | # it's useful to include the configuration in the log file
26 | cat "$BUNDLE/configs/common.yaml"
27 | cat "$BUNDLE/configs/train.yaml"
28 | cat "$BUNDLE/configs/train_multigpu.yaml"
29 |
30 | # remember to change arguments to match how many nodes and GPUs you have
31 | torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run training \
32 | --meta_file "$BUNDLE/configs/metadata.json" \
33 | --config_file "$CONFIG" \
34 | --logging_file "$BUNDLE/configs/logging.conf" \
35 | --bundle_root "$BUNDLE" \
36 | --dataset_dir "$DATASET"
37 |
--------------------------------------------------------------------------------
/tests/test_compute_fid_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import numpy as np
17 | import torch
18 |
19 | from generative.metrics import FIDMetric
20 |
21 |
22 | class TestFIDMetric(unittest.TestCase):
23 | def test_results(self):
24 | x = torch.Tensor([[1, 2], [1, 2], [1, 2]])
25 | y = torch.Tensor([[2, 2], [1, 2], [1, 2]])
26 | results = FIDMetric()(x, y)
27 | np.testing.assert_allclose(results.cpu().numpy(), 0.4444, atol=1e-4)
28 |
29 | def test_input_dimensions(self):
30 | with self.assertRaises(ValueError):
31 | FIDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
32 |
33 |
34 | if __name__ == "__main__":
35 | unittest.main()
36 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # Full requirements for developments
2 | -r requirements-min.txt
3 | pytorch-ignite==0.4.10
4 | gdown>=4.4.0
5 | scipy
6 | itk>=5.2
7 | nibabel
8 | pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
9 | tensorboard>=2.6 # https://github.com/Project-MONAI/MONAI/issues/5776
10 | scikit-image>=0.19.0
11 | tqdm>=4.47.0
12 | lmdb
13 | flake8>=3.8.1
14 | flake8-bugbear
15 | flake8-comprehensions
16 | flake8-executable
17 | pylint!=2.13 # https://github.com/PyCQA/pylint/issues/5969
18 | mccabe
19 | pep8-naming
20 | pycodestyle
21 | pyflakes
22 | black
23 | isort
24 | pytype>=2020.6.1; platform_system != "Windows"
25 | types-pkg_resources
26 | mypy>=0.790
27 | ninja
28 | torchvision
29 | psutil
30 | Sphinx==3.5.3
31 | recommonmark==0.6.0
32 | sphinx-autodoc-typehints==1.11.1
33 | sphinx-rtd-theme==0.5.2
34 | cucim==22.8.1; platform_system == "Linux"
35 | openslide-python==1.1.2
36 | imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
37 | tifffile; platform_system == "Linux" or platform_system == "Darwin"
38 | pandas
39 | requests
40 | einops
41 | transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
42 | mlflow
43 | matplotlib!=3.5.0
44 | tensorboardX
45 | types-PyYAML
46 | pyyaml
47 | fire
48 | jsonschema
49 | pynrrd
50 | pre-commit
51 | pydicom
52 | h5py
53 | nni
54 | optuna
55 | git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
56 | lpips==0.1.4
57 | xformers==0.0.16
58 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/scripts/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn as nn
5 | from monai.utils import optional_import
6 | from torch.cuda.amp import autocast
7 |
8 | tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
9 |
10 |
11 | class Sampler:
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | @torch.no_grad()
16 | def sampling_fn(
17 | self,
18 | noise: torch.Tensor,
19 | autoencoder_model: nn.Module,
20 | diffusion_model: nn.Module,
21 | scheduler: nn.Module,
22 | prompt_embeds: torch.Tensor,
23 | guidance_scale: float = 7.0,
24 | scale_factor: float = 0.3,
25 | ) -> torch.Tensor:
26 | if has_tqdm:
27 | progress_bar = tqdm(scheduler.timesteps)
28 | else:
29 | progress_bar = iter(scheduler.timesteps)
30 |
31 | for t in progress_bar:
32 | noise_input = torch.cat([noise] * 2)
33 | model_output = diffusion_model(
34 | noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds
35 | )
36 | noise_pred_uncond, noise_pred_text = model_output.chunk(2)
37 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
38 | noise, _ = scheduler.step(noise_pred, t, noise)
39 |
40 | with autocast():
41 | sample = autoencoder_model.decode_stage_2_outputs(noise / scale_factor)
42 |
43 | return sample
44 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import sys
15 | import unittest
16 | import warnings
17 |
18 |
19 | def _enter_pr_4800(self):
20 | """
21 | code from https://github.com/python/cpython/pull/4800
22 | """
23 | # The __warningregistry__'s need to be in a pristine state for tests
24 | # to work properly.
25 | for v in list(sys.modules.values()):
26 | if getattr(v, "__warningregistry__", None):
27 | v.__warningregistry__ = {}
28 | self.warnings_manager = warnings.catch_warnings(record=True)
29 | self.warnings = self.warnings_manager.__enter__()
30 | warnings.simplefilter("always", self.expected)
31 | return self
32 |
33 |
34 | # FIXME: workaround for https://bugs.python.org/issue29620
35 | try:
36 | # Suppression for issue #494: tests/__init__.py:34: error: Cannot assign to a method
37 | unittest.case._AssertWarnsContext.__enter__ = _enter_pr_4800 # type: ignore
38 | except AttributeError:
39 | pass
40 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/scripts/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | import torch.nn as nn
5 | from monai.utils import optional_import
6 | from torch.cuda.amp import autocast
7 |
8 | tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
9 |
10 |
11 | class Sampler:
12 | def __init__(self) -> None:
13 | super().__init__()
14 |
15 | @torch.no_grad()
16 | def sampling_fn(
17 | self,
18 | input_noise: torch.Tensor,
19 | autoencoder_model: nn.Module,
20 | diffusion_model: nn.Module,
21 | scheduler: nn.Module,
22 | conditioning: torch.Tensor,
23 | ) -> torch.Tensor:
24 | if has_tqdm:
25 | progress_bar = tqdm(scheduler.timesteps)
26 | else:
27 | progress_bar = iter(scheduler.timesteps)
28 |
29 | image = input_noise
30 | cond_concat = conditioning.squeeze(1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
31 | cond_concat = cond_concat.expand(list(cond_concat.shape[0:2]) + list(input_noise.shape[2:]))
32 | for t in progress_bar:
33 | with torch.no_grad():
34 | model_output = diffusion_model(
35 | torch.cat((image, cond_concat), dim=1),
36 | timesteps=torch.Tensor((t,)).to(input_noise.device).long(),
37 | context=conditioning,
38 | )
39 | image, _ = scheduler.step(model_output, t, image)
40 |
41 | with torch.no_grad():
42 | with autocast():
43 | sample = autoencoder_model.decode_stage_2_outputs(image)
44 |
45 | return sample
46 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/common.yaml:
--------------------------------------------------------------------------------
1 | # This file defines common definitions used in training and inference, most importantly the network definition
2 |
3 | imports:
4 | - $import os
5 | - $import datetime
6 | - $import torch
7 | - $import scripts
8 | - $import monai
9 | - $import generative
10 | - $import torch.distributed as dist
11 |
12 | image: $monai.utils.CommonKeys.IMAGE
13 | label: $monai.utils.CommonKeys.LABEL
14 | pred: $monai.utils.CommonKeys.PRED
15 |
16 | is_dist: '$dist.is_initialized()'
17 | rank: '$dist.get_rank() if @is_dist else 0'
18 | is_not_rank0: '$@rank > 0'
19 | device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")'
20 |
21 | network_def:
22 | _target_: generative.networks.nets.DiffusionModelUNet
23 | spatial_dims: 2
24 | in_channels: 1
25 | out_channels: 1
26 | num_channels: [64, 128, 128]
27 | attention_levels: [false, true, true]
28 | num_res_blocks: 1
29 | num_head_channels: 128
30 |
31 | network: $@network_def.to(@device)
32 |
33 | bundle_root: .
34 | ckpt_path: $@bundle_root + '/models/model.pt'
35 | use_amp: true
36 | image_dim: 64
37 | image_size: [1, '@image_dim', '@image_dim']
38 | num_train_timesteps: 1000
39 |
40 | base_transforms:
41 | - _target_: LoadImaged
42 | keys: '@image'
43 | image_only: true
44 | - _target_: EnsureChannelFirstd
45 | keys: '@image'
46 | - _target_: ScaleIntensityRanged
47 | keys: '@image'
48 | a_min: 0.0
49 | a_max: 255.0
50 | b_min: 0.0
51 | b_max: 1.0
52 | clip: true
53 |
54 | scheduler:
55 | _target_: generative.networks.schedulers.DDPMScheduler
56 | num_train_timesteps: '@num_train_timesteps'
57 |
58 | inferer:
59 | _target_: generative.inferers.DiffusionInferer
60 | scheduler: '@scheduler'
61 |
--------------------------------------------------------------------------------
/tests/test_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from monai.networks import eval_mode
18 |
19 | from generative.networks.nets import DecoderOnlyTransformer
20 |
21 |
22 | class TestDecoderOnlyTransformer(unittest.TestCase):
23 | def test_unconditioned_models(self):
24 | net = DecoderOnlyTransformer(
25 | num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=2
26 | )
27 | with eval_mode(net):
28 | net.forward(torch.randint(0, 10, (1, 16)))
29 |
30 | def test_conditioned_models(self):
31 | net = DecoderOnlyTransformer(
32 | num_tokens=10,
33 | max_seq_len=16,
34 | attn_layers_dim=8,
35 | attn_layers_depth=2,
36 | attn_layers_heads=2,
37 | with_cross_attention=True,
38 | embedding_dropout_rate=0,
39 | )
40 | with eval_mode(net):
41 | net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 4, 8))
42 |
43 |
44 | if __name__ == "__main__":
45 | unittest.main()
46 |
--------------------------------------------------------------------------------
/tests/test_compute_mmd_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import numpy as np
17 | import torch
18 | from parameterized import parameterized
19 |
20 | from generative.metrics import MMDMetric
21 |
22 | TEST_CASES = [
23 | [
24 | {"y_transform": None, "y_pred_transform": None},
25 | {"y": torch.ones([3, 3, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144])},
26 | 0.0,
27 | ],
28 | [
29 | {"y_transform": None, "y_pred_transform": None},
30 | {"y": torch.ones([3, 3, 144, 144, 144]), "y_pred": torch.ones([3, 3, 144, 144, 144])},
31 | 0.0,
32 | ],
33 | ]
34 |
35 |
36 | class TestMMDMetric(unittest.TestCase):
37 | @parameterized.expand(TEST_CASES)
38 | def test_results(self, input_param, input_data, expected_val):
39 | metric = MMDMetric(**input_param)
40 | results = metric(**input_data)
41 | np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)
42 |
43 | def test_if_inputs_different_shapes(self):
44 | with self.assertRaises(ValueError):
45 | MMDMetric()(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))
46 |
47 |
48 | if __name__ == "__main__":
49 | unittest.main()
50 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | target-version = ['py37', 'py38', 'py39', 'py310']
4 | include = '\.pyi?$'
5 | exclude = '''
6 | (
7 | /(
8 | # exclude a few common directories in the root of the project
9 | \.eggs
10 | | \.git
11 | | \.hg
12 | | \.mypy_cache
13 | | \.tox
14 | | \.venv
15 | | venv
16 | | \.pytype
17 | | _build
18 | | buck-out
19 | | build
20 | | dist
21 | )/
22 | # also separately exclude a file named versioneer.py
23 | | generative/_version.py
24 | )
25 | '''
26 |
27 | [tool.pycln]
28 | all = true
29 |
30 | [tool.pytype]
31 | # Space-separated list of files or directories to exclude.
32 | exclude = ["versioneer.py", "_version.py", "tutorials/"]
33 | # Space-separated list of files or directories to process.
34 | inputs = ["generative"]
35 | # Keep going past errors to analyze as many files as possible.
36 | keep_going = true
37 | # Run N jobs in parallel.
38 | jobs = 8
39 | # All pytype output goes here.
40 | output = ".pytype"
41 | # Paths to source code directories, separated by ':'.
42 | pythonpath = "."
43 | # Check attribute values against their annotations.
44 | check_attribute_types = true
45 | # Check container mutations against their annotations.
46 | check_container_types = true
47 | # Check parameter defaults and assignments against their annotations.
48 | check_parameter_types = true
49 | # Check variable values against their annotations.
50 | check_variable_types = true
51 | # Comma or space separated list of error names to ignore.
52 | disable = ["pyi-error"]
53 | # Report errors.
54 | report_errors = true
55 | # Experimental: Infer precise return types even for invalid function calls.
56 | precise_return = true
57 | # Experimental: solve unknown types to label with structural types.
58 | protocols = true
59 | # Experimental: Only load submodules that are explicitly imported.
60 | strict_import = false
61 |
--------------------------------------------------------------------------------
/tests/test_misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import numpy as np
17 | import torch
18 | from parameterized import parameterized
19 |
20 | from generative.utils import unsqueeze_left, unsqueeze_right
21 |
22 | RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))]
23 |
24 | LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))]
25 |
26 | ALL_CASES = [
27 | (np.random.rand(3, 4), 2, (3, 4)),
28 | (np.random.rand(3, 4), 0, (3, 4)),
29 | (np.random.rand(3, 4), -1, (3, 4)),
30 | (np.array(3), 4, (1, 1, 1, 1)),
31 | (np.array(3), 0, ()),
32 | (torch.rand(3, 4), 2, (3, 4)),
33 | (torch.rand(3, 4), 0, (3, 4)),
34 | (torch.rand(3, 4), -1, (3, 4)),
35 | (torch.tensor(3), 4, (1, 1, 1, 1)),
36 | (torch.tensor(3), 0, ()),
37 | ]
38 |
39 |
40 | class TestUnsqueeze(unittest.TestCase):
41 | @parameterized.expand(RIGHT_CASES + ALL_CASES)
42 | def test_unsqueeze_right(self, arr, ndim, shape):
43 | self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)
44 |
45 | @parameterized.expand(LEFT_CASES + ALL_CASES)
46 | def test_unsqueeze_left(self, arr, ndim, shape):
47 | self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)
48 |
--------------------------------------------------------------------------------
/tests/test_controlnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from monai.networks import eval_mode
18 | from parameterized import parameterized
19 |
20 | from generative.networks.nets.controlnet import ControlNet
21 |
22 | TEST_CASES = [
23 | [
24 | {
25 | "spatial_dims": 2,
26 | "in_channels": 1,
27 | "num_res_blocks": 1,
28 | "num_channels": (8, 8, 8),
29 | "attention_levels": (False, False, True),
30 | "num_head_channels": 8,
31 | "norm_num_groups": 8,
32 | "conditioning_embedding_in_channels": 1,
33 | "conditioning_embedding_num_channels": (8, 8),
34 | },
35 | 6,
36 | (1, 8, 4, 4),
37 | ]
38 | ]
39 |
40 |
41 | class TestControlNet(unittest.TestCase):
42 | @parameterized.expand(TEST_CASES)
43 | def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
44 | net = ControlNet(**input_param)
45 | with eval_mode(net):
46 | result = net.forward(
47 | torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 32, 32))
48 | )
49 | self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
50 | self.assertEqual(result[1].shape, expected_shape)
51 |
52 |
53 | if __name__ == "__main__":
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220729.json",
3 | "version": "0.1.0",
4 | "changelog": {
5 | "0.1.0": "Initial version"
6 | },
7 | "monai_version": "1.0.0",
8 | "pytorch_version": "1.10.2",
9 | "numpy_version": "1.21.2",
10 | "optional_packages_version": {
11 | "generative": "0.1.0"
12 | },
13 | "task": "MedNIST Hand Generation",
14 | "description": "",
15 | "authors": "Walter Hugo Lopez Pinaya, Mark Graham, and Eric Kerfoot",
16 | "copyright": "Copyright (c) KCL",
17 | "references": [],
18 | "intended_use": "This is suitable for research purposes only",
19 | "image_classes": "Single channel magnitude data",
20 | "data_source": "MedNIST",
21 | "network_data_format": {
22 | "inputs": {
23 | "image": {
24 | "type": "image",
25 | "format": "magnitude",
26 | "modality": "xray",
27 | "num_channels": 1,
28 | "spatial_shape": [
29 | 1,
30 | 64,
31 | 64
32 | ],
33 | "dtype": "float32",
34 | "value_range": [],
35 | "is_patch_data": false,
36 | "channel_def": {
37 | "0": "image"
38 | }
39 | }
40 | },
41 | "outputs": {
42 | "pred": {
43 | "type": "image",
44 | "format": "magnitude",
45 | "modality": "xray",
46 | "num_channels": 1,
47 | "spatial_shape": [
48 | 1,
49 | 64,
50 | 64
51 | ],
52 | "dtype": "float32",
53 | "value_range": [],
54 | "is_patch_data": false,
55 | "channel_def": {
56 | "0": "image"
57 | }
58 | }
59 | }
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | #default_language_version:
2 | # python: python3.8
3 |
4 | ci:
5 | autofix_prs: true
6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
7 | autoupdate_schedule: quarterly
8 | # submodules: true
9 |
10 | repos:
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v4.4.0
13 | hooks:
14 | - id: end-of-file-fixer
15 | exclude: ^tutorials/
16 | - id: trailing-whitespace
17 | - id: check-yaml
18 | - id: check-docstring-first
19 | - id: check-executables-have-shebangs
20 | - id: check-toml
21 | - id: check-case-conflict
22 | - id: check-added-large-files
23 | args: ['--maxkb=1024']
24 | - id: detect-private-key
25 | - id: forbid-new-submodules
26 | - id: pretty-format-json
27 | args: ['--autofix', '--no-sort-keys', '--indent=4']
28 | - id: mixed-line-ending
29 |
30 | - repo: https://github.com/asottile/pyupgrade
31 | rev: v3.3.1
32 | hooks:
33 | - id: pyupgrade
34 | args: [--py37-plus]
35 | name: Upgrade code
36 | exclude: |
37 | (?x)^(
38 | versioneer.py|
39 | monai/_version.py
40 | )$
41 |
42 | - repo: https://github.com/asottile/yesqa
43 | rev: v1.4.0
44 | hooks:
45 | - id: yesqa
46 | name: Unused noqa
47 | additional_dependencies:
48 | - flake8>=3.8.1
49 | - flake8-bugbear
50 | - flake8-comprehensions
51 | - flake8-executable
52 | - flake8-pyi
53 | - pep8-naming
54 | exclude: |
55 | (?x)^(
56 | generative/__init__.py|
57 | docs/source/conf.py
58 | )$
59 |
60 | - repo: https://github.com/hadialqattan/pycln
61 | rev: v2.1.2
62 | hooks:
63 | - id: pycln
64 | args: [--config=pyproject.toml]
65 |
66 | # - repo: https://github.com/psf/black
67 | # rev: 22.3.0
68 | # hooks:
69 | # - id: black
70 | #
71 | # - repo: https://github.com/PyCQA/isort
72 | # rev: 5.9.3
73 | # hooks:
74 | # - id: isort
75 |
--------------------------------------------------------------------------------
/tests/test_selfattention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from monai.networks import eval_mode
18 | from parameterized import parameterized
19 |
20 | from generative.networks.blocks.selfattention import SABlock
21 |
22 | TEST_CASE_SABLOCK = [
23 | [
24 | {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": False, "sequence_length": None},
25 | (2, 4, 16),
26 | (2, 4, 16),
27 | ],
28 | [
29 | {"hidden_size": 16, "num_heads": 8, "dropout_rate": 0.2, "causal": True, "sequence_length": 4},
30 | (2, 4, 16),
31 | (2, 4, 16),
32 | ],
33 | ]
34 |
35 |
36 | class TestResBlock(unittest.TestCase):
37 | @parameterized.expand(TEST_CASE_SABLOCK)
38 | def test_shape(self, input_param, input_shape, expected_shape):
39 | net = SABlock(**input_param)
40 | with eval_mode(net):
41 | result = net(torch.randn(input_shape))
42 | self.assertEqual(result.shape, expected_shape)
43 |
44 | def test_ill_arg(self):
45 | with self.assertRaises(ValueError):
46 | SABlock(hidden_size=12, num_heads=4, dropout_rate=6.0)
47 |
48 | with self.assertRaises(ValueError):
49 | SABlock(hidden_size=12, num_heads=4, dropout_rate=-6.0)
50 |
51 | with self.assertRaises(ValueError):
52 | SABlock(hidden_size=20, num_heads=8, dropout_rate=0.4)
53 |
54 | with self.assertRaises(ValueError):
55 | SABlock(hidden_size=12, num_heads=4, dropout_rate=0.4, causal=True, sequence_length=None)
56 |
57 |
58 | if __name__ == "__main__":
59 | unittest.main()
60 |
--------------------------------------------------------------------------------
/generative/utils/enums.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from typing import TYPE_CHECKING
15 |
16 | from monai.config import IgniteInfo
17 | from monai.utils import StrEnum, min_version, optional_import
18 |
19 | if TYPE_CHECKING:
20 | from ignite.engine import EventEnum
21 | else:
22 | EventEnum, _ = optional_import(
23 | "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
24 | )
25 |
26 |
27 | class AdversarialKeys(StrEnum):
28 | REALS = "reals"
29 | REAL_LOGITS = "real_logits"
30 | FAKES = "fakes"
31 | FAKE_LOGITS = "fake_logits"
32 | RECONSTRUCTION_LOSS = "reconstruction_loss"
33 | GENERATOR_LOSS = "generator_loss"
34 | DISCRIMINATOR_LOSS = "discriminator_loss"
35 |
36 |
37 | class AdversarialIterationEvents(EventEnum):
38 | RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
39 | GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
40 | GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
41 | GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
42 | GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
43 | GENERATOR_MODEL_COMPLETED = "generator_model_completed"
44 | DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
45 | DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
46 | DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
47 | DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
48 | DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"
49 |
50 |
51 | class OrderingType(StrEnum):
52 | RASTER_SCAN = "raster_scan"
53 | S_CURVE = "s_curve"
54 | RANDOM = "random"
55 |
56 |
57 | class OrderingTransformations(StrEnum):
58 | ROTATE_90 = "rotate_90"
59 | TRANSPOSE = "transpose"
60 | REFLECT = "reflect"
61 |
--------------------------------------------------------------------------------
/tests/min_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import glob
15 | import os
16 | import sys
17 | import unittest
18 |
19 |
20 | def run_testsuit():
21 | """
22 | Load test cases by excluding those need external dependencies.
23 | The loaded cases should work with "requirements-min.txt"::
24 |
25 | # in the monai repo folder:
26 | pip install -r requirements-min.txt
27 | QUICKTEST=true python -m tests.min_tests
28 |
29 | :return: a test suite
30 | """
31 | exclude_cases = [ # these cases use external dependencies
32 | "test_autoencoderkl",
33 | "test_diffusion_inferer",
34 | "test_integration_workflows_adversarial",
35 | "test_latent_diffusion_inferer",
36 | "test_perceptual_loss",
37 | "test_transformer",
38 | ]
39 | assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
40 |
41 | files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py"))
42 |
43 | cases = []
44 | for case in files:
45 | test_module = os.path.basename(case)[:-3]
46 | if test_module in exclude_cases:
47 | exclude_cases.remove(test_module)
48 | print(f"skipping tests.{test_module}.")
49 | else:
50 | cases.append(f"tests.{test_module}")
51 | assert not exclude_cases, f"items in exclude_cases not used: {exclude_cases}."
52 | test_suite = unittest.TestLoader().loadTestsFromNames(cases)
53 | return test_suite
54 |
55 |
56 | if __name__ == "__main__":
57 | # testing import submodules
58 | from monai.utils.module import load_submodules
59 |
60 | _, err_mod = load_submodules(sys.modules["monai"], True)
61 | assert not err_mod, f"err_mod={err_mod} not empty"
62 |
63 | # testing all modules
64 | test_runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2)
65 | result = test_runner.run(run_testsuit())
66 | sys.exit(int(not result.wasSuccessful()))
67 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3 | "version": "1.0.0",
4 | "changelog": {
5 | "0.2": "Flipped images fixed"
6 | },
7 | "monai_version": "1.1.0",
8 | "pytorch_version": "1.13.0",
9 | "numpy_version": "1.22.4",
10 | "optional_packages_version": {
11 | "nibabel": "4.0.1",
12 | "generative": "0.1.0",
13 | "transformers": "4.26.1"
14 | },
15 | "task": "Chest X-ray image synthesis",
16 | "description": "A generative model for creating high-resolution chest X-ray based on MIMIC dataset",
17 | "copyright": "Copyright (c) MONAI Consortium",
18 | "data_source": "https://physionet.org/content/mimic-cxr-jpg/2.0.0/",
19 | "data_type": "image",
20 | "image_classes": "Radiography (X-ray) with 512 x 512 pixels",
21 | "intended_use": "This is a research tool/prototype and not to be used clinically",
22 | "network_data_format": {
23 | "inputs": {
24 | "latent_representation": {
25 | "type": "image",
26 | "format": "magnitude",
27 | "modality": "CXR",
28 | "num_channels": 3,
29 | "spatial_shape": [
30 | 64,
31 | 64
32 | ],
33 | "dtype": "float32",
34 | "value_range": [],
35 | "is_patch_data": false
36 | },
37 | "timesteps": {
38 | "type": "vector",
39 | "value_range": [
40 | 0,
41 | 1000
42 | ],
43 | "dtype": "long"
44 | },
45 | "context": {
46 | "type": "vector",
47 | "value_range": [],
48 | "dtype": "float32"
49 | }
50 | },
51 | "outputs": {
52 | "pred": {
53 | "type": "image",
54 | "format": "magnitude",
55 | "modality": "CXR",
56 | "num_channels": 1,
57 | "spatial_shape": [
58 | 512,
59 | 512
60 | ],
61 | "dtype": "float32",
62 | "value_range": [
63 | 0,
64 | 1
65 | ],
66 | "is_patch_data": false,
67 | "channel_def": {
68 | "0": "X-ray"
69 | }
70 | }
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/tests/test_component_store.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | from generative.utils import ComponentStore
17 |
18 |
19 | class TestComponentStore(unittest.TestCase):
20 | def setUp(self):
21 | self.cs = ComponentStore("TestStore", "I am a test store, please ignore")
22 |
23 | def test_empty(self):
24 | self.assertEqual(len(self.cs), 0)
25 | self.assertEqual(list(self.cs), [])
26 |
27 | def test_add(self):
28 | test_obj = object()
29 |
30 | self.assertFalse("test_obj" in self.cs)
31 |
32 | self.cs.add("test_obj", "Test object", test_obj)
33 |
34 | self.assertTrue("test_obj" in self.cs)
35 |
36 | self.assertEqual(len(self.cs), 1)
37 | self.assertEqual(list(self.cs), [("test_obj", test_obj)])
38 |
39 | self.assertEqual(self.cs.test_obj, test_obj)
40 | self.assertEqual(self.cs["test_obj"], test_obj)
41 |
42 | def test_add2(self):
43 | test_obj1 = object()
44 | test_obj2 = object()
45 |
46 | self.cs.add("test_obj1", "Test object", test_obj1)
47 | self.cs.add("test_obj2", "Test object", test_obj2)
48 |
49 | self.assertEqual(len(self.cs), 2)
50 | self.assertTrue("test_obj1" in self.cs)
51 | self.assertTrue("test_obj2" in self.cs)
52 |
53 | def test_add_def(self):
54 | self.assertFalse("test_func" in self.cs)
55 |
56 | @self.cs.add_def("test_func", "Test function")
57 | def test_func():
58 | return 123
59 |
60 | self.assertTrue("test_func" in self.cs)
61 |
62 | self.assertEqual(len(self.cs), 1)
63 | self.assertEqual(list(self.cs), [("test_func", test_func)])
64 |
65 | self.assertEqual(self.cs.test_func, test_func)
66 | self.assertEqual(self.cs["test_func"], test_func)
67 |
68 | # try adding the same function again
69 | self.cs.add_def("test_func", "Test function but with new description")(test_func)
70 |
71 | self.assertEqual(len(self.cs), 1)
72 | self.assertEqual(self.cs.test_func, test_func)
73 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/metadata.json:
--------------------------------------------------------------------------------
1 | {
2 | "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3 | "version": "1.0.0",
4 | "changelog": {
5 | "1.0.8": "Initial release"
6 | },
7 | "monai_version": "1.1.0",
8 | "pytorch_version": "1.13.0",
9 | "numpy_version": "1.22.4",
10 | "optional_packages_version": {
11 | "nibabel": "4.0.1",
12 | "generative": "0.1.0"
13 | },
14 | "task": "Brain image synthesis",
15 | "description": "A generative model for creating high-resolution 3D brain MRI based on UK Biobank",
16 | "authors": "Walter H. L. Pinaya, Petru-Daniel Tudosiu, Jessica Dafflon, Pedro F Da Costa, Virginia Fernandez, Parashkev Nachev, Sebastien Ourselin, and M. Jorge Cardoso",
17 | "copyright": "Copyright (c) MONAI Consortium",
18 | "data_source": "https://www.ukbiobank.ac.uk/",
19 | "data_type": "nibabel",
20 | "image_classes": "T1w head MRI with 1x1x1 mm voxel size",
21 | "eval_metrics": {
22 | "fid": 0.0076,
23 | "msssim": 0.6555,
24 | "4gmsssim": 0.3883
25 | },
26 | "intended_use": "This is a research tool/prototype and not to be used clinically",
27 | "references": [
28 | "Pinaya, Walter HL, et al. \"Brain imaging generation with latent diffusion models.\" MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022."
29 | ],
30 | "network_data_format": {
31 | "inputs": {
32 | "image": {
33 | "type": "tabular",
34 | "num_channels": 1,
35 | "dtype": "float32",
36 | "value_range": [
37 | 0,
38 | 1
39 | ],
40 | "is_patch_data": false,
41 | "channel_def": {
42 | "0": "Gender",
43 | "1": "Age",
44 | "2": "Ventricular volume",
45 | "3": "Brain volume"
46 | }
47 | }
48 | },
49 | "outputs": {
50 | "pred": {
51 | "type": "image",
52 | "format": "magnitude",
53 | "modality": "MR",
54 | "num_channels": 1,
55 | "spatial_shape": [
56 | 160,
57 | 224,
58 | 160
59 | ],
60 | "dtype": "float32",
61 | "value_range": [
62 | 0,
63 | 1
64 | ],
65 | "is_patch_data": false,
66 | "channel_def": {
67 | "0": "T1w"
68 | }
69 | }
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/tests/test_scheduler_ddim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.networks.schedulers import DDIMScheduler
20 |
21 | TEST_2D_CASE = []
22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
23 | TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])
24 |
25 | TEST_3D_CASE = []
26 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
27 | TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
28 |
29 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
30 |
31 |
32 | class TestDDPMScheduler(unittest.TestCase):
33 | @parameterized.expand(TEST_CASES)
34 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape):
35 | scheduler = DDIMScheduler(**input_param)
36 | scheduler.set_timesteps(num_inference_steps=100)
37 | original_sample = torch.zeros(input_shape)
38 | noise = torch.randn_like(original_sample)
39 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
40 |
41 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
42 | self.assertEqual(noisy.shape, expected_shape)
43 |
44 | @parameterized.expand(TEST_CASES)
45 | def test_step_shape(self, input_param, input_shape, expected_shape):
46 | scheduler = DDIMScheduler(**input_param)
47 | scheduler.set_timesteps(num_inference_steps=100)
48 | model_output = torch.randn(input_shape)
49 | sample = torch.randn(input_shape)
50 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
51 | self.assertEqual(output_step[0].shape, expected_shape)
52 | self.assertEqual(output_step[1].shape, expected_shape)
53 |
54 | def test_set_timesteps(self):
55 | scheduler = DDIMScheduler(num_train_timesteps=1000)
56 | scheduler.set_timesteps(num_inference_steps=100)
57 | self.assertEqual(scheduler.num_inference_steps, 100)
58 | self.assertEqual(len(scheduler.timesteps), 100)
59 |
60 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
61 | scheduler = DDIMScheduler(num_train_timesteps=1000)
62 | with self.assertRaises(ValueError):
63 | scheduler.set_timesteps(num_inference_steps=2000)
64 |
65 |
66 | if __name__ == "__main__":
67 | unittest.main()
68 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/docs/README.md:
--------------------------------------------------------------------------------
1 | # Brain Imaging Generation with Latent Diffusion Models
2 |
3 | ### **Authors**
4 |
5 | Walter H. L. Pinaya, Petru-Daniel Tudosiu, Jessica Dafflon, Pedro F Da Costa, Virginia Fernandez, Parashkev Nachev,
6 | Sebastien Ourselin, and M. Jorge Cardoso
7 |
8 | ### **Tags**
9 | Synthetic data, Latent Diffusion Model, Generative model, Brain Imaging
10 |
11 | ## **Model Description**
12 | This model is trained using the Latent Diffusion Model architecture [1] and is used for the synthesis of conditioned 3D
13 | brain MRI data. The model is divided into two parts: an autoencoder with a KL-regularisation model that compresses data
14 | into a latent space and a diffusion model that learns to generate conditioned synthetic latent representations. This
15 | model is conditioned on age, sex, the volume of ventricular cerebrospinal fluid, and brain volume normalised for head size.
16 |
17 | 
18 |
19 | Figure 1 - Synthetic image from the model.
20 |
21 |
22 | ## **Data**
23 | The model was trained on brain data from 31,740 participants from the UK Biobank [2]. We used high-resolution 3D T1w MRI with voxel size of 1mm3, resulting in volumes with 160 x 224 x 160 voxels
24 |
25 | #### **Preprocessing**
26 | We used UniRes [3] to perform a rigid body registration to a common MNI space for image pre-processing. The voxel intensity was normalised to be between [0, 1].
27 |
28 | ## **Performance**
29 | This model achieves the following results on UK Biobank: an FID of 0.0076, an MS-SSIM of 0.6555, and a 4-G-R-SSIM of 0.3883.
30 |
31 | Please, check Table 1 of the original paper for more details regarding evaluation results.
32 |
33 |
34 | ## **commands example**
35 | Execute sampling:
36 | ```
37 | export PYTHONPATH=$PYTHONPATH:""
38 | $ python -m monai.bundle run save_nii --config_file configs/inference.json --gender 1.0 --age 0.7 --ventricular_vol 0.7 --brain_vol 0.5
39 | ```
40 | All conditioning are expected to have values between 0 and 1
41 |
42 | ## **Citation Info**
43 |
44 | ```
45 | @inproceedings{pinaya2022brain,
46 | title={Brain imaging generation with latent diffusion models},
47 | author={Pinaya, Walter HL and Tudosiu, Petru-Daniel and Dafflon, Jessica and Da Costa, Pedro F and Fernandez, Virginia and Nachev, Parashkev and Ourselin, Sebastien and Cardoso, M Jorge},
48 | booktitle={MICCAI Workshop on Deep Generative Models},
49 | pages={117--126},
50 | year={2022},
51 | organization={Springer}
52 | }
53 | ```
54 |
55 | ## **References**
56 |
57 | Example:
58 |
59 | [1] Pinaya, Walter HL, et al. "Brain imaging generation with latent diffusion models." MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022.
60 |
61 | [2] Sudlow, Cathie, et al. "UK biobank: an open access resource for identifying the causes of a wide range of complex diseases of middle and old age." PLoS medicine 12.3 (2015): e1001779.
62 |
63 | [3] Brudfors, Mikael, et al. "MRI super-resolution using multi-channel total variation." Annual Conference on Medical Image Understanding and Analysis. Springer, Cham, 2018.
64 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # MONAI Generative Models
6 | Prototyping repository for generative models to be integrated into MONAI core, MONAI tutorials, and MONAI model zoo.
7 | ## Features
8 | * Network architectures: Diffusion Model, Autoencoder-KL, VQ-VAE, Autoregressive transformers, (Multi-scale) Patch-GAN discriminator.
9 | * Diffusion Model Noise Schedulers: DDPM, DDIM, and PNDM.
10 | * Losses: Adversarial losses, Spectral losses, and Perceptual losses (for 2D and 3D data using LPIPS, RadImageNet, and 3DMedicalNet pre-trained models).
11 | * Metrics: Multi-Scale Structural Similarity Index Measure (MS-SSIM) and Fréchet inception distance (FID).
12 | * Diffusion Models, Latent Diffusion Models, and VQ-VAE + Transformer Inferers classes (compatible with MONAI style) containing methods to train, sample synthetic images, and obtain the likelihood of inputted data.
13 | * MONAI-compatible trainer engine (based on Ignite) to train models with reconstruction and adversarial components.
14 | * Tutorials including:
15 | * How to train VQ-VAEs, VQ-GANs, VQ-VAE + Transformers, AutoencoderKLs, Diffusion Models, and Latent Diffusion Models on 2D and 3D data.
16 | * Train diffusion model to perform conditional image generation with classifier-free guidance.
17 | * Comparison of different diffusion model schedulers.
18 | * Diffusion models with different parameterizations (e.g., v-prediction and epsilon parameterization).
19 | * Anomaly Detection using VQ-VAE + Transformers and Diffusion Models.
20 | * Inpainting with diffusion model (using Repaint method)
21 | * Super-resolution with Latent Diffusion Models (using Noise Conditioning Augmentation)
22 |
23 | ## Roadmap
24 | Our short-term goals are available in the [Milestones](https://github.com/Project-MONAI/GenerativeModels/milestones)
25 | section of the repository.
26 |
27 | In the longer term, we aim to integrate the generative models into the MONAI core repository (supporting tasks such as,
28 | image synthesis, anomaly detection, MRI reconstruction, domain transfer)
29 |
30 | ## Installation
31 | To install the current release of MONAI Generative Models, you can run:
32 | ```
33 | pip install monai-generative
34 | ```
35 | To install the current main branch of the repository, run:
36 | ```
37 | pip install git+https://github.com/Project-MONAI/GenerativeModels.git
38 | ```
39 | Requires Python >= 3.8.
40 |
41 | ## Contributing
42 | For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/GenerativeModels/blob/main/CONTRIBUTING.md).
43 |
44 | ## Community
45 | Join the conversation on Twitter [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).
46 |
47 | # Citation
48 |
49 | If you use MONAI Generative in your research, please cite us! The citation can be exported from [the paper](https://arxiv.org/abs/2307.15208).
50 |
51 | ## Links
52 | - Website: https://monai.io/
53 | - Code: https://github.com/Project-MONAI/GenerativeModels
54 | - Project tracker: https://github.com/Project-MONAI/GenerativeModels/projects
55 | - Issue tracker: https://github.com/Project-MONAI/GenerativeModels/issues
56 |
--------------------------------------------------------------------------------
/tests/test_scheduler_pndm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.networks.schedulers import PNDMScheduler
20 |
21 | TEST_2D_CASE = []
22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
23 | TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])
24 |
25 | TEST_3D_CASE = []
26 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
27 | TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
28 |
29 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
30 |
31 |
32 | class TestDDPMScheduler(unittest.TestCase):
33 | @parameterized.expand(TEST_CASES)
34 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape):
35 | scheduler = PNDMScheduler(**input_param)
36 | original_sample = torch.zeros(input_shape)
37 | noise = torch.randn_like(original_sample)
38 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
39 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
40 | self.assertEqual(noisy.shape, expected_shape)
41 |
42 | @parameterized.expand(TEST_CASES)
43 | def test_step_shape(self, input_param, input_shape, expected_shape):
44 | scheduler = PNDMScheduler(**input_param)
45 | scheduler.set_timesteps(600)
46 | model_output = torch.randn(input_shape)
47 | sample = torch.randn(input_shape)
48 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
49 | self.assertEqual(output_step[0].shape, expected_shape)
50 | self.assertEqual(output_step[1], None)
51 |
52 | def test_set_timesteps(self):
53 | scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True)
54 | scheduler.set_timesteps(num_inference_steps=100)
55 | self.assertEqual(scheduler.num_inference_steps, 100)
56 | self.assertEqual(len(scheduler.timesteps), 100)
57 |
58 | def test_set_timesteps_prk(self):
59 | scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False)
60 | scheduler.set_timesteps(num_inference_steps=100)
61 | self.assertEqual(scheduler.num_inference_steps, 109)
62 | self.assertEqual(len(scheduler.timesteps), 109)
63 |
64 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
65 | scheduler = PNDMScheduler(num_train_timesteps=1000)
66 | with self.assertRaises(ValueError):
67 | scheduler.set_timesteps(num_inference_steps=2000)
68 |
69 |
70 | if __name__ == "__main__":
71 | unittest.main()
72 |
--------------------------------------------------------------------------------
/tests/test_spectral_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import numpy as np
17 | import torch
18 | from parameterized import parameterized
19 |
20 | from generative.losses import JukeboxLoss
21 | from tests.utils import test_script_save
22 |
23 | TEST_CASES = [
24 | [
25 | {"spatial_dims": 2},
26 | {
27 | "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
28 | "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
29 | },
30 | 0.070648,
31 | ],
32 | [
33 | {"spatial_dims": 2, "reduction": "sum"},
34 | {
35 | "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
36 | "target": torch.tensor([[[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
37 | },
38 | 0.8478,
39 | ],
40 | [
41 | {"spatial_dims": 3},
42 | {
43 | "input": torch.tensor(
44 | [
45 | [
46 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],
47 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],
48 | ]
49 | ]
50 | ),
51 | "target": torch.tensor(
52 | [
53 | [
54 | [[[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],
55 | [[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]],
56 | ]
57 | ]
58 | ),
59 | },
60 | 0.03838,
61 | ],
62 | ]
63 |
64 |
65 | class TestJukeboxLoss(unittest.TestCase):
66 | @parameterized.expand(TEST_CASES)
67 | def test_results(self, input_param, input_data, expected_val):
68 | results = JukeboxLoss(**input_param).forward(**input_data)
69 | np.testing.assert_allclose(results.detach().cpu().numpy(), expected_val, rtol=1e-4)
70 |
71 | def test_2d_shape(self):
72 | results = JukeboxLoss(spatial_dims=2, reduction="none").forward(**TEST_CASES[0][1])
73 | self.assertEqual(results.shape, (1, 2, 2, 3))
74 |
75 | def test_3d_shape(self):
76 | results = JukeboxLoss(spatial_dims=3, reduction="none").forward(**TEST_CASES[2][1])
77 | self.assertEqual(results.shape, (1, 2, 2, 2, 3))
78 |
79 | def test_script(self):
80 | loss = JukeboxLoss(spatial_dims=2)
81 | test_input = torch.ones(2, 1, 8, 8)
82 | test_script_save(loss, test_input, test_input)
83 |
84 |
85 | if __name__ == "__main__":
86 | unittest.main()
87 |
--------------------------------------------------------------------------------
/tests/test_perceptual_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.losses import PerceptualLoss
20 |
21 | TEST_CASES = [
22 | [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)],
23 | [
24 | {"spatial_dims": 3, "network_type": "squeeze", "is_fake_3d": True, "fake_3d_ratio": 0.1},
25 | (2, 1, 64, 64, 64),
26 | (2, 1, 64, 64, 64),
27 | ],
28 | [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)],
29 | [{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)],
30 | [
31 | {"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1},
32 | (2, 1, 64, 64, 64),
33 | (2, 1, 64, 64, 64),
34 | ],
35 | [
36 | {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
37 | (2, 1, 64, 64, 64),
38 | (2, 1, 64, 64, 64),
39 | ],
40 | [
41 | {"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
42 | (2, 1, 64, 64, 64),
43 | (2, 1, 64, 64, 64),
44 | ],
45 | ]
46 |
47 |
48 | class TestPerceptualLoss(unittest.TestCase):
49 | @parameterized.expand(TEST_CASES)
50 | def test_shape(self, input_param, input_shape, target_shape):
51 | loss = PerceptualLoss(**input_param)
52 | result = loss(torch.randn(input_shape), torch.randn(target_shape))
53 | self.assertEqual(result.shape, torch.Size([]))
54 |
55 | @parameterized.expand(TEST_CASES)
56 | def test_identical_input(self, input_param, input_shape, target_shape):
57 | loss = PerceptualLoss(**input_param)
58 | tensor = torch.randn(input_shape)
59 | result = loss(tensor, tensor)
60 | self.assertEqual(result, torch.Tensor([0.0]))
61 |
62 | def test_different_shape(self):
63 | loss = PerceptualLoss(spatial_dims=2, network_type="squeeze")
64 | tensor = torch.randn(2, 1, 64, 64)
65 | target = torch.randn(2, 1, 32, 32)
66 | with self.assertRaises(ValueError):
67 | loss(tensor, target)
68 |
69 | def test_1d(self):
70 | with self.assertRaises(NotImplementedError):
71 | PerceptualLoss(spatial_dims=1)
72 |
73 | def test_medicalnet_on_2d_data(self):
74 | with self.assertRaises(ValueError):
75 | PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")
76 |
77 | with self.assertRaises(ValueError):
78 | PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
79 |
80 |
81 | if __name__ == "__main__":
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/tests/test_compute_multiscalessim_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from monai.utils import set_determinism
18 |
19 | from generative.metrics import MultiScaleSSIMMetric
20 |
21 |
22 | class TestMultiScaleSSIMMetric(unittest.TestCase):
23 | def test2d_gaussian(self):
24 | set_determinism(0)
25 | preds = torch.abs(torch.randn(1, 1, 64, 64))
26 | target = torch.abs(torch.randn(1, 1, 64, 64))
27 | preds = preds / preds.max()
28 | target = target / target.max()
29 |
30 | metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5])
31 | metric(preds, target)
32 | result = metric.aggregate()
33 | expected_value = 0.023176
34 | self.assertTrue(expected_value - result.item() < 0.000001)
35 |
36 | def test2d_uniform(self):
37 | set_determinism(0)
38 | preds = torch.abs(torch.randn(1, 1, 64, 64))
39 | target = torch.abs(torch.randn(1, 1, 64, 64))
40 | preds = preds / preds.max()
41 | target = target / target.max()
42 |
43 | metric = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_type="uniform", weights=[0.5, 0.5])
44 | metric(preds, target)
45 | result = metric.aggregate()
46 | expected_value = 0.022655
47 | self.assertTrue(expected_value - result.item() < 0.000001)
48 |
49 | def test3d_gaussian(self):
50 | set_determinism(0)
51 | preds = torch.abs(torch.randn(1, 1, 64, 64, 64))
52 | target = torch.abs(torch.randn(1, 1, 64, 64, 64))
53 | preds = preds / preds.max()
54 | target = target / target.max()
55 |
56 | metric = MultiScaleSSIMMetric(spatial_dims=3, data_range=1.0, kernel_type="gaussian", weights=[0.5, 0.5])
57 | metric(preds, target)
58 | result = metric.aggregate()
59 | expected_value = 0.061796
60 | self.assertTrue(expected_value - result.item() < 0.000001)
61 |
62 | def input_ill_input_shape2d(self):
63 | metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5])
64 |
65 | with self.assertRaises(ValueError):
66 | metric(torch.randn(1, 1, 64, 64), torch.randn(1, 1, 64, 64))
67 |
68 | def input_ill_input_shape3d(self):
69 | metric = MultiScaleSSIMMetric(spatial_dims=2, weights=[0.5, 0.5])
70 |
71 | with self.assertRaises(ValueError):
72 | metric(torch.randn(1, 1, 64, 64, 64), torch.randn(1, 1, 64, 64, 64))
73 |
74 | def small_inputs(self):
75 | metric = MultiScaleSSIMMetric(spatial_dims=2)
76 |
77 | with self.assertRaises(ValueError):
78 | metric(torch.randn(1, 1, 16, 16, 16), torch.randn(1, 1, 16, 16, 16))
79 |
80 |
81 | if __name__ == "__main__":
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/tests/test_vector_quantizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 |
18 | from generative.networks.layers import EMAQuantizer, VectorQuantizer
19 |
20 |
21 | class TestEMA(unittest.TestCase):
22 | def test_ema_shape(self):
23 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)
24 | input_shape = (1, 8, 8, 8)
25 | x = torch.randn(input_shape)
26 | layer = layer.train()
27 | outputs = layer(x)
28 | self.assertEqual(outputs[0].shape, input_shape)
29 | self.assertEqual(outputs[2].shape, (1, 8, 8))
30 |
31 | layer = layer.eval()
32 | outputs = layer(x)
33 | self.assertEqual(outputs[0].shape, input_shape)
34 | self.assertEqual(outputs[2].shape, (1, 8, 8))
35 |
36 | def test_ema_quantize(self):
37 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8)
38 | input_shape = (1, 8, 8, 8)
39 | x = torch.randn(input_shape)
40 | outputs = layer.quantize(x)
41 | self.assertEqual(outputs[0].shape, (64, 8)) # (HxW, C)
42 | self.assertEqual(outputs[1].shape, (64, 16)) # (HxW, E)
43 | self.assertEqual(outputs[2].shape, (1, 8, 8)) # (1, H, W)
44 |
45 | def test_ema(self):
46 | layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0)
47 | original_weight_0 = layer.embedding.weight[0].clone()
48 | original_weight_1 = layer.embedding.weight[1].clone()
49 | x_0 = original_weight_0
50 | x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
51 | x_0 = x_0.repeat(1, 1, 1, 2) + 0.001
52 |
53 | x_1 = original_weight_1
54 | x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
55 | x_1 = x_1.repeat(1, 1, 1, 2)
56 |
57 | x = torch.cat([x_0, x_1], dim=0)
58 | layer = layer.train()
59 | _ = layer(x)
60 |
61 | self.assertTrue(all(layer.embedding.weight[0] != original_weight_0))
62 | self.assertTrue(all(layer.embedding.weight[1] == original_weight_1))
63 |
64 |
65 | class TestVectorQuantizer(unittest.TestCase):
66 | def test_vector_quantizer_shape(self):
67 | layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8))
68 | input_shape = (1, 8, 8, 8)
69 | x = torch.randn(input_shape)
70 | outputs = layer(x)
71 | self.assertEqual(outputs[1].shape, input_shape)
72 |
73 | def test_vector_quantizer_quantize(self):
74 | layer = VectorQuantizer(EMAQuantizer(spatial_dims=2, num_embeddings=16, embedding_dim=8))
75 | input_shape = (1, 8, 8, 8)
76 | x = torch.randn(input_shape)
77 | outputs = layer.quantize(x)
78 | self.assertEqual(outputs.shape, (1, 8, 8))
79 |
80 |
81 | if __name__ == "__main__":
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/generative/metrics/mmd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from collections.abc import Callable
15 |
16 | import torch
17 | from monai.metrics.metric import Metric
18 |
19 |
20 | class MMDMetric(Metric):
21 | """
22 | Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two
23 | distributions. It is a non-negative metric where a smaller value indicates a closer match between the two
24 | distributions.
25 |
26 | Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.
27 |
28 | Args:
29 | y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace
30 | filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a
31 | feature extractor or an Identity function.
32 | y_pred_transform: Callable to transform the y_pred tensor before computing the metric.
33 | """
34 |
35 | def __init__(self, y_transform: Callable | None = None, y_pred_transform: Callable | None = None) -> None:
36 | super().__init__()
37 |
38 | self.y_transform = y_transform
39 | self.y_pred_transform = y_pred_transform
40 |
41 | def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
42 | """
43 | Args:
44 | y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
45 | y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.
46 | """
47 |
48 | # Beta and Gamma are not calculated since torch.mean is used at return
49 | beta = 1.0
50 | gamma = 2.0
51 |
52 | if self.y_transform is not None:
53 | y = self.y_transform(y)
54 |
55 | if self.y_pred_transform is not None:
56 | y_pred = self.y_pred_transform(y_pred)
57 |
58 | if y_pred.shape != y.shape:
59 | raise ValueError(
60 | "y_pred and y shapes dont match after being processed "
61 | f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}"
62 | )
63 |
64 | for d in range(len(y.shape) - 1, 1, -1):
65 | y = y.squeeze(dim=d)
66 | y_pred = y_pred.squeeze(dim=d)
67 |
68 | y = y.view(y.shape[0], -1)
69 | y_pred = y_pred.view(y_pred.shape[0], -1)
70 |
71 | y_y = torch.mm(y, y.t())
72 | y_pred_y_pred = torch.mm(y_pred, y_pred.t())
73 | y_pred_y = torch.mm(y_pred, y.t())
74 |
75 | y_y = y_y / y.shape[1]
76 | y_pred_y_pred = y_pred_y_pred / y.shape[1]
77 | y_pred_y = y_pred_y / y.shape[1]
78 |
79 | # Ref. 1 Eq. 3 (found under Lemma 6)
80 | return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y)
81 |
--------------------------------------------------------------------------------
/generative/networks/blocks/encoder_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from collections.abc import Sequence
15 | from functools import partial
16 |
17 | import torch
18 | import torch.nn as nn
19 | from monai.networks.blocks import Convolution
20 |
21 | __all__ = ["SpatialRescaler"]
22 |
23 |
24 | class SpatialRescaler(nn.Module):
25 | """
26 | SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py
27 |
28 | Args:
29 | spatial_dims: number of spatial dimensions.
30 | n_stages: number of interpolation stages.
31 | size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]).
32 | method: algorithm used for sampling.
33 | multiplier: multiplier for spatial size. If `multiplier` is a sequence,
34 | its length has to match the number of spatial dimensions; `input.dim() - 2`.
35 | in_channels: number of input channels.
36 | out_channels: number of output channels.
37 | bias: whether to have a bias term.
38 | """
39 |
40 | def __init__(
41 | self,
42 | spatial_dims: int = 2,
43 | n_stages: int = 1,
44 | size: Sequence[int] | int | None = None,
45 | method: str = "bilinear",
46 | multiplier: Sequence[float] | float | None = None,
47 | in_channels: int = 3,
48 | out_channels: int = None,
49 | bias: bool = False,
50 | ):
51 | super().__init__()
52 | self.n_stages = n_stages
53 | assert self.n_stages >= 0
54 | assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"]
55 | if size is not None and n_stages != 1:
56 | raise ValueError("when size is not None, n_stages should be 1.")
57 | if size is not None and multiplier is not None:
58 | raise ValueError("only one of size or multiplier should be defined.")
59 | self.multiplier = multiplier
60 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size)
61 | self.remap_output = out_channels is not None
62 | if self.remap_output:
63 | print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.")
64 | self.channel_mapper = Convolution(
65 | spatial_dims=spatial_dims,
66 | in_channels=in_channels,
67 | out_channels=out_channels,
68 | kernel_size=1,
69 | conv_only=True,
70 | bias=bias,
71 | )
72 |
73 | def forward(self, x: torch.Tensor) -> torch.Tensor:
74 | if self.remap_output:
75 | x = self.channel_mapper(x)
76 |
77 | for _ in range(self.n_stages):
78 | x = self.interpolator(x, scale_factor=self.multiplier)
79 |
80 | return x
81 |
82 | def encode(self, x: torch.Tensor) -> torch.Tensor:
83 | return self(x)
84 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at monai.contact@gmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/generative/losses/spectral_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from monai.utils import LossReduction
17 | from torch.fft import fftn
18 | from torch.nn.modules.loss import _Loss
19 |
20 |
21 | class JukeboxLoss(_Loss):
22 | """
23 | Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT).
24 |
25 | Based on:
26 | Dhariwal, et al. 'Jukebox: A generative model for music.'https://arxiv.org/abs/2005.00341
27 |
28 | Args:
29 | spatial_dims: number of spatial dimensions.
30 | fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information.
31 | fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See
32 | torch.fft.fftn() for more information.
33 |
34 | reduction: {``"none"``, ``"mean"``, ``"sum"``}
35 | Specifies the reduction to apply to the output. Defaults to ``"mean"``.
36 |
37 | - ``"none"``: no reduction will be applied.
38 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
39 | - ``"sum"``: the output will be summed.
40 | """
41 |
42 | def __init__(
43 | self,
44 | spatial_dims: int,
45 | fft_signal_size: tuple[int] | None = None,
46 | fft_norm: str = "ortho",
47 | reduction: LossReduction | str = LossReduction.MEAN,
48 | ) -> None:
49 | super().__init__(reduction=LossReduction(reduction).value)
50 |
51 | self.spatial_dims = spatial_dims
52 | self.fft_signal_size = fft_signal_size
53 | self.fft_dim = tuple(range(1, spatial_dims + 2))
54 | self.fft_norm = fft_norm
55 |
56 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
57 | input_amplitude = self._get_fft_amplitude(target)
58 | target_amplitude = self._get_fft_amplitude(input)
59 |
60 | # Compute distance between amplitude of frequency components
61 | # See Section 3.3 from https://arxiv.org/abs/2005.00341
62 | loss = F.mse_loss(target_amplitude, input_amplitude, reduction="none")
63 |
64 | if self.reduction == LossReduction.MEAN.value:
65 | loss = loss.mean()
66 | elif self.reduction == LossReduction.SUM.value:
67 | loss = loss.sum()
68 | elif self.reduction == LossReduction.NONE.value:
69 | pass
70 |
71 | return loss
72 |
73 | def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor:
74 | """
75 | Calculate the amplitude of the fourier transformations representation of the images
76 |
77 | Args:
78 | images: Images that are to undergo fftn
79 |
80 | Returns:
81 | fourier transformation amplitude
82 | """
83 | img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm)
84 |
85 | amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2)
86 |
87 | return amplitude
88 |
--------------------------------------------------------------------------------
/generative/networks/blocks/spade_norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from monai.networks.blocks import ADN, Convolution
18 |
19 |
20 | class SPADE(nn.Module):
21 | """
22 | SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291)
23 |
24 | Args:
25 | label_nc: number of semantic labels
26 | norm_nc: number of output channels
27 | kernel_size: kernel size
28 | spatial_dims: number of spatial dimensions
29 | hidden_channels: number of channels in the intermediate gamma and beta layers
30 | norm: type of base normalisation used before applying the SPADE normalisation
31 | norm_params: parameters for the base normalisation
32 | """
33 |
34 | def __init__(
35 | self,
36 | label_nc: int,
37 | norm_nc: int,
38 | kernel_size: int = 3,
39 | spatial_dims: int = 2,
40 | hidden_channels: int = 64,
41 | norm: str | tuple = "INSTANCE",
42 | norm_params: dict | None = None,
43 | ) -> None:
44 | super().__init__()
45 |
46 | if norm_params is None:
47 | norm_params = {}
48 | if len(norm_params) != 0:
49 | norm = (norm, norm_params)
50 | self.param_free_norm = ADN(
51 | act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc
52 | )
53 | self.mlp_shared = Convolution(
54 | spatial_dims=spatial_dims,
55 | in_channels=label_nc,
56 | out_channels=hidden_channels,
57 | kernel_size=kernel_size,
58 | norm=None,
59 | padding=kernel_size // 2,
60 | act="LEAKYRELU",
61 | )
62 | self.mlp_gamma = Convolution(
63 | spatial_dims=spatial_dims,
64 | in_channels=hidden_channels,
65 | out_channels=norm_nc,
66 | kernel_size=kernel_size,
67 | padding=kernel_size // 2,
68 | act=None,
69 | )
70 | self.mlp_beta = Convolution(
71 | spatial_dims=spatial_dims,
72 | in_channels=hidden_channels,
73 | out_channels=norm_nc,
74 | kernel_size=kernel_size,
75 | padding=kernel_size // 2,
76 | act=None,
77 | )
78 |
79 | def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:
80 | """
81 | Args:
82 | x: input tensor
83 | segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels.
84 | The map will be interpolated to the dimension of x internally.
85 | """
86 |
87 | # Part 1. generate parameter-free normalized activations
88 | normalized = self.param_free_norm(x)
89 |
90 | # Part 2. produce scaling and bias conditioned on semantic map
91 | segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
92 | actv = self.mlp_shared(segmap)
93 | gamma = self.mlp_gamma(actv)
94 | beta = self.mlp_beta(actv)
95 | out = normalized * (1 + gamma) + beta
96 | return out
97 |
--------------------------------------------------------------------------------
/generative/networks/blocks/transformerblock.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import torch
15 | import torch.nn as nn
16 | from monai.networks.blocks.mlp import MLPBlock
17 |
18 | from generative.networks.blocks.selfattention import SABlock
19 |
20 |
21 | class TransformerBlock(nn.Module):
22 | """
23 | A transformer block, based on: "Dosovitskiy et al.,
24 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
25 |
26 | Args:
27 | hidden_size: dimension of hidden layer.
28 | mlp_dim: dimension of feedforward layer.
29 | num_heads: number of attention heads.
30 | dropout_rate: faction of the input units to drop.
31 | qkv_bias: apply bias term for the qkv linear layer
32 | causal: whether to use causal attention.
33 | sequence_length: if causal is True, it is necessary to specify the sequence length.
34 | with_cross_attention: Whether to use cross attention for conditioning.
35 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
36 | """
37 |
38 | def __init__(
39 | self,
40 | hidden_size: int,
41 | mlp_dim: int,
42 | num_heads: int,
43 | dropout_rate: float = 0.0,
44 | qkv_bias: bool = False,
45 | causal: bool = False,
46 | sequence_length: int | None = None,
47 | with_cross_attention: bool = False,
48 | use_flash_attention: bool = False,
49 | ) -> None:
50 | self.with_cross_attention = with_cross_attention
51 | super().__init__()
52 |
53 | if not (0 <= dropout_rate <= 1):
54 | raise ValueError("dropout_rate should be between 0 and 1.")
55 |
56 | if hidden_size % num_heads != 0:
57 | raise ValueError("hidden_size should be divisible by num_heads.")
58 |
59 | self.norm1 = nn.LayerNorm(hidden_size)
60 | self.attn = SABlock(
61 | hidden_size=hidden_size,
62 | num_heads=num_heads,
63 | dropout_rate=dropout_rate,
64 | qkv_bias=qkv_bias,
65 | causal=causal,
66 | sequence_length=sequence_length,
67 | use_flash_attention=use_flash_attention,
68 | )
69 |
70 | self.norm2 = None
71 | self.cross_attn = None
72 | if self.with_cross_attention:
73 | self.norm2 = nn.LayerNorm(hidden_size)
74 | self.cross_attn = SABlock(
75 | hidden_size=hidden_size,
76 | num_heads=num_heads,
77 | dropout_rate=dropout_rate,
78 | qkv_bias=qkv_bias,
79 | with_cross_attention=with_cross_attention,
80 | causal=False,
81 | use_flash_attention=use_flash_attention,
82 | )
83 |
84 | self.norm3 = nn.LayerNorm(hidden_size)
85 | self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
86 |
87 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
88 | x = x + self.attn(self.norm1(x))
89 | if self.with_cross_attention:
90 | x = x + self.cross_attn(self.norm2(x), context=context)
91 | x = x + self.mlp(self.norm3(x))
92 | return x
93 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/docs/README.md:
--------------------------------------------------------------------------------
1 | # Chest X-ray with Latent Diffusion Models
2 |
3 | ### **Authors**
4 |
5 | MONAI Generative Models
6 |
7 | ### **Tags**
8 | Synthetic data, Latent Diffusion Model, Generative model, Chest X-ray
9 |
10 | ## **Model Description**
11 | This model is trained from scratch using the Latent Diffusion Model architecture [1] and is used for the synthesis of
12 | 2D Chest X-ray conditioned on Radiological reports. The model is divided into two parts: an autoencoder with a
13 | KL-regularisation model that compresses data into a latent space and a diffusion model that learns to generate
14 | conditioned synthetic latent representations. This model is conditioned on Findings and Impressions from radiological
15 | reports. The original repository can be found [here](https://github.com/Warvito/generative_chestxray)
16 |
17 | 
18 |
19 | Figure 1 - Synthetic images from the model.
20 |
21 | ## **Data**
22 | The model was trained on brain data from 90,000 participants from the MIMIC dataset [2] [3]. We downsampled the
23 | original images to have a format of 512 x 512 pixels.
24 |
25 | #### **Preprocessing**
26 | We resized the original images to make the smallest sides have 512 pixels. When inputting it to the network, we center
27 | cropped the images to 512 x 512. The pixel intensity was normalised to be between [0, 1]. The text data was obtained
28 | from associated radiological reports. We randoomly extracted sentences from the findings and impressions sections of the
29 | reports, having a maximum of 5 sentences and 77 tokens. The text was tokenised using the CLIPTokenizer from
30 | transformers package (https://github.com/huggingface/transformers) (pretrained model
31 | "stabilityai/stable-diffusion-2-1-base") and then encoded using CLIPTextModel from the same package and pretrained
32 | model.
33 |
34 |
35 | ## **Commands Example**
36 | Here we included a few examples of commands to sample images from the model and save them as .jpg files. The available
37 | arguments for this task are: "--prompt" (str) text prompt to condition the model on; "--guidance_scale" (float), the
38 | parameter that controls how much the image generation process follows the text prompt. The higher the value, the more
39 | the image sticks to a given text input (the common range is between 1-21).
40 |
41 | Examples:
42 |
43 | ```shell
44 | export PYTHONPATH=$PYTHONPATH:""
45 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Big right-sided pleural effusion" --guidance_scale 7.0
46 | ```
47 |
48 | ```shell
49 | export PYTHONPATH=$PYTHONPATH:""
50 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Small right-sided pleural effusion" --guidance_scale 7.0
51 | ```
52 |
53 | ```shell
54 | export PYTHONPATH=$PYTHONPATH:""
55 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Bilateral pleural effusion" --guidance_scale 7.0
56 | ```
57 |
58 | ```shell
59 | export PYTHONPATH=$PYTHONPATH:""
60 | $ python -m monai.bundle run save_jpg --config_file configs/inference.json --prompt "Cardiomegaly" --guidance_scale 7.0
61 | ```
62 |
63 |
64 | ## **References**
65 |
66 |
67 | [1] Pinaya, Walter HL, et al. "Brain imaging generation with latent diffusion models." MICCAI Workshop on Deep Generative Models. Springer, Cham, 2022.
68 |
69 | [2] Johnson, A., Lungren, M., Peng, Y., Lu, Z., Mark, R., Berkowitz, S., & Horng, S. (2019). MIMIC-CXR-JPG - chest radiographs with structured labels (version 2.0.0). PhysioNet. https://doi.org/10.13026/8360-t248.
70 |
71 | [3] Johnson AE, Pollard TJ, Berkowitz S, Greenbaum NR, Lungren MP, Deng CY, Mark RG, Horng S. MIMIC-CXR: A large publicly available database of labeled chest radiographs. arXiv preprint arXiv:1901.07042. 2019 Jan 21.
72 |
--------------------------------------------------------------------------------
/model-zoo/models/brain_image_synthesis_latent_diffusion_model/configs/inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "imports": [
3 | "$import torch",
4 | "$from datetime import datetime",
5 | "$from pathlib import Path"
6 | ],
7 | "bundle_root": ".",
8 | "model_dir": "$@bundle_root + '/models'",
9 | "output_dir": "$@bundle_root + '/output'",
10 | "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
11 | "gender": 0.0,
12 | "age": 0.1,
13 | "ventricular_vol": 0.2,
14 | "brain_vol": 0.4,
15 | "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
16 | "conditioning": "$torch.tensor([[@gender, @age, @ventricular_vol, @brain_vol]]).to(@device).unsqueeze(1)",
17 | "out_file": "$datetime.now().strftime('sample_%H%M%S_%d%m%Y') + '_' + str(@gender) + '_' + str(@age) + '_' + str(@ventricular_vol) + '_' + str(@brain_vol)",
18 | "autoencoder_def": {
19 | "_target_": "generative.networks.nets.AutoencoderKL",
20 | "spatial_dims": 3,
21 | "in_channels": 1,
22 | "out_channels": 1,
23 | "latent_channels": 3,
24 | "num_channels": [
25 | 64,
26 | 128,
27 | 128,
28 | 128
29 | ],
30 | "num_res_blocks": 2,
31 | "norm_num_groups": 32,
32 | "norm_eps": 1e-06,
33 | "attention_levels": [
34 | false,
35 | false,
36 | false,
37 | false
38 | ],
39 | "with_encoder_nonlocal_attn": false,
40 | "with_decoder_nonlocal_attn": false
41 | },
42 | "load_autoencoder_path": "$@model_dir + '/autoencoder.pth'",
43 | "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
44 | "autoencoder": "$@autoencoder_def.to(@device)",
45 | "diffusion_def": {
46 | "_target_": "generative.networks.nets.DiffusionModelUNet",
47 | "spatial_dims": 3,
48 | "in_channels": 7,
49 | "out_channels": 3,
50 | "num_channels": [
51 | 256,
52 | 512,
53 | 768
54 | ],
55 | "num_res_blocks": 2,
56 | "attention_levels": [
57 | false,
58 | true,
59 | true
60 | ],
61 | "norm_num_groups": 32,
62 | "norm_eps": 1e-06,
63 | "resblock_updown": true,
64 | "num_head_channels": [
65 | 0,
66 | 512,
67 | 768
68 | ],
69 | "with_conditioning": true,
70 | "transformer_num_layers": 1,
71 | "cross_attention_dim": 4,
72 | "upcast_attention": true,
73 | "use_flash_attention": false
74 | },
75 | "load_diffusion_path": "$@model_dir + '/diffusion_model.pth'",
76 | "load_diffusion": "$@diffusion_def.load_state_dict(torch.load(@load_diffusion_path))",
77 | "diffusion": "$@diffusion_def.to(@device)",
78 | "scheduler": {
79 | "_target_": "generative.networks.schedulers.DDIMScheduler",
80 | "_requires_": [
81 | "@load_diffusion",
82 | "@load_autoencoder"
83 | ],
84 | "beta_start": 0.0015,
85 | "beta_end": 0.0205,
86 | "num_train_timesteps": 1000,
87 | "schedule": "scaled_linear_beta",
88 | "clip_sample": false
89 | },
90 | "noise": "$torch.randn((1, 3, 20, 28, 20)).to(@device)",
91 | "set_timesteps": "$@scheduler.set_timesteps(num_inference_steps=50)",
92 | "sampler": {
93 | "_target_": "scripts.sampler.Sampler",
94 | "_requires_": "@set_timesteps"
95 | },
96 | "sample": "$@sampler.sampling_fn(@noise, @autoencoder, @diffusion, @scheduler, @conditioning)",
97 | "saver": {
98 | "_target_": "scripts.saver.NiftiSaver",
99 | "_requires_": "@create_output_dir",
100 | "output_dir": "@output_dir"
101 | },
102 | "save_nii": "$@saver.save(@sample, @out_file)",
103 | "save": "$torch.save(@sample, @output_dir + '/' + @out_file + '.pt')"
104 | }
105 |
--------------------------------------------------------------------------------
/tests/test_scheduler_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.networks.schedulers import DDPMScheduler
20 |
21 | TEST_2D_CASE = []
22 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
23 | for variance_type in ["fixed_small", "fixed_large"]:
24 | TEST_2D_CASE.append(
25 | [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)]
26 | )
27 |
28 | TEST_3D_CASE = []
29 | for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
30 | for variance_type in ["fixed_small", "fixed_large"]:
31 | TEST_3D_CASE.append(
32 | [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]
33 | )
34 |
35 | TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
36 |
37 |
38 | class TestDDPMScheduler(unittest.TestCase):
39 | @parameterized.expand(TEST_CASES)
40 | def test_add_noise_2d_shape(self, input_param, input_shape, expected_shape):
41 | scheduler = DDPMScheduler(**input_param)
42 | original_sample = torch.zeros(input_shape)
43 | noise = torch.randn_like(original_sample)
44 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
45 |
46 | noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
47 | self.assertEqual(noisy.shape, expected_shape)
48 |
49 | @parameterized.expand(TEST_CASES)
50 | def test_step_shape(self, input_param, input_shape, expected_shape):
51 | scheduler = DDPMScheduler(**input_param)
52 | model_output = torch.randn(input_shape)
53 | sample = torch.randn(input_shape)
54 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
55 | self.assertEqual(output_step[0].shape, expected_shape)
56 | self.assertEqual(output_step[1].shape, expected_shape)
57 |
58 | @parameterized.expand(TEST_CASES)
59 | def test_get_velocity_shape(self, input_param, input_shape, expected_shape):
60 | scheduler = DDPMScheduler(**input_param)
61 | sample = torch.randn(input_shape)
62 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long()
63 | velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps)
64 | self.assertEqual(velocity.shape, expected_shape)
65 |
66 | def test_step_learned(self):
67 | for variance_type in ["learned", "learned_range"]:
68 | scheduler = DDPMScheduler(variance_type=variance_type)
69 | model_output = torch.randn(2, 6, 16, 16)
70 | sample = torch.randn(2, 3, 16, 16)
71 | output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
72 | self.assertEqual(output_step[0].shape, sample.shape)
73 | self.assertEqual(output_step[1].shape, sample.shape)
74 |
75 | def test_set_timesteps(self):
76 | scheduler = DDPMScheduler(num_train_timesteps=1000)
77 | scheduler.set_timesteps(num_inference_steps=100)
78 | self.assertEqual(scheduler.num_inference_steps, 100)
79 | self.assertEqual(len(scheduler.timesteps), 100)
80 |
81 | def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
82 | scheduler = DDPMScheduler(num_train_timesteps=1000)
83 | with self.assertRaises(ValueError):
84 | scheduler.set_timesteps(num_inference_steps=2000)
85 |
86 |
87 | if __name__ == "__main__":
88 | unittest.main()
89 |
--------------------------------------------------------------------------------
/tests/test_encoder_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.networks.blocks import SpatialRescaler
20 |
21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22 |
23 | CASES = [
24 | [
25 | {
26 | "spatial_dims": 2,
27 | "n_stages": 1,
28 | "method": "bilinear",
29 | "multiplier": 0.5,
30 | "in_channels": None,
31 | "out_channels": None,
32 | },
33 | (1, 1, 16, 16),
34 | (1, 1, 8, 8),
35 | ],
36 | [
37 | {
38 | "spatial_dims": 2,
39 | "n_stages": 1,
40 | "method": "bilinear",
41 | "multiplier": 0.5,
42 | "in_channels": 3,
43 | "out_channels": 2,
44 | },
45 | (1, 3, 16, 16),
46 | (1, 2, 8, 8),
47 | ],
48 | [
49 | {
50 | "spatial_dims": 3,
51 | "n_stages": 1,
52 | "method": "trilinear",
53 | "multiplier": 0.5,
54 | "in_channels": None,
55 | "out_channels": None,
56 | },
57 | (1, 1, 16, 16, 16),
58 | (1, 1, 8, 8, 8),
59 | ],
60 | [
61 | {
62 | "spatial_dims": 3,
63 | "n_stages": 1,
64 | "method": "trilinear",
65 | "multiplier": 0.5,
66 | "in_channels": 3,
67 | "out_channels": 2,
68 | },
69 | (1, 3, 16, 16, 16),
70 | (1, 2, 8, 8, 8),
71 | ],
72 | [
73 | {
74 | "spatial_dims": 3,
75 | "n_stages": 1,
76 | "method": "trilinear",
77 | "multiplier": (0.25, 0.5, 0.75),
78 | "in_channels": 3,
79 | "out_channels": 2,
80 | },
81 | (1, 3, 20, 20, 20),
82 | (1, 2, 5, 10, 15),
83 | ],
84 | [
85 | {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2},
86 | (1, 3, 16, 16),
87 | (1, 2, 8, 8),
88 | ],
89 | [
90 | {
91 | "spatial_dims": 3,
92 | "n_stages": 1,
93 | "size": (8, 8, 8),
94 | "method": "trilinear",
95 | "in_channels": None,
96 | "out_channels": None,
97 | },
98 | (1, 1, 16, 16, 16),
99 | (1, 1, 8, 8, 8),
100 | ],
101 | ]
102 |
103 |
104 | class TestSpatialRescaler(unittest.TestCase):
105 | @parameterized.expand(CASES)
106 | def test_shape(self, input_param, input_shape, expected_shape):
107 | module = SpatialRescaler(**input_param).to(device)
108 |
109 | result = module(torch.randn(input_shape).to(device))
110 | self.assertEqual(result.shape, expected_shape)
111 |
112 | def test_method_not_in_available_options(self):
113 | with self.assertRaises(AssertionError):
114 | SpatialRescaler(method="none")
115 |
116 | def test_n_stages_is_negative(self):
117 | with self.assertRaises(AssertionError):
118 | SpatialRescaler(n_stages=-1)
119 |
120 | def test_use_size_but_n_stages_is_not_one(self):
121 | with self.assertRaises(ValueError):
122 | SpatialRescaler(n_stages=2, size=[8, 8, 8])
123 |
124 | def test_both_size_and_multiplier_defined(self):
125 | with self.assertRaises(ValueError):
126 | SpatialRescaler(size=[1, 2, 3], multiplier=0.5)
127 |
128 |
129 | if __name__ == "__main__":
130 | unittest.main()
131 |
--------------------------------------------------------------------------------
/tests/test_adversarial.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from parameterized import parameterized
18 |
19 | from generative.losses import PatchAdversarialLoss
20 |
21 | shapes_tensors = {"2d": [4, 1, 64, 64], "3d": [4, 1, 64, 64, 64]}
22 | reductions = ["sum", "mean"]
23 | criterion = ["bce", "least_squares", "hinge"]
24 |
25 | TEST_CASE_CREATION_FAIL = [{"reduction": "sum", "criterion": "invalid"}]
26 |
27 | TEST_CASES_LOSS_LOGIC_2D = []
28 | TEST_CASES_LOSS_LOGIC_3D = []
29 |
30 | for c in criterion:
31 | for r in reductions:
32 | TEST_CASES_LOSS_LOGIC_2D.append([{"reduction": r, "criterion": c}, shapes_tensors["2d"]])
33 | TEST_CASES_LOSS_LOGIC_3D.append([{"reduction": r, "criterion": c}, shapes_tensors["3d"]])
34 |
35 | TEST_CASES_LOSS_LOGIC_LIST = []
36 | for c in criterion:
37 | TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["2d"]])
38 | TEST_CASES_LOSS_LOGIC_LIST.append([{"reduction": "none", "criterion": c}, shapes_tensors["3d"]])
39 |
40 |
41 | class TestPatchAdversarialLoss(unittest.TestCase):
42 | def get_input(self, shape, is_positive):
43 | """
44 | Get tensor for the tests. The tensor is around (-1) or (+1), depending on
45 | is_positive.
46 | """
47 | if is_positive:
48 | offset = 1
49 | else:
50 | offset = -1
51 | return torch.ones(shape) * (offset) + 0.01 * torch.randn(shape)
52 |
53 | def test_criterion(self):
54 | """
55 | Make sure that unknown criterion fail.
56 | """
57 | with self.assertRaises(ValueError):
58 | PatchAdversarialLoss(**TEST_CASE_CREATION_FAIL[0])
59 |
60 | @parameterized.expand(TEST_CASES_LOSS_LOGIC_2D + TEST_CASES_LOSS_LOGIC_3D)
61 | def test_loss_logic(self, input_param: dict, shape_input: list):
62 | """
63 | We want to make sure that the adversarial losses do what they should.
64 | If the discriminator takes in a tensor that looks positive, yet the label is fake,
65 | the loss should be bigger than that obtained with a tensor that looks negative.
66 | Same for the real label, and for the generator.
67 | """
68 | loss = PatchAdversarialLoss(**input_param)
69 | fakes = self.get_input(shape_input, is_positive=False)
70 | reals = self.get_input(shape_input, is_positive=True)
71 | # Discriminator: fake label
72 | loss_disc_f_f = loss(fakes, target_is_real=False, for_discriminator=True)
73 | loss_disc_f_r = loss(reals, target_is_real=False, for_discriminator=True)
74 | assert loss_disc_f_f < loss_disc_f_r
75 | # Discriminator: real label
76 | loss_disc_r_f = loss(fakes, target_is_real=True, for_discriminator=True)
77 | loss_disc_r_r = loss(reals, target_is_real=True, for_discriminator=True)
78 | assert loss_disc_r_f > loss_disc_r_r
79 | # Generator:
80 | loss_gen_f = loss(fakes, target_is_real=True, for_discriminator=False) # target_is_real is overridden
81 | loss_gen_r = loss(reals, target_is_real=True, for_discriminator=False) # target_is_real is overridden
82 | assert loss_gen_f > loss_gen_r
83 |
84 | @parameterized.expand(TEST_CASES_LOSS_LOGIC_LIST)
85 | def test_multiple_discs(self, input_param: dict, shape_input):
86 | shapes = [shape_input] + [shape_input[0:2] + [int(i / j) for i in shape_input[2:]] for j in range(1, 3)]
87 | inputs = [self.get_input(shapes[i], is_positive=True) for i in range(len(shapes))]
88 | loss = PatchAdversarialLoss(**input_param)
89 | assert len(loss(inputs, for_discriminator=True, target_is_real=True)) == 3
90 |
91 |
92 | if __name__ == "__main__":
93 | unittest.main()
94 |
--------------------------------------------------------------------------------
/model-zoo/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json:
--------------------------------------------------------------------------------
1 | {
2 | "imports": [
3 | "$import torch",
4 | "$from datetime import datetime",
5 | "$from pathlib import Path",
6 | "$from transformers import CLIPTextModel",
7 | "$from transformers import CLIPTokenizer"
8 | ],
9 | "bundle_root": ".",
10 | "model_dir": "$@bundle_root + '/models'",
11 | "output_dir": "$@bundle_root + '/output'",
12 | "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
13 | "prompt": "Big right-sided pleural effusion",
14 | "prompt_list": "$['', @prompt]",
15 | "guidance_scale": 7.0,
16 | "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
17 | "tokenizer": "$CLIPTokenizer.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", subfolder=\"tokenizer\")",
18 | "text_encoder": "$CLIPTextModel.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\", subfolder=\"text_encoder\")",
19 | "tokenized_prompt": "$@tokenizer(@prompt_list, padding=\"max_length\", max_length=@tokenizer.model_max_length, truncation=True,return_tensors=\"pt\")",
20 | "prompt_embeds": "$@text_encoder(@tokenized_prompt.input_ids.squeeze(1))[0].to(@device)",
21 | "out_file": "$datetime.now().strftime('sample_%H%M%S_%d%m%Y')",
22 | "autoencoder_def": {
23 | "_target_": "generative.networks.nets.AutoencoderKL",
24 | "spatial_dims": 2,
25 | "in_channels": 1,
26 | "out_channels": 1,
27 | "latent_channels": 3,
28 | "num_channels": [
29 | 64,
30 | 128,
31 | 128,
32 | 128
33 | ],
34 | "num_res_blocks": 2,
35 | "norm_num_groups": 32,
36 | "norm_eps": 1e-06,
37 | "attention_levels": [
38 | false,
39 | false,
40 | false,
41 | false
42 | ],
43 | "with_encoder_nonlocal_attn": false,
44 | "with_decoder_nonlocal_attn": false
45 | },
46 | "load_autoencoder_path": "$@model_dir + '/autoencoder.pth'",
47 | "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
48 | "autoencoder": "$@autoencoder_def.to(@device)",
49 | "diffusion_def": {
50 | "_target_": "generative.networks.nets.DiffusionModelUNet",
51 | "spatial_dims": 2,
52 | "in_channels": 3,
53 | "out_channels": 3,
54 | "num_channels": [
55 | 256,
56 | 512,
57 | 768
58 | ],
59 | "num_res_blocks": 2,
60 | "attention_levels": [
61 | false,
62 | true,
63 | true
64 | ],
65 | "norm_num_groups": 32,
66 | "norm_eps": 1e-06,
67 | "resblock_updown": false,
68 | "num_head_channels": [
69 | 0,
70 | 512,
71 | 768
72 | ],
73 | "with_conditioning": true,
74 | "transformer_num_layers": 1,
75 | "cross_attention_dim": 1024
76 | },
77 | "load_diffusion_path": "$@model_dir + '/diffusion_model.pth'",
78 | "load_diffusion": "$@diffusion_def.load_state_dict(torch.load(@load_diffusion_path))",
79 | "diffusion": "$@diffusion_def.to(@device)",
80 | "scheduler": {
81 | "_target_": "generative.networks.schedulers.DDIMScheduler",
82 | "_requires_": [
83 | "@load_diffusion",
84 | "@load_autoencoder"
85 | ],
86 | "beta_start": 0.0015,
87 | "beta_end": 0.0205,
88 | "num_train_timesteps": 1000,
89 | "schedule": "scaled_linear_beta",
90 | "prediction_type": "v_prediction",
91 | "clip_sample": false
92 | },
93 | "noise": "$torch.randn((1, 3, 64, 64)).to(@device)",
94 | "set_timesteps": "$@scheduler.set_timesteps(num_inference_steps=50)",
95 | "sampler": {
96 | "_target_": "scripts.sampler.Sampler",
97 | "_requires_": "@set_timesteps"
98 | },
99 | "sample": "$@sampler.sampling_fn(@noise, @autoencoder, @diffusion, @scheduler, @prompt_embeds)",
100 | "saver": {
101 | "_target_": "scripts.saver.JPGSaver",
102 | "_requires_": "@create_output_dir",
103 | "output_dir": "@output_dir"
104 | },
105 | "save_jpg": "$@saver.save(@sample, @out_file)",
106 | "save": "$torch.save(@sample, @output_dir + '/' + @out_file + '.pt')"
107 | }
108 |
--------------------------------------------------------------------------------
/model-zoo/models/mednist_ddpm/bundle/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # This defines the training script for the network
2 |
3 | # choose a new directory for every run
4 | output_dir: $datetime.datetime.now().strftime('./results/output_%y%m%d_%H%M%S')
5 | dataset_dir: ./data
6 |
7 | train_data:
8 | _target_ : MedNISTDataset
9 | root_dir: '@dataset_dir'
10 | section: training
11 | download: true
12 | progress: false
13 | seed: 0
14 |
15 | val_data:
16 | _target_ : MedNISTDataset
17 | root_dir: '@dataset_dir'
18 | section: validation
19 | download: true
20 | progress: false
21 | seed: 0
22 |
23 | train_datalist: '$[{"image": item["image"]} for item in @train_data.data if item["class_name"] == "Hand"]'
24 | val_datalist: '$[{"image": item["image"]} for item in @val_data.data if item["class_name"] == "Hand"]'
25 |
26 | batch_size: 8
27 | num_substeps: 1
28 | num_workers: 4
29 | use_thread_workers: false
30 |
31 | lr: 0.000025
32 | rand_prob: 0.5
33 | num_epochs: 75
34 | val_interval: 5
35 | save_interval: 5
36 |
37 | train_transforms:
38 | - _target_: RandAffined
39 | keys: '@image'
40 | rotate_range:
41 | - ['$-np.pi / 36', '$np.pi / 36']
42 | - ['$-np.pi / 36', '$np.pi / 36']
43 | translate_range:
44 | - [-1, 1]
45 | - [-1, 1]
46 | scale_range:
47 | - [-0.05, 0.05]
48 | - [-0.05, 0.05]
49 | spatial_size: [64, 64]
50 | padding_mode: "zeros"
51 | prob: '@rand_prob'
52 |
53 | train_ds:
54 | _target_: Dataset
55 | data: $@train_datalist
56 | transform:
57 | _target_: Compose
58 | transforms: '$@base_transforms + @train_transforms'
59 |
60 | train_loader:
61 | _target_: ThreadDataLoader
62 | dataset: '@train_ds'
63 | batch_size: '@batch_size'
64 | repeats: '@num_substeps'
65 | num_workers: '@num_workers'
66 | use_thread_workers: '@use_thread_workers'
67 | persistent_workers: '$@num_workers > 0'
68 | shuffle: true
69 |
70 | val_ds:
71 | _target_: Dataset
72 | data: $@val_datalist
73 | transform:
74 | _target_: Compose
75 | transforms: '@base_transforms'
76 |
77 | val_loader:
78 | _target_: DataLoader
79 | dataset: '@val_ds'
80 | batch_size: '@batch_size'
81 | num_workers: '@num_workers'
82 | persistent_workers: '$@num_workers > 0'
83 | shuffle: false
84 |
85 | lossfn:
86 | _target_: torch.nn.MSELoss
87 |
88 | optimizer:
89 | _target_: torch.optim.Adam
90 | params: $@network.parameters()
91 | lr: '@lr'
92 |
93 | prepare_batch:
94 | _target_: generative.engines.DiffusionPrepareBatch
95 | num_train_timesteps: '@num_train_timesteps'
96 |
97 | val_handlers:
98 | - _target_: StatsHandler
99 | name: train_log
100 | output_transform: '$lambda x: None'
101 | _disabled_: '@is_not_rank0'
102 |
103 | evaluator:
104 | _target_: SupervisedEvaluator
105 | device: '@device'
106 | val_data_loader: '@val_loader'
107 | network: '@network'
108 | amp: '@use_amp'
109 | inferer: '@inferer'
110 | prepare_batch: '@prepare_batch'
111 | key_val_metric:
112 | val_mean_abs_error:
113 | _target_: MeanAbsoluteError
114 | output_transform: $monai.handlers.from_engine([@pred, @label])
115 | metric_cmp_fn: '$scripts.inv_metric_cmp_fn'
116 | val_handlers: '$list(filter(bool, @val_handlers))'
117 |
118 | handlers:
119 | - _target_: CheckpointLoader
120 | _disabled_: $not os.path.exists(@ckpt_path)
121 | load_path: '@ckpt_path'
122 | load_dict:
123 | model: '@network'
124 | - _target_: ValidationHandler
125 | validator: '@evaluator'
126 | epoch_level: true
127 | interval: '@val_interval'
128 | - _target_: CheckpointSaver
129 | save_dir: '@output_dir'
130 | save_dict:
131 | model: '@network'
132 | save_interval: '@save_interval'
133 | save_final: true
134 | epoch_level: true
135 | _disabled_: '@is_not_rank0'
136 |
137 | trainer:
138 | _target_: SupervisedTrainer
139 | max_epochs: '@num_epochs'
140 | device: '@device'
141 | train_data_loader: '@train_loader'
142 | network: '@network'
143 | loss_function: '@lossfn'
144 | optimizer: '@optimizer'
145 | inferer: '@inferer'
146 | prepare_batch: '@prepare_batch'
147 | key_train_metric:
148 | train_acc:
149 | _target_: MeanSquaredError
150 | output_transform: $monai.handlers.from_engine([@pred, @label])
151 | metric_cmp_fn: '$scripts.inv_metric_cmp_fn'
152 | train_handlers: '$list(filter(bool, @handlers))'
153 | amp: '@use_amp'
154 |
155 | training:
156 | - '$monai.utils.set_determinism(0)'
157 | - '$@trainer.run()'
158 |
--------------------------------------------------------------------------------
/generative/networks/nets/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from generative.networks.blocks.transformerblock import TransformerBlock
18 |
19 | __all__ = ["DecoderOnlyTransformer"]
20 |
21 |
22 | class AbsolutePositionalEmbedding(nn.Module):
23 | """Absolute positional embedding.
24 |
25 | Args:
26 | max_seq_len: Maximum sequence length.
27 | embedding_dim: Dimensionality of the embedding.
28 | """
29 |
30 | def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
31 | super().__init__()
32 | self.max_seq_len = max_seq_len
33 | self.embedding_dim = embedding_dim
34 | self.embedding = nn.Embedding(max_seq_len, embedding_dim)
35 |
36 | def forward(self, x: torch.Tensor) -> torch.Tensor:
37 | batch_size, seq_len = x.size()
38 | positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
39 | return self.embedding(positions)
40 |
41 |
42 | class DecoderOnlyTransformer(nn.Module):
43 | """Decoder-only (Autoregressive) Transformer model.
44 |
45 | Args:
46 | num_tokens: Number of tokens in the vocabulary.
47 | max_seq_len: Maximum sequence length.
48 | attn_layers_dim: Dimensionality of the attention layers.
49 | attn_layers_depth: Number of attention layers.
50 | attn_layers_heads: Number of attention heads.
51 | with_cross_attention: Whether to use cross attention for conditioning.
52 | embedding_dropout_rate: Dropout rate for the embedding.
53 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
54 | """
55 |
56 | def __init__(
57 | self,
58 | num_tokens: int,
59 | max_seq_len: int,
60 | attn_layers_dim: int,
61 | attn_layers_depth: int,
62 | attn_layers_heads: int,
63 | with_cross_attention: bool = False,
64 | embedding_dropout_rate: float = 0.0,
65 | use_flash_attention: bool = False,
66 | ) -> None:
67 | super().__init__()
68 | self.num_tokens = num_tokens
69 | self.max_seq_len = max_seq_len
70 | self.attn_layers_dim = attn_layers_dim
71 | self.attn_layers_depth = attn_layers_depth
72 | self.attn_layers_heads = attn_layers_heads
73 | self.with_cross_attention = with_cross_attention
74 |
75 | self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
76 | self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
77 | self.embedding_dropout = nn.Dropout(embedding_dropout_rate)
78 |
79 | self.blocks = nn.ModuleList(
80 | [
81 | TransformerBlock(
82 | hidden_size=attn_layers_dim,
83 | mlp_dim=attn_layers_dim * 4,
84 | num_heads=attn_layers_heads,
85 | dropout_rate=0.0,
86 | qkv_bias=False,
87 | causal=True,
88 | sequence_length=max_seq_len,
89 | with_cross_attention=with_cross_attention,
90 | use_flash_attention=use_flash_attention,
91 | )
92 | for _ in range(attn_layers_depth)
93 | ]
94 | )
95 |
96 | self.to_logits = nn.Linear(attn_layers_dim, num_tokens)
97 |
98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
99 | tok_emb = self.token_embeddings(x)
100 | pos_emb = self.position_embeddings(x)
101 | x = self.embedding_dropout(tok_emb + pos_emb)
102 |
103 | for block in self.blocks:
104 | x = block(x, context=context)
105 |
106 | return self.to_logits(x)
107 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = monai-generative
3 | author = MONAI Consortium
4 | author_email = monai.contact@gmail.com
5 | url = https://monai.io/
6 | description = MONAI Generative Models makes it easy to train, evaluate, and deploy generative models and related applications
7 | long_description = file:README.md
8 | long_description_content_type = text/markdown; charset=UTF-8
9 | platforms = OS Independent
10 | license = Apache License 2.0
11 | license_files =
12 | LICENSE
13 | project_urls =
14 | Documentation=https://docs.monai.io/
15 | Bug Tracker=https://github.com/Project-MONAI/GenerativeModels/issues
16 | Source Code=https://github.com/Project-MONAI/GenerativeModels
17 | classifiers =
18 | Intended Audience :: Developers
19 | Intended Audience :: Education
20 | Intended Audience :: Science/Research
21 | Intended Audience :: Healthcare Industry
22 | Programming Language :: C++
23 | Programming Language :: Python :: 3
24 | Programming Language :: Python :: 3.8
25 | Programming Language :: Python :: 3.9
26 | Programming Language :: Python :: 3.10
27 | Topic :: Scientific/Engineering
28 | Topic :: Scientific/Engineering :: Artificial Intelligence
29 | Topic :: Scientific/Engineering :: Medical Science Apps.
30 | Topic :: Scientific/Engineering :: Information Analysis
31 | Topic :: Software Development
32 | Topic :: Software Development :: Libraries
33 | Typing :: Typed
34 |
35 | [options]
36 | python_requires = >= 3.8
37 | # for compiling and develop setup only
38 | # no need to specify the versions so that we could
39 | # compile for multiple targeted versions.
40 | setup_requires =
41 | torch
42 | ninja
43 | install_requires =
44 | monai>=1.2.0rc1
45 | torch>=1.9
46 | numpy>=1.20
47 |
48 | [flake8]
49 | select = B,C,E,F,N,P,T4,W,B9
50 | max_line_length = 120
51 | # C408 ignored because we like the dict keyword argument syntax
52 | # E501 is not flexible enough, we're using B950 instead
53 | # N812 lowercase 'torch.nn.functional' imported as non lowercase 'F'
54 | # B023 https://github.com/Project-MONAI/MONAI/issues/4627
55 | # B028 https://github.com/Project-MONAI/MONAI/issues/5855
56 | # B907 https://github.com/Project-MONAI/MONAI/issues/5868
57 | # B908 https://github.com/Project-MONAI/MONAI/issues/6503
58 | ignore =
59 | E203
60 | E501
61 | E741
62 | W503
63 | W504
64 | C408
65 | N812
66 | B023
67 | B905
68 | B028
69 | B907
70 | B908
71 | per_file_ignores = __init__.py: F401, __main__.py: F401
72 | exclude = *.pyi,.git,.eggs,generative/_version.py,versioneer.py,venv,.venv,_version.py,tutorials/
73 |
74 | [isort]
75 | known_first_party = generative
76 | profile = black
77 | line_length = 120
78 | # generative/networks/layers/ is excluded because it is raising JIT errors
79 | skip = .git, .eggs, venv, .venv, versioneer.py, _version.py, conf.py, monai/__init__.py, tutorials/, generative/networks/layers/
80 | skip_glob = *.pyi
81 | add_imports = from __future__ import annotations
82 | append_only = true
83 |
84 | [mypy]
85 | # Suppresses error messages about imports that cannot be resolved.
86 | ignore_missing_imports = True
87 | # Changes the treatment of arguments with a default value of None by not implicitly making their type Optional.
88 | no_implicit_optional = True
89 | # Warns about casting an expression to its inferred type.
90 | warn_redundant_casts = True
91 | # No error on unneeded # type: ignore comments.
92 | warn_unused_ignores = False
93 | # Shows a warning when returning a value with type Any from a function declared with a non-Any return type.
94 | warn_return_any = True
95 | # Prohibit equality checks, identity checks, and container checks between non-overlapping types.
96 | strict_equality = True
97 | # Shows column numbers in error messages.
98 | show_column_numbers = True
99 | # Shows error codes in error messages.
100 | show_error_codes = True
101 | # Use visually nicer output in error messages: use soft word wrap, show source code snippets, and show error location markers.
102 | pretty = False
103 | # Warns about per-module sections in the config file that do not match any files processed when invoking mypy.
104 | warn_unused_configs = True
105 | # Make arguments prepended via Concatenate be truly positional-only.
106 | strict_concatenate = True
107 |
108 | exclude = venv/
109 |
110 | [coverage:run]
111 | concurrency = multiprocessing
112 | source = .
113 | data_file = .coverage/.coverage
114 | omit = setup.py
115 |
116 | [coverage:report]
117 | exclude_lines =
118 | pragma: no cover
119 | if TYPE_CHECKING:
120 | # Don't complain if tests don't hit code:
121 | raise NotImplementedError
122 | if __name__ == .__main__.:
123 | show_missing = True
124 | skip_covered = True
125 |
126 | [coverage:xml]
127 | output = coverage.xml
128 |
--------------------------------------------------------------------------------
/generative/utils/component_store.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from collections import namedtuple
15 | from keyword import iskeyword
16 | from textwrap import dedent, indent
17 | from typing import Any, Callable, Dict, Iterable, TypeVar
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | def is_variable(name):
23 | """Returns True if `name` is a valid Python variable name and also not a keyword."""
24 | return name.isidentifier() and not iskeyword(name)
25 |
26 |
27 | class ComponentStore:
28 | """
29 | Represents a storage object for other objects (specifically functions) keyed to a name with a description.
30 |
31 | These objects act as global named places for storing components for objects parameterised by component names.
32 | Typically this is functions although other objects can be added. Printing a component store will produce a
33 | list of members along with their docstring information if present.
34 |
35 | Example:
36 |
37 | .. code-block:: python
38 |
39 | TestStore = ComponentStore("Test Store", "A test store for demo purposes")
40 |
41 | @TestStore.add_def("my_func_name", "Some description of your function")
42 | def _my_func(a, b):
43 | '''A description of your function here.'''
44 | return a * b
45 |
46 | print(TestStore) # will print out name, description, and 'my_func_name' with the docstring
47 |
48 | func = TestStore["my_func_name"]
49 | result = func(7, 6)
50 |
51 | """
52 |
53 | _Component = namedtuple("Component", ("description", "value")) # internal value pair
54 |
55 | def __init__(self, name: str, description: str) -> None:
56 | self.components: Dict[str, self._Component] = {}
57 | self.name: str = name
58 | self.description: str = description
59 |
60 | self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip()
61 |
62 | def add(self, name: str, desc: str, value: T) -> T:
63 | """Store the object `value` under the name `name` with description `desc`."""
64 | if not is_variable(name):
65 | raise ValueError("Name of component must be valid Python identifier")
66 |
67 | self.components[name] = self._Component(desc, value)
68 | return value
69 |
70 | def add_def(self, name: str, desc: str) -> Callable:
71 | """Returns a decorator which stores the decorated function under `name` with description `desc`."""
72 |
73 | def deco(func):
74 | """Decorator to add a function to a store."""
75 | return self.add(name, desc, func)
76 |
77 | return deco
78 |
79 | def __contains__(self, name: str) -> bool:
80 | """Returns True if the given name is stored."""
81 | return name in self.components
82 |
83 | def __len__(self) -> int:
84 | """Returns the number of stored components."""
85 | return len(self.components)
86 |
87 | def __iter__(self) -> Iterable:
88 | """Yields name/component pairs."""
89 | for k, v in self.components.items():
90 | yield k, v.value
91 |
92 | def __str__(self):
93 | result = f"Component Store '{self.name}': {self.description}\nAvailable components:"
94 | for k, v in self.components.items():
95 | result += f"\n* {k}:"
96 |
97 | if hasattr(v.value, "__doc__"):
98 | doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ")
99 | result += f"\n{doc}\n"
100 | else:
101 | result += f" {v.description}"
102 |
103 | return result
104 |
105 | def __getattr__(self, name: str) -> Any:
106 | """Returns the stored object under the given name."""
107 | if name in self.components:
108 | return self.components[name].value
109 | else:
110 | return self.__getattribute__(name)
111 |
112 | def __getitem__(self, name: str) -> Any:
113 | """Returns the stored object under the given name."""
114 | if name in self.components:
115 | return self.components[name].value
116 | else:
117 | raise ValueError(f"Component '{name}' not found")
118 |
--------------------------------------------------------------------------------
/generative/metrics/fid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 |
13 | from __future__ import annotations
14 |
15 | import numpy as np
16 | import torch
17 | from monai.metrics.metric import Metric
18 | from scipy import linalg
19 |
20 |
21 | class FIDMetric(Metric):
22 | """
23 | Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
24 | Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
25 | https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format
26 | (number images, number of features)) extracted from the a pretrained network.
27 |
28 | Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.
29 | However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and
30 | MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial
31 | average pooling.
32 | """
33 |
34 | def __init__(self) -> None:
35 | super().__init__()
36 |
37 | def __call__(self, y_pred: torch.Tensor, y: torch.Tensor):
38 | return get_fid_score(y_pred, y)
39 |
40 |
41 | def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
42 | y = y.double()
43 | y_pred = y_pred.double()
44 |
45 | if y.ndimension() > 2:
46 | raise ValueError("Inputs should have (number images, number of features) shape.")
47 |
48 | mu_y_pred = torch.mean(y_pred, dim=0)
49 | sigma_y_pred = _cov(y_pred, rowvar=False)
50 | mu_y = torch.mean(y, dim=0)
51 | sigma_y = _cov(y, rowvar=False)
52 |
53 | return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y)
54 |
55 |
56 | def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
57 | """
58 | Estimate a covariance matrix of the variables.
59 |
60 | Args:
61 | input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
62 | and each column a single observation of all those variables.
63 | rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
64 | Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
65 | observations.
66 | """
67 | if input_data.dim() < 2:
68 | input_data = input_data.view(1, -1)
69 |
70 | if not rowvar and input_data.size(0) != 1:
71 | input_data = input_data.t()
72 |
73 | factor = 1.0 / (input_data.size(1) - 1)
74 | input_data = input_data - torch.mean(input_data, dim=1, keepdim=True)
75 | return factor * input_data.matmul(input_data.t()).squeeze()
76 |
77 |
78 | def _sqrtm(input_data: torch.Tensor) -> torch.Tensor:
79 | """Compute the square root of a matrix."""
80 | scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False)
81 | return torch.from_numpy(scipy_res)
82 |
83 |
84 | def compute_frechet_distance(
85 | mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6
86 | ) -> torch.Tensor:
87 | """The Frechet distance between multivariate normal distributions."""
88 | diff = mu_x - mu_y
89 |
90 | covmean = _sqrtm(sigma_x.mm(sigma_y))
91 |
92 | # Product might be almost singular
93 | if not torch.isfinite(covmean).all():
94 | print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates")
95 | offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon
96 | covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset))
97 |
98 | # Numerical error might give slight imaginary component
99 | if torch.is_complex(covmean):
100 | if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3):
101 | raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.")
102 | covmean = covmean.real
103 |
104 | tr_covmean = torch.trace(covmean)
105 | return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean
106 |
--------------------------------------------------------------------------------
/generative/engines/prepare_batch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from typing import Dict, Mapping, Optional, Union
15 |
16 | import torch
17 | import torch.nn as nn
18 | from monai.engines import PrepareBatch, default_prepare_batch
19 |
20 |
21 | class DiffusionPrepareBatch(PrepareBatch):
22 | """
23 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
24 |
25 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
26 | return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
27 | This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
28 |
29 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
30 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
31 |
32 | """
33 |
34 | def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
35 | self.condition_name = condition_name
36 | self.num_train_timesteps = num_train_timesteps
37 |
38 | def get_noise(self, images: torch.Tensor) -> torch.Tensor:
39 | """Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
40 | return torch.randn_like(images)
41 |
42 | def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
43 | """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
44 | return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
45 |
46 | def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
47 | """Return the target for the loss function, this is the `noise` value by default."""
48 | return noise
49 |
50 | def __call__(
51 | self,
52 | batchdata: Dict[str, torch.Tensor],
53 | device: Optional[Union[str, torch.device]] = None,
54 | non_blocking: bool = False,
55 | **kwargs,
56 | ):
57 | images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
58 | noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
59 | timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
60 |
61 | target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
62 | infer_kwargs = {"noise": noise, "timesteps": timesteps}
63 |
64 | if self.condition_name is not None and isinstance(batchdata, Mapping):
65 | infer_kwargs["conditioning"] = batchdata[self.condition_name].to(
66 | device, non_blocking=non_blocking, **kwargs
67 | )
68 |
69 | # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
70 | return images, target, (), infer_kwargs
71 |
72 |
73 | class VPredictionPrepareBatch(DiffusionPrepareBatch):
74 | """
75 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
76 |
77 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
78 | from this compute the velocity using the provided scheduler. This value is used as the target in place of the
79 | noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
80 | being used in conjunction with this class expects a "noise" parameter to be provided.
81 |
82 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
83 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
84 |
85 | """
86 |
87 | def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
88 | super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
89 | self.scheduler = scheduler
90 |
91 | def get_target(self, images, noise, timesteps):
92 | return self.scheduler.get_velocity(images, noise, timesteps)
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # MONAI core
132 | MONAI/
133 |
134 | # Created by https://www.toptal.com/developers/gitignore/api/jetbrains
135 | # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains
136 |
137 | ### JetBrains ###
138 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
139 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
140 |
141 | # User-specific stuff
142 | .idea/**/workspace.xml
143 | .idea/**/tasks.xml
144 | .idea/**/usage.statistics.xml
145 | .idea/**/dictionaries
146 | .idea/**/shelf
147 |
148 | # AWS User-specific
149 | .idea/**/aws.xml
150 |
151 | # Generated files
152 | .idea/**/contentModel.xml
153 |
154 | # Sensitive or high-churn files
155 | .idea/**/dataSources/
156 | .idea/**/dataSources.ids
157 | .idea/**/dataSources.local.xml
158 | .idea/**/sqlDataSources.xml
159 | .idea/**/dynamic.xml
160 | .idea/**/uiDesigner.xml
161 | .idea/**/dbnavigator.xml
162 |
163 | # Gradle
164 | .idea/**/gradle.xml
165 | .idea/**/libraries
166 |
167 | # Gradle and Maven with auto-import
168 | # When using Gradle or Maven with auto-import, you should exclude module files,
169 | # since they will be recreated, and may cause churn. Uncomment if using
170 | # auto-import.
171 | # .idea/artifacts
172 | # .idea/compiler.xml
173 | # .idea/jarRepositories.xml
174 | # .idea/modules.xml
175 | # .idea/*.iml
176 | # .idea/modules
177 | # *.iml
178 | # *.ipr
179 |
180 | # CMake
181 | cmake-build-*/
182 |
183 | # Mongo Explorer plugin
184 | .idea/**/mongoSettings.xml
185 |
186 | # File-based project format
187 | *.iws
188 |
189 | # IntelliJ
190 | out/
191 |
192 | # mpeltonen/sbt-idea plugin
193 | .idea_modules/
194 |
195 | # JIRA plugin
196 | atlassian-ide-plugin.xml
197 |
198 | # Cursive Clojure plugin
199 | .idea/replstate.xml
200 |
201 | # SonarLint plugin
202 | .idea/sonarlint/
203 |
204 | # Crashlytics plugin (for Android Studio and IntelliJ)
205 | com_crashlytics_export_strings.xml
206 | crashlytics.properties
207 | crashlytics-build.properties
208 | fabric.properties
209 |
210 | # Editor-based Rest Client
211 | .idea/httpRequests
212 |
213 | # Android studio 3.1+ serialized cache file
214 | .idea/caches/build_file_checksums.ser
215 |
216 | ### JetBrains Patch ###
217 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
218 |
219 | # *.iml
220 | # modules.xml
221 | # .idea/misc.xml
222 | # *.ipr
223 |
224 | # Sonarlint plugin
225 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
226 | .idea/**/sonarlint/
227 |
228 | # SonarQube Plugin
229 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
230 | .idea/**/sonarIssues.xml
231 |
232 | # Markdown Navigator plugin
233 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
234 | .idea/**/markdown-navigator.xml
235 | .idea/**/markdown-navigator-enh.xml
236 | .idea/**/markdown-navigator/
237 |
238 | # Cache file creation bug
239 | # See https://youtrack.jetbrains.com/issue/JBR-2257
240 | .idea/$CACHE_FILE$
241 |
242 | # CodeStream plugin
243 | # https://plugins.jetbrains.com/plugin/12206-codestream
244 | .idea/codestream.xml
245 |
246 | # Azure Toolkit for IntelliJ plugin
247 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij
248 | .idea/**/azureSettings.xml
249 |
250 | # End of https://www.toptal.com/developers/gitignore/api/jetbrains
251 |
--------------------------------------------------------------------------------
/tests/test_spade_vaegan.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import numpy as np
17 | import torch
18 | from monai.networks import eval_mode
19 | from parameterized import parameterized
20 |
21 | from generative.networks.nets import SPADENet
22 |
23 | CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]]
24 | CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]]
25 | CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]]
26 |
27 |
28 | def create_semantic_data(shape: list, semantic_regions: int):
29 | """
30 | To create semantic and image mock inputs for the network.
31 | Args:
32 | shape: input shape
33 | semantic_regions: number of semantic regions
34 | Returns:
35 | """
36 | out_label = torch.zeros(shape)
37 | out_image = torch.zeros(shape) + torch.randn(shape) * 0.01
38 | for i in range(1, semantic_regions):
39 | shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape]
40 | start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)]
41 | if len(shape) == 2:
42 | out_label[
43 | start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])
44 | ] = i
45 | base_intensity = torch.ones(shape_square) * np.random.randn()
46 | out_image[
47 | start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])
48 | ] = (base_intensity + torch.randn(shape_square) * 0.1)
49 | elif len(shape) == 3:
50 | out_label[
51 | start_point[0] : (start_point[0] + shape_square[0]),
52 | start_point[1] : (start_point[1] + shape_square[1]),
53 | start_point[2] : (start_point[2] + shape_square[2]),
54 | ] = i
55 | base_intensity = torch.ones(shape_square) * np.random.randn()
56 | out_image[
57 | start_point[0] : (start_point[0] + shape_square[0]),
58 | start_point[1] : (start_point[1] + shape_square[1]),
59 | start_point[2] : (start_point[2] + shape_square[2]),
60 | ] = (base_intensity + torch.randn(shape_square) * 0.1)
61 | else:
62 | ValueError("Supports only 2D and 3D tensors")
63 |
64 | # One hot encode label
65 | out_label_ = torch.zeros([semantic_regions] + list(out_label.shape))
66 | for ch in range(semantic_regions):
67 | out_label_[ch, ...] = out_label == ch
68 |
69 | return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0)
70 |
71 |
72 | class TestDiffusionModelUNet2D(unittest.TestCase):
73 | @parameterized.expand(CASE_2D)
74 | def test_forward_2d(self, input_param):
75 | """
76 | Check that forward method is called correctly and output shape matches.
77 | """
78 | net = SPADENet(*input_param)
79 | in_label, in_image = create_semantic_data(input_param[4], input_param[3])
80 | with eval_mode(net):
81 | out, kld = net(in_label, in_image)
82 | self.assertEqual(
83 | False,
84 | True in torch.isnan(out)
85 | or True in torch.isinf(out)
86 | or True in torch.isinf(kld)
87 | or True in torch.isinf(kld),
88 | )
89 | self.assertEqual(list(out.shape), [1, 1, 64, 64])
90 |
91 | @parameterized.expand(CASE_2D_BIS)
92 | def test_encoder_decoder(self, input_param):
93 | """
94 | Check that forward method is called correctly and output shape matches.
95 | """
96 | net = SPADENet(*input_param)
97 | in_label, in_image = create_semantic_data(input_param[4], input_param[3])
98 | with eval_mode(net):
99 | out_z = net.encode(in_image)
100 | self.assertEqual(list(out_z.shape), [1, 16])
101 | out_i = net.decode(in_label, out_z)
102 | self.assertEqual(list(out_i.shape), [1, 1, 64, 64])
103 |
104 | @parameterized.expand(CASE_3D)
105 | def test_forward_3d(self, input_param):
106 | """
107 | Check that forward method is called correctly and output shape matches.
108 | """
109 | net = SPADENet(*input_param)
110 | in_label, in_image = create_semantic_data(input_param[4], input_param[3])
111 | with eval_mode(net):
112 | out, kld = net(in_label, in_image)
113 | self.assertEqual(
114 | False,
115 | True in torch.isnan(out)
116 | or True in torch.isinf(out)
117 | or True in torch.isinf(kld)
118 | or True in torch.isinf(kld),
119 | )
120 | self.assertEqual(list(out.shape), [1, 1, 64, 64, 64])
121 |
122 | def test_shape_wrong(self):
123 | """
124 | We input an input shape that isn't divisible by 2**(n downstream steps)
125 | """
126 | with self.assertRaises(ValueError):
127 | _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True)
128 |
129 |
130 | if __name__ == "__main__":
131 | unittest.main()
132 |
--------------------------------------------------------------------------------
/tests/runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import argparse
15 | import glob
16 | import inspect
17 | import os
18 | import re
19 | import sys
20 | import time
21 | import unittest
22 |
23 | from monai.utils import PerfContext
24 |
25 | results: dict = {}
26 |
27 |
28 | class TimeLoggingTestResult(unittest.TextTestResult):
29 | """Overload the default results so that we can store the results."""
30 |
31 | def __init__(self, *args, **kwargs):
32 | super().__init__(*args, **kwargs)
33 | self.timed_tests = {}
34 |
35 | def startTest(self, test): # noqa: N802
36 | """Start timer, print test name, do normal test."""
37 | self.start_time = time.time()
38 | name = self.getDescription(test)
39 | self.stream.write(f"Starting test: {name}...\n")
40 | super().startTest(test)
41 |
42 | def stopTest(self, test): # noqa: N802
43 | """On test end, get time, print, store and do normal behaviour."""
44 | elapsed = time.time() - self.start_time
45 | name = self.getDescription(test)
46 | self.stream.write(f"Finished test: {name} ({elapsed:.03}s)\n")
47 | if name in results:
48 | raise AssertionError("expected all keys to be unique")
49 | results[name] = elapsed
50 | super().stopTest(test)
51 |
52 |
53 | def print_results(results, discovery_time, thresh, status):
54 | # only keep results >= threshold
55 | results = dict(filter(lambda x: x[1] > thresh, results.items()))
56 | if len(results) == 0:
57 | return
58 | print(f"\n\n{status}, printing completed times >{thresh}s in ascending order...\n")
59 | timings = dict(sorted(results.items(), key=lambda item: item[1]))
60 |
61 | for r in timings:
62 | if timings[r] >= thresh:
63 | print(f"{r} ({timings[r]:.03}s)")
64 | print(f"test discovery time: {discovery_time:.03}s")
65 | print(f"total testing time: {sum(results.values()):.03}s")
66 | print("Remember to check above times for any errors!")
67 |
68 |
69 | def parse_args():
70 | parser = argparse.ArgumentParser(description="Runner for MONAI unittests with timing.")
71 | parser.add_argument(
72 | "-s", action="store", dest="path", default=".", help="Directory to start discovery (default: '%(default)s')"
73 | )
74 | parser.add_argument(
75 | "-p",
76 | action="store",
77 | dest="pattern",
78 | default="test_*.py",
79 | help="Pattern to match tests (default: '%(default)s')",
80 | )
81 | parser.add_argument(
82 | "-t",
83 | "--thresh",
84 | dest="thresh",
85 | default=10.0,
86 | type=float,
87 | help="Display tests longer than given threshold (default: %(default)d)",
88 | )
89 | parser.add_argument(
90 | "-v",
91 | "--verbosity",
92 | action="store",
93 | dest="verbosity",
94 | type=int,
95 | default=1,
96 | help="Verbosity level (default: %(default)d)",
97 | )
98 | parser.add_argument("-q", "--quick", action="store_true", dest="quick", default=False, help="Only do quick tests")
99 | parser.add_argument(
100 | "-f", "--failfast", action="store_true", dest="failfast", default=False, help="Stop testing on first failure"
101 | )
102 | args = parser.parse_args()
103 | print(f"Running tests in folder: '{args.path}'")
104 | if args.pattern:
105 | print(f"With file pattern: '{args.pattern}'")
106 |
107 | return args
108 |
109 |
110 | def get_default_pattern(loader):
111 | signature = inspect.signature(loader.discover)
112 | params = {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
113 | return params["pattern"]
114 |
115 |
116 | if __name__ == "__main__":
117 | # Parse input arguments
118 | args = parse_args()
119 |
120 | # If quick is desired, set environment variable
121 | if args.quick:
122 | os.environ["QUICKTEST"] = "True"
123 |
124 | # Get all test names (optionally from some path with some pattern)
125 | with PerfContext() as pc:
126 | # the files are searched from `tests/` folder, starting with `test_`
127 | files = glob.glob(os.path.join(os.path.dirname(__file__), "test_*.py"))
128 | cases = []
129 | for test_module in {os.path.basename(f)[:-3] for f in files}:
130 | if re.match(args.pattern, test_module):
131 | cases.append(f"tests.{test_module}")
132 | else:
133 | print(f"monai test runner: excluding tests.{test_module}")
134 | tests = unittest.TestLoader().loadTestsFromNames(cases)
135 | discovery_time = pc.total_time
136 | print(f"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.")
137 |
138 | test_runner = unittest.runner.TextTestRunner(
139 | resultclass=TimeLoggingTestResult, verbosity=args.verbosity, failfast=args.failfast
140 | )
141 |
142 | # Use try catches to print the current results if encountering exception or keyboard interruption
143 | try:
144 | test_result = test_runner.run(tests)
145 | print_results(results, discovery_time, args.thresh, "tests finished")
146 | sys.exit(not test_result.wasSuccessful())
147 | except KeyboardInterrupt:
148 | print_results(results, discovery_time, args.thresh, "tests cancelled")
149 | sys.exit(1)
150 | except Exception:
151 | print_results(results, discovery_time, args.thresh, "exception reached")
152 | raise
153 |
--------------------------------------------------------------------------------
/tests/test_patch_gan.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import unittest
15 |
16 | import torch
17 | from monai.networks import eval_mode
18 | from parameterized import parameterized
19 |
20 | from generative.networks.nets.patchgan_discriminator import MultiScalePatchDiscriminator
21 | from tests.utils import test_script_save
22 |
23 | TEST_2D = [
24 | {
25 | "num_d": 2,
26 | "num_layers_d": 3,
27 | "spatial_dims": 2,
28 | "num_channels": 8,
29 | "in_channels": 3,
30 | "out_channels": 1,
31 | "kernel_size": 3,
32 | "activation": "LEAKYRELU",
33 | "norm": "instance",
34 | "bias": False,
35 | "dropout": 0.1,
36 | "minimum_size_im": 256,
37 | },
38 | torch.rand([1, 3, 256, 512]),
39 | [(1, 1, 32, 64), (1, 1, 4, 8)],
40 | [4, 7],
41 | ]
42 | TEST_3D = [
43 | {
44 | "num_d": 2,
45 | "num_layers_d": 3,
46 | "spatial_dims": 3,
47 | "num_channels": 8,
48 | "in_channels": 3,
49 | "out_channels": 1,
50 | "kernel_size": 3,
51 | "activation": "LEAKYRELU",
52 | "norm": "instance",
53 | "bias": False,
54 | "dropout": 0.1,
55 | "minimum_size_im": 256,
56 | },
57 | torch.rand([1, 3, 256, 512, 256]),
58 | [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)],
59 | [4, 7],
60 | ]
61 | TEST_3D_POOL = [
62 | {
63 | "num_d": 2,
64 | "num_layers_d": 3,
65 | "spatial_dims": 3,
66 | "num_channels": 8,
67 | "in_channels": 3,
68 | "out_channels": 1,
69 | "kernel_size": 3,
70 | "pooling_method": "max",
71 | "activation": "LEAKYRELU",
72 | "norm": "instance",
73 | "bias": False,
74 | "dropout": 0.1,
75 | "minimum_size_im": 256,
76 | },
77 | torch.rand([1, 3, 256, 512, 256]),
78 | [(1, 1, 32, 64, 32), (1, 1, 16, 32, 16)],
79 | [4, 4],
80 | ]
81 | TEST_2D_POOL = [
82 | {
83 | "num_d": 4,
84 | "num_layers_d": 3,
85 | "spatial_dims": 2,
86 | "num_channels": 8,
87 | "in_channels": 3,
88 | "out_channels": 1,
89 | "kernel_size": 3,
90 | "pooling_method": "avg",
91 | "activation": "LEAKYRELU",
92 | "norm": "instance",
93 | "bias": False,
94 | "dropout": 0.1,
95 | "minimum_size_im": 256,
96 | },
97 | torch.rand([1, 3, 256, 512]),
98 | [(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16), (1, 1, 4, 8)],
99 | [4, 4, 4, 4],
100 | ]
101 | TEST_LAYER_LIST = [
102 | {
103 | "num_d": 3,
104 | "num_layers_d": [3,4,5],
105 | "spatial_dims": 2,
106 | "num_channels": 8,
107 | "in_channels": 3,
108 | "out_channels": 1,
109 | "kernel_size": 3,
110 | "activation": "LEAKYRELU",
111 | "norm": "instance",
112 | "bias": False,
113 | "dropout": 0.1,
114 | "minimum_size_im": 256,
115 | },
116 | torch.rand([1, 3, 256, 512]),
117 | [(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16)],
118 | [4, 5, 6],
119 | ]
120 | TEST_TOO_SMALL_SIZE = [
121 | {
122 | "num_d": 2,
123 | "num_layers_d": 6,
124 | "spatial_dims": 2,
125 | "num_channels": 8,
126 | "in_channels": 3,
127 | "out_channels": 1,
128 | "kernel_size": 3,
129 | "activation": "LEAKYRELU",
130 | "norm": "instance",
131 | "bias": False,
132 | "dropout": 0.1,
133 | "minimum_size_im": 256,
134 | }
135 | ]
136 | TEST_MISMATCHED_NUM_LAYERS = [
137 | {
138 | "num_d": 5,
139 | "num_layers_d": [3,4,5],
140 | "spatial_dims": 2,
141 | "num_channels": 8,
142 | "in_channels": 3,
143 | "out_channels": 1,
144 | "kernel_size": 3,
145 | "activation": "LEAKYRELU",
146 | "norm": "instance",
147 | "bias": False,
148 | "dropout": 0.1,
149 | "minimum_size_im": 256,
150 | }
151 | ]
152 |
153 | CASES = [TEST_2D, TEST_3D, TEST_3D_POOL, TEST_2D_POOL, TEST_LAYER_LIST]
154 |
155 | class TestPatchGAN(unittest.TestCase):
156 | @parameterized.expand(CASES)
157 | def test_shape(self, input_param, input_data, expected_shape, features_lengths=None):
158 | net = MultiScalePatchDiscriminator(**input_param)
159 | with eval_mode(net):
160 | result, features = net.forward(input_data)
161 | for r_ind, r in enumerate(result):
162 | self.assertEqual(tuple(r.shape), expected_shape[r_ind])
163 | for o_d_ind, o_d in enumerate(features):
164 | self.assertEqual(len(o_d), features_lengths[o_d_ind])
165 |
166 | def test_too_small_shape(self):
167 | with self.assertRaises(AssertionError):
168 | MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])
169 |
170 | def test_mismatched_num_layers(self):
171 | with self.assertRaises(AssertionError):
172 | MultiScalePatchDiscriminator(**TEST_MISMATCHED_NUM_LAYERS[0])
173 |
174 | def test_script(self):
175 | net = MultiScalePatchDiscriminator(
176 | num_d=2,
177 | num_layers_d=3,
178 | spatial_dims=2,
179 | num_channels=8,
180 | in_channels=3,
181 | out_channels=1,
182 | kernel_size=3,
183 | activation="LEAKYRELU",
184 | norm="instance",
185 | bias=False,
186 | dropout=0.1,
187 | minimum_size_im=256,
188 | )
189 | i = torch.rand([1, 3, 256, 512])
190 | test_script_save(net, i)
191 |
192 |
193 | if __name__ == "__main__":
194 | unittest.main()
195 |
--------------------------------------------------------------------------------
/generative/networks/blocks/selfattention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import importlib.util
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import functional as F
20 |
21 | if importlib.util.find_spec("xformers") is not None:
22 | import xformers.ops as xops
23 |
24 | has_xformers = True
25 | else:
26 | has_xformers = False
27 |
28 |
29 | class SABlock(nn.Module):
30 | """
31 | A self-attention block, based on: "Dosovitskiy et al.,
32 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
33 |
34 | Args:
35 | hidden_size: dimension of hidden layer.
36 | num_heads: number of attention heads.
37 | dropout_rate: dropout ratio. Defaults to no dropout.
38 | qkv_bias: bias term for the qkv linear layer.
39 | causal: whether to use causal attention.
40 | sequence_length: if causal is True, it is necessary to specify the sequence length.
41 | with_cross_attention: Whether to use cross attention for conditioning.
42 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
43 | """
44 |
45 | def __init__(
46 | self,
47 | hidden_size: int,
48 | num_heads: int,
49 | dropout_rate: float = 0.0,
50 | qkv_bias: bool = False,
51 | causal: bool = False,
52 | sequence_length: int | None = None,
53 | with_cross_attention: bool = False,
54 | use_flash_attention: bool = False,
55 | ) -> None:
56 | super().__init__()
57 | self.hidden_size = hidden_size
58 | self.num_heads = num_heads
59 | self.head_dim = hidden_size // num_heads
60 | self.scale = 1.0 / math.sqrt(self.head_dim)
61 | self.causal = causal
62 | self.sequence_length = sequence_length
63 | self.with_cross_attention = with_cross_attention
64 | self.use_flash_attention = use_flash_attention
65 |
66 | if not (0 <= dropout_rate <= 1):
67 | raise ValueError("dropout_rate should be between 0 and 1.")
68 | self.dropout_rate = dropout_rate
69 |
70 | if hidden_size % num_heads != 0:
71 | raise ValueError("hidden size should be divisible by num_heads.")
72 |
73 | if causal and sequence_length is None:
74 | raise ValueError("sequence_length is necessary for causal attention.")
75 |
76 | if use_flash_attention and not has_xformers:
77 | raise ValueError("use_flash_attention is True but xformers is not installed.")
78 |
79 | # key, query, value projections
80 | self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
81 | self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
82 | self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
83 |
84 | # regularization
85 | self.drop_weights = nn.Dropout(dropout_rate)
86 | self.drop_output = nn.Dropout(dropout_rate)
87 |
88 | # output projection
89 | self.out_proj = nn.Linear(hidden_size, hidden_size)
90 |
91 | if causal and sequence_length is not None:
92 | # causal mask to ensure that attention is only applied to the left in the input sequence
93 | self.register_buffer(
94 | "causal_mask",
95 | torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
96 | )
97 |
98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
99 | b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
100 |
101 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
102 | query = self.to_q(x)
103 |
104 | kv = context if context is not None else x
105 | _, kv_t, _ = kv.size()
106 | key = self.to_k(kv)
107 | value = self.to_v(kv)
108 |
109 | query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs)
110 | key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
111 | value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
112 |
113 | if self.use_flash_attention:
114 | query = query.contiguous()
115 | key = key.contiguous()
116 | value = value.contiguous()
117 | y = xops.memory_efficient_attention(
118 | query=query,
119 | key=key,
120 | value=value,
121 | scale=self.scale,
122 | p=self.dropout_rate,
123 | attn_bias=xops.LowerTriangularMask() if self.causal else None,
124 | )
125 |
126 | else:
127 | query = query.transpose(1, 2) # (b, nh, t, hs)
128 | key = key.transpose(1, 2) # (b, nh, kv_t, hs)
129 | value = value.transpose(1, 2) # (b, nh, kv_t, hs)
130 |
131 | # manual implementation of attention
132 | query = query * self.scale
133 | attention_scores = query @ key.transpose(-2, -1)
134 |
135 | if self.causal:
136 | attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
137 |
138 | attention_probs = F.softmax(attention_scores, dim=-1)
139 | attention_probs = self.drop_weights(attention_probs)
140 | y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)
141 |
142 | y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs)
143 |
144 | y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side
145 |
146 | y = self.out_proj(y)
147 | y = self.drop_output(y)
148 | return y
149 |
--------------------------------------------------------------------------------
/generative/metrics/ms_ssim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | from collections.abc import Sequence
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from monai.metrics.regression import RegressionMetric
19 | from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep
20 |
21 | from generative.metrics.ssim import compute_ssim_and_cs
22 |
23 |
24 | class KernelType(StrEnum):
25 | GAUSSIAN = "gaussian"
26 | UNIFORM = "uniform"
27 |
28 |
29 | class MultiScaleSSIMMetric(RegressionMetric):
30 | """
31 | Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM).
32 |
33 | [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November.
34 | Multiscale structural similarity for image quality assessment.
35 | In The Thirty-Seventh Asilomar Conference on Signals, Systems
36 | & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee.
37 |
38 | Args:
39 | spatial_dims: number of spatial dimensions of the input images.
40 | data_range: value range of input images. (usually 1.0 or 255)
41 | kernel_type: type of kernel, can be "gaussian" or "uniform".
42 | kernel_size: size of kernel
43 | kernel_sigma: standard deviation for Gaussian kernel.
44 | k1: stability constant used in the luminance denominator
45 | k2: stability constant used in the contrast denominator
46 | weights: parameters for image similarity and contrast sensitivity at different resolution scores.
47 | reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
48 | available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
49 | ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
50 | get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
51 | """
52 |
53 | def __init__(
54 | self,
55 | spatial_dims: int,
56 | data_range: float = 1.0,
57 | kernel_type: KernelType | str = KernelType.GAUSSIAN,
58 | kernel_size: int | Sequence[int, ...] = 11,
59 | kernel_sigma: float | Sequence[float, ...] = 1.5,
60 | k1: float = 0.01,
61 | k2: float = 0.03,
62 | weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
63 | reduction: MetricReduction | str = MetricReduction.MEAN,
64 | get_not_nans: bool = False,
65 | ) -> None:
66 | super().__init__(reduction=reduction, get_not_nans=get_not_nans)
67 |
68 | self.spatial_dims = spatial_dims
69 | self.data_range = data_range
70 | self.kernel_type = kernel_type
71 |
72 | if not isinstance(kernel_size, Sequence):
73 | kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)
74 | self.kernel_size = kernel_size
75 |
76 | if not isinstance(kernel_sigma, Sequence):
77 | kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims)
78 | self.kernel_sigma = kernel_sigma
79 |
80 | self.k1 = k1
81 | self.k2 = k2
82 | self.weights = weights
83 |
84 | def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85 | """
86 | Args:
87 | y_pred: Predicted image.
88 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
89 | y: Reference image.
90 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
91 |
92 | Raises:
93 | ValueError: when `y_pred` is not a 2D or 3D image.
94 | """
95 | dims = y_pred.ndimension()
96 | if self.spatial_dims == 2 and dims != 4:
97 | raise ValueError(
98 | f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} "
99 | f"spatial dimensions, got {dims}."
100 | )
101 |
102 | if self.spatial_dims == 3 and dims != 5:
103 | raise ValueError(
104 | f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}"
105 | f" spatial dimensions, got {dims}."
106 | )
107 |
108 | # check if image have enough size for the number of downsamplings and the size of the kernel
109 | weights_div = max(1, (len(self.weights) - 1)) ** 2
110 | y_pred_spatial_dims = y_pred.shape[2:]
111 | for i in range(len(y_pred_spatial_dims)):
112 | if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1:
113 | raise ValueError(
114 | f"For a given number of `weights` parameters {len(self.weights)} and kernel size "
115 | f"{self.kernel_size[i]}, the image height must be larger than "
116 | f"{(self.kernel_size[i] - 1) * weights_div}."
117 | )
118 |
119 | weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float)
120 |
121 | avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d")
122 |
123 | multiscale_list: list[torch.Tensor] = []
124 | for _ in range(len(weights)):
125 | ssim, cs = compute_ssim_and_cs(
126 | y_pred=y_pred,
127 | y=y,
128 | spatial_dims=self.spatial_dims,
129 | data_range=self.data_range,
130 | kernel_type=self.kernel_type,
131 | kernel_size=self.kernel_size,
132 | kernel_sigma=self.kernel_sigma,
133 | k1=self.k1,
134 | k2=self.k2,
135 | )
136 |
137 | cs_per_batch = cs.view(cs.shape[0], -1).mean(1)
138 |
139 | multiscale_list.append(torch.relu(cs_per_batch))
140 | y_pred = avg_pool(y_pred, kernel_size=2)
141 | y = avg_pool(y, kernel_size=2)
142 |
143 | ssim = ssim.view(ssim.shape[0], -1).mean(1)
144 | multiscale_list[-1] = torch.relu(ssim)
145 | multiscale_list = torch.stack(multiscale_list)
146 |
147 | ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0)
148 |
149 | ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean(
150 | 1, keepdim=True
151 | )
152 |
153 | return ms_ssim_per_batch
154 |
--------------------------------------------------------------------------------
/tests/test_integration_workflows_adversarial.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from __future__ import annotations
13 |
14 | import os
15 | import shutil
16 | import tempfile
17 | import unittest
18 | from glob import glob
19 |
20 | import monai
21 | import nibabel as nib
22 | import numpy as np
23 | import torch
24 | from monai.data import create_test_image_2d
25 | from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler
26 | from monai.networks.nets import AutoEncoder, Discriminator
27 | from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd
28 | from monai.utils import CommonKeys, set_determinism
29 |
30 | from generative.engines import AdversarialTrainer
31 | from generative.utils import AdversarialKeys as Keys
32 | from tests.utils import DistTestCase, TimedCall, skip_if_quick
33 |
34 |
35 | def run_training_test(root_dir, device="cuda:0"):
36 | learning_rate = 2e-4
37 | real_label = 1
38 | fake_label = 0
39 |
40 | real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
41 | train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)]
42 |
43 | # prepare real data
44 | train_transforms = Compose(
45 | [
46 | LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]),
47 | EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2),
48 | ScaleIntensityd(keys=[CommonKeys.IMAGE]),
49 | RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5),
50 | ]
51 | )
52 | train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
53 | train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
54 |
55 | # Create Discriminator
56 | discriminator_net = Discriminator(
57 | in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5
58 | ).to(device)
59 | discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate)
60 | discriminator_loss_criterion = torch.nn.BCELoss()
61 |
62 | def discriminator_loss(real_logits, fake_logits):
63 | real_target = real_logits.new_full((real_logits.shape[0], 1), real_label)
64 | fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label)
65 | real_loss = discriminator_loss_criterion(real_logits, real_target)
66 | fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)
67 | return torch.div(torch.add(real_loss, fake_loss), 2)
68 |
69 | # Create Generator
70 | generator_network = AutoEncoder(
71 | spatial_dims=2,
72 | in_channels=1,
73 | out_channels=1,
74 | channels=(8, 16, 32, 64),
75 | strides=(2, 2, 2, 2),
76 | num_res_units=1,
77 | num_inter_units=1,
78 | )
79 | generator_network = generator_network.to(device)
80 | generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate)
81 | generator_loss_criterion = torch.nn.MSELoss()
82 |
83 | def reconstruction_loss(recon_images, real_images):
84 | return generator_loss_criterion(recon_images, real_images)
85 |
86 | def generator_loss(fake_logits):
87 | fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label)
88 | recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)
89 | return recon_loss
90 |
91 | key_train_metric = None
92 |
93 | train_handlers = [
94 | StatsHandler(
95 | name="training_loss",
96 | output_transform=lambda x: {
97 | Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],
98 | Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],
99 | Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],
100 | },
101 | ),
102 | TensorBoardStatsHandler(
103 | log_dir=root_dir,
104 | tag_name="training_loss",
105 | output_transform=lambda x: {
106 | Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],
107 | Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],
108 | Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],
109 | },
110 | ),
111 | CheckpointSaver(
112 | save_dir=root_dir,
113 | save_dict={"g_net": generator_network, "d_net": discriminator_net},
114 | save_interval=2,
115 | epoch_level=True,
116 | ),
117 | ]
118 |
119 | num_epochs = 5
120 |
121 | trainer = AdversarialTrainer(
122 | device=device,
123 | max_epochs=num_epochs,
124 | train_data_loader=train_loader,
125 | g_network=generator_network,
126 | g_optimizer=generator_optimiser,
127 | g_loss_function=generator_loss,
128 | recon_loss_function=reconstruction_loss,
129 | d_network=discriminator_net,
130 | d_optimizer=discriminator_opt,
131 | d_loss_function=discriminator_loss,
132 | non_blocking=True,
133 | key_train_metric=key_train_metric,
134 | train_handlers=train_handlers,
135 | )
136 | trainer.run()
137 |
138 | return trainer.state
139 |
140 |
141 | @skip_if_quick
142 | class IntegrationWorkflowsAdversarialTrainer(DistTestCase):
143 | def setUp(self):
144 | set_determinism(seed=0)
145 |
146 | self.data_dir = tempfile.mkdtemp()
147 | for i in range(40):
148 | im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1)
149 | n = nib.Nifti1Image(im, np.eye(4))
150 | nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz"))
151 |
152 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
153 | monai.config.print_config()
154 |
155 | def tearDown(self):
156 | set_determinism(seed=None)
157 | shutil.rmtree(self.data_dir)
158 |
159 | @TimedCall(seconds=200, daemon=False)
160 | def test_training(self):
161 | torch.manual_seed(0)
162 |
163 | finish_state = run_training_test(self.data_dir, device=self.device)
164 |
165 | # Assert AdversarialTrainer training finished
166 | self.assertEqual(finish_state.iteration, 100)
167 | self.assertEqual(finish_state.epoch, 5)
168 |
169 |
170 | if __name__ == "__main__":
171 | unittest.main()
172 |
--------------------------------------------------------------------------------
/tutorials/README.md:
--------------------------------------------------------------------------------
1 | # MONAI Generative Models Tutorials
2 | This directory hosts the MONAI Generative Models tutorials.
3 |
4 | ## Requirements
5 | To run the tutorials, you will need to install the Generative Models package.
6 | Besides that, most of the examples and tutorials require
7 | [matplotlib](https://matplotlib.org/) and [Jupyter Notebook](https://jupyter.org/).
8 |
9 | These can be installed with the following:
10 |
11 | ```bash
12 | python -m pip install -U pip
13 | python -m pip install -U matplotlib
14 | python -m pip install -U notebook
15 | ```
16 |
17 | Some of the examples may require optional dependencies. In case of any optional import errors,
18 | please install the relevant packages according to MONAI's [installation guide](https://docs.monai.io/en/latest/installation.html).
19 | Or install all optional requirements with the following:
20 |
21 | ```bash
22 | pip install -r requirements-dev.txt
23 | ```
24 |
25 | ## List of notebooks and examples
26 |
27 | ### Table of Contents
28 | 1. [Diffusion Models](#1-diffusion-models)
29 | 2. [Latent Diffusion Models](#2-latent-diffusion-models)
30 | 3. [VQ-VAE + Transformers](#3-vq-vae--transformers)
31 |
32 |
33 | ### 1. Diffusion Models
34 |
35 | #### Image synthesis with Diffusion Models
36 |
37 | * [Training a 3D Denoising Diffusion Probabilistic Model](./generative/3d_ddpm/3d_ddpm_tutorial.ipynb): This tutorial shows how to easily
38 | train a DDPM on 3D medical data. In this example, we use a downsampled version of the BraTS dataset. We will show how to
39 | make use of the UNet model and the Noise Scheduler necessary to train a diffusion model. Besides that, we show how to
40 | use the DiffusionInferer class to simplify the training and sampling processes. Finally, after training the model, we
41 | show how to use a Noise Scheduler with fewer timesteps to sample synthetic images.
42 |
43 | * [Training a 2D Denoising Diffusion Probabilistic Model](./generative/2d_ddpm/2d_ddpm_tutorial.ipynb): This tutorial shows how to easily
44 | train a DDPM on medical data. In this example, we use the MedNIST dataset, which is very suitable for beginners as a tutorial.
45 |
46 | * [Comparing different noise schedulers](./generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb): In this tutorial, we compare the
47 | performance of different noise schedulers. We will show how to sample a diffusion model using the DDPM, DDIM, and PNDM
48 | schedulers and how different numbers of timesteps affect the quality of the samples.
49 |
50 | * [Training a 2D Denoising Diffusion Probabilistic Model with different parameterisation](./generative/2d_ddpm/2d_ddpm_tutorial_v_prediction.ipynb):
51 | In MONAI Generative Models, we support different parameterizations for the diffusion model (epsilon, sample, and
52 | v-prediction). In this tutorial, we show how to train a DDPM using the v-prediction parameterization, which improves the
53 | stability and convergence of the model.
54 |
55 | * [Training a 2D DDPM using Pytorch Ignite](./generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb): Here, we show how to train a DDPM
56 | on medical data using Pytorch Ignite. We will show how to use the DiffusionPrepareBatch to prepare the model inputs and MONAI's SupervisedTrainer and SupervisedEvaluator to train DDPMs.
57 |
58 | * [Using a 2D DDPM to inpaint images](./generative/2d_ddpm/2d_ddpm_inpainting.ipynb): In this tutorial, we show how to use a DDPM to
59 | inpaint of 2D images from the MedNIST dataset using the RePaint method.
60 |
61 | * [Generating conditional samples with a 2D DDPM using classifier-free guidance](./generative/classifier_free_guidance/2d_ddpm_classifier_free_guidance_tutorial.ipynb):
62 | This tutorial shows how easily we can train a Diffusion Model and generate conditional samples using classifier-free guidance in
63 | the MONAI's framework.
64 |
65 | * [Training Diffusion models with Distributed Data Parallel](./generative/distributed_training/ddpm_training_ddp.py): This example shows how to execute distributed training and evaluation based on PyTorch native DistributedDataParallel
66 | module with torch.distributed.launch.
67 |
68 | #### Anomaly Detection with Diffusion Models
69 |
70 | * [Weakly Supervised Anomaly Detection with Implicit Guidance](./generative/anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb):
71 | This tutorial shows how to use a DDPM to perform weakly supervised anomaly detection using classifier-free (implicit) guidance based on the
72 | method proposed by Sanchez et al. [What is Healthy? Generative Counterfactual Diffusion for Lesion Localization](https://arxiv.org/abs/2207.12268). DGM 4 MICCAI 2022
73 |
74 |
75 | ### 2. Latent Diffusion Models
76 |
77 | #### Image synthesis with Latent Diffusion Models
78 |
79 | * [Training a 3D Latent Diffusion Model](./generative/3d_ldm/3d_ldm_tutorial.ipynb): This tutorial shows how to train a LDM on 3D medical
80 | data. In this example, we use the BraTS dataset. We show how to train an AutoencoderKL and connect it to an LDM. We also
81 | comment on the importance of the scaling factor in the LDM used to scale the latent representation of the AEKL to a suitable
82 | range for the diffusion model. Finally, we show how to use the LatentDiffusionInferer class to simplify the training and sampling.
83 |
84 | * [Training a 2D Latent Diffusion Model](./generative/2d_ldm/2d_ldm_tutorial.ipynb): This tutorial shows how to train an LDM on medical
85 | on the MedNIST dataset. We show how to train an AutoencoderKL and connect it to an LDM.
86 |
87 | * Training Autoencoder with KL-regularization: In this section, we focus on training an AutoencoderKL on [2D](./generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb) and [3D](./generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb) medical data,
88 | that can be used as the compression model used in a Latent Diffusion Model.
89 |
90 | #### Super-resolution with Latent Diffusion Models
91 |
92 | * [Super-resolution using Stable Diffusion Upscalers method](./generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution.ipynb):
93 | In this tutorial, we show how to perform super-resolution on 2D images from the MedNIST dataset using the Stable
94 | Diffusion Upscalers method. In this example, we will show how to condition a latent diffusion model on a low-resolution image
95 | as well as how to use the DiffusionModelUNet's class_labels conditioning to condition the model on the level of noise added to the image
96 | (aka "noise conditioning augmentation")
97 |
98 |
99 | ### 3. VQ-VAE + Transformers
100 |
101 | #### Image synthesis with VQ-VAE + Transformers
102 |
103 | * [Training a 2D VQ-VAE + Autoregressive Transformers](./generative/2d_vqvae_transformer/2d_vqvae_transformer_tutorial.ipynb): This tutorial shows how to train
104 | a Vector-Quantized Variation Autoencoder + Transformers on the MedNIST dataset.
105 |
106 | * Training VQ-VAEs and VQ-GANs: In this section, we show how to train Vector Quantized Variation Autoencoder (on [2D](./generative/2d_vqvae/2d_vqvae_tutorial.ipynb) and [3D](./generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb) data) and
107 | show how to use the PatchDiscriminator class to train a [VQ-GAN](./generative/2d_vqgan/2d_vqgan_tutorial.ipynb) and improve the quality of the generated images.
108 |
109 | #### Anomaly Detection with VQ-VAE + Transformers
110 |
111 | * [Anomaly Detection with 2D VQ-VAE + Autoregressive Transformers](./generative/anomaly_detection/anomaly_detection_with_transformers.ipynb): This tutorial shows how to
112 | train a Vector-Quantized Variation Autoencoder + Transformers on the MedNIST dataset and use it to extract the likelihood of
113 | testing images to be part of the in-distribution class (used during training).
114 |
--------------------------------------------------------------------------------