├── code ├── util │ ├── CelebAMask-HQ │ │ ├── __init__.py │ │ └── Data_preprocessing │ │ │ ├── __pycache__ │ │ │ └── utils.cpython-37.pyc │ │ │ ├── utils.py │ │ │ ├── g_mask.py │ │ │ └── v_mask.py │ ├── __pycache__ │ │ ├── coco.cpython-36.pyc │ │ ├── coco.cpython-37.pyc │ │ ├── coco.cpython-38.pyc │ │ ├── html.cpython-37.pyc │ │ ├── html.cpython-38.pyc │ │ ├── util.cpython-36.pyc │ │ ├── util.cpython-37.pyc │ │ ├── util.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── visualizer.cpython-37.pyc │ │ ├── visualizer.cpython-38.pyc │ │ ├── iter_counter.cpython-37.pyc │ │ └── iter_counter.cpython-38.pyc │ ├── __init__.py │ ├── html.py │ ├── iter_counter.py │ ├── coco.py │ ├── visualizer.py │ └── util.py ├── networks │ ├── __pycache__ │ │ ├── enet.cpython-36.pyc │ │ ├── enet.cpython-38.pyc │ │ ├── pnet.cpython-36.pyc │ │ ├── pnet.cpython-38.pyc │ │ ├── unet.cpython-36.pyc │ │ ├── unet.cpython-38.pyc │ │ ├── cenet.cpython-36.pyc │ │ ├── config.cpython-36.pyc │ │ ├── config.cpython-38.pyc │ │ ├── nnunet.cpython-36.pyc │ │ ├── nnunet.cpython-38.pyc │ │ ├── scse2d.cpython-36.pyc │ │ ├── unet2d.cpython-36.pyc │ │ ├── attention.cpython-36.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── unet2d_2.cpython-36.pyc │ │ ├── net_factory.cpython-36.pyc │ │ ├── net_factory.cpython-38.pyc │ │ ├── unet2d_nest.cpython-36.pyc │ │ ├── unet2d_scse.cpython-36.pyc │ │ ├── unet_2Plus.cpython-36.pyc │ │ ├── unet_3Plus.cpython-36.pyc │ │ ├── efficientunet.cpython-36.pyc │ │ ├── efficientunet.cpython-38.pyc │ │ ├── neural_network.cpython-36.pyc │ │ ├── neural_network.cpython-38.pyc │ │ ├── unet2d_attention.cpython-36.pyc │ │ ├── efficient_encoder.cpython-36.pyc │ │ ├── efficient_encoder.cpython-38.pyc │ │ ├── vision_transformer.cpython-36.pyc │ │ ├── vision_transformer.cpython-38.pyc │ │ ├── swin_transformer_unet_skip_expand_decoder_sys.cpython-36.pyc │ │ └── swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc │ ├── net_factory.py │ └── unet2d.py ├── utils │ ├── __pycache__ │ │ ├── losses.cpython-36.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── metrics.cpython-36.pyc │ │ ├── metrics.cpython-38.pyc │ │ ├── ramps.cpython-36.pyc │ │ ├── ramps.cpython-38.pyc │ │ ├── distance_metric.cpython-36.pyc │ │ └── distance_metric.cpython-38.pyc │ ├── metrics.py │ ├── ramps.py │ ├── util.py │ └── losses.py ├── models │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── pix2pix_model.cpython-36.pyc │ │ ├── pix2pix_model.cpython-37.pyc │ │ └── pix2pix_model.cpython-38.pyc │ ├── networks │ │ ├── __pycache__ │ │ │ ├── loss.cpython-36.pyc │ │ │ ├── loss.cpython-37.pyc │ │ │ ├── loss.cpython-38.pyc │ │ │ ├── encoder.cpython-36.pyc │ │ │ ├── encoder.cpython-37.pyc │ │ │ ├── encoder.cpython-38.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── generator.cpython-36.pyc │ │ │ ├── generator.cpython-37.pyc │ │ │ ├── generator.cpython-38.pyc │ │ │ ├── architecture.cpython-36.pyc │ │ │ ├── architecture.cpython-37.pyc │ │ │ ├── architecture.cpython-38.pyc │ │ │ ├── base_network.cpython-36.pyc │ │ │ ├── base_network.cpython-37.pyc │ │ │ ├── base_network.cpython-38.pyc │ │ │ ├── discriminator.cpython-36.pyc │ │ │ ├── discriminator.cpython-37.pyc │ │ │ ├── discriminator.cpython-38.pyc │ │ │ ├── normalization.cpython-36.pyc │ │ │ ├── normalization.cpython-37.pyc │ │ │ └── normalization.cpython-38.pyc │ │ ├── sync_batchnorm │ │ │ ├── __pycache__ │ │ │ │ ├── comm.cpython-36.pyc │ │ │ │ ├── comm.cpython-37.pyc │ │ │ │ ├── comm.cpython-38.pyc │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ │ ├── batchnorm.cpython-37.pyc │ │ │ │ ├── batchnorm.cpython-38.pyc │ │ │ │ ├── replicate.cpython-36.pyc │ │ │ │ ├── replicate.cpython-37.pyc │ │ │ │ └── replicate.cpython-38.pyc │ │ │ ├── __init__.py │ │ │ ├── unittest.py │ │ │ ├── batchnorm_reimpl.py │ │ │ ├── replicate.py │ │ │ └── comm.py │ │ ├── encoder.py │ │ ├── __init__.py │ │ ├── base_network.py │ │ ├── discriminator.py │ │ ├── generator.py │ │ ├── loss.py │ │ ├── architecture.py │ │ └── normalization.py │ └── __init__.py ├── dataloaders │ ├── __pycache__ │ │ ├── utils.cpython-36.pyc │ │ ├── utils.cpython-38.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── dataset_covid.cpython-36.pyc │ │ └── dataset_covid.cpython-38.pyc │ ├── dataset_covid.py │ ├── dataset.py │ └── utils.py ├── test.py ├── test_covid.py ├── train_template.py ├── train_SAST.py └── train_SACPS.py ├── dataset.png ├── framework.png ├── data ├── MOS1000 │ ├── val_slice.xlsx │ ├── test_slice.xlsx │ ├── test_volume.xlsx │ ├── val_volume.xlsx │ ├── train_slice_label.xlsx │ └── train_slice_unlabel.xlsx └── COVID249 │ ├── test_slice.xlsx │ ├── test_volume.xlsx │ ├── train_0.1_l.xlsx │ ├── train_0.1_u.xlsx │ ├── train_0.2_l.xlsx │ ├── train_0.2_u.xlsx │ ├── train_0.3_l.xlsx │ ├── train_0.3_u.xlsx │ ├── val_slice.xlsx │ └── val_volume.xlsx └── README.md /code/util/CelebAMask-HQ/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/dataset.png -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/framework.png -------------------------------------------------------------------------------- /data/MOS1000/val_slice.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/val_slice.xlsx -------------------------------------------------------------------------------- /data/COVID249/test_slice.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/test_slice.xlsx -------------------------------------------------------------------------------- /data/COVID249/test_volume.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/test_volume.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.1_l.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.1_l.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.1_u.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.1_u.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.2_l.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.2_l.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.2_u.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.2_u.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.3_l.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.3_l.xlsx -------------------------------------------------------------------------------- /data/COVID249/train_0.3_u.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/train_0.3_u.xlsx -------------------------------------------------------------------------------- /data/COVID249/val_slice.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/val_slice.xlsx -------------------------------------------------------------------------------- /data/COVID249/val_volume.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/COVID249/val_volume.xlsx -------------------------------------------------------------------------------- /data/MOS1000/test_slice.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/test_slice.xlsx -------------------------------------------------------------------------------- /data/MOS1000/test_volume.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/test_volume.xlsx -------------------------------------------------------------------------------- /data/MOS1000/val_volume.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/val_volume.xlsx -------------------------------------------------------------------------------- /data/MOS1000/train_slice_label.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/train_slice_label.xlsx -------------------------------------------------------------------------------- /data/MOS1000/train_slice_unlabel.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/data/MOS1000/train_slice_unlabel.xlsx -------------------------------------------------------------------------------- /code/util/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/coco.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/coco.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/coco.cpython-38.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/html.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/html.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/html.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/html.cpython-38.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/enet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/enet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/enet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/enet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/pnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/pnet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/pnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/pnet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/ramps.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/ramps.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/ramps.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/cenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/cenet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/nnunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/nnunet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/nnunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/nnunet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/scse2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/scse2d.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet2d.cpython-36.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/visualizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/visualizer.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/dataloaders/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/dataloaders/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet2d_2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet2d_2.cpython-36.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/iter_counter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/iter_counter.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/__pycache__/iter_counter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/__pycache__/iter_counter.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/dataloaders/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/pix2pix_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/pix2pix_model.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/pix2pix_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/pix2pix_model.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/pix2pix_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/__pycache__/pix2pix_model.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/net_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/net_factory.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/net_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/net_factory.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet2d_nest.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet2d_nest.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet2d_scse.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet2d_scse.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet_2Plus.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet_2Plus.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet_3Plus.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet_3Plus.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/encoder.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficientunet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/efficientunet.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficientunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/efficientunet.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/neural_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/neural_network.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/neural_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/neural_network.cpython-38.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/distance_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/distance_metric.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/distance_metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/utils/__pycache__/distance_metric.cpython-38.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset_covid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/dataloaders/__pycache__/dataset_covid.cpython-36.pyc -------------------------------------------------------------------------------- /code/dataloaders/__pycache__/dataset_covid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/dataloaders/__pycache__/dataset_covid.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/generator.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/generator.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/generator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/generator.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/unet2d_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/unet2d_attention.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/architecture.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/architecture.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/architecture.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/architecture.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/architecture.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/architecture.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/base_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/base_network.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/base_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/base_network.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/base_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/base_network.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficient_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/efficient_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/efficient_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/efficient_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vision_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/vision_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/vision_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/vision_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/discriminator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/discriminator.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/normalization.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/normalization.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/__pycache__/normalization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/__pycache__/normalization.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/models/networks/sync_batchnorm/__pycache__/replicate.cpython-38.pyc -------------------------------------------------------------------------------- /code/util/CelebAMask-HQ/Data_preprocessing/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/util/CelebAMask-HQ/Data_preprocessing/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /code/util/CelebAMask-HQ/Data_preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def make_folder(path): 4 | if not os.path.exists(os.path.join(path)): 5 | os.makedirs(os.path.join(path)) 6 | 7 | 8 | -------------------------------------------------------------------------------- /code/util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-36.pyc -------------------------------------------------------------------------------- /code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiLyu/SASSL/HEAD/code/networks/__pycache__/swin_transformer_unet_skip_expand_decoder_sys.cpython-38.pyc -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.unet2d import UNet2D 2 | 3 | def net_factory(net_type="unet"): 4 | if net_type == "unet": 5 | net = UNet2D().cuda() 6 | else: 7 | net = None 8 | return net -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /code/util/CelebAMask-HQ/Data_preprocessing/g_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import numpy as np 5 | from utils import make_folder 6 | 7 | 8 | label_list = ['skin', 'nose', 'eye_g', 'l_eye', 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 'mouth', 'u_lip', 9 | 'l_lip', 'hair', 'hat', 'ear_r', 'neck_l', 'neck', 'cloth'] 10 | 11 | folder_base = '/media/zhup/Data/CelebAMask-HQ/CelebAMaskHQ-mask-anno' 12 | folder_save = '/media/zhup/Data/CelebAMask-HQ/CelebAMaskHQ-mask' 13 | img_num = 30000 14 | 15 | make_folder(folder_save) 16 | 17 | for k in range(img_num): 18 | folder_num = int(k / 2000) 19 | im_base = np.zeros((512, 512)) 20 | for idx, label in enumerate(label_list): 21 | filename = os.path.join(folder_base, str(folder_num), str(k).rjust(5, '0') + '_' + label + '.png') 22 | if (os.path.exists(filename)): 23 | print(label, idx + 1) 24 | im = cv2.imread(filename) 25 | im = im[:, :, 0] 26 | im_base[im != 0] = (idx + 1) 27 | 28 | filename_save = os.path.join(folder_save, str(k) + '.png') 29 | print(filename_save) 30 | cv2.imwrite(filename_save, im_base) 31 | 32 | -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch 8 | 9 | 10 | def find_model_using_name(model_name): 11 | # Given the option --model [modelname], 12 | # the file "models/modelname_model.py" 13 | # will be imported. 14 | model_filename = "models." + model_name + "_model" 15 | modellib = importlib.import_module(model_filename) 16 | 17 | # In the file, the class called ModelNameModel() will 18 | # be instantiated. It has to be a subclass of torch.nn.Module, 19 | # and it is case-insensitive. 20 | model = None 21 | target_model_name = model_name.replace('_', '') + 'model' 22 | for name, cls in modellib.__dict__.items(): 23 | if name.lower() == target_model_name.lower() \ 24 | and issubclass(cls, torch.nn.Module): 25 | model = cls 26 | 27 | if model is None: 28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 29 | exit(0) 30 | 31 | return model 32 | 33 | 34 | def get_option_setter(model_name): 35 | model_class = find_model_using_name(model_name) 36 | return model_class.modify_commandline_options 37 | 38 | 39 | def create_model(opt): 40 | model = find_model_using_name(opt.model) 41 | instance = model(opt) 42 | print("model [%s] was created" % (type(instance).__name__)) 43 | 44 | return instance 45 | -------------------------------------------------------------------------------- /code/models/networks/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | 12 | 13 | class ConvEncoder(BaseNetwork): 14 | """ Same architecture as the image discriminator """ 15 | 16 | def __init__(self, opt): 17 | super().__init__() 18 | 19 | kw = 3 20 | pw = int(np.ceil((kw - 1.0) / 2)) 21 | ndf = opt.ngf 22 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 23 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 24 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 25 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 26 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 27 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 28 | if opt.crop_size >= 256: 29 | self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 30 | 31 | self.so = s0 = 4 32 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) 33 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) 34 | 35 | self.actvn = nn.LeakyReLU(0.2, False) 36 | self.opt = opt 37 | 38 | def forward(self, x): 39 | if x.size(2) != 256 or x.size(3) != 256: 40 | x = F.interpolate(x, size=(256, 256), mode='bilinear') 41 | 42 | x = self.layer1(x) 43 | x = self.layer2(self.actvn(x)) 44 | x = self.layer3(self.actvn(x)) 45 | x = self.layer4(self.actvn(x)) 46 | x = self.layer5(self.actvn(x)) 47 | if self.opt.crop_size >= 256: 48 | x = self.layer6(self.actvn(x)) 49 | x = self.actvn(x) 50 | 51 | x = x.view(x.size(0), -1) 52 | mu = self.fc_mu(x) 53 | logvar = self.fc_var(x) 54 | 55 | return mu, logvar 56 | -------------------------------------------------------------------------------- /code/models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | from models.networks.base_network import BaseNetwork 8 | from models.networks.loss import * 9 | from models.networks.discriminator import * 10 | from models.networks.generator import * 11 | from models.networks.encoder import * 12 | import util.util as util 13 | 14 | 15 | def find_network_using_name(target_network_name, filename): 16 | target_class_name = target_network_name + filename 17 | module_name = 'models.networks.' + filename 18 | network = util.find_class_in_module(target_class_name, module_name) 19 | 20 | assert issubclass(network, BaseNetwork), \ 21 | "Class %s should be a subclass of BaseNetwork" % network 22 | 23 | return network 24 | 25 | 26 | def modify_commandline_options(parser, is_train): 27 | opt, _ = parser.parse_known_args() 28 | 29 | netG_cls = find_network_using_name(opt.netG, 'generator') 30 | parser = netG_cls.modify_commandline_options(parser, is_train) 31 | if is_train: 32 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 33 | parser = netD_cls.modify_commandline_options(parser, is_train) 34 | netE_cls = find_network_using_name('conv', 'encoder') 35 | parser = netE_cls.modify_commandline_options(parser, is_train) 36 | 37 | return parser 38 | 39 | 40 | def create_network(cls, opt): 41 | net = cls(opt) 42 | net.print_network() 43 | if len(opt.gpu_ids) > 0: 44 | assert(torch.cuda.is_available()) 45 | net.cuda() 46 | net.init_weights(opt.init_type, opt.init_variance) 47 | return net 48 | 49 | 50 | def define_G(opt): 51 | netG_cls = find_network_using_name(opt.netG, 'generator') 52 | return create_network(netG_cls, opt) 53 | 54 | 55 | def define_D(opt): 56 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 57 | return create_network(netD_cls, opt) 58 | 59 | 60 | def define_E(opt): 61 | # there exists only one encoder type 62 | netE_cls = find_network_using_name('conv', 'encoder') 63 | return create_network(netE_cls, opt) 64 | -------------------------------------------------------------------------------- /code/util/CelebAMask-HQ/Data_preprocessing/v_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils import make_folder 4 | import skimage.io 5 | 6 | 7 | def labelcolormap(N): 8 | if N == 19: # CelebAMask-HQ 9 | cmap = np.array([(0, 0, 0), (204, 0, 0), (76, 153, 0), 10 | (204, 204, 0), (51, 51, 255), (204, 0, 204), (0, 255, 255), 11 | (51, 255, 255), (102, 51, 0), (255, 0, 0), (102, 204, 0), 12 | (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153), 13 | (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)], 14 | dtype=np.uint8) 15 | else: 16 | cmap = np.zeros((N, 3), dtype=np.uint8) 17 | for i in range(N): 18 | r, g, b = 0, 0, 0 19 | id = i 20 | for j in range(7): 21 | str_id = uint82bin(id) 22 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 23 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 24 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 25 | id = id >> 3 26 | cmap[i, 0] = r 27 | cmap[i, 1] = g 28 | cmap[i, 2] = b 29 | return cmap 30 | 31 | def uint82bin(n, count=8): 32 | """returns the binary of integer n, count refers to amount of bits""" 33 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 34 | 35 | 36 | 37 | def colorize(gray_image, cmap): 38 | size = gray_image.shape 39 | color_image = np.zeros((size[0], size[1], 3), np.uint8) 40 | 41 | for label in range(0, len(cmap)): 42 | mask = (label == gray_image[:, :]) 43 | color_image[:, :, 0][mask] = cmap[label][0] 44 | color_image[:, :, 1][mask] = cmap[label][1] 45 | color_image[:, :, 2][mask] = cmap[label][2] 46 | 47 | return color_image 48 | 49 | 50 | 51 | folder_base = '/media/zhup/Data/CelebAMask-HQ/CelebAMaskHQ-mask' 52 | folder_save = '/media/zhup/Data/CelebAMask-HQ/CelebAMaskHQ-mask-vis' 53 | img_num = 30000 54 | 55 | make_folder(folder_save) 56 | my_cmp = labelcolormap(19) 57 | 58 | for k in range(img_num): 59 | 60 | filename = os.path.join(folder_base, str(k) + '.png') 61 | if (os.path.exists(filename)): 62 | print(k + 1) 63 | im = skimage.io.imread(filename) 64 | im_vis = colorize(im,my_cmp) 65 | 66 | filename_save = os.path.join(folder_save, str(k) + '.png') 67 | print(filename_save) 68 | skimage.io.imsave(filename_save,im_vis) -------------------------------------------------------------------------------- /code/models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /code/util/html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import datetime 7 | import dominate 8 | from dominate.tags import * 9 | import os 10 | 11 | 12 | class HTML: 13 | def __init__(self, web_dir, title, refresh=0): 14 | if web_dir.endswith('.html'): 15 | web_dir, html_name = os.path.split(web_dir) 16 | else: 17 | web_dir, html_name = web_dir, 'index.html' 18 | self.title = title 19 | self.web_dir = web_dir 20 | self.html_name = html_name 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 23 | os.makedirs(self.web_dir) 24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 25 | os.makedirs(self.img_dir) 26 | 27 | self.doc = dominate.document(title=title) 28 | with self.doc: 29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 30 | if refresh > 0: 31 | with self.doc.head: 32 | meta(http_equiv="refresh", content=str(refresh)) 33 | 34 | def get_image_dir(self): 35 | return self.img_dir 36 | 37 | def add_header(self, str): 38 | with self.doc: 39 | h3(str) 40 | 41 | def add_table(self, border=1): 42 | self.t = table(border=border, style="table-layout: fixed;") 43 | self.doc.add(self.t) 44 | 45 | def add_images(self, ims, txts, links, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | for im, txt, link in zip(ims, txts, links): 50 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 51 | with p(): 52 | with a(href=os.path.join('images', link)): 53 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 54 | br() 55 | p(txt.encode('utf-8')) 56 | 57 | def save(self): 58 | html_file = os.path.join(self.web_dir, self.html_name) 59 | f = open(html_file, 'wt') 60 | f.write(self.doc.render()) 61 | f.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | html = HTML('web/', 'test_html') 66 | html.add_header('hello world') 67 | 68 | ims = [] 69 | txts = [] 70 | links = [] 71 | for n in range(4): 72 | ims.append('image_%d.jpg' % n) 73 | txts.append('text_%d' % n) 74 | links.append('image_%d.jpg' % n) 75 | html.add_images(ims, txts, links) 76 | html.save() 77 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | from itertools import cycle 23 | import numpy as np 24 | import cv2 25 | 26 | from dataloaders import utils 27 | from dataloaders.dataset_covid import (CovidDataSets, RandomGenerator) 28 | from networks.net_factory import net_factory 29 | from utils import losses, metrics, ramps 30 | from test_covid import get_model_metric 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--root_path', type=str, default='/home/code/SSL/', help='Name of Experiment') 35 | parser.add_argument('--exp', type=str, default='Test', help='experiment_name') 36 | parser.add_argument('--model', type=str, default='unet2', help='model_name') 37 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 38 | parser.add_argument('--patch_size', type=list, default=[512, 512], help='patch size of network input') 39 | parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') 40 | # label and unlabel 41 | parser.add_argument('--labeled_per', type=float, default=0.1, help='percent of labeled data') 42 | 43 | if False: 44 | parser.add_argument('--dataset_name', type=str, default='COVID249', help='Name of dataset') 45 | parser.add_argument('--model_path', type=str, default='/home/code/SSL/exp/COVID249/model.pth', help='path of teacher model') 46 | else: 47 | parser.add_argument('--dataset_name', type=str, default='MOS1000', help='Name of dataset') 48 | parser.add_argument('--model_path', type=str, default='/home/code/SSL/exp/MOS1000/model.pth', help='path of teacher model') 49 | args = parser.parse_args() 50 | 51 | 52 | 53 | def test(args, snapshot_path): 54 | model = net_factory(net_type=args.model) 55 | model.load_state_dict(torch.load(args.model_path)) 56 | model.eval() 57 | 58 | nsd, dice = get_model_metric(args = args, model = model, snapshot_path=snapshot_path, model_name='model', mode='test') 59 | print('nsd : %f dice : %f ' % (nsd, dice)) 60 | 61 | 62 | 63 | if __name__ == "__main__": 64 | snapshot_path = "{}exp/{}/test_{}_{}_{}".format(args.root_path, args.dataset_name, args.exp, args.labeled_per, args.model) 65 | if not os.path.exists(snapshot_path): 66 | os.makedirs(snapshot_path) 67 | test(args, snapshot_path) -------------------------------------------------------------------------------- /code/util/iter_counter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | 10 | 11 | # Helper class that keeps track of training iterations 12 | class IterationCounter(): 13 | def __init__(self, opt, dataset_size): 14 | self.opt = opt 15 | self.dataset_size = dataset_size 16 | 17 | self.first_epoch = 1 18 | self.total_epochs = opt.niter + opt.niter_decay + 1000 19 | self.epoch_iter = 0 # iter number within each epoch 20 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 21 | if opt.isTrain and opt.continue_train: 22 | try: 23 | self.first_epoch, self.epoch_iter = np.loadtxt( 24 | self.iter_record_path, delimiter=',', dtype=int) 25 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 26 | except: 27 | print('Could not load iteration record at %s. Starting from beginning.' % 28 | self.iter_record_path) 29 | 30 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter 31 | 32 | # return the iterator of epochs for the training 33 | def training_epochs(self): 34 | return range(self.first_epoch, self.total_epochs + 1) 35 | 36 | def record_epoch_start(self, epoch): 37 | self.epoch_start_time = time.time() 38 | self.epoch_iter = 0 39 | self.last_iter_time = time.time() 40 | self.current_epoch = epoch 41 | 42 | def record_one_iteration(self): 43 | current_time = time.time() 44 | 45 | # the last remaining batch is dropped (see data/__init__.py), 46 | # so we can assume batch size is always opt.batchSize 47 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 48 | self.last_iter_time = current_time 49 | self.total_steps_so_far += self.opt.batchSize 50 | self.epoch_iter += self.opt.batchSize 51 | 52 | def record_epoch_end(self): 53 | current_time = time.time() 54 | self.time_per_epoch = current_time - self.epoch_start_time 55 | print('End of epoch %d / %d \t Time Taken: %d sec' % 56 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 57 | if self.current_epoch % self.opt.save_epoch_freq == 0: 58 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 59 | delimiter=',', fmt='%d') 60 | print('Saved current iteration count at %s.' % self.iter_record_path) 61 | 62 | def record_current_iter(self): 63 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 64 | delimiter=',', fmt='%d') 65 | print('Saved current iteration count at %s.' % self.iter_record_path) 66 | 67 | def needs_saving(self): 68 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 69 | 70 | def needs_printing(self): 71 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 72 | 73 | def needs_displaying(self): 74 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 75 | -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /code/networks/unet2d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the U-Net paper: 4 | Olaf Ronneberger, Philipp Fischer, Thomas Brox: 5 | U-Net: Convolutional Networks for Biomedical Image Segmentation. 6 | MICCAI (3) 2015: 234-241 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | """ 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torchvision import models 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | class double_conv(nn.Module): 20 | def __init__(self, in_ch, out_ch): 21 | super(double_conv, self).__init__() 22 | self.conv = nn.Sequential( 23 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 24 | #nn.BatchNorm2d(out_ch), 25 | nn.InstanceNorm2d(out_ch), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 28 | #nn.BatchNorm2d(out_ch), 29 | nn.InstanceNorm2d(out_ch), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class inconv(nn.Module): 39 | def __init__(self, in_ch, out_ch): 40 | super(inconv, self).__init__() 41 | self.conv = double_conv(in_ch, out_ch) 42 | 43 | def forward(self, x): 44 | x = self.conv(x) 45 | return x 46 | 47 | 48 | class down(nn.Module): 49 | def __init__(self, in_ch, out_ch): 50 | super(down, self).__init__() 51 | self.max_pool_conv = nn.Sequential( 52 | nn.MaxPool2d(2), 53 | double_conv(in_ch, out_ch) 54 | ) 55 | 56 | def forward(self, x): 57 | x = self.max_pool_conv(x) 58 | return x 59 | 60 | 61 | class up(nn.Module): 62 | def __init__(self, in_ch, out_ch, bilinear=True): 63 | super(up, self).__init__() 64 | if bilinear: 65 | #self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 66 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 67 | else: 68 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 69 | 70 | self.conv = double_conv(in_ch, out_ch) 71 | 72 | def forward(self, x1, x2): 73 | x1 = self.up(x1) 74 | diffX = x1.size()[2] - x2.size()[2] 75 | diffY = x1.size()[3] - x2.size()[3] 76 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), diffY // 2, int(diffY / 2))) 77 | x = torch.cat([x2, x1], dim=1) 78 | x = self.conv(x) 79 | return x 80 | 81 | 82 | class outconv(nn.Module): 83 | def __init__(self, in_ch, out_ch): 84 | super(outconv, self).__init__() 85 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | class UNet2D(nn.Module): 93 | def __init__(self, n_channels=3, n_classes=2): 94 | super(UNet2D, self).__init__() 95 | self.inc = inconv(n_channels, 32) 96 | self.down1 = down(32, 64) 97 | self.down2 = down(64, 128) 98 | self.down3 = down(128, 256) 99 | self.down4 = down(256, 256) 100 | self.up1 = up(512, 128) 101 | self.up2 = up(256, 64) 102 | self.up3 = up(128, 32) 103 | self.up4 = up(64, 32) 104 | self.outc = outconv(32, n_classes) 105 | self.relu = nn.ReLU() 106 | 107 | def forward(self, x): 108 | x1 = self.inc(x) 109 | x2 = self.down1(x1) 110 | x3 = self.down2(x2) 111 | x4 = self.down3(x3) 112 | x5 = self.down4(x4) 113 | x = self.up1(x5, x4) 114 | x = self.up2(x, x3) 115 | x = self.up3(x, x2) 116 | x = self.up4(x, x1) 117 | x = self.outc(x) 118 | #x = self.relu(x) 119 | return x 120 | -------------------------------------------------------------------------------- /code/dataloaders/dataset_covid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import cv2 4 | import torch 5 | import random 6 | import numpy as np 7 | from glob import glob 8 | from torch.utils.data import Dataset 9 | from scipy.ndimage.interpolation import zoom 10 | import itertools 11 | from scipy import ndimage 12 | from torch.utils.data.sampler import Sampler 13 | import pandas as pd 14 | from PIL import Image 15 | 16 | 17 | class CovidDataSets(Dataset): 18 | def __init__(self, root_path=None, dataset_name='COVID249', file_name = 'val_slice.xlsx', aug=False): 19 | self.root_path = root_path 20 | self.file_name = file_name 21 | self.dataset_name = dataset_name 22 | self.file_path = root_path + "data/{}/{}".format(dataset_name, file_name) 23 | self.aug = aug 24 | 25 | 26 | excelData = pd.read_excel(self.file_path) 27 | length = excelData.shape[0] 28 | self.paths = [] 29 | for i in range(length): 30 | file_name_i = excelData.iloc[i][0] 31 | self.paths.append(file_name_i) 32 | 33 | def __len__(self): 34 | return len(self.paths) 35 | 36 | 37 | def __getitem__(self, idx): 38 | case = self.paths[idx] 39 | 40 | case_img_path = self.root_path + "data/{}/PNG/images/{}".format(self.dataset_name, case) 41 | case_label_path = self.root_path + "data/{}/PNG/labels/{}".format(self.dataset_name, case) 42 | case_lung_path = self.root_path + "data/{}/PNG/lung/{}".format(self.dataset_name, case) 43 | 44 | image = Image.open(case_img_path) 45 | 46 | if os.path.exists(case_label_path): 47 | label = Image.open(case_label_path) 48 | else: 49 | label = Image.open(case_lung_path) 50 | 51 | lung = Image.open(case_lung_path) 52 | 53 | if self.aug: 54 | if random.random() > 0.5: 55 | image, label, lung = random_rot_flip(image, label, lung) 56 | elif random.random() > 0.5: 57 | image, label, lung = random_rotate(image, label, lung) 58 | 59 | image = (torch.from_numpy(np.asarray(image).astype(np.float32)).permute(2, 0, 1).contiguous())/255.0 60 | label = torch.from_numpy(np.asarray(label).astype(np.uint8)) 61 | lung = torch.from_numpy(np.asarray(lung).astype(np.uint8)) 62 | 63 | return image, label, case, lung 64 | 65 | 66 | def random_rot_flip(image, label, lung): 67 | k = np.random.randint(0, 4) 68 | image = np.rot90(image, k) 69 | label = np.rot90(label, k) 70 | lung = np.rot90(lung, k) 71 | axis = np.random.randint(0, 2) 72 | image = np.flip(image, axis=axis).copy() 73 | label = np.flip(label, axis=axis).copy() 74 | lung = np.flip(lung, axis=axis).copy() 75 | return image, label, lung 76 | 77 | 78 | def random_rotate(image, label, lung): 79 | angle = np.random.randint(-20, 20) 80 | image = ndimage.rotate(image, angle, order=0, reshape=False) 81 | label = ndimage.rotate(label, angle, order=0, reshape=False) 82 | lung = ndimage.rotate(lung, angle, order=0, reshape=False) 83 | return image, label, lung 84 | 85 | 86 | def rotate_90(image, label, lung): 87 | angle = 90 88 | image = ndimage.rotate(image, angle, order=0, reshape=False) 89 | label = ndimage.rotate(label, angle, order=0, reshape=False) 90 | lung = ndimage.rotate(lung, angle, order=0, reshape=False) 91 | return image, label, lung 92 | 93 | def rotate_n90(image, label, lung): 94 | angle = -90 95 | image = ndimage.rotate(image, angle, order=0, reshape=False) 96 | label = ndimage.rotate(label, angle, order=0, reshape=False) 97 | lung = ndimage.rotate(lung, angle, order=0, reshape=False) 98 | return image, label, lung 99 | 100 | 101 | 102 | class RandomGenerator(object): 103 | def __init__(self, output_size): 104 | self.output_size = output_size 105 | 106 | def __call__(self, image, label): 107 | if random.random() > 0.5: 108 | image, label = random_rot_flip(image, label) 109 | elif random.random() > 0.5: 110 | image, label = random_rotate(image, label) 111 | x, y = image.shape 112 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 113 | label = torch.from_numpy(label.astype(np.uint8)) 114 | return image, label 115 | 116 | 117 | -------------------------------------------------------------------------------- /code/models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | import util.util as util 12 | 13 | 14 | class MultiscaleDiscriminator(BaseNetwork): 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 18 | help='architecture of each discriminator') 19 | parser.add_argument('--num_D', type=int, default=2, 20 | help='number of discriminators to be used in multiscale') 21 | opt, _ = parser.parse_known_args() 22 | 23 | # define properties of each discriminator of the multiscale discriminator 24 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 25 | 'models.networks.discriminator') 26 | subnetD.modify_commandline_options(parser, is_train) 27 | 28 | return parser 29 | 30 | def __init__(self, opt): 31 | super().__init__() 32 | self.opt = opt 33 | 34 | for i in range(opt.num_D): 35 | subnetD = self.create_single_discriminator(opt) 36 | self.add_module('discriminator_%d' % i, subnetD) 37 | 38 | def create_single_discriminator(self, opt): 39 | subarch = opt.netD_subarch 40 | if subarch == 'n_layer': 41 | netD = NLayerDiscriminator(opt) 42 | else: 43 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 44 | return netD 45 | 46 | def downsample(self, input): 47 | return F.avg_pool2d(input, kernel_size=3, 48 | stride=2, padding=[1, 1], 49 | count_include_pad=False) 50 | 51 | # Returns list of lists of discriminator outputs. 52 | # The final result is of size opt.num_D x opt.n_layers_D 53 | def forward(self, input): 54 | result = [] 55 | get_intermediate_features = not self.opt.no_ganFeat_loss 56 | for name, D in self.named_children(): 57 | out = D(input) 58 | if not get_intermediate_features: 59 | out = [out] 60 | result.append(out) 61 | input = self.downsample(input) 62 | 63 | return result 64 | 65 | 66 | # Defines the PatchGAN discriminator with the specified arguments. 67 | class NLayerDiscriminator(BaseNetwork): 68 | @staticmethod 69 | def modify_commandline_options(parser, is_train): 70 | parser.add_argument('--n_layers_D', type=int, default=3, 71 | help='# layers in each discriminator') 72 | return parser 73 | 74 | def __init__(self, opt): 75 | super().__init__() 76 | self.opt = opt 77 | 78 | kw = 4 79 | padw = int(np.ceil((kw - 1.0) / 2)) 80 | nf = opt.ndf 81 | input_nc = self.compute_D_input_nc(opt) 82 | 83 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 84 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 85 | nn.LeakyReLU(0.2, False)]] 86 | 87 | for n in range(1, opt.n_layers_D): 88 | nf_prev = nf 89 | nf = min(nf * 2, 512) 90 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 91 | stride=2, padding=padw)), 92 | nn.LeakyReLU(0.2, False) 93 | ]] 94 | 95 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 96 | 97 | # We divide the layers into groups to extract intermediate layer outputs 98 | for n in range(len(sequence)): 99 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 100 | 101 | def compute_D_input_nc(self, opt): 102 | input_nc = opt.label_nc + opt.output_nc 103 | if opt.contain_dontcare_label: 104 | input_nc += 1 105 | if not opt.no_instance: 106 | input_nc += 1 107 | return input_nc 108 | 109 | def forward(self, input): 110 | results = [input] 111 | for submodel in self.children(): 112 | intermediate_output = submodel(results[-1]) 113 | results.append(intermediate_output) 114 | 115 | get_intermediate_features = not self.opt.no_ganFeat_loss 116 | if get_intermediate_features: 117 | return results[1:] 118 | else: 119 | return results[-1] 120 | -------------------------------------------------------------------------------- /code/models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /code/models/networks/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | from models.networks.architecture import ResnetBlock as ResnetBlock 12 | from models.networks.architecture import SPADEResnetBlock as SPADEResnetBlock 13 | from models.networks.architecture import Zencoder 14 | import random 15 | 16 | class SPADEGenerator(BaseNetwork): 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | parser.set_defaults(norm_G='spectralspadesyncbatch3x3') 20 | parser.add_argument('--num_upsampling_layers', 21 | choices=('normal', 'more', 'most'), default='normal', 22 | help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") 23 | 24 | return parser 25 | 26 | def __init__(self, opt): 27 | super().__init__() 28 | self.opt = opt 29 | nf = opt.ngf 30 | 31 | self.sw, self.sh = self.compute_latent_vector_size(opt) 32 | 33 | self.Zencoder = Zencoder(3, 512) 34 | 35 | 36 | self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1) 37 | 38 | self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='head_0') 39 | 40 | self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_0') 41 | self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt, Block_Name='G_middle_1') 42 | 43 | self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt, Block_Name='up_0') 44 | self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt, Block_Name='up_1') 45 | self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt, Block_Name='up_2') 46 | self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt, Block_Name='up_3', use_rgb=False) 47 | 48 | final_nc = nf 49 | 50 | if opt.num_upsampling_layers == 'most': 51 | self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt, Block_Name='up_4') 52 | final_nc = nf // 2 53 | 54 | self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) 55 | 56 | self.up = nn.Upsample(scale_factor=2) 57 | #self.up = nn.Upsample(scale_factor=2, mode='bilinear') 58 | 59 | 60 | def compute_latent_vector_size(self, opt): 61 | if opt.num_upsampling_layers == 'normal': 62 | num_up_layers = 5 63 | elif opt.num_upsampling_layers == 'more': 64 | num_up_layers = 6 65 | elif opt.num_upsampling_layers == 'most': 66 | num_up_layers = 7 67 | else: 68 | raise ValueError('opt.num_upsampling_layers [%s] not recognized' % 69 | opt.num_upsampling_layers) 70 | 71 | sw = opt.crop_size // (2**num_up_layers) 72 | sh = round(sw / opt.aspect_ratio) 73 | 74 | return sw, sh 75 | 76 | def forward(self, input, rgb_img, obj_dic=None, return_style = False, style_input=None, alpha=0): 77 | seg = input 78 | 79 | x = F.interpolate(seg, size=(self.sh, self.sw)) 80 | x = self.fc(x) 81 | 82 | style_codes = self.Zencoder(input=rgb_img, segmap=seg) 83 | 84 | #---------------------------------------------------------------------------------------------------------- 85 | if return_style: 86 | return style_codes 87 | 88 | 89 | if style_input is not None and len(style_input)>0: 90 | extra_codes = torch.mean(torch.cat(style_input, 0), 0, keepdim=True) 91 | style_codes = alpha*style_codes + (1-alpha)*extra_codes 92 | 93 | #---------------------------------------------------------------------------------------------------------- 94 | 95 | 96 | x = self.head_0(x, seg, style_codes, obj_dic=obj_dic) 97 | 98 | x = self.up(x) 99 | x = self.G_middle_0(x, seg, style_codes, obj_dic=obj_dic) 100 | 101 | if self.opt.num_upsampling_layers == 'more' or \ 102 | self.opt.num_upsampling_layers == 'most': 103 | x = self.up(x) 104 | 105 | x = self.G_middle_1(x, seg, style_codes, obj_dic=obj_dic) 106 | 107 | x = self.up(x) 108 | x = self.up_0(x, seg, style_codes, obj_dic=obj_dic) 109 | x = self.up(x) 110 | x = self.up_1(x, seg, style_codes, obj_dic=obj_dic) 111 | x = self.up(x) 112 | x = self.up_2(x, seg, style_codes, obj_dic=obj_dic) 113 | x = self.up(x) 114 | x = self.up_3(x, seg, style_codes, obj_dic=obj_dic) 115 | 116 | # if self.opt.num_upsampling_layers == 'most': 117 | # x = self.up(x) 118 | # x= self.up_4(x, seg, style_codes, obj_dic=obj_dic) 119 | 120 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 121 | x = F.tanh(x) 122 | return x 123 | -------------------------------------------------------------------------------- /code/models/networks/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from models.networks.architecture import VGG19 10 | 11 | 12 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 13 | # When LSGAN is used, it is basically same as MSELoss, 14 | # but it abstracts away the need to create the target label tensor 15 | # that has the same size as the input 16 | class GANLoss(nn.Module): 17 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 18 | tensor=torch.FloatTensor, opt=None): 19 | super(GANLoss, self).__init__() 20 | self.real_label = target_real_label 21 | self.fake_label = target_fake_label 22 | self.real_label_tensor = None 23 | self.fake_label_tensor = None 24 | self.zero_tensor = None 25 | self.Tensor = tensor 26 | self.gan_mode = gan_mode 27 | self.opt = opt 28 | if gan_mode == 'ls': 29 | pass 30 | elif gan_mode == 'original': 31 | pass 32 | elif gan_mode == 'w': 33 | pass 34 | elif gan_mode == 'hinge': 35 | pass 36 | else: 37 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 38 | 39 | def get_target_tensor(self, input, target_is_real): 40 | if target_is_real: 41 | if self.real_label_tensor is None: 42 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 43 | self.real_label_tensor.requires_grad_(False) 44 | return self.real_label_tensor.expand_as(input) 45 | else: 46 | if self.fake_label_tensor is None: 47 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 48 | self.fake_label_tensor.requires_grad_(False) 49 | return self.fake_label_tensor.expand_as(input) 50 | 51 | def get_zero_tensor(self, input): 52 | if self.zero_tensor is None: 53 | self.zero_tensor = self.Tensor(1).fill_(0) 54 | self.zero_tensor.requires_grad_(False) 55 | return self.zero_tensor.expand_as(input) 56 | 57 | def loss(self, input, target_is_real, for_discriminator=True): 58 | if self.gan_mode == 'original': # cross entropy loss 59 | target_tensor = self.get_target_tensor(input, target_is_real) 60 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 61 | return loss 62 | elif self.gan_mode == 'ls': 63 | target_tensor = self.get_target_tensor(input, target_is_real) 64 | return F.mse_loss(input, target_tensor) 65 | elif self.gan_mode == 'hinge': 66 | if for_discriminator: 67 | if target_is_real: 68 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 69 | loss = -torch.mean(minval) 70 | else: 71 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 72 | loss = -torch.mean(minval) 73 | else: 74 | assert target_is_real, "The generator's hinge loss must be aiming for real" 75 | loss = -torch.mean(input) 76 | return loss 77 | else: 78 | # wgan 79 | if target_is_real: 80 | return -input.mean() 81 | else: 82 | return input.mean() 83 | 84 | def __call__(self, input, target_is_real, for_discriminator=True): 85 | # computing loss is a bit complicated because |input| may not be 86 | # a tensor, but list of tensors in case of multiscale discriminator 87 | if isinstance(input, list): 88 | loss = 0 89 | for pred_i in input: 90 | if isinstance(pred_i, list): 91 | pred_i = pred_i[-1] 92 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 93 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 94 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 95 | loss += new_loss 96 | return loss / len(input) 97 | else: 98 | return self.loss(input, target_is_real, for_discriminator) 99 | 100 | 101 | # Perceptual loss that uses a pretrained VGG network 102 | class VGGLoss(nn.Module): 103 | def __init__(self, gpu_ids): 104 | super(VGGLoss, self).__init__() 105 | self.vgg = VGG19().cuda() 106 | self.criterion = nn.L1Loss() 107 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 108 | 109 | def forward(self, x, y): 110 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 111 | loss = 0 112 | for i in range(len(x_vgg)): 113 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 114 | return loss 115 | 116 | 117 | # KL Divergence loss used in VAE with an image encoder 118 | # class KLDLoss(nn.Module): 119 | # def forward(self, mu, logvar): 120 | # return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 121 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | import numpy as np 10 | from scipy.ndimage import distance_transform_edt as distance 11 | from skimage import segmentation as skimage_seg 12 | import torch 13 | from torch.utils.data.sampler import Sampler 14 | 15 | import networks 16 | 17 | def load_model(path): 18 | """Loads model and return it without DataParallel table.""" 19 | if os.path.isfile(path): 20 | print("=> loading checkpoint '{}'".format(path)) 21 | checkpoint = torch.load(path) 22 | 23 | # size of the top layer 24 | N = checkpoint['state_dict']['top_layer.bias'].size() 25 | 26 | # build skeleton of the model 27 | sob = 'sobel.0.weight' in checkpoint['state_dict'].keys() 28 | model = models.__dict__[checkpoint['arch']](sobel=sob, out=int(N[0])) 29 | 30 | # deal with a dataparallel table 31 | def rename_key(key): 32 | if not 'module' in key: 33 | return key 34 | return ''.join(key.split('.module')) 35 | 36 | checkpoint['state_dict'] = {rename_key(key): val 37 | for key, val 38 | in checkpoint['state_dict'].items()} 39 | 40 | # load weights 41 | model.load_state_dict(checkpoint['state_dict']) 42 | print("Loaded") 43 | else: 44 | model = None 45 | print("=> no checkpoint found at '{}'".format(path)) 46 | return model 47 | 48 | 49 | class UnifLabelSampler(Sampler): 50 | """Samples elements uniformely accross pseudolabels. 51 | Args: 52 | N (int): size of returned iterator. 53 | images_lists: dict of key (target), value (list of data with this target) 54 | """ 55 | 56 | def __init__(self, N, images_lists): 57 | self.N = N 58 | self.images_lists = images_lists 59 | self.indexes = self.generate_indexes_epoch() 60 | 61 | def generate_indexes_epoch(self): 62 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 63 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 64 | 65 | for i in range(len(self.images_lists)): 66 | indexes = np.random.choice( 67 | self.images_lists[i], 68 | size_per_pseudolabel, 69 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 70 | ) 71 | res[i * size_per_pseudolabel: (i + 1) * size_per_pseudolabel] = indexes 72 | 73 | np.random.shuffle(res) 74 | return res[:self.N].astype('int') 75 | 76 | def __iter__(self): 77 | return iter(self.indexes) 78 | 79 | def __len__(self): 80 | return self.N 81 | 82 | 83 | class AverageMeter(object): 84 | """Computes and stores the average and current value""" 85 | def __init__(self): 86 | self.reset() 87 | 88 | def reset(self): 89 | self.val = 0 90 | self.avg = 0 91 | self.sum = 0 92 | self.count = 0 93 | 94 | def update(self, val, n=1): 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | 101 | def learning_rate_decay(optimizer, t, lr_0): 102 | for param_group in optimizer.param_groups: 103 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 104 | param_group['lr'] = lr 105 | 106 | 107 | class Logger(): 108 | """ Class to update every epoch to keep trace of the results 109 | Methods: 110 | - log() log and save 111 | """ 112 | 113 | def __init__(self, path): 114 | self.path = path 115 | self.data = [] 116 | 117 | def log(self, train_point): 118 | self.data.append(train_point) 119 | with open(os.path.join(self.path), 'wb') as fp: 120 | pickle.dump(self.data, fp, -1) 121 | 122 | 123 | def compute_sdf(img_gt, out_shape): 124 | """ 125 | compute the signed distance map of binary mask 126 | input: segmentation, shape = (batch_size, x, y, z) 127 | output: the Signed Distance Map (SDM) 128 | sdf(x) = 0; x in segmentation boundary 129 | -inf|x-y|; x in segmentation 130 | +inf|x-y|; x out of segmentation 131 | normalize sdf to [-1,1] 132 | """ 133 | 134 | img_gt = img_gt.astype(np.uint8) 135 | normalized_sdf = np.zeros(out_shape) 136 | 137 | for b in range(out_shape[0]): # batch size 138 | posmask = img_gt[b].astype(np.bool) 139 | if posmask.any(): 140 | negmask = ~posmask 141 | posdis = distance(posmask) 142 | negdis = distance(negmask) 143 | boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) 144 | sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) 145 | sdf[boundary==1] = 0 146 | normalized_sdf[b] = sdf 147 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 148 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 149 | 150 | return normalized_sdf -------------------------------------------------------------------------------- /code/dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import numpy as np 6 | from glob import glob 7 | from torch.utils.data import Dataset 8 | import h5py 9 | from scipy.ndimage.interpolation import zoom 10 | import itertools 11 | from scipy import ndimage 12 | from torch.utils.data.sampler import Sampler 13 | 14 | 15 | class BaseDataSets(Dataset): 16 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 17 | self._base_dir = base_dir 18 | self.sample_list = [] 19 | self.split = split 20 | self.transform = transform 21 | if self.split == 'train': 22 | with open(self._base_dir + '/train_slices.list', 'r') as f1: 23 | self.sample_list = f1.readlines() 24 | self.sample_list = [item.replace('\n', '') 25 | for item in self.sample_list] 26 | 27 | elif self.split == 'val': 28 | with open(self._base_dir + '/val.list', 'r') as f: 29 | self.sample_list = f.readlines() 30 | self.sample_list = [item.replace('\n', '') 31 | for item in self.sample_list] 32 | if num is not None and self.split == "train": 33 | self.sample_list = self.sample_list[:num] 34 | print("total {} samples".format(len(self.sample_list))) 35 | 36 | def __len__(self): 37 | return len(self.sample_list) 38 | 39 | def __getitem__(self, idx): 40 | case = self.sample_list[idx] 41 | if self.split == "train": 42 | h5f = h5py.File(self._base_dir + 43 | "/data/slices/{}.h5".format(case), 'r') 44 | else: 45 | h5f = h5py.File(self._base_dir + "/data/{}.h5".format(case), 'r') 46 | image = h5f['image'][:] 47 | label = h5f['label'][:] 48 | sample = {'image': image, 'label': label} 49 | if self.split == "train": 50 | sample = self.transform(sample) 51 | sample["idx"] = idx 52 | return sample 53 | 54 | 55 | def random_rot_flip(image, label): 56 | k = np.random.randint(0, 4) 57 | image = np.rot90(image, k) 58 | label = np.rot90(label, k) 59 | axis = np.random.randint(0, 2) 60 | image = np.flip(image, axis=axis).copy() 61 | label = np.flip(label, axis=axis).copy() 62 | return image, label 63 | 64 | 65 | def random_rotate(image, label): 66 | angle = np.random.randint(-20, 20) 67 | image = ndimage.rotate(image, angle, order=0, reshape=False) 68 | label = ndimage.rotate(label, angle, order=0, reshape=False) 69 | return image, label 70 | 71 | 72 | class RandomGenerator(object): 73 | def __init__(self, output_size): 74 | self.output_size = output_size 75 | 76 | def __call__(self, sample): 77 | image, label = sample['image'], sample['label'] 78 | # ind = random.randrange(0, img.shape[0]) 79 | # image = img[ind, ...] 80 | # label = lab[ind, ...] 81 | if random.random() > 0.5: 82 | image, label = random_rot_flip(image, label) 83 | elif random.random() > 0.5: 84 | image, label = random_rotate(image, label) 85 | x, y = image.shape 86 | image = zoom( 87 | image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 88 | label = zoom( 89 | label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 90 | image = torch.from_numpy( 91 | image.astype(np.float32)).unsqueeze(0) 92 | label = torch.from_numpy(label.astype(np.uint8)) 93 | sample = {'image': image, 'label': label} 94 | return sample 95 | 96 | 97 | class TwoStreamBatchSampler(Sampler): 98 | """Iterate two sets of indices 99 | 100 | An 'epoch' is one iteration through the primary indices. 101 | During the epoch, the secondary indices are iterated through 102 | as many times as needed. 103 | """ 104 | 105 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 106 | self.primary_indices = primary_indices 107 | self.secondary_indices = secondary_indices 108 | self.secondary_batch_size = secondary_batch_size 109 | self.primary_batch_size = batch_size - secondary_batch_size 110 | 111 | assert len(self.primary_indices) >= self.primary_batch_size > 0 112 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 113 | 114 | def __iter__(self): 115 | primary_iter = iterate_once(self.primary_indices) 116 | secondary_iter = iterate_eternally(self.secondary_indices) 117 | return ( 118 | primary_batch + secondary_batch 119 | for (primary_batch, secondary_batch) 120 | in zip(grouper(primary_iter, self.primary_batch_size), 121 | grouper(secondary_iter, self.secondary_batch_size)) 122 | ) 123 | 124 | def __len__(self): 125 | return len(self.primary_indices) // self.primary_batch_size 126 | 127 | 128 | def iterate_once(iterable): 129 | return np.random.permutation(iterable) 130 | 131 | 132 | def iterate_eternally(indices): 133 | def infinite_shuffles(): 134 | while True: 135 | yield np.random.permutation(indices) 136 | return itertools.chain.from_iterable(infinite_shuffles()) 137 | 138 | 139 | def grouper(iterable, n): 140 | "Collect data into fixed-length chunks or blocks" 141 | # grouper('ABCDEFG', 3) --> ABC DEF" 142 | args = [iter(iterable)] * n 143 | return zip(*args) 144 | -------------------------------------------------------------------------------- /code/util/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | 7 | def id2label(id): 8 | if id == 182: 9 | id = 0 10 | else: 11 | id = id + 1 12 | labelmap = \ 13 | {0: 'unlabeled', 14 | 1: 'person', 15 | 2: 'bicycle', 16 | 3: 'car', 17 | 4: 'motorcycle', 18 | 5: 'airplane', 19 | 6: 'bus', 20 | 7: 'train', 21 | 8: 'truck', 22 | 9: 'boat', 23 | 10: 'traffic light', 24 | 11: 'fire hydrant', 25 | 12: 'street sign', 26 | 13: 'stop sign', 27 | 14: 'parking meter', 28 | 15: 'bench', 29 | 16: 'bird', 30 | 17: 'cat', 31 | 18: 'dog', 32 | 19: 'horse', 33 | 20: 'sheep', 34 | 21: 'cow', 35 | 22: 'elephant', 36 | 23: 'bear', 37 | 24: 'zebra', 38 | 25: 'giraffe', 39 | 26: 'hat', 40 | 27: 'backpack', 41 | 28: 'umbrella', 42 | 29: 'shoe', 43 | 30: 'eye glasses', 44 | 31: 'handbag', 45 | 32: 'tie', 46 | 33: 'suitcase', 47 | 34: 'frisbee', 48 | 35: 'skis', 49 | 36: 'snowboard', 50 | 37: 'sports ball', 51 | 38: 'kite', 52 | 39: 'baseball bat', 53 | 40: 'baseball glove', 54 | 41: 'skateboard', 55 | 42: 'surfboard', 56 | 43: 'tennis racket', 57 | 44: 'bottle', 58 | 45: 'plate', 59 | 46: 'wine glass', 60 | 47: 'cup', 61 | 48: 'fork', 62 | 49: 'knife', 63 | 50: 'spoon', 64 | 51: 'bowl', 65 | 52: 'banana', 66 | 53: 'apple', 67 | 54: 'sandwich', 68 | 55: 'orange', 69 | 56: 'broccoli', 70 | 57: 'carrot', 71 | 58: 'hot dog', 72 | 59: 'pizza', 73 | 60: 'donut', 74 | 61: 'cake', 75 | 62: 'chair', 76 | 63: 'couch', 77 | 64: 'potted plant', 78 | 65: 'bed', 79 | 66: 'mirror', 80 | 67: 'dining table', 81 | 68: 'window', 82 | 69: 'desk', 83 | 70: 'toilet', 84 | 71: 'door', 85 | 72: 'tv', 86 | 73: 'laptop', 87 | 74: 'mouse', 88 | 75: 'remote', 89 | 76: 'keyboard', 90 | 77: 'cell phone', 91 | 78: 'microwave', 92 | 79: 'oven', 93 | 80: 'toaster', 94 | 81: 'sink', 95 | 82: 'refrigerator', 96 | 83: 'blender', 97 | 84: 'book', 98 | 85: 'clock', 99 | 86: 'vase', 100 | 87: 'scissors', 101 | 88: 'teddy bear', 102 | 89: 'hair drier', 103 | 90: 'toothbrush', 104 | 91: 'hair brush', # Last class of Thing 105 | 92: 'banner', # Beginning of Stuff 106 | 93: 'blanket', 107 | 94: 'branch', 108 | 95: 'bridge', 109 | 96: 'building-other', 110 | 97: 'bush', 111 | 98: 'cabinet', 112 | 99: 'cage', 113 | 100: 'cardboard', 114 | 101: 'carpet', 115 | 102: 'ceiling-other', 116 | 103: 'ceiling-tile', 117 | 104: 'cloth', 118 | 105: 'clothes', 119 | 106: 'clouds', 120 | 107: 'counter', 121 | 108: 'cupboard', 122 | 109: 'curtain', 123 | 110: 'desk-stuff', 124 | 111: 'dirt', 125 | 112: 'door-stuff', 126 | 113: 'fence', 127 | 114: 'floor-marble', 128 | 115: 'floor-other', 129 | 116: 'floor-stone', 130 | 117: 'floor-tile', 131 | 118: 'floor-wood', 132 | 119: 'flower', 133 | 120: 'fog', 134 | 121: 'food-other', 135 | 122: 'fruit', 136 | 123: 'furniture-other', 137 | 124: 'grass', 138 | 125: 'gravel', 139 | 126: 'ground-other', 140 | 127: 'hill', 141 | 128: 'house', 142 | 129: 'leaves', 143 | 130: 'light', 144 | 131: 'mat', 145 | 132: 'metal', 146 | 133: 'mirror-stuff', 147 | 134: 'moss', 148 | 135: 'mountain', 149 | 136: 'mud', 150 | 137: 'napkin', 151 | 138: 'net', 152 | 139: 'paper', 153 | 140: 'pavement', 154 | 141: 'pillow', 155 | 142: 'plant-other', 156 | 143: 'plastic', 157 | 144: 'platform', 158 | 145: 'playingfield', 159 | 146: 'railing', 160 | 147: 'railroad', 161 | 148: 'river', 162 | 149: 'road', 163 | 150: 'rock', 164 | 151: 'roof', 165 | 152: 'rug', 166 | 153: 'salad', 167 | 154: 'sand', 168 | 155: 'sea', 169 | 156: 'shelf', 170 | 157: 'sky-other', 171 | 158: 'skyscraper', 172 | 159: 'snow', 173 | 160: 'solid-other', 174 | 161: 'stairs', 175 | 162: 'stone', 176 | 163: 'straw', 177 | 164: 'structural-other', 178 | 165: 'table', 179 | 166: 'tent', 180 | 167: 'textile-other', 181 | 168: 'towel', 182 | 169: 'tree', 183 | 170: 'vegetable', 184 | 171: 'wall-brick', 185 | 172: 'wall-concrete', 186 | 173: 'wall-other', 187 | 174: 'wall-panel', 188 | 175: 'wall-stone', 189 | 176: 'wall-tile', 190 | 177: 'wall-wood', 191 | 178: 'water-other', 192 | 179: 'waterdrops', 193 | 180: 'window-blind', 194 | 181: 'window-other', 195 | 182: 'wood'} 196 | if id in labelmap: 197 | return labelmap[id] 198 | else: 199 | return 'unknown' 200 | -------------------------------------------------------------------------------- /code/test_covid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import zoom 13 | from tqdm import tqdm 14 | 15 | # from networks.efficientunet import UNet 16 | from networks.net_factory import net_factory 17 | from dataloaders.dataset_covid import (CovidDataSets, RandomGenerator) 18 | from torch.utils.data import DataLoader 19 | import cv2 20 | import pandas as pd 21 | from utils.distance_metric import compute_surface_distances, compute_surface_dice_at_tolerance, compute_dice_coefficient, compute_robust_hausdorff 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root_path', type=str, default='home/code/SSL/', help='Name of Experiment') 26 | parser.add_argument('--dataset_name', type=str, default='COVID249', help='Name of dataset') 27 | parser.add_argument('--exp', type=str, default='Cross_Pseudo_Supervision', help='experiment_name') 28 | parser.add_argument('--model', type=str, default='unet', help='model_name') 29 | 30 | 31 | def save_sample_png(png_results_path, file_name , out, rot = False): 32 | split_list = file_name.split('_') 33 | if len(split_list)>2: 34 | volume_num = split_list[0]+'_'+split_list[1] 35 | save_img_name = split_list[2] 36 | else: 37 | volume_num = split_list[0] 38 | save_img_name = split_list[1] 39 | 40 | volume_path = os.path.join(png_results_path, volume_num) 41 | if not os.path.exists(volume_path): 42 | os.makedirs(volume_path) 43 | img = out * 255 44 | img = np.clip(img, 0, 255) 45 | img = img.astype(np.uint8) 46 | #cv2.imshow('result', th2) 47 | #cv2.waitKey(0) 48 | #------------------------------------------------------------------------ 49 | # Save to certain path 50 | save_img_path = os.path.join(volume_path, save_img_name) 51 | cv2.imwrite(save_img_path, img) 52 | 53 | 54 | def pngs_2_niigz(args, png_results_path, nii_results_path, file_volume_name): 55 | volume_files = pd.read_excel(args.root_path + "data/{}/{}".format(args.dataset_name, file_volume_name)) 56 | length = volume_files.shape[0] 57 | for idx in range(length): 58 | volume_file = volume_files.iloc[idx][0] 59 | # load original nii 60 | ori_path= args.root_path + "data/{}/NII/{}".format(args.dataset_name, volume_file+'_ct.nii.gz') 61 | ori_nii = sitk.ReadImage(ori_path , sitk.sitkUInt8) 62 | ori_data = sitk.GetArrayFromImage(ori_nii) 63 | 64 | volume_png_folder = os.path.join(png_results_path, volume_file) 65 | if os.path.exists(volume_png_folder): 66 | png_files = os.listdir(volume_png_folder) 67 | for png_file in png_files: 68 | png_file_slice = int(png_file.split('.')[0]) 69 | png_file_data = cv2.imread(os.path.join(volume_png_folder, png_file), -1) 70 | ori_data_slice = ori_data[png_file_slice, :,:] 71 | ori_data_slice = 0*ori_data_slice 72 | ori_data_slice[png_file_data==255]=1 73 | ori_data[png_file_slice, :,:] = ori_data_slice 74 | 75 | #save nii 76 | out_path = os.path.join(nii_results_path, volume_file+'.nii.gz') 77 | img_new = sitk.GetImageFromArray(ori_data) 78 | img_new.CopyInformation(ori_nii) 79 | sitk.WriteImage(img_new, out_path) 80 | print(volume_file) 81 | 82 | 83 | def evaluate_nii(args, nii_results_path, file_volume_name): 84 | nsd_sum = 0 85 | dice_sum = 0 86 | hau_sum = 0 87 | volume_files = pd.read_excel(args.root_path + "data/{}/{}".format(args.dataset_name, file_volume_name)) 88 | length = volume_files.shape[0] 89 | for idx in range(length): 90 | volume_file = volume_files.iloc[idx][0] 91 | # load gt nii 92 | gt_path= args.root_path + "data/{}/NII/{}".format(args.dataset_name, volume_file+'_seg.nii.gz') 93 | gt_nii = nib.load(gt_path) 94 | gt_data = np.uint8(gt_nii.get_fdata()) 95 | 96 | pred_path= nii_results_path + volume_file+'.nii.gz' 97 | pred_nii = nib.load(pred_path) 98 | pred_data = np.uint8(pred_nii.get_fdata()) 99 | 100 | spacing = gt_nii.header.get_zooms() 101 | 102 | surface_distances = compute_surface_distances(gt_data, pred_data, spacing_mm=spacing) 103 | nsd = compute_surface_dice_at_tolerance(surface_distances, 1) 104 | dice = compute_dice_coefficient(gt_data, pred_data) 105 | print(nsd) 106 | print(dice) 107 | nsd_sum += nsd 108 | dice_sum +=dice 109 | mean_nsd = nsd_sum/length 110 | mean_dice = dice_sum/length 111 | 112 | return mean_nsd, mean_dice 113 | 114 | 115 | 116 | 117 | def get_model_metric(args, model, snapshot_path, model_name, mode='test'): 118 | model.eval() 119 | 120 | file_slice_name = '{}_slice.xlsx'.format(mode) 121 | file_volume_name = '{}_volume.xlsx'.format(mode) 122 | val_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = file_slice_name) 123 | print('The overall number of validation images equals to %d' % len(val_dataset)) 124 | val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1) 125 | 126 | png_results_path = os.path.join(snapshot_path, '{}_png/'.format(model_name)) 127 | if os.path.isdir(png_results_path) is False: 128 | os.mkdir(png_results_path) 129 | 130 | for batch_idx, (image, label, file_name, _) in enumerate(val_dataloader): 131 | image = image.cuda() 132 | label = label.cuda() 133 | 134 | with torch.no_grad(): 135 | out_main = model(image) 136 | out = torch.argmax(torch.softmax(out_main, dim=1), dim=1).squeeze(0) 137 | out = out.cpu().detach().numpy() 138 | save_sample_png(png_results_path, file_name = file_name[0], out=out) 139 | 140 | # png results to nii.gz label 141 | nii_results_path = os.path.join(snapshot_path, '{}_nii/'.format(model_name)) 142 | if os.path.isdir(nii_results_path) is False: 143 | os.mkdir(nii_results_path) 144 | pngs_2_niigz(args= args, png_results_path = png_results_path, nii_results_path=nii_results_path, file_volume_name = file_volume_name) 145 | # evaluate result 146 | nsd, dice = evaluate_nii(args = args, nii_results_path=nii_results_path, file_volume_name = file_volume_name) 147 | return nsd, dice 148 | 149 | 150 | if __name__ == '__main__': 151 | args = parser.parse_args() 152 | 153 | -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def dice_loss(score, target): 9 | target = target.float() 10 | smooth = 1e-5 11 | intersect = torch.sum(score * target) 12 | y_sum = torch.sum(target * target) 13 | z_sum = torch.sum(score * score) 14 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 15 | loss = 1 - loss 16 | return loss 17 | 18 | 19 | def dice_loss1(score, target): 20 | target = target.float() 21 | smooth = 1e-5 22 | intersect = torch.sum(score * target) 23 | y_sum = torch.sum(target) 24 | z_sum = torch.sum(score) 25 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 26 | loss = 1 - loss 27 | return loss 28 | 29 | 30 | def entropy_loss(p, C=2): 31 | # p N*C*W*H*D 32 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / \ 33 | torch.tensor(np.log(C)).cuda() 34 | ent = torch.mean(y1) 35 | 36 | return ent 37 | 38 | 39 | def softmax_dice_loss(input_logits, target_logits): 40 | """Takes softmax on both sides and returns MSE loss 41 | 42 | Note: 43 | - Returns the sum over all examples. Divide by the batch size afterwards 44 | if you want the mean. 45 | - Sends gradients to inputs but not the targets. 46 | """ 47 | assert input_logits.size() == target_logits.size() 48 | input_softmax = F.softmax(input_logits, dim=1) 49 | target_softmax = F.softmax(target_logits, dim=1) 50 | n = input_logits.shape[1] 51 | dice = 0 52 | for i in range(0, n): 53 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 54 | mean_dice = dice / n 55 | 56 | return mean_dice 57 | 58 | 59 | def entropy_loss_map(p, C=2): 60 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 61 | keepdim=True)/torch.tensor(np.log(C)).cuda() 62 | return ent 63 | 64 | 65 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 66 | """Takes softmax on both sides and returns MSE loss 67 | 68 | Note: 69 | - Returns the sum over all examples. Divide by the batch size afterwards 70 | if you want the mean. 71 | - Sends gradients to inputs but not the targets. 72 | """ 73 | assert input_logits.size() == target_logits.size() 74 | if sigmoid: 75 | input_softmax = torch.sigmoid(input_logits) 76 | target_softmax = torch.sigmoid(target_logits) 77 | else: 78 | input_softmax = F.softmax(input_logits, dim=1) 79 | target_softmax = F.softmax(target_logits, dim=1) 80 | 81 | mse_loss = (input_softmax-target_softmax)**2 82 | return mse_loss 83 | 84 | 85 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 86 | """Takes softmax on both sides and returns KL divergence 87 | 88 | Note: 89 | - Returns the sum over all examples. Divide by the batch size afterwards 90 | if you want the mean. 91 | - Sends gradients to inputs but not the targets. 92 | """ 93 | assert input_logits.size() == target_logits.size() 94 | if sigmoid: 95 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 96 | target_softmax = torch.sigmoid(target_logits) 97 | else: 98 | input_log_softmax = F.log_softmax(input_logits, dim=1) 99 | target_softmax = F.softmax(target_logits, dim=1) 100 | 101 | # return F.kl_div(input_log_softmax, target_softmax) 102 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 103 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 104 | return kl_div 105 | 106 | 107 | def symmetric_mse_loss(input1, input2): 108 | """Like F.mse_loss but sends gradients to both directions 109 | 110 | Note: 111 | - Returns the sum over all examples. Divide by the batch size afterwards 112 | if you want the mean. 113 | - Sends gradients to both input1 and input2. 114 | """ 115 | assert input1.size() == input2.size() 116 | return torch.mean((input1 - input2)**2) 117 | 118 | 119 | class FocalLoss(nn.Module): 120 | def __init__(self, gamma=2, alpha=None, size_average=True): 121 | super(FocalLoss, self).__init__() 122 | self.gamma = gamma 123 | self.alpha = alpha 124 | if isinstance(alpha, (float, int)): 125 | self.alpha = torch.Tensor([alpha, 1-alpha]) 126 | if isinstance(alpha, list): 127 | self.alpha = torch.Tensor(alpha) 128 | self.size_average = size_average 129 | 130 | def forward(self, input, target): 131 | if input.dim() > 2: 132 | # N,C,H,W => N,C,H*W 133 | input = input.view(input.size(0), input.size(1), -1) 134 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 135 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 136 | target = target.view(-1, 1) 137 | 138 | logpt = F.log_softmax(input, dim=1) 139 | logpt = logpt.gather(1, target) 140 | logpt = logpt.view(-1) 141 | pt = Variable(logpt.data.exp()) 142 | 143 | if self.alpha is not None: 144 | if self.alpha.type() != input.data.type(): 145 | self.alpha = self.alpha.type_as(input.data) 146 | at = self.alpha.gather(0, target.data.view(-1)) 147 | logpt = logpt * Variable(at) 148 | 149 | loss = -1 * (1-pt)**self.gamma * logpt 150 | if self.size_average: 151 | return loss.mean() 152 | else: 153 | return loss.sum() 154 | 155 | 156 | class DiceLoss(nn.Module): 157 | def __init__(self, n_classes): 158 | super(DiceLoss, self).__init__() 159 | self.n_classes = n_classes 160 | 161 | def _one_hot_encoder(self, input_tensor): 162 | tensor_list = [] 163 | for i in range(self.n_classes): 164 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 165 | tensor_list.append(temp_prob) 166 | output_tensor = torch.cat(tensor_list, dim=1) 167 | return output_tensor.float() 168 | 169 | def _dice_loss(self, score, target): 170 | target = target.float() 171 | smooth = 1e-5 172 | intersect = torch.sum(score * target) 173 | y_sum = torch.sum(target * target) 174 | z_sum = torch.sum(score * score) 175 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 176 | loss = 1 - loss 177 | return loss 178 | 179 | def forward(self, inputs, target, weight=None, softmax=False): 180 | if softmax: 181 | inputs = torch.softmax(inputs, dim=1) 182 | target = self._one_hot_encoder(target) 183 | if weight is None: 184 | weight = [1] * self.n_classes 185 | assert inputs.size() == target.size(), 'predict & target shape do not match' 186 | class_wise_dice = [] 187 | loss = 0.0 188 | for i in range(0, self.n_classes): 189 | dice = self._dice_loss(inputs[:, i], target[:, i]) 190 | class_wise_dice.append(1.0 - dice.item()) 191 | loss += dice * weight[i] 192 | return loss / self.n_classes 193 | 194 | 195 | def entropy_minmization(p): 196 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) 197 | ent = torch.mean(y1) 198 | 199 | return ent 200 | 201 | 202 | def entropy_map(p): 203 | ent_map = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 204 | keepdim=True) 205 | return ent_map 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation. 2 | 3 | Implementation of [Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation](https://ieeexplore.ieee.org/document/9931157). 4 | 5 |

6 | 7 |

8 | 9 | ## Implementation 10 | 11 | ### 1. Installation 12 | 13 | ```python 14 | pytorch==1.9.0 15 | ``` 16 | 17 | ### 2. Datset Preparation 18 | 19 | ```bash 20 | ├── COVID249 21 | │   ├── NII (Original dataset in NIFTI) 22 | │   ├── PNG (Pre-processed dataset in PNG) 23 | │   ├── train_0.1_l.xlsx (datasplit for 10% setting) 24 | │   ├── train_0.1_u.xlsx (datasplit for 10% setting) 25 | │   ├── train_0.2_l.xlsx (datasplit for 20% setting) 26 | │   ├── train_0.2_u.xlsx (datasplit for 20% setting) 27 | │   ├── train_0.3_l.xlsx (datasplit for 30% setting) 28 | │   ├── train_0.3_u.xlsx (datasplit for 30% setting) 29 | │   ├── test_slice.xlsx (datasplit for testing) 30 | │   ├── val_slice.xlsx (datasplit for validation) 31 | ├── MOS1000 32 | │   ├── NII (Original dataset in NIFTI) 33 | │   ├── PNG (Pre-processed dataset in PNG) 34 | │   ├── train_slice_label.xlsx (datasplit) 35 | │   ├── train_slice_unlabel.xlsx (datasplit) 36 | │   ├── test_slice.xlsx (datasplit for testing) 37 | │   ├── val_slice.xlsx (datasplit for validation) 38 | ``` 39 | - Convert the nifti images to int32 png format, then subtract 32768 from the pixel intensities to obtain the original Hounsfield unit (HU) values, saved in Image folder, similar to the processing steps in [Deeplesion](https://nihcc.app.box.com/v/DeepLesion/file/306055882594). 40 | - The lung regions can be extracted by a leading lung segmentation model provided by [JoHof](https://github.com/JoHof/lungmask). 41 | - Pre-processed COVID249 can be downloaded from the [link](https://drive.google.com/file/d/1A2f3RRblSByFncUlf5MEr9VEjFlqD0ge/view?usp=sharing). MOS1000 can be processed using the same steps, we do not provide all the processed images due to its large dataset size. 42 | - You can train SEAN models following [SEAN](https://github.com/ZPdesu/SEAN) , or download our pre-trained checkpoints from the [link](https://drive.google.com/file/d/1K9Job_cJp3kfOZsc8IchTzotTmFyMeNy/view?usp=sharing). 43 | 44 | 45 | ### 3. Training Our Models 46 | 47 | ```python 48 | python train_SACPS.py 49 | python train_SAST.py 50 | ``` 51 | 52 | ### 4. Training Other Models 53 | 54 | We have provided a template for training other models, where we have implemented the dataloader, optimizer, etc. The core codes are shown as below: 55 | 56 | ```python 57 | for epoch in range(max_epoch): 58 | print("Start epoch ", epoch+1, "!") 59 | 60 | tbar = tqdm(range(len(unlabeled_dataloader)), ncols=70) 61 | labeled_dataloader_iter = iter(labeled_dataloader) 62 | unlabeled_dataloader_iter = iter(unlabeled_dataloader) 63 | 64 | for batch_idx in tbar: 65 | try: 66 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 67 | except StopIteration: 68 | labeled_dataloader_iter = iter(labeled_dataloader) 69 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 70 | 71 | # load data 72 | input_ul, target_ul, file_name_ul , lung_ul = unlabeled_dataloader_iter.next() 73 | input_ul, target_ul, lung_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True), lung_ul.cuda(non_blocking=True) 74 | input_l, target_l, lung_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True), lung_l.cuda(non_blocking=True) 75 | 76 | 77 | # Add impelmentation here: the training process 78 | #------------------------------------------------------------- 79 | #************************************************************* 80 | #------------------------------------------------------------- 81 | 82 | 83 | optimizer.zero_grad() 84 | loss.backward() 85 | optimizer.step() 86 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 87 | for param_group in optimizer.param_groups: 88 | param_group['lr'] = lr_ 89 | 90 | iter_num = iter_num + 1 91 | writer.add_scalar('info/lr', lr_, iter_num) 92 | writer.add_scalar('info/total_loss', loss, iter_num) 93 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 94 | writer.close() 95 | ``` 96 | 97 | ### 5. Testing 98 | 99 | ```python 100 | python segment_test.py 101 | ``` 102 | 103 | ```python 104 | def test(args, snapshot_path): 105 | model = net_factory(net_type=args.model) 106 | model.load_state_dict(torch.load(args.model_path)) 107 | model.eval() 108 | 109 | nsd, dice = get_model_metric(args = args, model = model, snapshot_path=snapshot_path, model_name='model', mode='test') 110 | print('nsd : %f dice : %f ' % (nsd, dice)) 111 | ``` 112 | 113 | - snapshot_path: folder for saving results. 114 | - args.model: model type. 115 | - args.model_path: trained model path. 116 | - get_model_metric(): which includes prediction, png2nifti, calculate_nsd_dsc. 117 | 118 | ## Suplementary information 119 | 120 | 121 | 1. Statistics of the datasets. 122 | 123 | Descriptive statistics, including x-, y- and z-spacing, of both datasets are shown as follow. 124 | 125 |

126 | 127 |

128 | 129 | 130 | 2. Links for competing methods. 131 | - [Self-Ensembling]: [JBHI 2021](https://ieeexplore.ieee.org/abstract/document/9511146); [code](https://github.com/CaiziLee/SECT) 132 | - [Cross Pseudo Supervision]: [CVPR 2021](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Semi-Supervised_Semantic_Segmentation_With_Cross_Pseudo_Supervision_CVPR_2021_paper.pdf); [code](https://github.com/charlesCXK/TorchSemiSeg) 133 | - [Uncertainty-Aware Mean-Teacher]: [CVPR 2021](https://link.springer.com/chapter/10.1007/978-3-030-32245-8_67); [code](https://github.com/yulequan/UA-MT) 134 | - [Cross-Consistency Training]: [CVPR 2020](https://openaccess.thecvf.com/content_CVPR_2020/papers/Ouali_Semi-Supervised_Semantic_Segmentation_With_Cross-Consistency_Training_CVPR_2020_paper.pdf); [code](https://github.com/yassouali/CCT) 135 | - [Uncertainty-guided Dual-Consistency]: [MICCAI 2021](https://link.springer.com/chapter/10.1007/978-3-030-87196-3_19); [code](https://github.com/poiuohke/UDC-Net) 136 | - [SemiInfNet]: [TMI 2020](https://ieeexplore.ieee.org/abstract/document/9098956); [code](https://github.com/DengPingFan/Inf-Net) 137 | - [Self-Training]: [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/27e9661e033a73a6ad8cefcde965c54d-Abstract.html); [code](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/self_training) 138 | 139 | 140 | 141 | ## Citation 142 | If you find this repository useful for your research, please cite the following: 143 | ``` 144 | @ARTICLE{9931157, 145 | author={Lyu, Fei and Ye, Mang and Carlsen, Jonathan Frederik and Erleben, Kenny and Darkner, Sune and Yuen, Pong C.}, 146 | journal={IEEE Transactions on Medical Imaging}, 147 | title={Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation}, 148 | year={2022}, 149 | volume={}, 150 | number={}, 151 | pages={1-1}, 152 | doi={10.1109/TMI.2022.3217501}} 153 | ``` 154 | 155 | ## Acknowledgments 156 | We thank Luo, Xiangde for sharing his codes, our code borrows heavily from https://github.com/HiLab-git/SSL4MIS. 157 | -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | # import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /code/util/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import ntpath 8 | import time 9 | from . import util 10 | from . import html 11 | import scipy.misc 12 | try: 13 | from StringIO import StringIO # Python 2.7 14 | except ImportError: 15 | from io import BytesIO # Python 3.x 16 | 17 | class Visualizer(): 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.tf_log = opt.isTrain and opt.tf_log 21 | self.use_html = opt.isTrain and not opt.no_html 22 | self.win_size = opt.display_winsize 23 | self.name = opt.name 24 | if self.tf_log: 25 | import tensorflow as tf 26 | self.tf = tf 27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 28 | self.writer = tf.summary.FileWriter(self.log_dir) 29 | 30 | if self.use_html: 31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 32 | self.img_dir = os.path.join(self.web_dir, 'images') 33 | print('create web directory %s...' % self.web_dir) 34 | util.mkdirs([self.web_dir, self.img_dir]) 35 | if opt.isTrain: 36 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 37 | with open(self.log_name, "a") as log_file: 38 | now = time.strftime("%c") 39 | log_file.write('================ Training Loss (%s) ================\n' % now) 40 | 41 | # |visuals|: dictionary of images to display or save 42 | def display_current_results(self, visuals, epoch, step): 43 | 44 | ## convert tensors to numpy arrays 45 | visuals = self.convert_visuals_to_numpy(visuals) 46 | 47 | if self.tf_log: # show images in tensorboard output 48 | img_summaries = [] 49 | for label, image_numpy in visuals.items(): 50 | # Write the image to a string 51 | try: 52 | s = StringIO() 53 | except: 54 | s = BytesIO() 55 | if len(image_numpy.shape) >= 4: 56 | image_numpy = image_numpy[0] 57 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 58 | # Create an Image object 59 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 60 | # Create a Summary value 61 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 62 | 63 | # Create and write Summary 64 | summary = self.tf.Summary(value=img_summaries) 65 | self.writer.add_summary(summary, step) 66 | 67 | if self.use_html: # save images to a html file 68 | for label, image_numpy in visuals.items(): 69 | if isinstance(image_numpy, list): 70 | for i in range(len(image_numpy)): 71 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 72 | util.save_image(image_numpy[i], img_path) 73 | else: 74 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 75 | if len(image_numpy.shape) >= 4: 76 | image_numpy = image_numpy[0] 77 | util.save_image(image_numpy, img_path) 78 | 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 81 | for n in range(epoch, 0, -1): 82 | webpage.add_header('epoch [%d]' % n) 83 | ims = [] 84 | txts = [] 85 | links = [] 86 | 87 | for label, image_numpy in visuals.items(): 88 | if isinstance(image_numpy, list): 89 | for i in range(len(image_numpy)): 90 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 91 | ims.append(img_path) 92 | txts.append(label+str(i)) 93 | links.append(img_path) 94 | else: 95 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 96 | ims.append(img_path) 97 | txts.append(label) 98 | links.append(img_path) 99 | if len(ims) < 10: 100 | webpage.add_images(ims, txts, links, width=self.win_size) 101 | else: 102 | num = int(round(len(ims)/2.0)) 103 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 104 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 105 | webpage.save() 106 | 107 | # errors: dictionary of error labels and values 108 | def plot_current_errors(self, errors, step): 109 | if self.tf_log: 110 | for tag, value in errors.items(): 111 | value = value.mean().float() 112 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 113 | self.writer.add_summary(summary, step) 114 | 115 | # errors: same format as |errors| of plotCurrentErrors 116 | def print_current_errors(self, epoch, i, errors, t): 117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 118 | for k, v in errors.items(): 119 | #print(v) 120 | #if v != 0: 121 | v = v.mean().float() 122 | message += '%s: %.3f ' % (k, v) 123 | 124 | print(message) 125 | with open(self.log_name, "a") as log_file: 126 | log_file.write('%s\n' % message) 127 | 128 | def convert_visuals_to_numpy(self, visuals): 129 | for key, t in visuals.items(): 130 | tile = self.opt.batchSize > 8 131 | if 'input_label' == key: 132 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 133 | else: 134 | t = util.tensor2im(t, tile=tile) 135 | visuals[key] = t 136 | return visuals 137 | 138 | # save image to the disk 139 | def save_images(self, webpage, visuals, image_path): 140 | visuals = self.convert_visuals_to_numpy(visuals) 141 | 142 | image_dir = webpage.get_image_dir() 143 | short_path = ntpath.basename(image_path[0]) 144 | name = os.path.splitext(short_path)[0] 145 | 146 | webpage.add_header(name) 147 | ims = [] 148 | txts = [] 149 | links = [] 150 | 151 | for label, image_numpy in visuals.items(): 152 | image_name = os.path.join(label, '%s.png' % (name)) 153 | save_path = os.path.join(image_dir, image_name) 154 | util.save_image(image_numpy, save_path, create_dir=True) 155 | 156 | ims.append(image_name) 157 | txts.append(label) 158 | links.append(image_name) 159 | webpage.add_images(ims, txts, links, width=self.win_size) 160 | 161 | # Our codes My single image convert 162 | 163 | def convert_image(self, generated): 164 | tile = self.opt.batchSize > 8 165 | t = util.tensor2im(generated, tile=tile)[0] 166 | 167 | #image_pil = Image.fromarray(t) 168 | 169 | # save to png 170 | #image_pil.save('test.png') 171 | 172 | return (t) -------------------------------------------------------------------------------- /code/train_template.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | from itertools import cycle 23 | import numpy as np 24 | import cv2 25 | 26 | from dataloaders import utils 27 | from dataloaders.dataset_covid import (CovidDataSets, RandomGenerator) 28 | from networks.net_factory import net_factory 29 | from utils import losses, metrics, ramps 30 | from test_covid import get_model_metric 31 | from models.pix2pix_model import Pix2PixModel, get_opt 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--root_path', type=str, default='/home/code/SSL/', help='Name of Experiment') 36 | 37 | parser.add_argument('--labeled_per', type=float, default=0.1, help='percent of labeled data') 38 | if False: 39 | parser.add_argument('--dataset_name', type=str, default='COVID249', help='Name of dataset') 40 | parser.add_argument('--excel_file_name_label', type=str, default='train_0.1_l.xlsx', help='Name of dataset') 41 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_0.1_u.xlsx', help='Name of dataset') 42 | else: 43 | parser.add_argument('--dataset_name', type=str, default='MOS1000', help='Name of dataset') 44 | parser.add_argument('--excel_file_name_label', type=str, default='train_slice_label.xlsx', help='Name of dataset') 45 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_slice_unlabel.xlsx', help='Name of dataset') 46 | 47 | parser.add_argument('--exp', type=str, default='Template', help='experiment_name') 48 | 49 | parser.add_argument('--model', type=str, default='unet', help='model_name') 50 | parser.add_argument('--max_epoch', type=int, default=20, help='maximum epoch number to train') 51 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 52 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 53 | parser.add_argument('--patch_size', type=list, default=[512, 512], help='patch size of network input') 54 | parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') 55 | 56 | # label and unlabel 57 | parser.add_argument('--batch_size_label', type=int, default=8, help='batch_size per gpu') 58 | parser.add_argument('--batch_size_unlabel', type=int, default=8, help='batch_size per gpu') 59 | 60 | args = parser.parse_args() 61 | 62 | def train(args, snapshot_path): 63 | base_lr = args.base_lr 64 | base_lr = args.base_lr 65 | num_classes = args.num_classes 66 | max_epoch = args.max_epoch 67 | excel_file_name_label = args.excel_file_name_label 68 | excel_file_name_unlabel = args.excel_file_name_unlabel 69 | 70 | 71 | # create model 72 | teacher_model = net_factory(net_type=args.model) 73 | student_model = net_factory(net_type=args.model) 74 | teacher_model.load_state_dict(torch.load(args.teacher_path)) 75 | teacher_model.eval() 76 | opt = get_opt() 77 | syn_model = Pix2PixModel(opt) 78 | syn_model.eval() 79 | 80 | # Define the dataset 81 | labeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_label, aug = True) 82 | unlabeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_unlabel, aug = True) 83 | print('The overall number of labeled training image equals to %d' % len(labeled_train_dataset)) 84 | print('The overall number of unlabeled training images equals to %d' % len(unlabeled_train_dataset)) 85 | 86 | student_model.train() 87 | 88 | optimizer = optim.SGD(student_model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 89 | ce_loss = CrossEntropyLoss() 90 | dice_loss = losses.DiceLoss(num_classes) 91 | writer = SummaryWriter(snapshot_path + '/log') 92 | 93 | # Define the dataloader 94 | labeled_dataloader = DataLoader(labeled_train_dataset, batch_size = args.batch_size_label, shuffle = True, num_workers = 4, pin_memory = True) 95 | unlabeled_dataloader = DataLoader(unlabeled_train_dataset, batch_size = args.batch_size_unlabel, shuffle = True, num_workers = 4, pin_memory = True) 96 | 97 | iter_num_s = 0 98 | max_iterations_s = max_epoch * len(unlabeled_dataloader) 99 | for epoch in range(max_epoch): 100 | print("Start epoch ", epoch+1, "!") 101 | 102 | tbar = tqdm(range(len(unlabeled_dataloader)), ncols=70) 103 | labeled_dataloader_iter = iter(labeled_dataloader) 104 | unlabeled_dataloader_iter = iter(unlabeled_dataloader) 105 | 106 | for batch_idx in tbar: 107 | 108 | try: 109 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 110 | except StopIteration: 111 | labeled_dataloader_iter = iter(labeled_dataloader) 112 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 113 | 114 | style_output_global_positive_list =[] 115 | 116 | input_ul, target_ul, file_name_ul , lung_ul = unlabeled_dataloader_iter.next() 117 | input_ul, target_ul, lung_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True), lung_ul.cuda(non_blocking=True) 118 | input_l, target_l, lung_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True), lung_l.cuda(non_blocking=True) 119 | 120 | 121 | # Add new impelmentation here: design the training process 122 | #-------------------------------------------------------------------------------------------------------- 123 | #******************************************************************************************************** 124 | #-------------------------------------------------------------------------------------------------------- 125 | 126 | 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 131 | for param_group in optimizer_s.param_groups: 132 | param_group['lr'] = lr_ 133 | 134 | iter_num = iter_num + 1 135 | writer.add_scalar('info/lr', lr_, iter_num) 136 | writer.add_scalar('info/total_loss', loss, iter_num) 137 | logging.info('iteration %d : loss : %f' % (iter_num, loss.item())) 138 | writer.close() 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | # for reproducing 144 | cv2.setNumThreads(0) 145 | cv2.ocl.setUseOpenCL(False) 146 | seed = 66 147 | print("[ Using Seed : ", seed, " ]") 148 | torch.manual_seed(seed) 149 | torch.cuda.manual_seed_all(seed) 150 | torch.cuda.manual_seed(seed) 151 | np.random.seed(seed) 152 | random.seed(seed) 153 | torch.backends.cudnn.deterministic = True 154 | torch.backends.cudnn.benchmark = False 155 | os.environ["PYTHONHASHSEED"] = str(seed) 156 | 157 | snapshot_path = "{}exp/{}/exp_{}_{}_{}".format(args.root_path, args.dataset_name, args.exp, args.labeled_per, args.model) 158 | if not os.path.exists(snapshot_path): 159 | os.makedirs(snapshot_path) 160 | 161 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 162 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 163 | logging.info(str(args)) 164 | train(args, snapshot_path) -------------------------------------------------------------------------------- /code/models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torch.nn.utils.spectral_norm as spectral_norm 11 | from models.networks.normalization import SPADE, ACE 12 | 13 | 14 | # ResNet block that uses SPADE. 15 | # It differs from the ResNet block of pix2pixHD in that 16 | # it takes in the segmentation map as input, learns the skip connection if necessary, 17 | # and applies normalization first and then convolution. 18 | # This architecture seemed like a standard architecture for unconditional or 19 | # class-conditional GAN architecture using residual block. 20 | # The code was inspired from https://github.com/LMescheder/GAN_stability. 21 | class SPADEResnetBlock(nn.Module): 22 | def __init__(self, fin, fout, opt, Block_Name=None, use_rgb=True): 23 | super().__init__() 24 | 25 | self.use_rgb = use_rgb 26 | 27 | self.Block_Name = Block_Name 28 | self.status = opt.status 29 | 30 | # Attributes 31 | self.learned_shortcut = (fin != fout) 32 | fmiddle = min(fin, fout) 33 | 34 | # create conv layers 35 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 36 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 37 | if self.learned_shortcut: 38 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 39 | 40 | # apply spectral norm if specified 41 | if 'spectral' in opt.norm_G: 42 | self.conv_0 = spectral_norm(self.conv_0) 43 | self.conv_1 = spectral_norm(self.conv_1) 44 | if self.learned_shortcut: 45 | self.conv_s = spectral_norm(self.conv_s) 46 | 47 | # define normalization layers 48 | spade_config_str = opt.norm_G.replace('spectral', '') 49 | 50 | 51 | ########### Modifications 1 52 | normtype_list = ['spadeinstance3x3', 'spadesyncbatch3x3', 'spadebatch3x3'] 53 | our_norm_type = 'spadesyncbatch3x3' 54 | 55 | self.ace_0 = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_0', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb) 56 | ########### Modifications 1 57 | 58 | 59 | ########### Modifications 1 60 | self.ace_1 = ACE(our_norm_type, fmiddle, 3, ACE_Name= Block_Name + '_ACE_1', status=self.status, spade_params=[spade_config_str, fmiddle, opt.semantic_nc], use_rgb=use_rgb) 61 | ########### Modifications 1 62 | 63 | if self.learned_shortcut: 64 | self.ace_s = ACE(our_norm_type, fin, 3, ACE_Name= Block_Name + '_ACE_s', status=self.status, spade_params=[spade_config_str, fin, opt.semantic_nc], use_rgb=use_rgb) 65 | 66 | # note the resnet block with SPADE also takes in |seg|, 67 | # the semantic segmentation map as input 68 | def forward(self, x, seg, style_codes, obj_dic=None): 69 | 70 | 71 | x_s = self.shortcut(x, seg, style_codes, obj_dic) 72 | 73 | 74 | ########### Modifications 1 75 | dx = self.ace_0(x, seg, style_codes, obj_dic) 76 | 77 | dx = self.conv_0(self.actvn(dx)) 78 | 79 | dx = self.ace_1(dx, seg, style_codes, obj_dic) 80 | 81 | dx = self.conv_1(self.actvn(dx)) 82 | ########### Modifications 1 83 | 84 | 85 | out = x_s + dx 86 | return out 87 | 88 | def shortcut(self, x, seg, style_codes, obj_dic): 89 | if self.learned_shortcut: 90 | x_s = self.ace_s(x, seg, style_codes, obj_dic) 91 | x_s = self.conv_s(x_s) 92 | 93 | else: 94 | x_s = x 95 | return x_s 96 | 97 | def actvn(self, x): 98 | return F.leaky_relu(x, 2e-1) 99 | 100 | 101 | # ResNet block used in pix2pixHD 102 | # We keep the same architecture as pix2pixHD. 103 | class ResnetBlock(nn.Module): 104 | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): 105 | super().__init__() 106 | 107 | pw = (kernel_size - 1) // 2 108 | self.conv_block = nn.Sequential( 109 | nn.ReflectionPad2d(pw), 110 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), 111 | activation, 112 | nn.ReflectionPad2d(pw), 113 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)) 114 | ) 115 | 116 | def forward(self, x): 117 | y = self.conv_block(x) 118 | out = x + y 119 | return out 120 | 121 | 122 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 123 | class VGG19(torch.nn.Module): 124 | def __init__(self, requires_grad=False): 125 | super().__init__() 126 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 127 | self.slice1 = torch.nn.Sequential() 128 | self.slice2 = torch.nn.Sequential() 129 | self.slice3 = torch.nn.Sequential() 130 | self.slice4 = torch.nn.Sequential() 131 | self.slice5 = torch.nn.Sequential() 132 | for x in range(2): 133 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(2, 7): 135 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(7, 12): 137 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(12, 21): 139 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 140 | for x in range(21, 30): 141 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 142 | if not requires_grad: 143 | for param in self.parameters(): 144 | param.requires_grad = False 145 | 146 | def forward(self, X): 147 | h_relu1 = self.slice1(X) 148 | h_relu2 = self.slice2(h_relu1) 149 | h_relu3 = self.slice3(h_relu2) 150 | h_relu4 = self.slice4(h_relu3) 151 | h_relu5 = self.slice5(h_relu4) 152 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 153 | return out 154 | 155 | 156 | class Zencoder(torch.nn.Module): 157 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=2, norm_layer=nn.InstanceNorm2d): 158 | super(Zencoder, self).__init__() 159 | self.output_nc = output_nc 160 | 161 | model = [nn.ReflectionPad2d(1), nn.Conv2d(input_nc, ngf, kernel_size=3, padding=0), 162 | norm_layer(ngf), nn.LeakyReLU(0.2, False)] 163 | ### downsample 164 | for i in range(n_downsampling): 165 | mult = 2**i 166 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 167 | norm_layer(ngf * mult * 2), nn.LeakyReLU(0.2, False)] 168 | 169 | ### upsample 170 | for i in range(1): 171 | mult = 2**(n_downsampling - i) 172 | model += [nn.ConvTranspose2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, output_padding=1), 173 | norm_layer(int(ngf * mult / 2)), nn.LeakyReLU(0.2, False)] 174 | 175 | model += [nn.ReflectionPad2d(1), nn.Conv2d(256, output_nc, kernel_size=3, padding=0), nn.Tanh()] 176 | self.model = nn.Sequential(*model) 177 | 178 | 179 | def forward(self, input, segmap): 180 | 181 | codes = self.model(input) 182 | 183 | segmap = F.interpolate(segmap, size=codes.size()[2:], mode='nearest') 184 | 185 | # print(segmap.shape) 186 | # print(codes.shape) 187 | 188 | 189 | b_size = codes.shape[0] 190 | # h_size = codes.shape[2] 191 | # w_size = codes.shape[3] 192 | f_size = codes.shape[1] 193 | 194 | s_size = segmap.shape[1] 195 | 196 | codes_vector = torch.zeros((b_size, s_size, f_size), dtype=codes.dtype, device=codes.device) 197 | 198 | 199 | for i in range(b_size): 200 | for j in range(s_size): 201 | component_mask_area = torch.sum(segmap.bool()[i, j]) 202 | 203 | if component_mask_area > 0: 204 | codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size, component_mask_area).mean(1) 205 | codes_vector[i][j] = codes_component_feature 206 | 207 | # codes_avg[i].masked_scatter_(segmap.bool()[i, j], codes_component_mu) 208 | 209 | return codes_vector -------------------------------------------------------------------------------- /code/util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import importlib 8 | import torch 9 | from argparse import Namespace 10 | import numpy as np 11 | from PIL import Image 12 | import os 13 | import argparse 14 | import dill as pickle 15 | import util.coco 16 | 17 | 18 | def save_obj(obj, name): 19 | with open(name, 'wb') as f: 20 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 21 | 22 | 23 | def load_obj(name): 24 | with open(name, 'rb') as f: 25 | return pickle.load(f) 26 | 27 | # returns a configuration for creating a generator 28 | # |default_opt| should be the opt of the current experiment 29 | # |**kwargs|: if any configuration should be overriden, it can be specified here 30 | 31 | 32 | def copyconf(default_opt, **kwargs): 33 | conf = argparse.Namespace(**vars(default_opt)) 34 | for key in kwargs: 35 | print(key, kwargs[key]) 36 | setattr(conf, key, kwargs[key]) 37 | return conf 38 | 39 | 40 | def tile_images(imgs, picturesPerRow=4): 41 | """ Code borrowed from 42 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 43 | """ 44 | 45 | # Padding 46 | if imgs.shape[0] % picturesPerRow == 0: 47 | rowPadding = 0 48 | else: 49 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 50 | if rowPadding > 0: 51 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 52 | 53 | # Tiling Loop (The conditionals are not necessary anymore) 54 | tiled = [] 55 | for i in range(0, imgs.shape[0], picturesPerRow): 56 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 57 | 58 | tiled = np.concatenate(tiled, axis=0) 59 | return tiled 60 | 61 | 62 | # Converts a Tensor into a Numpy array 63 | # |imtype|: the desired type of the converted numpy array 64 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): 65 | if isinstance(image_tensor, list): 66 | image_numpy = [] 67 | for i in range(len(image_tensor)): 68 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 69 | return image_numpy 70 | 71 | if image_tensor.dim() == 4: 72 | # transform each image in the batch 73 | images_np = [] 74 | for b in range(image_tensor.size(0)): 75 | one_image = image_tensor[b] 76 | one_image_np = tensor2im(one_image) 77 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 78 | images_np = np.concatenate(images_np, axis=0) 79 | if tile: 80 | images_tiled = tile_images(images_np) 81 | return images_tiled 82 | else: 83 | return images_np 84 | 85 | if image_tensor.dim() == 2: 86 | image_tensor = image_tensor.unsqueeze(0) 87 | image_numpy = image_tensor.detach().cpu().float().numpy() 88 | if normalize: 89 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 90 | else: 91 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 92 | image_numpy = np.clip(image_numpy, 0, 255) 93 | if image_numpy.shape[2] == 1: 94 | image_numpy = image_numpy[:, :, 0] 95 | return image_numpy.astype(imtype) 96 | 97 | 98 | # Converts a one-hot tensor into a colorful label map 99 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): 100 | if label_tensor.dim() == 4: 101 | # transform each image in the batch 102 | images_np = [] 103 | for b in range(label_tensor.size(0)): 104 | one_image = label_tensor[b] 105 | one_image_np = tensor2label(one_image, n_label, imtype) 106 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 107 | images_np = np.concatenate(images_np, axis=0) 108 | if tile: 109 | images_tiled = tile_images(images_np) 110 | return images_tiled 111 | else: 112 | images_np = images_np[0] 113 | return images_np 114 | 115 | if label_tensor.dim() == 1: 116 | return np.zeros((64, 64, 3), dtype=np.uint8) 117 | if n_label == 0: 118 | return tensor2im(label_tensor, imtype) 119 | label_tensor = label_tensor.cpu().float() 120 | if label_tensor.size()[0] > 1: 121 | label_tensor = label_tensor.max(0, keepdim=True)[1] 122 | label_tensor = Colorize(n_label)(label_tensor) 123 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 124 | result = label_numpy.astype(imtype) 125 | return result 126 | 127 | 128 | def save_image(image_numpy, image_path, create_dir=False): 129 | if create_dir: 130 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 131 | if len(image_numpy.shape) == 2: 132 | image_numpy = np.expand_dims(image_numpy, axis=2) 133 | if image_numpy.shape[2] == 1: 134 | image_numpy = np.repeat(image_numpy, 3, 2) 135 | image_pil = Image.fromarray(image_numpy) 136 | 137 | # save to png 138 | image_pil.save(image_path.replace('.jpg', '.png')) 139 | 140 | 141 | def mkdirs(paths): 142 | if isinstance(paths, list) and not isinstance(paths, str): 143 | for path in paths: 144 | mkdir(path) 145 | else: 146 | mkdir(paths) 147 | 148 | 149 | def mkdir(path): 150 | if not os.path.exists(path): 151 | os.makedirs(path) 152 | 153 | 154 | def atoi(text): 155 | return int(text) if text.isdigit() else text 156 | 157 | 158 | def natural_keys(text): 159 | ''' 160 | alist.sort(key=natural_keys) sorts in human order 161 | http://nedbatchelder.com/blog/200712/human_sorting.html 162 | (See Toothy's implementation in the comments) 163 | ''' 164 | return [atoi(c) for c in re.split('(\d+)', text)] 165 | 166 | 167 | def natural_sort(items): 168 | items.sort(key=natural_keys) 169 | 170 | 171 | def str2bool(v): 172 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 173 | return True 174 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 175 | return False 176 | else: 177 | raise argparse.ArgumentTypeError('Boolean value expected.') 178 | 179 | 180 | def find_class_in_module(target_cls_name, module): 181 | target_cls_name = target_cls_name.replace('_', '').lower() 182 | clslib = importlib.import_module(module) 183 | cls = None 184 | for name, clsobj in clslib.__dict__.items(): 185 | if name.lower() == target_cls_name: 186 | cls = clsobj 187 | 188 | if cls is None: 189 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 190 | exit(0) 191 | 192 | return cls 193 | 194 | 195 | def save_network(net, label, epoch, opt): 196 | save_filename = '%s_net_%s.pth' % (epoch, label) 197 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 198 | torch.save(net.cpu().state_dict(), save_path) 199 | if len(opt.gpu_ids) and torch.cuda.is_available(): 200 | net.cuda() 201 | 202 | 203 | def load_network(net, label, epoch, opt): 204 | save_filename = '%s_net_%s.pth' % (epoch, label) 205 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 206 | save_path = os.path.join(save_dir, save_filename) 207 | weights = torch.load(save_path) 208 | net.load_state_dict(weights) 209 | return net 210 | 211 | 212 | ############################################################################### 213 | # Code from 214 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 215 | # Modified so it complies with the Citscape label map colors 216 | ############################################################################### 217 | def uint82bin(n, count=8): 218 | """returns the binary of integer n, count refers to amount of bits""" 219 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 220 | 221 | 222 | def labelcolormap(N): 223 | if N == 35: # cityscape 224 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), 225 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 226 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), 227 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), 228 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], 229 | dtype=np.uint8) 230 | else: 231 | cmap = np.zeros((N, 3), dtype=np.uint8) 232 | for i in range(N): 233 | r, g, b = 0, 0, 0 234 | id = i + 1 # let's give 0 a color 235 | for j in range(7): 236 | str_id = uint82bin(id) 237 | r = r ^ (np.uint8(str_id[-1]) << (7 - j)) 238 | g = g ^ (np.uint8(str_id[-2]) << (7 - j)) 239 | b = b ^ (np.uint8(str_id[-3]) << (7 - j)) 240 | id = id >> 3 241 | cmap[i, 0] = r 242 | cmap[i, 1] = g 243 | cmap[i, 2] = b 244 | 245 | if N == 182: # COCO 246 | important_colors = { 247 | 'sea': (54, 62, 167), 248 | 'sky-other': (95, 219, 255), 249 | 'tree': (140, 104, 47), 250 | 'clouds': (170, 170, 170), 251 | 'grass': (29, 195, 49) 252 | } 253 | for i in range(N): 254 | name = util.coco.id2label(i) 255 | if name in important_colors: 256 | color = important_colors[name] 257 | cmap[i] = np.array(list(color)) 258 | 259 | return cmap 260 | 261 | 262 | class Colorize(object): 263 | def __init__(self, n=35): 264 | self.cmap = labelcolormap(n) 265 | self.cmap = torch.from_numpy(self.cmap[:n]) 266 | 267 | def __call__(self, gray_image): 268 | size = gray_image.size() 269 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 270 | 271 | for label in range(0, len(self.cmap)): 272 | mask = (label == gray_image[0]).cpu() 273 | color_image[0][mask] = self.cmap[label][0] 274 | color_image[1][mask] = self.cmap[label][1] 275 | color_image[2][mask] = self.cmap[label][2] 276 | 277 | return color_image 278 | -------------------------------------------------------------------------------- /code/models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | import os 13 | import numpy as np 14 | 15 | 16 | 17 | 18 | # Returns a function that creates a normalization function 19 | # that does not condition on semantic map 20 | def get_nonspade_norm_layer(opt, norm_type='instance'): 21 | # helper function to get # output channels of the previous layer 22 | def get_out_channel(layer): 23 | if hasattr(layer, 'out_channels'): 24 | return getattr(layer, 'out_channels') 25 | return layer.weight.size(0) 26 | 27 | # this function will be returned 28 | def add_norm_layer(layer): 29 | nonlocal norm_type 30 | if norm_type.startswith('spectral'): 31 | layer = spectral_norm(layer) 32 | subnorm_type = norm_type[len('spectral'):] 33 | 34 | if subnorm_type == 'none' or len(subnorm_type) == 0: 35 | return layer 36 | 37 | # remove bias in the previous layer, which is meaningless 38 | # since it has no effect after normalization 39 | if getattr(layer, 'bias', None) is not None: 40 | delattr(layer, 'bias') 41 | layer.register_parameter('bias', None) 42 | 43 | if subnorm_type == 'batch': 44 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 45 | elif subnorm_type == 'sync_batch': 46 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 47 | elif subnorm_type == 'instance': 48 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 49 | else: 50 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 51 | 52 | return nn.Sequential(layer, norm_layer) 53 | 54 | return add_norm_layer 55 | 56 | 57 | # Creates SPADE normalization layer based on the given configuration 58 | # SPADE consists of two steps. First, it normalizes the activations using 59 | # your favorite normalization method, such as Batch Norm or Instance Norm. 60 | # Second, it applies scale and bias to the normalized output, conditioned on 61 | # the segmentation map. 62 | # The format of |config_text| is spade(norm)(ks), where 63 | # (norm) specifies the type of parameter-free normalization. 64 | # (e.g. syncbatch, batch, instance) 65 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 66 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 67 | # Also, the other arguments are 68 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 69 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 70 | 71 | 72 | 73 | class ACE(nn.Module): 74 | def __init__(self, config_text, norm_nc, label_nc, ACE_Name=None, status='train', spade_params=None, use_rgb=True): 75 | super().__init__() 76 | 77 | self.ACE_Name = ACE_Name 78 | self.status = status 79 | self.save_npy = True 80 | self.Spade = SPADE(*spade_params) 81 | self.use_rgb = use_rgb 82 | self.style_length = 512 83 | self.blending_gamma = nn.Parameter(torch.zeros(1), requires_grad=True) 84 | self.blending_beta = nn.Parameter(torch.zeros(1), requires_grad=True) 85 | self.noise_var = nn.Parameter(torch.zeros(norm_nc), requires_grad=True) 86 | 87 | 88 | assert config_text.startswith('spade') 89 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 90 | param_free_norm_type = str(parsed.group(1)) 91 | ks = int(parsed.group(2)) 92 | pw = ks // 2 93 | 94 | if param_free_norm_type == 'instance': 95 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 96 | elif param_free_norm_type == 'syncbatch': 97 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 98 | elif param_free_norm_type == 'batch': 99 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 100 | else: 101 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 102 | % param_free_norm_type) 103 | 104 | # The dimension of the intermediate embedding space. Yes, hardcoded. 105 | 106 | 107 | if self.use_rgb: 108 | self.create_gamma_beta_fc_layers() 109 | 110 | self.conv_gamma = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw) 111 | self.conv_beta = nn.Conv2d(self.style_length, norm_nc, kernel_size=ks, padding=pw) 112 | 113 | 114 | 115 | 116 | def forward(self, x, segmap, style_codes=None, obj_dic=None): 117 | 118 | # Part 1. generate parameter-free normalized activations 119 | added_noise = (torch.randn(x.shape[0], x.shape[3], x.shape[2], 1).cuda() * self.noise_var).transpose(1, 3) 120 | normalized = self.param_free_norm(x + added_noise) 121 | 122 | # Part 2. produce scaling and bias conditioned on semantic map 123 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 124 | 125 | if self.use_rgb: 126 | [b_size, f_size, h_size, w_size] = normalized.shape 127 | middle_avg = torch.zeros((b_size, self.style_length, h_size, w_size), device=normalized.device) 128 | 129 | if self.status == 'UI_mode': 130 | ############## hard coding 131 | 132 | for i in range(1): 133 | for j in range(segmap.shape[1]): 134 | 135 | component_mask_area = torch.sum(segmap.bool()[i, j]) 136 | 137 | if component_mask_area > 0: 138 | if obj_dic is None: 139 | print('wrong even it is the first input') 140 | else: 141 | style_code_tmp = obj_dic[str(j)]['ACE'] 142 | 143 | middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_code_tmp)) 144 | component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length,component_mask_area) 145 | 146 | middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu) 147 | 148 | else: 149 | 150 | for i in range(b_size): 151 | for j in range(segmap.shape[1]): 152 | component_mask_area = torch.sum(segmap.bool()[i, j]) 153 | 154 | if component_mask_area > 0: 155 | 156 | 157 | middle_mu = F.relu(self.__getattr__('fc_mu' + str(j))(style_codes[i][j])) 158 | component_mu = middle_mu.reshape(self.style_length, 1).expand(self.style_length, component_mask_area) 159 | 160 | middle_avg[i].masked_scatter_(segmap.bool()[i, j], component_mu) 161 | 162 | 163 | if self.status == 'test' and self.save_npy and self.ACE_Name=='up_2_ACE_0': 164 | tmp = style_codes[i][j].cpu().numpy() 165 | dir_path = 'styles_test' 166 | 167 | ############### some problem with obj_dic[i] 168 | 169 | im_name = os.path.basename(obj_dic[i]) 170 | folder_path = os.path.join(dir_path, 'style_codes', im_name, str(j)) 171 | if not os.path.exists(folder_path): 172 | os.makedirs(folder_path) 173 | 174 | style_code_path = os.path.join(folder_path, 'ACE.npy') 175 | np.save(style_code_path, tmp) 176 | 177 | 178 | gamma_avg = self.conv_gamma(middle_avg) 179 | beta_avg = self.conv_beta(middle_avg) 180 | 181 | 182 | gamma_spade, beta_spade = self.Spade(segmap) 183 | 184 | gamma_alpha = F.sigmoid(self.blending_gamma) 185 | beta_alpha = F.sigmoid(self.blending_beta) 186 | 187 | gamma_final = gamma_alpha * gamma_avg + (1 - gamma_alpha) * gamma_spade 188 | beta_final = beta_alpha * beta_avg + (1 - beta_alpha) * beta_spade 189 | out = normalized * (1 + gamma_final) + beta_final 190 | else: 191 | gamma_spade, beta_spade = self.Spade(segmap) 192 | gamma_final = gamma_spade 193 | beta_final = beta_spade 194 | out = normalized * (1 + gamma_final) + beta_final 195 | 196 | return out 197 | 198 | 199 | 200 | 201 | 202 | def create_gamma_beta_fc_layers(self): 203 | 204 | 205 | ################### These codes should be replaced with torch.nn.ModuleList 206 | 207 | style_length = self.style_length 208 | 209 | self.fc_mu0 = nn.Linear(style_length, style_length) 210 | self.fc_mu1 = nn.Linear(style_length, style_length) 211 | self.fc_mu2 = nn.Linear(style_length, style_length) 212 | self.fc_mu3 = nn.Linear(style_length, style_length) 213 | self.fc_mu4 = nn.Linear(style_length, style_length) 214 | self.fc_mu5 = nn.Linear(style_length, style_length) 215 | self.fc_mu6 = nn.Linear(style_length, style_length) 216 | self.fc_mu7 = nn.Linear(style_length, style_length) 217 | self.fc_mu8 = nn.Linear(style_length, style_length) 218 | self.fc_mu9 = nn.Linear(style_length, style_length) 219 | self.fc_mu10 = nn.Linear(style_length, style_length) 220 | self.fc_mu11 = nn.Linear(style_length, style_length) 221 | self.fc_mu12 = nn.Linear(style_length, style_length) 222 | self.fc_mu13 = nn.Linear(style_length, style_length) 223 | self.fc_mu14 = nn.Linear(style_length, style_length) 224 | self.fc_mu15 = nn.Linear(style_length, style_length) 225 | self.fc_mu16 = nn.Linear(style_length, style_length) 226 | self.fc_mu17 = nn.Linear(style_length, style_length) 227 | self.fc_mu18 = nn.Linear(style_length, style_length) 228 | 229 | 230 | 231 | 232 | class SPADE(nn.Module): 233 | def __init__(self, config_text, norm_nc, label_nc): 234 | super().__init__() 235 | 236 | assert config_text.startswith('spade') 237 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 238 | param_free_norm_type = str(parsed.group(1)) 239 | ks = int(parsed.group(2)) 240 | 241 | if param_free_norm_type == 'instance': 242 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 243 | elif param_free_norm_type == 'syncbatch': 244 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 245 | elif param_free_norm_type == 'batch': 246 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 247 | else: 248 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 249 | % param_free_norm_type) 250 | 251 | # The dimension of the intermediate embedding space. Yes, hardcoded. 252 | nhidden = 128 253 | 254 | pw = ks // 2 255 | self.mlp_shared = nn.Sequential( 256 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 257 | nn.ReLU() 258 | ) 259 | 260 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 261 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 262 | 263 | def forward(self, segmap): 264 | 265 | inputmap = segmap 266 | 267 | actv = self.mlp_shared(inputmap) 268 | gamma = self.mlp_gamma(actv) 269 | beta = self.mlp_beta(actv) 270 | 271 | return gamma, beta 272 | -------------------------------------------------------------------------------- /code/train_SAST.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | from itertools import cycle 23 | import numpy as np 24 | import cv2 25 | 26 | from dataloaders import utils 27 | from dataloaders.dataset_covid import (CovidDataSets, RandomGenerator) 28 | from networks.net_factory import net_factory 29 | from utils import losses, metrics, ramps 30 | from test_covid import get_model_metric 31 | from models.pix2pix_model import Pix2PixModel, get_opt 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--root_path', type=str, default='/home/code/SSL/', help='Name of Experiment') 36 | 37 | parser.add_argument('--labeled_per', type=float, default=0.1, help='percent of labeled data') 38 | if False: 39 | parser.add_argument('--dataset_name', type=str, default='COVID249', help='Name of dataset') 40 | parser.add_argument('--excel_file_name_label', type=str, default='train_0.1_l.xlsx', help='Name of dataset') 41 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_0.1_u.xlsx', help='Name of dataset') 42 | 43 | # path1 44 | parser.add_argument('--teacher_path', type=str, default='/home/code/SSL/exp/COVID249/model.pth', help='path of teacher model') 45 | 46 | else: 47 | parser.add_argument('--dataset_name', type=str, default='MOS1000', help='Name of dataset') 48 | parser.add_argument('--excel_file_name_label', type=str, default='train_slice_label.xlsx', help='Name of dataset') 49 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_slice_unlabel.xlsx', help='Name of dataset') 50 | 51 | # path1 52 | parser.add_argument('--teacher_path', type=str, default='/home/code/SSL/exp/MOS1000/model.pth', help='path of teacher model') 53 | 54 | 55 | parser.add_argument('--exp', type=str, default='SAST', help='experiment_name') 56 | parser.add_argument('--consistency_syn', type=float, default=0.5, help='consistency') 57 | parser.add_argument('--consistency_pseudo', type=float, default=0.5, help='consistency') 58 | 59 | parser.add_argument('--model', type=str, default='unet', help='model_name') 60 | parser.add_argument('--max_epoch', type=int, default=20, help='maximum epoch number to train') 61 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 62 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 63 | parser.add_argument('--patch_size', type=list, default=[512, 512], help='patch size of network input') 64 | parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') 65 | 66 | # label and unlabel 67 | parser.add_argument('--batch_size_label', type=int, default=8, help='batch_size per gpu') 68 | parser.add_argument('--batch_size_unlabel', type=int, default=8, help='batch_size per gpu') 69 | 70 | args = parser.parse_args() 71 | 72 | def train(args, snapshot_path): 73 | base_lr = args.base_lr 74 | base_lr = args.base_lr 75 | num_classes = args.num_classes 76 | max_epoch = args.max_epoch 77 | excel_file_name_label = args.excel_file_name_label 78 | excel_file_name_unlabel = args.excel_file_name_unlabel 79 | 80 | 81 | # create model 82 | teacher_model = net_factory(net_type=args.model) 83 | student_model = net_factory(net_type=args.model) 84 | teacher_model.load_state_dict(torch.load(args.teacher_path)) 85 | teacher_model.eval() 86 | opt = get_opt() 87 | syn_model = Pix2PixModel(opt) 88 | syn_model.eval() 89 | 90 | # Define the dataset 91 | labeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_label, aug = True) 92 | unlabeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_unlabel, aug = True) 93 | print('The overall number of labeled training image equals to %d' % len(labeled_train_dataset)) 94 | print('The overall number of unlabeled training images equals to %d' % len(unlabeled_train_dataset)) 95 | 96 | student_model.train() 97 | 98 | optimizer_s = optim.SGD(student_model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 99 | ce_loss = CrossEntropyLoss() 100 | dice_loss = losses.DiceLoss(num_classes) 101 | writer = SummaryWriter(snapshot_path + '/log') 102 | 103 | # Define the dataloader 104 | labeled_dataloader = DataLoader(labeled_train_dataset, batch_size = args.batch_size_label, shuffle = True, num_workers = 4, pin_memory = True) 105 | unlabeled_dataloader = DataLoader(unlabeled_train_dataset, batch_size = args.batch_size_unlabel, shuffle = True, num_workers = 4, pin_memory = True) 106 | 107 | iter_num_s = 0 108 | max_iterations_s = max_epoch * len(unlabeled_dataloader) 109 | for epoch in range(max_epoch): 110 | print("Start epoch ", epoch+1, "!") 111 | 112 | tbar = tqdm(range(len(unlabeled_dataloader)), ncols=70) 113 | labeled_dataloader_iter = iter(labeled_dataloader) 114 | unlabeled_dataloader_iter = iter(unlabeled_dataloader) 115 | 116 | style_output_global_positive_list =[] 117 | 118 | 119 | for batch_idx in tbar: 120 | 121 | try: 122 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 123 | except StopIteration: 124 | labeled_dataloader_iter = iter(labeled_dataloader) 125 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 126 | 127 | style_output_global_positive_list =[] 128 | 129 | input_ul, target_ul, file_name_ul , lung_ul = unlabeled_dataloader_iter.next() 130 | input_ul, target_ul, lung_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True), lung_ul.cuda(non_blocking=True) 131 | input_l, target_l, lung_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True), lung_l.cuda(non_blocking=True) 132 | 133 | 134 | if True: 135 | # generate style codes from labeld data 136 | lung_l[target_l>0] = 3 137 | len_style = input_l.shape[0] 138 | for style_idx in range(len_style): 139 | with torch.no_grad(): 140 | normalize_input = transforms.functional.normalize(input_l[style_idx], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 141 | output_ul_style = syn_model(normalize_input.unsqueeze_(0), lung_l[style_idx:style_idx+1].unsqueeze_(1), mode='style', data_path=[file_name_l[style_idx]]) 142 | output_ul_style_check = torch.mean(output_ul_style, 2) 143 | # label 4 does not exist 144 | if output_ul_style_check[0,3] != 0: 145 | style_output_global_positive_list.append((output_ul_style).detach()) 146 | style_output = style_output_global_positive_list 147 | if len(style_output)>10: 148 | style_output = random.choices(style_output, k=10) 149 | else: 150 | style_output = None 151 | 152 | 153 | 154 | with torch.no_grad(): 155 | t_output = teacher_model(input_ul) 156 | t_output = torch.softmax(t_output, dim=1) 157 | target_ul_pred = torch.argmax(t_output.detach(), dim=1, keepdim=False) 158 | 159 | # generate syn 160 | lung_ul[target_ul_pred>0] = 3 161 | syn_output_list =[] 162 | len_syn = input_ul.shape[0] 163 | for syn_idx in range(len_syn): 164 | normalize_input = transforms.functional.normalize(input_ul[syn_idx], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 165 | output_ul_syn = syn_model(normalize_input.unsqueeze_(0), lung_ul[syn_idx:syn_idx+1].unsqueeze_(1), mode='inference', data_path=[file_name_ul[syn_idx]], 166 | style_code=style_output, alpha=0) 167 | output_ul_syn_numpy = output_ul_syn[0].detach() 168 | output_ul_syn_numpy = (output_ul_syn_numpy + 1) / 2.0 169 | syn_output_list.append((output_ul_syn_numpy).unsqueeze_(0)) 170 | syn_output = torch.cat(syn_output_list, 0) 171 | 172 | 173 | volume_batch = torch.cat([input_l, input_ul, syn_output], 0) 174 | label_batch = torch.cat([target_l, target_ul_pred, target_ul_pred], 0) 175 | 176 | outputs = student_model(volume_batch) 177 | outputs_soft = torch.softmax(outputs, dim=1) 178 | 179 | # calculate loss 180 | labeled_loss = 0.5 * (ce_loss(outputs[:args.batch_size_label], label_batch[:][:args.batch_size_label].long()) + dice_loss( 181 | outputs_soft[:args.batch_size_label], label_batch[:args.batch_size_label].unsqueeze(1))) 182 | pseudo_supervision = 0.5 * (ce_loss(outputs[args.batch_size_label:args.batch_size_label*2], label_batch[:][args.batch_size_label:args.batch_size_label*2].long()) + dice_loss( 183 | outputs_soft[args.batch_size_label:args.batch_size_label*2], label_batch[args.batch_size_label:args.batch_size_label*2].unsqueeze(1))) 184 | syn_supervision = 0.5 * (ce_loss(outputs[args.batch_size_label*2:], label_batch[:][args.batch_size_label*2:].long()) + dice_loss( 185 | outputs_soft[args.batch_size_label*2:], label_batch[args.batch_size_label*2:].unsqueeze(1))) 186 | 187 | 188 | # calculate loss 189 | c_syn = args.consistency_syn 190 | c_pseudo = args.consistency_pseudo 191 | loss = labeled_loss + c_pseudo*pseudo_supervision + c_syn*syn_supervision 192 | 193 | 194 | optimizer_s.zero_grad() 195 | loss.backward() 196 | optimizer_s.step() 197 | lr_ = base_lr * (1.0 - iter_num_s / max_iterations_s) ** 0.9 198 | for param_group in optimizer_s.param_groups: 199 | param_group['lr'] = lr_ 200 | 201 | iter_num_s = iter_num_s + 1 202 | writer.add_scalar('info/lr', lr_, iter_num_s) 203 | writer.add_scalar('info/total_loss', loss, iter_num_s) 204 | writer.add_scalar('info/labeled_loss', labeled_loss, iter_num_s) 205 | writer.add_scalar('info/pseudo_supervision', pseudo_supervision, iter_num_s) 206 | logging.info('iteration %d : loss : %f, labeled_loss: %f, pseudo_supervision: %f' % (iter_num_s, loss.item(), labeled_loss.item(), pseudo_supervision.item())) 207 | 208 | 209 | writer.close() 210 | 211 | 212 | 213 | 214 | if __name__ == "__main__": 215 | cv2.setNumThreads(0) 216 | cv2.ocl.setUseOpenCL(False) 217 | seed = 66 218 | print("[ Using Seed : ", seed, " ]") 219 | torch.manual_seed(seed) 220 | torch.cuda.manual_seed_all(seed) 221 | torch.cuda.manual_seed(seed) 222 | np.random.seed(seed) 223 | random.seed(seed) 224 | torch.backends.cudnn.deterministic = True 225 | torch.backends.cudnn.benchmark = False 226 | os.environ["PYTHONHASHSEED"] = str(seed) 227 | 228 | snapshot_path = "{}exp/{}/exp_{}_{}_{}".format(args.root_path, args.dataset_name, args.exp, args.labeled_per, args.model) 229 | if not os.path.exists(snapshot_path): 230 | os.makedirs(snapshot_path) 231 | 232 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 233 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 234 | logging.info(str(args)) 235 | train(args, snapshot_path) 236 | -------------------------------------------------------------------------------- /code/train_SACPS.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | from itertools import cycle 23 | 24 | from dataloaders import utils 25 | from dataloaders.dataset_covid import (CovidDataSets, RandomGenerator) 26 | from networks.net_factory import net_factory 27 | from utils import losses, metrics, ramps 28 | from test_covid import get_model_metric 29 | import cv2 30 | from models.pix2pix_model import Pix2PixModel, get_opt 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--root_path', type=str, default='/home/', help='Name of Experiment') 34 | 35 | parser.add_argument('--consistency_syn', type=float, default=0.5, help='consistency') 36 | parser.add_argument('--consistency_pseudo', type=float, default=0.5, help='consistency') 37 | 38 | parser.add_argument('--labeled_per', type=float, default=0.1, help='percent of labeled data') 39 | if True: 40 | parser.add_argument('--dataset_name', type=str, default='COVID249', help='Name of dataset') 41 | parser.add_argument('--excel_file_name_label', type=str, default='train_0.1_l.xlsx', help='Name of dataset') 42 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_0.1_u.xlsx', help='Name of dataset') 43 | else: 44 | parser.add_argument('--dataset_name', type=str, default='MOS1000', help='Name of dataset') 45 | parser.add_argument('--excel_file_name_label', type=str, default='train_slice_label.xlsx', help='Name of dataset') 46 | parser.add_argument('--excel_file_name_unlabel', type=str, default='train_slice_unlabel.xlsx', help='Name of dataset') 47 | 48 | parser.add_argument('--exp', type=str, default='sacps', help='experiment_name') 49 | parser.add_argument('--model', type=str, default='unet', help='model_name') 50 | parser.add_argument('--max_epoch', type=int, default=20, help='maximum epoch number to train') 51 | parser.add_argument('--batch_size', type=int, default=8, help='batch_size per gpu') 52 | parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') 53 | parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate') 54 | parser.add_argument('--patch_size', type=list, default=[512, 512], help='patch size of network input') 55 | parser.add_argument('--num_classes', type=int, default=2, help='output channel of network') 56 | 57 | # label and unlabel 58 | parser.add_argument('--labeled_bs', type=int, default=8, help='labeled_batch_size per gpu') 59 | parser.add_argument('--batch_size_label', type=int, default=8, help='batch_size per gpu') 60 | parser.add_argument('--batch_size_unlabel', type=int, default=8, help='batch_size per gpu') 61 | 62 | # costs 63 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 64 | parser.add_argument('--consistency_type', type=str, default="mse", help='consistency_type') 65 | parser.add_argument('--consistency', type=float, default=0.1, help='consistency') 66 | parser.add_argument('--consistency_rampup', type=float, default=10.0, help='consistency_rampup') 67 | parser.add_argument('--alpha', type=float, default=1.0, help='alpha') 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | def train(args, snapshot_path): 73 | base_lr = args.base_lr 74 | num_classes = args.num_classes 75 | batch_size = args.batch_size 76 | max_epoch = args.max_epoch 77 | excel_file_name_label = args.excel_file_name_label 78 | excel_file_name_unlabel = args.excel_file_name_unlabel 79 | 80 | 81 | # create model 82 | model1 = net_factory(net_type=args.model) 83 | model2 = net_factory(net_type=args.model) 84 | opt = get_opt() 85 | syn_model = Pix2PixModel(opt) 86 | syn_model.eval() 87 | 88 | # Define the dataset 89 | labeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_label, aug = True) 90 | unlabeled_train_dataset = CovidDataSets(root_path=args.root_path, dataset_name=args.dataset_name, file_name = excel_file_name_unlabel, aug = True) 91 | print('The overall number of labeled training image equals to %d' % len(labeled_train_dataset)) 92 | print('The overall number of unlabeled training images equals to %d' % len(unlabeled_train_dataset)) 93 | 94 | 95 | # start training 96 | model1.train() 97 | model2.train() 98 | 99 | optimizer1 = optim.SGD(model1.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 100 | optimizer2 = optim.SGD(model2.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) 101 | 102 | 103 | ce_loss = CrossEntropyLoss() 104 | dice_loss = losses.DiceLoss(num_classes) 105 | 106 | writer = SummaryWriter(snapshot_path + '/log') 107 | #logging.info("{} iterations per epoch".format(len(trainloader))) 108 | 109 | iter_num = 0 110 | best_performance1 = 0.0 111 | best_performance2 = 0.0 112 | 113 | # Define the dataloader 114 | labeled_dataloader = DataLoader(labeled_train_dataset, batch_size = args.batch_size_label, shuffle = True, num_workers = 4, pin_memory = True) 115 | unlabeled_dataloader = DataLoader(unlabeled_train_dataset, batch_size = args.batch_size_unlabel, shuffle = True, num_workers = 4, pin_memory = True) 116 | max_iterations = max_epoch * len(unlabeled_dataloader) 117 | 118 | for epoch in range(max_epoch): 119 | print("Start epoch ", epoch, "!") 120 | 121 | style_output_global_positive_list =[] 122 | style_output_global_list =[] 123 | 124 | tbar = tqdm(range(len(unlabeled_dataloader)), ncols=70) 125 | labeled_dataloader_iter = iter(labeled_dataloader) 126 | unlabeled_dataloader_iter = iter(unlabeled_dataloader) 127 | 128 | for batch_idx in tbar: 129 | try: 130 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 131 | except StopIteration: 132 | labeled_dataloader_iter = iter(labeled_dataloader) 133 | input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next() 134 | print('length: style_output_global_positive_list') 135 | print(len(style_output_global_positive_list)) 136 | 137 | style_output_global_positive_list =[] 138 | style_output_global_list =[] 139 | 140 | input_ul, target_ul, file_name_ul , lung_ul = unlabeled_dataloader_iter.next() 141 | input_ul, target_ul, lung_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True), lung_ul.cuda(non_blocking=True) 142 | input_l, target_l, lung_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True), lung_l.cuda(non_blocking=True) 143 | 144 | if input_l.shape[0]!=args.batch_size_label: 145 | continue 146 | 147 | # generate style codes from labeld data 148 | lung_l[target_l>0] = 3 149 | style_output_list =[] 150 | len_style = input_l.shape[0] 151 | for style_idx in range(len_style): 152 | with torch.no_grad(): 153 | normalize_input = transforms.functional.normalize(input_l[style_idx], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 154 | output_ul_style = syn_model(normalize_input.unsqueeze_(0), lung_l[style_idx:style_idx+1].unsqueeze_(1), mode='style', data_path=[file_name_l[style_idx]]) 155 | output_ul_style_check = torch.mean(output_ul_style, 2) 156 | # label 4 does not exist 157 | if output_ul_style_check[0,3] != 0: 158 | style_output_global_positive_list.append((output_ul_style).detach()) 159 | style_output_list.append((output_ul_style).detach()) 160 | 161 | style_output = style_output_global_positive_list #+ style_output_list 162 | if len(style_output)>10: 163 | style_output = random.choices(style_output, k=10) 164 | 165 | # get pseudo labels from model1 for unlabeled data 166 | model1.eval() 167 | with torch.no_grad(): 168 | outputs_unlabeled = model1(input_ul) 169 | outputs_unlabeled_soft = torch.softmax(outputs_unlabeled, dim=1) 170 | pseudo_labels_s1 = torch.argmax(outputs_unlabeled_soft.detach(), dim=1, keepdim=False) 171 | model1.train() 172 | 173 | # get pseudo labels from model2 for unlabeled data 174 | model2.eval() 175 | with torch.no_grad(): 176 | outputs_unlabeled = model2(input_ul) 177 | outputs_unlabeled_soft = torch.softmax(outputs_unlabeled, dim=1) 178 | pseudo_labels_s2 = torch.argmax(outputs_unlabeled_soft.detach(), dim=1, keepdim=False) 179 | model2.train() 180 | 181 | # exchange pseudo label 182 | pseudo_labels_1 = pseudo_labels_s2 183 | pseudo_labels_2 = pseudo_labels_s1 184 | 185 | 186 | # generate syn for model 1 187 | syn_mask = lung_ul 188 | syn_mask[pseudo_labels_1>0] = 3 189 | syn_output_list =[] 190 | len_syn = input_ul.shape[0] 191 | for syn_idx in range(len_syn): 192 | with torch.no_grad(): 193 | normalize_input = transforms.functional.normalize(input_ul[syn_idx], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 194 | output_ul_syn = syn_model(normalize_input.unsqueeze_(0), syn_mask[syn_idx:syn_idx+1].unsqueeze_(1), mode='inference', data_path=[file_name_ul[syn_idx]], 195 | style_code=style_output, alpha=0) 196 | output_ul_syn_numpy = output_ul_syn[0].detach() 197 | output_ul_syn_numpy = (output_ul_syn_numpy + 1) / 2.0 198 | syn_output_list.append((output_ul_syn_numpy).unsqueeze_(0)) 199 | syn_output_1 = torch.cat(syn_output_list, 0) 200 | 201 | # generate syn for model 2 202 | syn_mask = lung_ul 203 | syn_mask[pseudo_labels_2>0] = 3 204 | syn_output_list =[] 205 | len_syn = input_ul.shape[0] 206 | for syn_idx in range(len_syn): 207 | with torch.no_grad(): 208 | normalize_input = transforms.functional.normalize(input_ul[syn_idx], [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 209 | output_ul_syn = syn_model(normalize_input.unsqueeze_(0), syn_mask[syn_idx:syn_idx+1].unsqueeze_(1), mode='inference', data_path=[file_name_ul[syn_idx]], 210 | style_code=style_output, alpha=alpha) 211 | output_ul_syn_numpy = output_ul_syn[0].detach() 212 | output_ul_syn_numpy = (output_ul_syn_numpy + 1) / 2.0 213 | syn_output_list.append((output_ul_syn_numpy).unsqueeze_(0)) 214 | syn_output_2 = torch.cat(syn_output_list, 0) 215 | 216 | 217 | 218 | # train model 1 219 | volume_batch = torch.cat([input_l, syn_output_1, input_ul], 0) 220 | label_batch = torch.cat([target_l, pseudo_labels_1, pseudo_labels_1], 0) 221 | 222 | outputs_1 = model1(volume_batch) 223 | outputs_soft_1 = torch.softmax(outputs_1, dim=1) 224 | 225 | labeled_loss_1 = 0.5 * (ce_loss(outputs_1[:args.batch_size_label], label_batch[:][:args.batch_size_label].long()) + dice_loss( 226 | outputs_soft_1[:args.batch_size_label], label_batch[:args.batch_size_label].unsqueeze(1))) 227 | syn_supervision_1 = 0.5 * (ce_loss(outputs_1[args.batch_size_label:args.batch_size_label*2], label_batch[args.batch_size_label:args.batch_size_label*2].long()) + dice_loss( 228 | outputs_soft_1[args.batch_size_label:args.batch_size_label*2], label_batch[args.batch_size_label:args.batch_size_label*2].unsqueeze(1))) 229 | pseudo_supervision_1 = 0.5 * (ce_loss(outputs_1[args.batch_size_label*2:], label_batch[args.batch_size_label*2:].long()) + dice_loss( 230 | outputs_soft_1[args.batch_size_label*2:], label_batch[args.batch_size_label*2:].unsqueeze(1))) 231 | 232 | # train model 2 233 | volume_batch = torch.cat([input_l, syn_output_2, input_ul], 0) 234 | label_batch = torch.cat([target_l, pseudo_labels_2, pseudo_labels_2], 0) 235 | 236 | outputs_2 = model2(volume_batch) 237 | outputs_soft_2 = torch.softmax(outputs_2, dim=1) 238 | 239 | labeled_loss_2 = 0.5 * (ce_loss(outputs_2[:args.batch_size_label], label_batch[:][:args.batch_size_label].long()) + dice_loss( 240 | outputs_soft_2[:args.batch_size_label], label_batch[:args.batch_size_label].unsqueeze(1))) 241 | syn_supervision_2 = 0.5 * (ce_loss(outputs_2[args.batch_size_label:args.batch_size_label*2], label_batch[:][args.batch_size_label:args.batch_size_label*2].long()) + dice_loss( 242 | outputs_soft_2[args.batch_size_label:args.batch_size_label*2], label_batch[args.batch_size_label:args.batch_size_label*2].unsqueeze(1))) 243 | pseudo_supervision_2 = 0.5 * (ce_loss(outputs_2[args.batch_size_label*2:], label_batch[:][args.batch_size_label*2:].long()) + dice_loss( 244 | outputs_soft_2[args.batch_size_label*2:], label_batch[args.batch_size_label*2:].unsqueeze(1))) 245 | 246 | 247 | 248 | # calculate loss 249 | c_syn = args.consistency_syn 250 | c_pseudo = args.consistency_pseudo 251 | 252 | model1_loss = labeled_loss_1 + c_syn*syn_supervision_1 + c_pseudo*pseudo_supervision_1 253 | model2_loss = labeled_loss_2 + c_syn*syn_supervision_2 + c_pseudo*pseudo_supervision_2 254 | loss = model1_loss + model2_loss 255 | 256 | 257 | optimizer1.zero_grad() 258 | optimizer2.zero_grad() 259 | loss.backward() 260 | optimizer1.step() 261 | optimizer2.step() 262 | 263 | 264 | # write summary 265 | iter_num = iter_num + 1 266 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 267 | for param_group in optimizer1.param_groups: 268 | param_group['lr'] = lr_ 269 | for param_group in optimizer2.param_groups: 270 | param_group['lr'] = lr_ 271 | writer.add_scalar('lr', lr_, iter_num) 272 | writer.add_scalar('loss/model1_loss', model1_loss, iter_num) 273 | writer.add_scalar('loss/model2_loss', model2_loss, iter_num) 274 | logging.info('iteration %d : model1 loss : %f model2 loss : %f' % (iter_num, model1_loss.item(), model2_loss.item())) 275 | 276 | writer.close() 277 | 278 | 279 | if __name__ == "__main__": 280 | cv2.setNumThreads(0) 281 | cv2.ocl.setUseOpenCL(False) 282 | seed = 66 283 | print("[ Using Seed : ", seed, " ]") 284 | torch.manual_seed(seed) 285 | torch.cuda.manual_seed_all(seed) 286 | torch.cuda.manual_seed(seed) 287 | np.random.seed(seed) 288 | random.seed(seed) 289 | torch.backends.cudnn.deterministic = True 290 | torch.backends.cudnn.benchmark = False 291 | os.environ["PYTHONHASHSEED"] = str(seed) 292 | 293 | snapshot_path = "{}exp/{}/exp_{}_{}_{}".format(args.root_path, args.dataset_name, args.exp, args.labeled_per, args.model) 294 | if not os.path.exists(snapshot_path): 295 | os.makedirs(snapshot_path) 296 | 297 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 298 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 299 | logging.info(str(args)) 300 | train(args, snapshot_path) 301 | --------------------------------------------------------------------------------