├── .gitignore ├── LICENSE ├── README.md ├── acs_joint_train.py ├── acs_train.py ├── architecture.png ├── args.py ├── experiments.txt ├── kd_train.py ├── logs ├── cas │ ├── ours_1.log │ └── ours_2.log ├── cas_ablation │ ├── ours_no_c_adv_0.log │ ├── ours_no_c_adv_1.log │ ├── ours_no_c_adv_2.log │ ├── ours_no_gan_0.log │ ├── ours_no_gan_1.log │ ├── ours_no_gan_2.log │ ├── ours_no_ganvaelr_0.log │ ├── ours_no_ganvaelr_1.log │ ├── ours_no_ganvaelr_2.log │ ├── ours_no_vae_0.log │ ├── ours_no_vae_1.log │ └── ours_no_vae_2.log ├── cas_joint │ └── ours_joint.log ├── kd │ ├── kd_0.log │ ├── kd_1.log │ └── kd_2.log ├── kd_lambda │ ├── kd_lambda_01.log │ ├── kd_lambda_05.log │ └── kd_lambda_1.log ├── mas │ ├── mas_0.log │ ├── mas_1.log │ └── mas_2.log ├── mas_lambda │ ├── mas_lambda_01.log │ ├── mas_lambda_05.log │ └── mas_lambda_1.log ├── u_net │ ├── unet_0.log │ ├── unet_1.log │ └── unet_2.log ├── u_net_b │ ├── unet_b_0.log │ ├── unet_b_1.log │ └── unet_b_2.log └── u_net_joint │ └── unet_joint.log ├── mas_train.py ├── mp ├── agents │ ├── acs_agent.py │ ├── agent.py │ ├── autoencoding_agent.py │ ├── kd_agent.py │ ├── mas_agent.py │ ├── segmentation_agent.py │ └── unet_agent.py ├── data │ ├── data.py │ ├── datasets │ │ ├── dataset.py │ │ ├── dataset_classification.py │ │ ├── dataset_segmentation.py │ │ ├── dataset_utils.py │ │ ├── ds_mr_cardiac_mm.py │ │ ├── ds_mr_hippocampus_decathlon.py │ │ ├── ds_mr_hippocampus_dryad.py │ │ ├── ds_mr_hippocampus_harp.py │ │ └── ds_mr_prostate_decathlon.py │ └── pytorch │ │ ├── pytorch_dataset.py │ │ ├── pytorch_seg_dataset.py │ │ └── transformation.py ├── eval │ ├── accumulator.py │ ├── evaluate.py │ ├── inference │ │ ├── predict.py │ │ └── predictor.py │ ├── losses │ │ ├── loss_abstract.py │ │ ├── losses_autoencoding.py │ │ └── losses_segmentation.py │ ├── metrics │ │ ├── mean_scores.py │ │ └── scores.py │ └── result.py ├── experiments │ ├── data_splitting.py │ └── experiment.py ├── models │ ├── autoencoding │ │ ├── autoencoder.py │ │ ├── autoencoder_cnn.py │ │ ├── autoencoder_featured.py │ │ └── autoencoder_linear.py │ ├── classification │ │ └── small_cnn.py │ ├── continual │ │ ├── acs.py │ │ ├── kd.py │ │ ├── mas.py │ │ └── model_utils.py │ ├── model.py │ └── segmentation │ │ ├── model_utils.py │ │ ├── segmentation_model.py │ │ ├── unet_fepegar.py │ │ └── unet_milesial.py ├── paths.py ├── utils │ ├── connection │ │ └── check_connection.py │ ├── helper_functions.py │ ├── introspection.py │ ├── load_restore.py │ ├── pytorch │ │ ├── compute_normalization_values.py │ │ └── pytorch_load_restore.py │ ├── seaborn │ │ └── legend_utils.py │ ├── tensorboard.py │ └── update_bots │ │ └── telegram_bot.py └── visualization │ ├── confusion_matrix.py │ ├── plot_results.py │ └── visualize_imgs.py ├── qualitative_results.png ├── requirements.txt ├── setup.py ├── test ├── agents │ └── model_state_restore.py ├── cuda │ └── test_cuda.py ├── data │ ├── datasets │ │ ├── test_ds_mr_cardiac_mm.py │ │ └── test_ds_mr_prostate_decathlon.py │ └── pytorch │ │ └── test_transformation.py ├── eval │ ├── inference │ │ └── test_predict.py │ ├── losses │ │ └── test_losses_segmentation.py │ ├── metrics │ │ └── test_metrics_segmentation.py │ ├── test_accumulator.py │ └── test_result.py ├── experiment │ ├── test_data_splitting.py │ └── test_experiment.py ├── test_obj │ ├── 3dimg.png │ ├── 3dsegm.png │ ├── README.txt │ ├── agent_states_prostate_2D │ │ └── epoch_300 │ │ │ ├── agent_state_dict.pkl │ │ │ ├── model │ │ │ └── optimizer │ ├── example_result.png │ ├── img_00.nii │ ├── mask_00.nii │ └── test_confusion_matrix.png ├── utils │ ├── test_helper_functions.py │ └── test_introspection.py └── visualization │ ├── test_confusion_matrix.py │ ├── test_plot_results.py │ └── test_visualize_imgs.py └── unet_joint_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore the following directories and files: 2 | /ignored/* 3 | /storage/* 4 | paths.py 5 | 6 | # Private files (e.g. sontaining private data information) 7 | *_private.py 8 | 9 | # VS code project setup 10 | .vscode 11 | 12 | # Byte-compiled / optimized / DLL files 13 | *__pycache__* 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | 116 | # dataset files 117 | dataset -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 camgbus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Continual Learning for Multi-Domain Hippocampal Segmentation 2 | 3 | ## Abstract 4 | Deep learning for medical imaging suffers from temporal and privacy-related restrictions on data availability. To still obtain viable models, continual learning aims to train in sequential order, as and when data is available. The main challenge that continual learning methods face is to prevent catastrophic forgetting, i.e., a decrease in performance on the data encountered earlier. This issue makes continuous training of segmentation models for medical applications extremely difficult. Yet, often, data from at least two different domains is available which we can exploit to train the model in a way that it disregards domain-specific information. We propose an architecture that leverages the simultaneous availability of two or more datasets to learn a disentanglement between the content and domain in an adversarial fashion. The domain-invariant content representation then lays the base for continual semantic segmentation. Our approach takes inspiration from domain adaptation and combines it with continual learning for hippocampal segmentation in brain MRI. We showcase that our method reduces catastrophic forgetting and outperforms state-of-the-art continual learning methods. 5 | 6 | For more information please refer to our [paper](https://arxiv.org/abs/2107.08751). 7 | 8 | ## ACS Architecture 9 | ![alt text](https://github.com/memmelma/continual_adversarial_segmenter/raw/master/architecture.png) 10 | 11 | ## Qualitative Results 12 | ![alt text](https://github.com/memmelma/continual_adversarial_segmenter/raw/master/qualitative_results.png) 13 | Legend: **VP MRI** (original), **GT** (groud truth segmentation), **ACS** (segmentation), **GAN O/P** (output of GAN generator) 14 | 15 | ## Setup 16 | This repository builds on [medical_pytorch](https://github.com/camgbus/medical_pytorch) and [torchio](https://github.com/fepegar/torchio). Please install this repository as explained in [medical_pytorch](https://github.com/camgbus/medical_pytorch). We provide an implementation of the [adversarial continual segmenter (ACS)](https://arxiv.org/abs/2107.08751), and the baselines [memory aware synapses (MAS)](https://arxiv.org/abs/2005.00079), and [knowledge distillation (KD)](https://arxiv.org/abs/1907.13372). 17 | 18 | Our core implementation consists of the following structure: 19 | ``` 20 | mp 21 | ├── agents 22 | | ├── ACSAgent 23 | | ├── KDAgent 24 | | ├── MASAgent 25 | | ├── UNETAgent 26 | ├── data 27 | | ├── PytorchSeg2DDatasetDomain 28 | ├── models 29 | | ├── continual 30 | | ├── ACS 31 | | ├── MAS 32 | | ├── KD 33 | ├── *_train.py 34 | ``` 35 | 36 | ## Usage 37 | Use the `*_train.py` scripts to run the experiments and set arguments corresponding to `args.py`. We also provide the execution commands that we used to produce the results in `experiments.txt` and provide console logs of the runs in `logs/`. 38 | 39 | ## Datasets 40 | Datasets should be placed in `storage/data/` and can be loaded via dataloaders provided in `mp.data.datasets` or by custom implementations. We use the following three datasets: 41 | - _DecathlonHippocampus_ [A](http://medicaldecathlon.com/) 42 | - _DryadHippocampus_ [B](https://www.nature.com/articles/sdata201559) 43 | - _HarP_ [C](https://pubmed.ncbi.nlm.nih.gov/25616957/). 44 | 45 | ## Additional Features 46 | - extensive logging via tensorboard 47 | - load/save/resume training 48 | - multi GPU support 49 | 50 | ## Acknowledgements 51 | Supported by the _Bundesministerium für Gesundheit_ (BMG) with grant [ZMVI1- 2520DAT03A] 52 | -------------------------------------------------------------------------------- /acs_joint_train.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Code to train ACS on all datasets simultaneously 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import sys 8 | from args import parse_args_as_dict 9 | 10 | import torch 11 | torch.set_num_threads(6) 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | 15 | from mp.experiments.experiment import Experiment 16 | from mp.data.data import Data 17 | from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus 18 | from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus 19 | from mp.data.datasets.ds_mr_hippocampus_harp import HarP 20 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDatasetDomain 21 | from mp.eval.losses.losses_segmentation import LossClassWeighted, LossDiceBCE 22 | from mp.agents.acs_agent import ACS 23 | from mp.eval.result import Result 24 | from mp.utils.tensorboard import create_writer 25 | from mp.utils.helper_functions import seed_all 26 | from mp.models.continual.acs import ACS 27 | 28 | # Get configuration from arguments 29 | config = parse_args_as_dict(sys.argv[1:]) 30 | seed_all(42) 31 | 32 | config['class_weights'] = (0., 1.) 33 | 34 | print('config', config) 35 | 36 | # Create experiment directories 37 | exp = Experiment(config=config, name=config['experiment_name'], notes='', reload_exp=(config['resume_epoch'] is not None)) 38 | 39 | # Datasets 40 | data = Data() 41 | 42 | dataset_domain_a = DecathlonHippocampus(merge_labels=True) 43 | dataset_domain_a.name = 'DecathlonHippocampus' 44 | data.add_dataset(dataset_domain_a) 45 | 46 | dataset_domain_b = DryadHippocampus(merge_labels=True) 47 | dataset_domain_b.name = 'DryadHippocampus' 48 | data.add_dataset(dataset_domain_b) 49 | 50 | dataset_domain_c = HarP(merge_labels=True) 51 | dataset_domain_c.name = 'HarP' 52 | data.add_dataset(dataset_domain_c) 53 | 54 | nr_labels = data.nr_labels 55 | label_names = data.label_names 56 | 57 | if config['combination'] == 0: 58 | ds_a = ('DecathlonHippocampus', 'train') 59 | ds_b = ('DryadHippocampus', 'train') 60 | ds_c = ('HarP', 'train') 61 | elif config['combination'] == 1: 62 | ds_a = ('DecathlonHippocampus', 'train') 63 | ds_c = ('DryadHippocampus', 'train') 64 | ds_b = ('HarP', 'train') 65 | elif config['combination'] == 2: 66 | ds_c = ('DecathlonHippocampus', 'train') 67 | ds_b = ('DryadHippocampus', 'train') 68 | ds_a = ('HarP', 'train') 69 | 70 | # Create data splits for each repetition 71 | exp.set_data_splits(data) 72 | 73 | # Now repeat for each repetition 74 | for run_ix in range(config['nr_runs']): 75 | exp_run = exp.get_run(run_ix=0, reload_exp_run=(config['resume_epoch'] is not None)) 76 | 77 | # Bring data to Pytorch format and add domain_code 78 | datasets = dict() 79 | for idx, item in enumerate(data.datasets.items()): 80 | ds_name, ds = item 81 | for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items(): 82 | data_ixs = data_ixs[:config['n_samples']] 83 | if len(data_ixs) > 0: # Sometimes val indexes may be an empty list 84 | aug = config['augmentation'] if not('test' in split) else 'none' 85 | datasets[(ds_name, split)] = PytorchSeg2DDatasetDomain(ds, 86 | ix_lst=data_ixs, size=config['input_shape'] , aug_key=aug, 87 | resize=(not config['no_resize']), domain_code=idx, domain_code_size=config['domain_code_size']) 88 | 89 | dataset = torch.utils.data.ConcatDataset((datasets[(ds_a)], datasets[(ds_b)], datasets[(ds_c)])) 90 | train_dataloader_0 = DataLoader(dataset, batch_size=config['batch_size'], drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 91 | 92 | if config['eval']: 93 | drop = [] 94 | for key in datasets.keys(): 95 | if 'train' in key or 'val' in key: 96 | drop += [key] 97 | for d in drop: 98 | datasets.pop(d) 99 | 100 | model = ACS(input_shape=config['input_shape'], nr_labels=nr_labels, domain_code_size=config['domain_code_size'], latent_scaler_sample_size=250, 101 | unet_dropout=config['unet_dropout'], unet_monte_carlo_dropout=config['unet_monte_carlo_dropout'], unet_preactivation=config['unet_preactivation']) 102 | 103 | model.to(config['device']) 104 | 105 | # Define loss and optimizer 106 | loss_g = LossDiceBCE(bce_weight=1., smooth=1., device=config['device']) 107 | loss_f = LossClassWeighted(loss=loss_g, weights=config['class_weights'], device=config['device']) 108 | 109 | # Set optimizers 110 | model.set_optimizers(optim.Adam, lr=config['lr']) 111 | 112 | # Train model 113 | results = Result(name='training_trajectory') 114 | 115 | agent = ACSAgent(model=model, label_names=label_names, device=config['device']) 116 | agent.summary_writer = create_writer(os.path.join(exp_run.paths['states'], '..'), 0) 117 | 118 | init_epoch = 0 119 | nr_epochs = config['epochs'] 120 | 121 | # Resume training 122 | if config['resume_epoch'] is not None: 123 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 124 | init_epoch = agent.agent_state_dict['epoch'] + 1 125 | 126 | config['continual'] = False 127 | 128 | # Joint Training 129 | agent.train(results, loss_f, train_dataloader_0, train_dataloader_0, config, 130 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 131 | eval_datasets=datasets, eval_interval=config['eval_interval'], 132 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 133 | display_interval=config['display_interval'], 134 | resume_epoch=config['resume_epoch'], device_ids=config['device_ids']) 135 | 136 | print('Finished training on A and B and C') 137 | 138 | # Save and print results for this experiment run 139 | exp_run.finish(results=results, plot_metrics=['Mean_LossBCEWithLogits', 'Mean_LossDice[smooth=1.0]', 'Mean_LossCombined[1.0xLossDice[smooth=1.0]+1.0xLossBCEWithLogits]']) 140 | -------------------------------------------------------------------------------- /acs_train.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Code to train ACS in continual fashion 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import sys 8 | from args import parse_args_as_dict 9 | 10 | import torch 11 | torch.set_num_threads(6) 12 | from torch.utils.data import DataLoader 13 | import torch.optim as optim 14 | 15 | from mp.utils.helper_functions import seed_all 16 | from mp.experiments.experiment import Experiment 17 | from mp.data.data import Data 18 | from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus 19 | from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus 20 | from mp.data.datasets.ds_mr_hippocampus_harp import HarP 21 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDatasetDomain 22 | from mp.eval.losses.losses_segmentation import LossClassWeighted, LossDiceBCE 23 | from mp.eval.result import Result 24 | from mp.models.continual.acs import ACS 25 | from mp.agents.acs_agent import ACSAgent 26 | from mp.utils.tensorboard import create_writer 27 | 28 | # Get configuration from arguments 29 | config = parse_args_as_dict(sys.argv[1:]) 30 | seed_all(42) 31 | config['class_weights'] = (0., 1.) 32 | print('config', config) 33 | 34 | # Create experiment directories 35 | exp = Experiment(config=config, name=config['experiment_name'], notes='', reload_exp=(config['resume_epoch'] is not None)) 36 | 37 | # Datasets 38 | data = Data() 39 | 40 | dataset_domain_a = DecathlonHippocampus(merge_labels=True) 41 | dataset_domain_a.name = 'DecathlonHippocampus' 42 | data.add_dataset(dataset_domain_a) 43 | 44 | dataset_domain_b = DryadHippocampus(merge_labels=True) 45 | dataset_domain_b.name = 'DryadHippocampus' 46 | data.add_dataset(dataset_domain_b) 47 | 48 | dataset_domain_c = HarP(merge_labels=True) 49 | dataset_domain_c.name = 'HarP' 50 | data.add_dataset(dataset_domain_c) 51 | 52 | nr_labels = data.nr_labels 53 | label_names = data.label_names 54 | 55 | if config['combination'] == 0: 56 | ds_a = ('DecathlonHippocampus', 'train') 57 | ds_b = ('DryadHippocampus', 'train') 58 | ds_c = ('HarP', 'train') 59 | elif config['combination'] == 1: 60 | ds_a = ('DecathlonHippocampus', 'train') 61 | ds_c = ('DryadHippocampus', 'train') 62 | ds_b = ('HarP', 'train') 63 | elif config['combination'] == 2: 64 | ds_c = ('DecathlonHippocampus', 'train') 65 | ds_b = ('DryadHippocampus', 'train') 66 | ds_a = ('HarP', 'train') 67 | 68 | # Create data splits for each repetition 69 | exp.set_data_splits(data) 70 | 71 | # Now repeat for each repetition 72 | for run_ix in range(config['nr_runs']): 73 | exp_run = exp.get_run(run_ix=0, reload_exp_run=(config['resume_epoch'] is not None)) 74 | 75 | # Bring data to Pytorch format and add domain_code 76 | datasets = dict() 77 | for idx, item in enumerate(data.datasets.items()): 78 | ds_name, ds = item 79 | for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items(): 80 | data_ixs = data_ixs[:config['n_samples']] 81 | if len(data_ixs) > 0: 82 | aug = config['augmentation'] if not('test' in split) else 'none' 83 | datasets[(ds_name, split)] = PytorchSeg2DDatasetDomain(ds, 84 | ix_lst=data_ixs, size=config['input_shape'] , aug_key=aug, 85 | resize=(not config['no_resize']), domain_code=idx, domain_code_size=config['domain_code_size']) 86 | 87 | dataset = torch.utils.data.ConcatDataset((datasets[(ds_a)], datasets[(ds_b)])) 88 | train_dataloader_0 = DataLoader(dataset, batch_size=config['batch_size'], drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 89 | train_dataloader_1 = DataLoader(datasets[(ds_c)], batch_size=config['batch_size'], shuffle=True, drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 90 | 91 | if config['eval']: 92 | drop = [] 93 | for key in datasets.keys(): 94 | if 'train' in key or 'val' in key: 95 | drop += [key] 96 | for d in drop: 97 | datasets.pop(d) 98 | 99 | model = ACS(input_shape=config['input_shape'], nr_labels=nr_labels, domain_code_size=config['domain_code_size'], latent_scaler_sample_size=250, 100 | unet_dropout=config['unet_dropout'], unet_monte_carlo_dropout=config['unet_monte_carlo_dropout'], unet_preactivation=config['unet_preactivation']) 101 | 102 | model.to(config['device']) 103 | 104 | # Define loss and optimizer 105 | loss_g = LossDiceBCE(bce_weight=1., smooth=1., device=config['device']) 106 | loss_f = LossClassWeighted(loss=loss_g, weights=config['class_weights'], device=config['device']) 107 | 108 | # Set optimizers 109 | model.set_optimizers(optim.Adam, lr=config['lr']) 110 | 111 | # Setup model 112 | results = Result(name='training_trajectory') 113 | 114 | agent = ACSAgent(model=model, label_names=label_names, device=config['device']) 115 | agent.summary_writer = create_writer(os.path.join(exp_run.paths['states'], '..'), 0) 116 | 117 | init_epoch = 0 118 | nr_epochs = config['epochs'] // 2 119 | 120 | # Resume training 121 | if config['resume_epoch'] is not None: 122 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 123 | init_epoch = agent.agent_state_dict['epoch'] + 1 124 | 125 | # Training Stage 1 126 | if init_epoch < config['epochs'] / 2: 127 | config['continual'] = False 128 | agent.train(results, loss_f, train_dataloader_0, train_dataloader_1, config, 129 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 130 | eval_datasets=datasets, eval_interval=config['eval_interval'], 131 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 132 | display_interval=config['display_interval'], 133 | resume_epoch=config['resume_epoch'], device_ids=config['device_ids']) 134 | 135 | print('Finished training on A and B, starting training on C') 136 | 137 | init_epoch = config['epochs'] // 2 138 | nr_epochs = config['epochs'] 139 | 140 | # Resume training 141 | if config['resume_epoch'] is not None: 142 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 143 | init_epoch = agent.agent_state_dict['epoch'] + 1 144 | 145 | if init_epoch >= config['epochs'] / 2: 146 | 147 | # Freeze model for fine-tuning 148 | config['continual'] = True 149 | for param in model.parameters(): 150 | param.requires_grad = False 151 | 152 | # Unfreeze last two U-Net blocks 153 | if len(config['device_ids']) > 1: 154 | for param in model.unet.decoder.module.decoding_blocks[-2].parameters(): 155 | param.requires_grad = True 156 | for param in model.unet.decoder.module.decoding_blocks[-1].parameters(): 157 | param.requires_grad = True 158 | for param in model.unet.classifier.parameters(): 159 | param.requires_grad = True 160 | else: 161 | for param in model.unet.decoder.decoding_blocks[-2].parameters(): 162 | param.requires_grad = True 163 | for param in model.unet.decoder.decoding_blocks[-1].parameters(): 164 | param.requires_grad = True 165 | for param in model.unet.classifier.parameters(): 166 | param.requires_grad = True 167 | 168 | # Set optimizers 169 | model.set_optimizers(optim.Adam, lr=config['lr_2']) 170 | config['continual'] = True 171 | model.unet_scheduler = torch.optim.lr_scheduler.StepLR(model.unet_optim, (nr_epochs-init_epoch), gamma=0.1, last_epoch=-1) 172 | 173 | # Training Stage 2 174 | agent.train(results, loss_f, train_dataloader_1, train_dataloader_0, config, 175 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 176 | eval_datasets=datasets, eval_interval=config['eval_interval'], 177 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 178 | display_interval=config['display_interval'], 179 | resume_epoch=config['resume_epoch'], device_ids=[0]) 180 | 181 | print('Finished training on C') 182 | 183 | # Save and print results for this experiment run 184 | exp_run.finish(results=results, plot_metrics=['Mean_LossBCEWithLogits', 'Mean_LossDice[smooth=1.0]', 'Mean_LossCombined[1.0xLossDice[smooth=1.0]+1.0xLossBCEWithLogits]']) 185 | -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/architecture.png -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | # parse train options 5 | def _get_parser(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # general 9 | parser.add_argument('--experiment-name', type=str, default='', help='experiment name for new or resume') 10 | parser.add_argument('--nr-runs', type=int, default=1, help='# of runs') 11 | 12 | # hardware 13 | parser.add_argument('--device', type=str, default='cuda', help='device type cpu or cuda') 14 | parser.add_argument("--device-ids", nargs="+", default=[0], type=int, help="ID(s) of GPU device(s)") 15 | parser.add_argument('--n-workers', type=int, default=2, help='# multiplied by # of GPU to get # of total workers') 16 | 17 | # dataset 18 | parser.add_argument('--test-ratio', type=float, default=0.2, help='ratio of data to be used for testing') 19 | parser.add_argument('--val-ratio', type=float, default=0.125, help='ratio of data to be used for validation') 20 | parser.add_argument('--input_dim_c', type=int, default=1, help='input channels for images') 21 | parser.add_argument('--input_dim_hw', type=int, default=256, help='height and width for images') 22 | parser.add_argument('--no-resize', action='store_true', help='specify if images should not be resized') 23 | parser.add_argument('--augmentation', type=str, default='none', help='augmentation to be used') 24 | parser.add_argument('--n-samples', type=int, default=None, help='# of samples per dataloader, only use when debugging') 25 | parser.add_argument('--sampler', action='store_true', help='sample datasets to have equal # of samples') 26 | parser.add_argument('--combination', type=int, default=0, help='0: ab->c, 1: ac->b, 2: bc->a') 27 | 28 | # training 29 | parser.add_argument('--epochs', type=int, default=60, help='# of epochs') 30 | parser.add_argument('--batch-size', type=int, default=32, help='batch size') 31 | parser.add_argument('--lr', type=float, default=2e-4, help='learning rate for training stage 1') 32 | parser.add_argument('--lr-2', type=float, default=1e-4, help='learning rate for training stage 2') 33 | parser.add_argument('--domain-code-size', type=int, default=3, help='# of domains present in training stage 1') 34 | parser.add_argument('--cross-validation', action='store_true', help='specify if cross validation should be used') 35 | parser.add_argument('--d-iter', type=int, default=1, help='discriminator update iterations per epoch') 36 | 37 | # resume training 38 | parser.add_argument('--resume-epoch', type=int, default=None, help='resume training at epoch, -1 for latest, select run using experiment-name argument') 39 | 40 | # logging 41 | parser.add_argument('--eval-interval', type=int, default=7, help='evaluation interval (on all datasets)') 42 | parser.add_argument('--save-interval', type=int, default=1, help='save interval') 43 | parser.add_argument('--display-interval', type=int, default=1, help='display/tensorboard interval') 44 | 45 | # evaluation 46 | parser.add_argument('--eval', action='store_true', help='for general evaluation (only eval on test)') 47 | parser.add_argument('--lambda-eval', action='store_true', help='for tuning lambda (only eval on val)') 48 | 49 | # loss weighting 50 | parser.add_argument('--lambda-vae', type=float, default=5, help='lambda tuning vae loss') 51 | parser.add_argument('--lambda-c-adv', type=float, default=1, help='lambda tuning content adversarial loss') 52 | parser.add_argument('--lambda-lcr', type=float, default=1e-4, help='lambda tuning lcr loss') 53 | parser.add_argument('--lambda-seg', type=float, default=5, help='lambda tuning segmentation loss') 54 | parser.add_argument('--lambda-c-recon', type=float, default=0, help='lambda tuning content reconstruction loss') 55 | parser.add_argument('--lambda-gan', type=float, default=5, help='lambda tuning gan loss') 56 | parser.add_argument('--lambda-d', type=float, default=1, help='lambda for tuning MAS or KD loss') 57 | 58 | # U-Net 59 | parser.add_argument('--unet-only', action='store_true', help='only train UNet') 60 | parser.add_argument('--unet-dropout', type=float, default=0, help='apply dropout to UNet') 61 | parser.add_argument('--unet-monte-carlo-dropout', type=float, default=0, help='apply monte carlo dropout to UNet') 62 | parser.add_argument('--unet-preactivation', action='store_true', help='UNet preactivation; True: norm, act, conv; False:conv, norm, act') 63 | 64 | return parser 65 | 66 | def parse_args(argv): 67 | """Parses arguments passed from the console as, e.g. 68 | 'python ptt/main.py --epochs 3' """ 69 | 70 | parser = _get_parser() 71 | args = parser.parse_args(argv) 72 | 73 | args.device = str(args.device+':'+str(args.device_ids[0]) if torch.cuda.is_available() and args.device == "cuda" else "cpu") 74 | device_name = str(torch.cuda.get_device_name(args.device) if args.device == "cuda" else args.device) 75 | print('Device name: {}'.format(device_name)) 76 | args.input_shape = (args.input_dim_c, args.input_dim_hw, args.input_dim_hw) 77 | 78 | # add dummy class 79 | args.domain_code_size = args.domain_code_size + 1 80 | 81 | return args 82 | 83 | def parse_args_as_dict(argv): 84 | """Parses arguments passed from the console and returns a dictionary """ 85 | return vars(parse_args(argv)) 86 | 87 | def parse_dict_as_args(dictionary): 88 | """Parses arguments given in a dictionary form""" 89 | argv = [] 90 | for key, value in dictionary.items(): 91 | if isinstance(value, bool): 92 | if value: 93 | argv.append('--'+key) 94 | else: 95 | argv.append('--'+key) 96 | argv.append(str(value)) 97 | return parse_args(argv) -------------------------------------------------------------------------------- /experiments.txt: -------------------------------------------------------------------------------- 1 | ### RESULTS 2 | 3 | # ACS 4 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --combination 0 --experiment-name ours_0 --eval > ours_0.log 5 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --combination 1 --experiment-name ours_1 --eval > ours_1.log 6 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --combination 2 --experiment-name ours_2 --eval > ours_2.log 7 | 8 | # UNET 9 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name unet_0_pe --eval --unet-only > unet_b_0.log 10 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 1 --experiment-name unet_1_ne --eval --unet-only > unet_b_1.log 11 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 2 --experiment-name unet_2_ne --eval --unet-only > unet_b_2.log 12 | 13 | # MAS 14 | nohup python mas_train.py --combination 0 --lambda-d 0.1 --eval --experiment-name mas27_0 > mas_0.log 15 | nohup python mas_train.py --combination 1 --lambda-d 0.1 --eval --experiment-name mas27_1 > mas_1.log 16 | nohup python mas_train.py --combination 2 --lambda-d 0.1 --eval --experiment-name mas27_2 > mas_2.log 17 | 18 | # KD 19 | nohup python kd_train.py --combination 0 --lambda-d 0.1 --eval --experiment-name kd27_0 > kd_0.log 20 | nohup python kd_train.py --combination 1 --lambda-d 0.1 --eval --experiment-name kd27_1 > kd_1.log 21 | nohup python kd_train.py --combination 2 --lambda-d 0.1 --eval --experiment-name kd27_2 > kd_2.log 22 | 23 | # Unet-Alternative 24 | nohup python mas_train.py --combination 0 --lambda-d 0 --eval --experiment-name unet27_0 --unet-only > unet_0.log 25 | nohup python mas_train.py --combination 1 --lambda-d 0 --eval --experiment-name unet27_1 --unet-only > unet_1.log 26 | nohup python mas_train.py --combination 2 --lambda-d 0 --eval --experiment-name unet27_2 --unet-only > unet_2.log 27 | 28 | 29 | ### HYPERPARAMETER 30 | 31 | # MAS 32 | nohup python mas_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name mas_lambda_1_ne --lambda-d 1 --lambda-eval > mas_lambda_1.log 33 | nohup python mas_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name mas_lambda_05_ne --lambda-d 0.5 --lambda-eval > mas_lambda_05.log 34 | nohup python mas_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name mas_lambda_01_ne --lambda-d 0.1 --lambda-eval > mas_lambda_01.log 35 | 36 | # KD 37 | nohup python kd_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name kd_lambda_1_ne --lambda-d 1 --lambda-eval > kd_lambda_1.log 38 | nohup python kd_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name kd_lambda_05_ne --lambda-d 0.5 --lambda-eval > kd_lambda_05.log 39 | nohup python kd_train.py --batch-size 40 --epochs 60 --device-ids 0 --combination 0 --experiment-name kd_lambda_01_ne --lambda-d 0.1 --lambda-eval > kd_lambda_01.log 40 | 41 | 42 | ### ABLATION 43 | 44 | #--lambda-c-adv 0 45 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-c-adv 0 --combination 0 --experiment-name ours_no_c_adv_0 --eval > ours_no_c_adv_0.log 46 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-c-adv 0 --combination 1 --experiment-name ours_no_c_adv_1 --eval > ours_no_c_adv_1.log 47 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-c-adv 0 --combination 2 --experiment-name ours_no_c_adv_2 --eval > ours_no_c_adv_2.log 48 | 49 | #--lambda-vae 0 --lambda-gan 0 --lambda-lcr 0 50 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --lambda-gan 0 --lambda-lcr 0 --combination 0 --experiment-name ours_no_ganvaelr_0 --eval > ours_no_ganvaelr_0.log 51 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --lambda-gan 0 --lambda-lcr 0 --combination 1 --experiment-name ours_no_ganvaelr_1 --eval > ours_no_ganvaelr_1.log 52 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --lambda-gan 0 --lambda-lcr 0 --combination 2 --experiment-name ours_no_ganvaelr_2 --eval > ours_no_ganvaelr_2.log 53 | 54 | #--lambda-vae 5 --lambda-gan 0 55 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-gan 0 --combination 0 --experiment-name ours_no_gan_0 --eval > ours_no_gan_0.log 56 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-gan 0 --combination 1 --experiment-name ours_no_gan_1 --eval > ours_no_gan_1.log 57 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-gan 0 --combination 2 --experiment-name ours_no_gan_2 --eval > ours_no_gan_2.log 58 | 59 | --lambda-vae 0 --lambda-gan 5 60 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --combination 0 --experiment-name ours_no_vae_0 --eval > ours_no_vae_0.log 61 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --combination 1 --experiment-name ours_no_vae_1 --eval > ours_no_vae_1.log 62 | nohup python acs_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --lambda-vae 0 --combination 2 --experiment-name ours_no_vae_2 --eval > ours_no_vae_2.log 63 | 64 | ### JOINT 65 | 66 | nohup python acs_joint_train.py --batch-size 40 --epochs 60 --device-ids 0 1 2 3 --experiment-name ours_joint_e0_60 --eval > ours_joint.log 67 | nohup python unet_joint_train.py --batch-size 40 --epochs 60 --device-ids 0 --experiment-name unet_all_joint --eval > unet_joint.log 68 | -------------------------------------------------------------------------------- /kd_train.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Code to train KD 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import sys 8 | from args import parse_args_as_dict 9 | from mp.utils.helper_functions import seed_all 10 | 11 | import torch 12 | torch.set_num_threads(6) 13 | from torch.utils.data import DataLoader 14 | import torch.optim as optim 15 | 16 | from mp.experiments.experiment import Experiment 17 | from mp.data.data import Data 18 | from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus 19 | from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus 20 | from mp.data.datasets.ds_mr_hippocampus_harp import HarP 21 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset 22 | from mp.eval.losses.losses_segmentation import LossClassWeighted, LossDiceBCE 23 | from mp.agents.kd_agent import KDAgent 24 | from mp.eval.result import Result 25 | from mp.utils.tensorboard import create_writer 26 | 27 | from mp.models.continual.kd import KD 28 | 29 | # Get configuration from arguments 30 | config = parse_args_as_dict(sys.argv[1:]) 31 | seed_all(42) 32 | 33 | config['class_weights'] = (0., 1.) 34 | 35 | # Create experiment directories 36 | exp = Experiment(config=config, name=config['experiment_name'], notes='', reload_exp=(config['resume_epoch'] is not None)) 37 | 38 | # Datasets 39 | data = Data() 40 | 41 | dataset_domain_a = DecathlonHippocampus(merge_labels=True) 42 | dataset_domain_a.name = 'DecathlonHippocampus' 43 | data.add_dataset(dataset_domain_a) 44 | 45 | dataset_domain_b = DryadHippocampus(merge_labels=True) 46 | dataset_domain_b.name = 'DryadHippocampus' 47 | data.add_dataset(dataset_domain_b) 48 | 49 | dataset_domain_c = HarP(merge_labels=True) 50 | dataset_domain_c.name = 'HarP' 51 | data.add_dataset(dataset_domain_c) 52 | 53 | nr_labels = data.nr_labels 54 | label_names = data.label_names 55 | 56 | if config['combination'] == 0: 57 | ds_a = ('DecathlonHippocampus', 'train') 58 | ds_b = ('DryadHippocampus', 'train') 59 | ds_c = ('HarP', 'train') 60 | elif config['combination'] == 1: 61 | ds_a = ('DecathlonHippocampus', 'train') 62 | ds_c = ('DryadHippocampus', 'train') 63 | ds_b = ('HarP', 'train') 64 | elif config['combination'] == 2: 65 | ds_c = ('DecathlonHippocampus', 'train') 66 | ds_b = ('DryadHippocampus', 'train') 67 | ds_a = ('HarP', 'train') 68 | 69 | # Create data splits for each repetition 70 | exp.set_data_splits(data) 71 | 72 | # Now repeat for each repetition 73 | for run_ix in range(config['nr_runs']): 74 | exp_run = exp.get_run(run_ix=0, reload_exp_run=(config['resume_epoch'] is not None)) 75 | 76 | datasets = dict() 77 | for idx, item in enumerate(data.datasets.items()): 78 | ds_name, ds = item 79 | for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items(): 80 | data_ixs = data_ixs[:config['n_samples']] 81 | if len(data_ixs) > 0: # Sometimes val indexes may be an empty list 82 | aug = config['augmentation'] if not('test' in split) else 'none' 83 | datasets[(ds_name, split)] = PytorchSeg2DDataset(ds, 84 | ix_lst=data_ixs, size=config['input_shape'] , aug_key=aug, 85 | resize=(not config['no_resize'])) 86 | 87 | dataset = torch.utils.data.ConcatDataset((datasets[(ds_a)], datasets[(ds_b)])) 88 | train_dataloader_0 = DataLoader(dataset, batch_size=config['batch_size'], drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 89 | train_dataloader_1 = DataLoader(datasets[(ds_c)], batch_size=config['batch_size'], shuffle=True, drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 90 | 91 | if config['eval']: 92 | drop = [] 93 | for key in datasets.keys(): 94 | if 'train' in key or 'val' in key: 95 | drop += [key] 96 | for d in drop: 97 | datasets.pop(d) 98 | elif config['lambda_eval']: 99 | drop = [] 100 | for key in datasets.keys(): 101 | if 'train' in key or 'test' in key: 102 | drop += [key] 103 | for d in drop: 104 | datasets.pop(d) 105 | 106 | model = KD(input_shape=config['input_shape'], nr_labels=nr_labels, 107 | unet_dropout=config['unet_dropout'], unet_monte_carlo_dropout=config['unet_monte_carlo_dropout'], unet_preactivation=config['unet_preactivation']) 108 | 109 | model.to(config['device']) 110 | 111 | # Define loss and optimizer 112 | loss_g = LossDiceBCE(bce_weight=1., smooth=1., device=config['device']) 113 | loss_f = LossClassWeighted(loss=loss_g, weights=config['class_weights'], device=config['device']) 114 | 115 | # Set optimizers 116 | model.set_optimizers(optim.Adam, lr=config['lr']) 117 | 118 | # Train model 119 | results = Result(name='training_trajectory') 120 | 121 | agent = KDAgent(model=model, label_names=label_names, device=config['device']) 122 | agent.summary_writer = create_writer(os.path.join(exp_run.paths['states'], '..'), 0) 123 | 124 | 125 | init_epoch = 0 126 | nr_epochs = config['epochs'] // 2 127 | 128 | config['continual'] = False 129 | 130 | # Resume training 131 | if config['resume_epoch'] is not None: 132 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 133 | init_epoch = agent.agent_state_dict['epoch'] + 1 134 | 135 | # Training epochs 0 - 30 136 | if init_epoch < config['epochs'] / 2: 137 | agent.train(results, loss_f, train_dataloader_0, train_dataloader_1, config, 138 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 139 | eval_datasets=datasets, eval_interval=config['eval_interval'], 140 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 141 | display_interval=config['display_interval'], 142 | resume_epoch=config['resume_epoch'], device_ids=config['device_ids']) 143 | 144 | print('Finished training on A and B, starting training on C') 145 | 146 | init_epoch = config['epochs'] // 2 147 | nr_epochs = config['epochs'] 148 | 149 | # Resume training 150 | if config['resume_epoch'] is not None: 151 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 152 | init_epoch = agent.agent_state_dict['epoch'] + 1 153 | 154 | # Training epochs 30 - 60 155 | if init_epoch >= config['epochs'] / 2: 156 | 157 | model.set_optimizers(optim.Adam, lr=config['lr_2']) 158 | config['continual'] = True 159 | model.unet_scheduler = torch.optim.lr_scheduler.StepLR(model.unet_optim, (nr_epochs-init_epoch), gamma=0.1, last_epoch=-1) 160 | 161 | agent.train(results, loss_f, train_dataloader_1, train_dataloader_0, config, 162 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 163 | eval_datasets=datasets, eval_interval=config['eval_interval'], 164 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 165 | display_interval=config['display_interval'], 166 | resume_epoch=config['resume_epoch'], device_ids=[0]) 167 | 168 | print('Finished training on C') 169 | 170 | # Save and print results for this experiment run 171 | exp_run.finish(results=results, plot_metrics=['Mean_LossBCEWithLogits', 'Mean_LossDice[smooth=1.0]', 'Mean_LossCombined[1.0xLossDice[smooth=1.0]+1.0xLossBCEWithLogits]']) 172 | -------------------------------------------------------------------------------- /mas_train.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Code to train MAS 4 | # ------------------------------------------------------------------------------ 5 | 6 | # Imports 7 | import os 8 | import sys 9 | from args import parse_args_as_dict 10 | from mp.utils.helper_functions import seed_all 11 | 12 | import torch 13 | torch.set_num_threads(6) 14 | from torch.utils.data import DataLoader 15 | import torch.optim as optim 16 | 17 | from mp.experiments.experiment import Experiment 18 | from mp.data.data import Data 19 | from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus 20 | from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus 21 | from mp.data.datasets.ds_mr_hippocampus_harp import HarP 22 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset 23 | from mp.eval.losses.losses_segmentation import LossClassWeighted, LossDiceBCE 24 | from mp.agents.mas_agent import MASAgent 25 | from mp.eval.result import Result 26 | from mp.utils.tensorboard import create_writer 27 | from mp.models.continual.mas import MAS 28 | 29 | # Get configuration from arguments 30 | config = parse_args_as_dict(sys.argv[1:]) 31 | seed_all(42) 32 | 33 | config['class_weights'] = (0., 1.) 34 | 35 | # Create experiment directories 36 | exp = Experiment(config=config, name=config['experiment_name'], notes='', reload_exp=(config['resume_epoch'] is not None)) 37 | 38 | # Datasets 39 | data = Data() 40 | 41 | dataset_domain_a = DecathlonHippocampus(merge_labels=True) 42 | dataset_domain_a.name = 'DecathlonHippocampus' 43 | data.add_dataset(dataset_domain_a) 44 | 45 | dataset_domain_b = DryadHippocampus(merge_labels=True) 46 | dataset_domain_b.name = 'DryadHippocampus' 47 | data.add_dataset(dataset_domain_b) 48 | 49 | dataset_domain_c = HarP(merge_labels=True) 50 | dataset_domain_c.name = 'HarP' 51 | data.add_dataset(dataset_domain_c) 52 | 53 | nr_labels = data.nr_labels 54 | label_names = data.label_names 55 | 56 | if config['combination'] == 0: 57 | ds_a = ('DecathlonHippocampus', 'train') 58 | ds_b = ('DryadHippocampus', 'train') 59 | ds_c = ('HarP', 'train') 60 | elif config['combination'] == 1: 61 | ds_a = ('DecathlonHippocampus', 'train') 62 | ds_c = ('DryadHippocampus', 'train') 63 | ds_b = ('HarP', 'train') 64 | elif config['combination'] == 2: 65 | ds_c = ('DecathlonHippocampus', 'train') 66 | ds_b = ('DryadHippocampus', 'train') 67 | ds_a = ('HarP', 'train') 68 | 69 | # Create data splits for each repetition 70 | exp.set_data_splits(data) 71 | 72 | # Now repeat for each repetition 73 | for run_ix in range(config['nr_runs']): 74 | exp_run = exp.get_run(run_ix=0, reload_exp_run=(config['resume_epoch'] is not None)) 75 | 76 | # Bring data to Pytorch format and add domain_code 77 | datasets = dict() 78 | for idx, item in enumerate(data.datasets.items()): 79 | ds_name, ds = item 80 | for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items(): 81 | data_ixs = data_ixs[:config['n_samples']] 82 | if len(data_ixs) > 0: # Sometimes val indexes may be an empty list 83 | aug = config['augmentation'] if not('test' in split) else 'none' 84 | datasets[(ds_name, split)] = PytorchSeg2DDataset(ds, 85 | ix_lst=data_ixs, size=config['input_shape'] , aug_key=aug, 86 | resize=(not config['no_resize'])) 87 | 88 | dataset = torch.utils.data.ConcatDataset((datasets[(ds_a)], datasets[(ds_b)])) 89 | train_dataloader_0 = DataLoader(dataset, batch_size=config['batch_size'], drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 90 | train_dataloader_1 = DataLoader(datasets[(ds_c)], batch_size=config['batch_size'], shuffle=True, drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 91 | 92 | if config['eval']: 93 | drop = [] 94 | for key in datasets.keys(): 95 | if 'train' in key or 'val' in key: 96 | drop += [key] 97 | for d in drop: 98 | datasets.pop(d) 99 | elif config['lambda_eval']: 100 | drop = [] 101 | for key in datasets.keys(): 102 | if 'train' in key or 'test' in key: 103 | drop += [key] 104 | for d in drop: 105 | datasets.pop(d) 106 | 107 | model = MAS(input_shape=config['input_shape'], nr_labels=nr_labels, 108 | unet_dropout=config['unet_dropout'], unet_monte_carlo_dropout=config['unet_monte_carlo_dropout'], unet_preactivation=config['unet_preactivation']) 109 | 110 | model.to(config['device']) 111 | 112 | # Define loss and optimizer 113 | loss_g = LossDiceBCE(bce_weight=1., smooth=1., device=config['device']) 114 | loss_f = LossClassWeighted(loss=loss_g, weights=config['class_weights'], device=config['device']) 115 | 116 | # Set optimizers 117 | model.set_optimizers(optim.Adam, lr=config['lr']) 118 | 119 | # Train model 120 | results = Result(name='training_trajectory') 121 | 122 | agent = MASAgent(model=model, label_names=label_names, device=config['device']) 123 | agent.summary_writer = create_writer(os.path.join(exp_run.paths['states'], '..'), 0) 124 | 125 | init_epoch = 0 126 | nr_epochs = config['epochs'] // 2 127 | 128 | config['continual'] = False 129 | 130 | # Resume training 131 | if config['resume_epoch'] is not None: 132 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 133 | init_epoch = agent.agent_state_dict['epoch'] + 1 134 | 135 | # Train on A and B 136 | if init_epoch < config['epochs'] / 2: 137 | # if init_epoch < config['epochs'] * 2/3: 138 | agent.train(results, loss_f, train_dataloader_0, train_dataloader_1, config, 139 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 140 | eval_datasets=datasets, eval_interval=config['eval_interval'], 141 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 142 | display_interval=config['display_interval'], 143 | resume_epoch=config['resume_epoch'], device_ids=config['device_ids']) 144 | 145 | print('Finished training on A and B, starting training on C') 146 | 147 | init_epoch = config['epochs'] // 2 148 | nr_epochs = config['epochs'] 149 | 150 | # Resume training 151 | if config['resume_epoch'] is not None: 152 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 153 | init_epoch = agent.agent_state_dict['epoch'] + 1 154 | 155 | # Training epochs 30 - 60 156 | if init_epoch >= config['epochs'] / 2: 157 | 158 | # Set optimizers 159 | model.set_optimizers(optim.Adam, lr=config['lr_2']) 160 | config['continual'] = True 161 | model.unet_scheduler = torch.optim.lr_scheduler.StepLR(model.unet_optim, (nr_epochs-init_epoch), gamma=0.1, last_epoch=-1) 162 | 163 | agent.train(results, loss_f, train_dataloader_1, train_dataloader_0, config, 164 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 165 | eval_datasets=datasets, eval_interval=config['eval_interval'], 166 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 167 | display_interval=config['display_interval'], 168 | resume_epoch=config['resume_epoch'], device_ids=[0]) 169 | 170 | print('Finished training on C') 171 | 172 | # Save and print results for this experiment run 173 | exp_run.finish(results=results, plot_metrics=['Mean_LossBCEWithLogits', 'Mean_LossDice[smooth=1.0]', 'Mean_LossCombined[1.0xLossDice[smooth=1.0]+1.0xLossBCEWithLogits]']) 174 | -------------------------------------------------------------------------------- /mp/agents/autoencoding_agent.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # An autoencoding agent. 3 | # ------------------------------------------------------------------------------ 4 | 5 | from mp.agents.agent import Agent 6 | 7 | class AutoencodingAgent(Agent): 8 | r"""An Agent for autoencoder models.""" 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | def get_inputs_targets(self, data): 13 | r"""The usual dataloaders are used for autoencoders. However, these 14 | ignore the target and instead treat he input as target 15 | """ 16 | inputs, targets = data 17 | inputs = inputs.to(self.device) 18 | inputs = self.model.preprocess_input(inputs) 19 | targets = inputs.clone() 20 | return inputs, targets 21 | 22 | def predict_from_outputs(self, outputs): 23 | r"""No transformation is performed on the outputs, as the goal is to 24 | reconstruct the input. Therefore, the output should have the same dim 25 | as the input.""" 26 | return outputs -------------------------------------------------------------------------------- /mp/agents/segmentation_agent.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A standard segmentation agent, which performs softmax in the outputs. 3 | # ------------------------------------------------------------------------------ 4 | 5 | from mp.agents.agent import Agent 6 | from mp.eval.inference.predict import softmax 7 | 8 | class SegmentationAgent(Agent): 9 | r"""An Agent for segmentation models.""" 10 | def __init__(self, *args, **kwargs): 11 | if 'metrics' not in kwargs: 12 | kwargs['metrics'] = ['ScoreDice', 'ScoreIoU'] 13 | super().__init__(*args, **kwargs) 14 | 15 | def get_outputs(self, inputs): 16 | r"""Applies a softmax transformation to the model outputs""" 17 | outputs = self.model(inputs) 18 | outputs = softmax(outputs).clamp(min=1e-08, max=1.-1e-08) 19 | return outputs -------------------------------------------------------------------------------- /mp/data/data.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # An instance of data includes a dictionary of datasets. 3 | # TODO: in the standard Dataset class, label_names are classes. These ar used 4 | # to produce stratesfied folds. 5 | # ------------------------------------------------------------------------------ 6 | 7 | class Data: 8 | r"""A Data object stores a dictionary of datasets.""" 9 | def __init__(self): 10 | self.datasets = dict() 11 | self.label_names = None 12 | self.nr_labels = None 13 | 14 | def add_dataset(self, dataset): 15 | r"""Saves the dataset with its name as key. 16 | 17 | Args: 18 | dataset (mp.data.datasets.dataset.Dataset): a Dataset object 19 | 20 | """ 21 | assert dataset.name not in self.datasets 22 | if len(self.datasets) > 0: 23 | for other_dataset in self.datasets.values(): 24 | assert dataset.label_names == other_dataset.label_names, 'Datasets must have the same label names' 25 | else: 26 | self.label_names = dataset.label_names 27 | self.nr_labels = dataset.nr_labels 28 | self.datasets[dataset.name] = dataset -------------------------------------------------------------------------------- /mp/data/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Dataset class meant to store general information about dataset and to divide 3 | # instances in folds, before converting to a torch.utils.data.Dataset. 4 | # All datasets descend from this Dataset class. 5 | # ------------------------------------------------------------------------------ 6 | 7 | class Dataset: 8 | r"""A dataset stores instances. 9 | 10 | Args: 11 | name (str): name of a dataset 12 | instances (list[mp.data.datasets.dataset.Instance]): list of Instances 13 | classes (tuple[str]): tuple with label names 14 | hold_out_ixs (list[int]): instances which are not evaluated until the end 15 | mean_shape (list[int]): mean input shape 16 | output_shape (list[int]): output shape 17 | x_norm (tuple[float]): normalization values for the input 18 | """ 19 | def __init__(self, name, instances, classes=('0'), hold_out_ixs=[], 20 | mean_shape=(1, 32, 32), output_shape=(1, 32, 32), x_norm=None): 21 | self.name = name 22 | # Sort instances in terms of name 23 | self.instances = sorted(instances, key=lambda ex: ex.name) 24 | self.size = len(instances) 25 | self.classes = classes 26 | self.hold_out_ixs = hold_out_ixs 27 | self.mean_shape = mean_shape 28 | self.output_shape = output_shape 29 | self.x_norm = x_norm 30 | 31 | def get_class_dist(self, ixs=None): 32 | r"""Get class (category) distribution 33 | 34 | Args: 35 | ixs (list[int]): if not None, distribution for only these indexes. 36 | Otherwise distribution for all indexes not part of the hold-out. 37 | """ 38 | if ixs is None: 39 | ixs = [ix for ix in range(self.size) if ix not in self.hold_out_ixs] 40 | class_dist = {class_ix: 0 for class_ix in self.classes} 41 | for ex_ix, ex in enumerate(self.instances): 42 | if ex_ix in ixs: 43 | class_dist[self.classes[ex.class_ix]] += 1 44 | return class_dist 45 | 46 | def get_class_instance_ixs(self, class_name, exclude_ixs): 47 | r"""Get instances for a class, excluding those in exclude_ixs.""" 48 | return [ix for ix, ex in enumerate(self.instances) if 49 | ex.class_ix==self.classes.index(class_name) and ix not in exclude_ixs] 50 | 51 | def get_instance(self, name): 52 | r"""Get an instance from a name.""" 53 | instances = [instance for instance in self.instances if instance.name == name] 54 | if len(instances) == 0: 55 | return None 56 | else: 57 | assert len(instances) == 1, "There are more than one instance with that name" 58 | return instances[0] 59 | 60 | def get_instance_ixs_from_names(self, name_lst): 61 | r"""Get instance ixs from a list of names.""" 62 | ixs = [ix for ix, instance in enumerate(self.instances) if instance.name in name_lst] 63 | return ixs 64 | 65 | class Instance: 66 | r"""A dataset instance. 67 | 68 | Args: 69 | x (Obj): input, can take different forms depending on the subclass 70 | y (Obj): ground truth 71 | name (str): instance name (e.g. file name) for case-wise evaluation 72 | class_ix (int): during splitting of the dataset, the resulting subsets 73 | are stratesfied according to this value (i.e. there are about as 74 | many examples of each class on each fold). For classification, 75 | class_ix==y, but can also be used solely for splitting. 76 | group_id (int): instances with same 'group id' should always 77 | remain on the same dataset split. A group id could be, for instance, 78 | a patient identifier (the same patient should typically not be in 79 | several different splits). 80 | """ 81 | def __init__(self, x, y, name=None, class_ix=0, group_id=None): 82 | self.x = x 83 | self.y = y 84 | self.name = name 85 | self.class_ix = class_ix 86 | self.group_id = group_id -------------------------------------------------------------------------------- /mp/data/datasets/dataset_classification.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Classes for creating new classification datasets. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import os 6 | from mp.data.datasets.dataset import Dataset, Instance 7 | import mp.data.datasets.dataset_utils as du 8 | from mp.paths import original_data_paths 9 | 10 | class ClassificationPathInstance(Instance): 11 | r"""Instance class where x is a path and y is an integer label corr. to 12 | an index of the dataset 'classes' field. 13 | """ 14 | def __init__(self, x_path, y, name=None, group_id=None): 15 | assert isinstance(x_path, str) 16 | assert isinstance(y, int) 17 | super().__init__(x=x_path, y=y, class_ix=y, name=name, group_id=group_id) 18 | 19 | class SplitClassImageDataset(Dataset): 20 | r"""Classification Dataset with the structure root/split/class/filename, 21 | where 'split' is test for the hold-out test dataset and train for the rest. 22 | The instances are of the type 'PathInstance'. 23 | """ 24 | def __init__(self, name, root_path=None, input_shape=(1, 32, 32), x_norm=None): 25 | root_path = du.get_original_data_path(name) 26 | classes = [] 27 | instances = [] 28 | hold_out_start = None 29 | for split in ['train', 'test']: 30 | if split == 'test': 31 | hold_out_start = len(instances) 32 | split_path = os.path.join(root_path, split) 33 | for class_name in os.listdir(split_path): 34 | if class_name not in classes: 35 | classes.append(class_name) 36 | class_path = os.path.join(split_path, class_name) 37 | for img_name in os.listdir(class_path): 38 | instance = ClassificationPathInstance(name=img_name, x_path=os.path.join(class_path, img_name), y=classes.index(class_name)) 39 | instances.append(instance) 40 | super().__init__(name=name, classes=tuple(classes), instances=instances, 41 | mean_shape=input_shape, output_shape=len(classes), x_norm=x_norm, 42 | hold_out_ixs=list(range(hold_out_start, len(instances)))) 43 | 44 | class CIFAR10(SplitClassImageDataset): 45 | r"""The Cifar10 dataset. 46 | """ 47 | def __init__(self, root_path=None): 48 | super().__init__(name='Cifar10', root_path=root_path, 49 | input_shape=(3, 32, 32), 50 | x_norm={'mean': (0.4914, 0.4822, 0.4465), 'std': (0.247, 0.243, 0.262)} 51 | ) 52 | -------------------------------------------------------------------------------- /mp/data/datasets/dataset_segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # All datasets descend from this SegmentationDataset class storing segmentation 3 | # instances. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import sys 8 | from mp.data.datasets.dataset import Dataset, Instance 9 | import mp.data.datasets.dataset_utils as du 10 | import torchio 11 | 12 | class SegmentationInstance(Instance): 13 | def __init__(self, x_path, y_path, name=None, class_ix=0, group_id=None): 14 | r"""A segmentation instance, using the TorchIO library. 15 | 16 | Args: 17 | x_path (str): path to image 18 | y_path (str): path to segmentation mask 19 | name (str): name of instance for case-wise evaluation 20 | class_ix (int): optinal "class" index. During splitting of the dataset, 21 | the resulting subsets are stratesfied according to this value (i.e. 22 | there are about as many examples from each class in each fold 23 | of each class on each fold). 24 | group_id (comparable): instances with same group_id (e.g. patient id) 25 | are always in the same fold 26 | 27 | Note that torchio images have the shape (channels, w, h, d) 28 | """ 29 | assert isinstance(x_path, str) 30 | assert isinstance(y_path, str) 31 | x = torchio.Image(x_path, type=torchio.INTENSITY) 32 | y = torchio.Image(y_path, type=torchio.LABEL) 33 | self.shape = x.shape 34 | super().__init__(x=x, y=y, name=name, class_ix=class_ix, 35 | group_id=group_id) 36 | 37 | def get_subject(self): 38 | return torchio.Subject( 39 | x=self.x, 40 | y=self.y 41 | ) 42 | 43 | class SegmentationDataset(Dataset): 44 | r"""A Dataset for segmentation tasks, that specific datasets descend from. 45 | 46 | Args: 47 | instances (list[SegmentationInstance]): a list of instances 48 | name (str): the dataset name 49 | mean_shape (tuple[int]): the mean input shape of the data, or None 50 | label_names (list[str]): list with label names, or None 51 | nr_channels (int): number input channels 52 | modality (str): modality of the data, e.g. MR, CT 53 | hold_out_ixs (list[int]): list of instance index to reserve for a 54 | separate hold-out dataset. 55 | check_correct_nr_labels (bool): Whether it should be checked if the 56 | correct number of labels (the mength of label_names) is consistent 57 | with the dataset. As it takes a long time to check, only set to True 58 | when initially testing a dataset. 59 | """ 60 | def __init__(self, instances, name, mean_shape=None, 61 | label_names=None, nr_channels=1, modality='unknown', hold_out_ixs=[], 62 | check_correct_nr_labels=False): 63 | # Set mean input shape and mask labels, if these are not provided 64 | print('\nDATASET: {} with {} instances'.format(name, len(instances))) 65 | if mean_shape is None: 66 | mean_shape, shape_std = du.get_mean_std_shape(instances) 67 | print('Mean shape: {}, shape std: {}'.format(mean_shape, shape_std)) 68 | if label_names is None: 69 | label_names = du.get_mask_labels(instances) 70 | else: 71 | if check_correct_nr_labels: 72 | du.check_correct_nr_labels(label_names, instances) 73 | print('Mask labels: {}\n'.format(label_names)) 74 | self.mean_shape = mean_shape 75 | self.label_names = label_names 76 | self.nr_labels = len(label_names) 77 | self.nr_channels = nr_channels 78 | self.modality = modality 79 | super().__init__(name=name, instances=instances, 80 | mean_shape=mean_shape, output_shape=mean_shape, 81 | hold_out_ixs=hold_out_ixs) 82 | 83 | -------------------------------------------------------------------------------- /mp/data/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Utils for binding datasets as Dataset subclasses. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import numpy as np 6 | import torch 7 | from mp.paths import original_data_paths 8 | 9 | def get_original_data_path(global_name): 10 | r"""Get the original path from mp.paths. The global name is the key.""" 11 | try: 12 | data_path = original_data_paths[global_name] 13 | except: 14 | raise Exception('Data path for {} must be set in paths.py.'.format(global_name)) 15 | return data_path 16 | 17 | def get_dataset_name(global_name, subset=None): 18 | r"""Get name of dataset by adding the global name to the subset.""" 19 | if subset is None: 20 | return global_name 21 | else: 22 | name = global_name 23 | for key, value in subset.items(): 24 | name += '[' + key+':'+ value + ']' 25 | return name 26 | 27 | def get_mean_std_shape(instances): 28 | r"""Returns the mean sheap as (channels, width, heigth, depth) for a 29 | list of instances. 30 | 31 | Args: 32 | instances (list[SegmentationInstance]): a list of segmentation instances. 33 | 34 | Returns (tuple[int]) tuple with form (channels, width, heigth, depth) 35 | """ 36 | shapes = [np.array(instance.shape) for instance in instances] 37 | mean = np.mean(shapes, axis=0) 38 | std = np.std(shapes, axis=0) 39 | return tuple(int(x) for x in mean), tuple(int(x) for x in std) 40 | 41 | def get_mask_labels(instances): 42 | r"""Returns a set of integer labels which appear in segmentation masks in 43 | a list of instances. 44 | 45 | Args: 46 | instances (list[SegmentationInstance)): a list of segmentation instances. 47 | 48 | Returns (list[str]): list of the form ['0', '1', '2', etc.] as replacement 49 | for not having the real label names. 50 | """ 51 | labels = set() 52 | for instance in instances: 53 | instance_labels = list(np.unique(instance.y.tensor.numpy())) 54 | assert all(x == int(x) for x in instance_labels), "Mask contain non-integer values" 55 | labels = labels.union([int(x) for x in instance_labels]) 56 | return [str(nr) for nr in range(max(labels)+1)] 57 | 58 | def check_correct_nr_labels(labels, instances): 59 | r"""Check that the number of label names manually supplied is consistent 60 | with the dataset masks. 61 | """ 62 | nr_labels = len(get_mask_labels(instances)) 63 | print(nr_labels) 64 | assert nr_labels <= len(labels), "There are mask indexes not accounted for in the manually supplied label list" 65 | if nr_labels < len(labels): 66 | print('Warning: Some labels are not represented in the data') 67 | 68 | def get_normalization_values(instances): 69 | r"""Get normalization values for a dataset.""" 70 | count = 0 71 | mean = torch.empty(3) 72 | std = torch.empty(3) 73 | for instance in instances: 74 | c, w, h, d = instance.shape 75 | nb_pixels = d * h * w 76 | data = instance.x.tensor 77 | sum_ = torch.sum(data, dim=[1, 2, 3]) 78 | sum_of_square = torch.sum(data ** 2, dim=[1, 2, 3]) 79 | mean = (count * mean + sum_) / (count + nb_pixels) 80 | std = (count * std + sum_of_square) / (count + nb_pixels) 81 | count += nb_pixels 82 | return {'mean': mean, 'std': torch.sqrt(std - mean ** 2)} -------------------------------------------------------------------------------- /mp/data/datasets/ds_mr_cardiac_mm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Multi-Centre, Multi-Vendor & Multi-Disease Cardiac Image Segmentation 3 | # Challenge (M&Ms) dataset. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import numpy as np 8 | import csv 9 | import SimpleITK as sitk 10 | from mp.data.datasets.dataset_segmentation import SegmentationDataset, SegmentationInstance 11 | from mp.paths import storage_data_path 12 | import mp.data.datasets.dataset_utils as du 13 | from mp.utils.load_restore import join_path 14 | 15 | class MM_Challenge(SegmentationDataset): 16 | r"""Class for importing the Multi-Centre, Multi-Vendor & Multi-Disease 17 | Cardiac Image Segmentation Challenge (M&Ms), found at www.ub.edu/mnms/.""" 18 | 19 | def __init__(self, subset={'Vendor': 'A'}, hold_out_ixs=[]): 20 | 21 | global_name = 'MM_Challenge' 22 | name = du.get_dataset_name(global_name, subset) 23 | dataset_path = os.path.join(storage_data_path, global_name) 24 | original_data_path = du.get_original_data_path(global_name) 25 | 26 | # Extract ED and ES images, if not already done 27 | if not os.path.isdir(dataset_path): 28 | _extract_segmented_slices(original_data_path, dataset_path) 29 | 30 | # Fetch metadata 31 | csv_info = os.path.join(original_data_path, "M&Ms Dataset Information.csv") 32 | data_info = _get_csv_patient_info(csv_info, id_ix=0) 33 | 34 | # Fetch all patient/study names in the directory (the csv includes 35 | # unlabeled data) 36 | study_names = set(file_name.split('_')[0] for file_name 37 | in os.listdir(dataset_path)) 38 | 39 | # Fetch image and mask for each study 40 | instances = [] 41 | for study_name in study_names: 42 | # If study is part of the defined subset, add ED and ES images 43 | if subset is None or all( 44 | [data_info[study_name][key] == value for key, value 45 | in subset.items()]): 46 | instance_ed = SegmentationInstance( 47 | x_path=os.path.join(dataset_path, study_name+'_ED.nii.gz'), 48 | y_path=os.path.join(dataset_path, study_name+'_ED_gt.nii.gz'), 49 | name=study_name+'_ED', 50 | group_id=study_name 51 | ) 52 | instance_es = SegmentationInstance( 53 | x_path=os.path.join(dataset_path, study_name+'_ES.nii.gz'), 54 | y_path=os.path.join(dataset_path, study_name+'_ES_gt.nii.gz'), 55 | name=study_name+'_ES', 56 | group_id=study_name 57 | ) 58 | instances.append(instance_ed) 59 | instances.append(instance_es) 60 | label_names = ['background', 'left ventricle', 'myocardium', 'right ventricle'] 61 | super().__init__(instances, name=name, label_names=label_names, 62 | modality='MR', nr_channels=1, hold_out_ixs=[]) 63 | 64 | 65 | def _get_csv_patient_info(file_full_path, id_ix=0): 66 | r"""From a .csv file with the description in the top row, turn into a dict 67 | where the keys are the dientifier entries and the values are a dictionary 68 | with all other entries. 69 | """ 70 | file_info = dict() 71 | with open(file_full_path, newline='') as csvfile: 72 | reader = csv.reader(csvfile, delimiter='\t') 73 | first_line = None 74 | for row in reader: 75 | if first_line is None: 76 | first_line = row 77 | else: 78 | file_info[row[id_ix]] = {key: row[key_ix] for key_ix, key in 79 | enumerate(first_line)} 80 | return file_info 81 | 82 | def _extract_segmented_slices(source_path, target_path): 83 | r"""The original dataset has the following structure: 84 | 85 | MM_Challenge_dataset 86 | ├── Training-corrected 87 | │ ├── Labeled 88 | │ │ ├── 89 | │ │ │ ├── _sa.nii.gz 90 | │ │ │ └── _sa_gt.nii.gz 91 | │ │ └── ... 92 | └──────── M&Ms Dataset Information.xlsx 93 | 94 | The "M&Ms Dataset Information.xlsx" file should first be converted to csv. 95 | Each image and mask have the dimension (timesteps, slices, width, height). 96 | This method extracts only the segmented time steps (ED and ES). The result 97 | of applying the method is: 98 | 99 | 100 | ├── data 101 | │ ├── MM_Challenge 102 | │ │ ├── _ED.nii.gz 103 | │ │ ├── _ED_gt.nii.gz 104 | │ │ ├── _ES.nii.gz 105 | │ │ ├── _ES_gt.nii.gz 106 | │ │ └── ... 107 | 108 | Args: 109 | original_data_path (str): path to MM_Challenge_dataset, where the 110 | metadata file has been converted to csv. 111 | """ 112 | # Fetch metadata 113 | csv_info = os.path.join(source_path, "M&Ms Dataset Information.csv") 114 | data_info = _get_csv_patient_info(csv_info, id_ix=0) 115 | 116 | # Create directories 117 | os.makedirs(target_path) 118 | 119 | # Extract segmented timestamps (ED and ES) and save 120 | img_path = join_path([source_path, 'Training-corrected', 'Labeled']) 121 | for study_name in os.listdir(img_path): 122 | x_path = join_path([img_path, study_name, study_name+"_sa.nii.gz"]) 123 | mask_path = join_path([img_path, study_name, study_name+"_sa_gt.nii.gz"]) 124 | x = sitk.ReadImage(x_path) 125 | x = sitk.GetArrayFromImage(x) 126 | mask = sitk.ReadImage(mask_path) 127 | mask = sitk.GetArrayFromImage(mask) 128 | assert x.shape == mask.shape 129 | assert len(x.shape) == 4 130 | # There are two times for which segmentation is performed, ED and ES. 131 | # These are specified in the metadata file 132 | ed_slice = int(data_info[study_name]["ED"]) 133 | es_slice = int(data_info[study_name]["ES"]) 134 | # Store new images 135 | sitk.WriteImage(sitk.GetImageFromArray(x[ed_slice]), 136 | join_path([target_path, study_name+"_ED.nii.gz"])) 137 | sitk.WriteImage(sitk.GetImageFromArray(mask[ed_slice]), 138 | join_path([target_path, study_name+"_ED_gt.nii.gz"])) 139 | sitk.WriteImage(sitk.GetImageFromArray(x[es_slice]), 140 | join_path([target_path, study_name+"_ES.nii.gz"])) 141 | sitk.WriteImage(sitk.GetImageFromArray(mask[es_slice]), 142 | join_path([target_path, study_name+"_ES_gt.nii.gz"])) 143 | 144 | -------------------------------------------------------------------------------- /mp/data/datasets/ds_mr_hippocampus_decathlon.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Hippocampus segmentation task from the Medical Segmentation Decathlon 3 | # (http://medicaldecathlon.com/) 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | 8 | import SimpleITK as sitk 9 | 10 | import mp.data.datasets.dataset_utils as du 11 | from mp.data.datasets.dataset_segmentation import SegmentationDataset, SegmentationInstance 12 | from mp.paths import storage_data_path 13 | from mp.utils.load_restore import join_path 14 | 15 | 16 | class DecathlonHippocampus(SegmentationDataset): 17 | r"""Class for the hippocampus segmentation decathlon challenge, 18 | found at http://medicaldecathlon.com/. 19 | """ 20 | 21 | def __init__(self, subset=None, hold_out_ixs=None, merge_labels=True): 22 | assert subset is None, "No subsets for this dataset." 23 | 24 | if hold_out_ixs is None: 25 | hold_out_ixs = [] 26 | 27 | global_name = 'DecathlonHippocampus' 28 | dataset_path = os.path.join(storage_data_path, global_name, "Merged Labels" if merge_labels else "Original") 29 | original_data_path = du.get_original_data_path(global_name) 30 | 31 | # Copy the images if not done already 32 | if not os.path.isdir(dataset_path): 33 | _extract_images(original_data_path, dataset_path, merge_labels) 34 | 35 | # Fetch all patient/study names 36 | study_names = set(file_name.split('.nii')[0].split('_gt')[0] for file_name 37 | in os.listdir(dataset_path)) 38 | 39 | # Build instances 40 | instances = [] 41 | for study_name in study_names: 42 | instances.append(SegmentationInstance( 43 | x_path=os.path.join(dataset_path, study_name + '.nii.gz'), 44 | y_path=os.path.join(dataset_path, study_name + '_gt.nii.gz'), 45 | name=study_name, 46 | group_id=None 47 | )) 48 | 49 | if merge_labels: 50 | label_names = ['background', 'hippocampus'] 51 | else: 52 | label_names = ['background', 'hippocampus proper', 'subiculum'] 53 | super().__init__(instances, name=global_name, label_names=label_names, 54 | modality='T1w MRI', nr_channels=1, hold_out_ixs=hold_out_ixs) 55 | 56 | 57 | def _extract_images(source_path, target_path, merge_labels): 58 | r"""Extracts images, merges mask labels (if specified) and saves the 59 | modified images. 60 | """ 61 | 62 | images_path = os.path.join(source_path, 'imagesTr') 63 | labels_path = os.path.join(source_path, 'labelsTr') 64 | 65 | # Filenames have the form 'hippocampus_XX.nii.gz' 66 | filenames = [x for x in os.listdir(images_path) if x[:5] == 'hippo'] 67 | 68 | # Create directories 69 | os.makedirs(target_path) 70 | 71 | for filename in filenames: 72 | 73 | # Extract only T2-weighted 74 | x = sitk.ReadImage(os.path.join(images_path, filename)) 75 | x = sitk.GetArrayFromImage(x) 76 | y = sitk.ReadImage(os.path.join(labels_path, filename)) 77 | y = sitk.GetArrayFromImage(y) 78 | 79 | # Shape expected: (35, 51, 35) 80 | # Average label shape: (24.5, 37.8, 21.0) 81 | assert x.shape == y.shape 82 | 83 | # No longer distinguish between hippocampus proper and subiculum 84 | if merge_labels: 85 | y[y == 2] = 1 86 | 87 | # Save new images so they can be loaded directly 88 | study_name = filename.replace('_', '').split('.nii')[0] 89 | sitk.WriteImage(sitk.GetImageFromArray(x), join_path([target_path, study_name + ".nii.gz"])) 90 | sitk.WriteImage(sitk.GetImageFromArray(y), join_path([target_path, study_name + "_gt.nii.gz"])) 91 | -------------------------------------------------------------------------------- /mp/data/datasets/ds_mr_hippocampus_dryad.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Hippocampus segmentation published by Dryad 3 | # (https://datadryad.org/stash/dataset/doi:10.5061/dryad.gc72v) 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | 8 | import SimpleITK as sitk 9 | 10 | import mp.data.datasets.dataset_utils as du 11 | from mp.data.datasets.dataset_segmentation import SegmentationDataset, SegmentationInstance 12 | from mp.paths import storage_data_path 13 | from mp.utils.load_restore import join_path 14 | import re 15 | import nibabel as nib 16 | import numpy as np 17 | import re 18 | 19 | 20 | class DryadHippocampus(SegmentationDataset): 21 | r"""Class for the segmentation of the HarP dataset, 22 | https://datadryad.org/stash/dataset/doi:10.5061/dryad.gc72v. 23 | """ 24 | 25 | def __init__(self, subset=None, hold_out_ixs=None, merge_labels=True): 26 | # Modality is either: "T1w" or "T2w" 27 | # Resolution is either: "Standard" or "Hires" 28 | # If you want to use different resolutions or modalities, please create another object with a different subset 29 | default = {"Modality": "T1w", "Resolution": "Standard"} 30 | if subset is not None: 31 | default.update(subset) 32 | subset = default 33 | else: 34 | subset = default 35 | 36 | # Hires T2w is not available 37 | assert not (subset["Resolution"] == "Standard" and subset["Modality"] == "T2w"), \ 38 | "Hires T2w not available for the Dryad Hippocampus dataset" 39 | 40 | if hold_out_ixs is None: 41 | hold_out_ixs = [] 42 | 43 | global_name = 'DryadHippocampus' 44 | name = du.get_dataset_name(global_name, subset) 45 | dataset_path = os.path.join(storage_data_path, 46 | global_name, 47 | "Merged Labels" if merge_labels else "Original", 48 | "".join([f"{key}[{subset[key]}]" for key in ["Modality", "Resolution"]]) 49 | ) 50 | original_data_path = du.get_original_data_path(global_name) 51 | 52 | # Copy the images if not done already 53 | if not os.path.isdir(dataset_path): 54 | _extract_images(original_data_path, dataset_path, merge_labels, subset) 55 | 56 | # Fetch all patient/study names 57 | study_names = set(file_name.split('.nii')[0].split('_gt')[0] for file_name in os.listdir(dataset_path)) 58 | 59 | # Build instances 60 | instances = [] 61 | for study_name in study_names: 62 | instances.append(SegmentationInstance( 63 | x_path=os.path.join(dataset_path, study_name + '.nii.gz'), 64 | y_path=os.path.join(dataset_path, study_name + '_gt.nii.gz'), 65 | name=study_name, 66 | group_id=None 67 | )) 68 | 69 | if merge_labels: 70 | label_names = ['background', 'hippocampus'] 71 | else: 72 | label_names = ['background', 'subiculum', 'CA1-3', 'CA4-DG'] 73 | 74 | super().__init__(instances, name=name, label_names=label_names, 75 | modality=subset["Modality"] + ' MRI', nr_channels=1, hold_out_ixs=hold_out_ixs) 76 | 77 | 78 | def _extract_images(source_path, target_path, merge_labels, subset): 79 | r"""Extracts images, merges mask labels (if specified) and saves the 80 | modified images. 81 | """ 82 | 83 | def bbox_3D(img): 84 | r = np.any(img, axis=(1, 2)) 85 | c = np.any(img, axis=(0, 2)) 86 | z = np.any(img, axis=(0, 1)) 87 | 88 | rmin, rmax = np.where(r)[0][[0, -1]] 89 | cmin, cmax = np.where(c)[0][[0, -1]] 90 | zmin, zmax = np.where(z)[0][[0, -1]] 91 | 92 | return rmin, rmax, cmin, cmax, zmin, zmax 93 | 94 | # Create directories 95 | os.makedirs(os.path.join(target_path)) 96 | 97 | # Patient folders s01, s02, ... 98 | for patient_folder in filter(lambda s: re.match(r"^s[0-9]+.*", s), os.listdir(source_path)): 99 | 100 | # Loading the image 101 | image_path = os.path.join(source_path, patient_folder, 102 | f"{patient_folder}_{subset['Modality'].lower()}_" 103 | f"{subset['Resolution'].lower()}_defaced_MNI.nii.gz") 104 | x = sitk.ReadImage(image_path) 105 | x = sitk.GetArrayFromImage(x) 106 | 107 | # For each MRI, there are 2 segmentation (left and right hippocampus) 108 | for side in ["L", "R"]: 109 | # Loading the label 110 | label_path = os.path.join(source_path, patient_folder, 111 | f"{patient_folder}_hippolabels_" 112 | f"{'hres' if subset['Resolution'] == 'Hires' else 't1w_standard'}" 113 | f"_{side}_MNI.nii.gz") 114 | 115 | y = sitk.ReadImage(label_path) 116 | y = sitk.GetArrayFromImage(y) 117 | 118 | # We need to recover the study name of the image name to construct the name of the segmentation files 119 | study_name = f"{patient_folder}_{side}" 120 | 121 | # Average label shape (T1w, standard): (37.0, 36.3, 26.7) 122 | # Average label shape (T1w, hires): (94.1, 92.1, 68.5) 123 | # Average label shape (T2w, hires): (94.1, 92.1, 68.5) 124 | assert x.shape == y.shape 125 | 126 | # Disclaimer: next part is ugly and not many checks are made 127 | 128 | # So we first compute the bounding box 129 | rmin, rmax, cmin, cmax, zmin, zmax = bbox_3D(y) 130 | 131 | # Compute the start idx for each dim 132 | dr = (rmax - rmin) // 4 133 | dc = (cmax - cmin) // 4 134 | dz = (zmax - zmin) // 4 135 | 136 | # Reshaping 137 | y = y[rmin - dr: rmax + dr, 138 | cmin - dc: cmax + dc, 139 | zmin - dz: zmax + dz] 140 | 141 | if merge_labels: 142 | y[y > 1] = 1 143 | 144 | x_cropped = x[rmin - dr: rmax + dr, 145 | cmin - dc: cmax + dc, 146 | zmin - dz: zmax + dz] 147 | 148 | # Save new images so they can be loaded directly 149 | sitk.WriteImage(sitk.GetImageFromArray(y), 150 | join_path([target_path, study_name + "_gt.nii.gz"])) 151 | sitk.WriteImage(sitk.GetImageFromArray(x_cropped), 152 | join_path([target_path, study_name + ".nii.gz"])) 153 | -------------------------------------------------------------------------------- /mp/data/datasets/ds_mr_hippocampus_harp.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Hippocampus segmentation task for the HarP dataset 3 | # (http://www.hippocampal-protocol.net/SOPs/index.php) 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import re 8 | 9 | import SimpleITK as sitk 10 | import nibabel as nib 11 | import numpy as np 12 | 13 | import mp.data.datasets.dataset_utils as du 14 | from mp.data.datasets.dataset_segmentation import SegmentationDataset, SegmentationInstance 15 | from mp.paths import storage_data_path 16 | from mp.utils.load_restore import join_path 17 | 18 | 19 | class HarP(SegmentationDataset): 20 | r"""Class for the segmentation of the HarP dataset, 21 | found at http://www.hippocampal-protocol.net/SOPs/index.php 22 | with the masks as .nii files and the scans as .mnc files. 23 | """ 24 | 25 | def __init__(self, subset=None, hold_out_ixs=None, merge_labels=True): 26 | # Part is either: "Training", "Validation" or "All" 27 | default = {"Part": "All"} 28 | if subset is not None: 29 | default.update(subset) 30 | subset = default 31 | else: 32 | subset = default 33 | 34 | if hold_out_ixs is None: 35 | hold_out_ixs = [] 36 | 37 | global_name = 'HarP' 38 | name = du.get_dataset_name(global_name, subset) 39 | dataset_path = os.path.join(storage_data_path, global_name) 40 | original_data_path = du.get_original_data_path(global_name) 41 | 42 | # Build instances 43 | instances = [] 44 | folders = [] 45 | if subset["Part"] in ["Training", "All"]: 46 | folders.append(("100", "Training")) 47 | if subset["Part"] in ["Validation", "All"]: 48 | folders.append(("35", "Validation")) 49 | 50 | for orig_folder, dst_folder in folders: 51 | # Paths with the sub-folder for the current subset 52 | dst_folder_path = os.path.join(dataset_path, dst_folder) 53 | 54 | # Copy the images if not done already 55 | if not os.path.isdir(dst_folder_path): 56 | _extract_images(original_data_path, dst_folder_path, orig_folder) 57 | 58 | # Fetch all patient/study names 59 | study_names = set(file_name.split('.nii')[0].split('_gt')[0] for file_name 60 | in os.listdir(os.path.join(dataset_path, dst_folder))) 61 | 62 | for study_name in study_names: 63 | instances.append(SegmentationInstance( 64 | x_path=os.path.join(dataset_path, dst_folder, study_name + '.nii.gz'), 65 | y_path=os.path.join(dataset_path, dst_folder, study_name + '_gt.nii.gz'), 66 | name=study_name, 67 | group_id=None 68 | )) 69 | 70 | label_names = ['background', 'hippocampus'] 71 | 72 | super().__init__(instances, name=name, label_names=label_names, 73 | modality='T1w MRI', nr_channels=1, hold_out_ixs=hold_out_ixs) 74 | 75 | 76 | def _extract_images(source_path, target_path, subset): 77 | r"""Extracts images, merges mask labels (if specified) and saves the 78 | modified images. 79 | """ 80 | 81 | def bbox_3D(img): 82 | r = np.any(img, axis=(1, 2)) 83 | c = np.any(img, axis=(0, 2)) 84 | z = np.any(img, axis=(0, 1)) 85 | 86 | rmin, rmax = np.where(r)[0][[0, -1]] 87 | cmin, cmax = np.where(c)[0][[0, -1]] 88 | zmin, zmax = np.where(z)[0][[0, -1]] 89 | 90 | return rmin, rmax, cmin, cmax, zmin, zmax 91 | 92 | # Folder 100 is for training (100 subjects), 35 subjects are left over for validation 93 | affine = np.array([[1, 0, 0, 0], 94 | [0, 1, 0, 0], 95 | [0, 0, 1, 0], 96 | [0, 0, 0, 1]]) 97 | 98 | images_path = os.path.join(source_path, subset) 99 | labels_path = os.path.join(source_path, f'Labels_{subset}_NIFTI') 100 | 101 | # Create directories 102 | os.makedirs(os.path.join(target_path)) 103 | 104 | # For each MRI, there are 2 segmentation (left and right hippocampus) 105 | for filename in os.listdir(images_path): 106 | # Loading the .mnc file and converting it to a .nii.gz file 107 | minc = nib.load(os.path.join(images_path, filename)) 108 | x = nib.Nifti1Image(minc.get_data(), affine=affine) 109 | 110 | # We need to recover the study name of the image name to construct the name of the segmentation files 111 | match = re.match(r"ADNI_[0-9]+_S_[0-9]+_[0-9]+", filename) 112 | if match is None: 113 | raise Exception(f"A file ({filename}) does not match the expected file naming format") 114 | 115 | # For each side of the brain 116 | for side in ["_L", "_R"]: 117 | study_name = match[0] + side 118 | 119 | y = sitk.ReadImage(os.path.join(labels_path, study_name + ".nii")) 120 | y = sitk.GetArrayFromImage(y) 121 | 122 | # Shape expected: (189, 233, 197) 123 | # Average label shape (Training): (27.1, 36.7, 22.0) 124 | # Average label shape (Validation): (27.7, 35.2, 21.8) 125 | assert x.shape == y.shape 126 | # Disclaimer: next part is ugly and not many checks are made 127 | # BUGFIX: Some segmentation have some weird values eg {26896.988, 26897.988} instead of {0, 1} 128 | y = (y - np.min(y.flat)).astype(np.uint32) 129 | 130 | # So we first compute the bounding box 131 | rmin, rmax, cmin, cmax, zmin, zmax = bbox_3D(y) 132 | 133 | # Compute the start idx for each dim 134 | dr = (rmax - rmin) // 4 135 | dc = (cmax - cmin) // 4 136 | dz = (zmax - zmin) // 4 137 | 138 | # Reshaping 139 | y = y[rmin - dr: rmax + dr, 140 | cmin - dc: cmax + dc, 141 | zmin - dz: zmax + dz] 142 | 143 | x_cropped = x.get_data()[rmin - dr: rmax + dr, 144 | cmin - dc: cmax + dc, 145 | zmin - dz: zmax + dz] 146 | 147 | # Save new images so they can be loaded directly 148 | sitk.WriteImage(sitk.GetImageFromArray(y), 149 | join_path([target_path, study_name + "_gt.nii.gz"])) 150 | sitk.WriteImage(sitk.GetImageFromArray(x_cropped), 151 | join_path([target_path, study_name + ".nii.gz"])) 152 | -------------------------------------------------------------------------------- /mp/data/datasets/ds_mr_prostate_decathlon.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Prostate segmentation task from the Medical Segmentation Decathlon 3 | # (http://medicaldecathlon.com/) 4 | # ------------------------------------------------------------------------------ 5 | 6 | import os 7 | import numpy as np 8 | import SimpleITK as sitk 9 | from mp.utils.load_restore import join_path 10 | from mp.data.datasets.dataset_segmentation import SegmentationDataset, SegmentationInstance 11 | from mp.paths import storage_data_path 12 | import mp.data.datasets.dataset_utils as du 13 | 14 | class DecathlonProstateT2(SegmentationDataset): 15 | r"""Class for the prostate segmentation decathlon challenge, only for T2, 16 | found at http://medicaldecathlon.com/. 17 | """ 18 | def __init__(self, subset=None, hold_out_ixs=[], merge_labels=True): 19 | assert subset is None, "No subsets for this dataset." 20 | 21 | global_name = 'DecathlonProstateT2' 22 | dataset_path = os.path.join(storage_data_path, global_name) 23 | original_data_path = du.get_original_data_path(global_name) 24 | 25 | # Separate T2 images, if not already done 26 | if not os.path.isdir(dataset_path): 27 | _extract_t2_images(original_data_path, dataset_path, merge_labels) 28 | 29 | # Fetch all patient/study names 30 | study_names = set(file_name.split('.nii')[0].split('_gt')[0] for file_name 31 | in os.listdir(dataset_path)) 32 | 33 | # Build instances 34 | instances = [] 35 | for study_name in study_names: 36 | instances.append(SegmentationInstance( 37 | x_path=os.path.join(dataset_path, study_name+'.nii.gz'), 38 | y_path=os.path.join(dataset_path, study_name+'_gt.nii.gz'), 39 | name=study_name, 40 | group_id=None 41 | )) 42 | if merge_labels: 43 | label_names = ['background', 'prostate'] 44 | else: 45 | label_names = ['background', 'peripheral zone', 'central gland'] 46 | super().__init__(instances, name=global_name, label_names=label_names, 47 | modality='MR', nr_channels=1, hold_out_ixs=[]) 48 | 49 | def _extract_t2_images(source_path, target_path, merge_labels): 50 | r"""Extracts T2 images, merges mask labels (if specified) and saves the 51 | modified images. 52 | """ 53 | images_path = os.path.join(source_path, 'imagesTr') 54 | labels_path = os.path.join(source_path, 'labelsTr') 55 | 56 | # Filenames have the form 'prostate_XX.nii.gz' 57 | filenames = [x for x in os.listdir(images_path) if x[:8] == 'prostate'] 58 | 59 | # Create directories 60 | os.makedirs(target_path) 61 | 62 | for filename in filenames: 63 | 64 | # Extract only T2-weighted 65 | x = sitk.ReadImage(os.path.join(images_path, filename)) 66 | x = sitk.GetArrayFromImage(x)[0] 67 | y = sitk.ReadImage(os.path.join(labels_path, filename)) 68 | y = sitk.GetArrayFromImage(y) 69 | assert x.shape == y.shape 70 | 71 | # No longer distinguish between central and peripheral zones 72 | if merge_labels: 73 | y = np.where(y==2, 1, y) 74 | 75 | # Save new images so they can be loaded directly 76 | study_name = filename.replace('_', '').split('.nii')[0] 77 | sitk.WriteImage(sitk.GetImageFromArray(x), 78 | join_path([target_path, study_name+".nii.gz"])) 79 | sitk.WriteImage(sitk.GetImageFromArray(y), 80 | join_path([target_path, study_name+"_gt.nii.gz"])) -------------------------------------------------------------------------------- /mp/data/pytorch/pytorch_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # This class builds a descendant of torch.utils.data.Dataset from a 3 | # mp.data.datasets.dataset.Dataset and a list of instance indexes. 4 | # ------------------------------------------------------------------------------ 5 | 6 | from torch.utils.data import Dataset 7 | 8 | class PytorchDataset(Dataset): 9 | def __init__(self, dataset, ix_lst=None, size=None): 10 | r"""A dataset which is compatible with PyTorch. 11 | 12 | Args: 13 | dataset (mp.data.datasets.dataset.Dataset): a descendant of the 14 | class defined internally for datasets. 15 | ix_lst (list[int]): list specifying the instances of 'dataset' to 16 | include. If 'None', all which are not in the hold-out dataset 17 | are incuded. 18 | size (tuple[int]): desired input size. 19 | 20 | :param resize: resize images into this new size. 21 | :param transform_lst: a list of torchvision transforms operations. 22 | :param norm: values to normalize the dataset with the form 23 | {'mean': tuple, 'std': tuple}, which can be generated by 24 | mp.utils.pytorch.compute_normalization_values 25 | """ 26 | # Indexes 27 | if ix_lst is None: 28 | ix_lst = [ix for ix in range(len(dataset.instances)) 29 | if ix not in dataset.hold_out_ixs] 30 | self.instances = [ex for ix, ex in enumerate(dataset.instances) 31 | if ix in ix_lst] 32 | self.size = size 33 | 34 | def __len__(self): 35 | return len(self.instances) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /mp/data/pytorch/transformation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Tensor transformations. Mainly, transformations from the TorchIO library are 3 | # used (https://torchio.readthedocs.io/transforms). 4 | # ------------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torchio 9 | 10 | NORMALIZATION_STRATEGIES = {None:None, 11 | 'rescaling': torchio.transforms.RescaleIntensity(out_min_max=(0, 1), percentiles=(0.1, 99.)), 12 | 'z_norm': torchio.transforms.ZNormalization(masking_method=None) 13 | # TODO 14 | #'histogram_norm': torchio.transforms.HistogramStandardization(landmarks) 15 | } 16 | 17 | AUGMENTATION_STRATEGIES = {'none':None, 18 | 'standard': torchio.transforms.Compose([ 19 | torchio.transforms.OneOf({ 20 | torchio.transforms.RandomElasticDeformation(p=0.1, 21 | # num_control_points=(7,7,7), 22 | # max_displacement=7.5): 0.7, 23 | num_control_points=(5,5,5), 24 | max_displacement=5.5): 0.7, 25 | torchio.RandomAffine(p=0.1, 26 | scales=(0.5, 1.5), 27 | degrees=(5), 28 | isotropic=False, 29 | default_pad_value='otsu', 30 | image_interpolation='bspline'): 0.1 31 | }), 32 | torchio.transforms.RandomFlip(p=0.1, 33 | axes=(1, 0, 0)), 34 | torchio.transforms.RandomMotion(p=0.1, 35 | degrees=10, 36 | translation=10, 37 | num_transforms=2), 38 | torchio.transforms.RandomBiasField(p=0.1, 39 | coefficients=(0.5, 0.5), 40 | order=3), 41 | torchio.transforms.RandomNoise(p=0.1, 42 | mean=(0,0), 43 | std=(50, 50)), 44 | torchio.transforms.RandomBlur(p=0.1, 45 | std=(0, 1)) 46 | ]), 47 | 'antoine': torchio.transforms.Compose([ 48 | torchio.RandomAffine(p=0.1, 49 | scales=(0.5, 1.5), 50 | degrees=(5), 51 | isotropic=False, 52 | default_pad_value='otsu', 53 | image_interpolation='bspline'), 54 | torchio.transforms.RandomFlip(p=0.1, 55 | axes=(1, 0, 0)), 56 | torchio.transforms.RandomMotion(p=0.1, 57 | degrees=10, 58 | translation=10, 59 | num_transforms=2) 60 | ]) 61 | } 62 | 63 | def per_label_channel(y, nr_labels, channel_dim=0, device='cpu'): 64 | r"""Trans. a one-channeled mask where the integers specify the label to a 65 | multi-channel output with one channel per label, where 1 marks belonging to 66 | that label.""" 67 | masks = [] 68 | zeros = torch.zeros(y.shape, dtype=torch.float64).to(device) 69 | ones = torch.ones(y.shape, dtype=torch.float64).to(device) 70 | for label_nr in range(nr_labels): 71 | mask = torch.where(y == label_nr, ones, zeros) 72 | masks.append(mask) 73 | target = torch.cat(masks, dim=channel_dim) 74 | return target 75 | 76 | def _one_output_channel_single(y): 77 | r"""Helper function.""" 78 | channel_dim = 0 79 | target_shape = list(y.shape) 80 | nr_labels = target_shape[channel_dim] 81 | target_shape[channel_dim] = 1 82 | target = torch.zeros(target_shape, dtype=torch.float64) 83 | label_nr_mask = torch.zeros(target_shape, dtype=torch.float64) 84 | for label_nr in range(nr_labels): 85 | label_nr_mask.fill_(label_nr) 86 | target = torch.where(y[label_nr] == 1, label_nr_mask, target) 87 | return target 88 | 89 | def one_output_channel(y, channel_dim=0): 90 | r"""Inverses the operation of 'per_label_channel'. The output is 91 | one-channelled. It is stricter than making a prediction because the content 92 | must be 1 and not the largest float.""" 93 | if channel_dim == 0: 94 | return _one_output_channel_single(y) 95 | else: 96 | assert channel_dim == 1, "Not implemented for channel_dim > 1" 97 | batch = [_one_output_channel_single(x) for x in y] 98 | return torch.stack(batch, dim=0) 99 | 100 | def resize_2d(img, size=(1, 128, 128), label=False): 101 | r"""2D resize.""" 102 | img.unsqueeze_(0) # Add additional batch dimension so input is 4D 103 | if label: 104 | # Interpolation in 'nearest' mode leaves the original mask values. 105 | img = F.interpolate(img, size=size[1:], mode='nearest') 106 | else: 107 | img = F.interpolate(img, size=size[1:], mode='bilinear', align_corners=True) 108 | return img[0] 109 | 110 | def resize_3d(img, size=(1, 56, 56, 56), label=False): 111 | r"""3D resize.""" 112 | img.unsqueeze_(0) # Add additional batch dimension so input is 5D 113 | if label: 114 | # Interpolation in 'nearest' mode leaves the original mask values. 115 | img = F.interpolate(img, size=size[1:], mode='nearest') 116 | else: 117 | img = F.interpolate(img, size=size[1:], mode='trilinear') 118 | return img[0] 119 | 120 | def centre_crop_pad_2d(img, size=(1, 128, 128)): 121 | r"""Center-crops to the specified size, unless the image is to small in some 122 | dimension, then padding takes place. 123 | """ 124 | img = torch.unsqueeze(img, -1) 125 | size = (size[1], size[2], 1) 126 | transform = torchio.transforms.CropOrPad(target_shape=size, padding_mode=0) 127 | device = img.device 128 | img = transform(img.cpu()).to(device) 129 | img = torch.squeeze(img, -1) 130 | return img 131 | 132 | def centre_crop_pad_3d(img, size=(1, 56, 56, 56)): 133 | r"""Center-crops to the specified size, unless the image is to small in some 134 | dimension, then padding takes place. For 3D data. 135 | """ 136 | transform = torchio.transforms.CropOrPad(target_shape=size[1:], padding_mode=0) 137 | device = img.device 138 | img = transform(img.cpu()).to(device) 139 | return img 140 | 141 | def pad_3d_if_required(instance, size): 142 | r"""Pads if required in the last dimension, for 3D. 143 | """ 144 | if instance.shape[-1] < size[-1]: 145 | delta = size[-1]-instance.shape[-1] 146 | subject = instance.get_subject() 147 | transform = torchio.transforms.Pad(padding=(0, 0, 0, 0, 0, delta), padding_mode = 0) 148 | subject = transform(subject) 149 | instance.x = torchio.Image(tensor=subject.x.tensor, type=torchio.INTENSITY) 150 | instance.y = torchio.Image(tensor=subject.y.tensor, type=torchio.LABEL) 151 | instance.shape = subject.shape 152 | return instance 153 | 154 | from torchvision import transforms 155 | 156 | def torchvision_rescaling(x, size=(3, 224, 224), resize=False): 157 | r"""To use pretrained torchvision models, three-channeled 2D images must 158 | first be normalized between 0 and 1 and then noralized with predfined values 159 | (see https://pytorch.org/docs/stable/torchvision/models.html) 160 | """ 161 | device = x.device 162 | # Images should be normalized between 0 and 1 163 | assert torch.min(x) >= 0. 164 | assert torch.max(x) <= 1. 165 | transform_ops = [] 166 | # If images are one-channeled, triplicate 167 | if x.shape[1] == 1: 168 | transform_ops.append(transforms.Lambda(lambda x: x.repeat(3, 1, 1))) 169 | # Resize or crop 170 | transform_ops.append(transforms.ToPILImage()) 171 | if resize: 172 | transform_ops.append(transforms.Resize(size=(size[1], size[2]))) 173 | else: 174 | transform_ops.append(transforms.CenterCrop(size=(size[1], size[2]))) 175 | # Apply pre-defined normalization 176 | transform_ops.append(transforms.ToTensor()) 177 | transform_ops.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], 178 | std=[0.229, 0.224, 0.225])) 179 | # Apply transform operation 180 | transform = transforms.Compose(transform_ops) 181 | imgs = [] 182 | for img in x: 183 | imgs.append(transform(img.cpu()).to(device)) 184 | return torch.stack(imgs, dim=0) -------------------------------------------------------------------------------- /mp/eval/accumulator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Accumulates results from a minibatch. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | class Accumulator: 9 | def __init__(self, keys=None): 10 | self.values = dict() 11 | if keys is not None: 12 | self.init(keys) 13 | 14 | def update(self, acc): 15 | for key, value in acc.values.items(): 16 | if key not in self.values: 17 | self.values[key] = value 18 | 19 | def init(self, keys): 20 | for key in keys: 21 | self.values[key] = [] 22 | 23 | def ensure_key(self, key): 24 | if key not in self.values: 25 | self.values[key] = [] 26 | 27 | def add(self, key, value, count=1): 28 | self.ensure_key(key) 29 | if isinstance(value, torch.Tensor): 30 | np_value = float(value.detach().cpu()) 31 | else: 32 | np_value = value 33 | for _ in range(count): 34 | self.values[key].append(np_value) 35 | 36 | def mean(self, key): 37 | return np.mean(self.values[key]) 38 | 39 | def std(self, key): 40 | return np.std(self.values[key]) 41 | 42 | def sum(self, key): 43 | return sum(self.values[key]) 44 | 45 | def get_keys(self): 46 | return sorted(list(self.values.keys())) 47 | 48 | 49 | -------------------------------------------------------------------------------- /mp/eval/inference/predict.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Transform a multi-channeled network output into a prediction, and similar 3 | # helper functions. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import torch 7 | 8 | def arg_max(output, channel_dim=1): 9 | r"""Select the class with highest probability.""" 10 | return torch.argmax(output, dim=channel_dim) 11 | 12 | def softmax(output, channel_dim=1): 13 | r"""Softmax outputs so that the vlues add up to 1.""" 14 | f = torch.nn.Softmax(dim=channel_dim) 15 | return f(output) 16 | 17 | # TODO 18 | def confidence(softmaxed_output, channel_dim=1): 19 | r"""Returns the confidence for each voxel (the highest value along the 20 | channel dimension).""" 21 | pass 22 | 23 | def ensable_prediction(): 24 | pass 25 | -------------------------------------------------------------------------------- /mp/eval/inference/predictor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A Predictor makes a prediction for a subject index that has the same size 3 | # as the subject's target. It reverses the trandormation operations performed 4 | # so that inputs can be passed through the model. It, for instance, merges 5 | # patches and 2D slices into 3D volumes of the original size. 6 | # ------------------------------------------------------------------------------ 7 | 8 | import copy 9 | import torch 10 | import torchio 11 | import mp.data.pytorch.transformation as trans 12 | 13 | class Predictor(): 14 | r"""A predictor recreates a prediction with the correct dimensions from 15 | model outputs. There are different predictors for different PytorchDatasets, 16 | and these are setted internally with the creation of a PytorchDataset. 17 | Args: 18 | instances (list[Instance]): a list of instances, as for a Dataset 19 | size (tuple[int]): size as (channels, width, height, Opt(depth)) 20 | norm (torchio.transforms): a normaliztion strategy 21 | """ 22 | def __init__(self, instances, size=(1, 56, 56, 10), norm=None): 23 | self.instances = instances 24 | assert len(size) > 2 25 | self.size = size 26 | self.norm = norm 27 | 28 | def transform_subject(self, subject): 29 | r"""Apply normalization strategy to subject.""" 30 | if self.norm is not None: 31 | subject = self.norm(subject) 32 | return subject 33 | 34 | def get_subject(self, subject_ix): 35 | r"""Copy and load a TorchIO subject.""" 36 | subject = copy.deepcopy(self.instances[subject_ix].get_subject()) 37 | subject.load() 38 | subject = self.transform_subject(subject) 39 | return subject 40 | 41 | def get_subject_prediction(self, agent, subject_ix): 42 | r"""Get a prediction for a 3D subject.""" 43 | raise NotImplementedError 44 | 45 | class Predictor2D(Predictor): 46 | r"""The Predictor2D makes a forward pass for each 2D slice and merged these 47 | into a volume. 48 | """ 49 | def __init__(self, *args, resize=False, **kwargs): 50 | super().__init__(*args, **kwargs) 51 | self.resize = resize 52 | 53 | def get_subject_prediction(self, agent, subject_ix): 54 | 55 | subject = self.get_subject(subject_ix) 56 | 57 | # Slides first 58 | x = subject.x.tensor.permute(3, 0, 1, 2) 59 | # Get original size 60 | original_size = subject['y'].data.shape 61 | original_size_2d = original_size[:3] 62 | 63 | pred = [] 64 | with torch.no_grad(): 65 | for slice_idx in range(len(x)): 66 | if self.resize: 67 | inputs = trans.resize_2d(x[slice_idx], size=self.size).to(agent.device) 68 | inputs = torch.unsqueeze(inputs, 0) 69 | slice_pred = agent.predict(inputs).float() 70 | pred.append(trans.resize_2d(slice_pred, size=original_size_2d, label=True)) 71 | else: 72 | inputs = trans.centre_crop_pad_2d(x[slice_idx], size=self.size).to(agent.device) 73 | inputs = torch.unsqueeze(inputs, 0) 74 | slice_pred = agent.predict(inputs).float() 75 | pred.append(trans.centre_crop_pad_2d(slice_pred, size=original_size_2d)) 76 | 77 | # Merge slices and rotate so depth last 78 | pred = torch.stack(pred, dim=0) 79 | pred = pred.permute(1, 2, 3, 0) 80 | assert original_size == pred.shape 81 | return pred 82 | 83 | class Predictor3D(Predictor): 84 | r"""The Predictor3D Reconstructs an image into the original size after 85 | performing a forward pass. 86 | """ 87 | def __init__(self, *args, resize=False, **kwargs): 88 | super().__init__(*args, **kwargs) 89 | self.resize = resize 90 | 91 | def get_subject_prediction(self, agent, subject_ix): 92 | subject = self.get_subject(subject_ix) 93 | 94 | x = subject['x'].data 95 | # Get original label size 96 | original_size = subject['y'].data.shape 97 | 98 | 99 | if self.resize: 100 | # Resize to appropiate model size and make prediction 101 | x = trans.resize_3d(x, size=self.size).to(agent.device) 102 | x = torch.unsqueeze(x, 0) 103 | with torch.no_grad(): 104 | pred = agent.predict(x).float() 105 | # Restore prediction to original size 106 | pred = trans.resize_3d(pred, size=original_size, label=True) 107 | 108 | else: 109 | # Crop or pad instead of interpolating 110 | x = trans.centre_crop_pad_3d(x, size=self.size).to(agent.device) 111 | x = torch.unsqueeze(x, 0) 112 | with torch.no_grad(): 113 | pred = agent.predict(x).float() 114 | pred = trans.centre_crop_pad_3d(pred, size=original_size) 115 | assert original_size == pred.shape 116 | return pred 117 | 118 | class GridPredictor(Predictor): 119 | r"""The GridPredictor deconstructs a 3D volume into patches, makes a forward 120 | pass through the model and reconstructs a prediction of the output size. 121 | """ 122 | def __init__(self, *args, patch_overlap = (0,0,0), **kwargs): 123 | super().__init__(*args, **kwargs) 124 | assert patch_overlap[2] == 0 # Otherwise, have gotten wrong overlap 125 | self.patch_overlap = patch_overlap 126 | self.patch_size = self.size[1:] 127 | 128 | def get_subject_prediction(self, agent, subject_ix): 129 | 130 | subject = self.get_subject(subject_ix) 131 | original_size = subject['y'].data.shape 132 | 133 | grid_sampler = torchio.inference.GridSampler( 134 | sample=subject, 135 | patch_size=self.patch_size, 136 | patch_overlap=self.patch_overlap) 137 | 138 | # Make sure the correct transformations are performed before predicting 139 | patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=5) 140 | patch_aggregator = torchio.inference.GridAggregator(grid_sampler) 141 | with torch.no_grad(): 142 | for patches_batch in patch_loader: 143 | input_tensor = patches_batch['x'][torchio.DATA].to(agent.device) 144 | locations = patches_batch[torchio.LOCATION].to(agent.device) 145 | pred = agent.predict(input_tensor) 146 | # Add dimension for channel, which is not in final output 147 | pred = torch.unsqueeze(pred, 1) 148 | 149 | patch_aggregator.add_batch(pred, locations) 150 | output = patch_aggregator.get_output_tensor().to(agent.device) 151 | 152 | assert original_size == output.shape 153 | return output -------------------------------------------------------------------------------- /mp/eval/losses/loss_abstract.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # An abstract loss function to use during training. These are defined in the 3 | # project to output respective evaluation dictionaries that report all 4 | # components of the loss separatedly. 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch.nn as nn 8 | 9 | class LossAbstract(nn.Module): 10 | r"""A named loss function, that loss functions should inherit from. 11 | Args: 12 | device (str): device key 13 | """ 14 | def __init__(self, device='cuda:0'): 15 | super().__init__() 16 | self.device = device 17 | self.name = self.__class__.__name__ 18 | 19 | def get_evaluation_dict(self, output, target): 20 | r"""Return keys and values of all components making up this loss. 21 | Args: 22 | output (torch.tensor): a torch tensor for a multi-channeled model 23 | output 24 | target (torch.tensor): a torch tensor for a multi-channeled target 25 | """ 26 | return {self.name: float(self.forward(output, target).cpu())} -------------------------------------------------------------------------------- /mp/eval/losses/losses_autoencoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Similitude metrics between output and target. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import torch.nn as nn 6 | from mp.eval.losses.loss_abstract import LossAbstract 7 | 8 | class LossMSE(LossAbstract): 9 | r"""Mean Squared Error.""" 10 | def __init__(self, device='cuda:0'): 11 | super().__init__(device=device) 12 | self.mse = nn.MSELoss(reduction='mean') 13 | 14 | def forward(self, output, target): 15 | return self.mse(output, target) 16 | 17 | class LossL1(LossAbstract): 18 | r"""L1 distance loss.""" 19 | def __init__(self, device='cuda:0'): 20 | super().__init__(device=device) 21 | self.l1 = nn.L1Loss(reduction='mean') 22 | 23 | def forward(self, output, target): 24 | return self.l1(output, target) -------------------------------------------------------------------------------- /mp/eval/losses/losses_segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Collection of loss metrics that can be used during training, including binary 3 | # cross-entropy and dice. Class-wise weights can be specified. 4 | # Losses receive a 'target' array with shape (batch_size, channel_dim, etc.) 5 | # and channel dimension equal to nr. of classes that has been previously 6 | # transformed (through e.g. softmax) so that values lie between 0 and 1, and an 7 | # 'output' array with the same dimension and values that are either 0 or 1. 8 | # The results of the loss is always averaged over batch items (the first dim). 9 | # ------------------------------------------------------------------------------ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from mp.eval.losses.loss_abstract import LossAbstract 14 | 15 | class LossDice(LossAbstract): 16 | r"""Dice loss with a smoothing factor.""" 17 | def __init__(self, smooth=1., device='cuda:0'): 18 | super().__init__(device=device) 19 | self.smooth = smooth 20 | self.device = device 21 | self.name = 'LossDice[smooth='+str(self.smooth)+']' 22 | 23 | def forward(self, output, target): 24 | output_flat = output.view(-1) 25 | target_flat = target.view(-1) 26 | intersection = (output_flat * target_flat).sum() 27 | return 1 - ((2. * intersection + self.smooth) / 28 | (output_flat.sum() + target_flat.sum() + self.smooth)) 29 | 30 | class LossBCE(LossAbstract): 31 | r"""Binary cross entropy loss.""" 32 | def __init__(self, device='cuda:0'): 33 | super().__init__(device=device) 34 | self.device = device 35 | self.bce = nn.BCELoss(reduction='mean') 36 | 37 | def forward(self, output, target): 38 | # output = output.contiguous() 39 | # target = target.contiguous() 40 | # print(output.max(), output.min()) 41 | # print(target.max(), target.min()) 42 | # print(output.max(), output.min()) 43 | # print(target.max(), target.min()) 44 | # print(torch.isnan(output).any()) 45 | # print(torch.isnan(target).any()) 46 | try: 47 | bce_loss = self.bce(output, target) 48 | except: 49 | print(output.max(), output.min()) 50 | print(target.max(), target.min()) 51 | print(torch.isnan(output).any()) 52 | print(torch.isnan(target).any()) 53 | return bce_loss # self.bce(output, target) 54 | 55 | class LossBCEWithLogits(LossAbstract): 56 | r"""More stable than following applying a sigmoid function to the output 57 | before applying the loss (see 58 | https://pytorch.org/docs/stable/generated/torch.nn.LossBCEWithLogits.html), 59 | but only use if applicable.""" 60 | def __init__(self, device='cuda:0'): 61 | super().__init__(device=device) 62 | self.bce = nn.BCEWithLogitsLoss(reduction='mean')# BCELossWithLogits 63 | 64 | def forward(self, output, target): 65 | return self.bce(output, target) 66 | 67 | class LossCombined(LossAbstract): 68 | r"""A combination of several different losses.""" 69 | def __init__(self, losses, weights, device='cuda:0'): 70 | super().__init__(device=device) 71 | self.losses = losses 72 | self.weights = weights 73 | # Set name 74 | self.name = 'LossCombined[' 75 | for loss, weight in zip(self.losses, self.weights): 76 | self.name += str(weight)+'x'+loss.name + '+' 77 | self.name = self.name[:-1] + ']' 78 | 79 | def forward(self, output, target): 80 | total_loss = torch.zeros(1).to(self.device) 81 | for loss, weight in zip(self.losses, self.weights): 82 | total_loss += weight*loss(output, target) 83 | return total_loss 84 | 85 | def get_evaluation_dict(self, output, target): 86 | eval_dict = super().get_evaluation_dict(output, target) 87 | for loss, weight in zip(self.losses, self.weights): 88 | loss_eval_dict = loss.get_evaluation_dict(output, target) 89 | for key, value in loss_eval_dict.items(): 90 | eval_dict[key] = value 91 | return eval_dict 92 | 93 | class LossDiceBCE(LossCombined): 94 | r"""A combination of Dice and Binary cross entropy.""" 95 | def __init__(self, bce_weight=1., smooth=1., device='cuda:0'): 96 | super().__init__(losses=[LossDice(smooth=smooth), LossBCE()], 97 | weights=[1., bce_weight], device=device) 98 | 99 | class LossClassWeighted(LossAbstract): 100 | r"""A loss that weights different labels differently. Often, weights should 101 | be set inverse to the ratio of pixels of that class in the data so that 102 | classes with high representation (e.g. background) do not monopolize the 103 | loss.""" 104 | def __init__(self, loss, weights=None, nr_labels=None, device='cuda:0'): 105 | super().__init__(device) 106 | 107 | self.loss = loss 108 | if weights is None: 109 | assert nr_labels is not None, "Specify either weights or number of labels." 110 | self.class_weights = [1 for label_nr in range(nr_labels)] 111 | else: 112 | self.class_weights = weights 113 | # Set name 114 | self.name = 'LossClassWeighted[loss='+loss.name+'; weights='+str(tuple(self.class_weights))+']' 115 | # Set tensor class weights 116 | self.class_weights = torch.tensor(self.class_weights).to(self.device) 117 | self.added_weights = self.class_weights.sum() 118 | 119 | def forward(self, output, target): 120 | batch_loss = torch.zeros(1).to(self.device) 121 | for instance_output, instance_target in zip(output, target): 122 | instance_loss = torch.zeros(1).to(self.device) 123 | for out_channel_output, out_channel_target, weight in zip(instance_output, instance_target, self.class_weights): 124 | instance_loss += weight * self.loss(out_channel_output, 125 | out_channel_target) 126 | batch_loss += instance_loss / self.added_weights 127 | return batch_loss / len(output) 128 | 129 | def get_evaluation_dict(self, output, target): 130 | eval_dict = super().get_evaluation_dict(output, target) 131 | weighted_loss_values = [0 for weight in self.class_weights] 132 | for instance_output, instance_target in zip(output, target): 133 | for out_channel_output, out_channel_target, weight_ix in zip(instance_output, instance_target, range(len(weighted_loss_values))): 134 | instance_weighted_loss = self.loss(out_channel_output, out_channel_target) 135 | weighted_loss_values[weight_ix] += float(instance_weighted_loss.cpu()) 136 | for weight_ix, loss_value in enumerate(weighted_loss_values): 137 | eval_dict[self.loss.name+'['+str(weight_ix)+']'] = loss_value / len(output) 138 | return eval_dict -------------------------------------------------------------------------------- /mp/eval/metrics/mean_scores.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Collection of metrics to compare whole 1-channel segmentation masks. 3 | # Metrics receive two 1-channel integer arrays. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import torch 7 | import mp.eval.metrics.scores as score_defs 8 | 9 | def get_tp_tn_fn_fp_segmentation(target, pred, class_ix=1): 10 | r"""Get TP, TN, FN and FP pixel values for segmentation.""" 11 | assert target.shape + pred.shape 12 | device, shape = target.device, target.shape 13 | zeros = torch.zeros(shape).to(device) 14 | ones = torch.ones(shape).to(device) 15 | target_class = torch.where(target==class_ix,ones,zeros) 16 | pred_class = torch.where(pred==class_ix,ones,zeros) 17 | tp = torch.where(target_class==1,pred_class,zeros).sum() 18 | tn = torch.where(target_class==0,1-pred_class,zeros).sum() 19 | fn = torch.where(target_class==1,1-pred_class,zeros).sum() 20 | fp = torch.where(pred_class==1,1-target_class,zeros).sum() 21 | tp, tn, fn, fp = int(tp), int(tn), int(fn), int(fp) 22 | #assert int(ones.sum()) == tp+tn+fn+fp 23 | return tp, tn, fn, fp 24 | 25 | def get_mean_scores(target, pred, metrics=['ScoreDice', 'ScoreIoU'], 26 | label_names=['background', 'class 1'], label_weights=None, 27 | segmentation=True): 28 | r"""Returns the scores per label, as well as the (weighted) mean, such as 29 | to avoid considering "don't care" classes. The weights don't have to be 30 | normalized. 31 | """ 32 | scores = {metric: dict() for metric in metrics} 33 | # Calculate metric values per each class 34 | metrics = {metric: getattr(score_defs, metric)() for metric in metrics} 35 | for label_nr, label_name in enumerate(label_names): 36 | # TODO: enable also for classification 37 | tp, tn, fn, fp = get_tp_tn_fn_fp_segmentation(target, pred, class_ix=label_nr) 38 | for metric_key, metric_f in metrics.items(): 39 | score = metric_f.eval(tp, tn, fn, fp) 40 | scores[metric_key+'['+label_name+']'] = score 41 | scores[metric_key][label_name] = score 42 | # Calculate metric means 43 | if label_weights is None: 44 | label_weights = {label_name: 1 for label_name in label_names} 45 | for metric_key in metrics.keys(): 46 | # Replace the dictionary by the mean 47 | mean = sum([ 48 | label_score*label_weights[label_name] for label_name, label_score 49 | in scores[metric_key].items()]) /sum(list(label_weights.values())) 50 | scores[metric_key] = mean 51 | return scores -------------------------------------------------------------------------------- /mp/eval/metrics/scores.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Definition of a score metrics for classification and segmentation, taking 3 | # in tp, tn, fn and fp as inputs. For segmentation, these refer to pixel/voxel 4 | # values for one example. 5 | # ------------------------------------------------------------------------------ 6 | 7 | class ScoreAbstract: 8 | r"""Ab abstract definition of a metric that uses true positives, true 9 | negatives, false negatives and false positives to calculate a value.""" 10 | def __init__(self): 11 | self.name = self.__class__.__name__ 12 | 13 | def eval(self, tp, tn, fn, fp): 14 | raise NotImplementedError 15 | 16 | class ScoreDice(ScoreAbstract): 17 | r"""Dice score, inverce of a Dice loss except for the smoothing factor in 18 | the loss.""" 19 | def eval(self, tp, tn, fn, fp): 20 | if tp == 0: 21 | if fn+fp > 0: 22 | return 0. 23 | else: 24 | return 1. 25 | return (2*tp)/(2*tp+fp+fn) 26 | 27 | class ScoreIoU(ScoreAbstract): 28 | r"""Intersection over Union.""" 29 | def eval(self, tp, tn, fn, fp): 30 | if tp == 0: 31 | if fn+fp > 0: 32 | return 0. 33 | else: 34 | return 1. 35 | return tp/(tp+fp+fn) 36 | 37 | class ScorePrecision(ScoreAbstract): 38 | r"""Precision.""" 39 | def eval(self, tp, tn, fn, fp): 40 | if tp == 0: 41 | if fp > 0: 42 | return 0. 43 | else: 44 | return 1. 45 | return tp/(tp+fp) 46 | 47 | class ScorePPV(ScorePrecision): 48 | r"""Positive predictve value, equivalent to precision.""" 49 | pass 50 | 51 | class ScoreRecall(ScoreAbstract): 52 | r"""Recall.""" 53 | def eval(self, tp, tn, fn, fp): 54 | if tp == 0: 55 | if fp > 0: 56 | return 0. 57 | else: 58 | return 1. 59 | return tp/(tp+fn) 60 | 61 | class ScoreSensitivity(ScoreRecall): 62 | r"""Sensitivity, equivalent to recall.""" 63 | pass 64 | 65 | class ScoreTPR(ScoreRecall): 66 | r"""True positive rate, equivalent to recall.""" 67 | pass 68 | 69 | class ScoreSpecificity(ScoreAbstract): 70 | r"""Specificity.""" 71 | def eval(self, tp, tn, fn, fp): 72 | if tn == 0: 73 | if fp > 0: 74 | return 0. 75 | else: 76 | return 1. 77 | return tn/(tn+fp) 78 | 79 | class ScoreTNR(ScoreSpecificity): 80 | r"""True negative rate, equivalent to specificity.""" 81 | pass -------------------------------------------------------------------------------- /mp/eval/result.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A class which accumulates results for easy visualization. 3 | # 'Result' stores the per-epoch results for a run, e.g. for a fold. 4 | # 'ExperimentResult' calculates the average over all runs at the end. 5 | # ------------------------------------------------------------------------------ 6 | 7 | import pandas as pd 8 | 9 | class ExperimentResults(): 10 | r"""Per-epoch results for all repetitions.""" 11 | def __init__(self, global_result_lst, epoch_result_lst): 12 | pass 13 | # TODO 14 | 15 | class Result(): 16 | r"""Per-epoch results for 1 repetition.""" 17 | def __init__(self, name='Results'): 18 | self.name = name 19 | self.results = dict() 20 | 21 | def add(self, epoch, metric, value, data='train'): 22 | r"""Add a new result entry.""" 23 | assert isinstance(epoch, int) 24 | assert isinstance(metric, str) 25 | if isinstance(data, tuple): 26 | data = '_'.join(data) 27 | assert isinstance(data, str) 28 | assert isinstance(value, float) or isinstance(value, int) 29 | if metric not in self.results: 30 | self.results[metric] = dict() 31 | if epoch not in self.results[metric]: 32 | self.results[metric][epoch] = dict() 33 | self.results[metric][epoch][data] = value 34 | 35 | def get_epoch_metric(self, epoch, metric, data='train'): 36 | r"""Get the value for a metric and epoch.""" 37 | try: 38 | value = self.results[metric][epoch][data] 39 | return value 40 | except Exception: 41 | return None 42 | 43 | def to_pandas(self): 44 | r"""Pandas representation of results.""" 45 | data = [[metric, epoch, data, 46 | self.results[metric][epoch][data]] 47 | for metric in self.results.keys() 48 | for epoch in self.results[metric].keys() 49 | for data in self.results[metric][epoch].keys()] 50 | df = pd.DataFrame(data, columns = ['Metric', 'Epoch', 'Data', 'Value']) 51 | return df 52 | 53 | def get_min_epoch(self, metric, data='val'): 54 | r"""Get the earliest epoch for which there is an entry.""" 55 | return min(self.results[metric].keys(), key=lambda e: self.results[metric][e][data]) 56 | 57 | def get_max_epoch(self, metric, data='val'): 58 | r"""Get the latest epoch for which there is an entry.""" 59 | return max(self.results[metric].keys(), key=lambda e: self.results[metric][e][data]) -------------------------------------------------------------------------------- /mp/models/autoencoding/autoencoder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Basic class for autoencoder models that reconstruct the input. 3 | # ------------------------------------------------------------------------------ 4 | 5 | from mp.models.model import Model 6 | 7 | class Autoencoder(Model): 8 | r"""A superclass for autoencoder models. 9 | 10 | Args: 11 | input_shape tuple (int): (channels, width, height, Opt(depth)) 12 | """ 13 | def __init__(self, input_shape): 14 | # An autoencoder has the same input and output shapes 15 | super().__init__(input_shape, output_shape=input_shape) 16 | 17 | def encode(self, x): 18 | r"""Encode the input.""" 19 | raise NotImplementedError 20 | 21 | def decode(self, x): 22 | r"""Decode the features into an output.""" 23 | raise NotImplementedError 24 | 25 | def forward(self, x): 26 | initial_shape = x.shape 27 | x = self.encode(x) 28 | x = self.decode(x) 29 | assert x.shape == initial_shape 30 | return x 31 | -------------------------------------------------------------------------------- /mp/models/autoencoding/autoencoder_cnn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A 2D convolutional autoencoder. Note that the input must be normalized between 3 | # 0 and 1. 4 | # ------------------------------------------------------------------------------ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from mp.models.autoencoding.autoencoder import Autoencoder 9 | 10 | class AutoencoderCNN(Autoencoder): 11 | r"""A simple CNN autoencoder.""" 12 | def __init__(self, input_shape, hidden_ch = [16, 4]): 13 | super().__init__(input_shape=input_shape) 14 | in_channels = self.input_shape[0] 15 | 16 | # Encoder layers 17 | self.enc_conv1 = nn.Conv2d(in_channels=in_channels, 18 | out_channels=hidden_ch[0], kernel_size=3, stride=1, padding=1) 19 | self.enc_conv2 = nn.Conv2d(hidden_ch[0], hidden_ch[1], 20 | kernel_size=3, stride=1, padding=1) 21 | self.enc_pool = nn.MaxPool2d(2, 2) 22 | 23 | # Decoder layers 24 | self.dec_conv1 = nn.ConvTranspose2d(hidden_ch[1], hidden_ch[0], 25 | kernel_size=2, stride=2, padding=0) 26 | self.dec_conv2 = nn.ConvTranspose2d(hidden_ch[0], in_channels, 27 | kernel_size=2, stride=2, padding=0) 28 | 29 | def encode(self, x): 30 | x = F.relu(self.enc_conv1(x)) 31 | x = self.enc_pool(x) 32 | x = F.relu(self.enc_conv2(x)) 33 | x = self.enc_pool(x) 34 | return x 35 | 36 | def decode(self, x): 37 | x = F.relu(self.dec_conv1(x)) 38 | x = F.sigmoid(self.dec_conv2(x)) # Input should be normed to [0, 1] 39 | return x 40 | -------------------------------------------------------------------------------- /mp/models/autoencoding/autoencoder_featured.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # An autoencoder that reconstructs extracted features. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.models as models 9 | from mp.models.autoencoding.autoencoder_linear import AutoencoderLinear 10 | from mp.data.pytorch.transformation import torchvision_rescaling 11 | 12 | class AutoencoderFeatured(AutoencoderLinear): 13 | r"""An autoencoder that recontracts features.""" 14 | def __init__(self, input_shape, hidden_dim = [128, 64], 15 | feature_model_name='AlexNet'): 16 | 17 | extractor_size = (3, 224, 224) # For AlexNet, TODO clean up and others 18 | features_size = 9216 19 | 20 | super().__init__(input_shape=[features_size], hidden_dim=hidden_dim) 21 | self.extractor_size = extractor_size 22 | self.feature_extractor = self.get_feature_extractor(feature_model_name) 23 | 24 | def preprocess_input(self, x): 25 | r"""Preprocessing that is done to the input before performing the 26 | autoencoding, which is to say also to the target.""" 27 | # Instead of doing a forward pass, we exclude the classifier 28 | # See https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py 29 | x = torchvision_rescaling(x, size=self.extractor_size, resize=False) 30 | x = self.feature_extractor.features(x) 31 | x = self.feature_extractor.avgpool(x) 32 | x = torch.flatten(x, start_dim=1) 33 | return x 34 | 35 | def get_feature_extractor(self, model_name='AlexNet'): 36 | r"""Features are extracted from the input data. These are normalized 37 | with the ImageNet statistics.""" 38 | # Fetch pretrained model 39 | if model_name == 'AlexNet': # input_size = 224 x 224 40 | feature_extractor = models.alexnet(pretrained=True) 41 | # Freeze pretrained parameters 42 | for param in feature_extractor.parameters(): 43 | param.requires_grad = False 44 | return feature_extractor 45 | 46 | def to(self, device): 47 | super().to(device) 48 | self.feature_extractor.to(device) 49 | 50 | 51 | -------------------------------------------------------------------------------- /mp/models/autoencoding/autoencoder_linear.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A linear autoencoder. Note that the input must be normalized between 3 | # 0 and 1. 4 | # ------------------------------------------------------------------------------ 5 | 6 | from functools import reduce 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from mp.models.autoencoding.autoencoder import Autoencoder 11 | 12 | class AutoencoderLinear(Autoencoder): 13 | r"""An autoencoder with only linear layers.""" 14 | def __init__(self, input_shape, hidden_dim = [128, 64]): 15 | super().__init__(input_shape=input_shape) 16 | in_dim = self.input_shape[0] if len(self.input_shape)<2 else reduce(lambda x, y: x*y, self.input_shape) 17 | dims = [in_dim] + hidden_dim 18 | 19 | # Encoder layers 20 | self.enc_layers = nn.ModuleList([nn.Linear(in_features=dims[i], out_features=dims[i+1]) 21 | for i in range(len(dims)-1)]) 22 | 23 | # Decoder layers 24 | self.dec_layers = nn.ModuleList([nn.Linear(in_features=dims[i+1], out_features=dims[i]) 25 | for i in reversed(range(len(dims)-1))]) 26 | 27 | def preprocess_input(self, x): 28 | r"""Flatten x into one dimension.""" 29 | return torch.flatten(x, start_dim=1) 30 | 31 | def encode(self, x): 32 | for layer in self.enc_layers: 33 | x = F.relu(layer(x)) 34 | return x 35 | 36 | def decode(self, x): 37 | for layer in self.dec_layers: 38 | x = F.relu(layer(x)) 39 | return x -------------------------------------------------------------------------------- /mp/models/classification/small_cnn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Example for a small CNN. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from mp.models.model import Model 8 | 9 | class SmallCNN(Model): 10 | r"""An CNN for classification.""" 11 | def __init__(self, input_shape=(3, 32, 32), output_shape=10): 12 | super().__init__(input_shape, output_shape) 13 | self.conv1 = nn.Conv2d(3, 6, 5) 14 | self.pool = nn.MaxPool2d(2, 2) 15 | self.conv2 = nn.Conv2d(6, 16, 5) 16 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 17 | self.fc2 = nn.Linear(120, 84) 18 | self.fc3 = nn.Linear(84, 10) 19 | 20 | def forward(self, x): 21 | x = self.pool(F.relu(self.conv1(x))) 22 | x = self.pool(F.relu(self.conv2(x))) 23 | x = x.view(-1, 16 * 5 * 5) 24 | x = F.relu(self.fc1(x)) 25 | x = F.relu(self.fc2(x)) 26 | x = self.fc3(x) 27 | return x -------------------------------------------------------------------------------- /mp/models/continual/kd.py: -------------------------------------------------------------------------------- 1 | from mp.models.model import Model 2 | from mp.models.segmentation.unet_fepegar import UNet2D 3 | import torch.optim as optim 4 | 5 | class KD(Model): 6 | r"""Knowledge Distillation as porposed in Incremental learning techniques for semantic segmentation 7 | by Michieli, U., Zanuttigh, P., 2019 8 | """ 9 | def __init__(self, 10 | input_shape=(1,256,256), 11 | nr_labels=2, 12 | unet_dropout = 0, 13 | unet_monte_carlo_dropout = 0, 14 | unet_preactivation= False 15 | ): 16 | r"""Constructor 17 | 18 | Args: 19 | input_shape (tuple of int): input shape of the images 20 | nr_labels (int): number of labels for the segmentation 21 | unet_dropout (float): dropout probability for the U-Net 22 | unet_monte_carlo_dropout (float): monte carlo dropout probability for the U-Net 23 | unet_preactivation (boolean): whether to use U-Net pre-activations 24 | """ 25 | super(KD, self).__init__() 26 | 27 | self.input_shape = input_shape 28 | self.nr_labels = nr_labels 29 | 30 | self.unet_dropout = unet_dropout 31 | self.unet_monte_carlo_dropout = unet_monte_carlo_dropout 32 | self.unet_preactivation = unet_preactivation 33 | 34 | self.unet_new = UNet2D(self.input_shape, self.nr_labels, dropout=self.unet_dropout, monte_carlo_dropout=self.unet_monte_carlo_dropout, preactivation=self.unet_preactivation) 35 | self.unet_old = None 36 | 37 | def forward(self, x): 38 | r"""Forward pass of current U-Net 39 | 40 | Args: 41 | x (torch.Tensor): input batch 42 | 43 | Returns: 44 | (torch.Tensor): segmentated batch 45 | """ 46 | return self.unet_new(x) 47 | 48 | def forward_old(self, x): 49 | r"""Forward pass of previous U-Net 50 | 51 | Args: 52 | x (torch.Tensor): input batch 53 | 54 | Returns: 55 | (torch.Tensor): segmentated batch 56 | """ 57 | return self.unet_old(x) 58 | 59 | def freeze_unet(self, unet): 60 | r"""Freeze U-Net 61 | 62 | Args: 63 | unet (nn.Module): U-Net 64 | 65 | Returns: 66 | (nn.Module): U-Net with frozen weights 67 | """ 68 | for param in unet.parameters(): 69 | param.requires_grad = False 70 | return unet 71 | 72 | def freeze_decoder(self, unet): 73 | r"""Freeze U-Net decoder 74 | 75 | Args: 76 | unet (nn.Module): U-Net 77 | 78 | Returns: 79 | (nn.Module): U-Net with frozen decoder weights 80 | """ 81 | for param in unet.decoder.parameters(): 82 | param.requires_grad = False 83 | for param in unet.classifier.parameters(): 84 | param.requires_grad = False 85 | return unet 86 | 87 | def finish(self): 88 | r"""Finish training, store current U-Net as old U-Net 89 | """ 90 | unet_new_state_dict = self.unet_new.state_dict() 91 | if next(self.unet_new.parameters()).is_cuda: 92 | device = next(self.unet_new.parameters()).device 93 | 94 | self.unet_old = UNet2D(self.input_shape, self.nr_labels, dropout=self.unet_dropout, monte_carlo_dropout=self.unet_monte_carlo_dropout, preactivation=self.unet_preactivation) 95 | self.unet_old.load_state_dict(unet_new_state_dict) 96 | self.unet_old = self.freeze_unet(self.unet_old) 97 | 98 | self.unet_old.to(device) 99 | 100 | 101 | def set_optimizers(self, optimizer=optim.SGD, lr=1e-4, weight_decay=1e-4): 102 | r"""Set optimizers for all modules 103 | 104 | Args: 105 | optimizer (torch.nn.optim): optimizer to use 106 | lr (float): learning rate to use 107 | weight_decay (float): weight decay 108 | """ 109 | if optimizer == optim.SGD: 110 | self.unet_optim = optimizer(self.unet_new.parameters(), lr=lr, weight_decay=weight_decay) 111 | else: 112 | self.unet_optim = optimizer(self.unet_new.parameters(), lr=lr) 113 | 114 | -------------------------------------------------------------------------------- /mp/models/continual/mas.py: -------------------------------------------------------------------------------- 1 | from mp.models.model import Model 2 | from mp.models.segmentation.unet_fepegar import UNet2D 3 | import torch.optim as optim 4 | 5 | class MAS(Model): 6 | r"""Memory Aware Synapses for brain segmentation 7 | as porposed in Importance driven continual learning for segmentation across domains by Oezguen et al., 2020 8 | """ 9 | def __init__(self, 10 | input_shape=(1,256,256), 11 | nr_labels=2, 12 | unet_dropout = 0, 13 | unet_monte_carlo_dropout = 0, 14 | unet_preactivation= False 15 | ): 16 | r"""Constructor 17 | 18 | Args: 19 | input_shape (tuple of int): input shape of the images 20 | nr_labels (int): number of labels for the segmentation 21 | unet_dropout (float): dropout probability for the U-Net 22 | unet_monte_carlo_dropout (float): monte carlo dropout probability for the U-Net 23 | unet_preactivation (boolean): whether to use U-Net pre-activations 24 | """ 25 | super(MAS, self).__init__() 26 | 27 | self.input_shape = input_shape 28 | self.nr_labels = nr_labels 29 | 30 | self.unet_dropout = unet_dropout 31 | self.unet_monte_carlo_dropout = unet_monte_carlo_dropout 32 | self.unet_preactivation = unet_preactivation 33 | 34 | self.unet = UNet2D(self.input_shape, self.nr_labels, dropout=self.unet_dropout, monte_carlo_dropout=self.unet_monte_carlo_dropout, preactivation=self.unet_preactivation) 35 | self.unet_old = None 36 | 37 | self.importance_weights = None 38 | self.tasks = 0 39 | 40 | self.n_params_unet = sum(p.numel() for p in self.unet.parameters()) 41 | 42 | 43 | def forward(self, x): 44 | r"""Forward pass of current U-Net 45 | 46 | Args: 47 | x (torch.Tensor): input batch 48 | 49 | Returns: 50 | (torch.Tensor): segmentated batch 51 | """ 52 | return self.unet(x) 53 | 54 | def freeze_unet(self, unet): 55 | r"""Freeze U-Net 56 | 57 | Args: 58 | unet (nn.Module): U-Net 59 | 60 | Returns: 61 | (nn.Module): U-Net with frozen weights 62 | """ 63 | for param in unet.parameters(): 64 | param.requires_grad = False 65 | return unet 66 | 67 | def freeze_decoder(self, unet): 68 | r"""Freeze U-Net decoder 69 | 70 | Args: 71 | unet (nn.Module): U-Net 72 | 73 | Returns: 74 | (nn.Module): U-Net with frozen decoder weights 75 | """ 76 | for param in unet.decoder.parameters(): 77 | param.requires_grad = False 78 | for param in unet.classifier.parameters(): 79 | param.requires_grad = False 80 | return unet 81 | 82 | def set_optimizers(self, optimizer=optim.SGD, lr=1e-4, weight_decay=1e-4): 83 | r"""Set optimizers for all modules 84 | 85 | Args: 86 | optimizer (torch.nn.optim): optimizer to use 87 | lr (float): learning rate to use 88 | weight_decay (float): weight decay 89 | """ 90 | if optimizer == optim.SGD: 91 | self.unet_optim = optimizer(self.unet.parameters(), lr=lr, weight_decay=weight_decay) 92 | else: 93 | self.unet_optim = optimizer(self.unet.parameters(), lr=lr) 94 | 95 | def update_importance_weights(self, importance_weights): 96 | r"""Update importance weights w/ computed ones 97 | 98 | Args: 99 | (torch.Tensor or list): importance_weights 100 | """ 101 | if self.importance_weights == None: 102 | self.importance_weights = importance_weights 103 | else: 104 | for i in range(len(self.importance_weights)): 105 | self.importance_weights[i] -= self.importance_weights[i] / self.tasks 106 | self.importance_weights[i] += importance_weights[i] / self.tasks 107 | self.tasks += 1 108 | 109 | def finish(self): 110 | r"""Finish training, store current U-Net as old U-Net 111 | """ 112 | unet_new_state_dict = self.unet.state_dict() 113 | if next(self.unet.parameters()).is_cuda: 114 | device = next(self.unet.parameters()).device 115 | 116 | self.unet_old = UNet2D(self.input_shape, self.nr_labels, dropout=self.unet_dropout, monte_carlo_dropout=self.unet_monte_carlo_dropout, preactivation=self.unet_preactivation) 117 | self.unet_old.load_state_dict(unet_new_state_dict) 118 | self.unet_old = self.freeze_unet(self.unet_old) 119 | 120 | self.unet_old.to(device) -------------------------------------------------------------------------------- /mp/models/model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Class all model definitions should descend from. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import os 6 | import torch.nn as nn 7 | import numpy as np 8 | from torchsummary import summary 9 | 10 | from mp.utils.pytorch.pytorch_load_restore import load_model_state, save_model_state 11 | 12 | class Model(nn.Module): 13 | r"""A model that descends from torch.nn.Model and includes methods to output 14 | a model summary, as well as the input_shape and output_shape fields used in 15 | other models and the logic to restore previous model states from a path. 16 | 17 | Args: 18 | input_shape tuple (int): Input shape with the form 19 | (channels, width, height, Opt(depth)) 20 | output_shape (Obj): output shape, which takes different forms depending 21 | on the problem 22 | """ 23 | def __init__(self, input_shape=(1, 32, 32), output_shape=2): 24 | super(Model, self).__init__() 25 | self.input_shape = input_shape 26 | self.output_shape = output_shape 27 | 28 | def preprocess_input(self, x): 29 | r"""E.g. pretrained features. Override if needed. """ 30 | return x 31 | 32 | def initialize(self, weights_init_path, device): 33 | r"""Tries to restore a previous model. If no model is found, the initial 34 | weights are saved. 35 | """ 36 | path, name = os.path.split(weights_init_path) 37 | restored = load_model_state(self, path=path, name=name, device=device) 38 | if restored: 39 | print('Initial parameters {} were restored'.format(weights_init_path)) 40 | else: 41 | save_model_state(self, path=path, name=name) 42 | print('Initial parameters {} were saved'.format(weights_init_path)) 43 | 44 | def get_param_list_static(self): 45 | r"""Returns a 1D array of parameter values 46 | """ 47 | model_params_array = [] 48 | for _, param in self.state_dict().items(): 49 | model_params_array.append(param.reshape(-1).cpu().numpy()) 50 | return np.concatenate(model_params_array) 51 | 52 | 53 | # Method to output model information 54 | 55 | def model_summary(self, verbose=False): 56 | r"""Return a Keras-style summary.""" 57 | summary_str = str(summary(self, input_data=self.input_shape, verbose=0)) 58 | if verbose: 59 | print(summary_str) 60 | return summary_str 61 | 62 | 63 | # Methods to calculate the feature size 64 | 65 | def num_flat_features(self, x): 66 | r"""Flattened view of all dimensions except the batch size. 67 | """ 68 | size = x.size()[1:] 69 | num_features = 1 70 | for s in size: 71 | num_features *= s 72 | return num_features 73 | 74 | def flatten(self, x): 75 | r"""Flatten x into 1 dimension.""" 76 | return x.view(-1, self.num_flat_features(x)) 77 | 78 | def size_before_lin(self, shape_input): 79 | r"""Size after linearization. 80 | 81 | Returns (int): integer of dense size 82 | """ 83 | return shape_input[0]*shape_input[1]*shape_input[2] 84 | 85 | def size_after_conv(self, shape_input, output_channels, kernel): 86 | r"""Gives the number of output neurons after the conv operation. 87 | The first dimension is the channel depth and the other 2 are given by 88 | input volume (size - kernel size + 2*padding)/stride + 1 89 | """ 90 | return (output_channels, shape_input[1]-kernel+1, shape_input[2]-kernel+1) 91 | 92 | def size_after_pooling(self, shape_input, shape_pooling): 93 | r"""Maintains the first input dimension, which is the output channels in 94 | the previous conv layer. The others are divided by the shape of the 95 | pooling. 96 | """ 97 | return (shape_input[0], shape_input[1]//shape_pooling[0], shape_input[2]//shape_pooling[1]) 98 | -------------------------------------------------------------------------------- /mp/models/segmentation/segmentation_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Basic class for segmentation models. 3 | # ------------------------------------------------------------------------------ 4 | 5 | from mp.models.model import Model 6 | 7 | class SegmentationModel(Model): 8 | r"""An abstract class for segmentation models that caluclates the output 9 | shape from the input shape and the number of labels.""" 10 | def __init__(self, input_shape, nr_labels): 11 | assert 2 < len(input_shape) < 5 12 | # The output shae is the same as the input shape, but instead of the 13 | # input channels it has the number of labels as channels 14 | output_shape = tuple([nr_labels] + list(input_shape[1:])) 15 | super(SegmentationModel, self).__init__(input_shape, output_shape=output_shape) 16 | self.nr_labels = nr_labels 17 | 18 | -------------------------------------------------------------------------------- /mp/models/segmentation/unet_fepegar.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # This UNet model is modified from https://github.com/fepegar/unet 3 | # (see https://zenodo.org/record/3522306#.X0FJnhmxVhE). 4 | # ------------------------------------------------------------------------------ 5 | 6 | from typing import Optional 7 | import torch.nn as nn 8 | from mp.models.segmentation.segmentation_model import SegmentationModel 9 | from mp.models.segmentation.model_utils import Encoder, EncodingBlock, Decoder, ConvolutionalBlock 10 | 11 | class UNet(SegmentationModel): 12 | def __init__( 13 | self, 14 | input_shape, 15 | nr_labels, 16 | dimensions: int = 2, 17 | num_encoding_blocks: int = 5, 18 | out_channels_first_layer: int = 64, 19 | normalization: Optional[str] = None, 20 | pooling_type: str = 'max', 21 | upsampling_type: str = 'conv', 22 | preactivation: bool = False, 23 | residual: bool = False, 24 | padding: int = 0, 25 | padding_mode: str = 'zeros', 26 | activation: Optional[str] = 'ReLU', 27 | initial_dilation: Optional[int] = None, 28 | dropout: float = 0, 29 | monte_carlo_dropout: float = 0, 30 | ): 31 | super(UNet, self).__init__(input_shape=input_shape, nr_labels=nr_labels) 32 | 33 | in_channels = input_shape[0] 34 | 35 | depth = num_encoding_blocks - 1 36 | 37 | # Force padding if residual blocks 38 | if residual: 39 | padding = 1 40 | 41 | # Encoder 42 | self.encoder = Encoder( 43 | in_channels, 44 | out_channels_first_layer, 45 | dimensions, 46 | pooling_type, 47 | depth, 48 | normalization, 49 | preactivation=preactivation, 50 | residual=residual, 51 | padding=padding, 52 | padding_mode=padding_mode, 53 | activation=activation, 54 | initial_dilation=initial_dilation, 55 | dropout=dropout, 56 | ) 57 | 58 | # Bottom (last encoding block) 59 | in_channels = self.encoder.out_channels 60 | if dimensions == 2: 61 | out_channels_first = 2 * in_channels 62 | else: 63 | out_channels_first = in_channels 64 | 65 | self.bottom_block = EncodingBlock( 66 | in_channels, 67 | out_channels_first, 68 | dimensions, 69 | normalization, 70 | pooling_type=None, 71 | preactivation=preactivation, 72 | residual=residual, 73 | padding=padding, 74 | padding_mode=padding_mode, 75 | activation=activation, 76 | dilation=self.encoder.dilation, 77 | dropout=dropout, 78 | ) 79 | 80 | # Decoder 81 | if dimensions == 2: 82 | power = depth - 1 83 | elif dimensions == 3: 84 | power = depth 85 | in_channels = self.bottom_block.out_channels 86 | in_channels_skip_connection = out_channels_first_layer * 2**power 87 | num_decoding_blocks = depth 88 | self.decoder = Decoder( 89 | in_channels_skip_connection, 90 | dimensions, 91 | upsampling_type, 92 | num_decoding_blocks, 93 | normalization=normalization, 94 | preactivation=preactivation, 95 | residual=residual, 96 | padding=padding, 97 | padding_mode=padding_mode, 98 | activation=activation, 99 | initial_dilation=self.encoder.dilation, 100 | dropout=dropout, 101 | ) 102 | 103 | # Monte Carlo dropout 104 | self.monte_carlo_layer = None 105 | if monte_carlo_dropout: 106 | dropout_class = getattr(nn, 'Dropout{}d'.format(dimensions)) 107 | self.monte_carlo_layer = dropout_class(p=monte_carlo_dropout) 108 | 109 | # Classifier 110 | if dimensions == 2: 111 | in_channels = out_channels_first_layer 112 | elif dimensions == 3: 113 | in_channels = 2 * out_channels_first_layer 114 | self.classifier = ConvolutionalBlock( 115 | dimensions, in_channels, nr_labels, 116 | kernel_size=1, activation=None, 117 | ) 118 | 119 | def forward(self, x): 120 | skip_connections, encoding = self.encoder(x) 121 | encoding = self.bottom_block(encoding) 122 | x = self.decoder(skip_connections, encoding) 123 | if self.monte_carlo_layer is not None: 124 | x = self.monte_carlo_layer(x) 125 | return self.classifier(x) 126 | 127 | class UNet2D(UNet): 128 | def __init__(self, *args, **kwargs): 129 | assert len(args[0]) == 3, "Input shape must have dimensions channels, width, height. Received: {}".format(args[0]) 130 | predef_kwargs = {} 131 | predef_kwargs['dimensions'] = 2 132 | predef_kwargs['num_encoding_blocks'] = 5 133 | predef_kwargs['out_channels_first_layer'] = 16 #64 134 | predef_kwargs['normalization'] = 'batch' 135 | # added TODO 136 | # predef_kwargs['preactivation'] = True 137 | preactivation = True 138 | # Added this so there is no error between the skip connection and 139 | # feature mas shapes 140 | predef_kwargs['padding'] = True 141 | predef_kwargs.update(kwargs) 142 | super(UNet2D, self).__init__(*args, **predef_kwargs) 143 | 144 | class UNet3D(UNet): 145 | def __init__(self, *args, **kwargs): 146 | assert len(args[0]) == 4, "Input shape must have dimensions channels, width, height, depth. Received: {}".format(args[0]) 147 | predef_kwargs = {} 148 | predef_kwargs['dimensions'] = 3 149 | predef_kwargs['num_encoding_blocks'] = 4 150 | predef_kwargs['out_channels_first_layer'] = 8 151 | predef_kwargs['normalization'] = 'batch' 152 | predef_kwargs['upsampling_type'] = 'linear' 153 | predef_kwargs['padding'] = True 154 | predef_kwargs.update(kwargs) 155 | super().__init__(*args, **predef_kwargs) -------------------------------------------------------------------------------- /mp/models/segmentation/unet_milesial.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from mp.models.segmentation.segmentation_model import SegmentationModel 8 | 9 | class UNet(SegmentationModel): 10 | def __init__(self, input_shape, n_classes, bilinear=True): 11 | super(UNet, self).__init__(input_shape, n_classes) 12 | self.n_channels = input_shape[0] 13 | self.n_classes = n_classes 14 | self.bilinear = bilinear 15 | 16 | self.inc = DoubleConv(self.n_channels, 64) 17 | self.down1 = Down(64, 128) 18 | self.down2 = Down(128, 256) 19 | self.down3 = Down(256, 512) 20 | factor = 2 if bilinear else 1 21 | self.down4 = Down(512, 1024 // factor) 22 | self.up1 = Up(1024, 512 // factor, bilinear) 23 | self.up2 = Up(512, 256 // factor, bilinear) 24 | self.up3 = Up(256, 128 // factor, bilinear) 25 | self.up4 = Up(128, 64, bilinear) 26 | self.outc = OutConv(64, n_classes) 27 | 28 | def forward(self, x): 29 | x1 = self.inc(x) 30 | x2 = self.down1(x1) 31 | x3 = self.down2(x2) 32 | x4 = self.down3(x3) 33 | x5 = self.down4(x4) 34 | x = self.up1(x5, x4) 35 | x = self.up2(x, x3) 36 | x = self.up3(x, x2) 37 | x = self.up4(x, x1) 38 | logits = self.outc(x) 39 | return logits 40 | 41 | """ Parts of the U-Net model """ 42 | 43 | class DoubleConv(nn.Module): 44 | """(convolution => [BN] => ReLU) * 2""" 45 | 46 | def __init__(self, in_channels, out_channels, mid_channels=None): 47 | super().__init__() 48 | if not mid_channels: 49 | mid_channels = out_channels 50 | self.double_conv = nn.Sequential( 51 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 52 | nn.BatchNorm2d(mid_channels), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(out_channels), 56 | nn.ReLU(inplace=True) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.double_conv(x) 61 | 62 | 63 | class Down(nn.Module): 64 | """Downscaling with maxpool then double conv""" 65 | 66 | def __init__(self, in_channels, out_channels): 67 | super().__init__() 68 | self.maxpool_conv = nn.Sequential( 69 | nn.MaxPool2d(2), 70 | DoubleConv(in_channels, out_channels) 71 | ) 72 | 73 | def forward(self, x): 74 | return self.maxpool_conv(x) 75 | 76 | 77 | class Up(nn.Module): 78 | """Upscaling then double conv""" 79 | 80 | def __init__(self, in_channels, out_channels, bilinear=True): 81 | super().__init__() 82 | 83 | # if bilinear, use the normal convolutions to reduce the number of channels 84 | if bilinear: 85 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 86 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 87 | else: 88 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 89 | self.conv = DoubleConv(in_channels, out_channels) 90 | 91 | 92 | def forward(self, x1, x2): 93 | x1 = self.up(x1) 94 | # input is CHW 95 | diffY = x2.size()[2] - x1.size()[2] 96 | diffX = x2.size()[3] - x1.size()[3] 97 | 98 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 99 | diffY // 2, diffY - diffY // 2]) 100 | # if you have padding issues, see 101 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 102 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 103 | x = torch.cat([x2, x1], dim=1) 104 | return self.conv(x) 105 | 106 | 107 | class OutConv(nn.Module): 108 | def __init__(self, in_channels, out_channels): 109 | super(OutConv, self).__init__() 110 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 111 | 112 | def forward(self, x): 113 | return self.conv(x) -------------------------------------------------------------------------------- /mp/paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Module where paths should be defined. 3 | # ------------------------------------------------------------------------------ 4 | import os 5 | 6 | # Path where intermediate and final results are stored 7 | storage_path = 'storage' 8 | storage_data_path = os.path.join(storage_path, 'data') 9 | 10 | # Original data paths. TODO: set necessary data paths. 11 | # original_data_paths = {'example_dataset_name': 'storage/data'} 12 | original_data_paths = {'DecathlonHippocampus': 'storage/data/DecathlonHippocampus', 13 | 'DryadHippocampus': 'storage/data/DryadHippocampus', 14 | 'HarP': 'storage/data/HarP'} 15 | 16 | # Login for Telegram Bot 17 | telegram_login = {'chat_id': 'TODO', 'token': 'TODO'} 18 | -------------------------------------------------------------------------------- /mp/utils/connection/check_connection.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Check connection by sending messages to a Telegram Bot once in a while. 3 | # ------------------------------------------------------------------------------ 4 | 5 | from mp.utils.update_bots.telegram_bot import TelegramBot 6 | from mp.paths import telegram_login 7 | import time 8 | 9 | def run_for_mins(bot, nr_mins): 10 | r"""Run for an many minutes, giving updates once per minute.""" 11 | for i in range(1, nr_mins+1): 12 | time.sleep(60) 13 | bot.send_msg('It has been {} minutes.'.format(i)) 14 | 15 | def run_for_hours(bot, nr_hours): 16 | r"""Run for an many hours, giving updates once per hour.""" 17 | for i in range(1, nr_hours+1): 18 | time.sleep(360) 19 | bot.send_msg('It has been {} hours'.format(i)) 20 | 21 | """ 22 | bot = TelegramBot(login_data=telegram_login) 23 | nr_hs = 14 24 | bot.send_msg('Connection testing script started. Should run for {} hours.'.format(nr_hs)) 25 | run_for_hours(bot, nr_hs) 26 | """ -------------------------------------------------------------------------------- /mp/utils/helper_functions.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Miscellaneous helper functions. 3 | # ------------------------------------------------------------------------------ 4 | 5 | def f_optional_args(f, args, x): 6 | r"""If there are arguments, these are passed to the function.""" 7 | if args: 8 | return f(x, **args) 9 | else: 10 | return f(x) 11 | 12 | import datetime 13 | def get_time_string(cover=False): 14 | r""" 15 | Returns the current time in the format YYYY-MM-DD_HH-MM, or 16 | [YYYY-MM-DD_HH-MM] if 'cover' is set to 'True'. 17 | """ 18 | date = str(datetime.datetime.now()).replace(' ', '_').replace(':', '-').split('.')[0] 19 | if cover: 20 | return '['+date+']' 21 | else: 22 | return date 23 | 24 | import ntpath 25 | def divide_path_fname(path): 26 | r"""Divide path and name from a full path.""" 27 | path_to_file, file_name = ntpath.split(path) 28 | if not file_name: 29 | # Cease where the path ends with a slash 30 | file_name = ntpath.basename(path_to_file) 31 | path_to_file = path_to_file.split(file_name)[0] 32 | return path_to_file, file_name 33 | 34 | import numpy as np 35 | import random 36 | import torch 37 | def seed_all(seed=42): 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False -------------------------------------------------------------------------------- /mp/utils/introspection.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Introspection function to create object instances from string arguments. 3 | # ------------------------------------------------------------------------------ 4 | 5 | def introspect(class_path): 6 | r"""Creates a class dynamically from a class path.""" 7 | if isinstance(class_path, str): 8 | class_path = class_path.split('.') 9 | class_name = class_path[-1] 10 | module_path = class_path[:-1] 11 | module = __import__('.'.join(module_path)) 12 | for m in module_path[1:]: 13 | module = getattr(module, m) 14 | module = getattr(module, class_name) 15 | return module -------------------------------------------------------------------------------- /mp/utils/load_restore.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Functions to save and restore different data types. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import os 6 | 7 | # PICKLE 8 | import pickle 9 | def pkl_dump(obj, name, path='obj'): 10 | r"""Saves an object in pickle format.""" 11 | if '.p' not in name: 12 | name = name + '.pkl' 13 | path = os.path.join(path, name) 14 | pickle.dump(obj, open(path, 'wb')) 15 | 16 | def pkl_load(name, path='obj'): 17 | r"""Restores an object from a pickle file.""" 18 | if '.p' not in name: 19 | name = name + '.pkl' 20 | path = os.path.join(path, name) 21 | try: 22 | obj = pickle.load(open(path, 'rb')) 23 | except FileNotFoundError: 24 | obj = None 25 | return obj 26 | 27 | # NUMPY 28 | from numpy import save, load 29 | 30 | def np_dump(obj, name, path='obj'): 31 | r"""Saves an object in npy format.""" 32 | if '.npy' not in name: 33 | name = name + '.npy' 34 | path = os.path.join(path, name) 35 | save(path, obj) 36 | 37 | def np_load(name, path='obj'): 38 | r"""Restores an object from a npy file.""" 39 | if '.npy' not in name: 40 | name = name + '.npy' 41 | path = os.path.join(path, name) 42 | try: 43 | obj = load(path) 44 | except FileNotFoundError: 45 | obj = None 46 | return obj 47 | 48 | # JSON 49 | import json 50 | def save_json(dict_obj, path, name): 51 | r"""Saves a dictionary in json format.""" 52 | if '.json' not in name: 53 | name += '.json' 54 | with open(os.path.join(path, name), 'w') as json_file: 55 | json.dump(dict_obj, json_file) 56 | 57 | def load_json(path, name): 58 | r"""Restores a dictionary from a json file.""" 59 | if '.json' not in name: 60 | name += '.json' 61 | with open(os.path.join(path, name), 'r') as json_file: 62 | return json.load(json_file) 63 | 64 | # NIFTY 65 | def nifty_dump(x, name, path): 66 | r"""Save a tensor of numpy array in nifty format.""" 67 | if 'torch.Tensor' in str(type(x)): 68 | x = x.detach().cpu().numpy() 69 | if '.nii' not in name: 70 | name = name + '.nii.gz' 71 | # Remove channels dimension and rotate axis so depth first 72 | if len(x.shape) == 4: 73 | x = np.moveaxis(x[0], -1, 0) 74 | assert len(x.shape) == 3 75 | path = os.path.join(path, name) 76 | sitk.WriteImage(sitk.GetImageFromArray(x), path) 77 | 78 | # OTHERS 79 | import functools 80 | def join_path(list): 81 | r"""From a list of chained directories, forms a path""" 82 | return functools.reduce(os.path.join, list) 83 | -------------------------------------------------------------------------------- /mp/utils/pytorch/compute_normalization_values.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Torchvision requires the mean and standard deviation to be calculated manually 3 | # for normalization. This method can be used for that. However, this is mainly 4 | # for colored 2D images and therefore rarely relevant for medical data. 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch 8 | 9 | def normalization_values(dataset): 10 | r"""Compute normalization values for a dataset.""" 11 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=False) 12 | count = 0 13 | mean = torch.empty(3) 14 | std = torch.empty(3) 15 | 16 | for data, _ in dataloader: 17 | b, c, h, w = data.shape 18 | nb_pixels = b * h * w 19 | sum_ = torch.sum(data, dim=[0, 2, 3]) 20 | sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3]) 21 | mean = (count * mean + sum_) / (count + nb_pixels) 22 | std = (count * std + sum_of_square) / (count + nb_pixels) 23 | count += nb_pixels 24 | 25 | return {'mean': mean, 'std': torch.sqrt(std - mean ** 2)} -------------------------------------------------------------------------------- /mp/utils/pytorch/pytorch_load_restore.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Functions to store and restore PyTorch objects. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import torch 6 | import os 7 | from copy import copy 8 | 9 | def save_model_state_dataparallel(model, name, path): 10 | r"""Saves a pytorch model that was encapsulated in nn.DataParallel.""" 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | full_path = os.path.join(path, name) 14 | state_dict = model.state_dict() 15 | state_dict_iter = copy(state_dict) 16 | # remove DataParallel specific .module 17 | for key in state_dict_iter.keys(): 18 | state_dict[key.replace('.module', '')] = state_dict.pop(key) 19 | torch.save(state_dict, full_path) 20 | 21 | def save_model_state(model, name, path): 22 | r"""Saves a pytorch model.""" 23 | if not os.path.exists(path): 24 | os.makedirs(path) 25 | full_path = os.path.join(path, name) 26 | torch.save(model.state_dict(), full_path) 27 | 28 | def load_model_state(model, name, path, device='cpu'): 29 | r"""Restores a pytorch model.""" 30 | if os.path.exists(path): 31 | full_path = os.path.join(path, name) 32 | if os.path.isfile(full_path): 33 | model.load_state_dict(torch.load(full_path, map_location=device)) 34 | return True 35 | return False 36 | 37 | def save_optimizer_state(optimizer, name, path): 38 | r"""Saves a pytorch optimizer state. 39 | 40 | This makes sure that, for instance, if learning rate decay is used the same 41 | state is restored which was left of at this point in time. 42 | """ 43 | full_path = os.path.join(path, name) 44 | torch.save(optimizer.state_dict(), full_path) 45 | 46 | def load_optimizer_state(optimizer, name, path, device='cpu'): 47 | r"""Restores a pytorch optimizer state.""" 48 | if os.path.exists(path): 49 | full_path = os.path.join(path, name) 50 | if os.path.isfile(full_path): 51 | optimizer.load_state_dict(torch.load(full_path, map_location=device)) 52 | return True 53 | return False 54 | 55 | def save_scheduler_state(scheduler, name, path): 56 | r"""Saves a scheduler state.""" 57 | full_path = os.path.join(path, name) 58 | torch.save(scheduler.state_dict(), full_path) 59 | 60 | def load_scheduler_state(scheduler, name, path, device='cpu'): 61 | r"""Loads a scheduler state.""" 62 | if os.path.exists(path): 63 | full_path = os.path.join(path, name) 64 | if os.path.isfile(full_path): 65 | scheduler.load_state_dict(torch.load(full_path, map_location=device)) 66 | return True 67 | return False -------------------------------------------------------------------------------- /mp/utils/seaborn/legend_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Functions to style Seaborn plots. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import matplotlib.patches as mpatches 6 | 7 | def _remove_empyties_and_duplicates(handles, labels, titles): 8 | r"""Removes repeated entries and titles which are not followed by entries""" 9 | new_labels = [] 10 | new_handles = [] 11 | appeared = set() 12 | for ix, label in enumerate(labels): 13 | if label in appeared: 14 | continue 15 | else: 16 | appeared.add(label) 17 | if label in titles: 18 | # label is the last entry or the next label is a title 19 | if ix == len(labels) - 1 or labels[ix+1] in titles: 20 | titles.remove(label) 21 | continue 22 | new_labels.append(label) 23 | new_handles.append(handles[ix]) 24 | return new_handles, new_labels, titles 25 | 26 | def _bold_titles(labels, titles): 27 | r"""Styles title labels bold. 28 | """ 29 | labels = ['$\\bf{'+label+'}$' if label in titles else label for label in labels] 30 | titles = ['$\\bf{'+title+'}$' for title in titles] 31 | return labels, titles 32 | 33 | def _insert_divider_before_titles(handles, labels, titles): 34 | r"""Inserts an empty line before each new legend easthetic 35 | param titles: elements of 'labels' before which a space should be inserted 36 | """ 37 | titles = titles[1:] # Do not need to insert space before first title 38 | empty_handle = mpatches.Patch(color='white', alpha=0) 39 | space_indexes = [labels.index(title) for title in titles] 40 | for i in range(len(space_indexes)): 41 | handles.insert(space_indexes[i], empty_handle) 42 | labels.insert(space_indexes[i], '') 43 | space_indexes = [i+1 for i in space_indexes] 44 | return handles, labels 45 | 46 | def _add_hue_dimension(handles, labels): 47 | # TODO 48 | handles.append(mpatches.Patch(color='red', alpha=0.5)) 49 | labels.append('white') 50 | return handles, labels 51 | 52 | def format_legend(ax, titles): 53 | r"""Format legend""" 54 | if 'numpy' in str(type(ax)): 55 | ax = ax.copy()[-1] 56 | # Fetch legend labels and handles 57 | handles, labels = ax.get_legend_handles_labels() 58 | handles, labels, titles = _remove_empyties_and_duplicates(handles, labels, titles) 59 | labels, titles = _bold_titles(labels, titles) 60 | handles, labels = _insert_divider_before_titles(handles, labels, titles) 61 | # Legend to the side 62 | ax.legend(handles, labels, bbox_to_anchor=(1, 1), loc=2) 63 | 64 | def _add_training_legend_items(handles, labels, alpha_training, alpha_not_training): 65 | handles.append(mpatches.Patch(color='white', alpha=0)) 66 | handles.append(mpatches.Patch(color='black', alpha=alpha_training)) 67 | handles.append(mpatches.Patch(color='black', alpha=alpha_not_training)) 68 | labels.append('Training') 69 | labels.append('On data') 70 | labels.append('On other data') 71 | return handles, labels -------------------------------------------------------------------------------- /mp/utils/tensorboard.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | def create_writer(path, init_epoch=0): 4 | ''' Creates tensorboard SummaryWriter.''' 5 | 6 | return SummaryWriter(path, purge_step=init_epoch) -------------------------------------------------------------------------------- /mp/utils/update_bots/telegram_bot.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Telegram bot to pass messages about the training, or inform when experiments 3 | # are done. Follow these links to get a token and chat-id 4 | # - https://www.christian-luetgens.de/homematic/telegram/botfather/Chat-Bot.htm 5 | # - https://stackoverflow.com/questions/32423837/telegram-bot-how-to-get-a-group-chat-id 6 | # Then, place these strings in a telegram_login.json file in this directory. 7 | # That file is ignored by git. 8 | # ------------------------------------------------------------------------------ 9 | 10 | import telegram as tel 11 | from mp.utils.load_restore import load_json, join_path 12 | 13 | class TelegramBot(): 14 | r"""Initialize a telegram bot. 15 | Args: 16 | login_data (dict[str -> str]): dictionary with the entries 'chat_id' 17 | and 'token' 18 | 19 | """ 20 | def __init__(self, login_data = None): 21 | if login_data is None: 22 | login_data = load_json(path=join_path(['src', 'utils', 'telegram_bot']), 23 | name='telegram_login') 24 | self.chat_id = login_data['chat_id'] 25 | self.bot = tel.Bot(token=login_data['token']) 26 | 27 | def send_msg(self, msg): 28 | r"""Send a message in string form""" 29 | self.bot.send_message(chat_id=self.chat_id, text=msg) 30 | -------------------------------------------------------------------------------- /mp/visualization/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # A confusion matrix for classification tasks. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import os 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | class ConfusionMatrix: 11 | r"""An internal representtaion for a confusion matrix. 12 | 13 | Args: 14 | nr_classes (int): number of classes/labels 15 | """ 16 | def __init__(self, nr_classes, labels=None): 17 | self.cm = [[0 for i in range(nr_classes)] for i in range(nr_classes)] 18 | if labels is None: 19 | self.labels = list(range(nr_classes)) 20 | else: 21 | self.labels = labels 22 | 23 | def add(self, predicted, actual, count=1): 24 | r"""Set an entry for the confusion matrix.""" 25 | self.cm[self.labels.index(actual)][self.labels.index(predicted)] += count 26 | 27 | def plot(self, path, name='confusion_matrix', label_predicted='Predicted', 28 | label_actual='Actual', figure_size=(7,5), annot=True): 29 | r"""Plot using seaborn.""" 30 | cm = self.cm.copy() 31 | nr_rows = len(cm) 32 | cm.insert(0, [0]*nr_rows) 33 | df = pd.DataFrame(cm, columns=[c+1 for c in range(nr_rows)]) 34 | df = df.drop([0]) 35 | plt.figure() 36 | sns.set(rc={'figure.figsize':figure_size}) 37 | ax = sns.heatmap(df, annot=annot) 38 | ax.set(xlabel=label_predicted, ylabel=label_actual) 39 | plt.savefig(os.path.join(path, name+'.png'), facecolor='w', 40 | bbox_inches="tight", dpi = 300) 41 | 42 | def get_accuracy(self): 43 | r"""Get the accuracy.""" 44 | correct = sum([self.cm[i][i] for i in range(len(self.cm))]) 45 | all_instances = sum([sum(x) for x in self.cm]) 46 | return correct/all_instances -------------------------------------------------------------------------------- /mp/visualization/plot_results.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Plots results. 3 | # ------------------------------------------------------------------------------ 4 | 5 | import os 6 | import matplotlib.pyplot as plt 7 | import seaborn as sns 8 | from mp.utils.seaborn.legend_utils import format_legend 9 | 10 | def plot_results(result, measures=None, save_path=None, save_name=None, 11 | title=None, ending='.png', ylog=False, figsize=(10,5)): 12 | """Plots a data frame as created by mp.eval.Results 13 | 14 | Args: 15 | measures (list[str]): list of measure names 16 | save_path (str): path to save plot. If None, plot is shown. 17 | save_name (str): name with which plot is saved 18 | title (str): the title that will appear on the plot 19 | ending (str): can be '.png' or '.svg' 20 | ylog (bool): apply logarithm to y axis 21 | figsize (tuple[int]): figure size 22 | """ 23 | df = result.to_pandas() 24 | # Filter out measures that are not to be shown 25 | # The default is using all measures in the df 26 | if measures: 27 | df = df.loc[df['Metric'].isin(measures)] 28 | # Start a new figure so that different plots do not overlap 29 | plt.figure() 30 | sns.set(rc={'figure.figsize':figsize}) 31 | # Plot 32 | ax = sns.lineplot(x='Epoch', 33 | y='Value', 34 | hue='Metric', 35 | style='Data', 36 | alpha=0.7, 37 | data=df) 38 | ax = sns.scatterplot(x='Epoch', 39 | y='Value', 40 | hue='Metric', 41 | style='Data', 42 | alpha=1., 43 | data=df) 44 | # Optional logarithmic scale 45 | if ylog: 46 | ax.set_yscale('log') 47 | # Style legend 48 | titles = ['Metric', 'Data'] 49 | format_legend(ax, titles) 50 | # Set title 51 | if title: 52 | ax.set_title(title) 53 | # Save image 54 | if save_path: 55 | file_name = save_name if save_name is not None else result.name 56 | if not os.path.exists(save_path): 57 | os.makedirs(save_path) 58 | file_name = file_name.split('.')[0]+ending 59 | plt.savefig(os.path.join(save_path, file_name), facecolor='w', 60 | bbox_inches="tight", dpi = 300) 61 | -------------------------------------------------------------------------------- /qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/qualitative_results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | pytest==6.1.1 3 | numpy==1.19.1 4 | scipy==1.5.2 5 | pandas==1.0.3 6 | ipykernel==5.2.1 7 | ipython==7.13.0 8 | ipython-genutils==0.2.0 9 | ipywidgets==7.5.1 10 | Jinja2==2.10.1 11 | jupyter==1.0.0 12 | jupyter-client==6.1.5 13 | jupyter-console==6.1.0 14 | jupyter-core==4.6.3 15 | notebook==6.0.3 16 | matplotlib==3.2.1 17 | seaborn==0.10.1 18 | tensorboard==2.2.1 19 | mypy==0.770 20 | pylint==2.6.0 21 | simpleitk==1.2.4 22 | torch-summary==1.2.0 23 | torchio==0.17.46 24 | pexpect==4.8.0 25 | python-telegram-bot==12.8 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='medical_pytorch', 5 | version='0.1', 6 | description='A project boundling several MIC libraries.', 7 | url='https://github.com/camgbus/medical_pytorch', 8 | keywords='python setuptools', 9 | packages=find_packages(include=['mp', 'mp.*']), 10 | ) -------------------------------------------------------------------------------- /test/agents/model_state_restore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mp.data.datasets.ds_mr_prostate_decathlon import DecathlonProstateT2 3 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset 4 | from mp.models.segmentation.unet_fepegar import UNet2D 5 | from mp.eval.losses.losses_segmentation import LossDice, LossClassWeighted 6 | from mp.agents.segmentation_agent import SegmentationAgent 7 | from mp.eval.evaluate import ds_losses_metrics 8 | 9 | def test_restore_model_state_and_eval(): 10 | device = 'cpu' 11 | 12 | # Fetch data 13 | data = DecathlonProstateT2() 14 | label_names = data.label_names 15 | nr_labels = data.nr_labels 16 | 17 | # Transform data to PyTorch format and build train dataloader 18 | input_shape = (1, 256, 256) 19 | datasets = dict() 20 | datasets['train'] = PytorchSeg2DDataset(data, ix_lst=[0], size=input_shape, aug_key='none', resize=False) 21 | datasets['mixed'] = PytorchSeg2DDataset(data, ix_lst=[0, 1], size=input_shape, aug_key='none', resize=False) 22 | datasets['test'] = PytorchSeg2DDataset(data, ix_lst=[1], size=input_shape, aug_key='none', resize=False) 23 | dl = torch.utils.data.DataLoader(datasets['train'], batch_size=8, shuffle=False) 24 | 25 | # Build model 26 | model = UNet2D(input_shape, nr_labels) 27 | 28 | # Define agent 29 | agent = SegmentationAgent(model=model, label_names=label_names, device=device, 30 | metrics=['ScoreDice']) 31 | 32 | # Restore model state 33 | agent.restore_state(states_path='test/test_obj/agent_states_prostate_2D', state_name="epoch_300") 34 | 35 | # Calculate metrics and compare 36 | loss_g = LossDice(1e-05) 37 | loss_f = LossClassWeighted(loss=loss_g, weights=(1.,1.), device=device) 38 | eval_dict = ds_losses_metrics(datasets['mixed'], agent, loss_f, metrics=['ScoreDice']) 39 | 40 | test_target_dict = {'ScoreDice': {'prostate00': 0.9280484305139076, 'prostate01': 0.5375613582619043, 'mean': 0.732804894387906, 'std': 0.19524353612600165}, 41 | 'ScoreDice[background]': {'prostate00': 0.996721191337123, 'prostate01': 0.9785040545630738, 'mean': 0.9876126229500983, 'std': 0.009108568387024618}, 42 | 'ScoreDice[prostate]': {'prostate00': 0.8593756696906922, 'prostate01': 0.09661866196073488, 'mean': 0.47799716582571355, 'std': 0.3813785038649787}, 43 | 'Loss_LossClassWeighted[loss=LossDice[smooth=1e-05]; weights=(1.0, 1.0)]': {'prostate00': 0.10226414799690246, 'prostate01': 0.4694981321692467, 'mean': 0.2858811400830746, 'std': 0.1836169920861721}, 44 | 'Loss_LossDice[smooth=1e-05][0]': {'prostate00': 0.005160685380299886, 'prostate01': 0.03430714905261993, 'mean': 0.01973391721645991, 'std': 0.014573231836160022}, 45 | 'Loss_LossDice[smooth=1e-05][1]': {'prostate00': 0.19936761061350505, 'prostate01': 0.9046891242265701, 'mean': 0.5520283674200376, 'std': 0.3526607568065325}} 46 | 47 | for metric_key, metric_dict in test_target_dict.items(): 48 | for key, value in metric_dict.items(): 49 | assert abs(value - eval_dict[metric_key][key]) < 0.01 50 | -------------------------------------------------------------------------------- /test/cuda/test_cuda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def test_cuda(): 5 | nr_devices = torch.cuda.device_count() 6 | assert nr_devices > 0 7 | device = nr_devices-1 # Last device chosen 8 | torch.cuda.set_device(device) 9 | tensor = torch.zeros((2,3)).cuda() 10 | assert str(tensor.device) == 'cuda:'+str(device) -------------------------------------------------------------------------------- /test/data/datasets/test_ds_mr_cardiac_mm.py: -------------------------------------------------------------------------------- 1 | from mp.data.datasets.ds_mr_cardiac_mm import MM_Challenge 2 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset 3 | 4 | def test_ds(): 5 | data = MM_Challenge(subset=None) 6 | assert data.label_names == ['background', 'left ventricle', 'myocardium', 'right ventricle'] 7 | assert data.nr_labels == 4 8 | assert data.modality == 'MR' 9 | assert data.size == 300 10 | ds = PytorchSeg2DDataset(data, size=(1, 256, 256), aug_key='none', resize=False) 11 | instance = ds.get_instance(0) 12 | assert instance.name == 'A0S9V9_ED' 13 | subject_ix = ds.get_ix_from_name('A0S9V9_ED') 14 | assert subject_ix == 0 15 | 16 | def test_ds_subset(): 17 | data = MM_Challenge(subset={'Vendor': 'B'}) 18 | print(data.size) 19 | assert data.size == 150 20 | ds = PytorchSeg2DDataset(data, size=(1, 256, 256), aug_key='none', resize=False) 21 | instance = ds.get_instance(0) 22 | assert instance.name == 'A1D0Q7_ED' 23 | subject_ix = ds.get_ix_from_name('A1D0Q7_ED') 24 | assert subject_ix == 0 25 | -------------------------------------------------------------------------------- /test/data/datasets/test_ds_mr_prostate_decathlon.py: -------------------------------------------------------------------------------- 1 | from mp.data.datasets.ds_mr_prostate_decathlon import DecathlonProstateT2 2 | 3 | def test_ds_label_merging(): 4 | data = DecathlonProstateT2(merge_labels=True) 5 | assert data.label_names == ['background', 'prostate'] 6 | assert data.nr_labels == 2 7 | assert data.modality == 'MR' 8 | assert data.size == 32 9 | assert data.name == 'DecathlonProstateT2' -------------------------------------------------------------------------------- /test/data/pytorch/test_transformation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mp.data.pytorch.transformation import per_label_channel, one_output_channel 3 | 4 | def test_per_label_channel(): 5 | A_1=[[0,0,0,0,0,0,0], 6 | [0,1,3,3,0,1,0], 7 | [0,0,3,1,1,2,2], 8 | [0,0,0,1,1,2,2]] 9 | A_2=[[0,0,0,0,0,0,0], 10 | [0,0,0,0,0,0,0], 11 | [0,0,0,0,3,3,3], 12 | [2,2,1,1,1,1,0]] 13 | A_3=[[0,0,0,0,0,0,0], 14 | [0,0,0,1,0,0,0], 15 | [2,0,1,1,0,0,0], 16 | [2,0,0,0,0,0,0]] 17 | A_4=[[1,1,1,0,0,0,0], 18 | [0,0,0,0,2,2,0], 19 | [0,2,0,0,2,2,0], 20 | [0,2,0,0,3,3,3]] 21 | a = torch.tensor([A_1, A_2, A_3, A_4]) 22 | a = a.unsqueeze(0) 23 | per_label_channel_a = per_label_channel(a, nr_labels=4, channel_dim=0) 24 | one_output_channel_a = one_output_channel(per_label_channel_a, channel_dim=0).numpy() 25 | assert (a.numpy() == one_output_channel_a).all() 26 | 27 | 28 | -------------------------------------------------------------------------------- /test/eval/inference/test_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mp.eval.inference.predict import arg_max, softmax 3 | from mp.data.pytorch.transformation import per_label_channel 4 | 5 | def test_argmax_pred(): 6 | output = torch.tensor([[[[0.3, 0.7], [0.2, 0.1]], [[8., .0], [.03, 0.4]], 7 | [[5.7, .1], [.55, 0.45]]], [[[0.3, 0.7], [0.2, 0.1]], [[8., .0], 8 | [.03, 0.4]], [[5.7, .1], [.55, 0.45]]]]) 9 | target = torch.tensor([[[1, 0], [2, 2]], [[1, 0], [2, 2]]]) 10 | assert output.numpy().shape == (2,3,2,2) 11 | pred = arg_max(output).numpy() 12 | assert pred.shape == (2,2,2) 13 | assert (pred == target.numpy()).all() 14 | 15 | def test_argmax_another_channel_output(): 16 | output = torch.tensor([[[0.3, 0.7], [0.2, 0.1]], [[8., .0], [.03, 0.4]], 17 | [[5.7, .1], [.55, 0.45]]]) 18 | target = torch.tensor([[1, 0], [2, 2]]) 19 | assert output.numpy().shape == (3,2,2) 20 | pred = arg_max(output, channel_dim=0).numpy() 21 | assert pred.shape == (2,2) 22 | assert (pred == target.numpy()).all() 23 | 24 | def test_softmax(): 25 | output = torch.tensor([[[[3., 1.], [0.2, 0.05]], [[4., .0], [0.8, 0.4]], 26 | [[3., 9.], [0., 0.45]]],[[[3., 1.], [0.2, 0.05]], [[4., .0], [0.8, 0.4]], 27 | [[3., 9.], [0., 0.45]]]]) 28 | softmaxed_output = softmax(output).numpy() 29 | assert softmaxed_output.shape == (2,3,2,2) 30 | for k in [0,1]: 31 | for i, j in [(0,0),(0,1),(1,0),(1,1)]: 32 | assert abs(1 - sum(x[i][j] for x in softmaxed_output[k]) ) < 0.0001 33 | 34 | def test_softmax_another_channel_output(): 35 | output = torch.tensor([[[3., 1.], [0.2, 0.05]], [[4., .0], [0.8, 0.4]], 36 | [[3., 9.], [0., 0.45]]]) 37 | softmaxed_output = softmax(output, channel_dim=0).numpy() 38 | assert softmaxed_output.shape == (3,2,2) 39 | for i, j in [(0,0),(0,1),(1,0),(1,1)]: 40 | assert abs(1 - sum(x[i][j] for x in softmaxed_output) ) < 0.0001 41 | 42 | def test_per_label_channel_to_pred(): 43 | A_1=[[0,0,0,0,0,0,0], 44 | [0,1,3,3,0,1,0], 45 | [0,0,3,1,1,2,2], 46 | [0,0,0,1,1,2,2]] 47 | A_2=[[0,0,0,0,0,0,0], 48 | [0,0,0,0,0,0,0], 49 | [0,0,0,0,3,3,3], 50 | [2,2,1,1,1,1,0]] 51 | A_3=[[0,0,0,0,0,0,0], 52 | [0,0,0,1,0,0,0], 53 | [2,0,1,1,0,0,0], 54 | [2,0,0,0,0,0,0]] 55 | A_4=[[1,1,1,0,0,0,0], 56 | [0,0,0,0,2,2,0], 57 | [0,2,0,0,2,2,0], 58 | [0,2,0,0,3,3,3]] 59 | a = torch.tensor([A_1, A_2, A_3, A_4]) 60 | a = a.unsqueeze(0) 61 | per_label_channel_a = per_label_channel(a, nr_labels=4, channel_dim=0) 62 | a_pred = arg_max(per_label_channel_a, channel_dim=0).numpy() 63 | assert (a.numpy() == a_pred).all() -------------------------------------------------------------------------------- /test/eval/losses/test_losses_segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mp.eval.losses.losses_segmentation import LossBCE, LossDice, LossClassWeighted, LossDiceBCE 3 | from mp.data.pytorch.transformation import per_label_channel 4 | 5 | A_1=[[0,0,0,0,0,0,0], 6 | [0,1,3,3,0,1,0], 7 | [0,0,3,1,1,2,2], 8 | [0,0,0,1,1,2,2]] 9 | A_2=[[0,0,0,0,0,0,0], 10 | [0,0,0,0,0,0,0], 11 | [0,0,0,0,3,3,3], 12 | [2,2,1,1,1,1,0]] 13 | A_3=[[0,0,0,0,0,0,0], 14 | [0,0,0,1,0,0,0], 15 | [2,0,1,1,0,0,0], 16 | [2,0,0,0,0,0,0]] 17 | A_4=[[1,1,1,0,0,0,0], 18 | [0,0,0,0,2,2,0], 19 | [0,2,0,0,2,2,0], 20 | [0,2,0,0,3,3,3]] 21 | 22 | B_1=[[0,0,0,0,0,0,0], 23 | [0,0,0,0,0,0,0], 24 | [0,0,3,1,1,2,2], 25 | [0,0,0,1,1,2,2]] 26 | B_2=[[0,0,0,0,0,0,0], 27 | [0,0,0,0,0,0,0], 28 | [0,0,0,0,0,0,0], 29 | [2,2,2,2,1,1,0]] 30 | B_3=[[0,0,0,0,0,1,0], 31 | [0,0,2,0,1,1,0], 32 | [0,0,2,0,0,3,0], 33 | [0,0,0,0,3,3,0]] 34 | B_4=[[0,0,1,0,0,0,0], 35 | [1,1,0,0,0,0,0], 36 | [0,2,0,0,0,0,0], 37 | [0,2,0,0,0,0,3]] 38 | 39 | # Build test inputs 40 | a = torch.tensor([A_1, A_2, A_3, A_4]) 41 | b = torch.tensor([B_1, B_2, B_3, B_4]) 42 | c = torch.tensor([A_4, B_3, A_1, B_2]) 43 | a = per_label_channel(a.unsqueeze(0), nr_labels=4, channel_dim=0) 44 | b = per_label_channel(b.unsqueeze(0), nr_labels=4, channel_dim=0) 45 | c = per_label_channel(c.unsqueeze(0), nr_labels=4, channel_dim=0) 46 | 47 | # Batched-inputs 48 | a_batch = torch.stack([a, a, a]) 49 | b_batch = torch.stack([b, b, b]) 50 | c_batch = torch.stack([c, c, c]) 51 | 52 | # Single-instance inputs 53 | a = a.unsqueeze(0) 54 | b = b.unsqueeze(0) 55 | c = c.unsqueeze(0) 56 | 57 | # Zero- and one- filled arrays 58 | d = torch.zeros(a.shape, dtype=torch.float64) 59 | e = torch.ones(a.shape, dtype=torch.float64) 60 | d_batch = torch.zeros(a_batch.shape, dtype=torch.float64) 61 | e_batch = torch.ones(a_batch.shape, dtype=torch.float64) 62 | 63 | def test_bce(): 64 | loss = LossBCE() 65 | assert float(loss(a, a)) == float(loss(b, b)) == 0 66 | assert float(loss(a_batch, a_batch)) == float(loss(b_batch, b_batch)) == 0 67 | assert float(loss(a, b)) == float(loss(a, b)) 68 | assert abs(float(loss(a, b)) - 13.839) < 0.01 69 | assert abs(float(loss(a_batch, b_batch)) - 13.839) < 0.01 70 | assert float(loss(d, e)) == float(loss(e, d)) == 100 71 | assert float(loss(d_batch, e_batch)) == float(loss(e_batch, d_batch)) == 100 72 | assert loss(a, b) < loss(b, c) < loss(a, c) < loss(e, d) 73 | 74 | def test_batched_dice(): 75 | # For higher batches, the smoothed Dice loss is higher (it pproaches the 76 | # actual loss better). 77 | loss = LossDice(smooth=.001) 78 | assert abs(float(loss(a, b)) - 0.2768) < 0.001 79 | assert abs(float(loss(a_batch, b_batch)) - 0.2768) < 0.001 80 | loss = LossDice(smooth=1.) 81 | assert abs(float(loss(a, b)) - 0.2756) < 0.001 82 | assert abs(float(loss(a_batch, b_batch)) - 0.2764) < 0.001 83 | 84 | def test_weighted_dice(): 85 | # The loss is lower the higher the smoothing factor. The smaller the 86 | # smoothing factor, the more similar the result to the inverse Dice score. 87 | dice_loss = LossDice(smooth=1.) 88 | loss = LossClassWeighted(loss=dice_loss, nr_labels=4) 89 | assert abs(float(loss(a, b)) - 0.4245) < 0.001 90 | assert abs(float(loss(a, a)) - 0.0) < 0.001 91 | assert abs(float(loss(d, e)) - 0.9911) < 0.001 92 | dice_loss = LossDice(smooth=.001) 93 | loss = LossClassWeighted(loss=dice_loss, nr_labels=4) 94 | assert abs(float(loss(a, b)) - 0.4446) < 0.001 95 | assert abs(float(loss(a, a)) - 0.0) < 0.001 96 | assert abs(float(loss(d, e)) - 0.9999) < 0.001 97 | loss = LossClassWeighted(loss=dice_loss, weights=[2,1,0,1]) 98 | assert abs(float(loss(a, b)) - 0.3933) < 0.001 99 | loss = LossClassWeighted(loss=dice_loss, weights=[0.2,0.1,.0,.1]) 100 | assert abs(float(loss(a, b)) - 0.3933) < 0.001 101 | 102 | def test_combined_losses(): 103 | loss = LossDiceBCE(bce_weight=1., smooth=1.) 104 | assert abs(float(loss(a, b)) - 14.115) < 0.001 105 | loss = LossDiceBCE(bce_weight=.5, smooth=1.) 106 | assert abs(float(loss(a, b)) - 7.195) < 0.001 107 | loss = LossDiceBCE(bce_weight=.5, smooth=.001) 108 | assert abs(float(loss(a, b)) - 7.1964) < 0.001 109 | assert abs(float(loss(a_batch, b_batch)) - 7.1964) < 0.001 110 | 111 | def test_get_evaluation_dict(): 112 | loss = LossBCE() 113 | evaluation_dict = loss.get_evaluation_dict(a, b) 114 | assert abs(evaluation_dict['LossBCE'] - 13.839) < 0.01 115 | 116 | dice_loss = LossDice(smooth=.001) 117 | loss = LossClassWeighted(loss=dice_loss, nr_labels=4) 118 | evaluation_dict = loss.get_evaluation_dict(a, b) 119 | assert abs(evaluation_dict['LossDice[smooth=0.001][0]'] - 0.1795) < 0.01 120 | assert abs(evaluation_dict['LossDice[smooth=0.001][1]'] - 0.5) < 0.01 121 | assert abs(evaluation_dict['LossDice[smooth=0.001][2]'] - 0.3846) < 0.01 122 | assert abs(evaluation_dict['LossDice[smooth=0.001][3]'] - 0.7142) < 0.01 123 | assert abs(evaluation_dict['LossClassWeighted[loss=LossDice[smooth=0.001]; weights=(1, 1, 1, 1)]'] - 0.4446) < 0.01 124 | 125 | dice_loss = LossDice(smooth=.001) 126 | loss = LossClassWeighted(loss=dice_loss, weights=[0.2,0.1,.0,.1]) 127 | evaluation_dict = loss.get_evaluation_dict(a, b) 128 | assert abs(evaluation_dict['LossDice[smooth=0.001][0]'] - 0.1795) < 0.01 129 | assert abs(evaluation_dict['LossDice[smooth=0.001][1]'] - 0.5) < 0.01 130 | assert abs(evaluation_dict['LossDice[smooth=0.001][2]'] - 0.3846) < 0.01 131 | assert abs(evaluation_dict['LossDice[smooth=0.001][3]'] - 0.7142) < 0.01 132 | assert abs(evaluation_dict['LossClassWeighted[loss=LossDice[smooth=0.001]; weights=(0.2, 0.1, 0.0, 0.1)]'] - 0.3933) < 0.01 133 | 134 | loss = LossDiceBCE(bce_weight=.5, smooth=1.) 135 | evaluation_dict = loss.get_evaluation_dict(a, b) 136 | assert abs(evaluation_dict['LossCombined[1.0xLossDice[smooth=1.0]+0.5xLossBCE]'] - 7.195) < 0.01 137 | assert abs(evaluation_dict['LossBCE'] - 13.839) < 0.01 138 | assert abs(evaluation_dict['LossDice[smooth=1.0]'] - 0.275) < 0.01 139 | 140 | def test_batched_weighted_dice(): 141 | dice_loss = LossDice(smooth=.001) 142 | loss = LossClassWeighted(loss=dice_loss, nr_labels=4) 143 | assert abs(float(loss(a_batch, b_batch)) - 0.4446) < 0.001 144 | assert abs(float(loss(a_batch, a_batch)) - 0.0) < 0.001 145 | assert abs(float(loss(d_batch, e_batch)) - 0.9999) < 0.001 146 | loss = LossClassWeighted(loss=dice_loss, weights=[2,1,0,1]) 147 | assert abs(float(loss(a_batch, b_batch)) - 0.3933) < 0.001 148 | loss = LossClassWeighted(loss=dice_loss, weights=[0.2,0.1,.0,.1]) 149 | assert abs(float(loss(a_batch, b_batch)) - 0.3933) < 0.001 150 | evaluation_dict = loss.get_evaluation_dict(a_batch, b_batch) 151 | assert abs(evaluation_dict['LossDice[smooth=0.001][0]'] - 0.1795) < 0.01 152 | assert abs(evaluation_dict['LossDice[smooth=0.001][1]'] - 0.5) < 0.01 153 | assert abs(evaluation_dict['LossDice[smooth=0.001][2]'] - 0.3846) < 0.01 154 | assert abs(evaluation_dict['LossDice[smooth=0.001][3]'] - 0.7142) < 0.01 155 | assert abs(evaluation_dict['LossClassWeighted[loss=LossDice[smooth=0.001]; weights=(0.2, 0.1, 0.0, 0.1)]'] - 0.3933) < 0.01 156 | 157 | def test_batched_weighted_dice_two(): 158 | dice_loss = LossDice(smooth=.001) 159 | loss = LossClassWeighted(loss=dice_loss, weights=[20.,10.,.0,10.]) 160 | assert abs(float(loss(a_batch, b_batch)) - 0.3933) < 0.001 161 | evaluation_dict = loss.get_evaluation_dict(a_batch, b_batch) 162 | assert abs(evaluation_dict['LossDice[smooth=0.001][0]'] - 0.1795) < 0.01 163 | assert abs(evaluation_dict['LossDice[smooth=0.001][1]'] - 0.5) < 0.01 164 | assert abs(evaluation_dict['LossDice[smooth=0.001][2]'] - 0.3846) < 0.01 165 | assert abs(evaluation_dict['LossDice[smooth=0.001][3]'] - 0.7142) < 0.01 166 | assert abs(evaluation_dict['LossClassWeighted[loss=LossDice[smooth=0.001]; weights=(20.0, 10.0, 0.0, 10.0)]'] - 0.3933) < 0.01 -------------------------------------------------------------------------------- /test/eval/metrics/test_metrics_segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mp.eval.metrics.mean_scores import get_tp_tn_fn_fp_segmentation, get_mean_scores 3 | 4 | A_1=[[0,0,0,0,0,0,0], 5 | [0,1,3,3,0,1,0], 6 | [0,0,3,1,1,2,2], 7 | [0,0,0,1,1,2,2]] 8 | A_2=[[0,0,0,0,0,0,0], 9 | [0,0,0,0,0,0,0], 10 | [0,0,0,0,3,3,3], 11 | [2,2,1,1,1,1,0]] 12 | A_3=[[0,0,0,0,0,0,0], 13 | [0,0,0,1,0,0,0], 14 | [2,0,1,1,0,0,0], 15 | [2,0,0,0,0,0,0]] 16 | A_4=[[1,1,1,0,0,0,0], 17 | [0,0,0,0,2,2,0], 18 | [0,2,0,0,2,2,0], 19 | [0,2,0,0,3,3,3]] 20 | 21 | B_1=[[0,0,0,0,0,0,0], 22 | [0,0,0,0,0,0,0], 23 | [0,0,3,1,1,2,2], 24 | [0,0,0,1,1,2,2]] 25 | B_2=[[0,0,0,0,0,0,0], 26 | [0,0,0,0,0,0,0], 27 | [0,0,0,0,0,0,0], 28 | [2,2,2,2,1,1,0]] 29 | B_3=[[0,0,0,0,0,1,0], 30 | [0,0,2,0,1,1,0], 31 | [0,0,2,0,0,3,0], 32 | [0,0,0,0,3,3,0]] 33 | B_4=[[0,0,1,0,0,0,0], 34 | [1,1,0,0,0,0,0], 35 | [0,2,0,0,0,0,0], 36 | [0,2,0,0,0,0,3]] 37 | 38 | a = torch.tensor([A_1, A_2, A_3, A_4]) 39 | b = torch.tensor([B_1, B_2, B_3, B_4]) 40 | 41 | # Batched-inputs 42 | a_batch = torch.stack([a, a, a]) 43 | b_batch = torch.stack([b, b, b]) 44 | 45 | # Single-instance inputs 46 | a = a.unsqueeze(0) 47 | b = b.unsqueeze(0) 48 | 49 | def test_tp_tn_fn_fp(): 50 | assert get_tp_tn_fn_fp_segmentation(a, b, 0) == (64, 20, 9, 19) 51 | assert get_tp_tn_fn_fp_segmentation(a, b, 1) == (7, 91, 9, 5) 52 | assert get_tp_tn_fn_fp_segmentation(a, b, 2) == (8, 94, 6, 4) 53 | assert get_tp_tn_fn_fp_segmentation(a, b, 3) == (2, 100, 7, 3) 54 | 55 | def test_dice_iou(): 56 | scores = get_mean_scores(a, b, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3']) 57 | target_scores = {'ScoreDice': 0.555, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.410, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 58 | for key, value in target_scores.items(): 59 | assert abs(value - scores[key]) <= 0.01 60 | 61 | def test_weighted_metrics(): 62 | scores = get_mean_scores(a, b, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3'], label_weights={'0':2, '1':1, '2':0, '3':1}) 63 | target_scores = {'ScoreDice': 0.607, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.473, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 64 | for key, value in target_scores.items(): 65 | assert abs(value - scores[key]) <= 0.01 66 | 67 | scores = get_mean_scores(a, b, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3'], label_weights={'0':0.2, '1':0.1, '2':0, '3':0.1}) 68 | target_scores = {'ScoreDice': 0.607, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.473, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 69 | for key, value in target_scores.items(): 70 | assert abs(value - scores[key]) <= 0.01 71 | 72 | def test_batched_tp_tn_fn_fp(): 73 | assert get_tp_tn_fn_fp_segmentation(a_batch, b_batch, 0) == (64*3, 20*3, 9*3, 19*3) 74 | assert get_tp_tn_fn_fp_segmentation(a_batch, b_batch, 1) == (7*3, 91*3, 9*3, 5*3) 75 | assert get_tp_tn_fn_fp_segmentation(a_batch, b_batch, 2) == (8*3, 94*3, 6*3, 4*3) 76 | assert get_tp_tn_fn_fp_segmentation(a_batch, b_batch, 3) == (2*3, 100*3, 7*3, 3*3) 77 | 78 | def test_batched_metrics(): 79 | scores = get_mean_scores(a_batch, b_batch, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3'], label_weights={'0':2, '1':1, '2':0, '3':1}) 80 | target_scores = {'ScoreDice': 0.607, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.473, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 81 | for key, value in target_scores.items(): 82 | assert abs(value - scores[key]) <= 0.01 83 | 84 | scores = get_mean_scores(a_batch, b_batch, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3'], label_weights={'0':2, '1':1, '2':0, '3':1}) 85 | target_scores = {'ScoreDice': 0.607, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.473, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 86 | for key, value in target_scores.items(): 87 | assert abs(value - scores[key]) <= 0.01 88 | 89 | scores = get_mean_scores(a_batch, b_batch, metrics=['ScoreDice', 'ScoreIoU'], label_names=['0', '1', '2', '3'], label_weights={'0':0.2, '1':0.1, '2':0, '3':0.1}) 90 | target_scores = {'ScoreDice': 0.607, 'ScoreDice[0]': 0.820, 'ScoreDice[1]': 0.5, 'ScoreDice[2]': 0.615, 'ScoreDice[3]': 0.286, 'ScoreIoU': 0.473, 'ScoreIoU[0]': 0.696, 'ScoreIoU[1]': 0.333, 'ScoreIoU[2]': 0.444, 'ScoreIoU[3]': 0.167} 91 | for key, value in target_scores.items(): 92 | assert abs(value - scores[key]) <= 0.01 -------------------------------------------------------------------------------- /test/eval/test_accumulator.py: -------------------------------------------------------------------------------- 1 | from mp.eval.accumulator import Accumulator 2 | 3 | def test_acc(): 4 | acc = Accumulator(keys=['A']) 5 | for i in range(5): 6 | acc.add('A', float(i)) 7 | assert acc.mean('A') == 2.0 8 | assert 1.41 < acc.std('A') < 1.415 -------------------------------------------------------------------------------- /test/eval/test_result.py: -------------------------------------------------------------------------------- 1 | from mp.eval.result import Result 2 | 3 | def test_results(): 4 | res = Result(name='Example') 5 | res.add(1, 'accuracy', 0.2, data='example') 6 | res.add(2, 'accuracy', 0.3, data='example') 7 | res.add(3, 'accuracy', 0.4, data='example') 8 | res.add(0, 'F1', 0.5, data='example') 9 | res.add(3, 'F1', 0.7, data='example') 10 | assert res.get_min_epoch(metric='accuracy', data='example') == 1 11 | assert res.get_max_epoch(metric='F1', data='example') == 3 12 | assert res.get_epoch_metric(epoch=2, metric='accuracy', data='example') == 0.3 13 | assert len(res.to_pandas()) == 5 14 | -------------------------------------------------------------------------------- /test/experiment/test_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from mp.experiments.experiment import Experiment 5 | from mp.eval.result import Result 6 | from mp.utils.load_restore import load_json 7 | from mp.paths import storage_path 8 | 9 | def test_success(): 10 | notes='A test experiment which is successful' 11 | exp = Experiment({'test_param': 2, 'nr_runs': 1}, name='TEST_SUCCESS', notes=notes) 12 | exp_run = exp.get_run(0) 13 | res = Result(name='some_result') 14 | res.add(1, 'A', 2.0) 15 | res.add(3, 'A', 10.0) 16 | exp_run.finish(results=res) 17 | path = os.path.join(os.path.join(storage_path, 'exp'), 'TEST_SUCCESS') 18 | exp_review = load_json(path, 'review') 19 | assert exp_review['notes'] == notes 20 | config = load_json(path, 'config') 21 | assert config['test_param'] == 2 22 | res_path = os.path.join(path, os.path.join('0', 'results')) 23 | assert os.path.isfile(os.path.join(res_path, 'some_result.png')) 24 | shutil.rmtree(path) 25 | 26 | def test_failure(): 27 | notes='A test experiment which fails' 28 | exp = Experiment({'test_param': 2, 'nr_runs': 1}, name='TEST_FAILURE', notes=notes) 29 | exp_run = exp.get_run(0) 30 | exp_run.finish(exception=Exception) 31 | path = os.path.join(os.path.join(storage_path, 'exp'), 'TEST_FAILURE') 32 | exp_review = load_json(path, 'review') 33 | assert exp_review['notes'] == notes 34 | exp_run_review = load_json(os.path.join(path, '0'), 'review') 35 | assert 'FAILURE' in exp_run_review['state'] 36 | shutil.rmtree(path) 37 | 38 | def test_reload(): 39 | notes='A test experiment which is reloaded' 40 | # First experiment creation 41 | exp = Experiment({'test_param': 2, 'nr_runs': 1}, name='TEST_RELOAD', notes=notes) 42 | res = Result(name='some_result') 43 | # Experiment reload 44 | exp = Experiment(name='TEST_RELOAD', reload_exp=True) 45 | assert exp.review['notes'] == notes 46 | assert exp.config['test_param'] == 2 47 | path = os.path.join(os.path.join(storage_path, 'exp'), 'TEST_RELOAD') 48 | shutil.rmtree(path) 49 | -------------------------------------------------------------------------------- /test/test_obj/3dimg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/3dimg.png -------------------------------------------------------------------------------- /test/test_obj/3dsegm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/3dsegm.png -------------------------------------------------------------------------------- /test/test_obj/README.txt: -------------------------------------------------------------------------------- 1 | Images img_00.nii and mask_00.nii are taken from the Medical Segmentation Decathlon (http://medicaldecathlon.com/) for task 5 (prostate). 2 | 3 | @article{DBLP:journals/corr/abs-1902-09063, 4 | author = {Amber L. Simpson and 5 | Michela Antonelli and 6 | Spyridon Bakas and 7 | Michel Bilello and 8 | Keyvan Farahani and 9 | Bram van Ginneken and 10 | Annette Kopp{-}Schneider and 11 | Bennett A. Landman and 12 | Geert J. S. Litjens and 13 | Bjoern H. Menze and 14 | Olaf Ronneberger and 15 | Ronald M. Summers and 16 | Patrick Bilic and 17 | Patrick Ferdinand Christ and 18 | Richard K. G. Do and 19 | Marc Gollub and 20 | Jennifer Golia{-}Pernicka and 21 | Stephan Heckers and 22 | William R. Jarnagin and 23 | Maureen McHugo and 24 | Sandy Napel and 25 | Eugene Vorontsov and 26 | Lena Maier{-}Hein and 27 | M. Jorge Cardoso}, 28 | title = {A large annotated medical image dataset for the development and evaluation 29 | of segmentation algorithms}, 30 | journal = {CoRR}, 31 | volume = {abs/1902.09063}, 32 | year = {2019}, 33 | url = {http://arxiv.org/abs/1902.09063}, 34 | archivePrefix = {arXiv}, 35 | eprint = {1902.09063}, 36 | timestamp = {Tue, 21 May 2019 18:03:37 +0200}, 37 | biburl = {https://dblp.org/rec/journals/corr/abs-1902-09063.bib}, 38 | bibsource = {dblp computer science bibliography, https://dblp.org} 39 | } 40 | -------------------------------------------------------------------------------- /test/test_obj/agent_states_prostate_2D/epoch_300/agent_state_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/agent_states_prostate_2D/epoch_300/agent_state_dict.pkl -------------------------------------------------------------------------------- /test/test_obj/agent_states_prostate_2D/epoch_300/model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/agent_states_prostate_2D/epoch_300/model -------------------------------------------------------------------------------- /test/test_obj/agent_states_prostate_2D/epoch_300/optimizer: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/agent_states_prostate_2D/epoch_300/optimizer -------------------------------------------------------------------------------- /test/test_obj/example_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/example_result.png -------------------------------------------------------------------------------- /test/test_obj/img_00.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/img_00.nii -------------------------------------------------------------------------------- /test/test_obj/mask_00.nii: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/mask_00.nii -------------------------------------------------------------------------------- /test/test_obj/test_confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MECLabTUDA/ACS/bb418c5479a3585138c48c63112352f5cc8f64b1/test/test_obj/test_confusion_matrix.png -------------------------------------------------------------------------------- /test/utils/test_helper_functions.py: -------------------------------------------------------------------------------- 1 | import mp.utils.helper_functions as hf 2 | 3 | def test_date_time(): 4 | date = hf.get_time_string(True) 5 | assert len(date) == 21 6 | date = hf.get_time_string(False) 7 | assert len(date) == 19 -------------------------------------------------------------------------------- /test/utils/test_introspection.py: -------------------------------------------------------------------------------- 1 | from mp.utils.introspection import introspect 2 | from mp.utils.load_restore import join_path 3 | 4 | def test_introspection(): 5 | class_path = 'mp.models.classification.small_cnn.SmallCNN' 6 | exp = introspect(class_path)() 7 | assert exp.__class__.__name__ == 'SmallCNN' -------------------------------------------------------------------------------- /test/visualization/test_confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mp.visualization.confusion_matrix import ConfusionMatrix 3 | 4 | def test_confusion_matrix(): 5 | # We have 3 classes 6 | cm = ConfusionMatrix(3) 7 | # 2 tp for each class 8 | cm.add(predicted=0, actual=0, count=2) 9 | cm.add(predicted=1, actual=1, count=2) 10 | cm.add(predicted=2, actual=2, count=2) 11 | # 3 exampels of class 0 were predicted as class 1 12 | cm.add(predicted=1, actual=0, count=3) 13 | # 1 example of class 1 was predicted as class 2 14 | cm.add(predicted=2, actual=1, count=1) 15 | save_path = os.path.join('test', 'test_obj') 16 | cm.plot(path=save_path, name='test_confusion_matrix' ) 17 | assert os.path.isfile(os.path.join(save_path, 'test_confusion_matrix.png')) 18 | assert cm.cm == [[2, 3, 0], [0, 2, 1], [0, 0, 2]] 19 | assert cm.get_accuracy() == 0.6 -------------------------------------------------------------------------------- /test/visualization/test_plot_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mp.eval.result import Result 4 | from mp.visualization.plot_results import plot_results 5 | 6 | def test_plotting(): 7 | res = Result(name='example_result') 8 | res.add(1, 'accuracy', 0.2, data='train') 9 | res.add(2, 'accuracy', 0.3, data='train') 10 | res.add(3, 'accuracy', 0.4, data='train') 11 | res.add(0, 'F1', 0.5, data='train') 12 | res.add(3, 'F1', 0.7, data='train') 13 | res.add(0, 'F1', 0.3, data='val') 14 | res.add(3, 'F1', 0.45, data='val') 15 | save_path = os.path.join('test', 'test_obj') 16 | plot_results(res, measures = ['accuracy', 'F1'], save_path=save_path, title='Test figure', ending='.png') 17 | assert os.path.isfile(os.path.join(save_path, 'example_result.png')) -------------------------------------------------------------------------------- /test/visualization/test_visualize_imgs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SimpleITK as sitk 3 | import mp.visualization.visualize_imgs as vi 4 | import torch 5 | 6 | def test_3d_img(): 7 | images_path = os.path.join('test', 'test_obj') 8 | x = sitk.ReadImage(os.path.join(images_path, 'img_00.nii')) 9 | x = sitk.GetArrayFromImage(x)[0] # Take only T2-weighted 10 | y = sitk.ReadImage(os.path.join(images_path, 'mask_00.nii')) 11 | y = sitk.GetArrayFromImage(y) 12 | save_path = os.path.join('test', 'test_obj') 13 | vi.plot_3d_img(x, save_path=os.path.join(save_path, '3dimg.png'), img_size=(128,128)) 14 | assert os.path.isfile(os.path.join(save_path, '3dimg.png')) 15 | 16 | def test_3d_seg(): 17 | images_path = os.path.join('test', 'test_obj') 18 | x = sitk.ReadImage(os.path.join(images_path, 'img_00.nii')) 19 | x = sitk.GetArrayFromImage(x)[0] # Take only T2-weighted 20 | y = sitk.ReadImage(os.path.join(images_path, 'mask_00.nii')) 21 | y = sitk.GetArrayFromImage(y) 22 | save_path = os.path.join('test', 'test_obj') 23 | vi.plot_3d_segmentation(x, y, save_path=os.path.join(save_path, '3dsegm.png'), img_size=(128, 128)) 24 | assert os.path.isfile(os.path.join(save_path, '3dsegm.png')) -------------------------------------------------------------------------------- /unet_joint_train.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Code to train U-Net on all datasets simultaneously 4 | # ------------------------------------------------------------------------------ 5 | 6 | # Imports 7 | import os 8 | import sys 9 | from args import parse_args_as_dict 10 | from mp.utils.helper_functions import seed_all 11 | 12 | import torch 13 | torch.set_num_threads(6) 14 | from torch.utils.data import DataLoader 15 | import torch.optim as optim 16 | 17 | from mp.experiments.experiment import Experiment 18 | from mp.data.data import Data 19 | from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus 20 | from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus 21 | from mp.data.datasets.ds_mr_hippocampus_harp import HarP 22 | from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset 23 | from mp.eval.losses.losses_segmentation import LossClassWeighted, LossDiceBCE 24 | from mp.agents.unet_agent import UNETAgent 25 | from mp.eval.result import Result 26 | from mp.utils.tensorboard import create_writer 27 | from mp.models.continual.mas import MAS 28 | 29 | # Get configuration from arguments 30 | config = parse_args_as_dict(sys.argv[1:]) 31 | seed_all(42) 32 | 33 | config['class_weights'] = (0., 1.) 34 | 35 | # Create experiment directories 36 | exp = Experiment(config=config, name=config['experiment_name'], notes='', reload_exp=(config['resume_epoch'] is not None)) 37 | 38 | # Datasets 39 | data = Data() 40 | 41 | dataset_domain_a = DecathlonHippocampus(merge_labels=True) 42 | dataset_domain_a.name = 'DecathlonHippocampus' 43 | data.add_dataset(dataset_domain_a) 44 | 45 | dataset_domain_b = DryadHippocampus(merge_labels=True) 46 | dataset_domain_b.name = 'DryadHippocampus' 47 | data.add_dataset(dataset_domain_b) 48 | 49 | dataset_domain_c = HarP(merge_labels=True) 50 | dataset_domain_c.name = 'HarP' 51 | data.add_dataset(dataset_domain_c) 52 | 53 | nr_labels = data.nr_labels 54 | label_names = data.label_names 55 | 56 | if config['combination'] == 0: 57 | ds_a = ('DecathlonHippocampus', 'train') 58 | ds_b = ('DryadHippocampus', 'train') 59 | ds_c = ('HarP', 'train') 60 | elif config['combination'] == 1: 61 | ds_a = ('DecathlonHippocampus', 'train') 62 | ds_c = ('DryadHippocampus', 'train') 63 | ds_b = ('HarP', 'train') 64 | elif config['combination'] == 2: 65 | ds_c = ('DecathlonHippocampus', 'train') 66 | ds_b = ('DryadHippocampus', 'train') 67 | ds_a = ('HarP', 'train') 68 | 69 | # Create data splits for each repetition 70 | exp.set_data_splits(data) 71 | 72 | # Now repeat for each repetition 73 | for run_ix in range(config['nr_runs']): 74 | exp_run = exp.get_run(run_ix=0, reload_exp_run=(config['resume_epoch'] is not None)) 75 | datasets = dict() 76 | for idx, item in enumerate(data.datasets.items()): 77 | ds_name, ds = item 78 | for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items(): 79 | data_ixs = data_ixs[:config['n_samples']] 80 | if len(data_ixs) > 0: # Sometimes val indexes may be an empty list 81 | aug = config['augmentation'] if not('test' in split) else 'none' 82 | datasets[(ds_name, split)] = PytorchSeg2DDataset(ds, 83 | ix_lst=data_ixs, size=config['input_shape'] , aug_key=aug, 84 | resize=(not config['no_resize'])) 85 | 86 | dataset = torch.utils.data.ConcatDataset((datasets[(ds_a)], datasets[(ds_b)], datasets[(ds_c)])) 87 | train_dataloader_0 = DataLoader(dataset, batch_size=config['batch_size'], drop_last=False, pin_memory=True, num_workers=len(config['device_ids'])*config['n_workers']) 88 | 89 | if config['eval']: 90 | drop = [] 91 | for key in datasets.keys(): 92 | if 'train' in key or 'val' in key: 93 | drop += [key] 94 | for d in drop: 95 | datasets.pop(d) 96 | 97 | model = MAS(input_shape=config['input_shape'], nr_labels=nr_labels, 98 | unet_dropout=config['unet_dropout'], unet_monte_carlo_dropout=config['unet_monte_carlo_dropout'], unet_preactivation=config['unet_preactivation']) 99 | 100 | model.to(config['device']) 101 | 102 | # Define loss and optimizer 103 | loss_g = LossDiceBCE(bce_weight=1., smooth=1., device=config['device']) 104 | loss_f = LossClassWeighted(loss=loss_g, weights=config['class_weights'], device=config['device']) 105 | 106 | # Set optimizers 107 | model.set_optimizers(optim.Adam, lr=config['lr']) 108 | 109 | # Train model 110 | results = Result(name='training_trajectory') 111 | 112 | agent = UNETAgent(model=model, label_names=label_names, device=config['device']) 113 | agent.summary_writer = create_writer(os.path.join(exp_run.paths['states'], '..'), 0) 114 | 115 | init_epoch = 0 116 | nr_epochs = config['epochs'] 117 | 118 | config['continual'] = False 119 | 120 | # Resume training 121 | if config['resume_epoch'] is not None: 122 | agent.restore_state(exp_run.paths['states'], config['resume_epoch']) 123 | init_epoch = agent.agent_state_dict['epoch'] + 1 124 | 125 | # Joint Training 126 | agent.train(results, loss_f, train_dataloader_0, train_dataloader_0, config, 127 | init_epoch=init_epoch, nr_epochs=nr_epochs, run_loss_print_interval=1, 128 | eval_datasets=datasets, eval_interval=config['eval_interval'], 129 | save_path=exp_run.paths['states'], save_interval=config['save_interval'], 130 | display_interval=config['display_interval'], 131 | resume_epoch=config['resume_epoch'], device_ids=config['device_ids']) 132 | 133 | print('Finished training on A and B and C') 134 | 135 | # Save and print results for this experiment run 136 | exp_run.finish(results=results, plot_metrics=['Mean_LossBCEWithLogits', 'Mean_LossDice[smooth=1.0]', 'Mean_LossCombined[1.0xLossDice[smooth=1.0]+1.0xLossBCEWithLogits]']) 137 | --------------------------------------------------------------------------------