├── code ├── .DS_Store ├── pretrained_ckpt │ └── readme.txt ├── configs │ └── swin_tiny_patch4_window7_224_lite.yaml ├── networks │ ├── net_factory_3d.py │ ├── discriminator.py │ ├── attention.py │ ├── unet_3D.py │ ├── VoxResNet.py │ ├── vision_transformer.py │ └── unet_3D_dv_semi.py ├── test_brats2019_semi_seg.sh ├── test_acdc_unet_semi_seg.sh ├── train_acdc_unet_semi_seg.sh ├── utils │ ├── metrics.py │ └── ramps.py ├── dataloaders │ ├── acdc_data_processing.py │ └── brats_proprecessing.py ├── test_3D.py ├── train_brats2019_semi_seg.sh ├── augmentations │ └── __init__.py ├── test_urpc.py └── val_2D.py ├── utils ├── __pycache__ │ ├── ST_loss.cpython-38.pyc │ ├── config.cpython-312.pyc │ ├── losses.cpython-310.pyc │ ├── losses.cpython-312.pyc │ ├── ramps.cpython-310.pyc │ ├── ramps.cpython-312.pyc │ ├── util.cpython-310.pyc │ ├── util.cpython-312.pyc │ ├── utils.cpython-312.pyc │ ├── ST_loss.cpython-310.pyc │ ├── builder.cpython-310.pyc │ ├── data_us.cpython-312.pyc │ ├── imgname.cpython-312.pyc │ ├── metrics.cpython-310.pyc │ ├── metrics.cpython-312.pyc │ ├── mind_3d.cpython-310.pyc │ ├── sortName.cpython-312.pyc │ ├── dist_helper.cpython-312.pyc │ ├── evaluation.cpython-312.pyc │ ├── get_prompts.cpython-310.pyc │ ├── get_prompts.cpython-312.pyc │ ├── loss_helper.cpython-310.pyc │ ├── vis_torch.cpython-312.pyc │ ├── data_mrliver.cpython-312.pyc │ ├── evaluation_ap.cpython-312.pyc │ ├── metrics_samus.cpython-312.pyc │ ├── visualization.cpython-312.pyc │ ├── generate_prompts.cpython-310.pyc │ ├── generate_prompts.cpython-312.pyc │ └── data_mrliver_norm.cpython-312.pyc ├── loss_functions │ ├── nd_softmax.py │ ├── __pycache__ │ │ ├── TopK_loss.cpython-36.pyc │ │ ├── dice_loss.cpython-36.pyc │ │ ├── sam_loss.cpython-310.pyc │ │ ├── sam_loss.cpython-312.pyc │ │ ├── sam_loss.cpython-37.pyc │ │ ├── sam_loss.cpython-38.pyc │ │ ├── crossentropy.cpython-36.pyc │ │ ├── nd_softmax.cpython-36.pyc │ │ └── tensor_utils.cpython-36.pyc │ ├── crossentropy.py │ ├── TopK_loss.py │ └── tensor_utils.py ├── imgname.py ├── metrics.py ├── ramps.py ├── dist_helper.py ├── sortName.py ├── generate_prompts.py ├── builder.py ├── metrics_samus.py ├── utils.py └── vis_torch.py ├── networks ├── __pycache__ │ ├── enet.cpython-310.pyc │ ├── enet.cpython-312.pyc │ ├── mind.cpython-310.pyc │ ├── mind.cpython-312.pyc │ ├── pnet.cpython-310.pyc │ ├── pnet.cpython-312.pyc │ ├── unet.cpython-310.pyc │ ├── unet.cpython-312.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-312.pyc │ ├── xnet.cpython-310.pyc │ ├── xnet.cpython-312.pyc │ ├── config.cpython-310.pyc │ ├── config.cpython-312.pyc │ ├── medsam.cpython-310.pyc │ ├── medsam.cpython-312.pyc │ ├── nnunet.cpython-310.pyc │ ├── nnunet.cpython-312.pyc │ ├── attention.cpython-310.pyc │ ├── attention.cpython-312.pyc │ ├── cloformer.cpython-310.pyc │ ├── cloformer.cpython-312.pyc │ ├── discriminator.cpython-310.pyc │ ├── discriminator.cpython-312.pyc │ ├── efficientunet.cpython-310.pyc │ ├── efficientunet.cpython-312.pyc │ ├── net_factory.cpython-310.pyc │ ├── net_factory.cpython-312.pyc │ ├── tiny_vit_sam.cpython-310.pyc │ ├── tiny_vit_sam.cpython-312.pyc │ ├── unet_unimatch.cpython-312.pyc │ ├── attention_unet.cpython-310.pyc │ ├── attention_unet.cpython-312.pyc │ ├── networks_other.cpython-310.pyc │ ├── networks_other.cpython-312.pyc │ ├── neural_network.cpython-310.pyc │ ├── neural_network.cpython-312.pyc │ ├── efficient_encoder.cpython-310.pyc │ ├── efficient_encoder.cpython-312.pyc │ ├── vision_transformer.cpython-310.pyc │ ├── vision_transformer.cpython-312.pyc │ ├── grid_attention_layer.cpython-310.pyc │ ├── grid_attention_layer.cpython-312.pyc │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc │ └── swin_transformer_unet_skip_expand_decoder_sys.cpython-312.pyc ├── net_factory_3d.py ├── swin_block.py ├── discriminator.py ├── attention.py ├── unet_3D.py ├── VoxResNet.py ├── vision_transformer.py ├── mind.py └── unet_3D_dv_semi.py ├── dataloaders ├── __pycache__ │ ├── acdc.cpython-310.pyc │ ├── acdc.cpython-312.pyc │ ├── busi.cpython-310.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-312.pyc │ ├── Synapse.cpython-310.pyc │ ├── acdc_ex.cpython-310.pyc │ ├── acdc_ex.cpython-312.pyc │ ├── dataset.cpython-310.pyc │ ├── dataset.cpython-312.pyc │ ├── mrliver.cpython-312.pyc │ ├── augs_TIBA.cpython-310.pyc │ ├── augs_TIBA.cpython-312.pyc │ ├── transform.cpython-310.pyc │ ├── transform.cpython-312.pyc │ ├── dataset_BUSI.cpython-310.pyc │ ├── dataset_acdc.cpython-312.pyc │ ├── dataset_aug.cpython-310.pyc │ ├── dataset_norm.cpython-312.pyc │ ├── dataset_synapse.cpython-310.pyc │ └── dataset_BUSI_uni.cpython-310.pyc ├── acdc_data_processing.py ├── transform.py ├── acdc_ex.py ├── mrliver.py ├── acdc.py ├── Synapse.py ├── busi.py └── brats_proprecessing.py ├── model_sam ├── __pycache__ │ ├── model_dict.cpython-37.pyc │ ├── model_dict.cpython-38.pyc │ ├── model_dict.cpython-310.pyc │ └── model_dict.cpython-312.pyc ├── segment_anything │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-312.pyc │ │ ├── build_sam.cpython-310.pyc │ │ ├── build_sam.cpython-312.pyc │ │ ├── build_sam.cpython-37.pyc │ │ ├── build_sam.cpython-38.pyc │ │ ├── predictor.cpython-310.pyc │ │ ├── predictor.cpython-312.pyc │ │ ├── predictor.cpython-37.pyc │ │ ├── predictor.cpython-38.pyc │ │ ├── automatic_mask_generator.cpython-310.pyc │ │ ├── automatic_mask_generator.cpython-312.pyc │ │ ├── automatic_mask_generator.cpython-37.pyc │ │ └── automatic_mask_generator.cpython-38.pyc │ ├── utils │ │ ├── __pycache__ │ │ │ ├── amg.cpython-310.pyc │ │ │ ├── amg.cpython-312.pyc │ │ │ ├── amg.cpython-37.pyc │ │ │ ├── amg.cpython-38.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── transforms.cpython-37.pyc │ │ │ ├── transforms.cpython-38.pyc │ │ │ ├── transforms.cpython-310.pyc │ │ │ └── transforms.cpython-312.pyc │ │ └── __init__.py │ ├── modeling │ │ ├── __pycache__ │ │ │ ├── sam.cpython-310.pyc │ │ │ ├── sam.cpython-312.pyc │ │ │ ├── sam.cpython-37.pyc │ │ │ ├── sam.cpython-38.pyc │ │ │ ├── common.cpython-310.pyc │ │ │ ├── common.cpython-312.pyc │ │ │ ├── common.cpython-37.pyc │ │ │ ├── common.cpython-38.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── mask_decoder.cpython-37.pyc │ │ │ ├── mask_decoder.cpython-38.pyc │ │ │ ├── transformer.cpython-310.pyc │ │ │ ├── transformer.cpython-312.pyc │ │ │ ├── transformer.cpython-37.pyc │ │ │ ├── transformer.cpython-38.pyc │ │ │ ├── image_encoder.cpython-310.pyc │ │ │ ├── image_encoder.cpython-312.pyc │ │ │ ├── image_encoder.cpython-37.pyc │ │ │ ├── image_encoder.cpython-38.pyc │ │ │ ├── mask_decoder.cpython-310.pyc │ │ │ ├── mask_decoder.cpython-312.pyc │ │ │ ├── prompt_encoder.cpython-37.pyc │ │ │ ├── prompt_encoder.cpython-38.pyc │ │ │ ├── prompt_encoder.cpython-310.pyc │ │ │ └── prompt_encoder.cpython-312.pyc │ │ ├── __init__.py │ │ └── common.py │ └── __init__.py ├── segment_anything_samus │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-312.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── predictor.cpython-37.pyc │ │ ├── predictor.cpython-38.pyc │ │ ├── build_sam_us.cpython-37.pyc │ │ ├── build_sam_us.cpython-38.pyc │ │ ├── build_sam_us.cpython-310.pyc │ │ ├── build_sam_us.cpython-312.pyc │ │ ├── automatic_mask_generator.cpython-310.pyc │ │ ├── automatic_mask_generator.cpython-312.pyc │ │ ├── automatic_mask_generator.cpython-37.pyc │ │ └── automatic_mask_generator.cpython-38.pyc │ ├── utils │ │ ├── __pycache__ │ │ │ ├── amg.cpython-310.pyc │ │ │ ├── amg.cpython-312.pyc │ │ │ ├── amg.cpython-37.pyc │ │ │ ├── amg.cpython-38.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── transforms.cpython-37.pyc │ │ │ └── transforms.cpython-38.pyc │ │ └── __init__.py │ ├── modeling │ │ ├── __pycache__ │ │ │ ├── common.cpython-37.pyc │ │ │ ├── common.cpython-38.pyc │ │ │ ├── samus.cpython-310.pyc │ │ │ ├── samus.cpython-312.pyc │ │ │ ├── samus.cpython-37.pyc │ │ │ ├── samus.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── common.cpython-310.pyc │ │ │ ├── common.cpython-312.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-312.pyc │ │ │ ├── transformer.cpython-37.pyc │ │ │ ├── transformer.cpython-38.pyc │ │ │ ├── image_encoder.cpython-37.pyc │ │ │ ├── image_encoder.cpython-38.pyc │ │ │ ├── mask_decoder.cpython-310.pyc │ │ │ ├── mask_decoder.cpython-312.pyc │ │ │ ├── mask_decoder.cpython-37.pyc │ │ │ ├── mask_decoder.cpython-38.pyc │ │ │ ├── transformer.cpython-310.pyc │ │ │ ├── transformer.cpython-312.pyc │ │ │ ├── image_encoder.cpython-310.pyc │ │ │ ├── image_encoder.cpython-312.pyc │ │ │ ├── prompt_encoder.cpython-310.pyc │ │ │ ├── prompt_encoder.cpython-312.pyc │ │ │ ├── prompt_encoder.cpython-37.pyc │ │ │ └── prompt_encoder.cpython-38.pyc │ │ ├── __init__.py │ │ └── common.py │ └── __init__.py └── model_dict.py ├── augmentations ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-312.pyc │ ├── ctaugment.cpython-310.pyc │ └── ctaugment.cpython-312.pyc └── __init__.py ├── configs ├── acdc.yaml └── mrliver.yaml ├── meduni_both.sh ├── README.md └── data ├── data_format_trans.py ├── data_split_train_test.py ├── data_split_train_test_val.py └── gen_train_test_val_name.py /code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/code/.DS_Store -------------------------------------------------------------------------------- /utils/__pycache__/ST_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/ST_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/config.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/losses.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/losses.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/losses.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/losses.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ramps.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/ramps.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ramps.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/ramps.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/util.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/enet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/enet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/enet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/enet.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/mind.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/mind.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/mind.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/mind.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/pnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/pnet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/pnet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/pnet.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/unet.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/xnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/xnet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/xnet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/xnet.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ST_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/ST_loss.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_us.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/data_us.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/imgname.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/imgname.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/metrics.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mind_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/mind_3d.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sortName.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/sortName.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/acdc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/acdc.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/acdc.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/acdc.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/busi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/busi.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/config.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/medsam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/medsam.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/medsam.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/medsam.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/nnunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/nnunet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/nnunet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/nnunet.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_helper.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/dist_helper.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluation.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/evaluation.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_prompts.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/get_prompts.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/get_prompts.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/get_prompts.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss_helper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/loss_helper.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/vis_torch.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/vis_torch.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/Synapse.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/Synapse.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/acdc_ex.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/acdc_ex.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/acdc_ex.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/acdc_ex.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/mrliver.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/mrliver.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/__pycache__/model_dict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/__pycache__/model_dict.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/__pycache__/model_dict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/__pycache__/model_dict.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/attention.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/attention.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/cloformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/cloformer.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/cloformer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/cloformer.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_mrliver.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/data_mrliver.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluation_ap.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/evaluation_ap.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics_samus.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/metrics_samus.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/visualization.cpython-312.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/augmentations/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/augmentations/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/augs_TIBA.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/augs_TIBA.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/augs_TIBA.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/augs_TIBA.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/transform.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/transform.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/__pycache__/model_dict.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/__pycache__/model_dict.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/__pycache__/model_dict.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/__pycache__/model_dict.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/discriminator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/discriminator.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/discriminator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/discriminator.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/efficientunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/efficientunet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/efficientunet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/efficientunet.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/net_factory.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/net_factory.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/net_factory.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/net_factory.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/tiny_vit_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/tiny_vit_sam.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/tiny_vit_sam.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/tiny_vit_sam.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/unet_unimatch.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/unet_unimatch.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/generate_prompts.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/generate_prompts.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/generate_prompts.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/generate_prompts.cpython-312.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/ctaugment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/augmentations/__pycache__/ctaugment.cpython-310.pyc -------------------------------------------------------------------------------- /augmentations/__pycache__/ctaugment.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/augmentations/__pycache__/ctaugment.cpython-312.pyc -------------------------------------------------------------------------------- /code/pretrained_ckpt/readme.txt: -------------------------------------------------------------------------------- 1 | download pre-trained model to this folder, link:https://drive.google.com/drive/folders/1UC3XOoezeum0uck4KBVGa8osahs6rKUY 2 | -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_BUSI.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_BUSI.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_acdc.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_acdc.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_aug.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_aug.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_norm.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_norm.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/attention_unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/attention_unet.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/attention_unet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/attention_unet.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/networks_other.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/networks_other.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/networks_other.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/networks_other.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/neural_network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/neural_network.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/neural_network.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/neural_network.cpython-312.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_mrliver_norm.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/__pycache__/data_mrliver_norm.cpython-312.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_synapse.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_synapse.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/efficient_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/efficient_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/efficient_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/efficient_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vision_transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/vision_transformer.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/vision_transformer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/vision_transformer.cpython-312.pyc -------------------------------------------------------------------------------- /utils/loss_functions/nd_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | softmax_helper = lambda x: F.softmax(x, 1) -------------------------------------------------------------------------------- /dataloaders/__pycache__/dataset_BUSI_uni.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/dataloaders/__pycache__/dataset_BUSI_uni.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/grid_attention_layer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/grid_attention_layer.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/grid_attention_layer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/grid_attention_layer.cpython-312.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/TopK_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/TopK_loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/dice_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/dice_loss.cpython-36.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/sam_loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/sam_loss.cpython-310.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/sam_loss.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/sam_loss.cpython-312.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/sam_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/sam_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/sam_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/sam_loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/crossentropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/crossentropy.cpython-36.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/nd_softmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/nd_softmax.cpython-36.pyc -------------------------------------------------------------------------------- /utils/loss_functions/__pycache__/tensor_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/utils/loss_functions/__pycache__/tensor_utils.cpython-36.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/build_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/build_sam.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/build_sam.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/build_sam.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/build_sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/build_sam.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/build_sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/build_sam.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/predictor.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/predictor.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/predictor.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/predictor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/predictor.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/predictor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/predictor.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/amg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/amg.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/amg.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/amg.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/amg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/amg.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/amg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/amg.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/sam.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/sam.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/sam.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/sam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/sam.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/sam.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/common.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/common.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/predictor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/predictor.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/predictor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/predictor.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/amg.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__pycache__/transforms.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/utils/__pycache__/transforms.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/transformer.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/build_sam_us.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/samus.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/common.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/__pycache__/automatic_mask_generator.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/mask_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/transformer.cpython-312.pyc -------------------------------------------------------------------------------- /networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-310.pyc -------------------------------------------------------------------------------- /networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/__pycache__/automatic_mask_generator.cpython-38.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/image_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple1986/SAMatch/HEAD/model_sam/segment_anything_samus/modeling/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /configs/acdc.yaml: -------------------------------------------------------------------------------- 1 | # arguments for dataset 2 | dataset: acdc 3 | nclass: 4 4 | crop_size: 256 5 | data_root: Your/ACDC/Path 6 | 7 | # arguments for training 8 | epochs: 300 9 | batch_size: 12 # per GPU x 1 GPU 10 | lr: 0.01 11 | conf_thresh: 0.95 12 | -------------------------------------------------------------------------------- /model_sam/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /configs/mrliver.yaml: -------------------------------------------------------------------------------- 1 | # arguments for dataset 2 | dataset: mrliver 3 | nclass: 2 4 | crop_size: 256 5 | data_root: /home/gxu/proj1/smatch/data/MRliver 6 | 7 | # arguments for training 8 | epochs: 300 9 | batch_size: 12 # per GPU x 1 GPU 10 | lr: 0.01 11 | conf_thresh: 0.95 12 | -------------------------------------------------------------------------------- /meduni_both.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --job-name=appletree 4 | #SBATCH --output=appletree_case0_5.txt 5 | # 6 | #SBATCH --partition=dgx2q 7 | #SBATCH --gres=gpu:a100:1 8 | # 9 | #SBATCH --ntasks=1 10 | 11 | srun python /data/maia/gpxu/proj1/samatch/train_unimatch_medsam_F2_ft_both_acdc.py -------------------------------------------------------------------------------- /utils/imgname.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def keep_img_name(path): 4 | with open("./imgname.txt", "w") as f: 5 | f.write(path) 6 | 7 | def read_img_name(): 8 | f = open(r"./imgname.txt", "r") 9 | file = f.readlines() 10 | for each in file: 11 | each = each.strip('\n') 12 | return each -------------------------------------------------------------------------------- /code/configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 2, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .samus import Samus 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam_us import ( 8 | build_samus, 9 | build_samus_vit_h, 10 | build_samus_vit_l, 11 | build_samus_vit_b, 12 | samus_model_registry, 13 | ) 14 | from .automatic_mask_generator import SamAutomaticMaskGenerator 15 | -------------------------------------------------------------------------------- /utils/loss_functions/crossentropy.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | 3 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 4 | """ 5 | this is just a compatibility layer because my target tensor is float and has an extra dimension 6 | """ 7 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 8 | if len(target.shape) == len(input.shape): 9 | assert target.shape[1] == 1 10 | target = target[:, 0] 11 | return super().forward(input, target.long()) -------------------------------------------------------------------------------- /model_sam/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /model_sam/model_dict.py: -------------------------------------------------------------------------------- 1 | from model_sam.segment_anything.build_sam import sam_model_registry 2 | from model_sam.segment_anything_samus.build_sam_us import samus_model_registry 3 | 4 | def get_model(modelname="SAM", args=None, opt=None): 5 | if modelname == "SAM": 6 | model = sam_model_registry['vit_b'](checkpoint=args.ckpt) 7 | elif modelname == "SAMUS": 8 | model = samus_model_registry['vit_b'](args=args, checkpoint=args.ckpt) 9 | else: 10 | raise RuntimeError("Could not find the model:", modelname) 11 | return model 12 | -------------------------------------------------------------------------------- /utils/loss_functions/TopK_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils.loss_functions.crossentropy import RobustCrossEntropyLoss 4 | 5 | class TopKLoss(RobustCrossEntropyLoss): 6 | """ 7 | Network has to have NO LINEARITY! 8 | """ 9 | def __init__(self, weight=None, ignore_index=-100, k=10): 10 | self.k = k 11 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False) 12 | 13 | def forward(self, inp, target): 14 | target = target[:, 0].long() 15 | res = super(TopKLoss, self).forward(inp, target) 16 | num_voxels = np.prod(res.shape, dtype=np.int64) 17 | res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) 18 | return res.mean() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The dataset and model can be downloaded in https://drive.google.com/drive/folders/1Kg1-fbFq2PD9N9Gqy9uLRKyxyivRlRED?usp=sharing 2 | 3 | Please mkdir a folder named checkpoint, and unzip the model (SAM) into this folder. 4 | 5 | Please unzip the dataset (ACDC and BUSI) into the data folder 6 | 7 | Please use the default splitting for train, validation and testing datasets for fair comparison. 8 | 9 | # SAMatch 10 | SAMatch: SAM-Guided and Match-based Semi-Supervised Segmentation for Medical Image 11 | 12 | 13 | This paper has been published in Medical Physics: 14 | https://aapm.onlinelibrary.wiley.com/doi/full/10.1002/mp.17785 15 | 16 | Xu, Guoping, et al. "A segment anything model‐guided and match‐based semi‐supervised segmentation framework for medical imaging." Medical physics (2025). 17 | 18 | The paper can be found in https://arxiv.org/abs/2411.16949 19 | 20 | Xu G, Qian X, Shao H C, et al. A SAM-guided and Match-based Semi-Supervised Segmentation Framework for Medical Imaging[J]. arXiv preprint arXiv:2411.16949, 2024. 21 | 22 | -------------------------------------------------------------------------------- /networks/net_factory_3d.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnet import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | from networks.nnunet import initialize_network 6 | 7 | 8 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 9 | if net_type == "unet_3D": 10 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() 11 | elif net_type == "attention_unet": 12 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 13 | elif net_type == "voxresnet": 14 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 15 | class_num=class_num).cuda() 16 | elif net_type == "vnet": 17 | net = VNet(n_channels=in_chns, n_classes=class_num, 18 | normalization='batchnorm', has_dropout=True).cuda() 19 | elif net_type == "nnUNet": 20 | net = initialize_network(num_classes=class_num).cuda() 21 | else: 22 | net = None 23 | return net 24 | -------------------------------------------------------------------------------- /code/networks/net_factory_3d.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnet import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | from networks.nnunet import initialize_network 6 | 7 | 8 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 9 | if net_type == "unet_3D": 10 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() 11 | elif net_type == "attention_unet": 12 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 13 | elif net_type == "voxresnet": 14 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 15 | class_num=class_num).cuda() 16 | elif net_type == "vnet": 17 | net = VNet(n_channels=in_chns, n_classes=class_num, 18 | normalization='batchnorm', has_dropout=True).cuda() 19 | elif net_type == "nnUNet": 20 | net = initialize_network(num_classes=class_num).cuda() 21 | else: 22 | net = None 23 | return net 24 | -------------------------------------------------------------------------------- /utils/loss_functions/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | 6 | def sum_tensor(inp, axes, keepdim=False): 7 | axes = np.unique(axes).astype(int) 8 | if keepdim: 9 | for ax in axes: 10 | inp = inp.sum(int(ax), keepdim=True) 11 | else: 12 | for ax in sorted(axes, reverse=True): 13 | inp = inp.sum(int(ax)) 14 | return inp 15 | 16 | 17 | def mean_tensor(inp, axes, keepdim=False): 18 | axes = np.unique(axes).astype(int) 19 | if keepdim: 20 | for ax in axes: 21 | inp = inp.mean(int(ax), keepdim=True) 22 | else: 23 | for ax in sorted(axes, reverse=True): 24 | inp = inp.mean(int(ax)) 25 | return inp 26 | 27 | 28 | def flip(x, dim): 29 | """ 30 | flips the tensor at dimension dim (mirroring!) 31 | :param x: 32 | :param dim: 33 | :return: 34 | """ 35 | indices = [slice(None)] * x.dim() 36 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 37 | dtype=torch.long, device=x.device) 38 | return x[tuple(indices)] 39 | -------------------------------------------------------------------------------- /code/test_brats2019_semi_seg.sh: -------------------------------------------------------------------------------- 1 | # & means run these methods at the same time, and && means run these methods one by one 2 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Fully_supervised_25 --model unet_3D && 3 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Fully_supervised_250 --model unet_3D && 4 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Mean_Teacher_25 --model unet_3D && 5 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Uncertainty_Aware_Mean_Teacher_25 --model unet_3D && 6 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Interpolation_Consistency_Training_25 --model unet_3D && 7 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Entropy_Minimization_25 --model unet_3D && 8 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Cross_Pseudo_Supervision_25 --model unet_3D && 9 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Adversarial_Network_25 --model unet_3D && 10 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Uncertainty_Rectified_Pyramid_Consistency_25 --model unet_3D_dv_semi -------------------------------------------------------------------------------- /code/test_acdc_unet_semi_seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 7 && \ 2 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Entropy_Minimization --num_classes 4 --labeled_num 7 && \ 3 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Interpolation_Consistency_Training --num_classes 4 --labeled_num 7 && \ 4 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Mean_Teacher --num_classes 4 --labeled_num 7 && \ 5 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Aware_Mean_Teacher --num_classes 4 --labeled_num 7 && \ 6 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Adversarial_Network --num_classes 4 --labeled_num 7 && \ 7 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Rectified_Pyramid_Consistency --model unet_urpc --num_classes 4 --labeled_num 7 && \ 8 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 140 -------------------------------------------------------------------------------- /code/train_acdc_unet_semi_seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_fully_supervised_2D.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 7 && \ 2 | CUDA_VISIBLE_DEVICES=0 python train_entropy_minimization_2D.py --root_path ../data/ACDC --exp ACDC/Entropy_Minimization --num_classes 4 --labeled_num 7 && \ 3 | CUDA_VISIBLE_DEVICES=0 python train_interpolation_consistency_training_2D.py --root_path ../data/ACDC --exp ACDC/Interpolation_Consistency_Training --num_classes 4 --labeled_num 7 && \ 4 | CUDA_VISIBLE_DEVICES=0 python train_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/Mean_Teacher --num_classes 4 --labeled_num 7 && \ 5 | CUDA_VISIBLE_DEVICES=0 python train_uncertainty_aware_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Aware_Mean_Teacher --num_classes 4 --labeled_num 7 && \ 6 | CUDA_VISIBLE_DEVICES=0 python train_adversarial_network_2D.py --root_path ../data/ACDC --exp ACDC/Adversarial_Network --num_classes 4 --labeled_num 7 && \ 7 | CUDA_VISIBLE_DEVICES=0 python train_uncertainty_rectified_pyramid_consistency_2D.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Rectified_Pyramid_Consistency --num_classes 4 --labeled_num 7 && \ 8 | CUDA_VISIBLE_DEVICES=0 python train_fully_supervised_2D.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 140 -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /utils/dist_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def setup_distributed(backend="nccl", port=None): 9 | """AdaHessian Optimizer 10 | Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py 11 | Originally licensed MIT, Copyright (c) 2020 Wei Li 12 | """ 13 | num_gpus = torch.cuda.device_count() 14 | 15 | if "SLURM_JOB_ID" in os.environ: 16 | rank = int(os.environ["SLURM_PROCID"]) 17 | world_size = int(os.environ["SLURM_NTASKS"]) 18 | node_list = os.environ["SLURM_NODELIST"] 19 | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") 20 | # specify master port 21 | if port is not None: 22 | os.environ["MASTER_PORT"] = str(port) 23 | elif "MASTER_PORT" not in os.environ: 24 | os.environ["MASTER_PORT"] = "10685" 25 | if "MASTER_ADDR" not in os.environ: 26 | os.environ["MASTER_ADDR"] = addr 27 | os.environ["WORLD_SIZE"] = str(world_size) 28 | os.environ["LOCAL_RANK"] = str(rank % num_gpus) 29 | os.environ["RANK"] = str(rank) 30 | else: 31 | rank = int(os.environ["RANK"]) 32 | world_size = int(os.environ["WORLD_SIZE"]) 33 | 34 | torch.cuda.set_device(rank % num_gpus) 35 | 36 | dist.init_process_group( 37 | backend=backend, 38 | world_size=world_size, 39 | rank=rank, 40 | ) 41 | return rank, world_size 42 | -------------------------------------------------------------------------------- /dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | slice_num = 0 9 | mask_path = sorted(glob.glob("/home/xdluo/data/ACDC/image/*.nii.gz")) 10 | for case in mask_path: 11 | img_itk = sitk.ReadImage(case) 12 | origin = img_itk.GetOrigin() 13 | spacing = img_itk.GetSpacing() 14 | direction = img_itk.GetDirection() 15 | image = sitk.GetArrayFromImage(img_itk) 16 | msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz") 17 | if os.path.exists(msk_path): 18 | print(msk_path) 19 | msk_itk = sitk.ReadImage(msk_path) 20 | mask = sitk.GetArrayFromImage(msk_itk) 21 | image = (image - image.min()) / (image.max() - image.min()) 22 | print(image.shape) 23 | image = image.astype(np.float32) 24 | item = case.split("/")[-1].split(".")[0] 25 | if image.shape != mask.shape: 26 | print("Error") 27 | print(item) 28 | for slice_ind in range(image.shape[0]): 29 | f = h5py.File( 30 | '/home/xdluo/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w') 31 | f.create_dataset( 32 | 'image', data=image[slice_ind], compression="gzip") 33 | f.create_dataset('label', data=mask[slice_ind], compression="gzip") 34 | f.close() 35 | slice_num += 1 36 | print("Converted all ACDC volumes to 2D slices") 37 | print("Total {} slices".format(slice_num)) 38 | -------------------------------------------------------------------------------- /code/dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | slice_num = 0 9 | mask_path = sorted(glob.glob("/home/xdluo/data/ACDC/image/*.nii.gz")) 10 | for case in mask_path: 11 | img_itk = sitk.ReadImage(case) 12 | origin = img_itk.GetOrigin() 13 | spacing = img_itk.GetSpacing() 14 | direction = img_itk.GetDirection() 15 | image = sitk.GetArrayFromImage(img_itk) 16 | msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz") 17 | if os.path.exists(msk_path): 18 | print(msk_path) 19 | msk_itk = sitk.ReadImage(msk_path) 20 | mask = sitk.GetArrayFromImage(msk_itk) 21 | image = (image - image.min()) / (image.max() - image.min()) 22 | print(image.shape) 23 | image = image.astype(np.float32) 24 | item = case.split("/")[-1].split(".")[0] 25 | if image.shape != mask.shape: 26 | print("Error") 27 | print(item) 28 | for slice_ind in range(image.shape[0]): 29 | f = h5py.File( 30 | '/home/xdluo/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w') 31 | f.create_dataset( 32 | 'image', data=image[slice_ind], compression="gzip") 33 | f.create_dataset('label', data=mask[slice_ind], compression="gzip") 34 | f.close() 35 | slice_num += 1 36 | print("Converted all ACDC volumes to 2D slices") 37 | print("Total {} slices".format(slice_num)) 38 | -------------------------------------------------------------------------------- /dataloaders/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from PIL import ImageFilter 5 | from scipy import ndimage 6 | 7 | 8 | def random_rot_flip(img, mask): 9 | k = np.random.randint(0, 4) 10 | img = np.rot90(img, k) 11 | mask = np.rot90(mask, k) 12 | axis = np.random.randint(0, 2) 13 | img = np.flip(img, axis=axis).copy() 14 | mask = np.flip(mask, axis=axis).copy() 15 | return img, mask 16 | 17 | 18 | def random_rotate(img, mask): 19 | angle = np.random.randint(-20, 20) 20 | img = ndimage.rotate(img, angle, order=0, reshape=False) 21 | mask = ndimage.rotate(mask, angle, order=0, reshape=False) 22 | return img, mask 23 | 24 | 25 | def blur(img, p=0.5): 26 | if random.random() < p: 27 | sigma = np.random.uniform(0.1, 2.0) 28 | img = img.filter(ImageFilter.GaussianBlur(radius=sigma)) 29 | return img 30 | 31 | 32 | def obtain_cutmix_box(img_size, p=0.5, size_min=0.02, size_max=0.4, ratio_1=0.3, ratio_2=1/0.3): 33 | mask = torch.zeros(img_size, img_size) 34 | if random.random() > p: 35 | return mask 36 | 37 | size = np.random.uniform(size_min, size_max) * img_size * img_size 38 | while True: 39 | ratio = np.random.uniform(ratio_1, ratio_2) 40 | cutmix_w = int(np.sqrt(size / ratio)) 41 | cutmix_h = int(np.sqrt(size * ratio)) 42 | x = np.random.randint(0, img_size) 43 | y = np.random.randint(0, img_size) 44 | 45 | if x + cutmix_w <= img_size and y + cutmix_h <= img_size: 46 | break 47 | 48 | mask[y:y + cutmix_h, x:x + cutmix_w] = 1 49 | 50 | return mask 51 | 52 | -------------------------------------------------------------------------------- /networks/swin_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__(self, in_channels): 7 | super(SelfAttention, self).__init__() 8 | 9 | # Define query, key, and value linear transformations 10 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) 11 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1) 12 | self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) 13 | 14 | # Attention softmax 15 | self.softmax = nn.Softmax(dim=-1) 16 | 17 | def forward(self, x): 18 | # Transform input for query, key, and value 19 | proj_query = self.query_conv(x) 20 | proj_key = self.key_conv(x) 21 | proj_value = self.value_conv(x) 22 | 23 | # Reshape for matrix multiplication 24 | B, C, H, W = proj_query.size() 25 | proj_query = proj_query.view(B, -1, H * W).permute(0, 2, 1) 26 | proj_key = proj_key.view(B, -1, H * W) 27 | 28 | # Compute attention scores 29 | energy = torch.bmm(proj_query, proj_key) 30 | attention = self.softmax(energy) 31 | 32 | # Compute weighted sum using attention scores 33 | proj_value = proj_value.view(B, -1, H * W) 34 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 35 | out = out.view(B, C, H, W) 36 | 37 | return out 38 | 39 | # Example usage 40 | input_tensor = torch.randn(24, 16, 128, 128) 41 | self_attention = SelfAttention(in_channels=16) 42 | output = self_attention(input_tensor) 43 | print(output.shape) 44 | -------------------------------------------------------------------------------- /code/test_3D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from glob import glob 5 | 6 | import torch 7 | 8 | from networks.unet_3D import unet_3D 9 | from test_3D_util import test_all_case 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='../data/BraTS2019', help='Name of Experiment') 14 | parser.add_argument('--exp', type=str, 15 | default='BraTS2019/Interpolation_Consistency_Training_25', help='experiment_name') 16 | parser.add_argument('--model', type=str, 17 | default='unet_3D', help='model_name') 18 | 19 | 20 | def Inference(FLAGS): 21 | snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) 22 | num_classes = 2 23 | test_save_path = "../model/{}/Prediction".format(FLAGS.exp) 24 | if os.path.exists(test_save_path): 25 | shutil.rmtree(test_save_path) 26 | os.makedirs(test_save_path) 27 | net = unet_3D(n_classes=num_classes, in_channels=1).cuda() 28 | save_mode_path = os.path.join( 29 | snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 30 | net.load_state_dict(torch.load(save_mode_path)) 31 | print("init weight from {}".format(save_mode_path)) 32 | net.eval() 33 | avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, 34 | patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) 35 | return avg_metric 36 | 37 | 38 | if __name__ == '__main__': 39 | FLAGS = parser.parse_args() 40 | metric = Inference(FLAGS) 41 | print(metric) 42 | -------------------------------------------------------------------------------- /model_sam/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /utils/sortName.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Nov 18 10:29:35 2020 4 | 5 | @author: David 6 | """ 7 | 8 | import glob 9 | import os 10 | import json 11 | import re 12 | from tifffile import imread 13 | import numpy as np 14 | import h5py 15 | from random import shuffle 16 | import math 17 | 18 | import matplotlib.pyplot as plt 19 | plt.switch_backend('agg') 20 | 21 | def tryint(s): 22 | try: 23 | return int(s) 24 | except ValueError: 25 | return s 26 | 27 | def str2int(v_str): 28 | return [tryint(sub_str) for sub_str in re.split('([0-9]+)', v_str)] 29 | 30 | def sort_humanly(v_list): 31 | return sorted(v_list, key=str2int) 32 | 33 | # def getfilename(root_path): 34 | # for root, dirs, files in os.walk(root_path): 35 | # array = dirs 36 | # if array: 37 | # return array 38 | 39 | def get_filename(data_path, pattern="*.png"): 40 | image_name = [] 41 | data_image = os.path.join(data_path, pattern) 42 | for name in glob.glob(data_image, recursive=True): 43 | image_name.append(name) 44 | image_name = sort_humanly(image_name) 45 | return image_name 46 | 47 | 48 | if __name__=='__main__': 49 | data_path = '/home/gxu/proj1/smatch/data/MRbrain/DICOM' # the whole brain path 50 | save_path = '/home/gpxu/vess_seg/vess_efficient/aping' 51 | # print(data_label) 52 | ## obtain file name 53 | image_name = get_filename(data_path, pattern="*.mat") 54 | image_name = sort_humanly(image_name) 55 | print(image_name[:10]) 56 | 57 | # with open('whole_brain_name.txt','a') as f: 58 | # for n in image_name: 59 | # f.write(n + '\n') 60 | 61 | print('done') -------------------------------------------------------------------------------- /code/train_brats2019_semi_seg.sh: -------------------------------------------------------------------------------- 1 | # & means run these methods at the same time, and && means run these methods one by one 2 | python -u train_fully_supervised_3D.py --labeled_num 25 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Fully_supervised --base_lr 0.1 && 3 | python -u train_fully_supervised_3D.py --labeled_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Fully_supervised --base_lr 0.1 && 4 | python -u train_adversarial_network_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Adversarial_Network --base_lr 0.1 && 5 | python -u train_entropy_minimization_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Entropy_Minimization --base_lr 0.1 && 6 | python -u train_interpolation_consistency_training_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Interpolation_Consistency_Training && 7 | python -u train_mean_teacher_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Mean_Teacher --base_lr 0.1 && 8 | python -u train_uncertainty_aware_mean_teacher_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Uncertainty_Aware_Mean_Teacher && 9 | python -u train_uncertainty_rectified_pyramid_consistency_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Uncertainty_Rectified_Pyramid_Consistency && 10 | python -u train_cross_pseudo_supervision_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Cross_Pseudo_Supervision -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | from augmentations.ctaugment import * 5 | 6 | 7 | class StorableCTAugment(CTAugment): 8 | def load_state_dict(self, state): 9 | for k in ["decay", "depth", "th", "rates"]: 10 | assert k in state, "{} not in {}".format(k, state.keys()) 11 | setattr(self, k, state[k]) 12 | 13 | def state_dict(self): 14 | return OrderedDict( 15 | [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]] 16 | ) 17 | 18 | 19 | def get_default_cta(): 20 | return StorableCTAugment() 21 | 22 | 23 | def cta_apply(pil_img, ops): 24 | if ops is None: 25 | return pil_img 26 | for op, args in ops: 27 | pil_img = OPS[op].f(pil_img, *args) 28 | return pil_img 29 | 30 | 31 | def deserialize(policy_str): 32 | return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] 33 | 34 | 35 | def stats(cta): 36 | return "\n".join( 37 | "%-16s %s" 38 | % ( 39 | k, 40 | " / ".join( 41 | " ".join("%.2f" % x for x in cta.rate_to_p(rate)) 42 | for rate in cta.rates[k] 43 | ), 44 | ) 45 | for k in sorted(OPS.keys()) 46 | ) 47 | 48 | 49 | def interleave(x, batch, inverse=False): 50 | """ 51 | TF code 52 | def interleave(x, batch): 53 | s = x.get_shape().as_list() 54 | return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) 55 | """ 56 | shape = x.shape 57 | axes = [batch, -1] if inverse else [-1, batch] 58 | return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) 59 | 60 | 61 | def deinterleave(x, batch): 62 | return interleave(x, batch, inverse=True) 63 | -------------------------------------------------------------------------------- /code/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | from augmentations.ctaugment import * 5 | 6 | 7 | class StorableCTAugment(CTAugment): 8 | def load_state_dict(self, state): 9 | for k in ["decay", "depth", "th", "rates"]: 10 | assert k in state, "{} not in {}".format(k, state.keys()) 11 | setattr(self, k, state[k]) 12 | 13 | def state_dict(self): 14 | return OrderedDict( 15 | [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]] 16 | ) 17 | 18 | 19 | def get_default_cta(): 20 | return StorableCTAugment() 21 | 22 | 23 | def cta_apply(pil_img, ops): 24 | if ops is None: 25 | return pil_img 26 | for op, args in ops: 27 | pil_img = OPS[op].f(pil_img, *args) 28 | return pil_img 29 | 30 | 31 | def deserialize(policy_str): 32 | return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] 33 | 34 | 35 | def stats(cta): 36 | return "\n".join( 37 | "%-16s %s" 38 | % ( 39 | k, 40 | " / ".join( 41 | " ".join("%.2f" % x for x in cta.rate_to_p(rate)) 42 | for rate in cta.rates[k] 43 | ), 44 | ) 45 | for k in sorted(OPS.keys()) 46 | ) 47 | 48 | 49 | def interleave(x, batch, inverse=False): 50 | """ 51 | TF code 52 | def interleave(x, batch): 53 | s = x.get_shape().as_list() 54 | return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) 55 | """ 56 | shape = x.shape 57 | axes = [batch, -1] if inverse else [-1, batch] 58 | return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) 59 | 60 | 61 | def deinterleave(x, batch): 62 | return interleave(x, batch, inverse=True) 63 | -------------------------------------------------------------------------------- /data/data_format_trans.py: -------------------------------------------------------------------------------- 1 | ## change .mat file to h5 format and save each volume as slices 2 | 3 | import scipy.io as sio 4 | import os 5 | import utils.sortName as getname 6 | import h5py 7 | 8 | # set path 9 | root_path = "/home/gxu/proj1/smatch/data/MRbrain" 10 | img_path = os.path.join(root_path, "DICOM") 11 | lab_path = os.path.join(root_path, "Labels") 12 | img_lab_slice_path = os.path.join(root_path, "img_label_slice") 13 | img_lab_volume_path = os.path.join(root_path, "img_label_volume") 14 | 15 | # satatical all files 16 | img_name = getname.get_filename(img_path, pattern="*.mat") 17 | lab_name = getname.get_filename(lab_path, pattern="*.mat") 18 | print(len(img_name)) 19 | # print(img_name[:10]) 20 | 21 | ## convert .mat to h5 22 | # load data 23 | total_slice = 0 24 | for one_img_path, one_lab_path in zip(img_name, lab_name): 25 | # basename 26 | img_basename = os.path.basename(one_img_path).split(".mat")[0] 27 | # lab_basename = os.path.basename(one_lab_path) 28 | # load data 29 | img_vol = sio.loadmat(one_img_path)["img"] 30 | lab_vol = sio.loadmat(one_lab_path)["label"] 31 | # save as volume 32 | vol_name = os.path.join(img_lab_volume_path, "subj"+img_basename +".h5") 33 | # save as h5 34 | hf = h5py.File(vol_name, 'w') 35 | hf.create_dataset('image', data=img_vol) 36 | hf.create_dataset('label', data=lab_vol) 37 | hf.close() 38 | 39 | # save as slice 40 | total_slice = total_slice + img_vol.shape[2] 41 | # read each slice 42 | for n in range(0, img_vol.shape[2]): # how many slices in each case 43 | img = img_vol[:,:,n] 44 | lab = lab_vol[:,:,n] 45 | 46 | # save path 47 | data_save_name = os.path.join(img_lab_slice_path, "subj"+img_basename+"_s"+str(n)+".h5") 48 | # save as h5 49 | hf = h5py.File(data_save_name, 'w') 50 | hf.create_dataset('image', data=img) 51 | hf.create_dataset('label', data=lab) 52 | hf.close() 53 | 54 | print("total slice number: {}".format(total_slice)) 55 | 56 | # ## test 57 | # h5f = h5py.File(data_save_name, "r") 58 | # image = h5f["image"][:] 59 | # label = h5f["label"][:] 60 | # print(image.shape) 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /code/test_urpc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from glob import glob 5 | import numpy 6 | 7 | import torch 8 | 9 | from networks.unet_3D_dv_semi import unet_3D_dv_semi 10 | from networks.unet_3D import unet_3D 11 | from test_urpc_util import test_all_case 12 | 13 | 14 | def net_factory(net_type="unet_3D", num_classes=3, in_channels=1): 15 | if net_type == "unet_3D": 16 | net = unet_3D(n_classes=num_classes, in_channels=in_channels).cuda() 17 | elif net_type == "unet_3D_dv_semi": 18 | net = unet_3D_dv_semi(n_classes=num_classes, 19 | in_channels=in_channels).cuda() 20 | else: 21 | net = None 22 | return net 23 | 24 | 25 | def Inference(FLAGS): 26 | snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) 27 | num_classes = 2 28 | test_save_path = "../model/{}/Prediction".format(FLAGS.exp) 29 | if os.path.exists(test_save_path): 30 | shutil.rmtree(test_save_path) 31 | os.makedirs(test_save_path) 32 | net = net_factory(FLAGS.model, num_classes, in_channels=1) 33 | save_mode_path = os.path.join( 34 | snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 35 | net.load_state_dict(torch.load(save_mode_path)) 36 | print("init weight from {}".format(save_mode_path)) 37 | net.eval() 38 | avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, 39 | patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) 40 | return avg_metric 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--root_path', type=str, 47 | default='../data/BraTS2019', help='Name of Experiment') 48 | parser.add_argument('--exp', type=str, 49 | default="BraTS2019/Uncertainty_Rectified_Pyramid_Consistency_25_labeled", help='experiment_name') 50 | parser.add_argument('--model', type=str, 51 | default="unet_3D_dv_semi", help='model_name') 52 | FLAGS = parser.parse_args() 53 | 54 | metric = Inference(FLAGS) 55 | print(metric) 56 | -------------------------------------------------------------------------------- /data/data_split_train_test.py: -------------------------------------------------------------------------------- 1 | ## split subject randomly for training and testing 2 | import os 3 | import utils.sortName as getname 4 | from random import randint, sample 5 | 6 | # get all subject id: total 48 subjects 7 | # set path 8 | root_path = "/home/gxu/proj1/smatch/data/MRbrain" 9 | subj_path = os.path.join(root_path, "DICOM") 10 | slice_path = os.path.join(root_path, "img_label_slice") 11 | 12 | # select 36 subjects for training and the other 12 subjects for testing 13 | # get all 48 subjects id 14 | subj_name = getname.get_filename(subj_path, pattern="*.mat") 15 | subj_id = [os.path.basename(name).split(".mat")[0] for name in subj_name] 16 | print(len(subj_id)) 17 | train_id = sample(subj_id, 36) 18 | test_id = [num for num in subj_id if num not in train_id] 19 | 20 | # filter all training slices according to selected training subject id 21 | # and filter all other testing slices according to selected testing subjects ID 22 | # 1. get all slice names 23 | slice_names = getname.get_filename(slice_path, pattern="*.h5") 24 | 25 | # 2. save all training slices name 26 | with open(os.path.join(root_path, "train_slices.list"),'w') as f: 27 | # for training slices 28 | cnt = 0 29 | for name in train_id: 30 | for slice_path in slice_names: 31 | if name in slice_path: 32 | # save slice name for training 33 | sel_slice = slice_path.split("_slice/")[1] 34 | f.write(sel_slice[:-3] + '\n') 35 | cnt += 1 36 | print(cnt) 37 | # 3. save all testing slices name 38 | with open(os.path.join(root_path, "test_slices.list"),'w') as f: 39 | # for training slices 40 | cnt = 0 41 | for name in test_id: 42 | for slice_path in slice_names: 43 | if name in slice_path: 44 | # save slice name for training 45 | sel_slice = slice_path.split("_slice/")[1] 46 | f.write(sel_slice[:-3] + '\n') 47 | cnt += 1 48 | print(cnt) 49 | 50 | # 4. save all testing volume name 51 | with open(os.path.join(root_path, "test_volume.list"),'w') as f: 52 | for sel_slice in test_id: 53 | f.write("subj"+sel_slice + '\n') 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /code/val_2D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | 6 | 7 | def calculate_metric_percase(pred, gt): 8 | pred[pred > 0] = 1 9 | gt[gt > 0] = 1 10 | if pred.sum() > 0: 11 | dice = metric.binary.dc(pred, gt) 12 | hd95 = metric.binary.hd95(pred, gt) 13 | return dice, hd95 14 | else: 15 | return 0, 0 16 | 17 | 18 | def test_single_volume(image, label, net, classes, patch_size=[256, 256]): 19 | image, label = image.squeeze(0).cpu().detach( 20 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 21 | prediction = np.zeros_like(label) 22 | for ind in range(image.shape[0]): 23 | slice = image[ind, :, :] 24 | x, y = slice.shape[0], slice.shape[1] 25 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 26 | input = torch.from_numpy(slice).unsqueeze( 27 | 0).unsqueeze(0).float().cuda() 28 | net.eval() 29 | with torch.no_grad(): 30 | out = torch.argmax(torch.softmax( 31 | net(input), dim=1), dim=1).squeeze(0) 32 | out = out.cpu().detach().numpy() 33 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 34 | prediction[ind] = pred 35 | metric_list = [] 36 | for i in range(1, classes): 37 | metric_list.append(calculate_metric_percase( 38 | prediction == i, label == i)) 39 | return metric_list 40 | 41 | 42 | def test_single_volume_ds(image, label, net, classes, patch_size=[256, 256]): 43 | image, label = image.squeeze(0).cpu().detach( 44 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 45 | prediction = np.zeros_like(label) 46 | for ind in range(image.shape[0]): 47 | slice = image[ind, :, :] 48 | x, y = slice.shape[0], slice.shape[1] 49 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 50 | input = torch.from_numpy(slice).unsqueeze( 51 | 0).unsqueeze(0).float().cuda() 52 | net.eval() 53 | with torch.no_grad(): 54 | output_main, _, _, _ = net(input) 55 | out = torch.argmax(torch.softmax( 56 | output_main, dim=1), dim=1).squeeze(0) 57 | out = out.cpu().detach().numpy() 58 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 59 | prediction[ind] = pred 60 | metric_list = [] 61 | for i in range(1, classes): 62 | metric_list.append(calculate_metric_percase( 63 | prediction == i, label == i)) 64 | return metric_list 65 | -------------------------------------------------------------------------------- /utils/generate_prompts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def generate_click_prompt(img, msk, pt_label = 1): 5 | # return: img, prompt, prompt mask 6 | pt_list = [] 7 | msk_list = [] 8 | b, c, h, w, d = msk.size() 9 | msk = msk[:,0,:,:,:] 10 | for i in range(d): 11 | pt_list_s = [] 12 | msk_list_s = [] 13 | for j in range(b): 14 | msk_s = msk[j,:,:,i] 15 | indices = torch.nonzero(msk_s) 16 | if indices.size(0) == 0: 17 | # generate a random array between [0-h, 0-h]: 18 | random_index = torch.randint(0, h, (2,)).to(device = msk.device) 19 | new_s = msk_s 20 | else: 21 | random_index = random.choice(indices) 22 | label = msk_s[random_index[0], random_index[1]] 23 | new_s = torch.zeros_like(msk_s) 24 | # convert bool tensor to int 25 | new_s = (msk_s == label).to(dtype = torch.float) 26 | # new_s[msk_s == label] = 1 27 | pt_list_s.append(random_index) 28 | msk_list_s.append(new_s) 29 | pts = torch.stack(pt_list_s, dim=0) # b 2 30 | msks = torch.stack(msk_list_s, dim=0) 31 | pt_list.append(pts) # c b 2 32 | msk_list.append(msks) 33 | pt = torch.stack(pt_list, dim=-1) # b 2 d 34 | msk = torch.stack(msk_list, dim=-1) # b h w d 35 | msk = msk.unsqueeze(1) # b c h w d 36 | return img, pt, msk #[b, 2, d], [b, c, h, w, d] 37 | 38 | def get_click_prompt(datapack, opt): 39 | if 'pt' not in datapack: 40 | imgs, pt, masks = generate_click_prompt(imgs, masks) 41 | else: 42 | pt = datapack['pt'] 43 | point_labels = datapack['p_label'] 44 | 45 | point_coords = pt 46 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float32, device=opt.device) 47 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=opt.device) 48 | if len(pt.shape) == 2: 49 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 50 | pt = (coords_torch, labels_torch) 51 | return pt 52 | import numpy as np 53 | def get_click_prompt_1(pt,point_labels): 54 | pt = np.array(pt) 55 | point_coords = pt 56 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float32, device=('cuda:0')) 57 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=('cuda:0')) 58 | if len(pt.shape) == 2: 59 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 60 | pt = (coords_torch, labels_torch) 61 | return pt 62 | -------------------------------------------------------------------------------- /data/data_split_train_test_val.py: -------------------------------------------------------------------------------- 1 | ## split subject randomly for training and testing 2 | import os 3 | import utils.sortName as getname 4 | from random import randint, sample 5 | import random 6 | random.seed(2024) 7 | 8 | 9 | # get all subject id: total 48 subjects 10 | # set path 11 | root_path = "/home/gxu/proj1/smatch/data/MRliver" 12 | subj_path = os.path.join(root_path, "DICOM") 13 | slice_path = os.path.join(root_path, "slices") 14 | 15 | # select 36 subjects for training and the other 12 subjects for testing 16 | # get all 48 subjects id 17 | subj_name = getname.get_filename(subj_path, pattern="*.mat") 18 | subj_id = [os.path.basename(name).split(".mat")[0] for name in subj_name] 19 | print(len(subj_id)) 20 | train_id = sample(subj_id, 30) 21 | test_val_id = [num for num in subj_id if num not in train_id] 22 | test_id = sample(test_val_id, 12) 23 | val_id = [num for num in test_val_id if num not in test_id] 24 | 25 | # filter all training slices according to selected training subject id 26 | # and filter all other testing slices according to selected testing subjects ID 27 | # 1. get all slice names 28 | slice_names = getname.get_filename(slice_path, pattern="*.h5") 29 | 30 | # 2. save all training slices name 31 | with open(os.path.join(root_path, "train_slices.list"),'w') as f: 32 | # for training slices 33 | cnt = 0 34 | for name in train_id: 35 | for slice_path in slice_names: 36 | if name in slice_path: 37 | # save slice name for training 38 | sel_slice = slice_path.split("slices/")[1] 39 | f.write(sel_slice[:-3] + '\n') 40 | cnt += 1 41 | print(cnt) 42 | # 3. save all testing and validation slices name 43 | with open(os.path.join(root_path, "test_slices.list"),'w') as f: 44 | # for training slices 45 | cnt = 0 46 | for name in test_id: 47 | for slice_path in slice_names: 48 | if name in slice_path: 49 | # save slice name for training 50 | sel_slice = slice_path.split("slices/")[1] 51 | f.write(sel_slice[:-3] + '\n') 52 | cnt += 1 53 | print(cnt) 54 | 55 | with open(os.path.join(root_path, "val_slices.list"),'w') as f: 56 | # for val slices 57 | cnt = 0 58 | for name in val_id: 59 | for slice_path in slice_names: 60 | if name in slice_path: 61 | # save slice name for training 62 | sel_slice = slice_path.split("slices/")[1] 63 | f.write(sel_slice[:-3] + '\n') 64 | cnt += 1 65 | print(cnt) 66 | 67 | # 4. save all val volume name 68 | with open(os.path.join(root_path, "val.list"),'w') as f: 69 | for sel_slice in val_id: 70 | f.write("subj"+sel_slice + '\n') 71 | 72 | # 5. save all testing volume name 73 | with open(os.path.join(root_path, "test.list"),'w') as f: 74 | for sel_slice in test_id: 75 | f.write("subj"+sel_slice + '\n') 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /dataloaders/acdc_ex.py: -------------------------------------------------------------------------------- 1 | from dataloaders.transform import random_rot_flip, random_rotate, blur, obtain_cutmix_box 2 | 3 | from copy import deepcopy 4 | import h5py 5 | import math 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import random 10 | from scipy.ndimage.interpolation import zoom 11 | from scipy import ndimage 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | from torchvision import transforms 16 | 17 | 18 | class ACDCDataset(Dataset): 19 | def __init__(self, name, root, mode, size=None, id_path=None, nsample=None): 20 | self.name = name 21 | self.root = root 22 | self.mode = mode 23 | self.size = size 24 | 25 | if mode == 'train_l' or mode == 'train_u': 26 | with open(os.path.join(self.root, id_path), 'r') as f: 27 | self.ids = f.read().splitlines() 28 | if mode == 'train_l' and nsample is not None: 29 | self.ids *= math.ceil(nsample / len(self.ids)) 30 | self.ids = self.ids[:nsample] 31 | else: 32 | with open(self.root + "/val.list", "r")as f: 33 | self.ids = f.read().splitlines() 34 | 35 | def __getitem__(self, item): 36 | id = self.ids[item] 37 | if "train" in self.mode: 38 | sample = h5py.File(self.root + "/slices/{}.h5".format(id), "r") 39 | else: 40 | sample = h5py.File(self.root + "/volumes/{}.h5".format(id), "r") 41 | 42 | img = sample['image'][:] 43 | mask = sample['label'][:] 44 | # normalize 45 | img = (img - img.min()) / (img.max()- img.min()+1e-9) 46 | 47 | if self.mode == 'val': 48 | return torch.from_numpy(img).float(), torch.from_numpy(mask).long() 49 | 50 | if random.random() > 0.5: 51 | img, mask = random_rot_flip(img, mask) 52 | elif random.random() > 0.5: 53 | img, mask = random_rotate(img, mask) 54 | x, y = img.shape 55 | img = zoom(img, (self.size / x, self.size / y), order=0) 56 | mask = zoom(mask, (self.size / x, self.size / y), order=0) 57 | 58 | if self.mode == 'train_l': 59 | return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(np.array(mask)).long() 60 | 61 | img = Image.fromarray((img * 255).astype(np.uint8)) 62 | img_s1, img_s2 = deepcopy(img), deepcopy(img) 63 | img = torch.from_numpy(np.array(img)).unsqueeze(0).float() / 255.0 64 | 65 | if random.random() < 0.8: 66 | img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) 67 | img_s1 = blur(img_s1, p=0.5) 68 | cutmix_box1 = obtain_cutmix_box(self.size, p=0.5) 69 | img_s1 = torch.from_numpy(np.array(img_s1)).unsqueeze(0).float() / 255.0 70 | 71 | if random.random() < 0.8: 72 | img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) 73 | img_s2 = blur(img_s2, p=0.5) 74 | cutmix_box2 = obtain_cutmix_box(self.size, p=0.5) 75 | img_s2 = torch.from_numpy(np.array(img_s2)).unsqueeze(0).float() / 255.0 76 | 77 | return img, img_s1, img_s2, cutmix_box1, cutmix_box2 78 | 79 | def __len__(self): 80 | return len(self.ids) 81 | -------------------------------------------------------------------------------- /dataloaders/mrliver.py: -------------------------------------------------------------------------------- 1 | from dataloaders.transform import random_rot_flip, random_rotate, blur, obtain_cutmix_box 2 | 3 | from copy import deepcopy 4 | import h5py 5 | import math 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import random 10 | from scipy.ndimage.interpolation import zoom 11 | from scipy import ndimage 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | from torchvision import transforms 16 | 17 | 18 | class MRliverDataset(Dataset): 19 | def __init__(self, name, root, mode, size=None, id_path=None, nsample=None): 20 | self.name = name 21 | self.root = root 22 | self.mode = mode 23 | self.size = size 24 | 25 | if mode == 'train_l' or mode == 'train_u': 26 | with open(os.path.join(self.root, id_path), 'r') as f: 27 | self.ids = f.read().splitlines() 28 | if mode == 'train_l' and nsample is not None: 29 | self.ids *= math.ceil(nsample / len(self.ids)) 30 | self.ids = self.ids[:nsample] 31 | else: 32 | with open(self.root + "/val.list", "r")as f: 33 | self.ids = f.read().splitlines() 34 | 35 | def __getitem__(self, item): 36 | id = self.ids[item] 37 | if "train" in self.mode: 38 | sample = h5py.File(self.root + "/slices/{}.h5".format(id), "r") 39 | else: 40 | sample = h5py.File(self.root + "/volumes/{}.h5".format(id), "r") 41 | 42 | img = sample['image'][:] 43 | mask = sample['label'][:] 44 | # normalize 45 | img = (img - img.min()) / (img.max()- img.min()+1e-9) 46 | 47 | if self.mode == 'val': 48 | return torch.from_numpy(img).float(), torch.from_numpy(mask).long() 49 | 50 | if random.random() > 0.5: 51 | img, mask = random_rot_flip(img, mask) 52 | elif random.random() > 0.5: 53 | img, mask = random_rotate(img, mask) 54 | x, y = img.shape 55 | img = zoom(img, (self.size / x, self.size / y), order=0) 56 | mask = zoom(mask, (self.size / x, self.size / y), order=0) 57 | 58 | if self.mode == 'train_l': 59 | return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(np.array(mask)).long() 60 | 61 | img = Image.fromarray((img * 255).astype(np.uint8)) 62 | img_s1, img_s2 = deepcopy(img), deepcopy(img) 63 | img = torch.from_numpy(np.array(img)).unsqueeze(0).float() / 255.0 64 | 65 | if random.random() < 0.8: 66 | img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) 67 | img_s1 = blur(img_s1, p=0.5) 68 | cutmix_box1 = obtain_cutmix_box(self.size, p=0.5) 69 | img_s1 = torch.from_numpy(np.array(img_s1)).unsqueeze(0).float() / 255.0 70 | 71 | if random.random() < 0.8: 72 | img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) 73 | img_s2 = blur(img_s2, p=0.5) 74 | cutmix_box2 = obtain_cutmix_box(self.size, p=0.5) 75 | img_s2 = torch.from_numpy(np.array(img_s2)).unsqueeze(0).float() / 255.0 76 | 77 | return img, img_s1, img_s2, cutmix_box1, cutmix_box2 78 | 79 | def __len__(self): 80 | return len(self.ids) 81 | -------------------------------------------------------------------------------- /dataloaders/acdc.py: -------------------------------------------------------------------------------- 1 | from dataloaders.transform import random_rot_flip, random_rotate, blur, obtain_cutmix_box 2 | 3 | from copy import deepcopy 4 | import h5py 5 | import math 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import random 10 | from scipy.ndimage.interpolation import zoom 11 | from scipy import ndimage 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | from torchvision import transforms 16 | import cv2 17 | import augmentations 18 | from augmentations.ctaugment import OPS 19 | 20 | class ACDCDataset(Dataset): 21 | def __init__(self, name, root, mode, size=None, id_path=None, nsample=None): 22 | self.name = name 23 | self.root = root 24 | self.mode = mode 25 | self.size = size 26 | if mode == 'train_l' or mode == 'train_u': 27 | with open(id_path, 'r') as f: 28 | self.ids = f.read().splitlines() 29 | if mode == 'train_l' and nsample is not None: 30 | self.ids *= math.ceil(nsample / len(self.ids)) 31 | self.ids = self.ids[:nsample] 32 | else: 33 | with open('/home/gxu/proj1/smatch/data/ACDC/%s/val.txt' % name, 'r') as f: 34 | self.ids = f.read().splitlines() 35 | 36 | def __getitem__(self, item): 37 | id = self.ids[item] 38 | sample = h5py.File(os.path.join(self.root, id), 'r') 39 | img = sample['image'][:] 40 | mask = sample['label'][:] 41 | 42 | if self.mode == 'val' or self.mode == 'test': 43 | return torch.from_numpy(img).float(), torch.from_numpy(mask).long() 44 | 45 | if random.random() > 0.5: 46 | img, mask = random_rot_flip(img, mask) 47 | elif random.random() > 0.5: 48 | img, mask = random_rotate(img, mask) 49 | x, y = img.shape 50 | img = zoom(img, (self.size / x, self.size / y), order=0) 51 | mask = zoom(mask, (self.size / x, self.size / y), order=0) 52 | 53 | if self.mode == 'train_l': 54 | return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(np.array(mask)).long() 55 | 56 | img = Image.fromarray((img * 255).astype(np.uint8)) 57 | img_s1, img_s2 = deepcopy(img), deepcopy(img) 58 | img = torch.from_numpy(np.array(img)).unsqueeze(0).float() / 255.0 59 | 60 | if random.random() < 0.8: 61 | #img_s1 = augmentations.cta_apply(img_s1, self.ops_strong) 62 | img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) 63 | img_s1 = blur(img_s1, p=0.5) 64 | cutmix_box1 = obtain_cutmix_box(self.size, p=0.5) 65 | img_s1 = torch.from_numpy(np.array(img_s1)).unsqueeze(0).float() / 255.0 66 | 67 | if random.random() < 0.8: 68 | #img_s2 = pil_to_tensor_fft(img_s2) 69 | #img_s2 = augmentations.cta_apply(img_s2, self.ops_strong) 70 | img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) 71 | img_s2 = blur(img_s2, p=0.5) 72 | cutmix_box2 = obtain_cutmix_box(self.size, p=0.5) 73 | img_s2 = torch.from_numpy(np.array(img_s2)).unsqueeze(0).float() / 255.0 74 | 75 | return img, img_s1, img_s2, cutmix_box1, cutmix_box2 76 | 77 | def __len__(self): 78 | return len(self.ids) 79 | 80 | 81 | -------------------------------------------------------------------------------- /dataloaders/Synapse.py: -------------------------------------------------------------------------------- 1 | from dataloaders.transform import random_rot_flip, random_rotate, blur, obtain_cutmix_box 2 | 3 | from copy import deepcopy 4 | import h5py 5 | import math 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import random 10 | from scipy.ndimage.interpolation import zoom 11 | from scipy import ndimage 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | from torchvision import transforms 16 | import cv2 17 | import augmentations 18 | from augmentations.ctaugment import OPS 19 | 20 | class SynapseDataset(Dataset): 21 | def __init__(self, name, root, mode, size=None, id_path=None, nsample=None): 22 | self.name = name 23 | self.root = root 24 | self.mode = mode 25 | self.size = size 26 | if mode == 'train_l' or mode == 'train_u': 27 | with open(id_path, 'r') as f: 28 | self.ids = f.read().splitlines() 29 | if mode == 'train_l' and nsample is not None: 30 | self.ids *= math.ceil(nsample / len(self.ids)) 31 | self.ids = self.ids[:nsample] 32 | else: 33 | with open('/home/data/lxs/SSL1/SSL4MIS-master/SSL4MIS-master/splits/%s/val.txt' % name, 'r') as f: 34 | self.ids = f.read().splitlines() 35 | 36 | def __getitem__(self, item): 37 | id = self.ids[item] 38 | if self.mode == "train_l" or self.mode == "train_u": 39 | sample = np.load(os.path.join(self.root+"/data/train_npz", id+".npz"), 'r') 40 | else: 41 | sample = h5py.File(os.path.join(self.root+"/data/val", id+".npy.h5"), 'r') 42 | img = sample['image'][:] 43 | mask = sample['label'][:] 44 | 45 | if self.mode == 'val' or self.mode == 'test': 46 | return torch.from_numpy(img).float(), torch.from_numpy(mask).long() 47 | 48 | if random.random() > 0.5: 49 | img, mask = random_rot_flip(img, mask) 50 | elif random.random() > 0.5: 51 | img, mask = random_rotate(img, mask) 52 | x, y = img.shape 53 | img = zoom(img, (self.size / x, self.size / y), order=0) 54 | mask = zoom(mask, (self.size / x, self.size / y), order=0) 55 | 56 | if self.mode == 'train_l': 57 | return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(np.array(mask)).long() 58 | 59 | img = Image.fromarray((img * 255).astype(np.uint8)) 60 | img_s1, img_s2 = deepcopy(img), deepcopy(img) 61 | img = torch.from_numpy(np.array(img)).unsqueeze(0).float() / 255.0 62 | 63 | if random.random() < 0.8: 64 | img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) 65 | img_s1 = blur(img_s1, p=0.5) 66 | cutmix_box1 = obtain_cutmix_box(self.size, p=0.5) 67 | img_s1 = torch.from_numpy(np.array(img_s1)).unsqueeze(0).float() / 255.0 68 | 69 | if random.random() < 0.8: 70 | img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) 71 | img_s2 = blur(img_s2, p=0.5) 72 | cutmix_box2 = obtain_cutmix_box(self.size, p=0.5) 73 | img_s2 = torch.from_numpy(np.array(img_s2)).unsqueeze(0).float() / 255.0 74 | 75 | return img, img_s1, img_s2, cutmix_box1, cutmix_box2 76 | 77 | def __len__(self): 78 | return len(self.ids) 79 | 80 | 81 | -------------------------------------------------------------------------------- /data/gen_train_test_val_name.py: -------------------------------------------------------------------------------- 1 | ## split subject randomly for training and testing 2 | import os 3 | import utils.sortName as getname 4 | from random import randint, sample 5 | 6 | # get all subject id: total 48 subjects 7 | # set path 8 | root_path = "/home/gxu/proj1/smatch/data/MRliver" 9 | subj_path = os.path.join(root_path, "DICOM") 10 | slice_path = os.path.join(root_path, "slices") 11 | 12 | # select 30 subjects for training and the other 12 subjects for testing, 6 for validation 13 | # get all 48 subjects id 14 | subj_name = getname.get_filename(subj_path, pattern="*.mat") 15 | subj_id = ["subj"+os.path.basename(name).split(".mat")[0] for name in subj_name] 16 | print(len(subj_id)) 17 | 18 | # get test id, and val id 19 | test_id_obj = open(os.path.join(root_path, "test.list"),'r') 20 | test_id = test_id_obj.read().split("\n")[:12] 21 | 22 | val_id_obj = open(os.path.join(root_path, "val.list"),'r') 23 | val_id = val_id_obj.read().split("\n")[:6] 24 | 25 | train_val_id = [num for num in subj_id if num not in test_id] 26 | train_id = [num for num in train_val_id if num not in val_id] 27 | 28 | 29 | # filter all training slices according to selected training subject id 30 | # and filter all other testing slices according to selected testing subjects ID 31 | # 1. get all slice names 32 | slice_names = getname.get_filename(slice_path, pattern="*.h5") 33 | 34 | # 2. save all training slices name 35 | with open(os.path.join(root_path, "train_slices.list"),'w') as f: 36 | # for training slices 37 | cnt = 0 38 | for name in train_id: 39 | for slice_path in slice_names: 40 | if name in slice_path: 41 | # save slice name for training 42 | sel_slice = slice_path.split("slices/")[1] 43 | f.write(sel_slice[:-3] + '\n') 44 | cnt += 1 45 | print(cnt) 46 | # 3. save all testing and validation slices name 47 | with open(os.path.join(root_path, "test_slices.list"),'w') as f: 48 | # for training slices 49 | cnt = 0 50 | for name in test_id: 51 | for slice_path in slice_names: 52 | if name in slice_path: 53 | # save slice name for training 54 | sel_slice = slice_path.split("slices/")[1] 55 | f.write(sel_slice[:-3] + '\n') 56 | cnt += 1 57 | print(cnt) 58 | 59 | with open(os.path.join(root_path, "val_slices.list"),'w') as f: 60 | # for val slices 61 | cnt = 0 62 | for name in val_id: 63 | for slice_path in slice_names: 64 | if name in slice_path: 65 | # save slice name for training 66 | sel_slice = slice_path.split("slices/")[1] 67 | f.write(sel_slice[:-3] + '\n') 68 | cnt += 1 69 | print(cnt) 70 | 71 | # # 4. save all val volume name 72 | # with open(os.path.join(root_path, "val.list"),'w') as f: 73 | # for sel_slice in val_id: 74 | # f.write("subj"+sel_slice + '\n') 75 | 76 | # # 5. save all testing volume name 77 | # with open(os.path.join(root_path, "test.list"),'w') as f: 78 | # for sel_slice in test_id: 79 | # f.write("subj"+sel_slice + '\n') 80 | 81 | # 6. save all training volume name 82 | with open(os.path.join(root_path, "train.list"),'w') as f: 83 | for sel_slice in train_id: 84 | f.write(sel_slice + '\n') 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /utils/builder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Union 2 | 3 | import torch.nn as nn 4 | from mmengine import MODELS as MMENGINE_MODELS 5 | from mmengine import Config, ConfigDict, Registry, build_from_cfg 6 | from torch.nn import Module 7 | 8 | MODELS = Registry('models', parent=MMENGINE_MODELS) 9 | LOSSES = MODELS 10 | ENCODERS = MODELS 11 | DECODERS = MODELS 12 | FLOW_ESTIMATORS = MODELS 13 | BACKBONES = MODELS 14 | METRICS = MODELS 15 | REGISTRATION_HEAD = MODELS 16 | 17 | CFG = Union[dict, Config, ConfigDict] 18 | 19 | 20 | def build(cfg: Union[Sequence[CFG], CFG], 21 | registry: Registry, 22 | default_args: Optional[dict] = None): 23 | """Build a module. 24 | 25 | Args: 26 | cfg (dict, list[dict]): The config of modules, is either a dict 27 | or a list of configs. 28 | registry (:obj:`Registry`): A registry the module belongs to. 29 | default_args (dict, optional): Default arguments to build the module. 30 | Defaults to None. 31 | Returns: 32 | nn.Module: A built nn module. 33 | """ 34 | if isinstance(cfg, list): 35 | modules = [ 36 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 37 | ] 38 | return nn.Sequential(*modules) 39 | else: 40 | return build_from_cfg(cfg, registry, default_args) 41 | 42 | 43 | def build_loss(cfg: dict) -> Module: 44 | """Build loss function. 45 | 46 | Args: 47 | cfg (dict): Config for loss function. 48 | Returns: 49 | Module: Loss function. 50 | """ 51 | cfg_ = cfg.copy() 52 | cfg_.pop('weight', None) 53 | return build(cfg_, LOSSES) 54 | 55 | 56 | def build_metrics(cfg: dict) -> Module: 57 | """Build metric function. 58 | 59 | Args: 60 | cfg (dict): Config for encoder. 61 | Returns: 62 | Module: Metric function. 63 | """ 64 | return build(cfg, METRICS) 65 | 66 | 67 | def build_encoder(cfg: dict) -> Module: 68 | """Build encoder for flow estimator. 69 | 70 | Args: 71 | cfg (dict): Config for encoder. 72 | Returns: 73 | Module: Encoder module. 74 | """ 75 | return build(cfg, ENCODERS) 76 | 77 | 78 | def build_decoder(cfg: dict) -> Module: 79 | """Build decoder for flow estimator. 80 | 81 | Args: 82 | cfg (dict): Config for decoder. 83 | Returns: 84 | Module: Decoder module. 85 | """ 86 | return build(cfg, DECODERS) 87 | 88 | 89 | def build_flow_estimator(cfg: dict) -> Module: 90 | """Build flow estimator. 91 | 92 | Args: 93 | cfg (dict): Config for optical flow estimator. 94 | Returns: 95 | Module: Flow estimator. 96 | """ 97 | return build(cfg, FLOW_ESTIMATORS) 98 | 99 | 100 | def build_backbone(cfg: dict) -> Module: 101 | """Build backbone. 102 | 103 | Args: 104 | cfg (dict): Config for optical flow estimator. 105 | Returns: 106 | Module: Backbone. 107 | """ 108 | return build(cfg, BACKBONES) 109 | 110 | 111 | def build_registration_head(cfg: dict) -> Module: 112 | """Build registration head. 113 | 114 | Args: 115 | cfg (dict): Config for registration head. 116 | Returns: 117 | Module: Registration head. 118 | """ 119 | return build(cfg, REGISTRATION_HEAD) -------------------------------------------------------------------------------- /model_sam/segment_anything_samus/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | class Adapter(nn.Module): 13 | def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): #0.25 14 | super().__init__() 15 | self.skip_connect = skip_connect 16 | D_hidden_features = int(D_features * mlp_ratio) 17 | self.act = act_layer() 18 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 19 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 20 | 21 | def forward(self, x): 22 | # x is (BT, HW+1, D) 23 | xs = self.D_fc1(x) 24 | xs = self.act(xs) 25 | xs = self.D_fc2(xs) 26 | if self.skip_connect: 27 | x = x + xs 28 | else: 29 | x = xs 30 | return x 31 | 32 | 33 | class AugAdapter(nn.Module): 34 | def __init__(self, D_features, mlp_ratio=0.25, num_heads=12, act_layer=nn.GELU, skip_connect=True): #0.25 35 | super().__init__() 36 | self.skip_connect = skip_connect 37 | D_hidden_features = int(D_features * mlp_ratio) 38 | self.act = act_layer() 39 | self.D_fc1 = nn.Linear(D_features, D_hidden_features) 40 | self.D_fc2 = nn.Linear(D_hidden_features, D_features) 41 | self.aug_fc = nn.Linear(num_heads, D_hidden_features) 42 | 43 | def forward(self, x, important_key): 44 | # x is (BT, HW+1, D) 45 | xs = self.D_fc1(x) 46 | aug = self.aug_fc(important_key) 47 | xs = self.act(xs * aug) 48 | xs = self.D_fc2(xs) 49 | if self.skip_connect: 50 | x = x + xs 51 | else: 52 | x = xs 53 | return x 54 | 55 | 56 | class MLPBlock(nn.Module): 57 | def __init__( 58 | self, 59 | embedding_dim: int, 60 | mlp_dim: int, 61 | act: Type[nn.Module] = nn.GELU, 62 | ) -> None: 63 | super().__init__() 64 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 65 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 66 | self.act = act() 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | x = self.lin1(x) 70 | x = self.act(x) 71 | x = self.lin2(x) 72 | return x 73 | #return self.lin2(self.act(self.lin1(x))) 74 | 75 | 76 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 77 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 78 | class LayerNorm2d(nn.Module): 79 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 80 | super().__init__() 81 | self.weight = nn.Parameter(torch.ones(num_channels)) 82 | self.bias = nn.Parameter(torch.zeros(num_channels)) 83 | self.eps = eps 84 | 85 | def forward(self, x: torch.Tensor) -> torch.Tensor: 86 | u = x.mean(1, keepdim=True) 87 | s = (x - u).pow(2).mean(1, keepdim=True) 88 | x = (x - u) / torch.sqrt(s + self.eps) 89 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 90 | return x 91 | -------------------------------------------------------------------------------- /utils/metrics_samus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from hausdorff import hausdorff_distance 4 | 5 | def dice_coefficient(pred, gt, smooth=1e-5): 6 | """ computational formula: 7 | dice = 2TP/(FP + 2TP + FN) 8 | """ 9 | N = gt.shape[0] 10 | pred[pred >= 1] = 1 11 | gt[gt >= 1] = 1 12 | pred_flat = pred.reshape(N, -1) 13 | gt_flat = gt.reshape(N, -1) 14 | # if (pred.sum() + gt.sum()) == 0: 15 | # return 1 16 | intersection = (pred_flat * gt_flat).sum(1) 17 | unionset = pred_flat.sum(1) + gt_flat.sum(1) 18 | dice = (2 * intersection + smooth) / (unionset + smooth) 19 | return dice.sum() / N 20 | 21 | def sespiou_coefficient(pred, gt, smooth=1e-5): 22 | """ computational formula: 23 | sensitivity = TP/(TP+FN) 24 | specificity = TN/(FP+TN) 25 | iou = TP/(FP+TP+FN) 26 | """ 27 | N = gt.shape[0] 28 | pred[pred >= 1] = 1 29 | gt[gt >= 1] = 1 30 | pred_flat = pred.reshape(N, -1) 31 | gt_flat = gt.reshape(N, -1) 32 | #pred_flat = pred.view(N, -1) 33 | #gt_flat = gt.view(N, -1) 34 | TP = (pred_flat * gt_flat).sum(1) 35 | FN = gt_flat.sum(1) - TP 36 | pred_flat_no = (pred_flat + 1) % 2 37 | gt_flat_no = (gt_flat + 1) % 2 38 | TN = (pred_flat_no * gt_flat_no).sum(1) 39 | FP = pred_flat.sum(1) - TP 40 | SE = (TP + smooth) / (TP + FN + smooth) 41 | SP = (TN + smooth) / (FP + TN + smooth) 42 | IOU = (TP + smooth) / (FP + TP + FN + smooth) 43 | return SE.sum() / N, SP.sum() / N, IOU.sum() / N 44 | 45 | def sespiou_coefficient2(pred, gt, all=False, smooth=1e-5): 46 | """ computational formula: 47 | sensitivity = TP/(TP+FN) 48 | specificity = TN/(FP+TN) 49 | iou = TP/(FP+TP+FN) 50 | """ 51 | N = gt.shape[0] 52 | pred[pred >= 1] = 1 53 | gt[gt >= 1] = 1 54 | pred_flat = pred.reshape(N, -1) 55 | gt_flat = gt.reshape(N, -1) 56 | #pred_flat = pred.view(N, -1) 57 | #gt_flat = gt.view(N, -1) 58 | TP = (pred_flat * gt_flat).sum(1) 59 | FN = gt_flat.sum(1) - TP 60 | pred_flat_no = (pred_flat + 1) % 2 61 | gt_flat_no = (gt_flat + 1) % 2 62 | TN = (pred_flat_no * gt_flat_no).sum(1) 63 | FP = pred_flat.sum(1) - TP 64 | SE = (TP + smooth) / (TP + FN + smooth) 65 | SP = (TN + smooth) / (FP + TN + smooth) 66 | IOU = (TP + smooth) / (FP + TP + FN + smooth) 67 | Acc = (TP + TN + smooth)/(TP + FP + FN + TN + smooth) 68 | Precision = (TP + smooth) / (TP + FP + smooth) 69 | Recall = (TP + smooth) / (TP + FN + smooth) 70 | F1 = 2*Precision*Recall/(Recall + Precision +smooth) 71 | if all: 72 | return SE.sum() / N, SP.sum() / N, IOU.sum() / N, Acc.sum()/N, F1.sum()/N, Precision.sum()/N, Recall.sum()/N 73 | else: 74 | return IOU.sum() / N, Acc.sum()/N, SE.sum() / N, SP.sum() / N 75 | 76 | def get_matrix(pred, gt, smooth=1e-5): 77 | """ computational formula: 78 | sensitivity = TP/(TP+FN) 79 | specificity = TN/(FP+TN) 80 | iou = TP/(FP+TP+FN) 81 | """ 82 | N = gt.shape[0] 83 | pred[pred >= 1] = 1 84 | gt[gt >= 1] = 1 85 | pred_flat = pred.reshape(N, -1) 86 | gt_flat = gt.reshape(N, -1) 87 | TP = (pred_flat * gt_flat).sum(1) 88 | FN = gt_flat.sum(1) - TP 89 | pred_flat_no = (pred_flat + 1) % 2 90 | gt_flat_no = (gt_flat + 1) % 2 91 | TN = (pred_flat_no * gt_flat_no).sum(1) 92 | FP = pred_flat.sum(1) - TP 93 | return TP, FP, TN, FN -------------------------------------------------------------------------------- /networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FC3DDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FC3DDiscriminator, self).__init__() 10 | # downsample 16 11 | self.conv0 = nn.Conv3d( 12 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 13 | self.conv1 = nn.Conv3d( 14 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 15 | 16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv3d( 18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 19 | self.conv4 = nn.Conv3d( 20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | 24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 25 | self.dropout = nn.Dropout3d(0.5) 26 | self.Softmax = nn.Softmax() 27 | 28 | def forward(self, map, image): 29 | batch_size = map.shape[0] 30 | map_feature = self.conv0(map) 31 | image_feature = self.conv1(image) 32 | x = torch.add(map_feature, image_feature) 33 | x = self.leaky_relu(x) 34 | x = self.dropout(x) 35 | 36 | x = self.conv2(x) 37 | x = self.leaky_relu(x) 38 | x = self.dropout(x) 39 | 40 | x = self.conv3(x) 41 | x = self.leaky_relu(x) 42 | x = self.dropout(x) 43 | 44 | x = self.conv4(x) 45 | x = self.leaky_relu(x) 46 | 47 | x = self.avgpool(x) 48 | 49 | x = x.view(batch_size, -1) 50 | 51 | x = self.classifier(x) 52 | x = x.reshape((batch_size, 2)) 53 | # x = self.Softmax(x) 54 | 55 | return x 56 | 57 | 58 | class FCDiscriminator(nn.Module): 59 | 60 | def __init__(self, num_classes, ndf=64, n_channel=1): 61 | super(FCDiscriminator, self).__init__() 62 | self.conv0 = nn.Conv2d( 63 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 64 | self.conv1 = nn.Conv2d( 65 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 67 | self.conv3 = nn.Conv2d( 68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 69 | self.conv4 = nn.Conv2d( 70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 71 | self.classifier = nn.Linear(ndf*32, 2) 72 | self.avgpool = nn.AvgPool2d((7, 7)) 73 | 74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | self.dropout = nn.Dropout2d(0.5) 76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 77 | # self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, map, feature): 80 | map_feature = self.conv0(map) 81 | image_feature = self.conv1(feature) 82 | x = torch.add(map_feature, image_feature) 83 | 84 | x = self.conv2(x) 85 | x = self.leaky_relu(x) 86 | x = self.dropout(x) 87 | 88 | x = self.conv3(x) 89 | x = self.leaky_relu(x) 90 | x = self.dropout(x) 91 | 92 | x = self.conv4(x) 93 | x = self.leaky_relu(x) 94 | x = self.avgpool(x) 95 | x = x.view(x.size(0), -1) 96 | x = self.classifier(x) 97 | # x = self.up_sample(x) 98 | # x = self.sigmoid(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FC3DDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FC3DDiscriminator, self).__init__() 10 | # downsample 16 11 | self.conv0 = nn.Conv3d( 12 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 13 | self.conv1 = nn.Conv3d( 14 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 15 | 16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv3d( 18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 19 | self.conv4 = nn.Conv3d( 20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | 24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 25 | self.dropout = nn.Dropout3d(0.5) 26 | self.Softmax = nn.Softmax() 27 | 28 | def forward(self, map, image): 29 | batch_size = map.shape[0] 30 | map_feature = self.conv0(map) 31 | image_feature = self.conv1(image) 32 | x = torch.add(map_feature, image_feature) 33 | x = self.leaky_relu(x) 34 | x = self.dropout(x) 35 | 36 | x = self.conv2(x) 37 | x = self.leaky_relu(x) 38 | x = self.dropout(x) 39 | 40 | x = self.conv3(x) 41 | x = self.leaky_relu(x) 42 | x = self.dropout(x) 43 | 44 | x = self.conv4(x) 45 | x = self.leaky_relu(x) 46 | 47 | x = self.avgpool(x) 48 | 49 | x = x.view(batch_size, -1) 50 | 51 | x = self.classifier(x) 52 | x = x.reshape((batch_size, 2)) 53 | # x = self.Softmax(x) 54 | 55 | return x 56 | 57 | 58 | class FCDiscriminator(nn.Module): 59 | 60 | def __init__(self, num_classes, ndf=64, n_channel=1): 61 | super(FCDiscriminator, self).__init__() 62 | self.conv0 = nn.Conv2d( 63 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 64 | self.conv1 = nn.Conv2d( 65 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 67 | self.conv3 = nn.Conv2d( 68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 69 | self.conv4 = nn.Conv2d( 70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 71 | self.classifier = nn.Linear(ndf*32, 2) 72 | self.avgpool = nn.AvgPool2d((7, 7)) 73 | 74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | self.dropout = nn.Dropout2d(0.5) 76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 77 | # self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, map, feature): 80 | map_feature = self.conv0(map) 81 | image_feature = self.conv1(feature) 82 | x = torch.add(map_feature, image_feature) 83 | 84 | x = self.conv2(x) 85 | x = self.leaky_relu(x) 86 | x = self.dropout(x) 87 | 88 | x = self.conv3(x) 89 | x = self.leaky_relu(x) 90 | x = self.dropout(x) 91 | 92 | x = self.conv4(x) 93 | x = self.leaky_relu(x) 94 | x = self.avgpool(x) 95 | x = x.view(x.size(0), -1) 96 | x = self.classifier(x) 97 | # x = self.up_sample(x) 98 | # x = self.sigmoid(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /networks/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | 9 | class Conv2dReLU(nn.Sequential): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | padding=0, 16 | stride=1, 17 | use_batchnorm=True, 18 | ): 19 | 20 | if use_batchnorm == "inplace" and InPlaceABN is None: 21 | raise RuntimeError( 22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 23 | + "To install see: https://github.com/mapillary/inplace_abn" 24 | ) 25 | 26 | super().__init__() 27 | 28 | conv = nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | bias=not (use_batchnorm), 35 | ) 36 | relu = nn.ReLU(inplace=True) 37 | 38 | if use_batchnorm == "inplace": 39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 40 | relu = nn.Identity() 41 | 42 | elif use_batchnorm and use_batchnorm != "inplace": 43 | bn = nn.BatchNorm2d(out_channels) 44 | 45 | else: 46 | bn = nn.Identity() 47 | 48 | super(Conv2dReLU, self).__init__(conv, bn, relu) 49 | 50 | 51 | class SCSEModule(nn.Module): 52 | def __init__(self, in_channels, reduction=16): 53 | super().__init__() 54 | self.cSE = nn.Sequential( 55 | nn.AdaptiveAvgPool2d(1), 56 | nn.Conv2d(in_channels, in_channels // reduction, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(in_channels // reduction, in_channels, 1), 59 | nn.Sigmoid(), 60 | ) 61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | return x * self.cSE(x) + x * self.sSE(x) 65 | 66 | 67 | class Activation(nn.Module): 68 | 69 | def __init__(self, name, **params): 70 | 71 | super().__init__() 72 | 73 | if name is None or name == 'identity': 74 | self.activation = nn.Identity(**params) 75 | elif name == 'sigmoid': 76 | self.activation = nn.Sigmoid() 77 | elif name == 'softmax2d': 78 | self.activation = nn.Softmax(dim=1, **params) 79 | elif name == 'softmax': 80 | self.activation = nn.Softmax(**params) 81 | elif name == 'logsoftmax': 82 | self.activation = nn.LogSoftmax(**params) 83 | elif callable(name): 84 | self.activation = name(**params) 85 | else: 86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 87 | 88 | def forward(self, x): 89 | return self.activation(x) 90 | 91 | 92 | class Attention(nn.Module): 93 | 94 | def __init__(self, name, **params): 95 | super().__init__() 96 | 97 | if name is None: 98 | self.attention = nn.Identity(**params) 99 | elif name == 'scse': 100 | self.attention = SCSEModule(**params) 101 | else: 102 | raise ValueError("Attention {} is not implemented".format(name)) 103 | 104 | def forward(self, x): 105 | return self.attention(x) 106 | 107 | 108 | class Flatten(nn.Module): 109 | def forward(self, x): 110 | return x.view(x.shape[0], -1) 111 | -------------------------------------------------------------------------------- /code/networks/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | 9 | class Conv2dReLU(nn.Sequential): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | padding=0, 16 | stride=1, 17 | use_batchnorm=True, 18 | ): 19 | 20 | if use_batchnorm == "inplace" and InPlaceABN is None: 21 | raise RuntimeError( 22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 23 | + "To install see: https://github.com/mapillary/inplace_abn" 24 | ) 25 | 26 | super().__init__() 27 | 28 | conv = nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | bias=not (use_batchnorm), 35 | ) 36 | relu = nn.ReLU(inplace=True) 37 | 38 | if use_batchnorm == "inplace": 39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 40 | relu = nn.Identity() 41 | 42 | elif use_batchnorm and use_batchnorm != "inplace": 43 | bn = nn.BatchNorm2d(out_channels) 44 | 45 | else: 46 | bn = nn.Identity() 47 | 48 | super(Conv2dReLU, self).__init__(conv, bn, relu) 49 | 50 | 51 | class SCSEModule(nn.Module): 52 | def __init__(self, in_channels, reduction=16): 53 | super().__init__() 54 | self.cSE = nn.Sequential( 55 | nn.AdaptiveAvgPool2d(1), 56 | nn.Conv2d(in_channels, in_channels // reduction, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(in_channels // reduction, in_channels, 1), 59 | nn.Sigmoid(), 60 | ) 61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | return x * self.cSE(x) + x * self.sSE(x) 65 | 66 | 67 | class Activation(nn.Module): 68 | 69 | def __init__(self, name, **params): 70 | 71 | super().__init__() 72 | 73 | if name is None or name == 'identity': 74 | self.activation = nn.Identity(**params) 75 | elif name == 'sigmoid': 76 | self.activation = nn.Sigmoid() 77 | elif name == 'softmax2d': 78 | self.activation = nn.Softmax(dim=1, **params) 79 | elif name == 'softmax': 80 | self.activation = nn.Softmax(**params) 81 | elif name == 'logsoftmax': 82 | self.activation = nn.LogSoftmax(**params) 83 | elif callable(name): 84 | self.activation = name(**params) 85 | else: 86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 87 | 88 | def forward(self, x): 89 | return self.activation(x) 90 | 91 | 92 | class Attention(nn.Module): 93 | 94 | def __init__(self, name, **params): 95 | super().__init__() 96 | 97 | if name is None: 98 | self.attention = nn.Identity(**params) 99 | elif name == 'scse': 100 | self.attention = SCSEModule(**params) 101 | else: 102 | raise ValueError("Attention {} is not implemented".format(name)) 103 | 104 | def forward(self, x): 105 | return self.attention(x) 106 | 107 | 108 | class Flatten(nn.Module): 109 | def forward(self, x): 110 | return x.view(x.shape[0], -1) 111 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | from torch import nn 6 | import torch 7 | 8 | 9 | def count_params(model): 10 | param_num = sum(p.numel() for p in model.parameters()) 11 | return param_num / 1e6 12 | 13 | 14 | class DiceLoss(nn.Module): 15 | def __init__(self, n_classes): 16 | super(DiceLoss, self).__init__() 17 | self.n_classes = n_classes 18 | 19 | def _one_hot_encoder(self, input_tensor): 20 | tensor_list = [] 21 | for i in range(self.n_classes): 22 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 23 | tensor_list.append(temp_prob) 24 | output_tensor = torch.cat(tensor_list, dim=1) 25 | return output_tensor.float() 26 | 27 | def _dice_loss(self, score, target, ignore): 28 | target = target.float() 29 | smooth = 1e-5 30 | intersect = torch.sum(score[ignore != 1] * target[ignore != 1]) 31 | y_sum = torch.sum(target[ignore != 1] * target[ignore != 1]) 32 | z_sum = torch.sum(score[ignore != 1] * score[ignore != 1]) 33 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 34 | loss = 1 - loss 35 | return loss 36 | 37 | def forward(self, inputs, target, weight=None, softmax=False, ignore=None): 38 | if softmax: 39 | inputs = torch.softmax(inputs, dim=1) 40 | target = self._one_hot_encoder(target) 41 | if weight is None: 42 | weight = [1] * self.n_classes 43 | assert inputs.size() == target.size(), 'predict & target shape do not match' 44 | class_wise_dice = [] 45 | loss = 0.0 46 | for i in range(0, self.n_classes): 47 | dice = self._dice_loss(inputs[:, i], target[:, i], ignore) 48 | class_wise_dice.append(1.0 - dice.item()) 49 | loss += dice * weight[i] 50 | return loss / self.n_classes 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self, length=0): 57 | self.length = length 58 | self.reset() 59 | 60 | def reset(self): 61 | if self.length > 0: 62 | self.history = [] 63 | else: 64 | self.count = 0 65 | self.sum = 0.0 66 | self.val = 0.0 67 | self.avg = 0.0 68 | 69 | def update(self, val, num=1): 70 | if self.length > 0: 71 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 72 | assert num == 1 73 | self.history.append(val) 74 | if len(self.history) > self.length: 75 | del self.history[0] 76 | 77 | self.val = self.history[-1] 78 | self.avg = np.mean(self.history) 79 | else: 80 | self.val = val 81 | self.sum += val * num 82 | self.count += num 83 | self.avg = self.sum / self.count 84 | 85 | 86 | logs = set() 87 | 88 | 89 | def init_log(name, level=logging.INFO): 90 | if (name, level) in logs: 91 | return 92 | logs.add((name, level)) 93 | logger = logging.getLogger(name) 94 | logger.setLevel(level) 95 | ch = logging.StreamHandler() 96 | ch.setLevel(level) 97 | if "SLURM_PROCID" in os.environ: 98 | rank = int(os.environ["SLURM_PROCID"]) 99 | logger.addFilter(lambda record: rank == 0) 100 | else: 101 | rank = 0 102 | format_str = "[%(asctime)s][%(levelname)8s] %(message)s" 103 | formatter = logging.Formatter(format_str) 104 | ch.setFormatter(formatter) 105 | logger.addHandler(ch) 106 | return logger 107 | -------------------------------------------------------------------------------- /networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the 3D U-Net paper: 4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: 5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. 6 | MICCAI (2) 2016: 424-432 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks 10 | """ 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from networks.networks_other import init_weights 17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT 18 | 19 | 20 | class unet_3D(nn.Module): 21 | 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 23 | super(unet_3D, self).__init__() 24 | self.is_deconv = is_deconv 25 | self.in_channels = in_channels 26 | self.is_batchnorm = is_batchnorm 27 | self.feature_scale = feature_scale 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 34 | 3, 3, 3), padding_size=(1, 1, 1)) 35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 36 | 37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 38 | 3, 3, 3), padding_size=(1, 1, 1)) 39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 40 | 41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 42 | 3, 3, 3), padding_size=(1, 1, 1)) 43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 44 | 45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 46 | 3, 3, 3), padding_size=(1, 1, 1)) 47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 48 | 49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 50 | 3, 3, 3), padding_size=(1, 1, 1)) 51 | 52 | # upsampling 53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(filters[0], n_classes, 1) 60 | 61 | self.dropout1 = nn.Dropout(p=0.3) 62 | self.dropout2 = nn.Dropout(p=0.3) 63 | 64 | # initialise weights 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv3d): 67 | init_weights(m, init_type='kaiming') 68 | elif isinstance(m, nn.BatchNorm3d): 69 | init_weights(m, init_type='kaiming') 70 | 71 | def forward(self, inputs): 72 | conv1 = self.conv1(inputs) 73 | maxpool1 = self.maxpool1(conv1) 74 | 75 | conv2 = self.conv2(maxpool1) 76 | maxpool2 = self.maxpool2(conv2) 77 | 78 | conv3 = self.conv3(maxpool2) 79 | maxpool3 = self.maxpool3(conv3) 80 | 81 | conv4 = self.conv4(maxpool3) 82 | maxpool4 = self.maxpool4(conv4) 83 | 84 | center = self.center(maxpool4) 85 | center = self.dropout1(center) 86 | up4 = self.up_concat4(conv4, center) 87 | up3 = self.up_concat3(conv3, up4) 88 | up2 = self.up_concat2(conv2, up3) 89 | up1 = self.up_concat1(conv1, up2) 90 | up1 = self.dropout2(up1) 91 | 92 | final = self.final(up1) 93 | 94 | return final 95 | 96 | @staticmethod 97 | def apply_argmax_softmax(pred): 98 | log_p = F.softmax(pred, dim=1) 99 | 100 | return log_p 101 | -------------------------------------------------------------------------------- /code/networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the 3D U-Net paper: 4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: 5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. 6 | MICCAI (2) 2016: 424-432 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks 10 | """ 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from networks.networks_other import init_weights 17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT 18 | 19 | 20 | class unet_3D(nn.Module): 21 | 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 23 | super(unet_3D, self).__init__() 24 | self.is_deconv = is_deconv 25 | self.in_channels = in_channels 26 | self.is_batchnorm = is_batchnorm 27 | self.feature_scale = feature_scale 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 34 | 3, 3, 3), padding_size=(1, 1, 1)) 35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 36 | 37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 38 | 3, 3, 3), padding_size=(1, 1, 1)) 39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 40 | 41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 42 | 3, 3, 3), padding_size=(1, 1, 1)) 43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 44 | 45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 46 | 3, 3, 3), padding_size=(1, 1, 1)) 47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 48 | 49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 50 | 3, 3, 3), padding_size=(1, 1, 1)) 51 | 52 | # upsampling 53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(filters[0], n_classes, 1) 60 | 61 | self.dropout1 = nn.Dropout(p=0.3) 62 | self.dropout2 = nn.Dropout(p=0.3) 63 | 64 | # initialise weights 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv3d): 67 | init_weights(m, init_type='kaiming') 68 | elif isinstance(m, nn.BatchNorm3d): 69 | init_weights(m, init_type='kaiming') 70 | 71 | def forward(self, inputs): 72 | conv1 = self.conv1(inputs) 73 | maxpool1 = self.maxpool1(conv1) 74 | 75 | conv2 = self.conv2(maxpool1) 76 | maxpool2 = self.maxpool2(conv2) 77 | 78 | conv3 = self.conv3(maxpool2) 79 | maxpool3 = self.maxpool3(conv3) 80 | 81 | conv4 = self.conv4(maxpool3) 82 | maxpool4 = self.maxpool4(conv4) 83 | 84 | center = self.center(maxpool4) 85 | center = self.dropout1(center) 86 | up4 = self.up_concat4(conv4, center) 87 | up3 = self.up_concat3(conv3, up4) 88 | up2 = self.up_concat2(conv2, up3) 89 | up1 = self.up_concat1(conv1, up2) 90 | up1 = self.dropout2(up1) 91 | 92 | final = self.final(up1) 93 | 94 | return final 95 | 96 | @staticmethod 97 | def apply_argmax_softmax(pred): 98 | log_p = F.softmax(pred, dim=1) 99 | 100 | return log_p 101 | -------------------------------------------------------------------------------- /dataloaders/busi.py: -------------------------------------------------------------------------------- 1 | from dataloaders.transform import random_rot_flip, random_rotate, blur, obtain_cutmix_box 2 | 3 | from copy import deepcopy 4 | import h5py 5 | import math 6 | import numpy as np 7 | import os 8 | from PIL import Image 9 | import random 10 | from scipy.ndimage.interpolation import zoom 11 | from scipy import ndimage 12 | import torch 13 | from torch.utils.data import Dataset 14 | import torch.nn.functional as F 15 | from torchvision import transforms 16 | import cv2 17 | import augmentations 18 | from augmentations.ctaugment import OPS 19 | 20 | class BUSIDataset(Dataset): 21 | def __init__(self, name, root, mode, size=256, id_path=None, nsample=None): 22 | self.name = name 23 | self.root = root 24 | self.mode = mode 25 | self.size = size 26 | if mode == 'train_l' or mode == 'train_u': 27 | with open(id_path, 'r') as f: 28 | self.ids = [line.split()[0] for line in f.readlines()]#f.read()#.splitlines() 29 | with open(id_path, 'r') as f1: 30 | self.ids_label = [line.split()[1] for line in f1.readlines()] 31 | if mode == 'train_l' and nsample is not None: 32 | self.ids *= math.ceil(nsample / len(self.ids)) 33 | self.ids = self.ids[:nsample] 34 | self.ids_label *= math.ceil(nsample / len(self.ids_label)) 35 | self.ids_label = self.ids_label[:nsample] 36 | else: 37 | with open('/home/data/lxs/SSL1/SSL4MIS-master/SSL4MIS-master/splits/%s/valtest.txt' % name, 'r') as f: 38 | self.ids = [line.split()[0] for line in f.readlines()] 39 | with open('/home/data/lxs/SSL1/SSL4MIS-master/SSL4MIS-master/splits/%s/valtest.txt' % name, 'r') as f1: 40 | self.ids_label = [line.split()[1] for line in f1.readlines()] 41 | 42 | 43 | def __getitem__(self, item): 44 | id = self.ids[item] 45 | id_l = self.ids_label[item] 46 | img = cv2.imread(self.root + "/{}".format(id), cv2.IMREAD_GRAYSCALE) / 255.0 47 | mask = cv2.imread(self.root + "/{}".format(id_l), cv2.IMREAD_GRAYSCALE) 48 | # sample = h5py.File(os.path.join(self.root, id), 'r') 49 | # img = sample['image'][:] 50 | # mask = sample['label'][:] 51 | 52 | if self.mode == 'val' or self.mode == 'test': 53 | return torch.from_numpy(img).float().unsqueeze(0), torch.from_numpy(mask).unsqueeze(0).long() 54 | 55 | if random.random() > 0.5: 56 | img, mask = random_rot_flip(img, mask) 57 | elif random.random() > 0.5: 58 | img, mask = random_rotate(img, mask) 59 | x, y = img.shape 60 | img = zoom(img, (self.size / x, self.size / y), order=0) 61 | mask = zoom(mask, (self.size / x, self.size / y), order=0) 62 | 63 | if self.mode == 'train_l': 64 | return torch.from_numpy(img).unsqueeze(0).float(), torch.from_numpy(np.array(mask)).long() 65 | 66 | img = Image.fromarray((img * 255).astype(np.uint8)) 67 | img_s1, img_s2 = deepcopy(img), deepcopy(img) 68 | img = torch.from_numpy(np.array(img)).unsqueeze(0).float() / 255.0 69 | 70 | if random.random() < 0.8: 71 | #img_s1 = augmentations.cta_apply(img_s1, self.ops_strong) 72 | img_s1 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s1) 73 | img_s1 = blur(img_s1, p=0.5) 74 | cutmix_box1 = obtain_cutmix_box(self.size, p=0.5) 75 | img_s1 = torch.from_numpy(np.array(img_s1)).unsqueeze(0).float() / 255.0 76 | 77 | if random.random() < 0.8: 78 | #img_s2 = pil_to_tensor_fft(img_s2) 79 | #img_s2 = augmentations.cta_apply(img_s2, self.ops_strong) 80 | img_s2 = transforms.ColorJitter(0.5, 0.5, 0.5, 0.25)(img_s2) 81 | img_s2 = blur(img_s2, p=0.5) 82 | cutmix_box2 = obtain_cutmix_box(self.size, p=0.5) 83 | img_s2 = torch.from_numpy(np.array(img_s2)).unsqueeze(0).float() / 255.0 84 | 85 | return img, img_s1, img_s2, cutmix_box1, cutmix_box2 86 | 87 | def __len__(self): 88 | return len(self.ids) 89 | -------------------------------------------------------------------------------- /utils/vis_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.utils import tensor_to_image 3 | import matplotlib 4 | matplotlib.use('TkAgg') 5 | import matplotlib.pyplot as plt 6 | from kornia.geometry.boxes import Boxes 7 | from kornia.geometry.keypoints import Keypoints 8 | from kornia.contrib.models import SegmentationResults 9 | 10 | def colorize_masks(binary_masks: torch.Tensor, merge: bool = True, alpha: None | float = None) -> list[torch.Tensor]: 11 | """Convert binary masks (B, C, H, W), boolean tensors, into masks with colors (B, (3, 4) , H, W) - RGB or RGBA. Where C refers to the number of masks. 12 | Args: 13 | binary_masks: a batched boolean tensor (B, C, H, W) 14 | merge: If true, will join the batch dimension into a unique mask. 15 | alpha: alpha channel value. If None, will generate RGB images 16 | 17 | Returns: 18 | A list of `C` colored masks. 19 | """ 20 | B, C, H, W = binary_masks.shape 21 | OUT_C = 4 if alpha else 3 22 | 23 | output_masks = [] 24 | 25 | for idx in range(C): 26 | _out = torch.zeros(B, OUT_C, H, W, device=binary_masks.device, dtype=torch.float32) 27 | for b in range(B): 28 | color = torch.rand(1, 3, 1, 1, device=binary_masks.device, dtype=torch.float32) 29 | if alpha: 30 | color = torch.cat([color, torch.tensor([[[[alpha]]]], device=binary_masks.device, dtype=torch.float32)], dim=1) 31 | 32 | to_colorize = binary_masks[b, idx, ...].view(1, 1, H, W).repeat(1, OUT_C, 1, 1) 33 | _out[b, ...] = torch.where(to_colorize, color, _out[b, ...]) 34 | output_masks.append(_out) 35 | 36 | if merge: 37 | output_masks = [c.max(dim=0)[0] for c in output_masks] 38 | 39 | return output_masks 40 | 41 | 42 | def show_binary_masks(binary_masks: torch.Tensor, axes) -> None: 43 | """plot binary masks, with shape (B, C, H, W), where C refers to the number of masks. 44 | 45 | will merge the `B` channel into a unique mask. 46 | Args: 47 | binary_masks: a batched boolean tensor (B, C, H, W) 48 | ax: a list of matplotlib axes with lenght of C 49 | """ 50 | colored_masks = colorize_masks(binary_masks, True, 0.6) 51 | 52 | for ax, mask in zip(axes, colored_masks): 53 | ax.imshow(tensor_to_image(mask)) 54 | 55 | 56 | def show_boxes(boxes: Boxes, ax) -> None: 57 | boxes_tensor = boxes.to_tensor(mode="xywh").detach().cpu().numpy() 58 | for box in boxes_tensor: 59 | x0, y0, w, h = box 60 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="orange", facecolor=(0, 0, 0, 0), lw=2)) 61 | 62 | 63 | def show_points(points: tuple[Keypoints, torch.Tensor], ax, marker_size=200): 64 | coords, labels = points 65 | pos_points = coords[labels == 1].to_tensor().detach().cpu().numpy() 66 | neg_points = coords[labels == 0].to_tensor().detach().cpu().numpy() 67 | 68 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color="green", marker="+", s=marker_size, linewidth=2) 69 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color="red", marker="x", s=marker_size, linewidth=2) 70 | 71 | 72 | def show_image(image: torch.Tensor): 73 | plt.imshow(tensor_to_image(image)) 74 | plt.axis("off") 75 | plt.show() 76 | 77 | 78 | def show_predictions( 79 | image: torch.Tensor, 80 | predictions: SegmentationResults, 81 | points: tuple[Keypoints, torch.Tensor] | None = None, 82 | boxes: Boxes | None = None, 83 | ) -> None: 84 | n_masks = predictions.logits.shape[1] 85 | 86 | fig, axes = plt.subplots(1, n_masks, figsize=(21, 16)) 87 | axes = [axes] if n_masks == 1 else axes 88 | 89 | for idx, ax in enumerate(axes): 90 | score = predictions.scores[:, idx, ...].mean() 91 | ax.imshow(tensor_to_image(image)) 92 | ax.set_title(f"Mask {idx+1}, Score: {score:.3f}", fontsize=18) 93 | 94 | if points: 95 | show_points(points, ax) 96 | 97 | if boxes: 98 | show_boxes(boxes, ax) 99 | 100 | ax.axis("off") 101 | 102 | show_binary_masks(predictions.binary_masks, axes) 103 | plt.show() -------------------------------------------------------------------------------- /networks/VoxResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SEBlock(nn.Module): 10 | def __init__(self, in_channels, r): 11 | super(SEBlock, self).__init__() 12 | 13 | redu_chns = int(in_channels / r) 14 | self.se_layers = nn.Sequential( 15 | nn.AdaptiveAvgPool3d(1), 16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0), 17 | nn.ReLU(), 18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0), 19 | nn.ReLU()) 20 | 21 | def forward(self, x): 22 | f = self.se_layers(x) 23 | return f * x + x 24 | 25 | 26 | class VoxRex(nn.Module): 27 | def __init__(self, in_channels): 28 | super(VoxRex, self).__init__() 29 | self.block = nn.Sequential( 30 | nn.InstanceNorm3d(in_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(in_channels, in_channels, 33 | kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm3d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv3d(in_channels, in_channels, 37 | kernel_size=3, padding=1, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.block(x)+x 42 | 43 | 44 | class ConvBlock(nn.Module): 45 | """two convolution layers with batch norm and leaky relu""" 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(ConvBlock, self).__init__() 49 | self.conv_conv = nn.Sequential( 50 | nn.InstanceNorm3d(in_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv3d(in_channels, out_channels, 53 | kernel_size=3, padding=1, bias=False), 54 | nn.InstanceNorm3d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(out_channels, out_channels, 57 | kernel_size=3, padding=1, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.conv_conv(x) 62 | 63 | 64 | class UpBlock(nn.Module): 65 | """Upssampling followed by ConvBlock""" 66 | 67 | def __init__(self, in_channels, out_channels): 68 | super(UpBlock, self).__init__() 69 | self.up = nn.Upsample( 70 | scale_factor=2, mode='trilinear', align_corners=True) 71 | self.conv = ConvBlock(in_channels, out_channels) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | x = torch.cat([x2, x1], dim=1) 76 | return self.conv(x) 77 | 78 | 79 | class VoxResNet(nn.Module): 80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2): 81 | super(VoxResNet, self).__init__() 82 | self.in_chns = in_chns 83 | self.ft_chns = feature_chns 84 | self.n_class = class_num 85 | 86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1) 87 | self.res1 = VoxRex(feature_chns) 88 | self.res2 = VoxRex(feature_chns) 89 | self.res3 = VoxRex(feature_chns) 90 | self.res4 = VoxRex(feature_chns) 91 | self.res5 = VoxRex(feature_chns) 92 | self.res6 = VoxRex(feature_chns) 93 | 94 | self.up1 = UpBlock(feature_chns * 2, feature_chns) 95 | self.up2 = UpBlock(feature_chns * 2, feature_chns) 96 | 97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1) 98 | 99 | self.maxpool = nn.MaxPool3d(2) 100 | self.upsample = nn.Upsample( 101 | scale_factor=2, mode='trilinear', align_corners=True) 102 | 103 | def forward(self, x): 104 | x = self.maxpool(self.conv1(x)) 105 | x1 = self.res1(x) 106 | x2 = self.res2(x1) 107 | x2_pool = self.maxpool(x2) 108 | x3 = self.res3(x2_pool) 109 | x4 = self.maxpool(self.res4(x3)) 110 | x5 = self.res5(x4) 111 | x6 = self.res6(x5) 112 | up1 = self.up1(x6, x2_pool) 113 | up2 = self.up2(up1, x) 114 | up = self.upsample(up2) 115 | out = self.out(up) 116 | return out 117 | -------------------------------------------------------------------------------- /code/networks/VoxResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SEBlock(nn.Module): 10 | def __init__(self, in_channels, r): 11 | super(SEBlock, self).__init__() 12 | 13 | redu_chns = int(in_channels / r) 14 | self.se_layers = nn.Sequential( 15 | nn.AdaptiveAvgPool3d(1), 16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0), 17 | nn.ReLU(), 18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0), 19 | nn.ReLU()) 20 | 21 | def forward(self, x): 22 | f = self.se_layers(x) 23 | return f * x + x 24 | 25 | 26 | class VoxRex(nn.Module): 27 | def __init__(self, in_channels): 28 | super(VoxRex, self).__init__() 29 | self.block = nn.Sequential( 30 | nn.InstanceNorm3d(in_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(in_channels, in_channels, 33 | kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm3d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv3d(in_channels, in_channels, 37 | kernel_size=3, padding=1, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.block(x)+x 42 | 43 | 44 | class ConvBlock(nn.Module): 45 | """two convolution layers with batch norm and leaky relu""" 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(ConvBlock, self).__init__() 49 | self.conv_conv = nn.Sequential( 50 | nn.InstanceNorm3d(in_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv3d(in_channels, out_channels, 53 | kernel_size=3, padding=1, bias=False), 54 | nn.InstanceNorm3d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(out_channels, out_channels, 57 | kernel_size=3, padding=1, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.conv_conv(x) 62 | 63 | 64 | class UpBlock(nn.Module): 65 | """Upssampling followed by ConvBlock""" 66 | 67 | def __init__(self, in_channels, out_channels): 68 | super(UpBlock, self).__init__() 69 | self.up = nn.Upsample( 70 | scale_factor=2, mode='trilinear', align_corners=True) 71 | self.conv = ConvBlock(in_channels, out_channels) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | x = torch.cat([x2, x1], dim=1) 76 | return self.conv(x) 77 | 78 | 79 | class VoxResNet(nn.Module): 80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2): 81 | super(VoxResNet, self).__init__() 82 | self.in_chns = in_chns 83 | self.ft_chns = feature_chns 84 | self.n_class = class_num 85 | 86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1) 87 | self.res1 = VoxRex(feature_chns) 88 | self.res2 = VoxRex(feature_chns) 89 | self.res3 = VoxRex(feature_chns) 90 | self.res4 = VoxRex(feature_chns) 91 | self.res5 = VoxRex(feature_chns) 92 | self.res6 = VoxRex(feature_chns) 93 | 94 | self.up1 = UpBlock(feature_chns * 2, feature_chns) 95 | self.up2 = UpBlock(feature_chns * 2, feature_chns) 96 | 97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1) 98 | 99 | self.maxpool = nn.MaxPool3d(2) 100 | self.upsample = nn.Upsample( 101 | scale_factor=2, mode='trilinear', align_corners=True) 102 | 103 | def forward(self, x): 104 | x = self.maxpool(self.conv1(x)) 105 | x1 = self.res1(x) 106 | x2 = self.res2(x1) 107 | x2_pool = self.maxpool(x2) 108 | x3 = self.res3(x2_pool) 109 | x4 = self.maxpool(self.res4(x3)) 110 | x5 = self.res5(x4) 111 | x6 = self.res6(x5) 112 | up1 = self.up1(x6, x2_pool) 113 | up2 = self.up2(up1, x) 114 | up = self.upsample(up2) 115 | out = self.out(up) 116 | return out 117 | -------------------------------------------------------------------------------- /networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import logging 9 | import math 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /code/networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import logging 9 | import math 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /networks/mind.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchvision 4 | import SimpleITK as sitk 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | def gaussian_kernel(sigma, sz): 8 | xpos_vec = np.arange(sz) 9 | ypos_vec = np.arange(sz) 10 | output = np.ones([1, 1,sz, sz], dtype=np.single) 11 | midpos = sz // 2 12 | for xpos in xpos_vec: 13 | for ypos in ypos_vec: 14 | output[:,:,xpos,ypos] = np.exp(-((xpos-midpos)**2 + (ypos-midpos)**2) / (2 * sigma**2)) / (2 * np.pi * sigma**2) 15 | return output 16 | def torch_image_translate(input_, tx, ty, interpolation='nearest'): 17 | # got these parameters from solving the equations for pixel translations 18 | # on https://www.tensorflow.org/api_docs/python/tf/contrib/image/transform 19 | translation_matrix = torch.zeros([input_.shape[0], 3, 3], dtype=torch.float) 20 | translation_matrix[:, 0, 0] = 1.0 21 | translation_matrix[:, 1, 1] = 1.0 22 | translation_matrix[:, 0, 2] = -2*tx/(input_.size()[2]-1) 23 | translation_matrix[:, 1, 2] = -2*ty/(input_.size()[3]-1) 24 | translation_matrix[:, 2, 2] = 1.0 25 | grid = F.affine_grid(translation_matrix[:, 0:2, :], input_.size()).to(input_.device) 26 | wrp = F.grid_sample(input_.to(torch.float32), grid, mode=interpolation) 27 | return wrp 28 | def Dp(image, xshift, yshift, sigma, patch_size): 29 | shift_image = torch_image_translate(image, xshift, yshift, interpolation='nearest')#将image在x、y方向移动I`(a) 30 | diff = torch.sub(image, shift_image).cuda()#计算差分图I-I`(a) 31 | diff_square = torch.mul(diff, diff).cuda()#(I-I`(a))^2 32 | res = torch.conv2d(diff_square, weight =torch.from_numpy(gaussian_kernel(sigma, patch_size)).cuda(), stride=1, padding=3)#C*(I-I`(a))^2 33 | return res 34 | 35 | def MIND(image, patch_size = 7, neigh_size = 9, sigma = 2.0, eps = 1e-5,image_size0=256,image_size1=256, name='MIND'): 36 | # compute the Modality independent neighbourhood descriptor (MIND) of input image. 37 | # suppose the neighbor size is R, patch size is P. 38 | # input image is 384 x 256 x input_c_dim 39 | # output MIND is (384-P-R+2) x (256-P-R+2) x R*R 40 | reduce_size = int((patch_size + neigh_size - 2) / 2)#卷积后减少的size 41 | 42 | # estimate the local variance of each pixel within the input image. 43 | Vimg = torch.add(Dp(image, -1, 0, sigma, patch_size), Dp(image, 1, 0, sigma, patch_size)) 44 | Vimg = torch.add(Vimg, Dp(image, 0, -1, sigma, patch_size)) 45 | Vimg = torch.add(Vimg, Dp(image, 0, 1, sigma, patch_size))#sum(Dp) 46 | Vimg = torch.div(Vimg,4) + torch.mul(torch.ones_like(Vimg), eps)#防除零 47 | # estimate the (R*R)-length MIND feature by shifting the input image by R*R times. 48 | xshift_vec = np.arange( -(neigh_size//2), neigh_size - (neigh_size//2))#邻域计算 49 | yshift_vec = np.arange(-(neigh_size // 2), neigh_size - (neigh_size // 2))#邻域计算 50 | iter_pos = 0 51 | for xshift in xshift_vec: 52 | for yshift in yshift_vec: 53 | if (xshift,yshift) == (0,0): 54 | continue 55 | MIND_tmp = torch.exp(torch.mul(torch.div(Dp(image, xshift, yshift, sigma, patch_size), Vimg), -1))#exp(-D(I)/V(I)) 56 | tmp = MIND_tmp[:, :, reduce_size:(image_size0 - reduce_size), reduce_size:(image_size1 - reduce_size)] 57 | if iter_pos == 0: 58 | output = tmp 59 | else: 60 | output = torch.cat([output,tmp], 1) 61 | iter_pos = iter_pos + 1 62 | 63 | # normalization. 64 | input_max, input_indexes = torch.max(output, dim=1) 65 | output = torch.div(output,input_max.unsqueeze(1)) 66 | 67 | return output 68 | def abs_criterion(in_, target): 69 | return torch.mean(torch.abs(in_ - target)) 70 | if __name__ == '__main__': 71 | patch_size=7 72 | neigh_size=9 73 | sigma=2.0 74 | eps=1e-5 75 | image_size0=256 76 | image_size1=256 77 | 78 | 79 | 80 | A = torch.randn(16, 1,256, 256).cuda() 81 | B = torch.randn(16, 1,256, 256).cuda() 82 | 83 | 84 | A_MIND = MIND(A, patch_size, neigh_size, sigma, eps,image_size0,image_size1, name='realA_MIND') 85 | B_MIND = MIND(B, patch_size, neigh_size, sigma, eps,image_size0,image_size1, name='realA_MIND') 86 | g_loss_MIND = abs_criterion(A_MIND, B_MIND) 87 | print('g_loss_MIND', g_loss_MIND) 88 | 89 | -------------------------------------------------------------------------------- /dataloaders/brats_proprecessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from skimage import measure 5 | import nibabel as nib 6 | import SimpleITK as sitk 7 | import glob 8 | 9 | 10 | def brain_bbox(data, gt): 11 | mask = (data != 0) 12 | brain_voxels = np.where(mask != 0) 13 | minZidx = int(np.min(brain_voxels[0])) 14 | maxZidx = int(np.max(brain_voxels[0])) 15 | minXidx = int(np.min(brain_voxels[1])) 16 | maxXidx = int(np.max(brain_voxels[1])) 17 | minYidx = int(np.min(brain_voxels[2])) 18 | maxYidx = int(np.max(brain_voxels[2])) 19 | data_bboxed = data[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 20 | gt_bboxed = gt[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 21 | return data_bboxed, gt_bboxed 22 | 23 | 24 | def volume_bounding_box(data, gt, expend=0, status="train"): 25 | data, gt = brain_bbox(data, gt) 26 | print(data.shape) 27 | mask = (gt != 0) 28 | brain_voxels = np.where(mask != 0) 29 | z, x, y = data.shape 30 | minZidx = int(np.min(brain_voxels[0])) 31 | maxZidx = int(np.max(brain_voxels[0])) 32 | minXidx = int(np.min(brain_voxels[1])) 33 | maxXidx = int(np.max(brain_voxels[1])) 34 | minYidx = int(np.min(brain_voxels[2])) 35 | maxYidx = int(np.max(brain_voxels[2])) 36 | 37 | minZidx_jitterd = max(minZidx - expend, 0) 38 | maxZidx_jitterd = min(maxZidx + expend, z) 39 | minXidx_jitterd = max(minXidx - expend, 0) 40 | maxXidx_jitterd = min(maxXidx + expend, x) 41 | minYidx_jitterd = max(minYidx - expend, 0) 42 | maxYidx_jitterd = min(maxYidx + expend, y) 43 | 44 | data_bboxed = data[minZidx_jitterd:maxZidx_jitterd, 45 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 46 | print([minZidx, maxZidx, minXidx, maxXidx, minYidx, maxYidx]) 47 | print([minZidx_jitterd, maxZidx_jitterd, 48 | minXidx_jitterd, maxXidx_jitterd, minYidx_jitterd, maxYidx_jitterd]) 49 | 50 | if status == "train": 51 | gt_bboxed = np.zeros_like(data_bboxed, dtype=np.uint8) 52 | gt_bboxed[expend:maxZidx_jitterd-expend, expend:maxXidx_jitterd - 53 | expend, expend:maxYidx_jitterd - expend] = 1 54 | return data_bboxed, gt_bboxed 55 | 56 | if status == "test": 57 | gt_bboxed = gt[minZidx_jitterd:maxZidx_jitterd, 58 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 59 | return data_bboxed, gt_bboxed 60 | 61 | 62 | def itensity_normalize_one_volume(volume): 63 | """ 64 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 65 | inputs: 66 | volume: the input nd volume 67 | outputs: 68 | out: the normalized nd volume 69 | """ 70 | 71 | pixels = volume[volume > 0] 72 | mean = pixels.mean() 73 | std = pixels.std() 74 | out = (volume - mean)/std 75 | out_random = np.random.normal(0, 1, size=volume.shape) 76 | # out[volume == 0] = out_random[volume == 0] 77 | out = out.astype(np.float32) 78 | return out 79 | 80 | 81 | class MedicalImageDeal(object): 82 | def __init__(self, img, percent=1): 83 | self.img = img 84 | self.percent = percent 85 | 86 | @property 87 | def valid_img(self): 88 | from skimage import exposure 89 | cdf = exposure.cumulative_distribution(self.img) 90 | watershed = cdf[1][cdf[0] >= self.percent][0] 91 | return np.clip(self.img, self.img.min(), watershed) 92 | 93 | @property 94 | def norm_img(self): 95 | return (self.img - self.img.min()) / (self.img.max() - self.img.min()) 96 | 97 | 98 | all_flair = glob.glob("flair/*_flair.nii.gz") 99 | for p in all_flair: 100 | data = sitk.GetArrayFromImage(sitk.ReadImage(p)) 101 | lab = sitk.GetArrayFromImage(sitk.ReadImage(p.replace("flair", "seg"))) 102 | img, lab = brain_bbox(data, lab) 103 | img = MedicalImageDeal(img, percent=0.999).valid_img 104 | img = itensity_normalize_one_volume(img) 105 | lab[lab > 0] = 1 106 | uid = p.split("/")[-1] 107 | sitk.WriteImage(sitk.GetImageFromArray( 108 | img), "/media/xdluo/Data/brats19/data/flair/{}".format(uid)) 109 | sitk.WriteImage(sitk.GetImageFromArray( 110 | lab), "/media/xdluo/Data/brats19/data/label/{}".format(uid)) 111 | -------------------------------------------------------------------------------- /code/dataloaders/brats_proprecessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from skimage import measure 5 | import nibabel as nib 6 | import SimpleITK as sitk 7 | import glob 8 | 9 | 10 | def brain_bbox(data, gt): 11 | mask = (data != 0) 12 | brain_voxels = np.where(mask != 0) 13 | minZidx = int(np.min(brain_voxels[0])) 14 | maxZidx = int(np.max(brain_voxels[0])) 15 | minXidx = int(np.min(brain_voxels[1])) 16 | maxXidx = int(np.max(brain_voxels[1])) 17 | minYidx = int(np.min(brain_voxels[2])) 18 | maxYidx = int(np.max(brain_voxels[2])) 19 | data_bboxed = data[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 20 | gt_bboxed = gt[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 21 | return data_bboxed, gt_bboxed 22 | 23 | 24 | def volume_bounding_box(data, gt, expend=0, status="train"): 25 | data, gt = brain_bbox(data, gt) 26 | print(data.shape) 27 | mask = (gt != 0) 28 | brain_voxels = np.where(mask != 0) 29 | z, x, y = data.shape 30 | minZidx = int(np.min(brain_voxels[0])) 31 | maxZidx = int(np.max(brain_voxels[0])) 32 | minXidx = int(np.min(brain_voxels[1])) 33 | maxXidx = int(np.max(brain_voxels[1])) 34 | minYidx = int(np.min(brain_voxels[2])) 35 | maxYidx = int(np.max(brain_voxels[2])) 36 | 37 | minZidx_jitterd = max(minZidx - expend, 0) 38 | maxZidx_jitterd = min(maxZidx + expend, z) 39 | minXidx_jitterd = max(minXidx - expend, 0) 40 | maxXidx_jitterd = min(maxXidx + expend, x) 41 | minYidx_jitterd = max(minYidx - expend, 0) 42 | maxYidx_jitterd = min(maxYidx + expend, y) 43 | 44 | data_bboxed = data[minZidx_jitterd:maxZidx_jitterd, 45 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 46 | print([minZidx, maxZidx, minXidx, maxXidx, minYidx, maxYidx]) 47 | print([minZidx_jitterd, maxZidx_jitterd, 48 | minXidx_jitterd, maxXidx_jitterd, minYidx_jitterd, maxYidx_jitterd]) 49 | 50 | if status == "train": 51 | gt_bboxed = np.zeros_like(data_bboxed, dtype=np.uint8) 52 | gt_bboxed[expend:maxZidx_jitterd-expend, expend:maxXidx_jitterd - 53 | expend, expend:maxYidx_jitterd - expend] = 1 54 | return data_bboxed, gt_bboxed 55 | 56 | if status == "test": 57 | gt_bboxed = gt[minZidx_jitterd:maxZidx_jitterd, 58 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 59 | return data_bboxed, gt_bboxed 60 | 61 | 62 | def itensity_normalize_one_volume(volume): 63 | """ 64 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 65 | inputs: 66 | volume: the input nd volume 67 | outputs: 68 | out: the normalized nd volume 69 | """ 70 | 71 | pixels = volume[volume > 0] 72 | mean = pixels.mean() 73 | std = pixels.std() 74 | out = (volume - mean)/std 75 | out_random = np.random.normal(0, 1, size=volume.shape) 76 | # out[volume == 0] = out_random[volume == 0] 77 | out = out.astype(np.float32) 78 | return out 79 | 80 | 81 | class MedicalImageDeal(object): 82 | def __init__(self, img, percent=1): 83 | self.img = img 84 | self.percent = percent 85 | 86 | @property 87 | def valid_img(self): 88 | from skimage import exposure 89 | cdf = exposure.cumulative_distribution(self.img) 90 | watershed = cdf[1][cdf[0] >= self.percent][0] 91 | return np.clip(self.img, self.img.min(), watershed) 92 | 93 | @property 94 | def norm_img(self): 95 | return (self.img - self.img.min()) / (self.img.max() - self.img.min()) 96 | 97 | 98 | all_flair = glob.glob("flair/*_flair.nii.gz") 99 | for p in all_flair: 100 | data = sitk.GetArrayFromImage(sitk.ReadImage(p)) 101 | lab = sitk.GetArrayFromImage(sitk.ReadImage(p.replace("flair", "seg"))) 102 | img, lab = brain_bbox(data, lab) 103 | img = MedicalImageDeal(img, percent=0.999).valid_img 104 | img = itensity_normalize_one_volume(img) 105 | lab[lab > 0] = 1 106 | uid = p.split("/")[-1] 107 | sitk.WriteImage(sitk.GetImageFromArray( 108 | img), "/media/xdluo/Data/brats19/data/flair/{}".format(uid)) 109 | sitk.WriteImage(sitk.GetImageFromArray( 110 | lab), "/media/xdluo/Data/brats19/data/label/{}".format(uid)) 111 | -------------------------------------------------------------------------------- /networks/unet_3D_dv_semi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/ozan-oktay/Attention-Gated-Networks 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT, UnetDsv3 9 | import torch.nn.functional as F 10 | from networks.networks_other import init_weights 11 | 12 | 13 | class unet_3D_dv_semi(nn.Module): 14 | 15 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 16 | super(unet_3D_dv_semi, self).__init__() 17 | self.is_deconv = is_deconv 18 | self.in_channels = in_channels 19 | self.is_batchnorm = is_batchnorm 20 | self.feature_scale = feature_scale 21 | 22 | filters = [64, 128, 256, 512, 1024] 23 | filters = [int(x / self.feature_scale) for x in filters] 24 | 25 | # downsampling 26 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 27 | 3, 3, 3), padding_size=(1, 1, 1)) 28 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 29 | 30 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 31 | 3, 3, 3), padding_size=(1, 1, 1)) 32 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 33 | 34 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 35 | 3, 3, 3), padding_size=(1, 1, 1)) 36 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 37 | 38 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 39 | 3, 3, 3), padding_size=(1, 1, 1)) 40 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 41 | 42 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 43 | 3, 3, 3), padding_size=(1, 1, 1)) 44 | 45 | # upsampling 46 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 47 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 48 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 49 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 50 | 51 | # deep supervision 52 | self.dsv4 = UnetDsv3( 53 | in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3( 55 | in_size=filters[2], out_size=n_classes, scale_factor=4) 56 | self.dsv2 = UnetDsv3( 57 | in_size=filters[1], out_size=n_classes, scale_factor=2) 58 | self.dsv1 = nn.Conv3d( 59 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 60 | 61 | self.dropout1 = nn.Dropout3d(p=0.5) 62 | self.dropout2 = nn.Dropout3d(p=0.3) 63 | self.dropout3 = nn.Dropout3d(p=0.2) 64 | self.dropout4 = nn.Dropout3d(p=0.1) 65 | 66 | # initialise weights 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv3d): 69 | init_weights(m, init_type='kaiming') 70 | elif isinstance(m, nn.BatchNorm3d): 71 | init_weights(m, init_type='kaiming') 72 | 73 | def forward(self, inputs): 74 | conv1 = self.conv1(inputs) 75 | maxpool1 = self.maxpool1(conv1) 76 | 77 | conv2 = self.conv2(maxpool1) 78 | maxpool2 = self.maxpool2(conv2) 79 | 80 | conv3 = self.conv3(maxpool2) 81 | maxpool3 = self.maxpool3(conv3) 82 | 83 | conv4 = self.conv4(maxpool3) 84 | maxpool4 = self.maxpool4(conv4) 85 | 86 | center = self.center(maxpool4) 87 | 88 | up4 = self.up_concat4(conv4, center) 89 | up4 = self.dropout1(up4) 90 | 91 | up3 = self.up_concat3(conv3, up4) 92 | up3 = self.dropout2(up3) 93 | 94 | up2 = self.up_concat2(conv2, up3) 95 | up2 = self.dropout3(up2) 96 | 97 | up1 = self.up_concat1(conv1, up2) 98 | up1 = self.dropout4(up1) 99 | 100 | # Deep Supervision 101 | dsv4 = self.dsv4(up4) 102 | dsv3 = self.dsv3(up3) 103 | dsv2 = self.dsv2(up2) 104 | dsv1 = self.dsv1(up1) 105 | 106 | return dsv1, dsv2, dsv3, dsv4 107 | 108 | @staticmethod 109 | def apply_argmax_softmax(pred): 110 | log_p = F.softmax(pred, dim=1) 111 | 112 | return log_p 113 | -------------------------------------------------------------------------------- /code/networks/unet_3D_dv_semi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/ozan-oktay/Attention-Gated-Networks 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT, UnetDsv3 9 | import torch.nn.functional as F 10 | from networks.networks_other import init_weights 11 | 12 | 13 | class unet_3D_dv_semi(nn.Module): 14 | 15 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 16 | super(unet_3D_dv_semi, self).__init__() 17 | self.is_deconv = is_deconv 18 | self.in_channels = in_channels 19 | self.is_batchnorm = is_batchnorm 20 | self.feature_scale = feature_scale 21 | 22 | filters = [64, 128, 256, 512, 1024] 23 | filters = [int(x / self.feature_scale) for x in filters] 24 | 25 | # downsampling 26 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 27 | 3, 3, 3), padding_size=(1, 1, 1)) 28 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 29 | 30 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 31 | 3, 3, 3), padding_size=(1, 1, 1)) 32 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 33 | 34 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 35 | 3, 3, 3), padding_size=(1, 1, 1)) 36 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 37 | 38 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 39 | 3, 3, 3), padding_size=(1, 1, 1)) 40 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 41 | 42 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 43 | 3, 3, 3), padding_size=(1, 1, 1)) 44 | 45 | # upsampling 46 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 47 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 48 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 49 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 50 | 51 | # deep supervision 52 | self.dsv4 = UnetDsv3( 53 | in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3( 55 | in_size=filters[2], out_size=n_classes, scale_factor=4) 56 | self.dsv2 = UnetDsv3( 57 | in_size=filters[1], out_size=n_classes, scale_factor=2) 58 | self.dsv1 = nn.Conv3d( 59 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 60 | 61 | self.dropout1 = nn.Dropout3d(p=0.5) 62 | self.dropout2 = nn.Dropout3d(p=0.3) 63 | self.dropout3 = nn.Dropout3d(p=0.2) 64 | self.dropout4 = nn.Dropout3d(p=0.1) 65 | 66 | # initialise weights 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv3d): 69 | init_weights(m, init_type='kaiming') 70 | elif isinstance(m, nn.BatchNorm3d): 71 | init_weights(m, init_type='kaiming') 72 | 73 | def forward(self, inputs): 74 | conv1 = self.conv1(inputs) 75 | maxpool1 = self.maxpool1(conv1) 76 | 77 | conv2 = self.conv2(maxpool1) 78 | maxpool2 = self.maxpool2(conv2) 79 | 80 | conv3 = self.conv3(maxpool2) 81 | maxpool3 = self.maxpool3(conv3) 82 | 83 | conv4 = self.conv4(maxpool3) 84 | maxpool4 = self.maxpool4(conv4) 85 | 86 | center = self.center(maxpool4) 87 | 88 | up4 = self.up_concat4(conv4, center) 89 | up4 = self.dropout1(up4) 90 | 91 | up3 = self.up_concat3(conv3, up4) 92 | up3 = self.dropout2(up3) 93 | 94 | up2 = self.up_concat2(conv2, up3) 95 | up2 = self.dropout3(up2) 96 | 97 | up1 = self.up_concat1(conv1, up2) 98 | up1 = self.dropout4(up1) 99 | 100 | # Deep Supervision 101 | dsv4 = self.dsv4(up4) 102 | dsv3 = self.dsv3(up3) 103 | dsv2 = self.dsv2(up2) 104 | dsv1 = self.dsv1(up1) 105 | 106 | return dsv1, dsv2, dsv3, dsv4 107 | 108 | @staticmethod 109 | def apply_argmax_softmax(pred): 110 | log_p = F.softmax(pred, dim=1) 111 | 112 | return log_p 113 | --------------------------------------------------------------------------------