├── img ├── View.png ├── Model.png └── Strategy.png ├── mask ├── mask_4.pt ├── mask_6.pt ├── mask_8.pt ├── mask_10.pt └── mask.py ├── scripts ├── train_vae.sh ├── eval.sh ├── train_sd.sh └── train_controlnet.sh ├── utils ├── __pycache__ │ ├── common.cpython-38.pyc │ ├── common.cpython-39.pyc │ ├── common.cpython-310.pyc │ ├── convert_time.cpython-310.pyc │ ├── visual_stuff.cpython-38.pyc │ ├── visual_stuff.cpython-39.pyc │ ├── Class_dataloader.cpython-39.pyc │ ├── Condition_dataloader.cpython-310.pyc │ ├── Condition_dataloader.cpython-38.pyc │ ├── Condition_dataloader.cpython-39.pyc │ └── Condition_aug_dataloader.cpython-310.pyc ├── visual_stuff.py ├── common.py └── Condition_aug_dataloader.py ├── config ├── __pycache__ │ ├── con_vae.cpython-310.pyc │ ├── sd_config_512.cpython-310.pyc │ ├── vae_config_512.cpython-310.pyc │ └── controlnet_config.cpython-310.pyc ├── vae │ ├── __pycache__ │ │ ├── con_vae.cpython-310.pyc │ │ ├── config_vae_zheer.cpython-310.pyc │ │ └── config_monaivae_zheer.cpython-310.pyc │ ├── config_monaivae_zheer.py │ ├── vae_config_512.py │ ├── config_vae_zheer.py │ └── config_con_vae.py ├── diffusion │ ├── __pycache__ │ │ ├── sd_config_512.cpython-310.pyc │ │ ├── controlnet_config.cpython-310.pyc │ │ ├── config_mix_controlnet.cpython-310.pyc │ │ ├── config_mix_diffusion.cpython-310.pyc │ │ └── config_zheer_controlnet.cpython-310.pyc │ └── config_controlnet.py ├── cvae │ ├── __pycache__ │ │ └── config_cvae_controlnet.cpython-310.pyc │ └── config_cvae_controlnet.py └── 512_vae_config.yaml ├── unet ├── __pycache__ │ ├── MC_model.cpython-310.pyc │ └── Model_UKAN_Hybrid.cpython-310.pyc └── MC_model.py ├── generative ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── version.cpython-310.pyc ├── utils │ ├── __pycache__ │ │ ├── misc.cpython-310.pyc │ │ ├── enums.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ └── component_store.cpython-310.pyc │ ├── __init__.py │ ├── misc.py │ ├── enums.py │ ├── component_store.py │ └── ordering.py ├── losses │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── perceptual.cpython-310.pyc │ │ ├── spectral_loss.cpython-310.pyc │ │ └── adversarial_loss.cpython-310.pyc │ ├── __init__.py │ ├── spectral_loss.py │ └── adversarial_loss.py ├── inferers │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── inferer.cpython-310.pyc │ └── __init__.py ├── networks │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ ├── nets │ │ ├── __pycache__ │ │ │ ├── vqvae.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── controlnet.cpython-310.pyc │ │ │ ├── transformer.cpython-310.pyc │ │ │ ├── autoencoderkl.cpython-310.pyc │ │ │ ├── spade_network.cpython-310.pyc │ │ │ ├── spade_autoencoderkl.cpython-310.pyc │ │ │ ├── diffusion_model_unet.cpython-310.pyc │ │ │ ├── patchgan_discriminator.cpython-310.pyc │ │ │ └── spade_diffusion_model_unet.cpython-310.pyc │ │ ├── __init__.py │ │ ├── transformer.py │ │ └── patchgan_discriminator.py │ ├── blocks │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── spade_norm.cpython-310.pyc │ │ │ ├── selfattention.cpython-310.pyc │ │ │ ├── encoder_modules.cpython-310.pyc │ │ │ └── transformerblock.cpython-310.pyc │ │ ├── __init__.py │ │ ├── encoder_modules.py │ │ ├── spade_norm.py │ │ ├── transformerblock.py │ │ └── selfattention.py │ ├── layers │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── vector_quantizer.cpython-310.pyc │ │ ├── __init__.py │ │ └── vector_quantizer.py │ ├── schedulers │ │ ├── __pycache__ │ │ │ ├── ddim.cpython-310.pyc │ │ │ ├── ddpm.cpython-310.pyc │ │ │ ├── pndm.cpython-310.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── scheduler.cpython-310.pyc │ │ ├── __init__.py │ │ ├── scheduler.py │ │ └── ddpm.py │ └── __init__.py ├── version.py ├── __init__.py ├── engines │ ├── __init__.py │ └── prepare_batch.py └── metrics │ ├── __init__.py │ ├── mmd.py │ ├── fid.py │ ├── ms_ssim.py │ └── ssim.py ├── my_vqvae ├── __pycache__ │ ├── conditional_vae.cpython-310.pyc │ └── conditional_encoder.cpython-310.pyc └── train_vae.py ├── requirements.txt ├── README.md └── stable_diffusion ├── val_model.py ├── train_sd.py └── trian_model.py /img/View.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/img/View.png -------------------------------------------------------------------------------- /img/Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/img/Model.png -------------------------------------------------------------------------------- /mask/mask_4.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/mask/mask_4.pt -------------------------------------------------------------------------------- /mask/mask_6.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/mask/mask_6.pt -------------------------------------------------------------------------------- /mask/mask_8.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/mask/mask_8.pt -------------------------------------------------------------------------------- /img/Strategy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/img/Strategy.png -------------------------------------------------------------------------------- /mask/mask_10.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/mask/mask_10.pt -------------------------------------------------------------------------------- /scripts/train_vae.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | CUDA_VISIBLE_DEVICES="0" nohup python my_vqvae/train_cvae.py > train.log 2>&1 & -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | CUDA_VISIBLE_DEVICES="3" nohup python evaluation/controlnet_eval.py > train_sd.log 2>&1 & -------------------------------------------------------------------------------- /scripts/train_sd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | CUDA_VISIBLE_DEVICES="0" nohup python stable_diffusion/zheer_sd.py > train_sd.log 2>&1 & -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /config/__pycache__/con_vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/__pycache__/con_vae.cpython-310.pyc -------------------------------------------------------------------------------- /unet/__pycache__/MC_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/unet/__pycache__/MC_model.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /config/vae/__pycache__/con_vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/vae/__pycache__/con_vae.cpython-310.pyc -------------------------------------------------------------------------------- /generative/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/__pycache__/version.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/__pycache__/version.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/convert_time.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/convert_time.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual_stuff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/visual_stuff.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visual_stuff.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/visual_stuff.cpython-39.pyc -------------------------------------------------------------------------------- /config/__pycache__/sd_config_512.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/__pycache__/sd_config_512.cpython-310.pyc -------------------------------------------------------------------------------- /config/__pycache__/vae_config_512.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/__pycache__/vae_config_512.cpython-310.pyc -------------------------------------------------------------------------------- /generative/utils/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/utils/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/train_controlnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | CUDA_VISIBLE_DEVICES="2" nohup python stable_diffusion/zheer_late_controlnet.py > train_diff_loss.log 2>&1 & -------------------------------------------------------------------------------- /utils/__pycache__/Class_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/Class_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /config/__pycache__/controlnet_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/__pycache__/controlnet_config.cpython-310.pyc -------------------------------------------------------------------------------- /generative/utils/__pycache__/enums.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/utils/__pycache__/enums.cpython-310.pyc -------------------------------------------------------------------------------- /my_vqvae/__pycache__/conditional_vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/my_vqvae/__pycache__/conditional_vae.cpython-310.pyc -------------------------------------------------------------------------------- /unet/__pycache__/Model_UKAN_Hybrid.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/unet/__pycache__/Model_UKAN_Hybrid.cpython-310.pyc -------------------------------------------------------------------------------- /generative/losses/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/losses/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Condition_dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/Condition_dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Condition_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/Condition_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Condition_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/Condition_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /config/vae/__pycache__/config_vae_zheer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/vae/__pycache__/config_vae_zheer.cpython-310.pyc -------------------------------------------------------------------------------- /generative/inferers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/inferers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/inferers/__pycache__/inferer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/inferers/__pycache__/inferer.cpython-310.pyc -------------------------------------------------------------------------------- /generative/losses/__pycache__/perceptual.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/losses/__pycache__/perceptual.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /my_vqvae/__pycache__/conditional_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/my_vqvae/__pycache__/conditional_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /config/diffusion/__pycache__/sd_config_512.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/diffusion/__pycache__/sd_config_512.cpython-310.pyc -------------------------------------------------------------------------------- /generative/losses/__pycache__/spectral_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/losses/__pycache__/spectral_loss.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/vqvae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/vqvae.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Condition_aug_dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/utils/__pycache__/Condition_aug_dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /config/cvae/__pycache__/config_cvae_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/cvae/__pycache__/config_cvae_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /config/diffusion/__pycache__/controlnet_config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/diffusion/__pycache__/controlnet_config.cpython-310.pyc -------------------------------------------------------------------------------- /config/vae/__pycache__/config_monaivae_zheer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/vae/__pycache__/config_monaivae_zheer.cpython-310.pyc -------------------------------------------------------------------------------- /generative/losses/__pycache__/adversarial_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/losses/__pycache__/adversarial_loss.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/utils/__pycache__/component_store.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/utils/__pycache__/component_store.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/blocks/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/blocks/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/layers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/layers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/schedulers/__pycache__/ddim.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/schedulers/__pycache__/ddim.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/schedulers/__pycache__/ddpm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/schedulers/__pycache__/ddpm.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/schedulers/__pycache__/pndm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/schedulers/__pycache__/pndm.cpython-310.pyc -------------------------------------------------------------------------------- /config/diffusion/__pycache__/config_mix_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/diffusion/__pycache__/config_mix_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /config/diffusion/__pycache__/config_mix_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/diffusion/__pycache__/config_mix_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/blocks/__pycache__/spade_norm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/blocks/__pycache__/spade_norm.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/autoencoderkl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/autoencoderkl.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/spade_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/spade_network.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/schedulers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/schedulers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /config/diffusion/__pycache__/config_zheer_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/config/diffusion/__pycache__/config_zheer_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/blocks/__pycache__/selfattention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/blocks/__pycache__/selfattention.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/schedulers/__pycache__/scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/schedulers/__pycache__/scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/blocks/__pycache__/encoder_modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/blocks/__pycache__/encoder_modules.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/blocks/__pycache__/transformerblock.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/blocks/__pycache__/transformerblock.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/layers/__pycache__/vector_quantizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/layers/__pycache__/vector_quantizer.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/spade_autoencoderkl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/spade_autoencoderkl.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-fid==0.30.0 2 | torch==2.3.0 3 | torchvision==0.18.0 4 | tqdm 5 | timm==0.9.16 6 | scikit-image==0.23.1 7 | opencv-python==4.10.0.84 8 | accelerate==0.32.1 9 | numpy==1.26.4 -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/diffusion_model_unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/diffusion_model_unet.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/patchgan_discriminator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/patchgan_discriminator.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/nets/__pycache__/spade_diffusion_model_unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lcbkmm/TC-KANRecon/HEAD/generative/networks/nets/__pycache__/spade_diffusion_model_unet.cpython-310.pyc -------------------------------------------------------------------------------- /generative/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | -------------------------------------------------------------------------------- /generative/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | __version__ = "0.2.3" 15 | -------------------------------------------------------------------------------- /generative/networks/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from .vector_quantizer import EMAQuantizer, VectorQuantizer 13 | -------------------------------------------------------------------------------- /generative/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .version import __version__ 15 | -------------------------------------------------------------------------------- /generative/engines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch 15 | from .trainer import AdversarialTrainer 16 | -------------------------------------------------------------------------------- /generative/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .adversarial_loss import PatchAdversarialLoss 15 | from .perceptual import PerceptualLoss 16 | from .spectral_loss import JukeboxLoss 17 | -------------------------------------------------------------------------------- /generative/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .fid import FIDMetric 15 | from .mmd import MMDMetric 16 | from .ms_ssim import MultiScaleSSIMMetric 17 | from .ssim import SSIMMetric 18 | -------------------------------------------------------------------------------- /generative/networks/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .encoder_modules import SpatialRescaler 15 | from .selfattention import SABlock 16 | from .transformerblock import TransformerBlock 17 | -------------------------------------------------------------------------------- /generative/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .component_store import ComponentStore 15 | from .enums import AdversarialIterationEvents, AdversarialKeys 16 | from .misc import unsqueeze_left, unsqueeze_right 17 | -------------------------------------------------------------------------------- /generative/networks/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .ddim import DDIMScheduler 15 | from .ddpm import DDPMScheduler 16 | from .pndm import PNDMScheduler 17 | from .scheduler import NoiseSchedules, Scheduler 18 | -------------------------------------------------------------------------------- /generative/inferers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .inferer import ( 15 | ControlNetDiffusionInferer, 16 | ControlNetLatentDiffusionInferer, 17 | DiffusionInferer, 18 | LatentDiffusionInferer, 19 | VQVAETransformerInferer, 20 | ) 21 | -------------------------------------------------------------------------------- /config/vae/config_monaivae_zheer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 8 7 | eval_bc = 8 8 | num_epochs = 120 9 | data_path = 'mridata/firstMRI/trainImg' 10 | eval_path = 'mridata/firstMRI/trainImg' 11 | single_channel = True 12 | mode = 'double' 13 | val_length = 2000 14 | 15 | # train process configuration 16 | val_inter = 3 17 | save_inter = 5 18 | sample_size = (256, 256) 19 | 20 | # model parameters 21 | in_channels = 1 22 | out_channels = 1 23 | up_and_down = (128, 256, 512) 24 | num_res_layers = 2 25 | vae_path = '' 26 | dis_path = '' 27 | autoencoder_warm_up_n_epochs = 100 28 | 29 | # mask 30 | mask_path = 'mask/mask_4.pt' 31 | # accelerate config 32 | split_batches = False 33 | mixed_precision = 'fp16' 34 | log_with = 'tensorboard' 35 | project_dir = 'weights/vae/exp_7_20' 36 | gradient_accumulation_steps = 1 37 | -------------------------------------------------------------------------------- /config/vae/vae_config_512.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 8 7 | eval_bc = 8 8 | num_epochs = 120 9 | data_path = '/home/wangchangmiao/yuxiao/MRIRecon/mridata/firstMRI/trainImg' 10 | eval_path = '/home/wangchangmiao/yuxiao/MRIRecon/mridata/firstMRI/trainImg' 11 | single_channel = True 12 | mode = 'double' 13 | val_length = 2000 14 | 15 | # train process configuration 16 | val_inter = 2 17 | save_inter = 10 18 | sample_size = 256 19 | 20 | # model parameters 21 | in_channels = 1 22 | out_channels = 1 23 | up_and_down = (128, 256, 512) 24 | num_res_layers = 2 25 | num_embeddings = 256 26 | autoencoder_warm_up_n_epochs = 0 27 | 28 | # mask 29 | mask_path = '/home/wangchangmiao/yuxiao/TC-UKANRecon/mask/mask_4.pt' 30 | # accelerate config 31 | split_batches = False 32 | mixed_precision = 'fp16' 33 | log_with = 'tensorboard' 34 | project_dir = 'weights/vae/exp_7_15' 35 | gradient_accumulation_steps = 1 36 | -------------------------------------------------------------------------------- /config/vae/config_vae_zheer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 1 7 | eval_bc = 8 8 | num_epochs = 2000 9 | mode = 'double' 10 | data_path = '/home/chenyifei/data/CYFData/FZJ/308_1.06/color/train' 11 | eval_path = '/home/chenyifei/data/CYFData/FZJ/308_1.06/color/test' 12 | single_channel = True 13 | mode = 'double' 14 | val_length = 2000 15 | 16 | # train process configuration 17 | val_inter = 120 18 | save_inter = 200 19 | sample_size = 512 20 | 21 | # model parameters 22 | in_channels = 1 23 | out_channels = 1 24 | condition_channels=3 25 | up_and_down = (128, 256, 512) 26 | num_res_layers = 2 27 | num_embeddings = 512 28 | 29 | resume_path = '/home/chenyifei/data/CYFData/FZJ/mri-autoencoder-v0.1' 30 | 31 | autoencoder_warm_up_n_epochs = 0 32 | # accelerate config 33 | split_batches = False 34 | mixed_precision = 'fp16' 35 | log_with = 'tensorboard' 36 | project_dir = 'weights/exp_6_5' 37 | gradient_accumulation_steps = 1 38 | -------------------------------------------------------------------------------- /generative/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import TypeVar 15 | 16 | T = TypeVar("T") 17 | 18 | 19 | def unsqueeze_right(arr: T, ndim: int) -> T: 20 | """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" 21 | return arr[(...,) + (None,) * (ndim - arr.ndim)] 22 | 23 | 24 | def unsqueeze_left(arr: T, ndim: int) -> T: 25 | """Preppend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" 26 | return arr[(None,) * (ndim - arr.ndim)] 27 | -------------------------------------------------------------------------------- /generative/networks/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from .autoencoderkl import AutoencoderKL 15 | from .controlnet import ControlNet 16 | from .diffusion_model_unet import DiffusionModelUNet 17 | from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator 18 | from .spade_autoencoderkl import SPADEAutoencoderKL 19 | from .spade_diffusion_model_unet import SPADEDiffusionModelUNet 20 | from .spade_network import SPADENet 21 | from .transformer import DecoderOnlyTransformer 22 | from .vqvae import VQVAE 23 | -------------------------------------------------------------------------------- /config/512_vae_config.yaml: -------------------------------------------------------------------------------- 1 | # base configuration 2 | train_bc: 4 3 | eval_bc: 8 4 | num_epochs: 200 5 | data_path: ['/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/data_sloeaffa', 6 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/data_slolaffa', 7 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/early_koufu_ea', 8 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/early_koufu_la'] 9 | eval_path: '/mntcephfs/lab_data/wangcm/fzj/adv-Diff/dataset/308_12_26/grey/train' 10 | mode: 'double' 11 | val_length: 2000 12 | 13 | # train process configuration 14 | val_inter: 4 15 | save_inter: 5 16 | 17 | # model configuration 18 | in_channels: 3 19 | out_channels: 3 20 | block_out_channels: [ 21 | 128, 22 | 256, 23 | 512 24 | ] 25 | down_block_types: [ 26 | "DownEncoderBlock2D", 27 | "DownEncoderBlock2D", 28 | "DownEncoderBlock2D", 29 | ] 30 | up_block_types: [ 31 | "UpDecoderBlock2D", 32 | "UpDecoderBlock2D", 33 | "UpDecoderBlock2D", 34 | ] 35 | latent_channels: 4 36 | layers_per_block: 2 37 | sample_size: 512 38 | resume_path: '' 39 | 40 | # accelerate config 41 | split_batches: False 42 | mixed_precision: 'fp16' 43 | log_with: 'tensorboard' 44 | project_dir: 'weights/exp_2_13' 45 | gradient_accumulation_steps: 1 -------------------------------------------------------------------------------- /config/vae/config_con_vae.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 16 7 | eval_bc = 16 8 | num_epochs = 200 9 | data_path = ['/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/data_sloeaffa', 10 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/data_slolaffa', 11 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/early_koufu_ea', 12 | '/mntcephfs/lab_data/wangcm/fzj/advanced_VT/dataset/early_koufu_la'] 13 | eval_path = '/mntcephfs/lab_data/wangcm/fzj/adv-Diff/dataset/308_12_26/grey/train' 14 | single_channel = True 15 | mode = 'double' 16 | val_length = 2000 17 | 18 | # train process configuration 19 | val_inter = 2 20 | save_inter = 10 21 | sample_size = 512 22 | 23 | # model parameters 24 | in_channels = 1 25 | out_channels = 1 26 | condition_channels=3 27 | up_and_down = (128, 256, 512) 28 | num_res_layers = 2 29 | num_embeddings = 512 30 | 31 | out_vae_path = 'weights/exp_5_3/gen_save/vqgan.pth' 32 | in_vae_path = 'weights/exp_5_2/gen_save/vqgan.pth' 33 | 34 | autoencoder_warm_up_n_epochs = 0 35 | # accelerate config 36 | split_batches = False 37 | mixed_precision = 'fp16' 38 | log_with = 'tensorboard' 39 | project_dir = 'weights/exp_5_9' 40 | gradient_accumulation_steps = 1 41 | -------------------------------------------------------------------------------- /config/diffusion/config_controlnet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 16 7 | eval_bc = 16 8 | num_epochs = 120 9 | data_path = '' 10 | eval_path = '' 11 | 12 | single_channel = False 13 | mode = 'double' 14 | val_length = 2000 15 | 16 | # train process configuration 17 | val_inter = 3 18 | save_inter = 3 19 | sample_size = 256 20 | 21 | # aekl parameters 22 | in_channels = 1 23 | out_channels = 1 24 | up_and_down = (128, 256, 512) 25 | num_res_layers = 2 26 | scaling_factor = 0.18215 27 | vae_resume_path = '' 28 | 29 | 30 | # stable_model parameters 31 | sd_num_channels = (128, 256, 512, 1024) 32 | attention_levels = (False, False, True, True) 33 | sd_resume_path = '' 34 | controlnet_path = '' 35 | mc_model_path = '' 36 | 37 | # controlnet_model parameters 38 | conditioning_embedding_num_channels = (32, 96, 256) 39 | diff_loss_coefficient = 0.25 40 | offset_noise = False 41 | 42 | # mask 43 | mask_path = 'mask/mask_4.pt' 44 | 45 | # scheduler 46 | beta_start = 0.0008 47 | beta_end = 0.02 48 | beta_schedule = "squaredcos_cap_v2" 49 | clip_sample = True 50 | initial_clip_sample_range = 1.8 51 | clip_rate = 0.0014 52 | 53 | # accelerate config 54 | split_batches = False 55 | mixed_precision = 'fp16' 56 | log_with = 'tensorboard' 57 | project_dir = '' 58 | -------------------------------------------------------------------------------- /config/cvae/config_cvae_controlnet.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config(): 6 | train_bc = 8 7 | eval_bc = 8 8 | num_epochs = 1000 9 | data_path = '/home/chenyifei/data/CYFData/FZJ/308_1.06/color/train' 10 | eval_path = '/home/chenyifei/data/CYFData/FZJ/308_1.06/color/test' 11 | single_channel = False 12 | mode = 'double' 13 | val_length = 2000 14 | 15 | # train process configuration 16 | val_inter = 80 17 | save_inter = 100 18 | sample_size = 512 19 | 20 | # vqgan parameters 21 | in_channels = 1 22 | out_channels = 1 23 | up_and_down = (128, 256, 512) 24 | num_res_layers = 2 25 | num_embeddings = 512 26 | input_path = '/home/chenyifei/data/CYFData/FZJ/weights_of_models/weights_dif_series/exp_2_26/input_save' 27 | output_path = '/home/chenyifei/data/CYFData/FZJ/weights_of_models/weights_dif_series/exp_2_26/output_save' 28 | 29 | 30 | # stable_model parameters 31 | sd_num_channels = (128, 256, 512, 1024) 32 | attention_levels = (False, True, True, True) 33 | sd_resume_path = 'weights/exp_5_11/model_save/model.pth' 34 | controlnet_path = 'weights/exp_5_6/model_save/model.pth' 35 | 36 | # controlnet_model parameters 37 | 38 | 39 | 40 | autoencoder_warm_up_n_epochs = 0 41 | # accelerate config 42 | split_batches = False 43 | mixed_precision = 'fp16' 44 | log_with = 'tensorboard' 45 | project_dir = 'weights/exp_6_10' 46 | gradient_accumulation_steps = 1 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TC-KANRecon 2 | 3 | Overall structure of the TC-KANRecon: 4 | 5 | ![model](img/Model.png) 6 | 7 | Dynamic Clipping Strategy Process: 8 | 9 | ![Strategy](img/Strategy.png) 10 | 11 | Model Generate Detailed Effect Comparison(AF=4): 12 | 13 | ![renderings](img/View.png) 14 | 15 | ## 1. Installation 16 | 17 | Clone this repository and navigate to it in your terminal. Then run: 18 | 19 | ``` 20 | pip install -r requirements . 21 | ``` 22 | 23 | ## 2. Data Preparation 24 | 25 | The two datasets we used are both public datasets. For firstMRI, you can find it in [Link](https://fastmri.med.nyu.edu/), which includes 1172 subjects with more than 41,020 slice data; for SKM-TEA, you can find it in [Link](https://stanfordaimi.azurewebsites.net/datasets/4aaeafb9-c6e6-4e3c-9188-3aaaf0e0a9e7), which includes 155 subjects with more than 24,800 slice data. Both of them use the single-coil data of their knee. 26 | 27 | When you have your data set ready, you need to change your data set path in the configuration file below: 28 | 29 | - for vae: 30 | ``` 31 | python config/vae/config_monaivae_zheer.py 32 | ``` 33 | - for model: 34 | ``` 35 | python config/diffusion/config_controlnet.py 36 | ``` 37 | 38 | ## 3. Training 39 | - for vae: 40 | 41 | ``` 42 | python my_vqvae/train_vae.py 43 | 44 | ``` 45 | - for model: 46 | 47 | ``` 48 | python stable_diffusion/train_sd.py 49 | python stable_diffusion/trian_model.py 50 | 51 | ``` 52 | 53 | ## 4. Evaluating 54 | 55 | ``` 56 | python stable_diffusion/val_model.py 57 | 58 | ``` 59 | 60 | ## 5. Citation 61 | ``` 62 | @article{ge2024tc, 63 | title={TC-KANRecon: High-Quality and Accelerated MRI Reconstruction via Adaptive KAN Mechanisms and Intelligent Feature Scaling}, 64 | author={Ge, Ruiquan and Yu, Xiao and Chen, Yifei and Jia, Fan and Zhu, Shenghao and Zhou, Guanyu and Huang, Yiyu and Zhang, Chenyan and Zeng, Dong and Wang, Changmiao and others}, 65 | journal={arXiv preprint arXiv:2408.05705}, 66 | year={2024} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /utils/visual_stuff.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.tensorboard import SummaryWriter 3 | import torch as th 4 | import os 5 | class Visualizer: 6 | def __init__(self, weights_up_dir, image_size, model, diffusion): 7 | weights_up_dir = self.check_dir(weights_up_dir) 8 | self.recorder = SummaryWriter(weights_up_dir) 9 | self.model = model 10 | self.diffusion = diffusion 11 | self.image_title = 'Default' 12 | if isinstance(image_size, int): 13 | self.image_size = (image_size, image_size) 14 | else: 15 | self.image_size = image_size 16 | 17 | def check_dir(self, dir): 18 | if not os.path.exists(dir): 19 | os.makedirs(dir) 20 | return dir 21 | 22 | def inverse_transform(self, tensor): 23 | return (tensor + 1) / 2 24 | 25 | def performance_display(self, val_iter, step): 26 | batch, cond = next(val_iter) 27 | sample = self.diffusion.p_sample_loop( 28 | self.model, 29 | (1, 3, self.image_size[0], self.image_size[1]), 30 | clip_denoised=True, 31 | model_kwargs=cond, 32 | device=batch.device 33 | ) 34 | sample = th.cat([cond['condition'], sample, batch], dim=0) 35 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 36 | # sample = sample.permute(0, 2, 3, 1) 37 | sample = sample.contiguous() 38 | self.recorder.add_images('p_result', sample.cpu().detach(), step) 39 | self.recorder.flush() 40 | 41 | def display_single(self, tensor, step, normalize=True): 42 | ''' 43 | tensor range: (-1, 1) if normalize=True else (0, 1) 44 | step: milestone you want to show in tensorboard 45 | ''' 46 | if normalize: 47 | tensor = self.inverse_transform(tensor) 48 | self.recorder.add_images(self.image_title, tensor.cpu().detach(), step) 49 | self.recorder.flush() 50 | 51 | def tb_draw_scalars(self, value, which_epoch): 52 | self.recorder.add_scalar('Rescale_loss', value, which_epoch) 53 | self.recorder.flush() -------------------------------------------------------------------------------- /generative/utils/enums.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import TYPE_CHECKING 15 | 16 | from monai.config import IgniteInfo 17 | from monai.utils import StrEnum, min_version, optional_import 18 | 19 | if TYPE_CHECKING: 20 | from ignite.engine import EventEnum 21 | else: 22 | EventEnum, _ = optional_import( 23 | "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base" 24 | ) 25 | 26 | 27 | class AdversarialKeys(StrEnum): 28 | REALS = "reals" 29 | REAL_LOGITS = "real_logits" 30 | FAKES = "fakes" 31 | FAKE_LOGITS = "fake_logits" 32 | RECONSTRUCTION_LOSS = "reconstruction_loss" 33 | GENERATOR_LOSS = "generator_loss" 34 | DISCRIMINATOR_LOSS = "discriminator_loss" 35 | 36 | 37 | class AdversarialIterationEvents(EventEnum): 38 | RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed" 39 | GENERATOR_FORWARD_COMPLETED = "generator_forward_completed" 40 | GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed" 41 | GENERATOR_LOSS_COMPLETED = "generator_loss_completed" 42 | GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed" 43 | GENERATOR_MODEL_COMPLETED = "generator_model_completed" 44 | DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed" 45 | DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed" 46 | DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed" 47 | DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed" 48 | DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed" 49 | 50 | 51 | class OrderingType(StrEnum): 52 | RASTER_SCAN = "raster_scan" 53 | S_CURVE = "s_curve" 54 | RANDOM = "random" 55 | 56 | 57 | class OrderingTransformations(StrEnum): 58 | ROTATE_90 = "rotate_90" 59 | TRANSPOSE = "transpose" 60 | REFLECT = "reflect" 61 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | import inspect 4 | from shutil import copyfile, copy 5 | import os 6 | import torch 7 | from torchvision.utils import make_grid 8 | 9 | def one_to_three(x): 10 | return torch.cat([x]*3, dim=1) 11 | 12 | 13 | def save_config_to_yaml(config_obj, parrent_dir: str): 14 | """ 15 | Saves the given Config object as a YAML file. The output file name is derived 16 | from the module name where the Config class is defined. 17 | 18 | Args: 19 | config_obj (Config): The Config object to be saved. 20 | module_name (str): Name of the Python module containing the Config class definition. 21 | """ 22 | os.makedirs(parrent_dir, exist_ok=True) 23 | # Get the source file path for the specified module 24 | module_path = inspect.getfile(config_obj) 25 | 26 | # Extract the base file name without extension 27 | base_filename = os.path.splitext(os.path.basename(module_path))[0] 28 | 29 | # Construct the output YAML file name 30 | output_file = f"{base_filename}.yaml" 31 | output_file = os.path.join(parrent_dir, output_file) 32 | 33 | # Convert the Config object to a dictionary 34 | config_dict = config_obj.__dict__ 35 | config_dict = {k: v for k, v in config_obj.__dict__.items() if not k.startswith('__')} 36 | # Save the dictionary as a YAML file 37 | with open(output_file, 'w') as yaml_file: 38 | yaml.dump(config_dict, yaml_file, sort_keys=False) 39 | 40 | return config_dict 41 | 42 | 43 | def copy_yaml_to_folder(yaml_file, folder): 44 | """ 45 | 将一个 YAML 文件复制到一个文件夹中 46 | :param yaml_file: YAML 文件的路径 47 | :param folder: 目标文件夹路径 48 | """ 49 | # 确保目标文件夹存在 50 | os.makedirs(folder, exist_ok=True) 51 | 52 | # 获取 YAML 文件的文件名 53 | file_name = os.path.basename(yaml_file) 54 | 55 | # 将 YAML 文件复制到目标文件夹中 56 | copy(yaml_file, os.path.join(folder, file_name)) 57 | 58 | def force_remove_empty_dir(path): 59 | try: 60 | os.rmdir(path) 61 | print(f"Directory '{path}' removed successfully.") 62 | except FileNotFoundError: 63 | print(f"Directory '{path}' not found.") 64 | except OSError as e: 65 | print(f"Error removing directory '{path}': {e}") 66 | 67 | def load_config(file_path): 68 | with open(file_path, 'r', encoding='utf-8') as f: 69 | config = yaml.safe_load(f) 70 | for key in config.keys(): 71 | if type(config[key]) == list: 72 | config[key] = tuple(config[key]) 73 | return config 74 | 75 | def get_parameters(fn, original_dict): 76 | new_dict = dict() 77 | arg_names = inspect.getfullargspec(fn)[0] 78 | for k in original_dict.keys(): 79 | if k in arg_names: 80 | new_dict[k] = original_dict[k] 81 | return new_dict 82 | 83 | def write_config(config_path, save_path): 84 | copyfile(config_path, save_path) 85 | 86 | 87 | def check_dir(dire): 88 | if not os.path.exists(dire): 89 | os.makedirs(dire) 90 | return dire 91 | 92 | def combine_tensors_2_tb(tensor_list:list=None): 93 | image = torch.cat(tensor_list, dim=-1) 94 | image = (make_grid(image, nrow=1).unsqueeze(0)+1)/2 95 | return image.clamp(0, 1) 96 | 97 | -------------------------------------------------------------------------------- /generative/metrics/mmd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Callable 15 | 16 | import torch 17 | from monai.metrics.metric import Metric 18 | 19 | 20 | class MMDMetric(Metric): 21 | """ 22 | Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two 23 | distributions. It is a non-negative metric where a smaller value indicates a closer match between the two 24 | distributions. 25 | 26 | Gretton, A., et al,, 2012. A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773. 27 | 28 | Args: 29 | y_transform: Callable to transform the y tensor before computing the metric. It is usually a Gaussian or Laplace 30 | filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a 31 | feature extractor or an Identity function. 32 | y_pred_transform: Callable to transform the y_pred tensor before computing the metric. 33 | """ 34 | 35 | def __init__(self, y_transform: Callable | None = None, y_pred_transform: Callable | None = None) -> None: 36 | super().__init__() 37 | 38 | self.y_transform = y_transform 39 | self.y_pred_transform = y_pred_transform 40 | 41 | def __call__(self, y: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Args: 44 | y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D. 45 | y_pred: second sample (e.g., the reconstructed image). It has similar shape as y. 46 | """ 47 | 48 | # Beta and Gamma are not calculated since torch.mean is used at return 49 | beta = 1.0 50 | gamma = 2.0 51 | 52 | if self.y_transform is not None: 53 | y = self.y_transform(y) 54 | 55 | if self.y_pred_transform is not None: 56 | y_pred = self.y_pred_transform(y_pred) 57 | 58 | if y_pred.shape != y.shape: 59 | raise ValueError( 60 | "y_pred and y shapes dont match after being processed " 61 | f"by their transforms, received y_pred: {y_pred.shape} and y: {y.shape}" 62 | ) 63 | 64 | for d in range(len(y.shape) - 1, 1, -1): 65 | y = y.squeeze(dim=d) 66 | y_pred = y_pred.squeeze(dim=d) 67 | 68 | y = y.view(y.shape[0], -1) 69 | y_pred = y_pred.view(y_pred.shape[0], -1) 70 | 71 | y_y = torch.mm(y, y.t()) 72 | y_pred_y_pred = torch.mm(y_pred, y_pred.t()) 73 | y_pred_y = torch.mm(y_pred, y.t()) 74 | 75 | y_y = y_y / y.shape[1] 76 | y_pred_y_pred = y_pred_y_pred / y.shape[1] 77 | y_pred_y = y_pred_y / y.shape[1] 78 | 79 | # Ref. 1 Eq. 3 (found under Lemma 6) 80 | return beta * (torch.mean(y_y) + torch.mean(y_pred_y_pred)) - gamma * torch.mean(y_pred_y) 81 | -------------------------------------------------------------------------------- /generative/networks/blocks/encoder_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Sequence 15 | from functools import partial 16 | 17 | import torch 18 | import torch.nn as nn 19 | from monai.networks.blocks import Convolution 20 | 21 | __all__ = ["SpatialRescaler"] 22 | 23 | 24 | class SpatialRescaler(nn.Module): 25 | """ 26 | SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py 27 | 28 | Args: 29 | spatial_dims: number of spatial dimensions. 30 | n_stages: number of interpolation stages. 31 | size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). 32 | method: algorithm used for sampling. 33 | multiplier: multiplier for spatial size. If `multiplier` is a sequence, 34 | its length has to match the number of spatial dimensions; `input.dim() - 2`. 35 | in_channels: number of input channels. 36 | out_channels: number of output channels. 37 | bias: whether to have a bias term. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | spatial_dims: int = 2, 43 | n_stages: int = 1, 44 | size: Sequence[int] | int | None = None, 45 | method: str = "bilinear", 46 | multiplier: Sequence[float] | float | None = None, 47 | in_channels: int = 3, 48 | out_channels: int = None, 49 | bias: bool = False, 50 | ): 51 | super().__init__() 52 | self.n_stages = n_stages 53 | assert self.n_stages >= 0 54 | assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] 55 | if size is not None and n_stages != 1: 56 | raise ValueError("when size is not None, n_stages should be 1.") 57 | if size is not None and multiplier is not None: 58 | raise ValueError("only one of size or multiplier should be defined.") 59 | self.multiplier = multiplier 60 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size) 61 | self.remap_output = out_channels is not None 62 | if self.remap_output: 63 | print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") 64 | self.channel_mapper = Convolution( 65 | spatial_dims=spatial_dims, 66 | in_channels=in_channels, 67 | out_channels=out_channels, 68 | kernel_size=1, 69 | conv_only=True, 70 | bias=bias, 71 | ) 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | if self.remap_output: 75 | x = self.channel_mapper(x) 76 | 77 | for _ in range(self.n_stages): 78 | x = self.interpolator(x, scale_factor=self.multiplier) 79 | 80 | return x 81 | 82 | def encode(self, x: torch.Tensor) -> torch.Tensor: 83 | return self(x) 84 | -------------------------------------------------------------------------------- /generative/losses/spectral_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from monai.utils import LossReduction 17 | from torch.fft import fftn 18 | from torch.nn.modules.loss import _Loss 19 | 20 | 21 | class JukeboxLoss(_Loss): 22 | """ 23 | Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT). 24 | 25 | Based on: 26 | Dhariwal, et al. 'Jukebox: A generative model for music.'https://arxiv.org/abs/2005.00341 27 | 28 | Args: 29 | spatial_dims: number of spatial dimensions. 30 | fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information. 31 | fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See 32 | torch.fft.fftn() for more information. 33 | 34 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 35 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 36 | 37 | - ``"none"``: no reduction will be applied. 38 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 39 | - ``"sum"``: the output will be summed. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | spatial_dims: int, 45 | fft_signal_size: tuple[int] | None = None, 46 | fft_norm: str = "ortho", 47 | reduction: LossReduction | str = LossReduction.MEAN, 48 | ) -> None: 49 | super().__init__(reduction=LossReduction(reduction).value) 50 | 51 | self.spatial_dims = spatial_dims 52 | self.fft_signal_size = fft_signal_size 53 | self.fft_dim = tuple(range(1, spatial_dims + 2)) 54 | self.fft_norm = fft_norm 55 | 56 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 57 | input_amplitude = self._get_fft_amplitude(target) 58 | target_amplitude = self._get_fft_amplitude(input) 59 | 60 | # Compute distance between amplitude of frequency components 61 | # See Section 3.3 from https://arxiv.org/abs/2005.00341 62 | loss = F.mse_loss(target_amplitude, input_amplitude, reduction="none") 63 | 64 | if self.reduction == LossReduction.MEAN.value: 65 | loss = loss.mean() 66 | elif self.reduction == LossReduction.SUM.value: 67 | loss = loss.sum() 68 | elif self.reduction == LossReduction.NONE.value: 69 | pass 70 | 71 | return loss 72 | 73 | def _get_fft_amplitude(self, images: torch.Tensor) -> torch.Tensor: 74 | """ 75 | Calculate the amplitude of the fourier transformations representation of the images 76 | 77 | Args: 78 | images: Images that are to undergo fftn 79 | 80 | Returns: 81 | fourier transformation amplitude 82 | """ 83 | img_fft = fftn(images, s=self.fft_signal_size, dim=self.fft_dim, norm=self.fft_norm) 84 | 85 | amplitude = torch.sqrt(torch.real(img_fft) ** 2 + torch.imag(img_fft) ** 2) 86 | 87 | return amplitude 88 | -------------------------------------------------------------------------------- /generative/networks/blocks/spade_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from monai.networks.blocks import ADN, Convolution 18 | 19 | 20 | class SPADE(nn.Module): 21 | """ 22 | SPADE normalisation block based on the 2019 paper by Park et al. (doi: https://doi.org/10.48550/arXiv.1903.07291) 23 | 24 | Args: 25 | label_nc: number of semantic labels 26 | norm_nc: number of output channels 27 | kernel_size: kernel size 28 | spatial_dims: number of spatial dimensions 29 | hidden_channels: number of channels in the intermediate gamma and beta layers 30 | norm: type of base normalisation used before applying the SPADE normalisation 31 | norm_params: parameters for the base normalisation 32 | """ 33 | 34 | def __init__( 35 | self, 36 | label_nc: int, 37 | norm_nc: int, 38 | kernel_size: int = 3, 39 | spatial_dims: int = 2, 40 | hidden_channels: int = 64, 41 | norm: str | tuple = "INSTANCE", 42 | norm_params: dict | None = None, 43 | ) -> None: 44 | super().__init__() 45 | 46 | if norm_params is None: 47 | norm_params = {} 48 | if len(norm_params) != 0: 49 | norm = (norm, norm_params) 50 | self.param_free_norm = ADN( 51 | act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc 52 | ) 53 | self.mlp_shared = Convolution( 54 | spatial_dims=spatial_dims, 55 | in_channels=label_nc, 56 | out_channels=hidden_channels, 57 | kernel_size=kernel_size, 58 | norm=None, 59 | padding=kernel_size // 2, 60 | act="LEAKYRELU", 61 | ) 62 | self.mlp_gamma = Convolution( 63 | spatial_dims=spatial_dims, 64 | in_channels=hidden_channels, 65 | out_channels=norm_nc, 66 | kernel_size=kernel_size, 67 | padding=kernel_size // 2, 68 | act=None, 69 | ) 70 | self.mlp_beta = Convolution( 71 | spatial_dims=spatial_dims, 72 | in_channels=hidden_channels, 73 | out_channels=norm_nc, 74 | kernel_size=kernel_size, 75 | padding=kernel_size // 2, 76 | act=None, 77 | ) 78 | 79 | def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: 80 | """ 81 | Args: 82 | x: input tensor 83 | segmap: input segmentation map (bxcx[spatial-dimensions]) where c is the number of semantic channels. 84 | The map will be interpolated to the dimension of x internally. 85 | """ 86 | 87 | # Part 1. generate parameter-free normalized activations 88 | normalized = self.param_free_norm(x) 89 | 90 | # Part 2. produce scaling and bias conditioned on semantic map 91 | segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") 92 | actv = self.mlp_shared(segmap) 93 | gamma = self.mlp_gamma(actv) 94 | beta = self.mlp_beta(actv) 95 | out = normalized * (1 + gamma) + beta 96 | return out 97 | -------------------------------------------------------------------------------- /generative/networks/blocks/transformerblock.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | from monai.networks.blocks.mlp import MLPBlock 17 | 18 | from generative.networks.blocks.selfattention import SABlock 19 | 20 | 21 | class TransformerBlock(nn.Module): 22 | """ 23 | A transformer block, based on: "Dosovitskiy et al., 24 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 25 | 26 | Args: 27 | hidden_size: dimension of hidden layer. 28 | mlp_dim: dimension of feedforward layer. 29 | num_heads: number of attention heads. 30 | dropout_rate: faction of the input units to drop. 31 | qkv_bias: apply bias term for the qkv linear layer 32 | causal: whether to use causal attention. 33 | sequence_length: if causal is True, it is necessary to specify the sequence length. 34 | with_cross_attention: Whether to use cross attention for conditioning. 35 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | hidden_size: int, 41 | mlp_dim: int, 42 | num_heads: int, 43 | dropout_rate: float = 0.0, 44 | qkv_bias: bool = False, 45 | causal: bool = False, 46 | sequence_length: int | None = None, 47 | with_cross_attention: bool = False, 48 | use_flash_attention: bool = False, 49 | ) -> None: 50 | self.with_cross_attention = with_cross_attention 51 | super().__init__() 52 | 53 | if not (0 <= dropout_rate <= 1): 54 | raise ValueError("dropout_rate should be between 0 and 1.") 55 | 56 | if hidden_size % num_heads != 0: 57 | raise ValueError("hidden_size should be divisible by num_heads.") 58 | 59 | self.norm1 = nn.LayerNorm(hidden_size) 60 | self.attn = SABlock( 61 | hidden_size=hidden_size, 62 | num_heads=num_heads, 63 | dropout_rate=dropout_rate, 64 | qkv_bias=qkv_bias, 65 | causal=causal, 66 | sequence_length=sequence_length, 67 | use_flash_attention=use_flash_attention, 68 | ) 69 | 70 | self.norm2 = None 71 | self.cross_attn = None 72 | if self.with_cross_attention: 73 | self.norm2 = nn.LayerNorm(hidden_size) 74 | self.cross_attn = SABlock( 75 | hidden_size=hidden_size, 76 | num_heads=num_heads, 77 | dropout_rate=dropout_rate, 78 | qkv_bias=qkv_bias, 79 | with_cross_attention=with_cross_attention, 80 | causal=False, 81 | use_flash_attention=use_flash_attention, 82 | ) 83 | 84 | self.norm3 = nn.LayerNorm(hidden_size) 85 | self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) 86 | 87 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 88 | x = x + self.attn(self.norm1(x)) 89 | if self.with_cross_attention: 90 | x = x + self.cross_attn(self.norm2(x), context=context) 91 | x = x + self.mlp(self.norm3(x)) 92 | return x 93 | -------------------------------------------------------------------------------- /generative/networks/nets/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from generative.networks.blocks.transformerblock import TransformerBlock 18 | 19 | __all__ = ["DecoderOnlyTransformer"] 20 | 21 | 22 | class AbsolutePositionalEmbedding(nn.Module): 23 | """Absolute positional embedding. 24 | 25 | Args: 26 | max_seq_len: Maximum sequence length. 27 | embedding_dim: Dimensionality of the embedding. 28 | """ 29 | 30 | def __init__(self, max_seq_len: int, embedding_dim: int) -> None: 31 | super().__init__() 32 | self.max_seq_len = max_seq_len 33 | self.embedding_dim = embedding_dim 34 | self.embedding = nn.Embedding(max_seq_len, embedding_dim) 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | batch_size, seq_len = x.size() 38 | positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1) 39 | return self.embedding(positions) 40 | 41 | 42 | class DecoderOnlyTransformer(nn.Module): 43 | """Decoder-only (Autoregressive) Transformer model. 44 | 45 | Args: 46 | num_tokens: Number of tokens in the vocabulary. 47 | max_seq_len: Maximum sequence length. 48 | attn_layers_dim: Dimensionality of the attention layers. 49 | attn_layers_depth: Number of attention layers. 50 | attn_layers_heads: Number of attention heads. 51 | with_cross_attention: Whether to use cross attention for conditioning. 52 | embedding_dropout_rate: Dropout rate for the embedding. 53 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | num_tokens: int, 59 | max_seq_len: int, 60 | attn_layers_dim: int, 61 | attn_layers_depth: int, 62 | attn_layers_heads: int, 63 | with_cross_attention: bool = False, 64 | embedding_dropout_rate: float = 0.0, 65 | use_flash_attention: bool = False, 66 | ) -> None: 67 | super().__init__() 68 | self.num_tokens = num_tokens 69 | self.max_seq_len = max_seq_len 70 | self.attn_layers_dim = attn_layers_dim 71 | self.attn_layers_depth = attn_layers_depth 72 | self.attn_layers_heads = attn_layers_heads 73 | self.with_cross_attention = with_cross_attention 74 | 75 | self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim) 76 | self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim) 77 | self.embedding_dropout = nn.Dropout(embedding_dropout_rate) 78 | 79 | self.blocks = nn.ModuleList( 80 | [ 81 | TransformerBlock( 82 | hidden_size=attn_layers_dim, 83 | mlp_dim=attn_layers_dim * 4, 84 | num_heads=attn_layers_heads, 85 | dropout_rate=0.0, 86 | qkv_bias=False, 87 | causal=True, 88 | sequence_length=max_seq_len, 89 | with_cross_attention=with_cross_attention, 90 | use_flash_attention=use_flash_attention, 91 | ) 92 | for _ in range(attn_layers_depth) 93 | ] 94 | ) 95 | 96 | self.to_logits = nn.Linear(attn_layers_dim, num_tokens) 97 | 98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 99 | tok_emb = self.token_embeddings(x) 100 | pos_emb = self.position_embeddings(x) 101 | x = self.embedding_dropout(tok_emb + pos_emb) 102 | 103 | for block in self.blocks: 104 | x = block(x, context=context) 105 | 106 | return self.to_logits(x) 107 | -------------------------------------------------------------------------------- /generative/utils/component_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections import namedtuple 15 | from keyword import iskeyword 16 | from textwrap import dedent, indent 17 | from typing import Any, Callable, Dict, Iterable, TypeVar 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | def is_variable(name): 23 | """Returns True if `name` is a valid Python variable name and also not a keyword.""" 24 | return name.isidentifier() and not iskeyword(name) 25 | 26 | 27 | class ComponentStore: 28 | """ 29 | Represents a storage object for other objects (specifically functions) keyed to a name with a description. 30 | 31 | These objects act as global named places for storing components for objects parameterised by component names. 32 | Typically this is functions although other objects can be added. Printing a component store will produce a 33 | list of members along with their docstring information if present. 34 | 35 | Example: 36 | 37 | .. code-block:: python 38 | 39 | TestStore = ComponentStore("Test Store", "A test store for demo purposes") 40 | 41 | @TestStore.add_def("my_func_name", "Some description of your function") 42 | def _my_func(a, b): 43 | '''A description of your function here.''' 44 | return a * b 45 | 46 | print(TestStore) # will print out name, description, and 'my_func_name' with the docstring 47 | 48 | func = TestStore["my_func_name"] 49 | result = func(7, 6) 50 | 51 | """ 52 | 53 | _Component = namedtuple("Component", ("description", "value")) # internal value pair 54 | 55 | def __init__(self, name: str, description: str) -> None: 56 | self.components: Dict[str, self._Component] = {} 57 | self.name: str = name 58 | self.description: str = description 59 | 60 | self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() 61 | 62 | def add(self, name: str, desc: str, value: T) -> T: 63 | """Store the object `value` under the name `name` with description `desc`.""" 64 | if not is_variable(name): 65 | raise ValueError("Name of component must be valid Python identifier") 66 | 67 | self.components[name] = self._Component(desc, value) 68 | return value 69 | 70 | def add_def(self, name: str, desc: str) -> Callable: 71 | """Returns a decorator which stores the decorated function under `name` with description `desc`.""" 72 | 73 | def deco(func): 74 | """Decorator to add a function to a store.""" 75 | return self.add(name, desc, func) 76 | 77 | return deco 78 | 79 | def __contains__(self, name: str) -> bool: 80 | """Returns True if the given name is stored.""" 81 | return name in self.components 82 | 83 | def __len__(self) -> int: 84 | """Returns the number of stored components.""" 85 | return len(self.components) 86 | 87 | def __iter__(self) -> Iterable: 88 | """Yields name/component pairs.""" 89 | for k, v in self.components.items(): 90 | yield k, v.value 91 | 92 | def __str__(self): 93 | result = f"Component Store '{self.name}': {self.description}\nAvailable components:" 94 | for k, v in self.components.items(): 95 | result += f"\n* {k}:" 96 | 97 | if hasattr(v.value, "__doc__"): 98 | doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") 99 | result += f"\n{doc}\n" 100 | else: 101 | result += f" {v.description}" 102 | 103 | return result 104 | 105 | def __getattr__(self, name: str) -> Any: 106 | """Returns the stored object under the given name.""" 107 | if name in self.components: 108 | return self.components[name].value 109 | else: 110 | return self.__getattribute__(name) 111 | 112 | def __getitem__(self, name: str) -> Any: 113 | """Returns the stored object under the given name.""" 114 | if name in self.components: 115 | return self.components[name].value 116 | else: 117 | raise ValueError(f"Component '{name}' not found") 118 | -------------------------------------------------------------------------------- /generative/metrics/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from __future__ import annotations 14 | 15 | import numpy as np 16 | import torch 17 | from monai.metrics.metric import Metric 18 | from scipy import linalg 19 | 20 | 21 | class FIDMetric(Metric): 22 | """ 23 | Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors. 24 | Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." 25 | https://arxiv.org/abs/1706.08500#. The inputs for this metric should be two groups of feature vectors (with format 26 | (number images, number of features)) extracted from the a pretrained network. 27 | 28 | Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet. 29 | However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and 30 | MedicalNet for 3D images). If the chosen model output is not a scalar, usually it is used a global spatial 31 | average pooling. 32 | """ 33 | 34 | def __init__(self) -> None: 35 | super().__init__() 36 | 37 | def __call__(self, y_pred: torch.Tensor, y: torch.Tensor): 38 | return get_fid_score(y_pred, y) 39 | 40 | 41 | def get_fid_score(y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 42 | y = y.double() 43 | y_pred = y_pred.double() 44 | 45 | if y.ndimension() > 2: 46 | raise ValueError("Inputs should have (number images, number of features) shape.") 47 | 48 | mu_y_pred = torch.mean(y_pred, dim=0) 49 | sigma_y_pred = _cov(y_pred, rowvar=False) 50 | mu_y = torch.mean(y, dim=0) 51 | sigma_y = _cov(y, rowvar=False) 52 | 53 | return compute_frechet_distance(mu_y_pred, sigma_y_pred, mu_y, sigma_y) 54 | 55 | 56 | def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor: 57 | """ 58 | Estimate a covariance matrix of the variables. 59 | 60 | Args: 61 | input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable, 62 | and each column a single observation of all those variables. 63 | rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns. 64 | Otherwise, the relationship is transposed: each column represents a variable, while the rows contain 65 | observations. 66 | """ 67 | if input_data.dim() < 2: 68 | input_data = input_data.view(1, -1) 69 | 70 | if not rowvar and input_data.size(0) != 1: 71 | input_data = input_data.t() 72 | 73 | factor = 1.0 / (input_data.size(1) - 1) 74 | input_data = input_data - torch.mean(input_data, dim=1, keepdim=True) 75 | return factor * input_data.matmul(input_data.t()).squeeze() 76 | 77 | 78 | def _sqrtm(input_data: torch.Tensor) -> torch.Tensor: 79 | """Compute the square root of a matrix.""" 80 | scipy_res, _ = linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False) 81 | return torch.from_numpy(scipy_res) 82 | 83 | 84 | def compute_frechet_distance( 85 | mu_x: torch.Tensor, sigma_x: torch.Tensor, mu_y: torch.Tensor, sigma_y: torch.Tensor, epsilon: float = 1e-6 86 | ) -> torch.Tensor: 87 | """The Frechet distance between multivariate normal distributions.""" 88 | diff = mu_x - mu_y 89 | 90 | covmean = _sqrtm(sigma_x.mm(sigma_y)) 91 | 92 | # Product might be almost singular 93 | if not torch.isfinite(covmean).all(): 94 | print(f"FID calculation produces singular product; adding {epsilon} to diagonal of covariance estimates") 95 | offset = torch.eye(sigma_x.size(0), device=mu_x.device, dtype=mu_x.dtype) * epsilon 96 | covmean = _sqrtm((sigma_x + offset).mm(sigma_y + offset)) 97 | 98 | # Numerical error might give slight imaginary component 99 | if torch.is_complex(covmean): 100 | if not torch.allclose(torch.diagonal(covmean).imag, torch.tensor(0, dtype=torch.double), atol=1e-3): 101 | raise ValueError(f"Imaginary component {torch.max(torch.abs(covmean.imag))} too high.") 102 | covmean = covmean.real 103 | 104 | tr_covmean = torch.trace(covmean) 105 | return diff.dot(diff) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2 * tr_covmean 106 | -------------------------------------------------------------------------------- /generative/engines/prepare_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from typing import Dict, Mapping, Optional, Union 15 | 16 | import torch 17 | import torch.nn as nn 18 | from monai.engines import PrepareBatch, default_prepare_batch 19 | 20 | 21 | class DiffusionPrepareBatch(PrepareBatch): 22 | """ 23 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. 24 | 25 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and 26 | return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise". 27 | This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided. 28 | 29 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition 30 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". 31 | 32 | """ 33 | 34 | def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: 35 | self.condition_name = condition_name 36 | self.num_train_timesteps = num_train_timesteps 37 | 38 | def get_noise(self, images: torch.Tensor) -> torch.Tensor: 39 | """Returns the noise tensor for input tensor `images`, override this for different noise distributions.""" 40 | return torch.randn_like(images) 41 | 42 | def get_timesteps(self, images: torch.Tensor) -> torch.Tensor: 43 | """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`.""" 44 | return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long() 45 | 46 | def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: 47 | """Return the target for the loss function, this is the `noise` value by default.""" 48 | return noise 49 | 50 | def __call__( 51 | self, 52 | batchdata: Dict[str, torch.Tensor], 53 | device: Optional[Union[str, torch.device]] = None, 54 | non_blocking: bool = False, 55 | **kwargs, 56 | ): 57 | images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs) 58 | noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs) 59 | timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs) 60 | 61 | target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs) 62 | infer_kwargs = {"noise": noise, "timesteps": timesteps} 63 | 64 | if self.condition_name is not None and isinstance(batchdata, Mapping): 65 | infer_kwargs["conditioning"] = batchdata[self.condition_name].to( 66 | device, non_blocking=non_blocking, **kwargs 67 | ) 68 | 69 | # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value 70 | return images, target, (), infer_kwargs 71 | 72 | 73 | class VPredictionPrepareBatch(DiffusionPrepareBatch): 74 | """ 75 | This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training. 76 | 77 | Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and 78 | from this compute the velocity using the provided scheduler. This value is used as the target in place of the 79 | noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer 80 | being used in conjunction with this class expects a "noise" parameter to be provided. 81 | 82 | If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition 83 | field to be passed to the inferer. This will appear in the keyword arguments under the key "condition". 84 | 85 | """ 86 | 87 | def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None: 88 | super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name) 89 | self.scheduler = scheduler 90 | 91 | def get_target(self, images, noise, timesteps): 92 | return self.scheduler.get_velocity(images, noise, timesteps) 93 | -------------------------------------------------------------------------------- /utils/Condition_aug_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | import glob 4 | from torchvision import transforms 5 | import cv2 6 | from os.path import join 7 | from PIL import Image 8 | import os 9 | from natsort import natsorted 10 | from os.path import dirname as di 11 | import torch 12 | import sys;sys.path.append('./') 13 | from utils.common import load_config 14 | import pickle 15 | 16 | def get_image_address(updir): 17 | files = natsorted(os.listdir(updir)) 18 | slo_list = [] 19 | for file_name in files: 20 | # 构建旧路径 21 | old_path = os.path.join(updir, file_name) 22 | 23 | # 获取文件名 24 | base_name, picture_form = os.path.splitext(file_name) 25 | 26 | if os.path.isfile(old_path): 27 | if len(base_name.split('.')) == 1: 28 | new_name = f'{base_name}{picture_form}' 29 | slo_list.append(os.path.join(updir, new_name)) 30 | return slo_list 31 | 32 | def get_address_list(up_dir, picture_form: str): 33 | if up_dir[-1] != '/': 34 | up_dir = f'{up_dir}/' 35 | return glob.glob(up_dir+'*.'+picture_form) 36 | 37 | 38 | 39 | class Double_dataset(data.Dataset): 40 | def __init__(self, data_path, img_size, 41 | mode='double', read_channel='color', mask=None, 42 | data_aug=True): 43 | ''' 44 | data_path: the up dir of data 45 | img_size: what size of image you want to read (tuple, int) 46 | mode: vary from: 1. 'double' 2. 'first' 3. 'second' 47 | read_channel: 'color' or 'gray' 48 | ''' 49 | super(Double_dataset, self).__init__() 50 | if isinstance(data_path, list): 51 | self.img_path = [] 52 | for path in data_path: 53 | self.img_path += get_image_address(path) 54 | else: 55 | self.img_path = get_image_address(data_path) 56 | if isinstance(img_size, int): 57 | img_size = (img_size, img_size) 58 | 59 | basic_trans_list = [ 60 | transforms.ToTensor(), 61 | # transforms.Resize((512, 640), antialias=True), 62 | ] 63 | 64 | self.data_aug = data_aug 65 | if data_aug: 66 | self.augmentator = transforms.Compose([ 67 | # transforms.RandomRotation(5), 68 | transforms.Resize((256, 256)), 69 | # transforms.RandomVerticalFlip(), 70 | # transforms.RandomHorizontalFlip(), 71 | # transforms.Normalize(mean=0.5, std=0.5) 72 | ]) 73 | else: 74 | basic_trans_list.append(transforms.Resize(img_size, antialias=True)) 75 | basic_trans_list.append(transforms.Normalize(mean=0.5, std=0.5)) 76 | self.transformer = transforms.Compose(basic_trans_list) 77 | self.mode = mode 78 | if read_channel == 'color': 79 | self.img_reader = self.colorloader 80 | else: 81 | self.img_reader = self.grayloader 82 | 83 | self.slo_reader = self.grayloader 84 | self.mask = mask 85 | 86 | def double_get(self, pt_path, mask) -> list: 87 | f_kspace = pickle.load(open(pt_path, 'rb'))['img'] 88 | 89 | f_kspace1 = np.fft.fft2(f_kspace) 90 | f_kspace1 = f_kspace1 * mask 91 | mri_mask = np.fft.ifft2(f_kspace1) 92 | 93 | mri = np.abs(f_kspace) 94 | mri_mask = np.abs(mri_mask) 95 | 96 | mri_min = np.min(mri) 97 | mri_max = np.max(mri) 98 | mri = 2 * (mri - mri_min) / (mri_max - mri_min) - 1 99 | 100 | mri_mask_min = np.min(mri_mask) 101 | mri_mask_max = np.max(mri_mask) 102 | mri_mask = 2 * (mri_mask - mri_mask_min) / (mri_mask_max - mri_mask_min) - 1 103 | 104 | 105 | mri_mask = self.transformer(mri_mask) 106 | mri = self.transformer(mri) 107 | 108 | if self.data_aug: 109 | mri = self.augmentator(mri) 110 | mri_mask = self.augmentator(mri_mask) 111 | else: 112 | mri = mri 113 | mri_mask = mri_mask 114 | return mri, mri_mask 115 | 116 | def __getitem__(self, index)->list: 117 | slo_name = self.img_path[index] 118 | var_list, info = None, None 119 | if self.mode == 'double': 120 | var_list, info = self.double_get(slo_name, mask=self.mask) 121 | 122 | return var_list, info 123 | 124 | def __len__(self): 125 | return len(self.img_path) 126 | 127 | def colorloader(self, path): 128 | with open(path, 'rb') as f: 129 | img = Image.open(f) 130 | return img.convert('RGB') 131 | 132 | def grayloader(self, path): 133 | with open(path, 'rb') as f: 134 | img = Image.open(f) 135 | return img.convert('L') 136 | 137 | def double_form_dataloader(updir, image_size, batch_size, mode, 138 | read_channel='color', mask = None, data_aug=True, 139 | shuffle=True, drop_last=True): 140 | dataset = Double_dataset(updir, image_size, mode, read_channel, mask, data_aug) 141 | return data.DataLoader(dataset, batch_size, shuffle=shuffle, drop_last=drop_last) 142 | -------------------------------------------------------------------------------- /mask/mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 3 | 4 | import contextlib 5 | import numpy as np 6 | import torch 7 | 8 | import pickle 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | ACC4 = { 13 | "mask_type": "random", 14 | "center_fractions": [0.08], 15 | "accelerations": [4] 16 | } 17 | ACC8 = { 18 | "mask_type": "random", 19 | "center_fractions": [0.04], 20 | "accelerations": [8] 21 | } 22 | 23 | 24 | @contextlib.contextmanager 25 | def temp_seed(rng, seed): 26 | state = rng.get_state() 27 | rng.seed(seed) 28 | try: 29 | yield 30 | finally: 31 | rng.set_state(state) 32 | 33 | 34 | class MaskFunc(object): 35 | """ 36 | An object for GRAPPA-style sampling masks. 37 | 38 | This crates a sampling mask that densely samples the center while 39 | subsampling outer k-space regions based on the under-sampling factor. 40 | """ 41 | 42 | def __init__(self, center_fractions, accelerations): 43 | """ 44 | :param center_fractions: list of float, fraction of low-frequency columns to be 45 | retained. If multiple values are provided, then one of these 46 | numbers is chosen uniformly each time. 47 | :param accelerations: list of int, amount of under-sampling. This should have 48 | the same length as center_fractions. If multiple values are 49 | provided, then one of these is chosen uniformly each time. 50 | """ 51 | if len(center_fractions) != len(accelerations): 52 | raise ValueError("number of center fractions should match number of accelerations.") 53 | 54 | self.center_fractions = center_fractions 55 | self.accelerations = accelerations 56 | self.rng = np.random 57 | 58 | def choose_acceleration(self): 59 | """ 60 | Choose acceleration based on class parameters. 61 | """ 62 | choice = self.rng.randint(0, len(self.accelerations)) 63 | center_fraction = self.center_fractions[choice] 64 | acceleration = self.accelerations[choice] 65 | 66 | return center_fraction, acceleration 67 | 68 | 69 | class RandomMaskFunc(MaskFunc): 70 | """ 71 | RandomMaskFunc creates a sub-sampling mask of a given shape. 72 | 73 | The mask selects a subset of columns from the input k-space data. If the 74 | k-space data has N columns, the mask picks out: 75 | 1. N_low_freqs = (N * center_fraction) columns in the center 76 | corresponding to low-frequencies. 77 | 2. The other columns are selected uniformly at random with a 78 | probability equal to: prob = (N / acceleration - N_low_freqs) / 79 | (N - N_low_freqs). This ensures that the expected number of columns 80 | selected is equal to (N / acceleration). 81 | 82 | It is possible to use multiple center_fractions and accelerations, in which 83 | case one possible (center_fraction, acceleration) is chosen uniformly at 84 | random each time the RandomMaskFunc object is called. 85 | 86 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], 87 | then there is a 50% probability that 4-fold acceleration with 8% center 88 | fraction is selected and a 50% probability that 8-fold acceleration with 4% 89 | center fraction is selected. 90 | """ 91 | 92 | def __call__(self, shape, seed=None): 93 | """ 94 | Create the mask. 95 | 96 | :param shape: (iterable[int]), the shape of the mask to be created. 97 | :param seed: (int, optional), seed for the random number generator. Setting 98 | the seed ensures the same mask is generated each time for the 99 | same shape. The random state is reset afterwards. 100 | :return torch.Tensor, a mask of the specified shape. Its shape should be 101 | (2, height, width) and the two channels are the same. 102 | """ 103 | with temp_seed(self.rng, seed): 104 | num_cols = shape[-1] 105 | center_fraction, acceleration = self.choose_acceleration() 106 | 107 | # create the mask 108 | num_low_freqs = int(round(num_cols * center_fraction)) 109 | prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) 110 | mask_location = self.rng.uniform(size=num_cols) < prob 111 | pad = (num_cols - num_low_freqs + 1) // 2 112 | mask_location[pad: pad + num_low_freqs] = True 113 | mask = np.zeros(shape, dtype=np.float32) 114 | mask[..., mask_location] = 1.0 115 | 116 | return mask 117 | 118 | 119 | def create_mask_for_mask_type(mask_type_str, center_fractions, accelerations): 120 | if mask_type_str == "random": 121 | return RandomMaskFunc(center_fractions, accelerations) 122 | else: 123 | raise Exception(f"{mask_type_str} not supported") 124 | 125 | # mask生成 126 | # 4倍 127 | # a = create_mask_for_mask_type("random",[0.08], [4]) 128 | # 6倍 129 | # a = create_mask_for_mask_type("random",[0.06], [6]) 130 | # 8倍 131 | # a = create_mask_for_mask_type("random",[0.04], [8]) 132 | # 10倍 133 | # a = create_mask_for_mask_type("random",[0.02], [10]) 134 | # mask = a((1, 320, 320)) 135 | # torch.save(mask, 'mask_10.pt') 136 | 137 | # mask展示 138 | # temp=torch.load(open('./mask_10.pt','rb')) 139 | # plt.imshow(temp[0], cmap='gray') 140 | # plt.show() 141 | -------------------------------------------------------------------------------- /stable_diffusion/val_model.py: -------------------------------------------------------------------------------- 1 | import sys;sys.path.append('./') 2 | 3 | from tqdm import tqdm 4 | import numpy 5 | from utils.Condition_aug_dataloader import double_form_dataloader 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from generative.networks.nets import AutoencoderKL 10 | from generative.inferers import DiffusionInferer 11 | from diffusers import DDPMScheduler 12 | 13 | from utils.common import get_parameters, save_config_to_yaml, one_to_three 14 | from config.diffusion.config_controlnet import Config 15 | from os.path import join as j 16 | from accelerate import Accelerator 17 | from torchvision.utils import make_grid, save_image 18 | from unet.MF_UKAN import MF_UKAN 19 | from unet.MC_model import MC_MODEL 20 | import os 21 | 22 | def main(): 23 | mask = torch.load(Config.mask_path, map_location=torch.device('cpu')) 24 | mask = numpy.squeeze(mask) 25 | cf = save_config_to_yaml(Config, Config.project_dir) 26 | accelerator = Accelerator(**get_parameters(Accelerator, cf)) 27 | device = 'cuda' 28 | val_dataloader = double_form_dataloader( 29 | Config.eval_path, 30 | Config.sample_size, 31 | Config.eval_bc, 32 | Config.mode, 33 | read_channel='gray', 34 | mask = mask) 35 | attention_levels = (False, ) * len(Config.up_and_down) 36 | vae = AutoencoderKL( 37 | spatial_dims=2, 38 | in_channels=Config.in_channels, 39 | out_channels=Config.out_channels, 40 | num_channels=Config.up_and_down, 41 | latent_channels=4, 42 | num_res_blocks=Config.num_res_layers, 43 | attention_levels = attention_levels 44 | ) 45 | vae = vae.eval().to(device) 46 | if len(Config.vae_resume_path): 47 | vae.load_state_dict(torch.load(Config.vae_resume_path)) 48 | 49 | model =MF_UKAN( 50 | T=1000, 51 | ch=64, 52 | ch_mult=[1, 2, 3, 4], 53 | attn=[2], 54 | num_res_blocks=2, 55 | dropout=0.15).to(device) 56 | if len(Config.sd_resume_path): 57 | model.load_state_dict(torch.load(Config.sd_resume_path), strict=False) 58 | model = model.to(device) 59 | 60 | mc_model =MC_MODEL( 61 | T=1000, 62 | ch=64, 63 | ch_mult=[1, 2, 3, 4], 64 | attn=[2], 65 | num_res_blocks=2, 66 | dropout=0.15).to(device) 67 | if len(Config.sd_resume_path): 68 | mc_model.load_state_dict(torch.load(Config.mc_model_path), strict=False) 69 | mc_model = mc_model.to(device) 70 | 71 | 72 | # scheduler = DDPMScheduler(num_train_timesteps=1000) 73 | # 动态裁剪策略 74 | scheduler = DDPMScheduler(num_train_timesteps=1000, 75 | beta_start=Config.beta_start, 76 | beta_end=Config.beta_end, 77 | beta_schedule=Config.beta_schedule, 78 | clip_sample=Config.clip_sample, 79 | clip_sample_range=Config.initial_clip_sample_range, 80 | ) 81 | 82 | if len(Config.log_with): 83 | accelerator.init_trackers('train_example') 84 | 85 | latent_shape = None 86 | # scaling_factor = 1 / torch.std(next(iter(train_dataloader))[0][1]) 87 | scaling_factor = Config.scaling_factor 88 | for step, batch in enumerate(val_dataloader): 89 | model.eval() 90 | mc_model.eval() 91 | progress_bar = tqdm(total=len(val_dataloader)) 92 | progress_bar.set_description(f"No. {step+1}") 93 | mri = batch[0].float().to(device) 94 | mri_mask = batch[1].float().to(device) 95 | pt_name = batch[2][0].split('.')[0] 96 | with torch.no_grad(): 97 | latent_mri = vae.encode_stage_2_inputs(mri) 98 | latent_shape = list(latent_mri.shape);latent_shape[0] = Config.eval_bc 99 | noise = torch.randn(latent_shape).to(device) 100 | progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110, position=0, leave=True) 101 | with torch.no_grad(): 102 | for i, t in enumerate(progress_bar_sampling): 103 | t_tensor = torch.tensor([t], dtype=torch.long).to(device) 104 | down_block_res_samples, _ = mc_model( 105 | x=noise, t=t_tensor, controlnet_cond=mri_mask 106 | ) 107 | noise_pred = model( 108 | noise, 109 | t=t_tensor, 110 | down_block_additional_residuals=down_block_res_samples 111 | ) 112 | 113 | scheduler = DDPMScheduler(num_train_timesteps=1000, 114 | beta_start=Config.beta_start, 115 | beta_end=Config.beta_end, 116 | beta_schedule=Config.beta_schedule, 117 | clip_sample=Config.clip_sample, 118 | clip_sample_range=Config.initial_clip_sample_range - Config.clip_rate * i, 119 | ) 120 | noise = scheduler.step(model_output=noise_pred, timestep=t, sample=noise).prev_sample 121 | with torch.no_grad(): 122 | image = vae.decode_stage_2_outputs(noise / scaling_factor) 123 | 124 | image, mri = one_to_three(image), one_to_three(mri), 125 | image = torch.cat([mri, image], dim=-1) 126 | image = (make_grid(image, nrow=1).unsqueeze(0)+1)/2 127 | log_image = {"MRI": image.clamp(0, 1)} 128 | save_path = j(Config.output_dir, 'image_save') 129 | os.makedirs(save_path, exist_ok=True) 130 | save_image(log_image["MRI"], j(save_path, f'{pt_name}.png')) 131 | # accelerator.trackers[0].log_images(log_image, epoch+1) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /generative/networks/blocks/selfattention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import importlib.util 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | 21 | if importlib.util.find_spec("xformers") is not None: 22 | import xformers.ops as xops 23 | 24 | has_xformers = True 25 | else: 26 | has_xformers = False 27 | 28 | 29 | class SABlock(nn.Module): 30 | """ 31 | A self-attention block, based on: "Dosovitskiy et al., 32 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 33 | 34 | Args: 35 | hidden_size: dimension of hidden layer. 36 | num_heads: number of attention heads. 37 | dropout_rate: dropout ratio. Defaults to no dropout. 38 | qkv_bias: bias term for the qkv linear layer. 39 | causal: whether to use causal attention. 40 | sequence_length: if causal is True, it is necessary to specify the sequence length. 41 | with_cross_attention: Whether to use cross attention for conditioning. 42 | use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | hidden_size: int, 48 | num_heads: int, 49 | dropout_rate: float = 0.0, 50 | qkv_bias: bool = False, 51 | causal: bool = False, 52 | sequence_length: int | None = None, 53 | with_cross_attention: bool = False, 54 | use_flash_attention: bool = False, 55 | ) -> None: 56 | super().__init__() 57 | self.hidden_size = hidden_size 58 | self.num_heads = num_heads 59 | self.head_dim = hidden_size // num_heads 60 | self.scale = 1.0 / math.sqrt(self.head_dim) 61 | self.causal = causal 62 | self.sequence_length = sequence_length 63 | self.with_cross_attention = with_cross_attention 64 | self.use_flash_attention = use_flash_attention 65 | 66 | if not (0 <= dropout_rate <= 1): 67 | raise ValueError("dropout_rate should be between 0 and 1.") 68 | self.dropout_rate = dropout_rate 69 | 70 | if hidden_size % num_heads != 0: 71 | raise ValueError("hidden size should be divisible by num_heads.") 72 | 73 | if causal and sequence_length is None: 74 | raise ValueError("sequence_length is necessary for causal attention.") 75 | 76 | if use_flash_attention and not has_xformers: 77 | raise ValueError("use_flash_attention is True but xformers is not installed.") 78 | 79 | # key, query, value projections 80 | self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 81 | self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 82 | self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 83 | 84 | # regularization 85 | self.drop_weights = nn.Dropout(dropout_rate) 86 | self.drop_output = nn.Dropout(dropout_rate) 87 | 88 | # output projection 89 | self.out_proj = nn.Linear(hidden_size, hidden_size) 90 | 91 | if causal and sequence_length is not None: 92 | # causal mask to ensure that attention is only applied to the left in the input sequence 93 | self.register_buffer( 94 | "causal_mask", 95 | torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), 96 | ) 97 | 98 | def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: 99 | b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) 100 | 101 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 102 | query = self.to_q(x) 103 | 104 | kv = context if context is not None else x 105 | _, kv_t, _ = kv.size() 106 | key = self.to_k(kv) 107 | value = self.to_v(kv) 108 | 109 | query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) 110 | key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) 111 | value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) 112 | 113 | if self.use_flash_attention: 114 | query = query.contiguous() 115 | key = key.contiguous() 116 | value = value.contiguous() 117 | y = xops.memory_efficient_attention( 118 | query=query, 119 | key=key, 120 | value=value, 121 | scale=self.scale, 122 | p=self.dropout_rate, 123 | attn_bias=xops.LowerTriangularMask() if self.causal else None, 124 | ) 125 | 126 | else: 127 | query = query.transpose(1, 2) # (b, nh, t, hs) 128 | key = key.transpose(1, 2) # (b, nh, kv_t, hs) 129 | value = value.transpose(1, 2) # (b, nh, kv_t, hs) 130 | 131 | # manual implementation of attention 132 | query = query * self.scale 133 | attention_scores = query @ key.transpose(-2, -1) 134 | 135 | if self.causal: 136 | attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) 137 | 138 | attention_probs = F.softmax(attention_scores, dim=-1) 139 | attention_probs = self.drop_weights(attention_probs) 140 | y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) 141 | 142 | y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) 143 | 144 | y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side 145 | 146 | y = self.out_proj(y) 147 | y = self.drop_output(y) 148 | return y 149 | -------------------------------------------------------------------------------- /stable_diffusion/train_sd.py: -------------------------------------------------------------------------------- 1 | import sys;sys.path.append('./') 2 | 3 | from tqdm import tqdm 4 | import numpy 5 | from utils.Condition_aug_dataloader import double_form_dataloader 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from generative.networks.nets import AutoencoderKL 10 | from generative.inferers import DiffusionInferer 11 | from generative.networks.schedulers import DDPMScheduler 12 | 13 | from utils.common import get_parameters, save_config_to_yaml 14 | from config.diffusion.config_controlnet import Config 15 | from os.path import join as j 16 | from accelerate import Accelerator 17 | from torchvision.utils import make_grid, save_image 18 | from unet.MF_UKAN import MF_UKAN 19 | import os 20 | 21 | def main(): 22 | mask = torch.load(Config.mask_path, map_location=torch.device('cpu')) 23 | mask = numpy.squeeze(mask) 24 | cf = save_config_to_yaml(Config, Config.project_dir) 25 | accelerator = Accelerator(**get_parameters(Accelerator, cf)) 26 | train_dataloader = double_form_dataloader(Config.data_path, 27 | Config.sample_size, 28 | Config.train_bc, 29 | Config.mode, 30 | read_channel='gray', 31 | mask=mask) 32 | device = 'cuda' 33 | val_dataloader = double_form_dataloader( 34 | Config.eval_path, 35 | Config.sample_size, 36 | Config.eval_bc, 37 | Config.mode, 38 | read_channel='gray', 39 | mask=mask) 40 | 41 | device = 'cuda' 42 | attention_levels = (False, ) * len(Config.up_and_down) 43 | vae = AutoencoderKL( 44 | spatial_dims=2, 45 | in_channels=Config.in_channels, 46 | out_channels=Config.out_channels, 47 | num_channels=Config.up_and_down, 48 | latent_channels=4, 49 | num_res_blocks=Config.num_res_layers, 50 | attention_levels = attention_levels 51 | ) 52 | vae = vae.eval().to(device) 53 | if len(Config.vae_resume_path): 54 | vae.load_state_dict(torch.load(Config.vae_resume_path)) 55 | 56 | 57 | model =MF_UKAN( 58 | T=1000, 59 | ch=64, 60 | ch_mult=[1, 2, 3, 4], 61 | attn=[2], 62 | num_res_blocks=2, 63 | dropout=0.15).to(device) 64 | if len(Config.sd_resume_path): 65 | model.load_state_dict(torch.load(Config.sd_resume_path), strict=False) 66 | model = model.to(device) 67 | 68 | scheduler = DDPMScheduler(num_train_timesteps=1000, clip_sample=True) 69 | # optimizer_con = torch.optim.Adam(params=controlnet.parameters(), lr=2.5e-5) 70 | optimizer_sd = torch.optim.Adam(params=model.parameters(), lr=2.5e-5) 71 | inferer = DiffusionInferer(scheduler) 72 | 73 | val_interval = Config.val_inter 74 | save_interval = Config.save_inter 75 | 76 | if len(Config.log_with): 77 | accelerator.init_trackers('train_example') 78 | 79 | global_step = 0 80 | latent_shape = None 81 | # scaling_factor = 1 / torch.std(next(iter(train_dataloader))[0][1]) 82 | scaling_factor = Config.scaling_factor 83 | for epoch in range(Config.num_epochs): 84 | model.train() 85 | # controlnet.train() 86 | epoch_loss = 0 87 | progress_bar = tqdm(total=len(train_dataloader)) 88 | progress_bar.set_description(f"Epoch {epoch+1}") 89 | for step, batch in enumerate(train_dataloader): 90 | mri = batch[0].float().to(device) 91 | # optimizer_con.zero_grad(set_to_none=True) 92 | optimizer_sd.zero_grad(set_to_none=True) 93 | with torch.no_grad(): 94 | latented_mri = vae.encode_stage_2_inputs(mri) 95 | latented_mri = latented_mri * scaling_factor 96 | latent_shape = list(latented_mri.shape);latent_shape[0] = Config.train_bc 97 | latented_noise = torch.randn_like(latented_mri) 98 | timesteps = torch.randint( 99 | 0, inferer.scheduler.num_train_timesteps, (latented_mri.shape[0],), device=latented_mri.device 100 | ).long() 101 | latented_mri_noised = scheduler.add_noise(latented_mri, latented_noise, timesteps) 102 | # down_block_res_samples, mid_block_res_sample = controlnet( 103 | # x=ffa_noised, timesteps=timesteps, controlnet_cond=slo 104 | # ) 105 | noise_pred = model( 106 | x=latented_mri_noised, 107 | t=timesteps, 108 | ) 109 | loss = F.mse_loss(noise_pred.float(), latented_noise.float()) 110 | loss.backward() 111 | optimizer_sd.step() 112 | epoch_loss += loss.item() 113 | logs = {"loss": epoch_loss / (step + 1)} 114 | progress_bar.update() 115 | progress_bar.set_postfix(logs) 116 | accelerator.log(logs, step=global_step) 117 | global_step += 1 118 | 119 | if (epoch + 1) % val_interval == 0 or epoch == cf['num_epochs'] - 1: 120 | model.eval() 121 | batch = next(iter(val_dataloader)) 122 | mri = batch[0].float().to(device) 123 | noise = torch.randn(latent_shape).to(device) 124 | progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110, position=0, leave=True) 125 | with torch.no_grad(): 126 | for t in progress_bar_sampling: 127 | t_tensor = torch.tensor([t], dtype=torch.long).to(device) 128 | noise_pred = model( 129 | x=noise, 130 | t=t_tensor, 131 | ) 132 | noise, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=noise) 133 | 134 | 135 | with torch.no_grad(): 136 | image = vae.decode_stage_2_outputs(noise / scaling_factor) 137 | image = torch.cat([mri, image], dim=-1) 138 | image = (make_grid(image, nrow=1).unsqueeze(0)+1)/2 139 | log_image = {"MRI": image.clip(0, 1)} 140 | save_path = j(Config.project_dir, 'image_save') 141 | os.makedirs(save_path, exist_ok=True) 142 | save_image(log_image["MRI"], j(save_path, f'epoch_{epoch + 1}_firstMRI.png')) 143 | 144 | # accelerator.trackers[0].log_images(log_image, epoch+1) 145 | 146 | if (epoch + 1) % save_interval == 0 or epoch == cf['num_epochs'] - 1: 147 | save_path = j(Config.project_dir, 'model_save') 148 | os.makedirs(save_path, exist_ok=True) 149 | torch.save(model.state_dict(), j(save_path, 'model.pth')) 150 | # torch.save(controlnet.state_dict(), j(save_path, 'controlnet.pth')) 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /generative/metrics/ms_ssim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Sequence 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from monai.metrics.regression import RegressionMetric 19 | from monai.utils import MetricReduction, StrEnum, ensure_tuple_rep 20 | 21 | from generative.metrics.ssim import compute_ssim_and_cs 22 | 23 | 24 | class KernelType(StrEnum): 25 | GAUSSIAN = "gaussian" 26 | UNIFORM = "uniform" 27 | 28 | 29 | class MultiScaleSSIMMetric(RegressionMetric): 30 | """ 31 | Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM). 32 | 33 | [1] Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. 34 | Multiscale structural similarity for image quality assessment. 35 | In The Thirty-Seventh Asilomar Conference on Signals, Systems 36 | & Computers, 2003 (Vol. 2, pp. 1398-1402). Ieee. 37 | 38 | Args: 39 | spatial_dims: number of spatial dimensions of the input images. 40 | data_range: value range of input images. (usually 1.0 or 255) 41 | kernel_type: type of kernel, can be "gaussian" or "uniform". 42 | kernel_size: size of kernel 43 | kernel_sigma: standard deviation for Gaussian kernel. 44 | k1: stability constant used in the luminance denominator 45 | k2: stability constant used in the contrast denominator 46 | weights: parameters for image similarity and contrast sensitivity at different resolution scores. 47 | reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, 48 | available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, 49 | ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction 50 | get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) 51 | """ 52 | 53 | def __init__( 54 | self, 55 | spatial_dims: int, 56 | data_range: float = 1.0, 57 | kernel_type: KernelType | str = KernelType.GAUSSIAN, 58 | kernel_size: int | Sequence[int, ...] = 11, 59 | kernel_sigma: float | Sequence[float, ...] = 1.5, 60 | k1: float = 0.01, 61 | k2: float = 0.03, 62 | weights: Sequence[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), 63 | reduction: MetricReduction | str = MetricReduction.MEAN, 64 | get_not_nans: bool = False, 65 | ) -> None: 66 | super().__init__(reduction=reduction, get_not_nans=get_not_nans) 67 | 68 | self.spatial_dims = spatial_dims 69 | self.data_range = data_range 70 | self.kernel_type = kernel_type 71 | 72 | if not isinstance(kernel_size, Sequence): 73 | kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) 74 | self.kernel_size = kernel_size 75 | 76 | if not isinstance(kernel_sigma, Sequence): 77 | kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) 78 | self.kernel_sigma = kernel_sigma 79 | 80 | self.k1 = k1 81 | self.k2 = k2 82 | self.weights = weights 83 | 84 | def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 85 | """ 86 | Args: 87 | y_pred: Predicted image. 88 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 89 | y: Reference image. 90 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 91 | 92 | Raises: 93 | ValueError: when `y_pred` is not a 2D or 3D image. 94 | """ 95 | dims = y_pred.ndimension() 96 | if self.spatial_dims == 2 and dims != 4: 97 | raise ValueError( 98 | f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " 99 | f"spatial dimensions, got {dims}." 100 | ) 101 | 102 | if self.spatial_dims == 3 and dims != 5: 103 | raise ValueError( 104 | f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" 105 | f" spatial dimensions, got {dims}." 106 | ) 107 | 108 | # check if image have enough size for the number of downsamplings and the size of the kernel 109 | weights_div = max(1, (len(self.weights) - 1)) ** 2 110 | y_pred_spatial_dims = y_pred.shape[2:] 111 | for i in range(len(y_pred_spatial_dims)): 112 | if y_pred_spatial_dims[i] // weights_div <= self.kernel_size[i] - 1: 113 | raise ValueError( 114 | f"For a given number of `weights` parameters {len(self.weights)} and kernel size " 115 | f"{self.kernel_size[i]}, the image height must be larger than " 116 | f"{(self.kernel_size[i] - 1) * weights_div}." 117 | ) 118 | 119 | weights = torch.tensor(self.weights, device=y_pred.device, dtype=torch.float) 120 | 121 | avg_pool = getattr(F, f"avg_pool{self.spatial_dims}d") 122 | 123 | multiscale_list: list[torch.Tensor] = [] 124 | for _ in range(len(weights)): 125 | ssim, cs = compute_ssim_and_cs( 126 | y_pred=y_pred, 127 | y=y, 128 | spatial_dims=self.spatial_dims, 129 | data_range=self.data_range, 130 | kernel_type=self.kernel_type, 131 | kernel_size=self.kernel_size, 132 | kernel_sigma=self.kernel_sigma, 133 | k1=self.k1, 134 | k2=self.k2, 135 | ) 136 | 137 | cs_per_batch = cs.view(cs.shape[0], -1).mean(1) 138 | 139 | multiscale_list.append(torch.relu(cs_per_batch)) 140 | y_pred = avg_pool(y_pred, kernel_size=2) 141 | y = avg_pool(y, kernel_size=2) 142 | 143 | ssim = ssim.view(ssim.shape[0], -1).mean(1) 144 | multiscale_list[-1] = torch.relu(ssim) 145 | multiscale_list = torch.stack(multiscale_list) 146 | 147 | ms_ssim_value_full_image = torch.prod(multiscale_list ** weights.view(-1, 1), dim=0) 148 | 149 | ms_ssim_per_batch: torch.Tensor = ms_ssim_value_full_image.view(ms_ssim_value_full_image.shape[0], -1).mean( 150 | 1, keepdim=True 151 | ) 152 | 153 | return ms_ssim_per_batch 154 | -------------------------------------------------------------------------------- /my_vqvae/train_vae.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | from tqdm import tqdm;sys.path.append('./') 5 | 6 | import torch 7 | import numpy 8 | from torch.nn import functional as F 9 | from generative.networks.nets import AutoencoderKL 10 | from generative.losses import PatchAdversarialLoss, PerceptualLoss 11 | from generative.networks.nets import PatchDiscriminator 12 | from utils.Condition_aug_dataloader import double_form_dataloader 13 | from utils.common import get_parameters, save_config_to_yaml 14 | from config.vae.config_monaivae_zheer import Config 15 | from os.path import join as j 16 | from accelerate import Accelerator 17 | from torchvision.utils import make_grid, save_image 18 | import os 19 | 20 | def main(): 21 | mask = torch.load(Config.mask_path, map_location=torch.device('cpu')) 22 | mask = numpy.squeeze(mask) 23 | cf = save_config_to_yaml(Config, Config.project_dir) 24 | accelerator = Accelerator(**get_parameters(Accelerator, cf)) 25 | train_dataloader = double_form_dataloader(Config.data_path, 26 | Config.sample_size, 27 | Config.train_bc, 28 | Config.mode, 29 | read_channel='gray', 30 | mask=mask) 31 | device = 'cuda' 32 | val_dataloader = double_form_dataloader( 33 | Config.eval_path, 34 | Config.sample_size, 35 | Config.eval_bc, 36 | Config.mode, 37 | read_channel='gray', 38 | mask=mask) 39 | 40 | up_and_down = Config.up_and_down 41 | attention_levels = (False, ) * len(up_and_down) 42 | vae = AutoencoderKL( 43 | spatial_dims=2, 44 | in_channels=Config.in_channels, 45 | out_channels=Config.out_channels, 46 | num_channels=Config.up_and_down, 47 | latent_channels=4, 48 | num_res_blocks=Config.num_res_layers, 49 | attention_levels = attention_levels 50 | ) 51 | 52 | discriminator = PatchDiscriminator(spatial_dims=2, num_channels=64, 53 | in_channels=Config.out_channels, out_channels=1) 54 | 55 | perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="vgg") 56 | vae, discriminator, perceptual_loss = vae.to(device), discriminator.to(device), perceptual_loss.to(device) 57 | 58 | adv_loss = PatchAdversarialLoss(criterion="least_squares") 59 | adv_weight = 0.01 60 | perceptual_weight = 0.001 61 | kl_weight = 1e-6 62 | 63 | optimizer_g = torch.optim.Adam(params=vae.parameters(), lr=1e-4) 64 | optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4) 65 | 66 | val_interval = Config.val_inter 67 | save_interval = Config.save_inter 68 | autoencoder_warm_up_n_epochs = 0 if (len(Config.vae_path) and len(Config.dis_path)) else Config.autoencoder_warm_up_n_epochs 69 | 70 | 71 | if len(Config.log_with): 72 | accelerator.init_trackers('train_example') 73 | 74 | global_step = 0 75 | for epoch in range(Config.num_epochs): 76 | progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) 77 | progress_bar.set_description(f"Epoch {epoch+1}") 78 | vae.train() 79 | discriminator.train() 80 | for step, batch in enumerate(train_dataloader): 81 | images = batch[0].to(device).clip(-1, 1).float() 82 | # print("Image size: ", images.shape) 83 | optimizer_g.zero_grad(set_to_none=True) 84 | 85 | reconstruction, z_mu, z_sigma = vae(images) 86 | 87 | recons_loss = F.mse_loss(reconstruction.float(), images.float()) 88 | 89 | kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3]) 90 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 91 | p_loss = perceptual_loss(reconstruction.float(), images.float()) 92 | loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss 93 | # loss_g = recons_loss + perceptual_weight * p_loss 94 | 95 | if epoch+1 > autoencoder_warm_up_n_epochs: 96 | logits_fake = discriminator(reconstruction.contiguous().float())[-1] 97 | generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False) 98 | loss_g += adv_weight * generator_loss 99 | 100 | loss_g.backward() 101 | optimizer_g.step() 102 | 103 | if epoch+1 > autoencoder_warm_up_n_epochs: 104 | optimizer_d.zero_grad(set_to_none=True) 105 | 106 | logits_fake = discriminator(reconstruction.contiguous().detach())[-1] 107 | loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True) 108 | logits_real = discriminator(images.contiguous().detach())[-1] 109 | loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True) 110 | discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 111 | 112 | loss_d = adv_weight * discriminator_loss 113 | 114 | loss_d.backward() 115 | optimizer_d.step() 116 | 117 | progress_bar.update(1) 118 | logs = {"gen_loss": loss_g.detach().item(), 119 | "dis_loss": loss_d.detach().item() if epoch+1 > autoencoder_warm_up_n_epochs else 0, 120 | "kl_loss": kl_loss.detach().item(), 121 | "pp_loss": p_loss.detach().item(), 122 | "adv_loss": generator_loss.detach().item() if epoch+1 > autoencoder_warm_up_n_epochs else 0} 123 | progress_bar.set_postfix(**logs) 124 | # accelerator.log(logs, step=global_step) 125 | global_step += 1 126 | 127 | if (epoch + 1) % val_interval == 0 or epoch == cf['num_epochs'] - 1: 128 | vae.eval() 129 | total_mse_loss = 0.0 130 | with torch.no_grad(): 131 | for batch_idx, batch in enumerate(val_dataloader): 132 | mri = batch[0].to(device).float() 133 | val_recon_1, _, _ = vae(mri) 134 | mse_loss_1 = F.mse_loss(val_recon_1, mri) 135 | total_mse_loss += mse_loss_1 136 | if batch_idx == 0: 137 | save_path = j(Config.project_dir, 'image_save') 138 | os.makedirs(save_path, exist_ok=True) 139 | val_recon = torch.cat([mri, val_recon_1], dim=-1) 140 | val_recon = make_grid(val_recon, nrow=1).unsqueeze(0) 141 | log_image = {"MRI": (val_recon + 1) / 2} 142 | save_image(log_image["MRI"].clip(0, 1), j(save_path, f'epoch_{epoch + 1}_firstMRI.png')) 143 | average_mse_loss = total_mse_loss / len(val_dataloader) 144 | print(f'Epoch {epoch + 1}, Average MSE Loss: {average_mse_loss.item()}') 145 | del average_mse_loss, total_mse_loss, mse_loss_1 146 | 147 | 148 | if (epoch + 1) % save_interval == 0 or epoch == cf['num_epochs'] - 1: 149 | gen_path = j(Config.project_dir, 'gen_save') 150 | dis_path = j(Config.project_dir, 'dis_save') 151 | os.makedirs(gen_path, exist_ok=True) 152 | os.makedirs(dis_path, exist_ok=True) 153 | torch.save(vae.state_dict(), j(gen_path, 'vae.pth')) 154 | torch.save(discriminator.state_dict(), j(dis_path, 'dis.pth')) 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /stable_diffusion/trian_model.py: -------------------------------------------------------------------------------- 1 | import sys;sys.path.append('./') 2 | 3 | from tqdm import tqdm 4 | import numpy 5 | from utils.Condition_aug_dataloader import double_form_dataloader 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from generative.networks.nets import AutoencoderKL 10 | from generative.inferers import DiffusionInferer 11 | from generative.networks.schedulers import DDPMScheduler 12 | 13 | from utils.common import get_parameters, save_config_to_yaml, one_to_three 14 | from config.diffusion.config_controlnet import Config 15 | from os.path import join as j 16 | from accelerate import Accelerator 17 | from torchvision.utils import make_grid, save_image 18 | from unet.MF_UKAN import MF_UKAN 19 | from unet.MC_model import MC_MODEL 20 | import os 21 | 22 | def main(): 23 | mask = torch.load(Config.mask_path, map_location=torch.device('cpu')) 24 | mask = numpy.squeeze(mask) 25 | cf = save_config_to_yaml(Config, Config.project_dir) 26 | accelerator = Accelerator(**get_parameters(Accelerator, cf)) 27 | train_dataloader = double_form_dataloader(Config.data_path, 28 | Config.sample_size, 29 | Config.train_bc, 30 | Config.mode, 31 | read_channel='gray', 32 | mask = mask) 33 | device = 'cuda' 34 | val_dataloader = double_form_dataloader( 35 | Config.eval_path, 36 | Config.sample_size, 37 | Config.eval_bc, 38 | Config.mode, 39 | read_channel='gray', 40 | mask = mask) 41 | 42 | attention_levels = (False, ) * len(Config.up_and_down) 43 | vae = AutoencoderKL( 44 | spatial_dims=2, 45 | in_channels=Config.in_channels, 46 | out_channels=Config.out_channels, 47 | num_channels=Config.up_and_down, 48 | latent_channels=4, 49 | num_res_blocks=Config.num_res_layers, 50 | attention_levels = attention_levels 51 | ) 52 | vae = vae.eval().to(device) 53 | if len(Config.vae_resume_path): 54 | vae.load_state_dict(torch.load(Config.vae_resume_path)) 55 | 56 | model =MF_UKAN( 57 | T=1000, 58 | ch=64, 59 | ch_mult=[1, 2, 3, 4], 60 | attn=[2], 61 | num_res_blocks=2, 62 | dropout=0.15).to(device) 63 | if len(Config.sd_resume_path): 64 | model.load_state_dict(torch.load(Config.sd_resume_path), strict=False) 65 | model = model.to(device) 66 | 67 | mc_model =MC_MODEL( 68 | T=1000, 69 | ch=64, 70 | ch_mult=[1, 2, 3, 4], 71 | attn=[2], 72 | num_res_blocks=2, 73 | dropout=0.15).to(device) 74 | if len(Config.mc_model_path): 75 | mc_model.load_state_dict(torch.load(Config.mc_model_path), strict=False) 76 | mc_model = mc_model.to(device) 77 | 78 | 79 | scheduler = DDPMScheduler(num_train_timesteps=1000) 80 | optimizer_con = torch.optim.Adam(params=mc_model.parameters(), lr=2.5e-5) 81 | optimizer_sd = torch.optim.Adam(params=model.parameters(), lr=2.5e-5) 82 | inferer = DiffusionInferer(scheduler) 83 | 84 | val_interval = Config.val_inter 85 | save_interval = Config.save_inter 86 | 87 | if len(Config.log_with): 88 | accelerator.init_trackers('train_example') 89 | 90 | global_step = 0 91 | latent_shape = None 92 | # scaling_factor = 1 / torch.std(next(iter(train_dataloader))[0][1]) 93 | scaling_factor = Config.scaling_factor 94 | for epoch in range(Config.num_epochs): 95 | model.train() 96 | mc_model.train() 97 | epoch_loss = 0 98 | progress_bar = tqdm(total=len(train_dataloader)) 99 | progress_bar.set_description(f"Epoch {epoch+1}") 100 | for step, batch in enumerate(train_dataloader): 101 | mri = batch[0].float().to(device) 102 | mri_mask = batch[1].float().to(device) 103 | optimizer_con.zero_grad(set_to_none=True) 104 | optimizer_sd.zero_grad(set_to_none=True) 105 | with torch.no_grad(): 106 | latented_mri = vae.encode_stage_2_inputs(mri) 107 | latented_mri = latented_mri * scaling_factor 108 | latent_shape = list(latented_mri.shape);latent_shape[0] = Config.train_bc 109 | latent_noise = torch.randn_like(latented_mri) + 0.1 * torch.randn(latented_mri.shape[0], latented_mri.shape[1], 1, 1).to(latented_mri.device) 110 | timesteps = torch.randint( 111 | 0, inferer.scheduler.num_train_timesteps, (latented_mri.shape[0],), device=latented_mri.device 112 | ).long() 113 | latented_mri_noised = scheduler.add_noise(latented_mri, latent_noise, timesteps) 114 | down_block_res_samples, _ = mc_model( 115 | x=latented_mri_noised, t=timesteps, controlnet_cond=mri_mask 116 | ) 117 | noise_pred = model( 118 | x=latented_mri_noised, 119 | t=timesteps, 120 | down_block_additional_residuals=down_block_res_samples 121 | ) 122 | loss = F.mse_loss(noise_pred.float(), latent_noise.float()) 123 | loss.backward() 124 | optimizer_con.step() 125 | optimizer_sd.step() 126 | epoch_loss += loss.item() 127 | logs = {"loss": epoch_loss / (step + 1)} 128 | progress_bar.update() 129 | progress_bar.set_postfix(logs) 130 | accelerator.log(logs, step=global_step) 131 | global_step += 1 132 | 133 | if (epoch + 1) % val_interval == 0 or epoch == cf['num_epochs'] - 1: 134 | model.eval() 135 | mc_model.eval() 136 | batch = next(iter(val_dataloader)) 137 | mri = batch[0].float().to(device) 138 | mri_mask = batch[1].float().to(device) 139 | noise = torch.randn(latent_shape).to(device) 140 | progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110, position=0, leave=True) 141 | with torch.no_grad(): 142 | for t in progress_bar_sampling: 143 | t_tensor = torch.tensor([t], dtype=torch.long).to(device) 144 | down_block_res_samples, _ = mc_model( 145 | x=noise, t=t_tensor, controlnet_cond=mri_mask 146 | ) 147 | noise_pred = model( 148 | noise, 149 | t=t_tensor, 150 | down_block_additional_residuals=down_block_res_samples 151 | ) 152 | noise, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=noise) 153 | 154 | 155 | with torch.no_grad(): 156 | image = vae.decode_stage_2_outputs(noise / scaling_factor) 157 | mri, image = one_to_three(mri), one_to_three(image) 158 | image = torch.cat([mri, image], dim=-1) 159 | image = (make_grid(image, nrow=1).unsqueeze(0)+1)/2 160 | log_image = {"MRI": image.clamp(0, 1)} 161 | save_path = j(Config.project_dir, 'image_save') 162 | os.makedirs(save_path, exist_ok=True) 163 | save_image(log_image["MRI"], j(save_path, f'epoch_{epoch + 1}_MRI.png')) 164 | 165 | # accelerator.trackers[0].log_images(log_image, epoch+1) 166 | 167 | if (epoch + 1) % save_interval == 0 or epoch == cf['num_epochs'] - 1: 168 | save_path = j(Config.project_dir, 'model_save') 169 | os.makedirs(save_path, exist_ok=True) 170 | torch.save(model.state_dict(), j(save_path, 'model.pth')) 171 | torch.save(mc_model.state_dict(), j(save_path, 'mc_modle.pth')) 172 | 173 | if __name__ == '__main__': 174 | main() 175 | -------------------------------------------------------------------------------- /generative/losses/adversarial_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import warnings 15 | 16 | import torch 17 | from monai.networks.layers.utils import get_act_layer 18 | from monai.utils import LossReduction 19 | from monai.utils.enums import StrEnum 20 | from torch.nn.modules.loss import _Loss 21 | 22 | 23 | class AdversarialCriterions(StrEnum): 24 | BCE = "bce" 25 | HINGE = "hinge" 26 | LEAST_SQUARE = "least_squares" 27 | 28 | 29 | class PatchAdversarialLoss(_Loss): 30 | """ 31 | Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator. 32 | Warning: due to the possibility of using different criterions, the output of the discrimination 33 | mustn't be passed to a final activation layer. That is taken care of internally within the loss. 34 | 35 | Args: 36 | reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. 37 | Defaults to ``"mean"``. 38 | - ``"none"``: no reduction will be applied. 39 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 40 | - ``"sum"``: the output will be summed. 41 | criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs. 42 | Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs 43 | through an activation layer prior to calling the loss. 44 | no_activation_leastsq: if True, the activation layer in the case of least-squares is removed. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | reduction: LossReduction | str = LossReduction.MEAN, 50 | criterion: str = AdversarialCriterions.LEAST_SQUARE.value, 51 | no_activation_leastsq: bool = False, 52 | ) -> None: 53 | super().__init__(reduction=LossReduction(reduction).value) 54 | 55 | if criterion.lower() not in [m.value for m in AdversarialCriterions]: 56 | raise ValueError( 57 | "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" 58 | % ", ".join([m.value for m in AdversarialCriterions]) 59 | ) 60 | 61 | # Depending on the criterion, a different activation layer is used. 62 | self.real_label = 1.0 63 | self.fake_label = 0.0 64 | if criterion == AdversarialCriterions.BCE.value: 65 | self.activation = get_act_layer("SIGMOID") 66 | self.loss_fct = torch.nn.BCELoss(reduction=reduction) 67 | elif criterion == AdversarialCriterions.HINGE.value: 68 | self.activation = get_act_layer("TANH") 69 | self.fake_label = -1.0 70 | elif criterion == AdversarialCriterions.LEAST_SQUARE.value: 71 | if no_activation_leastsq: 72 | self.activation = None 73 | else: 74 | self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05})) 75 | self.loss_fct = torch.nn.MSELoss(reduction=reduction) 76 | 77 | self.criterion = criterion 78 | self.reduction = reduction 79 | 80 | def get_target_tensor(self, input: torch.FloatTensor, target_is_real: bool) -> torch.Tensor: 81 | """ 82 | Gets the ground truth tensor for the discriminator depending on whether the input is real or fake. 83 | 84 | Args: 85 | input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale 86 | discriminator). This is used to match the shape. 87 | target_is_real: whether the input is real or wannabe-real (1s) or fake (0s). 88 | Returns: 89 | """ 90 | filling_label = self.real_label if target_is_real else self.fake_label 91 | label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device) 92 | label_tensor.requires_grad_(False) 93 | return label_tensor.expand_as(input) 94 | 95 | def get_zero_tensor(self, input: torch.FloatTensor) -> torch.Tensor: 96 | """ 97 | Gets a zero tensor. 98 | 99 | Args: 100 | input: tensor which shape you want the zeros tensor to correspond to. 101 | Returns: 102 | """ 103 | 104 | zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device) 105 | zero_label_tensor.requires_grad_(False) 106 | return zero_label_tensor.expand_as(input) 107 | 108 | def forward( 109 | self, input: torch.FloatTensor | list, target_is_real: bool, for_discriminator: bool 110 | ) -> torch.Tensor | list[torch.Tensor]: 111 | """ 112 | 113 | Args: 114 | input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of 115 | tensors or a tensor; they shouldn't have gone through an activation layer. 116 | target_is_real: whereas the input corresponds to discriminator output for real or fake images 117 | for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last 118 | case, target_is_real is set to True, as the generator wants the input to be dimmed as real. 119 | Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale 120 | discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the 121 | summed or mean loss over the tensor and discriminator/s. 122 | 123 | """ 124 | 125 | if not for_discriminator and not target_is_real: 126 | target_is_real = True # With generator, we always want this to be true! 127 | warnings.warn( 128 | "Variable target_is_real has been set to False, but for_discriminator is set" 129 | "to False. To optimise a generator, target_is_real must be set to True." 130 | ) 131 | 132 | if type(input) is not list: 133 | input = [input] 134 | target_ = [] 135 | for _, disc_out in enumerate(input): 136 | if self.criterion != AdversarialCriterions.HINGE.value: 137 | target_.append(self.get_target_tensor(disc_out, target_is_real)) 138 | else: 139 | target_.append(self.get_zero_tensor(disc_out)) 140 | 141 | # Loss calculation 142 | loss = [] 143 | for disc_ind, disc_out in enumerate(input): 144 | if self.activation is not None: 145 | disc_out = self.activation(disc_out) 146 | if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real: 147 | loss_ = self.forward_single(-disc_out, target_[disc_ind]) 148 | else: 149 | loss_ = self.forward_single(disc_out, target_[disc_ind]) 150 | loss.append(loss_) 151 | 152 | if loss is not None: 153 | if self.reduction == LossReduction.MEAN.value: 154 | loss = torch.mean(torch.stack(loss)) 155 | elif self.reduction == LossReduction.SUM.value: 156 | loss = torch.sum(torch.stack(loss)) 157 | 158 | return loss 159 | 160 | def forward_single(self, input: torch.FloatTensor, target: torch.FloatTensor) -> torch.Tensor | None: 161 | if ( 162 | self.criterion == AdversarialCriterions.BCE.value 163 | or self.criterion == AdversarialCriterions.LEAST_SQUARE.value 164 | ): 165 | return self.loss_fct(input, target) 166 | elif self.criterion == AdversarialCriterions.HINGE.value: 167 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 168 | return -torch.mean(minval) 169 | else: 170 | return None 171 | -------------------------------------------------------------------------------- /generative/utils/ordering.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import numpy as np 15 | import torch 16 | 17 | from generative.utils.enums import OrderingTransformations, OrderingType 18 | 19 | 20 | class Ordering: 21 | """ 22 | Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with 23 | one of the following transformations: 24 | - Reflection - see np.flip for more details. 25 | - Transposition - see np.transpose for more details. 26 | - 90-degree rotation - see np.rot90 for more details. 27 | 28 | The transformations are applied in the order specified by the transformation_order parameter. 29 | 30 | Args: 31 | ordering_type: The ordering type. One of the following: 32 | - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from 33 | top to bottom. Also called a row major ordering. 34 | - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like 35 | pattern from top left towards right gowing in a spiral towards the center. 36 | - 'random': The image is projected into a 1D sequence by randomly shuffling the image. 37 | spatial_dims: The number of spatial dimensions of the image. 38 | dimensions: The dimensions of the image. 39 | reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension. 40 | transpositions_axes: A tuple of tuples indicating the axes to transpose the image along. 41 | rot90_axes: A tuple of tuples indicating the axes to rotate the image along. 42 | transformation_order: The order in which to apply the transformations. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | ordering_type: str, 48 | spatial_dims: int, 49 | dimensions: tuple[int, int, int] | tuple[int, int, int, int], 50 | reflected_spatial_dims: tuple[bool, bool] | tuple[bool, bool, bool] = (), 51 | transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] = (), 52 | rot90_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] = (), 53 | transformation_order: tuple[str, ...] = ( 54 | OrderingTransformations.TRANSPOSE.value, 55 | OrderingTransformations.ROTATE_90.value, 56 | OrderingTransformations.REFLECT.value, 57 | ), 58 | ) -> None: 59 | super().__init__() 60 | self.ordering_type = ordering_type 61 | 62 | if self.ordering_type not in list(OrderingType): 63 | raise ValueError( 64 | f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}." 65 | ) 66 | 67 | self.spatial_dims = spatial_dims 68 | self.dimensions = dimensions 69 | 70 | if len(dimensions) != self.spatial_dims + 1: 71 | raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.") 72 | 73 | self.reflected_spatial_dims = reflected_spatial_dims 74 | self.transpositions_axes = transpositions_axes 75 | self.rot90_axes = rot90_axes 76 | if len(set(transformation_order)) != len(transformation_order): 77 | raise ValueError(f"No duplicates are allowed. Received {transformation_order}.") 78 | 79 | for transformation in transformation_order: 80 | if transformation not in list(OrderingTransformations): 81 | raise ValueError( 82 | f"Valid transformations are {list(OrderingTransformations)} but received {transformation}." 83 | ) 84 | self.transformation_order = transformation_order 85 | 86 | self.template = self._create_template() 87 | self._sequence_ordering = self._create_ordering() 88 | self._revert_sequence_ordering = np.argsort(self._sequence_ordering) 89 | 90 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 91 | x = x[self._sequence_ordering] 92 | 93 | return x 94 | 95 | def get_sequence_ordering(self) -> np.ndarray: 96 | return self._sequence_ordering 97 | 98 | def get_revert_sequence_ordering(self) -> np.ndarray: 99 | return self._revert_sequence_ordering 100 | 101 | def _create_ordering(self) -> np.ndarray: 102 | self.template = self._transform_template() 103 | order = self._order_template(template=self.template) 104 | 105 | return order 106 | 107 | def _create_template(self) -> np.ndarray: 108 | spatial_dimensions = self.dimensions[1:] 109 | template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions) 110 | 111 | return template 112 | 113 | def _transform_template(self) -> np.ndarray: 114 | for transformation in self.transformation_order: 115 | if transformation == OrderingTransformations.TRANSPOSE.value: 116 | self.template = self._transpose_template(template=self.template) 117 | elif transformation == OrderingTransformations.ROTATE_90.value: 118 | self.template = self._rot90_template(template=self.template) 119 | elif transformation == OrderingTransformations.REFLECT.value: 120 | self.template = self._flip_template(template=self.template) 121 | 122 | return self.template 123 | 124 | def _transpose_template(self, template: np.ndarray) -> np.ndarray: 125 | for axes in self.transpositions_axes: 126 | template = np.transpose(template, axes=axes) 127 | 128 | return template 129 | 130 | def _flip_template(self, template: np.ndarray) -> np.ndarray: 131 | for axis, to_reflect in enumerate(self.reflected_spatial_dims): 132 | template = np.flip(template, axis=axis) if to_reflect else template 133 | 134 | return template 135 | 136 | def _rot90_template(self, template: np.ndarray) -> np.ndarray: 137 | for axes in self.rot90_axes: 138 | template = np.rot90(template, axes=axes) 139 | 140 | return template 141 | 142 | def _order_template(self, template: np.ndarray) -> np.ndarray: 143 | depths = None 144 | if self.spatial_dims == 2: 145 | rows, columns = template.shape[0], template.shape[1] 146 | else: 147 | rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2]) 148 | 149 | sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths) 150 | 151 | ordering = np.array([template[tuple(e)] for e in sequence]) 152 | 153 | return ordering 154 | 155 | @staticmethod 156 | def raster_scan_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: 157 | idx = [] 158 | 159 | for r in range(rows): 160 | for c in range(cols): 161 | if depths: 162 | for d in range(depths): 163 | idx.append((r, c, d)) 164 | else: 165 | idx.append((r, c)) 166 | 167 | idx = np.array(idx) 168 | 169 | return idx 170 | 171 | @staticmethod 172 | def s_curve_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: 173 | idx = [] 174 | 175 | for r in range(rows): 176 | col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1) 177 | for c in col_idx: 178 | if depths: 179 | depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1) 180 | 181 | for d in depth_idx: 182 | idx.append((r, c, d)) 183 | else: 184 | idx.append((r, c)) 185 | 186 | idx = np.array(idx) 187 | 188 | return idx 189 | 190 | @staticmethod 191 | def random_idx(rows: int, cols: int, depths: int = None) -> np.ndarray: 192 | idx = [] 193 | 194 | for r in range(rows): 195 | for c in range(cols): 196 | if depths: 197 | for d in range(depths): 198 | idx.append((r, c, d)) 199 | else: 200 | idx.append((r, c)) 201 | 202 | idx = np.array(idx) 203 | np.random.shuffle(idx) 204 | 205 | return idx 206 | -------------------------------------------------------------------------------- /generative/networks/schedulers/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | # 12 | # ========================================================================= 13 | # Adapted from https://github.com/huggingface/diffusers 14 | # which has the following license: 15 | # https://github.com/huggingface/diffusers/blob/main/LICENSE 16 | # 17 | # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. 18 | # 19 | # Licensed under the Apache License, Version 2.0 (the "License"); 20 | # you may not use this file except in compliance with the License. 21 | # You may obtain a copy of the License at 22 | # 23 | # http://www.apache.org/licenses/LICENSE-2.0 24 | # 25 | # Unless required by applicable law or agreed to in writing, software 26 | # distributed under the License is distributed on an "AS IS" BASIS, 27 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28 | # See the License for the specific language governing permissions and 29 | # limitations under the License. 30 | # ========================================================================= 31 | 32 | 33 | from __future__ import annotations 34 | 35 | import torch 36 | import torch.nn as nn 37 | 38 | from generative.utils import ComponentStore, unsqueeze_right 39 | 40 | NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules") 41 | 42 | 43 | @NoiseSchedules.add_def("linear_beta", "Linear beta schedule") 44 | def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): 45 | """ 46 | Linear beta noise schedule function. 47 | 48 | Args: 49 | num_train_timesteps: number of timesteps 50 | beta_start: start of beta range, default 1e-4 51 | beta_end: end of beta range, default 2e-2 52 | 53 | Returns: 54 | betas: beta schedule tensor 55 | """ 56 | return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 57 | 58 | 59 | @NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule") 60 | def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2): 61 | """ 62 | Scaled linear beta noise schedule function. 63 | 64 | Args: 65 | num_train_timesteps: number of timesteps 66 | beta_start: start of beta range, default 1e-4 67 | beta_end: end of beta range, default 2e-2 68 | 69 | Returns: 70 | betas: beta schedule tensor 71 | """ 72 | return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 73 | 74 | 75 | @NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule") 76 | def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6): 77 | """ 78 | Sigmoid beta noise schedule function. 79 | 80 | Args: 81 | num_train_timesteps: number of timesteps 82 | beta_start: start of beta range, default 1e-4 83 | beta_end: end of beta range, default 2e-2 84 | sig_range: pos/neg range of sigmoid input, default 6 85 | 86 | Returns: 87 | betas: beta schedule tensor 88 | """ 89 | betas = torch.linspace(-sig_range, sig_range, num_train_timesteps) 90 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 91 | 92 | 93 | @NoiseSchedules.add_def("cosine", "Cosine schedule") 94 | def _cosine_beta(num_train_timesteps: int, s: float = 8e-3): 95 | """ 96 | Cosine noise schedule, see https://arxiv.org/abs/2102.09672 97 | 98 | Args: 99 | num_train_timesteps: number of timesteps 100 | s: smoothing factor, default 8e-3 (see referenced paper) 101 | 102 | Returns: 103 | (betas, alphas, alpha_cumprod) values 104 | """ 105 | x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1) 106 | alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 107 | alphas_cumprod /= alphas_cumprod[0].item() 108 | alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999) 109 | betas = 1.0 - alphas 110 | return betas, alphas, alphas_cumprod[:-1] 111 | 112 | 113 | class Scheduler(nn.Module): 114 | """ 115 | Base class for other schedulers based on a noise schedule function. 116 | 117 | This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here 118 | the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`, 119 | which is the name of a component in NoiseSchedules. These components must all be callables which return either 120 | the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions 121 | can be provided by using the NoiseSchedules.add_def, for example: 122 | 123 | .. code-block:: python 124 | from generative.networks.schedulers import NoiseSchedules, DDPMScheduler 125 | 126 | @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function") 127 | def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2): 128 | return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 129 | 130 | scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule") 131 | 132 | All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of 133 | timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through 134 | the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules 135 | to get a listing of stored objects with their docstring descriptions. 136 | 137 | Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule 138 | type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended 139 | to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are 140 | still used for some schedules but these are provided as keyword arguments now. 141 | 142 | Args: 143 | num_train_timesteps: number of diffusion steps used to train the model. 144 | schedule: member of NoiseSchedules, 145 | a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple 146 | schedule_args: arguments to pass to the schedule function 147 | """ 148 | 149 | def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None: 150 | super().__init__() 151 | schedule_args["num_train_timesteps"] = num_train_timesteps 152 | noise_sched = NoiseSchedules[schedule](**schedule_args) 153 | 154 | # set betas, alphas, alphas_cumprod based off return value from noise function 155 | if isinstance(noise_sched, tuple): 156 | self.betas, self.alphas, self.alphas_cumprod = noise_sched 157 | else: 158 | self.betas = noise_sched 159 | self.alphas = 1.0 - self.betas 160 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 161 | 162 | self.num_train_timesteps = num_train_timesteps 163 | self.one = torch.tensor(1.0) 164 | 165 | # settable values 166 | self.num_inference_steps = None 167 | self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1) 168 | 169 | def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: 170 | """ 171 | Add noise to the original samples. 172 | 173 | Args: 174 | original_samples: original samples 175 | noise: noise to add to samples 176 | timesteps: timesteps tensor indicating the timestep to be computed for each sample. 177 | 178 | Returns: 179 | noisy_samples: sample with added noise 180 | """ 181 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 182 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 183 | timesteps = timesteps.to(original_samples.device) 184 | 185 | sqrt_alpha_cumprod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim) 186 | sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim) 187 | 188 | noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise 189 | return noisy_samples 190 | 191 | def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: 192 | # Make sure alphas_cumprod and timestep have same device and dtype as sample 193 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) 194 | timesteps = timesteps.to(sample.device) 195 | 196 | sqrt_alpha_prod = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim) 197 | sqrt_one_minus_alpha_prod = unsqueeze_right((1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim) 198 | 199 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 200 | return velocity 201 | -------------------------------------------------------------------------------- /generative/metrics/ssim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | from collections.abc import Sequence 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from monai.metrics.regression import RegressionMetric 19 | from monai.utils import MetricReduction, StrEnum, convert_data_type, ensure_tuple_rep 20 | from monai.utils.type_conversion import convert_to_dst_type 21 | 22 | 23 | class KernelType(StrEnum): 24 | GAUSSIAN = "gaussian" 25 | UNIFORM = "uniform" 26 | 27 | 28 | class SSIMMetric(RegressionMetric): 29 | r""" 30 | Computes the Structural Similarity Index Measure (SSIM). 31 | 32 | .. math:: 33 | \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \ 34 | \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)} 35 | 36 | For more info, visit 37 | https://vicuesoft.com/glossary/term/ssim-ms-ssim/ 38 | 39 | SSIM reference paper: 40 | Wang, Zhou, et al. "Image quality assessment: from error visibility to structural 41 | similarity." IEEE transactions on image processing 13.4 (2004): 600-612. 42 | 43 | Args: 44 | spatial_dims: number of spatial dimensions of the input images. 45 | data_range: value range of input images. (usually 1.0 or 255) 46 | kernel_type: type of kernel, can be "gaussian" or "uniform". 47 | kernel_size: size of kernel 48 | kernel_sigma: standard deviation for Gaussian kernel. 49 | k1: stability constant used in the luminance denominator 50 | k2: stability constant used in the contrast denominator 51 | reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, 52 | available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, 53 | ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction 54 | get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans) 55 | """ 56 | 57 | def __init__( 58 | self, 59 | spatial_dims: int, 60 | data_range: float = 1.0, 61 | kernel_type: KernelType | str = KernelType.GAUSSIAN, 62 | kernel_size: int | Sequence[int, ...] = 11, 63 | kernel_sigma: float | Sequence[float, ...] = 1.5, 64 | k1: float = 0.01, 65 | k2: float = 0.03, 66 | reduction: MetricReduction | str = MetricReduction.MEAN, 67 | get_not_nans: bool = False, 68 | ) -> None: 69 | super().__init__(reduction=reduction, get_not_nans=get_not_nans) 70 | 71 | self.spatial_dims = spatial_dims 72 | self.data_range = data_range 73 | self.kernel_type = kernel_type 74 | 75 | if not isinstance(kernel_size, Sequence): 76 | kernel_size = ensure_tuple_rep(kernel_size, spatial_dims) 77 | self.kernel_size = kernel_size 78 | 79 | if not isinstance(kernel_sigma, Sequence): 80 | kernel_sigma = ensure_tuple_rep(kernel_sigma, spatial_dims) 81 | self.kernel_sigma = kernel_sigma 82 | 83 | self.k1 = k1 84 | self.k2 = k2 85 | 86 | def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 87 | """ 88 | Args: 89 | y_pred: Predicted image. 90 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 91 | y: Reference image. 92 | It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D]. 93 | 94 | Raises: 95 | ValueError: when `y_pred` is not a 2D or 3D image. 96 | """ 97 | dims = y_pred.ndimension() 98 | if self.spatial_dims == 2 and dims != 4: 99 | raise ValueError( 100 | f"y_pred should have 4 dimensions (batch, channel, height, width) when using {self.spatial_dims} " 101 | f"spatial dimensions, got {dims}." 102 | ) 103 | 104 | if self.spatial_dims == 3 and dims != 5: 105 | raise ValueError( 106 | f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}" 107 | f" spatial dimensions, got {dims}." 108 | ) 109 | 110 | ssim_value_full_image, _ = compute_ssim_and_cs( 111 | y_pred=y_pred, 112 | y=y, 113 | spatial_dims=self.spatial_dims, 114 | data_range=self.data_range, 115 | kernel_type=self.kernel_type, 116 | kernel_size=self.kernel_size, 117 | kernel_sigma=self.kernel_sigma, 118 | k1=self.k1, 119 | k2=self.k2, 120 | ) 121 | 122 | ssim_per_batch: torch.Tensor = ssim_value_full_image.view(ssim_value_full_image.shape[0], -1).mean( 123 | 1, keepdim=True 124 | ) 125 | 126 | return ssim_per_batch 127 | 128 | 129 | def _gaussian_kernel( 130 | spatial_dims: int, num_channels: int, kernel_size: Sequence[int, ...], kernel_sigma: Sequence[float, ...] 131 | ) -> torch.Tensor: 132 | """Computes 2D or 3D gaussian kernel. 133 | 134 | Args: 135 | spatial_dims: number of spatial dimensions of the input images. 136 | num_channels: number of channels in the image 137 | kernel_size: size of kernel 138 | kernel_sigma: standard deviation for Gaussian kernel. 139 | """ 140 | 141 | def gaussian_1d(kernel_size: int, sigma: float) -> torch.Tensor: 142 | """Computes 1D gaussian kernel. 143 | 144 | Args: 145 | kernel_size: size of the gaussian kernel 146 | sigma: Standard deviation of the gaussian kernel 147 | """ 148 | dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1) 149 | gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2) 150 | return (gauss / gauss.sum()).unsqueeze(dim=0) 151 | 152 | gaussian_kernel_x = gaussian_1d(kernel_size[0], kernel_sigma[0]) 153 | gaussian_kernel_y = gaussian_1d(kernel_size[1], kernel_sigma[1]) 154 | kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) 155 | 156 | kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1]) 157 | 158 | if spatial_dims == 3: 159 | gaussian_kernel_z = gaussian_1d(kernel_size[2], kernel_sigma[2])[None,] 160 | kernel = torch.mul( 161 | kernel.unsqueeze(-1).repeat(1, 1, kernel_size[2]), 162 | gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]), 163 | ) 164 | kernel_dimensions = (num_channels, 1, kernel_size[0], kernel_size[1], kernel_size[2]) 165 | 166 | return kernel.expand(kernel_dimensions) 167 | 168 | 169 | def compute_ssim_and_cs( 170 | y_pred: torch.Tensor, 171 | y: torch.Tensor, 172 | spatial_dims: int, 173 | data_range: float = 1.0, 174 | kernel_type: KernelType | str = KernelType.GAUSSIAN, 175 | kernel_size: Sequence[int, ...] = 11, 176 | kernel_sigma: Sequence[float, ...] = 1.5, 177 | k1: float = 0.01, 178 | k2: float = 0.03, 179 | ) -> tuple[torch.Tensor, torch.Tensor]: 180 | """ 181 | Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch 182 | of images. 183 | 184 | Args: 185 | y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) 186 | y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3]) 187 | spatial_dims: number of spatial dimensions of the images (2, 3) 188 | data_range: the data range of the images. 189 | kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform". 190 | kernel_size: the size of the kernel to use for the SSIM computation. 191 | kernel_sigma: the standard deviation of the kernel to use for the SSIM computation. 192 | k1: the first stability constant. 193 | k2: the second stability constant. 194 | 195 | Returns: 196 | ssim: the Structural Similarity Index Measure score for the batch of images. 197 | cs: the Contrast Sensitivity for the batch of images. 198 | """ 199 | if y.shape != y_pred.shape: 200 | raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.") 201 | 202 | y_pred = convert_data_type(y_pred, output_type=torch.Tensor, dtype=torch.float)[0] 203 | y = convert_data_type(y, output_type=torch.Tensor, dtype=torch.float)[0] 204 | 205 | num_channels = y_pred.size(1) 206 | 207 | if kernel_type == KernelType.GAUSSIAN: 208 | kernel = _gaussian_kernel(spatial_dims, num_channels, kernel_size, kernel_sigma) 209 | elif kernel_type == KernelType.UNIFORM: 210 | kernel = torch.ones((num_channels, 1, *kernel_size)) / torch.prod(torch.tensor(kernel_size)) 211 | 212 | kernel = convert_to_dst_type(src=kernel, dst=y_pred)[0] 213 | 214 | c1 = (k1 * data_range) ** 2 # stability constant for luminance 215 | c2 = (k2 * data_range) ** 2 # stability constant for contrast 216 | 217 | conv_fn = getattr(F, f"conv{spatial_dims}d") 218 | mu_x = conv_fn(y_pred, kernel, groups=num_channels) 219 | mu_y = conv_fn(y, kernel, groups=num_channels) 220 | mu_xx = conv_fn(y_pred * y_pred, kernel, groups=num_channels) 221 | mu_yy = conv_fn(y * y, kernel, groups=num_channels) 222 | mu_xy = conv_fn(y_pred * y, kernel, groups=num_channels) 223 | 224 | sigma_x = mu_xx - mu_x * mu_x 225 | sigma_y = mu_yy - mu_y * mu_y 226 | sigma_xy = mu_xy - mu_x * mu_y 227 | 228 | contrast_sensitivity = (2 * sigma_xy + c2) / (sigma_x + sigma_y + c2) 229 | ssim_value_full_image = ((2 * mu_x * mu_y + c1) / (mu_x**2 + mu_y**2 + c1)) * contrast_sensitivity 230 | 231 | return ssim_value_full_image, contrast_sensitivity 232 | -------------------------------------------------------------------------------- /generative/networks/layers/vector_quantizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Sequence, Tuple 13 | 14 | import torch 15 | from torch import nn 16 | 17 | __all__ = ["VectorQuantizer", "EMAQuantizer"] 18 | 19 | 20 | class EMAQuantizer(nn.Module): 21 | """ 22 | Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural 23 | Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation 24 | that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit 25 | 58d9a2746493717a7c9252938da7efa6006f3739. 26 | 27 | This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due 28 | to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353 29 | on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False. 30 | 31 | Args: 32 | spatial_dims : number of spatial spatial_dims. 33 | num_embeddings: number of atomic elements in the codebook. 34 | embedding_dim: number of channels of the input and atomic elements. 35 | commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25. 36 | decay: EMA decay. Defaults to 0.99. 37 | epsilon: epsilon value. Defaults to 1e-5. 38 | embedding_init: initialization method for the codebook. Defaults to "normal". 39 | ddp_sync: whether to synchronize the codebook across processes. Defaults to True. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | spatial_dims: int, 45 | num_embeddings: int, 46 | embedding_dim: int, 47 | commitment_cost: float = 0.25, 48 | decay: float = 0.99, 49 | epsilon: float = 1e-5, 50 | embedding_init: str = "normal", 51 | ddp_sync: bool = True, 52 | ): 53 | super().__init__() 54 | self.spatial_dims: int = spatial_dims 55 | self.embedding_dim: int = embedding_dim 56 | self.num_embeddings: int = num_embeddings 57 | 58 | assert self.spatial_dims in [2, 3], ValueError( 59 | f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}." 60 | ) 61 | 62 | self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim) 63 | if embedding_init == "normal": 64 | # Initialization is passed since the default one is normal inside the nn.Embedding 65 | pass 66 | elif embedding_init == "kaiming_uniform": 67 | torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear") 68 | self.embedding.weight.requires_grad = False 69 | 70 | self.commitment_cost: float = commitment_cost 71 | 72 | self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings)) 73 | self.register_buffer("ema_w", self.embedding.weight.data.clone()) 74 | 75 | self.decay: float = decay 76 | self.epsilon: float = epsilon 77 | 78 | self.ddp_sync: bool = ddp_sync 79 | 80 | # Precalculating required permutation shapes 81 | self.flatten_permutation: Sequence[int] = [0] + list(range(2, self.spatial_dims + 2)) + [1] 82 | self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list( 83 | range(1, self.spatial_dims + 1) 84 | ) 85 | 86 | def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 87 | """ 88 | Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss. 89 | 90 | Args: 91 | inputs: Encoding space tensors 92 | 93 | Returns: 94 | torch.Tensor: Flatten version of the input of shape [B*D*H*W, C]. 95 | torch.Tensor: One-hot representation of the quantization indices of shape [B*D*H*W, self.num_embeddings]. 96 | torch.Tensor: Quantization indices of shape [B,D,H,W,1] 97 | 98 | """ 99 | encoding_indices_view = list(inputs.shape) 100 | del encoding_indices_view[1] 101 | 102 | with torch.cuda.amp.autocast(enabled=False): 103 | inputs = inputs.float() 104 | 105 | # Converting to channel last format 106 | flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim) 107 | 108 | # Calculate Euclidean distances 109 | distances = ( 110 | (flat_input**2).sum(dim=1, keepdim=True) 111 | + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True) 112 | - 2 * torch.mm(flat_input, self.embedding.weight.t()) 113 | ) 114 | 115 | # Mapping distances to indexes 116 | encoding_indices = torch.max(-distances, dim=1)[1] 117 | encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float() 118 | 119 | # Quantize and reshape 120 | encoding_indices = encoding_indices.view(encoding_indices_view) 121 | 122 | return flat_input, encodings, encoding_indices 123 | 124 | def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: 125 | """ 126 | Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space 127 | [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the 128 | decoder. 129 | 130 | Args: 131 | embedding_indices: Tensor in channel last format which holds indices referencing atomic 132 | elements from self.embedding 133 | 134 | Returns: 135 | torch.Tensor: Quantize space representation of encoding_indices in channel first format. 136 | """ 137 | with torch.cuda.amp.autocast(enabled=False): 138 | return self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() 139 | 140 | @torch.jit.unused 141 | def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None: 142 | """ 143 | TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the 144 | example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused 145 | 146 | Args: 147 | encodings_sum: The summation of one hot representation of what encoding was used for each 148 | position. 149 | dw: The multiplication of the one hot representation of what encoding was used for each 150 | position with the flattened input. 151 | 152 | Returns: 153 | None 154 | """ 155 | if self.ddp_sync and torch.distributed.is_initialized(): 156 | torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM) 157 | torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM) 158 | else: 159 | pass 160 | 161 | def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 162 | flat_input, encodings, encoding_indices = self.quantize(inputs) 163 | quantized = self.embed(encoding_indices) 164 | 165 | # Use EMA to update the embedding vectors 166 | if self.training: 167 | with torch.no_grad(): 168 | encodings_sum = encodings.sum(0) 169 | dw = torch.mm(encodings.t(), flat_input) 170 | 171 | if self.ddp_sync: 172 | self.distributed_synchronization(encodings_sum, dw) 173 | 174 | self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay)) 175 | 176 | # Laplace smoothing of the cluster size 177 | n = self.ema_cluster_size.sum() 178 | weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n 179 | self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay)) 180 | self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1)) 181 | 182 | # Encoding Loss 183 | loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs) 184 | 185 | # Straight Through Estimator 186 | quantized = inputs + (quantized - inputs).detach() 187 | 188 | return quantized, loss, encoding_indices 189 | 190 | 191 | class VectorQuantizer(torch.nn.Module): 192 | """ 193 | Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of 194 | the quantization in their own class. 195 | 196 | Args: 197 | quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index 198 | based quantized representation. Defaults to None 199 | """ 200 | 201 | def __init__(self, quantizer: torch.nn.Module = None): 202 | super().__init__() 203 | 204 | self.quantizer: torch.nn.Module = quantizer 205 | 206 | self.perplexity: torch.Tensor = torch.rand(1) 207 | 208 | def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 209 | quantized, loss, encoding_indices = self.quantizer(inputs) 210 | 211 | # Perplexity calculations 212 | avg_probs = ( 213 | torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings) 214 | .float() 215 | .div(encoding_indices.numel()) 216 | ) 217 | 218 | self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 219 | 220 | return loss, quantized 221 | 222 | def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: 223 | return self.quantizer.embed(embedding_indices=embedding_indices) 224 | 225 | def quantize(self, encodings: torch.Tensor) -> torch.Tensor: 226 | _, _, encoding_indices = self.quantizer(encodings) 227 | 228 | return encoding_indices 229 | -------------------------------------------------------------------------------- /generative/networks/schedulers/ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | # 12 | # ========================================================================= 13 | # Adapted from https://github.com/huggingface/diffusers 14 | # which has the following license: 15 | # https://github.com/huggingface/diffusers/blob/main/LICENSE 16 | # 17 | # Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. 18 | # 19 | # Licensed under the Apache License, Version 2.0 (the "License"); 20 | # you may not use this file except in compliance with the License. 21 | # You may obtain a copy of the License at 22 | # 23 | # http://www.apache.org/licenses/LICENSE-2.0 24 | # 25 | # Unless required by applicable law or agreed to in writing, software 26 | # distributed under the License is distributed on an "AS IS" BASIS, 27 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 28 | # See the License for the specific language governing permissions and 29 | # limitations under the License. 30 | # ========================================================================= 31 | 32 | from __future__ import annotations 33 | 34 | import numpy as np 35 | import torch 36 | from monai.utils import StrEnum 37 | 38 | from .scheduler import Scheduler 39 | 40 | 41 | class DDPMVarianceType(StrEnum): 42 | """ 43 | Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise 44 | to the denoised sample. 45 | """ 46 | 47 | FIXED_SMALL = "fixed_small" 48 | FIXED_LARGE = "fixed_large" 49 | LEARNED = "learned" 50 | LEARNED_RANGE = "learned_range" 51 | 52 | 53 | class DDPMPredictionType(StrEnum): 54 | """ 55 | Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument. 56 | 57 | epsilon: predicting the noise of the diffusion process 58 | sample: directly predicting the noisy sample 59 | v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf 60 | """ 61 | 62 | EPSILON = "epsilon" 63 | SAMPLE = "sample" 64 | V_PREDICTION = "v_prediction" 65 | 66 | 67 | class DDPMScheduler(Scheduler): 68 | """ 69 | Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and 70 | Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models" 71 | https://arxiv.org/abs/2006.11239 72 | 73 | Args: 74 | num_train_timesteps: number of diffusion steps used to train the model. 75 | schedule: member of NoiseSchedules, name of noise schedule function in component store 76 | variance_type: member of DDPMVarianceType 77 | clip_sample: option to clip predicted sample between -1 and 1 for numerical stability. 78 | prediction_type: member of DDPMPredictionType 79 | clip_sample_min: if clip_sample is True, minimum value to clamp the prediction by. 80 | clip_sample_max: if clip_sample is False, maximum value to clamp the prediction by. 81 | schedule_args: arguments to pass to the schedule function 82 | """ 83 | 84 | def __init__( 85 | self, 86 | num_train_timesteps: int = 1000, 87 | schedule: str = "linear_beta", 88 | variance_type: str = DDPMVarianceType.FIXED_SMALL, 89 | clip_sample: bool = True, 90 | prediction_type: str = DDPMPredictionType.EPSILON, 91 | clip_sample_min: int = -1, 92 | clip_sample_max: int = 1, 93 | **schedule_args, 94 | ) -> None: 95 | super().__init__(num_train_timesteps, schedule, **schedule_args) 96 | 97 | if variance_type not in DDPMVarianceType.__members__.values(): 98 | raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`") 99 | 100 | if prediction_type not in DDPMPredictionType.__members__.values(): 101 | raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`") 102 | 103 | if clip_sample_min >= clip_sample_max: 104 | raise ValueError("clip_sample_min must be < clip_sample_max") 105 | 106 | self.clip_sample = clip_sample 107 | self.variance_type = variance_type 108 | self.prediction_type = prediction_type 109 | self.clip_sample_values = [clip_sample_min, clip_sample_max] 110 | 111 | def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None: 112 | """ 113 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 114 | 115 | Args: 116 | num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. 117 | device: target device to put the data. 118 | """ 119 | if num_inference_steps > self.num_train_timesteps: 120 | raise ValueError( 121 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" 122 | f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" 123 | f" maximal {self.num_train_timesteps} timesteps." 124 | ) 125 | 126 | self.num_inference_steps = num_inference_steps 127 | step_ratio = self.num_train_timesteps // self.num_inference_steps 128 | # creates integer timesteps by multiplying by ratio 129 | # casting to int to avoid issues when num_inference_step is power of 3 130 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) 131 | self.timesteps = torch.from_numpy(timesteps).to(device) 132 | 133 | def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: 134 | """ 135 | Compute the mean of the posterior at timestep t. 136 | 137 | Args: 138 | timestep: current timestep. 139 | x0: the noise-free input. 140 | x_t: the input noised to timestep t. 141 | 142 | Returns: 143 | Returns the mean 144 | """ 145 | # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0), 146 | # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf) 147 | alpha_t = self.alphas[timestep] 148 | alpha_prod_t = self.alphas_cumprod[timestep] 149 | alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one 150 | 151 | x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t) 152 | x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) 153 | 154 | mean = x_0_coefficient * x_0 + x_t_coefficient * x_t 155 | 156 | return mean 157 | 158 | def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor: 159 | """ 160 | Compute the variance of the posterior at timestep t. 161 | 162 | Args: 163 | timestep: current timestep. 164 | predicted_variance: variance predicted by the model. 165 | 166 | Returns: 167 | Returns the variance 168 | """ 169 | alpha_prod_t = self.alphas_cumprod[timestep] 170 | alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one 171 | 172 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 173 | # and sample from it to get previous sample 174 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample 175 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] 176 | # hacks - were probably added for training stability 177 | if self.variance_type == DDPMVarianceType.FIXED_SMALL: 178 | variance = torch.clamp(variance, min=1e-20) 179 | elif self.variance_type == DDPMVarianceType.FIXED_LARGE: 180 | variance = self.betas[timestep] 181 | elif self.variance_type == DDPMVarianceType.LEARNED: 182 | return predicted_variance 183 | elif self.variance_type == DDPMVarianceType.LEARNED_RANGE: 184 | min_log = variance 185 | max_log = self.betas[timestep] 186 | frac = (predicted_variance + 1) / 2 187 | variance = frac * max_log + (1 - frac) * min_log 188 | 189 | return variance 190 | 191 | def step( 192 | self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None 193 | ) -> tuple[torch.Tensor, torch.Tensor]: 194 | """ 195 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 196 | process from the learned model outputs (most often the predicted noise). 197 | 198 | Args: 199 | model_output: direct output from learned diffusion model. 200 | timestep: current discrete timestep in the diffusion chain. 201 | sample: current instance of sample being created by diffusion process. 202 | generator: random number generator. 203 | 204 | Returns: 205 | pred_prev_sample: Predicted previous sample 206 | """ 207 | if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: 208 | model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) 209 | else: 210 | predicted_variance = None 211 | 212 | # 1. compute alphas, betas 213 | alpha_prod_t = self.alphas_cumprod[timestep] 214 | alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one 215 | beta_prod_t = 1 - alpha_prod_t 216 | beta_prod_t_prev = 1 - alpha_prod_t_prev 217 | 218 | # 2. compute predicted original sample from predicted noise also called 219 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 220 | if self.prediction_type == DDPMPredictionType.EPSILON: 221 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 222 | elif self.prediction_type == DDPMPredictionType.SAMPLE: 223 | pred_original_sample = model_output 224 | elif self.prediction_type == DDPMPredictionType.V_PREDICTION: 225 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 226 | 227 | # 3. Clip "predicted x_0" 228 | if self.clip_sample: 229 | pred_original_sample = torch.clamp( 230 | pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1] 231 | ) 232 | 233 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 234 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 235 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t 236 | current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t 237 | 238 | # 5. Compute predicted previous sample µ_t 239 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 240 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample 241 | 242 | # 6. Add noise 243 | variance = 0 244 | if timestep > 0: 245 | noise = torch.randn( 246 | model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator 247 | ).to(model_output.device) 248 | variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise 249 | 250 | pred_prev_sample = pred_prev_sample + variance 251 | 252 | return pred_prev_sample, pred_original_sample 253 | -------------------------------------------------------------------------------- /unet/MC_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from __future__ import annotations 4 | from torch.nn import init 5 | from collections.abc import Sequence 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from monai.networks.blocks import Convolution 10 | from monai.utils import ensure_tuple_rep 11 | from torch import nn 12 | from unet.MF_UKAN import TimeEmbedding, ResBlock, DownSample, OverlapPatchEmbed, shiftedBlock 13 | from generative.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding 14 | from config.diffusion.config_controlnet import Config 15 | 16 | class ControlNetConditioningEmbedding(nn.Module): 17 | """ 18 | 用于将控制条件编码到潜在空间中的网络。 19 | """ 20 | 21 | def __init__( 22 | self, spatial_dims: int, in_channels: int, out_channels: int, 23 | num_channels: Sequence[int] = (16, 32, 96, 256) 24 | ): 25 | super().__init__() 26 | 27 | # 输入卷积层,用于初始特征提取 28 | self.conv_in = Convolution( 29 | spatial_dims=spatial_dims, 30 | in_channels=in_channels, 31 | out_channels=num_channels[0], 32 | strides=1, 33 | kernel_size=3, 34 | padding=1, 35 | conv_only=True, 36 | ) 37 | 38 | self.blocks = nn.ModuleList([]) 39 | 40 | # 构建一系列卷积块 41 | for i in range(len(num_channels) - 2): 42 | channel_in = num_channels[i] 43 | channel_out = num_channels[i + 1] 44 | 45 | self.blocks.append( 46 | Convolution( 47 | spatial_dims=spatial_dims, 48 | in_channels=channel_in, 49 | out_channels=channel_in, 50 | strides=1, 51 | kernel_size=3, 52 | padding=1, 53 | conv_only=True, 54 | ) 55 | ) 56 | self.blocks.append( 57 | Convolution( 58 | spatial_dims=spatial_dims, 59 | in_channels=channel_in, 60 | out_channels=channel_out, 61 | strides=2, 62 | kernel_size=3, 63 | padding=1, 64 | conv_only=True, 65 | ) 66 | ) 67 | 68 | self.conv_out = zero_module( 69 | Convolution( 70 | spatial_dims=spatial_dims, 71 | in_channels=num_channels[-2], 72 | out_channels=out_channels, 73 | strides=1, 74 | kernel_size=3, 75 | padding=1, 76 | conv_only=True, 77 | ) 78 | ) 79 | 80 | def forward(self, conditioning): 81 | embedding = self.conv_in(conditioning) 82 | embedding = F.silu(embedding) 83 | 84 | for block in self.blocks: 85 | embedding = block(embedding) 86 | embedding = F.silu(embedding) 87 | 88 | embedding = self.conv_out(embedding) 89 | 90 | return embedding 91 | 92 | 93 | def zero_module(module): 94 | for p in module.parameters(): 95 | nn.init.zeros_(p) 96 | return module 97 | 98 | 99 | def copy_weights_to_controlnet(controlnet: nn.Module, diffusion_model: nn.Module, verbose: bool = True) -> None: 100 | """ 101 | Args: 102 | controlnet: ControlNet 实例 103 | diffusion_model: DiffusionModelUnet 或 SPADEDiffusionModelUnet 实例 104 | verbose: 如果为 True,将打印匹配和不匹配的键。 105 | """ 106 | 107 | output = controlnet.load_state_dict(diffusion_model.state_dict(), strict=False) 108 | if verbose: 109 | dm_keys = [p[0] for p in list(diffusion_model.named_parameters()) if p[0] not in output.unexpected_keys] 110 | print( 111 | f"Copied weights from {len(dm_keys)} keys of the diffusion model into the ControlNet:" 112 | f"\n{'; '.join(dm_keys)}\nControlNet missing keys: {len(output.missing_keys)}:" 113 | f"\n{'; '.join(output.missing_keys)}\nDiffusion model incompatible keys: {len(output.unexpected_keys)}:" 114 | f"\n{'; '.join(output.unexpected_keys)}" 115 | ) 116 | 117 | 118 | class MC_MODEL(nn.Module): 119 | """ 120 | Args: 121 | spatial_dims: 空间维度数(例如,2 表示 2D,3 表示 3D)。 122 | in_channels: 输入通道数。 123 | num_res_blocks: 每个阶段的残差块数量。 124 | num_channels: 各个阶段的输出通道数。 125 | attention_levels: 列表,表示在哪些阶段添加注意力机制。 126 | norm_num_groups: 归一化时的组数。 127 | norm_eps: 归一化的 epsilon 值。 128 | resblock_updown: 如果为 True,则使用残差块进行上下采样。 129 | num_head_channels: 每个注意力头的通道数。 130 | with_conditioning: 如果为 True,则添加空间变换器以进行条件处理。 131 | transformer_num_layers: Transformer 块的层数。 132 | cross_attention_dim: 使用的上下文维度数。 133 | num_class_embeds: 如果指定(作为整数),则该模型将使用类条件生成。 134 | upcast_attention: 如果为 True,则将注意力操作提升到全精度。 135 | use_flash_attention: 如果为 True,则使用内存效率高的闪存注意力机制。 136 | conditioning_embedding_in_channels: 条件嵌入的输入通道数。 137 | conditioning_embedding_num_channels: 条件嵌入的通道数。 138 | """ 139 | 140 | def __init__( 141 | self, T, ch, ch_mult, attn, num_res_blocks, 142 | dropout,spatial_dims = 2, 143 | num_zero_res_blocks = (2, 2, 2, 2), 144 | mlp_ratios = (1, 1) 145 | ): 146 | super().__init__() 147 | assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' 148 | 149 | tdim = ch * 4 150 | self.time_embedding = TimeEmbedding(T, ch, tdim) 151 | attn = [] # 注意力机制的层索引,初始化为空列表 152 | self.head = nn.Conv2d(4, ch, kernel_size=3, stride=1, padding=1) # 初始卷积层 153 | self.downblocks = nn.ModuleList() # 下采样块列表 154 | self.controlnet_down_blocks = nn.ModuleList() # controlnet下采样块列表 155 | self.controlnet_kan_blocks = nn.ModuleList() # kan_block块列表 156 | chs = [ch] # 记录下采样过程中使用的通道数 157 | now_ch = ch 158 | 159 | controlnet_block = Convolution( 160 | spatial_dims=spatial_dims, 161 | in_channels=now_ch, 162 | out_channels=now_ch, 163 | strides=1, 164 | kernel_size=1, 165 | padding=0, 166 | conv_only=True, 167 | ) 168 | # 初始化参数为0 169 | controlnet_block = zero_module(controlnet_block) 170 | self.controlnet_down_blocks.append(controlnet_block) 171 | # 构建下采样块 172 | self.zero_res_chs = [now_ch] 173 | for i, mult in enumerate(ch_mult): 174 | h_ch = ch * mult 175 | for _ in range(num_res_blocks): 176 | self.downblocks.append(ResBlock( 177 | in_ch=now_ch, h_ch=h_ch, tdim=tdim, 178 | dropout=dropout, attn=(i in attn))) 179 | now_ch = h_ch 180 | chs.append(now_ch) 181 | controlnet_block = Convolution( 182 | spatial_dims=spatial_dims, 183 | in_channels=now_ch, 184 | out_channels=now_ch, 185 | strides=1, 186 | kernel_size=1, 187 | padding=0, 188 | conv_only=True, 189 | ) 190 | # 初始化参数为0 191 | controlnet_block = zero_module(controlnet_block) 192 | self.controlnet_down_blocks.append(controlnet_block) 193 | if i != len(ch_mult) - 1: 194 | self.downblocks.append(DownSample(now_ch)) 195 | chs.append(now_ch) 196 | controlnet_block = Convolution( 197 | spatial_dims=spatial_dims, 198 | in_channels=now_ch, 199 | out_channels=now_ch, 200 | strides=1, 201 | kernel_size=1, 202 | padding=0, 203 | conv_only=True, 204 | ) 205 | # 初始化参数为0 206 | controlnet_block = zero_module(controlnet_block) 207 | self.controlnet_down_blocks.append(controlnet_block) 208 | 209 | self.chs = chs 210 | 211 | # 额外的特征提取层 212 | embed_dims = [256, 320, 512] 213 | norm_layer = nn.LayerNorm 214 | dpr = [0.0, 0.0, 0.0] 215 | self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 216 | embed_dim=embed_dims[1]) 217 | self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 218 | embed_dim=embed_dims[2]) 219 | 220 | self.norm3 = norm_layer(embed_dims[1]) 221 | self.norm4 = norm_layer(embed_dims[2]) 222 | self.dnorm3 = norm_layer(embed_dims[1]) 223 | 224 | self.kan_block1 = nn.ModuleList([shiftedBlock( 225 | dim=embed_dims[1], mlp_ratio=mlp_ratios[0], drop_path=dpr[0], norm_layer=norm_layer)]) 226 | 227 | self.controlnet_kan_blocks.append(zero_module(Convolution( 228 | spatial_dims=spatial_dims, 229 | in_channels=embed_dims[1] * mlp_ratios[0], 230 | out_channels=embed_dims[1] * mlp_ratios[0], 231 | strides=1, 232 | kernel_size=1, 233 | padding=0, 234 | conv_only=True, 235 | ))) 236 | 237 | self.kan_block2 = nn.ModuleList([shiftedBlock( 238 | dim=embed_dims[2], mlp_ratio=mlp_ratios[1], drop_path=dpr[1], norm_layer=norm_layer)]) 239 | 240 | self.controlnet_kan_blocks.append(zero_module(Convolution( 241 | spatial_dims=spatial_dims, 242 | in_channels=embed_dims[2] * mlp_ratios[1], 243 | out_channels=embed_dims[2] * mlp_ratios[1], 244 | strides=1, 245 | kernel_size=1, 246 | padding=0, 247 | conv_only=True, 248 | ))) 249 | 250 | self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 251 | spatial_dims=2, 252 | in_channels=1, 253 | num_channels=(64, 128, 192, 256), 254 | out_channels=64, 255 | ) 256 | 257 | self.initialize() # 初始化权重 258 | 259 | def initialize(self): 260 | # 使用 Xavier 均匀分布初始化卷积层的权重 261 | init.xavier_uniform_(self.head.weight) 262 | init.zeros_(self.head.bias) 263 | 264 | def forward(self, x, t, controlnet_cond, conditioning_scale=1.0): 265 | # 时间嵌入 266 | t_emb = self.time_embedding(t) 267 | temb = t_emb.to(dtype=x.dtype) 268 | 269 | # 初始卷积 270 | h = self.head(x) 271 | # 引入条件 272 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 273 | h += controlnet_cond 274 | 275 | hs = [h] 276 | for layer in self.downblocks: 277 | h = layer(h, temb) 278 | hs.append(h) 279 | 280 | # 额外的特征提取块 281 | B = x.shape[0] 282 | 283 | # h -- [B, H*W, embed_dim] 284 | h, H, W = self.patch_embed3(h) 285 | for i, blk in enumerate(self.kan_block1): 286 | h = blk(h, H, W, temb) 287 | h = self.norm3(h) 288 | h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 289 | kan_hs = [h] 290 | 291 | h, H, W = self.patch_embed4(h) 292 | for i, blk in enumerate(self.kan_block2): 293 | h = blk(h, H, W, temb) 294 | h = self.norm4(h) 295 | h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 296 | kan_hs.append(h) 297 | 298 | # 6. Control net blocks 299 | controlnet_down_block_res_samples = () 300 | for h, controlnet_block in zip(hs, self.controlnet_down_blocks): 301 | h = controlnet_block(h) 302 | controlnet_down_block_res_samples += (h,) 303 | 304 | for kan_h, controlnet_kan_block in zip(kan_hs, self.controlnet_kan_blocks): 305 | kan_h = controlnet_kan_block(kan_h) 306 | controlnet_down_block_res_samples += (kan_h,) 307 | 308 | down_block_res_samples = controlnet_down_block_res_samples 309 | 310 | # 6. 缩放 311 | down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] 312 | # mid_block_res_sample *= conditioning_scale 313 | mid_block_res_sample = down_block_res_samples[-1] 314 | 315 | return down_block_res_samples, mid_block_res_sample 316 | -------------------------------------------------------------------------------- /generative/networks/nets/patchgan_discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from __future__ import annotations 13 | 14 | import warnings 15 | from collections.abc import Sequence 16 | 17 | import torch 18 | import torch.nn as nn 19 | from monai.networks.blocks import Convolution 20 | from monai.networks.layers import Act, get_pool_layer 21 | 22 | 23 | class MultiScalePatchDiscriminator(nn.Sequential): 24 | """ 25 | Multi-scale Patch-GAN discriminator based on Pix2PixHD: 26 | High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs 27 | Ting-Chun Wang1, Ming-Yu Liu1, Jun-Yan Zhu2, Andrew Tao1, Jan Kautz1, Bryan Catanzaro (1) 28 | (1) NVIDIA Corporation, 2UC Berkeley 29 | In CVPR 2018. 30 | Multi-Scale discriminator made up of several Patch-GAN discriminators, that process the images 31 | up to different spatial scales. 32 | 33 | Args: 34 | num_d: number of discriminators 35 | num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in each 36 | of the discriminators. In each layer, the number of channels are doubled and the spatial size is 37 | divided by 2. 38 | spatial_dims: number of spatial dimensions (1D, 2D etc.) 39 | num_channels: number of filters in the first convolutional layer (double of the value is taken from then on) 40 | in_channels: number of input channels 41 | pooling_method: pooling method to be applied before each discriminator after the first. 42 | If None, the number of layers is multiplied by the number of discriminators. 43 | out_channels: number of output channels in each discriminator 44 | kernel_size: kernel size of the convolution layers 45 | activation: activation layer type 46 | norm: normalisation type 47 | bias: introduction of layer bias 48 | dropout: proportion of dropout applied, defaults to 0. 49 | minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture 50 | requested isn't going to downsample the input image beyond value of 1. 51 | last_conv_kernel_size: kernel size of the last convolutional layer. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | num_d: int, 57 | num_layers_d: int | list[int], 58 | spatial_dims: int, 59 | num_channels: int, 60 | in_channels: int, 61 | pooling_method: str = None, 62 | out_channels: int = 1, 63 | kernel_size: int = 4, 64 | activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), 65 | norm: str | tuple = "BATCH", 66 | bias: bool = False, 67 | dropout: float | tuple = 0.0, 68 | minimum_size_im: int = 256, 69 | last_conv_kernel_size: int = 1, 70 | ) -> None: 71 | super().__init__() 72 | self.num_d = num_d 73 | if isinstance(num_layers_d, int) and pooling_method is None: 74 | # if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators 75 | num_layers_d = [num_layers_d * i for i in range(1, num_d + 1)] 76 | elif isinstance(num_layers_d, int) and pooling_method is not None: 77 | # if pooling_method is not None, the number of layers is the same for all discriminators 78 | num_layers_d = [num_layers_d] * num_d 79 | self.num_layers_d = num_layers_d 80 | assert ( 81 | len(self.num_layers_d) == self.num_d 82 | ), f"MultiScalePatchDiscriminator: num_d {num_d} must match the number of num_layers_d. {num_layers_d}" 83 | 84 | self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims) 85 | 86 | if pooling_method is None: 87 | pool = None 88 | else: 89 | pool = get_pool_layer( 90 | (pooling_method, {"kernel_size": kernel_size, "stride": 2, "padding": self.padding}), 91 | spatial_dims=spatial_dims, 92 | ) 93 | self.num_channels = num_channels 94 | for i_ in range(self.num_d): 95 | num_layers_d_i = self.num_layers_d[i_] 96 | output_size = float(minimum_size_im) / (2**num_layers_d_i) 97 | if output_size < 1: 98 | raise AssertionError( 99 | "Your image size is too small to take in up to %d discriminators with num_layers = %d." 100 | "Please reduce num_layers, reduce num_D or enter bigger images." % (i_, num_layers_d_i) 101 | ) 102 | if i_ == 0 or pool is None: 103 | subnet_d = PatchDiscriminator( 104 | spatial_dims=spatial_dims, 105 | num_channels=self.num_channels, 106 | in_channels=in_channels, 107 | out_channels=out_channels, 108 | num_layers_d=num_layers_d_i, 109 | kernel_size=kernel_size, 110 | activation=activation, 111 | norm=norm, 112 | bias=bias, 113 | padding=self.padding, 114 | dropout=dropout, 115 | last_conv_kernel_size=last_conv_kernel_size, 116 | ) 117 | else: 118 | subnet_d = nn.Sequential( 119 | *[pool] * i_, 120 | PatchDiscriminator( 121 | spatial_dims=spatial_dims, 122 | num_channels=self.num_channels, 123 | in_channels=in_channels, 124 | out_channels=out_channels, 125 | num_layers_d=num_layers_d_i, 126 | kernel_size=kernel_size, 127 | activation=activation, 128 | norm=norm, 129 | bias=bias, 130 | padding=self.padding, 131 | dropout=dropout, 132 | last_conv_kernel_size=last_conv_kernel_size, 133 | ), 134 | ) 135 | 136 | self.add_module("discriminator_%d" % i_, subnet_d) 137 | 138 | def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]: 139 | """ 140 | 141 | Args: 142 | i: Input tensor 143 | Returns: 144 | list of outputs and another list of lists with the intermediate features 145 | of each discriminator. 146 | """ 147 | 148 | out: list[torch.Tensor] = [] 149 | intermediate_features: list[list[torch.Tensor]] = [] 150 | for disc in self.children(): 151 | out_d: list[torch.Tensor] = disc(i) 152 | out.append(out_d[-1]) 153 | intermediate_features.append(out_d[:-1]) 154 | 155 | return out, intermediate_features 156 | 157 | 158 | class PatchDiscriminator(nn.Sequential): 159 | """ 160 | Patch-GAN discriminator based on Pix2PixHD: 161 | High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs 162 | Ting-Chun Wang1, Ming-Yu Liu1, Jun-Yan Zhu2, Andrew Tao1, Jan Kautz1, Bryan Catanzaro (1) 163 | (1) NVIDIA Corporation, 2UC Berkeley 164 | In CVPR 2018. 165 | 166 | Args: 167 | spatial_dims: number of spatial dimensions (1D, 2D etc.) 168 | num_channels: number of filters in the first convolutional layer (double of the value is taken from then on) 169 | in_channels: number of input channels 170 | out_channels: number of output channels in each discriminator 171 | num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in each 172 | of the discriminators. In each layer, the number of channels are doubled and the spatial size is 173 | divided by 2. 174 | kernel_size: kernel size of the convolution layers 175 | activation: activation layer type 176 | norm: normalisation type 177 | bias: introduction of layer bias 178 | padding: padding to be applied to the convolutional layers 179 | dropout: proportion of dropout applied, defaults to 0. 180 | last_conv_kernel_size: kernel size of the last convolutional layer. 181 | """ 182 | 183 | def __init__( 184 | self, 185 | spatial_dims: int, 186 | num_channels: int, 187 | in_channels: int, 188 | out_channels: int = 1, 189 | num_layers_d: int = 3, 190 | kernel_size: int = 4, 191 | activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), 192 | norm: str | tuple = "BATCH", 193 | bias: bool = False, 194 | padding: int | Sequence[int] = 1, 195 | dropout: float | tuple = 0.0, 196 | last_conv_kernel_size: int | None = None, 197 | ) -> None: 198 | super().__init__() 199 | self.num_layers_d = num_layers_d 200 | self.num_channels = num_channels 201 | if last_conv_kernel_size is None: 202 | last_conv_kernel_size = kernel_size 203 | 204 | self.add_module( 205 | "initial_conv", 206 | Convolution( 207 | spatial_dims=spatial_dims, 208 | kernel_size=kernel_size, 209 | in_channels=in_channels, 210 | out_channels=num_channels, 211 | act=activation, 212 | bias=True, 213 | norm=None, 214 | dropout=dropout, 215 | padding=padding, 216 | strides=2, 217 | ), 218 | ) 219 | 220 | input_channels = num_channels 221 | output_channels = num_channels * 2 222 | 223 | # Initial Layer 224 | for l_ in range(self.num_layers_d): 225 | if l_ == self.num_layers_d - 1: 226 | stride = 1 227 | else: 228 | stride = 2 229 | layer = Convolution( 230 | spatial_dims=spatial_dims, 231 | kernel_size=kernel_size, 232 | in_channels=input_channels, 233 | out_channels=output_channels, 234 | act=activation, 235 | bias=bias, 236 | norm=norm, 237 | dropout=dropout, 238 | padding=padding, 239 | strides=stride, 240 | ) 241 | self.add_module("%d" % l_, layer) 242 | input_channels = output_channels 243 | output_channels = output_channels * 2 244 | 245 | # Final layer 246 | self.add_module( 247 | "final_conv", 248 | Convolution( 249 | spatial_dims=spatial_dims, 250 | kernel_size=last_conv_kernel_size, 251 | in_channels=input_channels, 252 | out_channels=out_channels, 253 | bias=True, 254 | conv_only=True, 255 | padding=int((last_conv_kernel_size - 1) / 2), 256 | dropout=0.0, 257 | strides=1, 258 | ), 259 | ) 260 | 261 | self.apply(self.initialise_weights) 262 | if norm.lower() == "batch" and torch.distributed.is_initialized(): 263 | warnings.warn( 264 | "WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. " 265 | "To train with DDP, convert discriminator to SyncBatchNorm using " 266 | "torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)" 267 | ) 268 | 269 | def forward(self, x: torch.Tensor) -> list[torch.Tensor]: 270 | """ 271 | 272 | Args: 273 | x: input tensor 274 | feature-matching loss (regulariser loss) on the discriminators as well (see Pix2Pix paper). 275 | Returns: 276 | list of intermediate features, with the last element being the output. 277 | """ 278 | out = [x] 279 | for submodel in self.children(): 280 | intermediate_output = submodel(out[-1]) 281 | out.append(intermediate_output) 282 | 283 | return out[1:] 284 | 285 | def initialise_weights(self, m: nn.Module) -> None: 286 | """ 287 | Initialise weights of Convolution and BatchNorm layers. 288 | 289 | Args: 290 | m: instance of torch.nn.module (or of class inheriting torch.nn.module) 291 | """ 292 | classname = m.__class__.__name__ 293 | if classname.find("Conv2d") != -1: 294 | nn.init.normal_(m.weight.data, 0.0, 0.02) 295 | elif classname.find("Conv3d") != -1: 296 | nn.init.normal_(m.weight.data, 0.0, 0.02) 297 | elif classname.find("Conv1d") != -1: 298 | nn.init.normal_(m.weight.data, 0.0, 0.02) 299 | elif classname.find("BatchNorm") != -1: 300 | nn.init.normal_(m.weight.data, 1.0, 0.02) 301 | nn.init.constant_(m.bias.data, 0) 302 | --------------------------------------------------------------------------------