├── .github └── ISSUE_TEMPLATE │ └── question-on-boundary-loss.md ├── .gitignore ├── .gitmodules ├── LICENSE ├── acdc.make ├── data ├── ISLES.lineage ├── acdc.lineage └── wmh.lineage ├── dataloader.py ├── extract.sh ├── hist.py ├── isles.make ├── keras_loss.py ├── losses.py ├── main.py ├── metrics_overtime.py ├── models ├── __init__.py ├── enet.py ├── residualunet.py ├── unet.py └── unet_3d.py ├── moustache.py ├── networks.py ├── plot.py ├── preprocess ├── slice_acdc.py ├── slice_isles.py └── slice_wmh.py ├── readme.md ├── release.md ├── report.py ├── resources ├── acdc_bl.png └── readme_comparison.png ├── scheduler.py ├── test.py ├── utils.py └── wmh.make /.github/ISSUE_TEMPLATE/question-on-boundary-loss.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question on boundary loss 3 | about: If you have a question on the formulation or its inner workings 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Before posting, do check out the FAQ: https://github.com/LIVIAETS/boundary-loss#frequently-asked-question , the answer of your question might be there 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | inference/ 3 | archives/ 4 | 5 | data/* 6 | !data/*.lineage 7 | 8 | *.zip 9 | *.pkl 10 | *.tar.gz 11 | plots 12 | Results/ 13 | results 14 | *.npy 15 | plot_bounds/ 16 | /*.png 17 | /**/RANDOM_DATA* 18 | 19 | .DS_Store 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | .static_storage/ 77 | .media/ 78 | local_settings.py 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # celery beat schedule file 100 | celerybeat-schedule 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "viewer"] 2 | path = viewer 3 | url = git@github.com:HKervadec/segmentation_viewer.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hoel Kervadec 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /acdc.make: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 Hoel Kervadec 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | CC = python3.9 24 | PP = PYTHONPATH="$(PYTHONPATH):." 25 | SHELL = zsh 26 | 27 | 28 | .PHONY: all geodist train plot view view_labels npy pack report weak 29 | 30 | red:=$(shell tput bold ; tput setaf 1) 31 | green:=$(shell tput bold ; tput setaf 2) 32 | yellow:=$(shell tput bold ; tput setaf 3) 33 | blue:=$(shell tput bold ; tput setaf 4) 34 | reset:=$(shell tput sgr0) 35 | 36 | # RD stands for Result DIR -- useful way to report from extracted archive 37 | RD = results/acdc 38 | 39 | # CFLAGS = -O 40 | # DEBUG = --debug 41 | EPC = 100 42 | BS = 8 # BS stands for Batch Size 43 | K = 4 # K for class 44 | 45 | G_RGX = (patient\d+_\d+_\d+)_\d+ 46 | B_DATA = [('img', png_transform, False), ('gt', gt_transform, True)] 47 | NET = ENet 48 | # NET = Dummy 49 | 50 | 51 | TRN = $(RD)/ce \ 52 | $(RD)/diceloss \ 53 | $(RD)/boundary 54 | 55 | 56 | GRAPH = $(RD)/val_dice.png $(RD)/tra_dice.png \ 57 | $(RD)/tra_loss.png \ 58 | $(RD)/val_3d_dsc.png 59 | BOXPLOT = $(RD)/val_3d_dsc_boxplot.png 60 | PLT = $(GRAPH) $(HIST) $(BOXPLOT) 61 | 62 | REPO = $(shell basename `git rev-parse --show-toplevel`) 63 | DATE = $(shell date +"%y%m%d") 64 | HASH = $(shell git rev-parse --short HEAD) 65 | HOSTNAME = $(shell hostname) 66 | PBASE = archives 67 | PACK = $(PBASE)/$(REPO)-$(DATE)-$(HASH)-$(HOSTNAME)-acdc.tar.gz 68 | 69 | all: pack 70 | 71 | train: $(TRN) 72 | plot: $(PLT) 73 | 74 | pack: $(PACK) report 75 | $(PACK): $(PLT) $(TRN) 76 | $(info $(red)tar cf $@$(reset)) 77 | mkdir -p $(@D) 78 | tar cf - $^ | pigz > $@ 79 | chmod -w $@ 80 | # tar -zc -f $@ $^ # Use if pigz is not available 81 | 82 | 83 | # Data generation 84 | data/ACDC-2D: OPT = --seed=0 --retain 25 85 | data/ACDC-2D: data/acdc 86 | $(info $(yellow)$(CC) $(CFLAGS) preprocess/slice_acdc.py$(reset)) 87 | rm -rf $@_tmp $@ 88 | $(PP) $(CC) $(CFLAGS) preprocess/slice_acdc.py --source_dir="data/acdc/training" --dest_dir=$@_tmp $(OPT) 89 | mv $@_tmp $@ 90 | 91 | data/acdc: data/acdc.lineage data/acdc.zip 92 | $(info $(yellow)unzip data/acdc.zip$(reset)) 93 | md5sum -c $< 94 | rm -rf $@_tmp $@ 95 | unzip -q $(word 2, $^) -d $@_tmp 96 | rm $@_tmp/training/*/*_4d.nii.gz # space optimization 97 | mv $@_tmp $@ 98 | 99 | 100 | data/ACDC-2D/train/img data/ACDC-2D/val/img: | data/ACDC-2D 101 | data/ACDC-2D/train/gt data/ACDC-2D/val/gt: | data/ACDC-2D 102 | 103 | 104 | # Trainings 105 | $(RD)/ce: OPT = --losses="[('CrossEntropy', {'idc': [0, 1, 2, 3]}, 1)]" 106 | $(RD)/ce: data/ACDC-2D/train/gt data/ACDC-2D/val/gt 107 | $(RD)/ce: DATA = --folders="$(B_DATA)+[('gt', gt_transform, True)]" 108 | 109 | $(RD)/diceloss: OPT = --losses="[('DiceLoss', {'idc': [0, 1, 2, 3]}, 1)]" 110 | $(RD)/diceloss: data/ACDC-2D/train/gt data/ACDC-2D/val/gt 111 | $(RD)/diceloss: DATA = --folders="$(B_DATA)+[('gt', gt_transform, True)]" 112 | 113 | $(RD)/boundary: OPT = --losses="[('BoundaryLoss', {'idc': [0, 1, 2, 3]}, 1)]" 114 | $(RD)/boundary: data/ACDC-2D/train/gt data/ACDC-2D/val/gt 115 | $(RD)/boundary: DATA = --folders="$(B_DATA)+[('gt', dist_map_transform, False)]" 116 | 117 | # Template 118 | $(RD)/%: 119 | $(info $(green)$(CC) $(CFLAGS) main.py $@$(reset)) 120 | rm -rf $@_tmp 121 | mkdir -p $@_tmp 122 | printenv > $@_tmp/env.txt 123 | git diff > $@_tmp/repo.diff 124 | git rev-parse --short HEAD > $@_tmp/commit_hash 125 | $(CC) $(CFLAGS) main.py --dataset=$(dir $( None: 38 | # assert len(args.folders) <= len(colors) 39 | 40 | # if len(args.columns) > 1: 41 | # raise NotImplementedError("Only 1 columns at a time is handled for now") 42 | 43 | paths: List[Path] = [Path(f, args.filename) for f in args.folders] 44 | arrays: List[np.ndarray] = map_(np.load, paths) 45 | metric_name: str = paths[0].stem 46 | 47 | if len(arrays[0].shape) == 2: 48 | arrays = map_(lambda a: a[..., np.newaxis], arrays) 49 | epoch, _, class_ = arrays[0].shape 50 | for a in arrays[1:]: 51 | ea, _, ca = a.shape 52 | assert epoch == ea, (epoch, ea) 53 | 54 | if not args.dynamic_third_axis: 55 | assert class_ == ca, (class_, ca) 56 | 57 | fig = plt.figure(figsize=(14, 9)) 58 | ax = fig.gca() 59 | # ax.set_ylim([0, 1]) 60 | ax.set_xlim([0, 1]) 61 | ax.set_xlabel(metric_name) 62 | ax.set_ylabel("Percentage") 63 | ax.grid(True, axis='y') 64 | ax.set_title(f"{metric_name} histograms") 65 | 66 | bins = np.linspace(0, 1, args.nbins) 67 | c = 0 68 | for a, p in zip(arrays, paths): 69 | for k in args.columns: 70 | mean_a = a[..., k].mean(axis=1) 71 | best_epoch: int = np.argmax(mean_a) 72 | 73 | # values = a[args.epc, :, k] 74 | values = a[best_epoch, :, k] 75 | 76 | ax.hist(values, bins, alpha=0.5, label=f"{p.parent.name}-{k}", color=colors[c]) 77 | c += 1 78 | ax.legend() 79 | 80 | fig.tight_layout() 81 | if args.savefig: 82 | fig.savefig(args.savefig) 83 | 84 | if not args.headless: 85 | plt.show() 86 | 87 | 88 | def get_args() -> argparse.Namespace: 89 | parser = argparse.ArgumentParser(description='Plot data over time') 90 | parser.add_argument('--folders', type=str, required=True, nargs='+', help="The folders containing the file") 91 | parser.add_argument('--filename', type=str, required=True) 92 | parser.add_argument('--columns', type=int, nargs='+', default=0, help="Which columns of the third axis to plot") 93 | parser.add_argument("--savefig", type=str, default=None) 94 | parser.add_argument("--headless", action="store_true") 95 | parser.add_argument("--smooth", action="store_true", 96 | help="Help for compatibility with other plotting functions, does not do anything.") 97 | parser.add_argument("--nbins", type=int, default=100) 98 | parser.add_argument("--epc", type=int, required=True) 99 | 100 | parser.add_argument("--dynamic_third_axis", action="store_true", 101 | help="Allow the third axis of the arguments to be of varying size") 102 | 103 | # Dummies 104 | parser.add_argument("--debug", action="store_true", help="Dummy for compatibility") 105 | parser.add_argument("--cpu", action="store_true", help="Dummy for compatibility") 106 | parser.add_argument("--fontsize", type=int, default=10, help="Dummy opt for compatibility") 107 | parser.add_argument("--ylabel", type=str, default='') 108 | parser.add_argument("--loc", type=str, default=None) 109 | parser.add_argument("--labels", type=str, nargs='*') 110 | args = parser.parse_args() 111 | 112 | print(args) 113 | 114 | return args 115 | 116 | 117 | if __name__ == "__main__": 118 | run(get_args()) 119 | -------------------------------------------------------------------------------- /isles.make: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 Hoel Kervadec 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | CC = python3 24 | SHELL = /usr/bin/zsh 25 | PP = PYTHONPATH="$(PYTHONPATH):." 26 | 27 | # RD stands for Result DIR -- useful way to report from extracted archive 28 | RD = results/isles 29 | 30 | .PHONY = all boundary plot train metrics hausdorff pack 31 | 32 | red:=$(shell tput bold ; tput setaf 1) 33 | green:=$(shell tput bold ; tput setaf 2) 34 | yellow:=$(shell tput bold ; tput setaf 3) 35 | blue:=$(shell tput bold ; tput setaf 4) 36 | reset:=$(shell tput sgr0) 37 | 38 | # CFLAGS = -O 39 | # DEBUG = --debug 40 | EPC = 100 41 | # EPC = 5 42 | 43 | K = 2 44 | BS = 8 45 | G_RGX = (case_\d+_\d+)_\d+ 46 | P_RGX = (case_\d+)_\d+_\d+ 47 | NET = UNet 48 | B_DATA = [('in_npy', tensor_transform, False), ('gt_npy', gt_transform, True)] 49 | 50 | TRN = $(RD)/gdl $(RD)/gdl_surface_steal $(RD)/gdl_3d_surface_steal $(RD)/gdl_hausdorff_w 51 | 52 | GRAPH = $(RD)/tra_loss.png $(RD)/val_loss.png \ 53 | $(RD)/val_dice.png $(RD)/tra_dice.png \ 54 | $(RD)/val_3d_hausdorff.png \ 55 | $(RD)/val_3d_hd95.png 56 | BOXPLOT = $(RD)/val_dice_boxplot.png 57 | PLT = $(GRAPH) $(BOXPLOT) 58 | 59 | REPO = $(shell basename `git rev-parse --show-toplevel`) 60 | DATE = $(shell date +"%y%m%d") 61 | HASH = $(shell git rev-parse --short HEAD) 62 | HOSTNAME = $(shell hostname) 63 | PBASE = archives 64 | PACK = $(PBASE)/$(REPO)-$(DATE)-$(HASH)-$(HOSTNAME)-isles.tar.gz 65 | 66 | all: $(PACK) 67 | 68 | plot: $(PLT) 69 | 70 | train: $(TRN) 71 | 72 | pack: report $(PACK) 73 | $(PACK): $(PLT) $(TRN) 74 | $(info $(red)tar cf $@$(reset)) 75 | mkdir -p $(@D) 76 | tar cf - $^ | pigz > $@ 77 | chmod -w $@ 78 | # tar -zc -f $@ $^ # Use if pigz is not available 79 | $(LIGHTPACK): $(PLT) $(TRN) 80 | mkdir -p $(@D) 81 | $(eval PLTS:=$(filter %.png, $^)) 82 | $(eval FF:=$(filter-out %.png, $^)) 83 | $(eval TGT:=$(addsuffix /best_epoch, $(FF)) $(addsuffix /*.npy, $(FF)) $(addsuffix /best_epoch.txt, $(FF)) $(addsuffix /metrics.csv, $(FF))) 84 | tar cf - $(PLTS) $(TGT) | pigz > $@ 85 | chmod -w $@ 86 | 87 | 88 | 89 | # Extraction and slicing 90 | data/ISLES/train/in_npy data/ISLES/val/in_npy: data/ISLES 91 | data/ISLES: data/isles/TRAINING 92 | $(info $(yellow)$(CC) $(CFLAGS) preprocess/slice_isles.py$(reset)) 93 | rm -rf $@_tmp $@ 94 | $(PP) $(CC) $(CFLAGS) preprocess/slice_isles.py --source_dir $< --dest_dir $@_tmp --n_augment=0 --retain=20 95 | mv $@_tmp $@ 96 | data/ISLES/test: data/isles/TESTING 97 | 98 | data/isles/TESTING data/isles/TRAINING: data/isles 99 | data/isles: data/ISLES.lineage data/ISLES2018_Training.zip data/ISLES2018_Testing.zip 100 | $(info $(yellow)unzip data/ISLES2018_Training.zip data/ISLES2018_Testing.zip$(reset)) 101 | md5sum -c $< 102 | rm -rf $@_tmp $@ 103 | unzip -q $(word 2, $^) -d $@_tmp 104 | unzip -q $(word 3, $^) -d $@_tmp 105 | rm -r $@_tmp/__MACOSX 106 | rm -r $@_tmp/*/*/*CT_4DPWI* # For space efficiency1 107 | mv $@_tmp $@ 108 | 109 | 110 | # Training 111 | $(RD)/gdl: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1)]" 112 | $(RD)/gdl: data/ISLES/train/in_npy data/ISLES/val/in_npy 113 | $(RD)/gdl: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True)]" 114 | 115 | $(RD)/gdl_surface_w: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 116 | ('SurfaceLoss', {'idc': [1]}, 0.1)]" 117 | $(RD)/gdl_surface_w: data/ISLES/train/in_npy data/ISLES/val/in_npy 118 | $(RD)/gdl_surface_w: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 119 | ('gt_npy', dist_map_transform, False)]" 120 | 121 | $(RD)/gdl_hausdorff_w: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 122 | ('HausdorffLoss', {'idc': [1]}, 0.1)]" 123 | $(RD)/gdl_hausdorff_w: data/ISLES/train/in_npy data/ISLES/val/in_npy 124 | $(RD)/gdl_hausdorff_w: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 125 | ('gt_npy', gt_transform, True)]" 126 | 127 | 128 | $(RD)/hausdorff: OPT = --losses="[('HausdorffLoss', {'idc': [1]}, 0.1)]" 129 | $(RD)/hausdorff: data/ISLES/train/in_npy data/ISLES/val/in_npy 130 | $(RD)/hausdorff: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True)]" 131 | 132 | 133 | $(RD)/gdl_surface_add: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 134 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 135 | $(RD)/gdl_surface_add: data/ISLES/train/in_npy data/ISLES/val/in_npy 136 | $(RD)/gdl_surface_add: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 137 | ('gt_npy', dist_map_transform, False)]" \ 138 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 139 | 140 | $(RD)/gdl_surface_steal: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 141 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 142 | $(RD)/gdl_surface_steal: data/ISLES/train/in_npy data/ISLES/val/in_npy 143 | $(RD)/gdl_surface_steal: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 144 | ('gt_npy', dist_map_transform, False)]" \ 145 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 146 | 147 | 148 | $(RD)/gdl_3d_surface_steal: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 149 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 150 | $(RD)/gdl_3d_surface_steal: data/ISLES/train/in_npy data/ISLES/val/in_npy 151 | $(RD)/gdl_3d_surface_steal: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 152 | ('3d_distmap', raw_npy_transform, False)]" \ 153 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 154 | 155 | $(RD)/surface: OPT = --losses="[('SurfaceLoss', {'idc': [1]}, 0.1)]" 156 | $(RD)/surface: data/ISLES/train/in_npy data/ISLES/val/in_npy 157 | $(RD)/surface: DATA = --folders="$(B_DATA)+[('gt_npy', dist_map_transform, False)]" 158 | 159 | 160 | $(RD)/%: 161 | $(info $(green)$(CC) $(CFLAGS) main.py $@$(reset)) 162 | rm -rf $@_tmp 163 | mkdir -p $@_tmp 164 | printenv > $@_tmp/env.txt 165 | git diff > $@_tmp/repo.diff 166 | git rev-parse --short HEAD > $@_tmp/commit_hash 167 | $(CC) $(CFLAGS) main.py --dataset=$(dir $( Tensor: 42 | assert simplex(probs) and simplex(target) 43 | 44 | log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log() 45 | mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) 46 | 47 | loss = - einsum("bkwh,bkwh->", mask, log_p) 48 | loss /= mask.sum() + 1e-10 49 | 50 | return loss 51 | 52 | 53 | class GeneralizedDice(): 54 | def __init__(self, **kwargs): 55 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 56 | self.idc: List[int] = kwargs["idc"] 57 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 58 | 59 | def __call__(self, probs: Tensor, target: Tensor) -> Tensor: 60 | assert simplex(probs) and simplex(target) 61 | 62 | pc = probs[:, self.idc, ...].type(torch.float32) 63 | tc = target[:, self.idc, ...].type(torch.float32) 64 | 65 | w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) + 1e-10) ** 2) 66 | intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc) 67 | union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) 68 | 69 | divided: Tensor = 1 - 2 * (einsum("bk->b", intersection) + 1e-10) / (einsum("bk->b", union) + 1e-10) 70 | 71 | loss = divided.mean() 72 | 73 | return loss 74 | 75 | 76 | class DiceLoss(): 77 | def __init__(self, **kwargs): 78 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 79 | self.idc: List[int] = kwargs["idc"] 80 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 81 | 82 | def __call__(self, probs: Tensor, target: Tensor) -> Tensor: 83 | assert simplex(probs) and simplex(target) 84 | 85 | pc = probs[:, self.idc, ...].type(torch.float32) 86 | tc = target[:, self.idc, ...].type(torch.float32) 87 | 88 | intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc) 89 | union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) 90 | 91 | divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10) 92 | 93 | loss = divided.mean() 94 | 95 | return loss 96 | 97 | 98 | class SurfaceLoss(): 99 | def __init__(self, **kwargs): 100 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 101 | self.idc: List[int] = kwargs["idc"] 102 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 103 | 104 | def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor: 105 | assert simplex(probs) 106 | assert not one_hot(dist_maps) 107 | 108 | pc = probs[:, self.idc, ...].type(torch.float32) 109 | dc = dist_maps[:, self.idc, ...].type(torch.float32) 110 | 111 | multipled = einsum("bkwh,bkwh->bkwh", pc, dc) 112 | 113 | loss = multipled.mean() 114 | 115 | return loss 116 | 117 | 118 | BoundaryLoss = SurfaceLoss 119 | 120 | 121 | class HausdorffLoss(): 122 | """ 123 | Implementation heavily inspired from https://github.com/JunMa11/SegWithDistMap 124 | """ 125 | def __init__(self, **kwargs): 126 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 127 | self.idc: List[int] = kwargs["idc"] 128 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 129 | 130 | def __call__(self, probs: Tensor, target: Tensor) -> Tensor: 131 | assert simplex(probs) 132 | assert simplex(target) 133 | assert probs.shape == target.shape 134 | 135 | B, K, *xyz = probs.shape # type: ignore 136 | 137 | pc = cast(Tensor, probs[:, self.idc, ...].type(torch.float32)) 138 | tc = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) 139 | assert pc.shape == tc.shape == (B, len(self.idc), *xyz) 140 | 141 | target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy()) 142 | for b in range(B)], axis=0) 143 | assert target_dm_npy.shape == tc.shape == pc.shape 144 | tdm: Tensor = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32) 145 | 146 | pred_segmentation: Tensor = probs2one_hot(probs).cpu().detach() 147 | pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...].numpy()) 148 | for b in range(B)], axis=0) 149 | assert pred_dm_npy.shape == tc.shape == pc.shape 150 | pdm: Tensor = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32) 151 | 152 | delta = (pc - tc)**2 153 | dtm = tdm**2 + pdm**2 154 | 155 | multipled = einsum("bkwh,bkwh->bkwh", delta, dtm) 156 | 157 | loss = multipled.mean() 158 | 159 | return loss 160 | 161 | 162 | class FocalLoss(): 163 | def __init__(self, **kwargs): 164 | # Self.idc is used to filter out some classes of the target mask. Use fancy indexing 165 | self.idc: List[int] = kwargs["idc"] 166 | self.gamma: float = kwargs["gamma"] 167 | print(f"Initialized {self.__class__.__name__} with {kwargs}") 168 | 169 | def __call__(self, probs: Tensor, target: Tensor) -> Tensor: 170 | assert simplex(probs) and simplex(target) 171 | 172 | masked_probs: Tensor = probs[:, self.idc, ...] 173 | log_p: Tensor = (masked_probs + 1e-10).log() 174 | mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) 175 | 176 | w: Tensor = (1 - masked_probs)**self.gamma 177 | loss = - einsum("bkwh,bkwh,bkwh->", w, mask, log_p) 178 | loss /= mask.sum() + 1e-10 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import argparse 26 | import warnings 27 | from pathlib import Path 28 | from functools import reduce 29 | from operator import add, itemgetter 30 | from shutil import copytree, rmtree 31 | from typing import Any, Callable, Optional, Tuple, cast 32 | 33 | import torch 34 | import numpy as np 35 | import pandas as pd 36 | import torch.nn.functional as F 37 | from torch import Tensor 38 | from torch.utils.data import DataLoader 39 | 40 | from dataloader import get_loaders 41 | from utils import map_ 42 | from utils import depth 43 | from utils import probs2one_hot, probs2class 44 | from utils import dice_coef, save_images, tqdm_, dice_batch 45 | 46 | 47 | def setup(args, n_class: int) -> Tuple[Any, Any, Any, list[list[Callable]], list[list[float]], Callable]: 48 | print("\n>>> Setting up") 49 | cpu: bool = args.cpu or not torch.cuda.is_available() 50 | device = torch.device("cpu") if cpu else torch.device("cuda") 51 | 52 | if args.weights: 53 | if cpu: 54 | net = torch.load(args.weights, map_location='cpu') 55 | else: 56 | net = torch.load(args.weights) 57 | print(f">> Restored weights from {args.weights} successfully.") 58 | else: 59 | net_class = getattr(__import__('networks'), args.network) 60 | net = net_class(args.modalities, n_class).to(device) 61 | net.init_weights() 62 | net.to(device) 63 | 64 | optimizer: Any # disable an error for the optmizer (ADAM and SGD not same type) 65 | if args.use_sgd: 66 | optimizer = torch.optim.SGD(net.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4) 67 | else: 68 | optimizer = torch.optim.Adam(net.parameters(), lr=args.l_rate, betas=(0.9, 0.99), amsgrad=False) 69 | 70 | # print(args.losses) 71 | list_losses = eval(args.losses) 72 | if depth(list_losses) == 1: # For compatibility reasons, avoid changing all the previous configuration files 73 | list_losses = [list_losses] 74 | 75 | loss_fns: list[list[Callable]] = [] 76 | for i, losses in enumerate(list_losses): 77 | print(f">> {i}th list of losses: {losses}") 78 | tmp: list[Callable] = [] 79 | for loss_name, loss_params, _ in losses: 80 | loss_class = getattr(__import__('losses'), loss_name) 81 | tmp.append(loss_class(**loss_params)) 82 | loss_fns.append(tmp) 83 | 84 | loss_weights: list[list[float]] = [map_(itemgetter(2), losses) for losses in list_losses] 85 | 86 | scheduler = getattr(__import__('scheduler'), args.scheduler)(**eval(args.scheduler_params)) 87 | 88 | return net, optimizer, device, loss_fns, loss_weights, scheduler 89 | 90 | 91 | def do_epoch(mode: str, net: Any, device: Any, loaders: list[DataLoader], epc: int, 92 | list_loss_fns: list[list[Callable]], list_loss_weights: list[list[float]], K: int, 93 | savedir: str = "", optimizer: Any = None, 94 | metric_axis: list[int] = [1], 95 | compute_3d_dice: bool = False, 96 | temperature: float = 1) -> Tuple[Tensor, 97 | Tensor, 98 | Optional[Tensor]]: 99 | assert mode in ["train", "val", "dual"] 100 | 101 | if mode == "train": 102 | net.train() 103 | desc = f">> Training ({epc})" 104 | elif mode == "val": 105 | net.eval() 106 | desc = f">> Validation ({epc})" 107 | 108 | total_iteration: int = sum(len(loader) for loader in loaders) # U 109 | total_images: int = sum(len(loader.dataset) for loader in loaders) # D 110 | n_loss: int = max(map(len, list_loss_fns)) 111 | 112 | all_dices: Tensor = torch.zeros((total_images, K), dtype=torch.float32, device=device) 113 | loss_log: Tensor = torch.zeros((total_iteration, n_loss), dtype=torch.float32, device=device) 114 | 115 | three_d_dices: Optional[Tensor] 116 | if compute_3d_dice: 117 | three_d_dices = torch.zeros((total_iteration, K), dtype=torch.float32, device=device) 118 | else: 119 | three_d_dices = None 120 | 121 | done_img: int = 0 122 | done_batch: int = 0 123 | tq_iter = tqdm_(total=total_iteration, desc=desc) 124 | for i, (loader, loss_fns, loss_weights) in enumerate(zip(loaders, list_loss_fns, list_loss_weights)): 125 | for data in loader: 126 | # t0 = time() 127 | image: Tensor = data["images"].to(device) 128 | target: Tensor = data["gt"].to(device) 129 | filenames: list[str] = data["filenames"] 130 | assert not target.requires_grad 131 | labels: list[Tensor] = [e.to(device) for e in data["labels"]] 132 | B, C, *_ = image.shape 133 | 134 | # Reset gradients 135 | if optimizer: 136 | optimizer.zero_grad() 137 | 138 | # Forward 139 | pred_logits: Tensor = net(image) 140 | pred_probs: Tensor = F.softmax(temperature * pred_logits, dim=1) 141 | predicted_mask: Tensor = probs2one_hot(pred_probs.detach()) # Used only for dice computation 142 | assert not predicted_mask.requires_grad 143 | 144 | assert len(loss_fns) == len(loss_weights) == len(labels) 145 | ziped = zip(loss_fns, labels, loss_weights) 146 | losses = [w * loss_fn(pred_probs, label) 147 | for loss_fn, label, w in ziped] 148 | loss = reduce(add, losses) 149 | assert loss.shape == (), loss.shape 150 | 151 | # Backward 152 | if optimizer: 153 | loss.backward() 154 | optimizer.step() 155 | 156 | # Compute and log metrics 157 | for j in range(len(loss_fns)): 158 | loss_log[done_batch, j] = losses[j].detach() 159 | 160 | sm_slice = slice(done_img, done_img + B) # Values only for current batch 161 | 162 | dices: Tensor = dice_coef(predicted_mask, target) 163 | assert dices.shape == (B, K), (dices.shape, B, K) 164 | all_dices[sm_slice, ...] = dices 165 | 166 | if compute_3d_dice: 167 | three_d_DSC: Tensor = dice_batch(predicted_mask, target) 168 | assert three_d_DSC.shape == (K,) 169 | 170 | three_d_dices[done_batch] = three_d_DSC # type: ignore 171 | 172 | # Save images 173 | if savedir: 174 | with warnings.catch_warnings(): 175 | warnings.filterwarnings("ignore", category=UserWarning) 176 | predicted_class: Tensor = probs2class(pred_probs) 177 | save_images(predicted_class, filenames, savedir, mode, epc) 178 | 179 | # Logging 180 | big_slice = slice(0, done_img + B) # Value for current and previous batches 181 | 182 | dsc_dict: dict = {f"DSC{n}": all_dices[big_slice, n].mean() for n in metric_axis} | \ 183 | ({f"3d_DSC{n}": three_d_dices[:done_batch, n].mean() for n in metric_axis} 184 | if three_d_dices is not None else {}) 185 | 186 | loss_dict = {f"loss_{i}": loss_log[:done_batch].mean(dim=0)[i] for i in range(n_loss)} 187 | 188 | stat_dict = dsc_dict | loss_dict 189 | nice_dict = {k: f"{v:.3f}" for (k, v) in stat_dict.items()} 190 | 191 | done_img += B 192 | done_batch += 1 193 | tq_iter.set_postfix({**nice_dict, "loader": str(i)}) 194 | tq_iter.update(1) 195 | tq_iter.close() 196 | 197 | print(f"{desc} " + ', '.join(f"{k}={v}" for (k, v) in nice_dict.items())) 198 | 199 | return (loss_log.detach().cpu(), 200 | all_dices.detach().cpu(), 201 | three_d_dices.detach().cpu() if three_d_dices is not None else None) 202 | 203 | 204 | def run(args: argparse.Namespace) -> dict[str, Tensor]: 205 | n_class: int = args.n_class 206 | lr: float = args.l_rate 207 | savedir: str = args.workdir 208 | n_epoch: int = args.n_epoch 209 | val_f: int = args.val_loader_id 210 | 211 | loss_fns: list[list[Callable]] 212 | loss_weights: list[list[float]] 213 | net, optimizer, device, loss_fns, loss_weights, scheduler = setup(args, n_class) 214 | train_loaders: list[DataLoader] 215 | val_loaders: list[DataLoader] 216 | train_loaders, val_loaders = get_loaders(args, args.dataset, 217 | args.batch_size, n_class, 218 | args.debug, args.in_memory, args.dimensions, args.use_spacing) 219 | 220 | n_tra: int = sum(len(tr_lo.dataset) for tr_lo in train_loaders) # Number of images in dataset 221 | l_tra: int = sum(len(tr_lo) for tr_lo in train_loaders) # Number of iteration per epc: different if batch_size > 1 222 | n_val: int = sum(len(vl_lo.dataset) for vl_lo in val_loaders) 223 | l_val: int = sum(len(vl_lo) for vl_lo in val_loaders) 224 | n_loss: int = max(map(len, loss_fns)) 225 | 226 | best_dice: Tensor = cast(Tensor, torch.zeros(1).type(torch.float32)) 227 | best_epoch: int = 0 228 | metrics: dict[str, Tensor] = {"val_dice": torch.zeros((n_epoch, n_val, n_class)).type(torch.float32), 229 | "val_loss": torch.zeros((n_epoch, l_val, len(loss_fns[val_f]))).type(torch.float32), 230 | "tra_dice": torch.zeros((n_epoch, n_tra, n_class)).type(torch.float32), 231 | "tra_loss": torch.zeros((n_epoch, l_tra, n_loss)).type(torch.float32)} 232 | if args.compute_3d_dice: 233 | metrics["val_3d_dsc"] = cast(Tensor, torch.zeros((n_epoch, l_val, n_class)).type(torch.float32)) 234 | 235 | print("\n>>> Starting the training") 236 | for i in range(n_epoch): 237 | # Do training and validation loops 238 | tra_loss, tra_dice, _ = do_epoch("train", net, device, train_loaders, i, 239 | loss_fns, loss_weights, n_class, 240 | savedir=savedir if args.save_train else "", 241 | optimizer=optimizer, 242 | metric_axis=args.metric_axis, 243 | temperature=args.temperature) 244 | with torch.no_grad(): 245 | val_res = do_epoch("val", net, device, val_loaders, i, 246 | [loss_fns[val_f]], 247 | [loss_weights[val_f]], 248 | n_class, 249 | savedir=savedir, 250 | metric_axis=args.metric_axis, 251 | compute_3d_dice=args.compute_3d_dice, 252 | temperature=args.temperature) 253 | val_loss, val_dice, val_3d_dsc = val_res 254 | 255 | # Sort and save the metrics 256 | for k in metrics: 257 | assert metrics[k][i].shape == eval(k).shape, (metrics[k][i].shape, eval(k).shape, k) 258 | metrics[k][i] = eval(k) 259 | 260 | for k, e in metrics.items(): 261 | np.save(Path(savedir, f"{k}.npy"), e.cpu().numpy()) 262 | 263 | df = pd.DataFrame({"tra_loss": metrics["tra_loss"].mean(dim=(1, 2)).numpy(), 264 | "val_loss": metrics["val_loss"].mean(dim=(1, 2)).numpy(), 265 | "tra_dice": metrics["tra_dice"][:, :, -1].mean(dim=1).numpy(), 266 | "val_dice": metrics["val_dice"][:, :, -1].mean(dim=1).numpy()}) 267 | df.to_csv(Path(savedir, args.csv), float_format="%.4f", index_label="epoch") 268 | 269 | # Save model if better 270 | current_dice: Tensor = val_dice[:, args.metric_axis].mean() 271 | if current_dice > best_dice: 272 | best_epoch = i 273 | best_dice = current_dice 274 | 275 | with open(Path(savedir, "best_epoch.txt"), 'w') as f: 276 | f.write(str(i)) 277 | best_folder = Path(savedir, "best_epoch") 278 | if best_folder.exists(): 279 | rmtree(best_folder) 280 | copytree(Path(savedir, f"iter{i:03d}"), Path(best_folder)) 281 | torch.save(net, Path(savedir, "best.pkl")) 282 | 283 | optimizer, loss_fns, loss_weights = scheduler(i, optimizer, loss_fns, loss_weights) 284 | 285 | # if args.schedule and (i > (best_epoch + 20)): 286 | if args.schedule and (i % (best_epoch + 20) == 0): # Yeah, ugly but will clean that later 287 | for param_group in optimizer.param_groups: 288 | lr *= 0.5 289 | param_group['lr'] = lr 290 | print(f'>> New learning Rate: {lr}') 291 | 292 | if i > 0 and not (i % 5): 293 | maybe_3d = ', 3d_DSC: {best_3d_dsc:.3f}' if args.compute_3d_dice else '' 294 | print(f">> Best results at epoch {best_epoch}: DSC: {best_dice:.3f}{maybe_3d}") 295 | 296 | # Because displaying the results at the end is actually convenient 297 | maybe_3d = ', 3d_DSC: {best_3d_dsc:.3f}' if args.compute_3d_dice else '' 298 | print(f">> Best results at epoch {best_epoch}: DSC: {best_dice:.3f}{maybe_3d}") 299 | for metric in metrics: 300 | # Do not care about training values, nor the loss (keep it simple) 301 | if "val" in metric and "loss" not in metric: 302 | print(f"\t{metric}: {metrics[metric][best_epoch].mean(dim=0)}") 303 | 304 | return metrics 305 | 306 | 307 | def get_args() -> argparse.Namespace: 308 | parser = argparse.ArgumentParser(description='Hyperparams') 309 | parser.add_argument('--dataset', type=str, required=True) 310 | parser.add_argument("--csv", type=str, required=True) 311 | parser.add_argument("--workdir", type=str, required=True) 312 | parser.add_argument("--losses", type=str, required=True, 313 | help="List of list of (loss_name, loss_params, bounds_name, bounds_params, fn, weight)") 314 | parser.add_argument("--folders", type=str, required=True, 315 | help="List of list of (subfolder, transform, is_hot)") 316 | parser.add_argument("--network", type=str, required=True, help="The network to use") 317 | parser.add_argument("--n_class", type=int, required=True) 318 | parser.add_argument("--metric_axis", type=int, nargs='*', required=True, help="Classes to display metrics. \ 319 | Display only the average of everything if empty") 320 | 321 | parser.add_argument("--debug", action="store_true") 322 | parser.add_argument("--cpu", action='store_true') 323 | parser.add_argument("--in_memory", action='store_true') 324 | parser.add_argument("--schedule", action='store_true') 325 | parser.add_argument("--use_sgd", action='store_true') 326 | parser.add_argument("--compute_3d_dice", action='store_true') 327 | parser.add_argument("--save_train", action='store_true') 328 | parser.add_argument("--use_spacing", action='store_true') 329 | parser.add_argument("--no_assert_dataloader", action='store_true') 330 | parser.add_argument("--ignore_norm_dataloader", action='store_true') 331 | parser.add_argument("--group", action='store_true', help="Group the patient slices together for validation. \ 332 | Useful to compute the 3d dice, but might destroy the memory for datasets with a lot of slices per patient.") 333 | parser.add_argument("--group_train", action='store_true', help="Group the patient slices together for training. \ 334 | Useful to compute the 3d dice, but might destroy the memory for datasets with a lot of slices per patient.") 335 | 336 | parser.add_argument('--n_epoch', nargs='?', type=int, default=200, 337 | help='# of the epochs') 338 | parser.add_argument('--l_rate', nargs='?', type=float, default=5e-4, 339 | help='Learning Rate') 340 | parser.add_argument("--grp_regex", type=str, default=None) 341 | parser.add_argument('--temperature', type=float, default=1, help="Temperature for the softmax") 342 | parser.add_argument("--scheduler", type=str, default="DummyScheduler") 343 | parser.add_argument("--scheduler_params", type=str, default="{}") 344 | parser.add_argument("--modalities", type=int, default=1) 345 | parser.add_argument("--dimensions", type=int, default=2) 346 | parser.add_argument('--batch_size', type=int, default=1) 347 | parser.add_argument("--weights", type=str, default='', help="Stored weights to restore") 348 | parser.add_argument("--training_folders", type=str, nargs="+", default=["train"]) 349 | parser.add_argument("--validation_folder", type=str, default="val") 350 | parser.add_argument("--val_loader_id", type=int, default=-1, help=""" 351 | Kinda housefiry at the moment. When we have several train loader (for hybrid training 352 | for instance), wants only one validation loader. The way the dataloading creation is 353 | written at the moment, it will create several validation loader on the same topfolder (val), 354 | but with different folders/bounds ; which will basically duplicate the evaluation. 355 | """) 356 | 357 | args = parser.parse_args() 358 | if args.metric_axis == []: 359 | args.metric_axis = list(range(args.n_class)) 360 | print("\n", args) 361 | 362 | return args 363 | 364 | 365 | if __name__ == '__main__': 366 | run(get_args()) 367 | -------------------------------------------------------------------------------- /metrics_overtime.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.8 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import re 26 | import pickle 27 | import argparse 28 | from pathlib import Path 29 | from functools import partial 30 | from multiprocessing import cpu_count, Pool 31 | from typing import Dict, List, Match, Optional, Pattern, Tuple 32 | 33 | import torch 34 | import numpy as np 35 | from tqdm import tqdm 36 | from torch import Tensor, einsum 37 | from torch.utils.data import DataLoader 38 | from medpy.metric.binary import hd, hd95 39 | 40 | from utils import map_, starmmap_ 41 | from utils import dice_batch, hausdorff 42 | from dataloader import SliceDataset, PatientSampler, custom_collate 43 | from dataloader import png_transform, gt_transform, dist_map_transform 44 | 45 | 46 | def get_args() -> argparse.Namespace: 47 | parser = argparse.ArgumentParser(description='Compute metrics over time on saved predictions') 48 | parser.add_argument('--basefolder', type=str, required=True, help="The folder containing the predicted epochs") 49 | parser.add_argument('--gt_folder', type=str) 50 | parser.add_argument('--spacing', type=str, default='') 51 | parser.add_argument('--metrics', type=str, nargs='+', required=True, 52 | choices=['3d_dsc', '3d_hausdorff', '3d_hd95', 'hausdorff', 'boundary']) 53 | parser.add_argument("--grp_regex", type=str, required=True) 54 | parser.add_argument("--resolution_regex", type=str, default=None) 55 | parser.add_argument('--num_classes', type=int, required=True) 56 | parser.add_argument("--debug", action="store_true", help="Dummy for compatibility") 57 | parser.add_argument("--cpu", action="store_true") 58 | 59 | parser.add_argument("--n_epoch", type=int, default=-1) 60 | args = parser.parse_args() 61 | 62 | print(args) 63 | 64 | return args 65 | 66 | 67 | def main() -> None: 68 | args = get_args() 69 | 70 | cpu: bool = args.cpu or not torch.cuda.is_available() 71 | device = torch.device("cpu") if cpu else torch.device("cuda") 72 | 73 | base_path: Path = Path(args.basefolder) 74 | 75 | iterations_paths: List[Path] = sorted(base_path.glob("iter*")) 76 | # print(iterations_paths) 77 | print(f">>> Found {len(iterations_paths)} epoch folders") 78 | 79 | # Handle gracefully if not all folders are there (early stop) 80 | EPC: int = args.n_epoch if args.n_epoch >= 0 else len(iterations_paths) 81 | K: int = args.num_classes 82 | 83 | # Get the patient number, and image names, from the GT folder 84 | gt_path: Path = Path(args.gt_folder) 85 | names: List[str] = map_(lambda p: str(p.name), gt_path.glob("*")) 86 | n_img: int = len(names) 87 | 88 | resolution_regex: Pattern = re.compile(args.resolution_regex if args.resolution_regex else args.grp_regex) 89 | spacing_dict: Dict[str, Tuple[float, float, float]] 90 | spacing_dict = pickle.load(open(args.spacing, 'rb')) if args.spacing else None 91 | 92 | grouping_regex: Pattern = re.compile(args.grp_regex) 93 | stems: List[str] = [Path(filename).stem for filename in names] # avoid matching the extension 94 | matches: List[Match] = map_(grouping_regex.match, stems) # type: ignore 95 | patients: List[str] = [match.group(1) for match in matches] 96 | 97 | unique_patients: List[str] = list(set(patients)) 98 | n_patients: int = len(unique_patients) 99 | 100 | print(f">>> Found {len(unique_patients)} unique patients out of {n_img} images ; regex: {args.grp_regex}") 101 | # from pprint import pprint 102 | # pprint(unique_patients) 103 | 104 | # First, quickly assert all folders have the same numbers of predited images 105 | n_img_epoc: List[int] = [len(list((p / "val").glob("*.png"))) for p in iterations_paths] 106 | assert len(set(n_img_epoc)) == 1 107 | assert all(len(list((p / "val").glob("*.png"))) == n_img for p in iterations_paths) 108 | 109 | metrics: Dict['str', Tensor] = {} 110 | if '3d_dsc' in args.metrics: 111 | metrics['3d_dsc'] = torch.zeros((EPC, n_patients, K), dtype=torch.float32) 112 | print(f">> Will compute {'3d_dsc'} metric") 113 | if '3d_hausdorff' in args.metrics: 114 | metrics['3d_hausdorff'] = torch.zeros((EPC, n_patients, K), dtype=torch.float32) 115 | print(f">> Will compute {'3d_hausdorff'} metric") 116 | if '3d_hd95' in args.metrics: 117 | metrics['3d_hd95'] = torch.zeros((EPC, n_patients, K), dtype=torch.float32) 118 | print(f">> Will compute {'3d_hd95'} metric") 119 | if 'hausdorff' in args.metrics: 120 | metrics['hausdorff'] = torch.zeros((EPC, n_img, K), dtype=torch.float32) 121 | print(f">> Will compute {'hausdorff'} metric") 122 | if 'boundary' in args.metrics: 123 | metrics['boundary'] = torch.zeros((EPC, n_img, K), dtype=torch.float32) 124 | print(f">> Will compute {'boundary'} metric") 125 | 126 | gen_dataset = partial(SliceDataset, 127 | transforms=[png_transform, gt_transform, gt_transform, dist_map_transform], 128 | are_hots=[False, True, True, False], 129 | K=K, 130 | in_memory=False, 131 | dimensions=2) 132 | data_loader = partial(DataLoader, 133 | num_workers=cpu_count(), 134 | pin_memory=False, 135 | collate_fn=custom_collate) 136 | 137 | # Will replace live dataset.folders and call again load_images to update dataset.files 138 | print(gt_path, gt_path, Path(iterations_paths[0], 'val')) 139 | dataset: SliceDataset = gen_dataset(names, [gt_path, gt_path, Path(iterations_paths[0], 'val'), gt_path]) 140 | sampler: PatientSampler = PatientSampler(dataset, args.grp_regex, shuffle=False) 141 | dataloader: DataLoader = data_loader(dataset, batch_sampler=sampler) 142 | 143 | current_path: Path 144 | for e, current_path in enumerate(iterations_paths): 145 | pool = Pool() 146 | dataset.folders = [gt_path, gt_path, Path(current_path, 'val'), gt_path] 147 | dataset.files = SliceDataset.load_images(dataset.folders, dataset.filenames, False) 148 | 149 | print(f">>> Doing epoch {str(current_path)}") 150 | 151 | done_img: int = 0 152 | for i, data in enumerate(tqdm(dataloader, leave=None)): 153 | target: Tensor = data["gt"] 154 | prediction: Tensor = data["labels"][0] 155 | B, *_ = target.shape 156 | # slice_names: Tensor = data['filenames'] 157 | 158 | # assert len(slice_names) == target.shape[0] 159 | # print(slice_names) 160 | 161 | if (match := resolution_regex.match(data['filenames'][0])): 162 | pid: str = match.group(1) 163 | else: 164 | raise ValueError 165 | 166 | voxelspacing: Optional[Tuple[float, float, float]] 167 | if spacing_dict: 168 | voxelspacing = spacing_dict[pid] 169 | # Need to go from (dx, dy, dz) to (dz, dx, dy) (z is on the batch axis now) 170 | voxelspacing = (voxelspacing[2], voxelspacing[0], voxelspacing[1]) 171 | assert len(voxelspacing) == 3 172 | else: 173 | voxelspacing = None 174 | # print(f"{pid=} {voxelspacing=}") 175 | 176 | assert target.shape == prediction.shape 177 | 178 | if 'hausdorff' in args.metrics: 179 | hausdorff_res: Tensor = hausdorff(prediction, target, 180 | data["spacings"]) 181 | assert hausdorff_res.shape == (B, K) 182 | metrics['hausdorff'][e, done_img:done_img + B, ...] = hausdorff_res[...] 183 | 184 | if 'boundary' in args.metrics: 185 | distmap: Tensor = data["labels"][1] 186 | bd: Tensor = einsum("bkwh,bkwh->bk", prediction.type(torch.float32), distmap) 187 | 188 | metrics['boundary'][e, done_img:done_img + B, ...] = bd 189 | 190 | if '3d_dsc' in args.metrics: 191 | dsc: Tensor = dice_batch(target.to(device), prediction.to(device)) 192 | assert dsc.shape == (K,) 193 | 194 | metrics['3d_dsc'][e, i, :] = dsc.cpu() 195 | 196 | np_pred: np.ndarray 197 | np_target: np.ndarray 198 | if '3d_hausdorff' or '3d_hd95' in args.metrics: 199 | np_pred = prediction.numpy().astype(np.uint8) 200 | np_target = target.numpy().astype(np.uint8) 201 | 202 | list_float: List[float] 203 | if '3d_hausdorff' in args.metrics: 204 | def cb_1(r): 205 | metrics["3d_hausdorff"][e, i, 1:] = torch.tensor(r) 206 | pool.starmap_async(partial(get_hd_thing, 207 | fn=hd, 208 | voxelspacing=voxelspacing), 209 | ((np_pred[:, k, :, :], np_target[:, k, :, :]) 210 | for k in range(1, K)), 211 | callback=cb_1) 212 | if '3d_hd95' in args.metrics: 213 | def cb_2(r): 214 | metrics["3d_hd95"][e, i, 1:] = torch.tensor(r) 215 | pool.starmap_async(partial(get_hd_thing, 216 | fn=hd95, 217 | voxelspacing=voxelspacing), 218 | ((np_pred[:, k, :, :], np_target[:, k, :, :]) 219 | for k in range(1, K)), 220 | callback=cb_2) 221 | 222 | pool.close() 223 | pool.join() 224 | 225 | for metric in args.metrics: 226 | # For now, hardcode the fact we care about class 1 only 227 | print(f">> {metric}: {metrics[metric][e].mean(dim=0)[1]:.04f}") 228 | 229 | key: str 230 | el: Tensor 231 | for key, el in metrics.items(): 232 | np.save(Path(args.basefolder, f"val_{key}.npy"), el.cpu().numpy()) 233 | 234 | 235 | def get_hd_thing(np_pred: np.ndarray, np_target: np.ndarray, fn, voxelspacing): 236 | hd_thing: float 237 | if np_pred.sum() > 0: 238 | hd_thing = fn(np_pred, np_target, voxelspacing=voxelspacing) 239 | else: 240 | x, y, z = np_pred.shape 241 | dx, dy, dz = voxelspacing if voxelspacing else (1, 1, 1) 242 | hd_thing = ((dx * x)**2 + (dy * y)**2 + (dz * z)**2)**0.5 243 | 244 | return hd_thing 245 | 246 | 247 | if __name__ == '__main__': 248 | main() 249 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LIVIAETS/boundary-loss/171c32d88a4ce59af8be46fb88b96d3637b9515b/models/__init__.py -------------------------------------------------------------------------------- /models/enet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.7 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | 30 | def conv_block_1(in_dim, out_dim): 31 | model = nn.Sequential( 32 | nn.Conv2d(in_dim, out_dim, kernel_size=1), 33 | nn.BatchNorm2d(out_dim), 34 | nn.PReLU(), 35 | ) 36 | return model 37 | 38 | 39 | def conv_block_3_3(in_dim, out_dim): 40 | model = nn.Sequential( 41 | nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(out_dim), 43 | nn.PReLU(), 44 | ) 45 | return model 46 | 47 | 48 | def conv_block_Asym(in_dim, out_dim, kernelSize): 49 | model = nn.Sequential( 50 | nn.Conv2d(in_dim, out_dim, kernel_size=[kernelSize, 1], padding=tuple([2, 0])), 51 | nn.Conv2d(out_dim, out_dim, kernel_size=[1, kernelSize], padding=tuple([0, 2])), 52 | nn.BatchNorm2d(out_dim), 53 | nn.PReLU(), 54 | ) 55 | return model 56 | 57 | 58 | def convBatch(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv2d, dilation=1): 59 | return nn.Sequential( 60 | layer(nin, nout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation), 61 | nn.BatchNorm2d(nout), 62 | nn.PReLU() 63 | ) 64 | 65 | 66 | def upSampleConv(nin, nout, kernel_size=3, upscale=2, padding=1, bias=False): 67 | return nn.Sequential( 68 | # nn.Upsample(scale_factor=upscale), 69 | interpolate(mode='nearest', scale_factor=upscale), 70 | convBatch(nin, nout, kernel_size=kernel_size, stride=1, padding=padding, bias=bias), 71 | convBatch(nout, nout, kernel_size=3, stride=1, padding=1, bias=bias), 72 | ) 73 | 74 | 75 | class interpolate(nn.Module): 76 | def __init__(self, scale_factor, mode='nearest'): 77 | super().__init__() 78 | 79 | self.scale_factor = scale_factor 80 | self.mode = mode 81 | 82 | def forward(self, cin): 83 | return F.interpolate(cin, mode=self.mode, scale_factor=self.scale_factor) 84 | 85 | 86 | class BottleNeckDownSampling(nn.Module): 87 | def __init__(self, in_dim, projectionFactor, out_dim): 88 | super().__init__() 89 | 90 | # Main branch 91 | self.maxpool0 = nn.MaxPool2d(2, return_indices=True) 92 | # Secondary branch 93 | self.conv0 = nn.Conv2d(in_dim, int(in_dim / projectionFactor), kernel_size=2, stride=2) 94 | self.bn0 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 95 | self.PReLU0 = nn.PReLU() 96 | 97 | self.conv1 = nn.Conv2d(int(in_dim / projectionFactor), int(in_dim / projectionFactor), kernel_size=3, padding=1) 98 | self.bn1 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 99 | self.PReLU1 = nn.PReLU() 100 | 101 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 102 | 103 | self.do = nn.Dropout(p=0.01) 104 | self.PReLU3 = nn.PReLU() 105 | 106 | def forward(self, input): 107 | # Main branch 108 | maxpool_output, indices = self.maxpool0(input) 109 | 110 | # Secondary branch 111 | c0 = self.conv0(input) 112 | b0 = self.bn0(c0) 113 | p0 = self.PReLU0(b0) 114 | 115 | c1 = self.conv1(p0) 116 | b1 = self.bn1(c1) 117 | p1 = self.PReLU1(b1) 118 | 119 | p2 = self.block2(p1) 120 | 121 | do = self.do(p2) 122 | 123 | # Zero padding the feature maps from the main branch 124 | depth_to_pad = abs(maxpool_output.shape[1] - do.shape[1]) 125 | padding = torch.zeros(maxpool_output.shape[0], depth_to_pad, maxpool_output.shape[2], 126 | maxpool_output.shape[3], device=maxpool_output.device) 127 | maxpool_output_pad = torch.cat((maxpool_output, padding), 1) 128 | output = maxpool_output_pad + do 129 | 130 | # _, c, _, _ = maxpool_output.shape 131 | # output = do 132 | # output[:, :c, :, :] += maxpool_output 133 | 134 | final_output = self.PReLU3(output) 135 | 136 | return final_output, indices 137 | 138 | 139 | class BottleNeckNormal(nn.Module): 140 | def __init__(self, in_dim, out_dim, projectionFactor, dropoutRate): 141 | super(BottleNeckNormal, self).__init__() 142 | self.in_dim = in_dim 143 | self.out_dim = out_dim 144 | # Main branch 145 | 146 | # Secondary branch 147 | self.block0 = conv_block_1(in_dim, int(in_dim / projectionFactor)) 148 | self.block1 = conv_block_3_3(int(in_dim / projectionFactor), int(in_dim / projectionFactor)) 149 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 150 | 151 | self.do = nn.Dropout(p=dropoutRate) 152 | self.PReLU_out = nn.PReLU() 153 | 154 | if in_dim > out_dim: 155 | self.conv_out = conv_block_1(in_dim, out_dim) 156 | 157 | def forward(self, input): 158 | # Main branch 159 | # Secondary branch 160 | b0 = self.block0(input) 161 | b1 = self.block1(b0) 162 | b2 = self.block2(b1) 163 | do = self.do(b2) 164 | 165 | if self.in_dim > self.out_dim: 166 | output = self.conv_out(input) + do 167 | else: 168 | output = input + do 169 | output = self.PReLU_out(output) 170 | 171 | return output 172 | 173 | 174 | class BottleNeckDownSamplingDilatedConv(nn.Module): 175 | def __init__(self, in_dim, projectionFactor, out_dim, dilation): 176 | super(BottleNeckDownSamplingDilatedConv, self).__init__() 177 | # Main branch 178 | 179 | # Secondary branch 180 | self.block0 = conv_block_1(in_dim, int(in_dim / projectionFactor)) 181 | 182 | self.conv1 = nn.Conv2d(int(in_dim / projectionFactor), int(in_dim / projectionFactor), kernel_size=3, 183 | padding=dilation, dilation=dilation) 184 | self.bn1 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 185 | self.PReLU1 = nn.PReLU() 186 | 187 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 188 | 189 | self.do = nn.Dropout(p=0.01) 190 | self.PReLU3 = nn.PReLU() 191 | 192 | def forward(self, input): 193 | # Secondary branch 194 | b0 = self.block0(input) 195 | 196 | c1 = self.conv1(b0) 197 | b1 = self.bn1(c1) 198 | p1 = self.PReLU1(b1) 199 | 200 | b2 = self.block2(p1) 201 | 202 | do = self.do(b2) 203 | 204 | output = input + do 205 | output = self.PReLU3(output) 206 | 207 | return output 208 | 209 | 210 | class BottleNeckNormal_Asym(nn.Module): 211 | def __init__(self, in_dim, out_dim, projectionFactor, dropoutRate): 212 | super(BottleNeckNormal_Asym, self).__init__() 213 | self.in_dim = in_dim 214 | self.out_dim = out_dim 215 | # Main branch 216 | 217 | # Secondary branch 218 | self.block0 = conv_block_1(in_dim, int(in_dim / projectionFactor)) 219 | self.block1 = conv_block_Asym(int(in_dim / projectionFactor), int(in_dim / projectionFactor), 5) 220 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 221 | 222 | self.do = nn.Dropout(p=dropoutRate) 223 | self.PReLU_out = nn.PReLU() 224 | 225 | if in_dim > out_dim: 226 | self.conv_out = conv_block_1(in_dim, out_dim) 227 | 228 | def forward(self, input): 229 | # Main branch 230 | # Secondary branch 231 | b0 = self.block0(input) 232 | b1 = self.block1(b0) 233 | b2 = self.block2(b1) 234 | do = self.do(b2) 235 | 236 | if self.in_dim > self.out_dim: 237 | output = self.conv_out(input) + do 238 | else: 239 | output = input + do 240 | output = self.PReLU_out(output) 241 | 242 | return output 243 | 244 | 245 | class BottleNeckDownSamplingDilatedConvLast(nn.Module): 246 | def __init__(self, in_dim, projectionFactor, out_dim, dilation): 247 | super(BottleNeckDownSamplingDilatedConvLast, self).__init__() 248 | # Main branch 249 | 250 | # Secondary branch 251 | self.block0 = conv_block_1(in_dim, int(in_dim / projectionFactor)) 252 | 253 | self.conv1 = nn.Conv2d(int(in_dim / projectionFactor), int(in_dim / projectionFactor), kernel_size=3, 254 | padding=dilation, dilation=dilation) 255 | self.bn1 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 256 | self.PReLU1 = nn.PReLU() 257 | 258 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 259 | 260 | self.do = nn.Dropout(p=0.01) 261 | self.conv_out = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 262 | self.PReLU3 = nn.PReLU() 263 | 264 | def forward(self, input): 265 | 266 | # Secondary branch 267 | b0 = self.block0(input) 268 | 269 | c1 = self.conv1(b0) 270 | b1 = self.bn1(c1) 271 | p1 = self.PReLU1(b1) 272 | 273 | b2 = self.block2(p1) 274 | 275 | do = self.do(b2) 276 | 277 | output = self.conv_out(input) + do 278 | output = self.PReLU3(output) 279 | 280 | return output 281 | 282 | 283 | class BottleNeckUpSampling(nn.Module): 284 | def __init__(self, in_dim, projectionFactor, out_dim): 285 | super(BottleNeckUpSampling, self).__init__() 286 | # Main branch 287 | self.conv0 = nn.Conv2d(in_dim, int(in_dim / projectionFactor), kernel_size=3, padding=1) 288 | self.bn0 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 289 | self.PReLU0 = nn.PReLU() 290 | 291 | self.conv1 = nn.Conv2d(int(in_dim / projectionFactor), int(in_dim / projectionFactor), kernel_size=3, padding=1) 292 | self.bn1 = nn.BatchNorm2d(int(in_dim / projectionFactor)) 293 | self.PReLU1 = nn.PReLU() 294 | 295 | self.block2 = conv_block_1(int(in_dim / projectionFactor), out_dim) 296 | 297 | self.do = nn.Dropout(p=0.01) 298 | self.PReLU3 = nn.PReLU() 299 | 300 | def forward(self, input): 301 | # Secondary branch 302 | c0 = self.conv0(input) 303 | b0 = self.bn0(c0) 304 | p0 = self.PReLU0(b0) 305 | 306 | c1 = self.conv1(p0) 307 | b1 = self.bn1(c1) 308 | p1 = self.PReLU1(b1) 309 | 310 | p2 = self.block2(p1) 311 | 312 | do = self.do(p2) 313 | 314 | return do -------------------------------------------------------------------------------- /models/residualunet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.8 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import torch.nn as nn 26 | 27 | 28 | def maxpool(): 29 | pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 30 | return pool 31 | 32 | 33 | def conv_block(in_dim, out_dim, act_fn, kernel_size=3, stride=1, padding=1, dilation=1): 34 | model = nn.Sequential( 35 | nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation), 36 | nn.BatchNorm2d(out_dim), 37 | act_fn, 38 | ) 39 | return model 40 | 41 | 42 | def conv_block_3(in_dim, out_dim, act_fn): 43 | model = nn.Sequential( 44 | conv_block(in_dim, out_dim, act_fn), 45 | conv_block(out_dim, out_dim, act_fn), 46 | nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1), 47 | nn.BatchNorm2d(out_dim), 48 | ) 49 | return model 50 | 51 | 52 | # TODO: Change order of block: BN + Activation + Conv 53 | def conv_decod_block(in_dim, out_dim, act_fn): 54 | model = nn.Sequential( 55 | nn.ConvTranspose2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1), 56 | nn.BatchNorm2d(out_dim), 57 | act_fn, 58 | ) 59 | return model 60 | 61 | 62 | class Conv_residual_conv(nn.Module): 63 | def __init__(self, in_dim, out_dim, act_fn): 64 | super().__init__() 65 | self.in_dim = in_dim 66 | self.out_dim = out_dim 67 | act_fn = act_fn 68 | 69 | self.conv_1 = conv_block(self.in_dim, self.out_dim, act_fn) 70 | self.conv_2 = conv_block_3(self.out_dim, self.out_dim, act_fn) 71 | self.conv_3 = conv_block(self.out_dim, self.out_dim, act_fn) 72 | 73 | def forward(self, input): 74 | conv_1 = self.conv_1(input) 75 | conv_2 = self.conv_2(conv_1) 76 | res = conv_1 + conv_2 77 | conv_3 = self.conv_3(res) 78 | return conv_3 79 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.7 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | 29 | def convBatch(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv2d, dilation=1): 30 | return nn.Sequential( 31 | layer(nin, nout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation), 32 | nn.BatchNorm2d(nout), 33 | nn.PReLU() 34 | ) 35 | 36 | 37 | def upSampleConv(nin, nout, kernel_size=3, upscale=2, padding=1, bias=False): 38 | return nn.Sequential( 39 | # nn.Upsample(scale_factor=upscale), 40 | interpolate(mode='nearest', scale_factor=upscale), 41 | convBatch(nin, nout, kernel_size=kernel_size, stride=1, padding=padding, bias=bias), 42 | convBatch(nout, nout, kernel_size=3, stride=1, padding=1, bias=bias), 43 | ) 44 | 45 | 46 | class interpolate(nn.Module): 47 | def __init__(self, scale_factor, mode='nearest'): 48 | super().__init__() 49 | 50 | self.scale_factor = scale_factor 51 | self.mode = mode 52 | 53 | def forward(self, cin): 54 | return F.interpolate(cin, mode=self.mode, scale_factor=self.scale_factor) 55 | 56 | 57 | class residualConv(nn.Module): 58 | def __init__(self, nin, nout): 59 | super(residualConv, self).__init__() 60 | self.convs = nn.Sequential( 61 | convBatch(nin, nout), 62 | nn.Conv2d(nout, nout, kernel_size=3, stride=1, padding=1), 63 | nn.BatchNorm2d(nout) 64 | ) 65 | self.res = nn.Sequential() 66 | if nin != nout: 67 | self.res = nn.Sequential( 68 | nn.Conv2d(nin, nout, kernel_size=1, bias=False), 69 | nn.BatchNorm2d(nout) 70 | ) 71 | 72 | def forward(self, input): 73 | out = self.convs(input) 74 | return F.leaky_relu(out + self.res(input), 0.2) 75 | -------------------------------------------------------------------------------- /models/unet_3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.7 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | 26 | 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | 31 | def convBatch(nin, nout, kernel_size=3, stride=1, padding=1, bias=False, layer=nn.Conv3d, dilation=1): 32 | return nn.Sequential( 33 | layer(nin, nout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, dilation=dilation), 34 | nn.BatchNorm3d(nout), 35 | nn.PReLU() 36 | ) 37 | 38 | 39 | def upSampleConv(nin, nout, kernel_size=3, upscale=2, padding=1, bias=False): 40 | return nn.Sequential( 41 | # nn.Upsample(scale_factor=upscale), 42 | interpolate(mode='nearest', scale_factor=upscale), 43 | convBatch(nin, nout, kernel_size=kernel_size, stride=1, padding=padding, bias=bias), 44 | convBatch(nout, nout, kernel_size=3, stride=1, padding=1, bias=bias), 45 | ) 46 | 47 | 48 | class interpolate(nn.Module): 49 | def __init__(self, scale_factor, mode='nearest'): 50 | super().__init__() 51 | 52 | self.scale_factor = scale_factor 53 | self.mode = mode 54 | 55 | def forward(self, cin): 56 | return F.interpolate(cin, mode=self.mode, scale_factor=self.scale_factor) 57 | 58 | 59 | class residualConv(nn.Module): 60 | def __init__(self, nin, nout): 61 | super(residualConv, self).__init__() 62 | self.convs = nn.Sequential( 63 | convBatch(nin, nout), 64 | nn.Conv3d(nout, nout, kernel_size=3, stride=1, padding=1), 65 | nn.BatchNorm3d(nout) 66 | ) 67 | self.res = nn.Sequential() 68 | if nin != nout: 69 | self.res = nn.Sequential( 70 | nn.Conv3d(nin, nout, kernel_size=1, bias=False), 71 | nn.BatchNorm3d(nout) 72 | ) 73 | 74 | def forward(self, input): 75 | out = self.convs(input) 76 | return F.leaky_relu(out + self.res(input), 0.2) 77 | -------------------------------------------------------------------------------- /moustache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import argparse 26 | from typing import List 27 | from pathlib import Path 28 | 29 | import numpy as np 30 | import matplotlib 31 | matplotlib.use("Agg") 32 | import matplotlib.pyplot as plt 33 | 34 | from utils import map_ 35 | 36 | 37 | def run(args: argparse.Namespace) -> None: 38 | plt.rc('font', size=args.fontsize) 39 | # if len(args.columns) > 1: 40 | # raise NotImplementedError("Only 1 columns at a time is handled for now") 41 | 42 | paths: List[Path] = [Path(f, args.filename) for f in args.folders] 43 | arrays: List[np.ndarray] = map_(np.load, paths) 44 | metric_name: str = paths[0].stem 45 | 46 | assert len(set(a.shape for a in arrays)) == 1 # All arrays should have the same shape 47 | if len(arrays[0].shape) == 2: 48 | arrays = map_(lambda a: a[..., np.newaxis], arrays) # Add an extra dimension for column selection 49 | 50 | fig = plt.figure(figsize=args.figsize) 51 | ax = fig.gca() 52 | 53 | ymin, ymax = args.ylim # Tuple[int, int] 54 | ax.set_ylim(ymin, ymax) 55 | yrange: int = ymax - ymin 56 | ystep: float = yrange / 10 57 | ax.set_yticks(np.mgrid[ymin:ymax + ystep:ystep]) 58 | 59 | if not args.xlabel: 60 | ax.set_xlabel(metric_name) 61 | else: 62 | ax.set_xlabel(args.xlabel) 63 | 64 | if not args.ylabel: 65 | ax.set_ylabel("Percentage") 66 | else: 67 | ax.set_ylabel(args.ylabel) 68 | 69 | ax.grid(True, axis='y') 70 | if not args.title: 71 | ax.set_title(f"{metric_name} moustaches") 72 | else: 73 | ax.set_title(args.title) 74 | 75 | # bins = np.linspace(0, 1, args.nbins) 76 | pos = 0 77 | for i, (a, p) in enumerate(zip(arrays, paths)): 78 | for k in args.columns: 79 | mean_a = a[..., k].mean(axis=1) 80 | best_epoch: int = np.argmax(mean_a) 81 | 82 | # values = a[args.epc, :, k] 83 | values = a[best_epoch, :, k] 84 | 85 | ax.boxplot(values, positions=[pos + 1], manage_ticks=False, showmeans=True, meanline=True, whis=[5, 95]) 86 | print(f"{p.parent.stem:10}: min {values.min():.03f} 25{np.percentile(values, 25):.03f} " 87 | + f"avg {values.mean():.03f} 75 {np.percentile(values, 75):.03f} max {values.max():.03f} at epc {best_epoch}") 88 | 89 | pos += 1 90 | # ax.legend() 91 | 92 | if not args.labels: 93 | ax.set_xticklabels([""] + [f"{p.parent.stem}-{k}" for p in paths for k in range(len(args.columns))], 94 | rotation=60) 95 | else: 96 | if len(args.columns): 97 | ax.set_xticklabels([""] + [f"{l}-{k}" for l in args.labels for k in range(len(args.columns))], 98 | rotation=60) 99 | else: 100 | ax.set_xticklabels([""] + [f"{l}" for l in args.labels], 101 | rotation=60) 102 | 103 | ax.set_xticks(np.mgrid[0:len(args.folders) * len(args.columns) + 1]) 104 | 105 | fig.tight_layout() 106 | if args.savefig: 107 | fig.savefig(args.savefig) 108 | 109 | if not args.headless: 110 | plt.show() 111 | 112 | 113 | def get_args() -> argparse.Namespace: 114 | parser = argparse.ArgumentParser(description='Plot data over time') 115 | parser.add_argument('--folders', type=str, required=True, nargs='+', help="The folders containing the file") 116 | parser.add_argument('--filename', type=str, required=True) 117 | parser.add_argument('--columns', type=int, nargs='+', default=0, help="Which columns of the third axis to plot") 118 | parser.add_argument("--savefig", type=str, default=None) 119 | parser.add_argument("--headless", action="store_true") 120 | parser.add_argument("--nbins", type=int, default=100) 121 | parser.add_argument("--epc", type=int, required=True) 122 | 123 | parser.add_argument("--ylim", type=float, nargs=2, default=[0, 1]) 124 | 125 | parser.add_argument("--xlabel", type=str, default='') 126 | parser.add_argument("--ylabel", type=str, default='') 127 | parser.add_argument("--labels", type=str, nargs='*') 128 | parser.add_argument("--title", type=str, default=None) 129 | parser.add_argument("--loc", type=str, default=None) 130 | parser.add_argument("--figsize", type=int, nargs='*', default=[14, 9]) 131 | parser.add_argument("--fontsize", type=int, default=10, help="Dummy opt for compatibility") 132 | 133 | # Dummies 134 | parser.add_argument("--debug", action="store_true", help="Dummy for compatibility") 135 | parser.add_argument("--cpu", action="store_true", help="Dummy for compatibility") 136 | parser.add_argument("--save_csv", action="store_true", help="Dummy for compatibility") 137 | args = parser.parse_args() 138 | 139 | print(args) 140 | 141 | return args 142 | 143 | 144 | if __name__ == "__main__": 145 | run(get_args()) 146 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import math 26 | 27 | import torch 28 | from torch import nn 29 | from torch import Tensor 30 | 31 | 32 | def random_weights_init(m): 33 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 34 | nn.init.xavier_normal_(m.weight.data) 35 | elif type(m) == nn.BatchNorm2d: 36 | m.weight.data.normal_(1.0, 0.02) 37 | m.bias.data.fill_(0) 38 | 39 | 40 | class Dummy(nn.Module): 41 | def __init__(self, in_dim: int, out_dim: int) -> None: 42 | super().__init__() 43 | 44 | self.down = nn.Conv2d(in_dim, 10, kernel_size=2, stride=2) 45 | self.up = nn.ConvTranspose2d(10, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1) 46 | 47 | def forward(self, input: Tensor) -> Tensor: 48 | return self.up(self.down(input)) 49 | 50 | def init_weights(self, *args, **kwargs): 51 | self.apply(random_weights_init) 52 | 53 | 54 | Dimwit = Dummy 55 | 56 | 57 | class Dummy3D(nn.Module): 58 | def __init__(self, in_dim: int, out_dim: int) -> None: 59 | super().__init__() 60 | 61 | self.down = nn.Conv3d(in_dim, 10, kernel_size=2, stride=2) 62 | self.up = nn.ConvTranspose3d(10, out_dim, kernel_size=3, stride=2, padding=1, output_padding=1) 63 | 64 | def forward(self, input: Tensor) -> Tensor: 65 | return self.up(self.down(input)) 66 | 67 | def init_weights(self, *args, **kwargs): 68 | self.apply(random_weights_init) 69 | 70 | 71 | Dimwit3D = Dummy3D 72 | 73 | 74 | class UNet(nn.Module): 75 | def __init__(self, in_dim: int, out_dim: int, nG=64) -> None: 76 | super().__init__() 77 | 78 | from models.unet import (convBatch, 79 | residualConv, 80 | upSampleConv) 81 | 82 | self.conv0 = nn.Sequential(convBatch(in_dim, nG), 83 | convBatch(nG, nG)) 84 | self.conv1 = nn.Sequential(convBatch(nG * 1, nG * 2, stride=2), 85 | convBatch(nG * 2, nG * 2)) 86 | self.conv2 = nn.Sequential(convBatch(nG * 2, nG * 4, stride=2), 87 | convBatch(nG * 4, nG * 4)) 88 | 89 | self.bridge = nn.Sequential(convBatch(nG * 4, nG * 8, stride=2), 90 | residualConv(nG * 8, nG * 8), 91 | convBatch(nG * 8, nG * 8)) 92 | 93 | self.deconv1 = upSampleConv(nG * 8, nG * 8) 94 | self.conv5 = nn.Sequential(convBatch(nG * 12, nG * 4), 95 | convBatch(nG * 4, nG * 4)) 96 | self.deconv2 = upSampleConv(nG * 4, nG * 4) 97 | self.conv6 = nn.Sequential(convBatch(nG * 6, nG * 2), 98 | convBatch(nG * 2, nG * 2)) 99 | self.deconv3 = upSampleConv(nG * 2, nG * 2) 100 | self.conv7 = nn.Sequential(convBatch(nG * 3, nG * 1), 101 | convBatch(nG * 1, nG * 1)) 102 | self.final = nn.Conv2d(nG, out_dim, kernel_size=1) 103 | 104 | def forward(self, input): 105 | input = input.float() 106 | x0 = self.conv0(input) 107 | x1 = self.conv1(x0) 108 | x2 = self.conv2(x1) 109 | 110 | bridge = self.bridge(x2) 111 | 112 | y0 = self.deconv1(bridge) 113 | y1 = self.deconv2(self.conv5(torch.cat((y0, x2), dim=1))) 114 | y2 = self.deconv3(self.conv6(torch.cat((y1, x1), dim=1))) 115 | y3 = self.conv7(torch.cat((y2, x0), dim=1)) 116 | 117 | return self.final(y3) 118 | 119 | def init_weights(self, *args, **kwargs): 120 | self.apply(random_weights_init) 121 | 122 | 123 | class UNet3D(nn.Module): 124 | def __init__(self, nin: int, nout: int, nG=64): 125 | super().__init__() 126 | 127 | from models.unet_3d import (convBatch, 128 | residualConv, 129 | upSampleConv) 130 | 131 | self.conv0 = nn.Sequential(convBatch(nin, nG), 132 | convBatch(nG, nG)) 133 | self.conv1 = nn.Sequential(convBatch(nG * 1, nG * 2, stride=2), 134 | convBatch(nG * 2, nG * 2)) 135 | self.conv2 = nn.Sequential(convBatch(nG * 2, nG * 4, stride=2), 136 | convBatch(nG * 4, nG * 4)) 137 | 138 | self.bridge = nn.Sequential(convBatch(nG * 4, nG * 8, stride=2), 139 | residualConv(nG * 8, nG * 8), 140 | convBatch(nG * 8, nG * 8)) 141 | 142 | self.deconv1 = upSampleConv(nG * 8, nG * 8) 143 | self.conv5 = nn.Sequential(convBatch(nG * 12, nG * 4), 144 | convBatch(nG * 4, nG * 4)) 145 | self.deconv2 = upSampleConv(nG * 4, nG * 4) 146 | self.conv6 = nn.Sequential(convBatch(nG * 6, nG * 2), 147 | convBatch(nG * 2, nG * 2)) 148 | self.deconv3 = upSampleConv(nG * 2, nG * 2) 149 | self.conv7 = nn.Sequential(convBatch(nG * 3, nG * 1), 150 | convBatch(nG * 1, nG * 1)) 151 | self.final = nn.Conv3d(nG, nout, kernel_size=1) 152 | 153 | def forward(self, input): 154 | input = input.float() 155 | x0 = self.conv0(input) 156 | x1 = self.conv1(x0) 157 | x2 = self.conv2(x1) 158 | 159 | bridge = self.bridge(x2) 160 | 161 | y0 = self.deconv1(bridge) 162 | # print(f"{x0.shape=} {x1.shape=}") 163 | # print(f"{y0.shape=} {x2.shape=}") 164 | y1 = self.deconv2(self.conv5(torch.cat((y0, x2), dim=1))) 165 | y2 = self.deconv3(self.conv6(torch.cat((y1, x1), dim=1))) 166 | y3 = self.conv7(torch.cat((y2, x0), dim=1)) 167 | 168 | return self.final(y3) 169 | 170 | def init_weights(self, *args, **kwargs): 171 | self.apply(random_weights_init) 172 | 173 | 174 | class ENet(nn.Module): 175 | def __init__(self, in_dim: int, out_dim: int): 176 | super().__init__() 177 | self.projecting_factor = 4 178 | self.n_kernels = 16 179 | 180 | from models.enet import (BottleNeckDownSampling, 181 | BottleNeckNormal, 182 | BottleNeckDownSamplingDilatedConv, 183 | BottleNeckNormal_Asym, 184 | BottleNeckDownSamplingDilatedConvLast, 185 | BottleNeckUpSampling, 186 | upSampleConv) 187 | 188 | # Initial 189 | self.conv0 = nn.Conv2d(in_dim, 15, kernel_size=3, stride=2, padding=1) 190 | self.maxpool0 = nn.MaxPool2d(2, return_indices=True) 191 | 192 | # First group 193 | self.bottleNeck1_0 = BottleNeckDownSampling(self.n_kernels, self.projecting_factor, self.n_kernels * 4) 194 | self.bottleNeck1_1 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01) 195 | self.bottleNeck1_2 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01) 196 | self.bottleNeck1_3 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01) 197 | self.bottleNeck1_4 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 0.01) 198 | 199 | # Second group 200 | self.bottleNeck2_0 = BottleNeckDownSampling(self.n_kernels * 4, self.projecting_factor, self.n_kernels * 8) 201 | self.bottleNeck2_1 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1) 202 | self.bottleNeck2_2 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 203 | self.n_kernels * 8, 2) 204 | self.bottleNeck2_3 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 205 | 0.1) 206 | self.bottleNeck2_4 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 207 | self.n_kernels * 8, 4) 208 | self.bottleNeck2_5 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1) 209 | self.bottleNeck2_6 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 210 | self.n_kernels * 8, 8) 211 | self.bottleNeck2_7 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 212 | 0.1) 213 | self.bottleNeck2_8 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 214 | self.n_kernels * 8, 16) 215 | 216 | # Third group 217 | self.bottleNeck3_1 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1) 218 | self.bottleNeck3_2 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 219 | self.n_kernels * 8, 2) 220 | self.bottleNeck3_3 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 221 | 0.1) 222 | self.bottleNeck3_4 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 223 | self.n_kernels * 8, 4) 224 | self.bottleNeck3_5 = BottleNeckNormal(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 0.1) 225 | self.bottleNeck3_6 = BottleNeckDownSamplingDilatedConv(self.n_kernels * 8, self.projecting_factor, 226 | self.n_kernels * 8, 8) 227 | self.bottleNeck3_7 = BottleNeckNormal_Asym(self.n_kernels * 8, self.n_kernels * 8, self.projecting_factor, 228 | 0.1) 229 | self.bottleNeck3_8 = BottleNeckDownSamplingDilatedConvLast(self.n_kernels * 8, self.projecting_factor, 230 | self.n_kernels * 4, 16) 231 | 232 | # ### Decoding path #### 233 | # Unpooling 1 234 | self.unpool_0 = nn.MaxUnpool2d(2) 235 | 236 | self.bottleNeck_Up_1_0 = BottleNeckUpSampling(self.n_kernels * 8, self.projecting_factor, 237 | self.n_kernels * 4) 238 | self.PReLU_Up_1 = nn.PReLU() 239 | 240 | self.bottleNeck_Up_1_1 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels * 4, self.projecting_factor, 241 | 0.1) 242 | self.bottleNeck_Up_1_2 = BottleNeckNormal(self.n_kernels * 4, self.n_kernels, self.projecting_factor, 0.1) 243 | 244 | # Unpooling 2 245 | self.unpool_1 = nn.MaxUnpool2d(2) 246 | self.bottleNeck_Up_2_1 = BottleNeckUpSampling(self.n_kernels * 2, self.projecting_factor, self.n_kernels) 247 | self.bottleNeck_Up_2_2 = BottleNeckNormal(self.n_kernels, self.n_kernels, self.projecting_factor, 0.1) 248 | self.PReLU_Up_2 = nn.PReLU() 249 | 250 | # Unpooling Last 251 | self.deconv3 = upSampleConv(self.n_kernels, self.n_kernels) 252 | 253 | self.out_025 = nn.Conv2d(self.n_kernels * 8, out_dim, kernel_size=3, stride=1, padding=1) 254 | self.out_05 = nn.Conv2d(self.n_kernels, out_dim, kernel_size=3, stride=1, padding=1) 255 | self.final = nn.Conv2d(self.n_kernels, out_dim, kernel_size=1) 256 | 257 | def forward(self, input): 258 | conv_0 = self.conv0(input) # This will go as res in deconv path 259 | maxpool_0, indices_0 = self.maxpool0(input) 260 | outputInitial = torch.cat((conv_0, maxpool_0), dim=1) 261 | 262 | # First group 263 | bn1_0, indices_1 = self.bottleNeck1_0(outputInitial) 264 | bn1_1 = self.bottleNeck1_1(bn1_0) 265 | bn1_2 = self.bottleNeck1_2(bn1_1) 266 | bn1_3 = self.bottleNeck1_3(bn1_2) 267 | bn1_4 = self.bottleNeck1_4(bn1_3) 268 | 269 | # Second group 270 | bn2_0, indices_2 = self.bottleNeck2_0(bn1_4) 271 | bn2_1 = self.bottleNeck2_1(bn2_0) 272 | bn2_2 = self.bottleNeck2_2(bn2_1) 273 | bn2_3 = self.bottleNeck2_3(bn2_2) 274 | bn2_4 = self.bottleNeck2_4(bn2_3) 275 | bn2_5 = self.bottleNeck2_5(bn2_4) 276 | bn2_6 = self.bottleNeck2_6(bn2_5) 277 | bn2_7 = self.bottleNeck2_7(bn2_6) 278 | bn2_8 = self.bottleNeck2_8(bn2_7) 279 | 280 | # Third group 281 | bn3_1 = self.bottleNeck3_1(bn2_8) 282 | bn3_2 = self.bottleNeck3_2(bn3_1) 283 | bn3_3 = self.bottleNeck3_3(bn3_2) 284 | bn3_4 = self.bottleNeck3_4(bn3_3) 285 | bn3_5 = self.bottleNeck3_5(bn3_4) 286 | bn3_6 = self.bottleNeck3_6(bn3_5) 287 | bn3_7 = self.bottleNeck3_7(bn3_6) 288 | bn3_8 = self.bottleNeck3_8(bn3_7) 289 | 290 | # #### Deconvolution Path #### 291 | # First block # 292 | unpool_0 = self.unpool_0(bn3_8, indices_2) 293 | 294 | # bn_up_1_0 = self.bottleNeck_Up_1_0(unpool_0) # Not concatenate 295 | bn_up_1_0 = self.bottleNeck_Up_1_0(torch.cat((unpool_0, bn1_4), dim=1)) # concatenate 296 | 297 | up_block_1 = self.PReLU_Up_1(unpool_0 + bn_up_1_0) 298 | 299 | bn_up_1_1 = self.bottleNeck_Up_1_1(up_block_1) 300 | bn_up_1_2 = self.bottleNeck_Up_1_2(bn_up_1_1) 301 | 302 | # Second block # 303 | unpool_1 = self.unpool_1(bn_up_1_2, indices_1) 304 | 305 | # bn_up_2_1 = self.bottleNeck_Up_2_1(unpool_1) # Not concatenate 306 | bn_up_2_1 = self.bottleNeck_Up_2_1(torch.cat((unpool_1, outputInitial), dim=1)) # concatenate 307 | 308 | bn_up_2_2 = self.bottleNeck_Up_2_2(bn_up_2_1) 309 | 310 | up_block_1 = self.PReLU_Up_2(unpool_1 + bn_up_2_2) 311 | 312 | unpool_12 = self.deconv3(up_block_1) 313 | 314 | return self.final(unpool_12) 315 | 316 | def init_weights(self, *args, **kwargs): 317 | self.apply(random_weights_init) 318 | 319 | 320 | class ResidualUNet(nn.Module): 321 | # def __init__(self, output_nc, ngf=32): 322 | def __init__(self, in_dim: int, out_dim: int): 323 | super().__init__() 324 | self.in_dim = in_dim 325 | ngf = 32 326 | self.out_dim = ngf 327 | self.final_out_dim = out_dim 328 | act_fn = nn.LeakyReLU(0.2, inplace=True) 329 | act_fn_2 = nn.ReLU() 330 | 331 | from models.residualunet import (Conv_residual_conv, 332 | maxpool, 333 | conv_decod_block) 334 | 335 | # Encoder 336 | self.down_1 = Conv_residual_conv(self.in_dim, self.out_dim, act_fn) 337 | self.pool_1 = maxpool() 338 | self.down_2 = Conv_residual_conv(self.out_dim, self.out_dim * 2, act_fn) 339 | self.pool_2 = maxpool() 340 | self.down_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 4, act_fn) 341 | self.pool_3 = maxpool() 342 | self.down_4 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 8, act_fn) 343 | self.pool_4 = maxpool() 344 | 345 | # Bridge between Encoder-Decoder 346 | self.bridge = Conv_residual_conv(self.out_dim * 8, self.out_dim * 16, act_fn) 347 | 348 | # Decoder 349 | self.deconv_1 = conv_decod_block(self.out_dim * 16, self.out_dim * 8, act_fn_2) 350 | self.up_1 = Conv_residual_conv(self.out_dim * 8, self.out_dim * 8, act_fn_2) 351 | self.deconv_2 = conv_decod_block(self.out_dim * 8, self.out_dim * 4, act_fn_2) 352 | self.up_2 = Conv_residual_conv(self.out_dim * 4, self.out_dim * 4, act_fn_2) 353 | self.deconv_3 = conv_decod_block(self.out_dim * 4, self.out_dim * 2, act_fn_2) 354 | self.up_3 = Conv_residual_conv(self.out_dim * 2, self.out_dim * 2, act_fn_2) 355 | self.deconv_4 = conv_decod_block(self.out_dim * 2, self.out_dim, act_fn_2) 356 | self.up_4 = Conv_residual_conv(self.out_dim, self.out_dim, act_fn_2) 357 | 358 | self.out = nn.Conv2d(self.out_dim, self.final_out_dim, kernel_size=3, stride=1, padding=1) 359 | 360 | print(f"Initialized {self.__class__.__name__} succesfully") 361 | 362 | def forward(self, input): 363 | # Encoding path 364 | 365 | down_1 = self.down_1(input) # This will go as res in deconv path 366 | down_2 = self.down_2(self.pool_1(down_1)) 367 | down_3 = self.down_3(self.pool_2(down_2)) 368 | down_4 = self.down_4(self.pool_3(down_3)) 369 | 370 | bridge = self.bridge(self.pool_4(down_4)) 371 | 372 | # Decoding path 373 | deconv_1 = self.deconv_1(bridge) 374 | skip_1 = (deconv_1 + down_4) / 2 # Residual connection 375 | up_1 = self.up_1(skip_1) 376 | 377 | deconv_2 = self.deconv_2(up_1) 378 | skip_2 = (deconv_2 + down_3) / 2 # Residual connection 379 | up_2 = self.up_2(skip_2) 380 | 381 | deconv_3 = self.deconv_3(up_2) 382 | skip_3 = (deconv_3 + down_2) / 2 # Residual connection 383 | up_3 = self.up_3(skip_3) 384 | 385 | deconv_4 = self.deconv_4(up_3) 386 | skip_4 = (deconv_4 + down_1) / 2 # Residual connection 387 | up_4 = self.up_4(skip_4) 388 | 389 | return self.out(up_4) 390 | 391 | def init_weights(self, *args, **kwargs): 392 | for m in self.modules(): 393 | if isinstance(m, nn.Conv2d): 394 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 395 | m.weight.data.normal_(0, math.sqrt(2. / n)) 396 | elif isinstance(m, nn.BatchNorm2d): 397 | m.weight.data.fill_(1) 398 | m.bias.data.zero_() 399 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import argparse 26 | from typing import List 27 | from pathlib import Path 28 | from subprocess import call 29 | from itertools import cycle 30 | 31 | import numpy as np 32 | import pandas as pd 33 | import matplotlib 34 | matplotlib.use("Agg") 35 | import matplotlib.pyplot as plt 36 | from scipy.interpolate import interp1d 37 | 38 | from utils import map_, colors as util_colors 39 | 40 | 41 | def run(args: argparse.Namespace) -> None: 42 | plt.rc('font', size=args.fontsize) 43 | 44 | colors: List[str] = args.colors if args.colors else util_colors 45 | 46 | styles = ['--', '-.', ':', '-'] 47 | if len(args.folders) > len(colors): 48 | print("Warning: more folders than colors") 49 | assert len(args.columns) <= len(styles) 50 | 51 | paths: List[Path] = [Path(f, args.filename) for f in args.folders] 52 | arrays: List[np.ndarray] = map_(np.load, paths) 53 | 54 | if len(arrays[0].shape) == 2: 55 | arrays = map_(lambda a: a[..., np.newaxis], arrays) 56 | epoch, _, class_ = arrays[0].shape 57 | if args.n_epoch: 58 | epoch = min(epoch, args.n_epoch) 59 | for a in arrays[1:]: 60 | ea, _, ca = a.shape 61 | assert ea <= epoch, (epoch, class_, a.shape) 62 | 63 | if not args.dynamic_third_axis: # Useful for when trainings don't have same number of losses 64 | assert class_ == ca, (epoch, class_, a.shape) 65 | 66 | n_epoch = arrays[0].shape[0] if not args.n_epoch else args.n_epoch 67 | 68 | fig = plt.figure(figsize=args.figsize) 69 | ax = fig.gca() 70 | ax.set_xlim([0, n_epoch - 2]) 71 | 72 | ymin, ymax = args.ylim # Tuple[int, int] 73 | ax.set_ylim(ymin, ymax) 74 | yrange: int = ymax - ymin 75 | ystep: float = yrange / 10 76 | yticks = np.mgrid[ymin:ymax + ystep:ystep] if not args.yticks else args.yticks 77 | 78 | ax.set_yticks(yticks) 79 | 80 | ax.set_xlabel("Epoch") 81 | if args.ylabel: 82 | ax.set_ylabel(args.ylabel) 83 | else: 84 | ax.set_ylabel(Path(args.filename).stem) 85 | ax.grid(True, axis='y') 86 | if args.title: 87 | ax.set_title(args.title) 88 | else: 89 | ax.set_title(f"{paths[0].stem} over epochs") 90 | 91 | if args.labels: 92 | labels = args.labels 93 | else: 94 | labels = [p.parent.name for p in paths] 95 | 96 | epcs = np.arange(n_epoch) 97 | if args.save_csv: 98 | df = pd.DataFrame() 99 | 100 | xnew = np.linspace(0, n_epoch - 1, int(n_epoch * args.sampling_factor)) 101 | for i, (a, c, p, l) in enumerate(zip(arrays, cycle(colors), paths, labels)): 102 | mean_a = a.mean(axis=1) 103 | 104 | _, n_col = mean_a.shape 105 | # For when more args.columns than columns (weird case with varying multiple losses) 106 | allowed_cols: List[int] = list(set(args.columns).intersection(set(range(n_col)))) 107 | 108 | if len(allowed_cols) > 1 and not args.no_mean: 109 | mean_column = mean_a[:, allowed_cols].mean(axis=1) 110 | lw: float = 3 if not args.only_mean else 1.5 111 | lab: str = f"{l}-mean" if not args.only_mean else l 112 | ax.plot(epcs, mean_column[:n_epoch], color=c, linestyle='-', label=lab, linewidth=lw) 113 | 114 | if args.save_csv: 115 | df[lab] = mean_column[:n_epoch] 116 | 117 | if not args.only_mean: 118 | for k, s in zip(allowed_cols, styles): 119 | values = mean_a[..., k] 120 | 121 | if args.smooth: 122 | # smoothed = spline(epcs, values, xnew) 123 | inter_fn = interp1d(epcs, values[:n_epoch], kind='slinear') 124 | smoothed = inter_fn(xnew) 125 | x, y = xnew, smoothed 126 | else: 127 | x, y = epcs, values[:n_epoch] 128 | 129 | lab = l if len(args.columns) == 1 else f"{l}-{k}" 130 | 131 | sty: str 132 | if len(args.columns) == 1: 133 | if args.curves_styles: 134 | sty = args.curves_styles[i][1:] # Have to remove the extra space 135 | else: 136 | sty = '-' 137 | else: 138 | sty = s 139 | 140 | ax.plot(x, y, linestyle=sty, color=c, label=lab, linewidth=1.5) 141 | if args.save_csv: 142 | df[lab] = y 143 | 144 | if args.min: 145 | print(f"{Path(p).parents[0]}, class {k}: {values.min():.04f}") 146 | else: 147 | print(f"{Path(p).parents[0]}, class {k}: {values.max():.04f}") 148 | 149 | if args.hline: 150 | for v, l, s in zip(args.hline, args.l_line, styles): 151 | ax.plot([0, n_epoch], [v, v], linestyle=s, linewidth=1, color='green', label=l) 152 | 153 | ax.legend(loc=args.loc) 154 | 155 | fig.tight_layout() 156 | if args.savefig: 157 | fig.savefig(args.savefig) 158 | if args.trim: 159 | call(["mogrify", "-trim", args.savefig]) 160 | 161 | if not args.headless: 162 | plt.show() 163 | 164 | if args.save_csv and args.savefig: 165 | df.to_csv(Path(args.savefig).with_suffix(".csv"), float_format="%.4f", index_label="epoch") 166 | 167 | 168 | def get_args() -> argparse.Namespace: 169 | parser = argparse.ArgumentParser(description='Plot data over time') 170 | parser.add_argument('--folders', type=str, required=True, nargs='+', help="The folders containing the file") 171 | parser.add_argument('--filename', type=str, required=True) 172 | 173 | parser.add_argument("--headless", action="store_true") 174 | parser.add_argument("--smooth", action="store_true") 175 | parser.add_argument("--trim", action="store_true", help="Remove the whitespaces around the figure") 176 | parser.add_argument("--min", action="store_true", help="Display the min of each file instead of maximum") 177 | parser.add_argument("--debug", action="store_true", help="Dummy for compatibility") 178 | parser.add_argument("--cpu", action="store_true", help="Dummy for compatibility") 179 | parser.add_argument("--no_mean", action="store_true", help="Don't plot the mean line") 180 | parser.add_argument("--only_mean", action="store_true", help="Plot only the mean line") 181 | parser.add_argument("--dynamic_third_axis", action="store_true", 182 | help="Allow the third axis of the arguments to be of varying size") 183 | 184 | parser.add_argument("--savefig", type=str, default=None) 185 | parser.add_argument('--columns', type=int, nargs='+', default=0, help="Which columns of the third axis to plot") 186 | parser.add_argument("--hline", type=float, nargs='*') 187 | parser.add_argument("--ylim", type=float, nargs=2, default=[0, 1]) 188 | 189 | parser.add_argument("--l_line", type=str, nargs='*') 190 | parser.add_argument("--title", type=str, default='') 191 | parser.add_argument("--ylabel", type=str, default='') 192 | parser.add_argument("--labels", type=str, nargs='*') 193 | parser.add_argument("--colors", type=str, nargs='*') 194 | parser.add_argument("--figsize", type=int, nargs='*', default=[14, 9]) 195 | parser.add_argument("--yticks", type=float, nargs='*') 196 | parser.add_argument("--fontsize", type=int, default=10) 197 | parser.add_argument("--sampling_factor", type=float, default=4) 198 | parser.add_argument("--n_epoch", type=int, default=None) 199 | parser.add_argument("--curves_styles", type=str, nargs='*', choices=[' -', ' --', ' -.', ' :'], 200 | help="Careful: put an extra space at the beginning of the string, to avoid a parsing error.") 201 | parser.add_argument("--loc", type=str, default=None, choices=matplotlib.legend.Legend.codes.copy()) 202 | parser.add_argument("--epc", type=int, help="Dummy to maintain call compatibility with hist.py and moustache.py") 203 | 204 | parser.add_argument("--save_csv", action='store_true', help="Save the data used for the plot a a csv.") 205 | 206 | args = parser.parse_args() 207 | 208 | print(args) 209 | 210 | return args 211 | 212 | 213 | if __name__ == "__main__": 214 | run(get_args()) 215 | -------------------------------------------------------------------------------- /preprocess/slice_acdc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import re 26 | import random 27 | import argparse 28 | import warnings 29 | from pathlib import Path 30 | from pprint import pprint 31 | from functools import partial 32 | from typing import Any, Callable, List, Tuple 33 | 34 | import numpy as np 35 | import nibabel as nib 36 | import matplotlib 37 | matplotlib.use("Agg") 38 | import matplotlib.pyplot as plt 39 | from numpy import unique as uniq 40 | from skimage.io import imsave 41 | from skimage.transform import resize 42 | # from PIL import Image 43 | 44 | from utils import mmap_, uc_, map_, augment, flatten_ 45 | 46 | 47 | def norm_arr(img: np.ndarray) -> np.ndarray: 48 | casted = img.astype(np.float32) 49 | shifted = casted - casted.min() 50 | norm = shifted / shifted.max() 51 | res = 255 * norm 52 | 53 | return res.astype(np.uint8) 54 | 55 | 56 | def get_frame(filename: str, regex: str = ".*_frame(\d+)(_gt)?\.nii.*") -> str: 57 | matched = re.match(regex, filename) 58 | 59 | if matched: 60 | return matched.group(1) 61 | raise ValueError(regex, filename) 62 | 63 | 64 | def get_p_id(path: Path) -> str: 65 | ''' 66 | The patient ID, for the ACDC dataset, is the folder containing the data. 67 | ''' 68 | res = path.parent.name 69 | 70 | assert "patient" in res, res 71 | return res 72 | 73 | 74 | def save_slices(img_p: Path, gt_p: Path, 75 | dest_dir: Path, shape: Tuple[int, int], n_augment: int, 76 | img_dir: str = "img", gt_dir: str = "gt") -> Tuple[Any, Any, Any, Any]: 77 | p_id: str = get_p_id(img_p) 78 | assert "patient" in p_id 79 | assert p_id == get_p_id(gt_p) 80 | 81 | f_id: str = get_frame(img_p.name) 82 | assert f_id == get_frame(gt_p.name) 83 | 84 | # Load the data 85 | dx, dy, dz = nib.load(str(img_p)).header.get_zooms() 86 | assert dz in [5, 6.5, 7, 10], dz 87 | img = np.asarray(nib.load(str(img_p)).dataobj) 88 | gt = np.asarray(nib.load(str(gt_p)).dataobj) 89 | 90 | nx, ny = shape 91 | fx = nx / img.shape[0] 92 | fy = ny / img.shape[1] 93 | # print(f"Before dx {dx:.04f}, dy {dy:.04f}") 94 | dx /= fx 95 | dy /= fy 96 | # print(f"After dx {dx:.04f}, dy {dy:.04f}") 97 | 98 | # print(dx, dy, dz) 99 | pixel_surface: float = dx * dy 100 | voxel_volume: float = dx * dy * dz 101 | 102 | assert img.shape == gt.shape 103 | # assert img.shape[:-1] == shape 104 | assert img.dtype in [np.uint8, np.int16, np.float32] 105 | 106 | # Normalize and check data content 107 | norm_img = norm_arr(img) # We need to normalize the whole 3d img, not 2d slices 108 | assert 0 == norm_img.min() and norm_img.max() == 255, (norm_img.min(), norm_img.max()) 109 | assert gt.dtype == norm_img.dtype == np.uint8 110 | 111 | resize_: Callable = partial(resize, mode="constant", preserve_range=True, anti_aliasing=False) 112 | 113 | save_dir_img: Path = Path(dest_dir, img_dir) 114 | save_dir_gt: Path = Path(dest_dir, gt_dir) 115 | sizes_2d: np.ndarray = np.zeros(img.shape[-1]) 116 | for j in range(img.shape[-1]): 117 | img_s = norm_img[:, :, j] 118 | gt_s = gt[:, :, j] 119 | assert img_s.shape == gt_s.shape 120 | assert gt_s.dtype == np.uint8 121 | 122 | # Resize and check the data are still what we expect 123 | r_img: np.ndarray = resize_(img_s, shape).astype(np.uint8) 124 | r_gt: np.ndarray = resize_(gt_s, shape, order=0) 125 | # r_gt: np.ndarray = np.array(Image.fromarray(gt_s, mode='L').resize(shape)) 126 | assert set(uniq(r_gt)).issubset(set(uniq(gt))), (r_gt.dtype, uniq(r_gt)) 127 | r_gt = r_gt.astype(np.uint8) 128 | assert r_img.dtype == r_gt.dtype == np.uint8 129 | assert 0 <= r_img.min() and r_img.max() <= 255 # The range might be smaller 130 | sizes_2d[j] = (r_gt == 3).astype(np.int64).sum() 131 | 132 | for k in range(n_augment + 1): 133 | if k == 0: 134 | a_img, a_gt = r_img, r_gt 135 | else: 136 | a_img, a_gt = map_(np.asarray, augment(r_img, r_gt)) 137 | 138 | for save_dir, data in zip([save_dir_img, save_dir_gt], [a_img, a_gt]): 139 | filename = f"{p_id}_{f_id}_{k}_{j}.png" 140 | save_dir.mkdir(parents=True, exist_ok=True) 141 | 142 | with warnings.catch_warnings(): 143 | warnings.filterwarnings("ignore", category=UserWarning) 144 | imsave(str(Path(save_dir, filename)), data) 145 | 146 | lv_gt = (gt == 3).astype(np.uint8) 147 | assert set(np.unique(lv_gt)) <= set([0, 1]) 148 | assert lv_gt.shape == gt.shape 149 | 150 | lv_gt = resize_(lv_gt, (*shape, img.shape[-1]), order=0) 151 | assert set(np.unique(lv_gt)) <= set([0, 1]) 152 | 153 | slices_sizes_px = np.einsum("xyz->z", lv_gt.astype(np.int64)) 154 | assert np.array_equal(slices_sizes_px, sizes_2d), (slices_sizes_px, sizes_2d) 155 | # slices_sizes_px = sizes_2d[...] 156 | slices_sizes_px = slices_sizes_px[slices_sizes_px > 0] 157 | slices_sizes_mm2 = slices_sizes_px * pixel_surface 158 | 159 | # volume_size_px = np.einsum("xyz->", lv_gt) 160 | volume_size_px = slices_sizes_px.sum() 161 | volume_size_mm3 = volume_size_px * voxel_volume 162 | 163 | # print(f"{slices_sizes_px.mean():.0f}, {volume_size_px}") 164 | 165 | return slices_sizes_px, slices_sizes_mm2, volume_size_px, volume_size_mm3 166 | 167 | 168 | def main(args: argparse.Namespace): 169 | src_path: Path = Path(args.source_dir) 170 | dest_path: Path = Path(args.dest_dir) 171 | 172 | # Assume the cleaning up is done before calling the script 173 | assert src_path.exists() 174 | assert not dest_path.exists() 175 | 176 | # Get all the file names, avoid the temporal ones 177 | nii_paths: List[Path] = [p for p in src_path.rglob('*.nii.gz') if "_4d" not in str(p)] 178 | assert len(nii_paths) % 2 == 0, "Uneven number of .nii, one+ pair is broken" 179 | 180 | # We sort now, but also id matching is checked while iterating later on 181 | img_nii_paths: List[Path] = sorted(p for p in nii_paths if "_gt" not in str(p)) 182 | gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "_gt" in str(p)) 183 | assert len(img_nii_paths) == len(gt_nii_paths) == 200 184 | paths: List[Tuple[Path, Path]] = list(zip(img_nii_paths, gt_nii_paths)) 185 | 186 | print(f"Found {len(img_nii_paths)} pairs in total") 187 | pprint(paths[:5]) 188 | 189 | pids: List[str] = sorted(set(map_(get_p_id, img_nii_paths))) 190 | assert len(pids) == (len(img_nii_paths) // 2), (len(pids), len(img_nii_paths)) 191 | 192 | # validation_pids: List[str] = random.sample(pids, args.retains) 193 | random.shuffle(pids) # Shuffle before to avoid any problem if the patients are sorted in any way 194 | validation_slice = slice(args.fold * args.retains, (args.fold + 1) * args.retains) 195 | validation_pids: List[str] = pids[validation_slice] 196 | assert len(validation_pids) == args.retains 197 | 198 | validation_paths: List[Tuple[Path, Path]] = [p for p in paths if get_p_id(p[0]) in validation_pids] 199 | training_paths: List[Tuple[Path, Path]] = [p for p in paths if get_p_id(p[0]) not in validation_pids] 200 | assert set(validation_paths).isdisjoint(set(training_paths)) 201 | assert len(paths) == (len(validation_paths) + len(training_paths)) 202 | assert len(validation_paths) == 2 * args.retains 203 | assert len(training_paths) == (len(paths) - 2 * args.retains) 204 | 205 | for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): 206 | img_paths, gt_paths = zip(*_paths) # type: Tuple[Any, Any] 207 | 208 | dest_dir = Path(dest_path, mode) 209 | print(f"Slicing {len(img_paths)} pairs to {dest_dir}") 210 | assert len(img_paths) == len(gt_paths) 211 | 212 | pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment) 213 | all_sizes = mmap_(uc_(pfun), zip(img_paths, gt_paths)) 214 | # for paths in tqdm(list(zip(img_paths, gt_paths)), ncols=50): 215 | # uc_(pfun)(paths) 216 | 217 | all_slices_sizes_px, all_slices_sizes_mm2, all_volume_size_px, all_volume_size_mm3 = zip(*all_sizes) 218 | 219 | flat_sizes_px = flatten_(all_slices_sizes_px) 220 | flat_sizes_mm2 = flatten_(all_slices_sizes_mm2) 221 | print("px", len(flat_sizes_px), min(flat_sizes_px), max(flat_sizes_px)) 222 | print('\t', "px 5/95", np.percentile(flat_sizes_px, 5), np.percentile(flat_sizes_px, 95)) 223 | print('\t', "mm2", f"{min(flat_sizes_mm2):.02f}", f"{max(flat_sizes_mm2):.02f}") 224 | 225 | _, axes = plt.subplots(nrows=2, ncols=2) 226 | axes = axes.flatten() 227 | 228 | axes[0].set_title("Slice surface (pixel)") 229 | axes[0].boxplot(all_slices_sizes_px, whis=[0, 100]) 230 | 231 | axes[1].set_title("Slice surface (mm2)") 232 | axes[1].boxplot(all_slices_sizes_mm2, whis=[0, 100]) 233 | 234 | axes[2].set_title("LV volume (pixel)") 235 | axes[2].hist(all_volume_size_px, bins=len(all_volume_size_px) // 2) 236 | 237 | axes[3].set_title("LV volume (mm3)") 238 | axes[3].hist(all_volume_size_mm3, bins=len(all_volume_size_px) // 2) 239 | 240 | # plt.show() 241 | 242 | 243 | def get_args() -> argparse.Namespace: 244 | parser = argparse.ArgumentParser(description='Slicing parameters') 245 | parser.add_argument('--source_dir', type=str, required=True) 246 | parser.add_argument('--dest_dir', type=str, required=True) 247 | 248 | parser.add_argument('--img_dir', type=str, default="IMG") 249 | parser.add_argument('--gt_dir', type=str, default="GT") 250 | parser.add_argument('--shape', type=int, nargs="+", default=[256, 256]) 251 | parser.add_argument('--retains', type=int, default=25, help="Number of retained patient for the validation data") 252 | parser.add_argument('--seed', type=int, default=0) 253 | parser.add_argument('--fold', type=int, default=0) 254 | parser.add_argument('--n_augment', type=int, default=0, 255 | help="Number of augmentation to create per image, only for the training set") 256 | args = parser.parse_args() 257 | random.seed(args.seed) 258 | 259 | print(args) 260 | 261 | return args 262 | 263 | 264 | if __name__ == "__main__": 265 | args = get_args() 266 | random.seed(args.seed) 267 | 268 | main(args) 269 | -------------------------------------------------------------------------------- /preprocess/slice_isles.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import pickle 26 | import random 27 | import argparse 28 | import warnings 29 | from pathlib import Path 30 | from pprint import pprint 31 | from functools import partial 32 | from typing import Dict, List, Tuple 33 | 34 | import torch 35 | import numpy as np 36 | import nibabel as nib 37 | import matplotlib.pyplot as plt 38 | from torch import Tensor 39 | from tqdm import tqdm 40 | from skimage.io import imsave 41 | 42 | from utils import mmap_, uc_, map_, augment_arr 43 | from utils import class2one_hot, one_hot2dist 44 | 45 | 46 | def norm_arr(img: np.ndarray) -> np.ndarray: 47 | casted = img.astype(np.float32) 48 | shifted = casted - casted.min() 49 | norm = shifted / shifted.max() 50 | res = norm 51 | 52 | return res 53 | 54 | 55 | def get_p_id(path: Path) -> str: 56 | ''' 57 | The patient ID, for the ACDC dataset, is the folder containing the data. 58 | ''' 59 | res = path.parents[1].name 60 | 61 | assert "case_" in res, res 62 | return res 63 | 64 | 65 | def save_slices(ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths, 66 | dest_dir: Path, shape: Tuple[int], n_augment: int, 67 | ct_dir: str = "ct", cbf_dir="cbf", cbv_dir="cbv", mtt_dir="mtt", tmax_dir="tmax", 68 | gt_dir: str = "gt", in_npy_dir="in_npy", gt_npy_dir='gt_npy') -> Tuple[int, 69 | int, 70 | Dict, 71 | Tuple[float, float, float]]: 72 | p_id: str = get_p_id(ct_paths) 73 | assert len(set(map_(get_p_id, [ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths]))) == 1 74 | 75 | space_dict: Dict[str, Tuple[float, float]] = {} 76 | 77 | # Load the data 78 | dx, dy, dz = nib.load(str(ct_paths)).header.get_zooms() 79 | assert dx == dy 80 | ct = np.asarray(nib.load(str(ct_paths)).dataobj) 81 | cbf = np.asarray(nib.load(str(cbf_paths)).dataobj) 82 | cbv = np.asarray(nib.load(str(cbv_paths)).dataobj) 83 | mtt = np.asarray(nib.load(str(mtt_paths)).dataobj) 84 | tmax = np.asarray(nib.load(str(tmax_paths)).dataobj) 85 | gt = np.asarray(nib.load(str(gt_paths)).dataobj) 86 | 87 | assert len(set(map_(np.shape, [ct, cbf, cbv, mtt, tmax, gt]))) == 1 88 | assert ct.dtype in [np.int32], ct.dtype 89 | assert cbf.dtype in [np.uint16], cbf.dtype 90 | assert cbv.dtype in [np.uint16], cbv.dtype 91 | assert mtt.dtype in [np.float64], mtt.dtype 92 | assert tmax.dtype in [np.float64], tmax.dtype 93 | assert gt.dtype in [np.uint8], gt.dtype 94 | 95 | pos: int = (gt == 1).sum() 96 | neg: int = (gt == 0).sum() 97 | 98 | x, y, z = ct.shape 99 | 100 | # Normalize and check data content 101 | norm_ct = norm_arr(ct) # We need to normalize the whole 3d img, not 2d slices 102 | norm_cbf = norm_arr(cbf) 103 | norm_cbv = norm_arr(cbv) 104 | norm_mtt = norm_arr(mtt) 105 | norm_tmax = norm_arr(tmax) 106 | assert 0 == norm_ct.min() and norm_ct.max() == 1, (norm_ct.min(), norm_ct.max()) 107 | assert 0 == norm_cbf.min() and norm_cbf.max() == 1, (norm_cbf.min(), norm_cbf.max()) 108 | assert 0 == norm_cbv.min() and norm_cbv.max() == 1, (norm_cbv.min(), norm_cbv.max()) 109 | assert 0 == norm_mtt.min() and norm_mtt.max() == 1, (norm_mtt.min(), norm_mtt.max()) 110 | assert 0 == norm_tmax.min() and norm_tmax.max() == 1, (norm_tmax.min(), norm_tmax.max()) 111 | 112 | one_hot_gt: Tensor = class2one_hot(torch.tensor(gt[None, ...], dtype=torch.int64), K=2)[0] 113 | assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape 114 | distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(), 115 | resolution=(dx, dy, dz), 116 | dtype=np.float32) 117 | 118 | save_dir_ct: Path = Path(dest_dir, ct_dir) 119 | save_dir_cbf: Path = Path(dest_dir, cbf_dir) 120 | save_dir_cbv: Path = Path(dest_dir, cbv_dir) 121 | save_dir_mtt: Path = Path(dest_dir, mtt_dir) 122 | save_dir_tmax: Path = Path(dest_dir, tmax_dir) 123 | save_dir_gt: Path = Path(dest_dir, gt_dir) 124 | save_dir_in_npy: Path = Path(dest_dir, in_npy_dir) 125 | save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir) 126 | save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap") 127 | save_dirs = [save_dir_ct, save_dir_cbf, save_dir_cbv, save_dir_mtt, save_dir_tmax, save_dir_gt] 128 | 129 | for j in range(ct.shape[-1]): 130 | ct_s = norm_ct[:, :, j] 131 | cbf_s = norm_cbf[:, :, j] 132 | cbv_s = norm_cbv[:, :, j] 133 | mtt_s = norm_mtt[:, :, j] 134 | tmax_s = norm_tmax[:, :, j] 135 | gt_s = gt[:, :, j] 136 | dist_s = distmap[:, :, :, j] 137 | slices = [ct_s, cbf_s, cbv_s, mtt_s, tmax_s, gt_s] 138 | assert ct_s.shape == cbf_s.shape == cbv_s.shape, mtt_s.shape == tmax_s.shape == gt_s.shape 139 | assert gt_s.shape == dist_s[0, ...].shape, ((x, y, z), gt_s.shape, dist_s.shape) 140 | assert set(np.unique(gt_s)).issubset([0, 1]) 141 | 142 | # if gt_s.sum() > 0: 143 | # print(f"{dist_s[1].min()=} {dist_s[1].max()=}") 144 | # _, axes = plt.subplots(nrows=1, ncols=3) 145 | # axes[0].imshow(gt_s) 146 | # axes[0].set_title("GT") 147 | 148 | # tmp = axes[1].imshow(dist_s[1, ...], cmap='rainbow') 149 | # axes[1].set_title("Signed distance map") 150 | # plt.colorbar(tmp, ax=axes[1]) 151 | 152 | # tmp = axes[2].imshow(np.abs(dist_s[1, ...]), cmap='rainbow') 153 | # axes[2].set_title("Abs distance map") 154 | # plt.colorbar(tmp, ax=axes[2]) 155 | # plt.show() 156 | 157 | for k in range(n_augment + 1): 158 | if k == 0: 159 | to_save = slices 160 | else: 161 | to_save = map_(np.asarray, augment_arr(*slices)) 162 | assert to_save[0].shape == slices[0].shape, (to_save[0].shape, slices[0].shape) 163 | 164 | filename = f"{p_id}_{k:02d}_{j:04d}" 165 | space_dict[filename] = (dx, dy) 166 | for save_dir, data in zip(save_dirs, to_save): 167 | save_dir.mkdir(parents=True, exist_ok=True) 168 | 169 | if "gt" not in str(save_dir): 170 | img = (data * 255).astype(np.uint8) 171 | else: 172 | img = data.astype(np.uint8) 173 | 174 | with warnings.catch_warnings(): 175 | warnings.filterwarnings("ignore", category=UserWarning) 176 | imsave(str(Path(save_dir, filename).with_suffix(".png")), img) 177 | 178 | multimodal = np.stack(to_save[:-1]) # Do not include the ground truth 179 | assert 0 <= multimodal.min() and multimodal.max() <= 1 180 | save_dir_in_npy.mkdir(parents=True, exist_ok=True) 181 | save_dir_gt_npy.mkdir(parents=True, exist_ok=True) 182 | np.save(Path(save_dir_in_npy, filename).with_suffix(".npy"), multimodal) 183 | np.save(Path(save_dir_gt_npy, filename).with_suffix(".npy"), to_save[-1]) 184 | 185 | save_dir_distmap_npy.mkdir(parents=True, exist_ok=True) 186 | np.save(Path(save_dir_distmap_npy, filename).with_suffix(".npy"), dist_s) 187 | 188 | return neg, pos, space_dict, (dx, dy, dz) 189 | 190 | 191 | def main(args: argparse.Namespace): 192 | src_path: Path = Path(args.source_dir) 193 | dest_path: Path = Path(args.dest_dir) 194 | 195 | # Assume the cleaning up is done before calling the script 196 | assert src_path.exists() 197 | assert not dest_path.exists() 198 | 199 | # Get all the file names, avoid the temporal ones 200 | all_paths: List[Path] = list(src_path.rglob('*.nii')) 201 | nii_paths: List[Path] = [p for p in all_paths if "_4D" not in str(p)] 202 | assert len(nii_paths) % 6 == 0, "Number of .nii not multiple of 6, some pairs are broken" 203 | 204 | # We sort now, but also id matching is checked while iterating later on 205 | CT_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT." in str(p)) 206 | CBF_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_CBF" in str(p)) 207 | CBV_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_CBV" in str(p)) 208 | MTT_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_MTT" in str(p)) 209 | Tmax_nii_paths: List[Path] = sorted(p for p in nii_paths if "CT_Tmax" in str(p)) 210 | gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "OT" in str(p)) 211 | assert len(CT_nii_paths) == len(CBF_nii_paths) == len(CBV_nii_paths) == len(MTT_nii_paths) \ 212 | == len(Tmax_nii_paths) == len(gt_nii_paths) 213 | paths: List[Tuple[Path, ...]] = list(zip(CT_nii_paths, CBF_nii_paths, CBV_nii_paths, MTT_nii_paths, 214 | Tmax_nii_paths, gt_nii_paths)) 215 | 216 | print(f"Found {len(CT_nii_paths)} pairs in total") 217 | pprint(paths[:2]) 218 | 219 | resolution_dict: Dict[str, Tuple[float, float, float]] = {} 220 | 221 | validation_paths: List[Tuple[Path, ...]] = random.sample(paths, args.retain) 222 | training_paths: List[Tuple[Path, ...]] = [p for p in paths if p not in validation_paths] 223 | assert set(validation_paths).isdisjoint(set(training_paths)) 224 | assert len(paths) == (len(validation_paths) + len(training_paths)) 225 | 226 | for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): 227 | # ct_paths, cbf_paths, cbv_paths, mtt_paths, tmax_paths, gt_paths = zip(*_paths) 228 | six_paths = list(zip(*_paths)) 229 | 230 | dest_dir = Path(dest_path, mode) 231 | print(f"Slicing {len(six_paths[0])} pairs to {dest_dir}") 232 | assert len(set(map_(len, six_paths))) == 1 233 | 234 | pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment) 235 | resolutions: List[Tuple[float, float, float]] 236 | all_neg, all_pos, space_dicts, resolutions = zip(*mmap_(uc_(pfun), zip(*six_paths))) 237 | neg, pos = sum(all_neg), sum(all_pos) 238 | ratio = pos / neg 239 | print(f"Ratio between pos/neg: {ratio} ({pos}/{neg})") 240 | # space_dicts, resolutions = zip(*map_(uc_(pfun), zip(*six_paths))) 241 | # for case_paths in tqdm(list(zip(*six_paths)), ncols=50): 242 | # uc_(pfun)(case_paths) 243 | 244 | final_dict = {k: v for space_dict in space_dicts for k, v in space_dict.items()} 245 | 246 | for key, val in zip(map_(get_p_id, six_paths[0]), resolutions): 247 | resolution_dict[key] = val 248 | 249 | with open(Path(dest_dir, "spacing.pkl"), 'wb') as f: 250 | pickle.dump(final_dict, f, pickle.HIGHEST_PROTOCOL) 251 | print(f"Saved spacing dictionnary to {f}") 252 | 253 | assert len(resolution_dict.keys()) == len(CT_nii_paths) 254 | pprint(resolution_dict) 255 | 256 | with open(dest_path / "spacing_3d.pkl", 'wb') as f: 257 | pickle.dump(resolution_dict, f, pickle.HIGHEST_PROTOCOL) 258 | print(f"Saved spacing dictionnary to {f}") 259 | 260 | 261 | def get_args() -> argparse.Namespace: 262 | parser = argparse.ArgumentParser(description='Slicing parameters') 263 | parser.add_argument('--source_dir', type=str, required=True) 264 | parser.add_argument('--dest_dir', type=str, required=True) 265 | parser.add_argument('--img_dir', type=str, default="IMG") 266 | parser.add_argument('--gt_dir', type=str, default="GT") 267 | parser.add_argument('--shape', type=int, nargs="+", default=[256, 256]) 268 | parser.add_argument('--retain', type=int, default=25, help="Number of retained patient for the validation data") 269 | parser.add_argument('--seed', type=int, default=0) 270 | parser.add_argument('--n_augment', type=int, default=0, 271 | help="Number of augmentation to create per image, only for the training set") 272 | args = parser.parse_args() 273 | random.seed(args.seed) 274 | 275 | print(args) 276 | 277 | return args 278 | 279 | 280 | if __name__ == "__main__": 281 | args = get_args() 282 | random.seed(args.seed) 283 | 284 | main(args) 285 | -------------------------------------------------------------------------------- /preprocess/slice_wmh.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.6 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import random 26 | import pickle 27 | import argparse 28 | import warnings 29 | from pathlib import Path 30 | from pprint import pprint 31 | from functools import partial 32 | from typing import Dict, List, Tuple 33 | 34 | import torch 35 | import numpy as np 36 | import nibabel as nib 37 | from tqdm import tqdm 38 | from torch import Tensor 39 | from skimage.io import imsave 40 | from skimage.transform import resize 41 | import matplotlib.pyplot as plt 42 | 43 | from utils import mmap_, uc_, map_, augment_arr 44 | from utils import class2one_hot, one_hot2dist 45 | 46 | 47 | def norm_arr(img: np.ndarray) -> np.ndarray: 48 | casted = img.astype(np.float32) 49 | shifted = casted - casted.min() 50 | norm = shifted / shifted.max() 51 | res = norm 52 | 53 | return res 54 | 55 | 56 | def get_p_id(path: Path) -> str: 57 | res = list(path.parents)[-4].name 58 | 59 | return res 60 | 61 | 62 | def save_slices(flair_path, t1_path, gt_path, 63 | dest_dir: Path, shape: Tuple[int], n_augment: int, discard_negatives: bool, 64 | flair_dir: str = "flair", t1_dir="t1", 65 | gt_dir: str = "gt", in_npy_dir="in_npy", gt_npy_dir='gt_npy') -> Tuple[int, 66 | int, 67 | Dict, 68 | Tuple[float, float, float]]: 69 | p_id: str = get_p_id(flair_path) 70 | assert len(set(map_(get_p_id, [flair_path, t1_path, gt_path]))) == 1 71 | print(p_id) 72 | 73 | space_dict: Dict[str, Tuple[float, float]] = {} 74 | 75 | # Load the data 76 | dx, dy, dz = nib.load(str(flair_path)).header.get_zooms() 77 | # assert dx == dy, (dx, dy) 78 | flair = np.asarray(nib.load(str(flair_path)).dataobj) 79 | w, h, _ = flair.shape 80 | x, y, z = flair.shape 81 | t1 = np.asarray(nib.load(str(t1_path)).dataobj) 82 | gt = np.asarray(nib.load(str(gt_path)).dataobj) 83 | assert set(np.unique(gt)) <= set([0., 1., 2.]) 84 | 85 | pos: int = (gt == 1).sum() 86 | neg: int = ((gt == 0) | (gt == 2)).sum() 87 | 88 | assert len(set(map_(np.shape, [flair, t1, gt]))) == 1 89 | assert flair.dtype in [np.float32], flair.dtype 90 | assert t1.dtype in [np.uint16], t1.dtype 91 | assert gt.dtype in [np.float32], gt.dtype 92 | 93 | # Normalize and check data content 94 | norm_flair = norm_arr(flair) # We need to normalize the whole 3d img, not 2d slices 95 | norm_t1 = norm_arr(t1) 96 | norm_gt = gt.astype(np.uint8) 97 | assert 0 == norm_flair.min() and norm_flair.max() == 1, (norm_flair.min(), norm_flair.max()) 98 | assert 0 == norm_t1.min() and norm_t1.max() == 1, (norm_t1.min(), norm_t1.max()) 99 | assert np.array_equal(np.unique(gt), np.unique(norm_gt)) 100 | 101 | resized_flair = resize(norm_flair, (256, 256, z), 102 | mode='constant', preserve_range=True, anti_aliasing=False).astype(np.float32) 103 | resized_t1 = resize(norm_t1, (256, 256, z), 104 | mode='constant', preserve_range=True, anti_aliasing=False).astype(np.float32) 105 | resized_gt = resize(norm_gt, (256, 256, z), 106 | mode='constant', preserve_range=True, anti_aliasing=False, order=0).astype(np.uint8) 107 | resized_gt[np.where(resized_gt == 2)] = 0 # Count those labels as background 108 | 109 | # Pre-compute the 3d distance map 110 | rx = dx * w / 256 111 | ry = dy * h / 256 112 | rz = dz 113 | # print(f"{flair.shape=}") 114 | # print(f"{(dx,dy,dz)=} {(rx,ry,rz)=}") 115 | 116 | one_hot_gt: Tensor = class2one_hot(torch.tensor(resized_gt[None, ...], dtype=torch.int64), K=2)[0] 117 | assert one_hot_gt.shape == (2, 256, 256, z), one_hot_gt.shape 118 | distmap: np.ndarray = one_hot2dist(one_hot_gt.numpy(), 119 | resolution=(rx, ry, rz), 120 | dtype=np.float32) 121 | 122 | save_dir_flair: Path = Path(dest_dir, flair_dir) 123 | save_dir_t1: Path = Path(dest_dir, t1_dir) 124 | save_dir_gt: Path = Path(dest_dir, gt_dir) 125 | save_dir_in_npy: Path = Path(dest_dir, in_npy_dir) 126 | save_dir_gt_npy: Path = Path(dest_dir, gt_npy_dir) 127 | save_dir_distmap_npy: Path = Path(dest_dir, "3d_distmap") 128 | save_dirs = [save_dir_flair, save_dir_t1, save_dir_gt] 129 | 130 | for j in range(flair.shape[-1]): 131 | flair_s = resized_flair[:, :, j] 132 | t1_s = resized_t1[:, :, j] 133 | gt_s = resized_gt[:, :, j] 134 | 135 | dist_s = distmap[:, :, :, j] 136 | # if gt_s.sum() > 0: 137 | # print(f"{dist_s.min()=} {dist_s.max()=}") 138 | # _, axes = plt.subplots(nrows=1, ncols=2) 139 | # axes[0].imshow(gt_s) 140 | # axes[0].set_title("GT") 141 | 142 | # tmp = axes[1].imshow(dist_s[1, ...]) 143 | # axes[1].set_title("Distance map") 144 | # plt.colorbar(tmp, ax=axes[1]) 145 | # plt.show() 146 | 147 | slices = [flair_s, t1_s, gt_s] 148 | assert flair_s.shape == t1_s.shape == gt_s.shape == dist_s[0, ...].shape, ((x, y, z), flair_s.shape, dist_s.shape) 149 | # gt_s[np.where(gt_s == 2)] = 0 # Now do that part earlier 150 | assert set(np.unique(gt_s)).issubset([0, 1]), np.unique(gt_s) 151 | 152 | if discard_negatives and (gt_s.sum() == 0): 153 | continue 154 | 155 | for k in range(n_augment + 1): 156 | if k == 0: 157 | to_save = slices 158 | else: 159 | to_save = map_(np.asarray, augment_arr(*slices)) 160 | assert to_save[0].shape == slices[0].shape, (to_save[0].shape, slices[0].shape) 161 | 162 | filename = f"{p_id}_{k:02d}_{j:04d}" 163 | space_dict[filename] = (rx, ry) 164 | for save_dir, data in zip(save_dirs, to_save): 165 | save_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | if "gt" not in str(save_dir): 168 | img = (data * 255).astype(np.uint8) 169 | else: 170 | img = data.astype(np.uint8) 171 | 172 | with warnings.catch_warnings(): 173 | warnings.filterwarnings("ignore", category=UserWarning) 174 | imsave(str(Path(save_dir, filename).with_suffix(".png")), img) 175 | 176 | multimodal = np.stack(to_save[:-1]) # Do not include the ground truth 177 | assert 0 <= multimodal.min(), multimodal.min() 178 | assert multimodal.max() <= 1, multimodal.max() 179 | save_dir_in_npy.mkdir(parents=True, exist_ok=True) 180 | save_dir_gt_npy.mkdir(parents=True, exist_ok=True) 181 | np.save(Path(save_dir_in_npy, filename).with_suffix(".npy"), multimodal) 182 | np.save(Path(save_dir_gt_npy, filename).with_suffix(".npy"), to_save[-1]) 183 | 184 | save_dir_distmap_npy.mkdir(parents=True, exist_ok=True) 185 | np.save(Path(save_dir_distmap_npy, filename).with_suffix(".npy"), dist_s) 186 | 187 | return neg, pos, space_dict, (rx, ry, rz) 188 | 189 | 190 | def main(args: argparse.Namespace): 191 | src_path: Path = Path(args.source_dir) 192 | dest_path: Path = Path(args.dest_dir) 193 | 194 | # Assume the cleaning up is done before calling the script 195 | assert src_path.exists() 196 | assert not dest_path.exists() 197 | 198 | # Get all the file names, avoid the temporal ones 199 | all_paths: List[Path] = list(src_path.rglob('*.nii.gz')) 200 | nii_paths: List[Path] = [p for p in all_paths if "_4D" not in str(p)] 201 | assert len(nii_paths) % 3 == 0, "Number of .nii not multiple of 6, some pairs are broken" 202 | 203 | # We sort now, but also id matching is checked while iterating later on 204 | flair_nii_paths: List[Path] = sorted(p for p in nii_paths if "FLAIR" in str(p)) 205 | t1_nii_paths: List[Path] = sorted(p for p in nii_paths if "T1" in str(p)) 206 | gt_nii_paths: List[Path] = sorted(p for p in nii_paths if "wmh.nii" in str(p)) 207 | assert len(flair_nii_paths) == len(t1_nii_paths) == len(gt_nii_paths) 208 | paths: List[Tuple[Path, ...]] = list(zip(flair_nii_paths, t1_nii_paths, gt_nii_paths)) 209 | 210 | print(f"Found {len(flair_nii_paths)} pairs in total") 211 | pprint(paths[:2]) 212 | 213 | resolution_dict: Dict[str, Tuple[float, float, float]] = {} 214 | 215 | validation_paths: List[Tuple[Path, ...]] = random.sample(paths, args.retain) 216 | training_paths: List[Tuple[Path, ...]] = [p for p in paths if p not in validation_paths] 217 | assert set(validation_paths).isdisjoint(set(training_paths)) 218 | assert len(paths) == (len(validation_paths) + len(training_paths)) 219 | 220 | for mode, _paths, n_augment in zip(["train", "val"], [training_paths, validation_paths], [args.n_augment, 0]): 221 | three_paths = list(zip(*_paths)) 222 | 223 | dest_dir = Path(dest_path, mode) 224 | print(f"Slicing {len(three_paths[0])} pairs to {dest_dir}") 225 | assert len(set(map_(len, three_paths))) == 1 226 | 227 | pfun = partial(save_slices, dest_dir=dest_dir, shape=args.shape, n_augment=n_augment, 228 | discard_negatives=args.discard_negatives) 229 | sizes = mmap_(uc_(pfun), zip(*three_paths)) 230 | # sizes = map_(uc_(pfun), zip(*three_paths)) 231 | resolutions: List[Tuple[float, float, float]] 232 | all_neg, all_pos, space_dicts, resolutions = zip(*sizes) 233 | neg, pos = sum(all_neg), sum(all_pos) 234 | ratio = pos / neg 235 | print(f"Ratio between pos/neg: {ratio} ({pos}/{neg})") 236 | 237 | final_dict = {k: v for space_dict in space_dicts for k, v in space_dict.items()} 238 | 239 | for key, val in zip(map_(get_p_id, three_paths[0]), resolutions): 240 | resolution_dict[key] = val 241 | 242 | with open(Path(dest_dir, "spacing.pkl"), 'wb') as f: 243 | pickle.dump(final_dict, f, pickle.HIGHEST_PROTOCOL) 244 | print(f"Saved spacing dictionnary to {f}") 245 | 246 | # for case_paths in tqdm(list(zip(*three_paths)), ncols=50): 247 | # uc_(pfun)(case_paths) 248 | 249 | # from pprint import pprint 250 | assert len(resolution_dict.keys()) == len(flair_nii_paths) 251 | pprint(resolution_dict) 252 | 253 | with open(dest_path / "spacing_3d.pkl", 'wb') as f: 254 | pickle.dump(resolution_dict, f, pickle.HIGHEST_PROTOCOL) 255 | print(f"Saved spacing dictionnary to {f}") 256 | 257 | 258 | def get_args() -> argparse.Namespace: 259 | parser = argparse.ArgumentParser(description='Slicing parameters') 260 | parser.add_argument('--source_dir', type=str, required=True) 261 | parser.add_argument('--dest_dir', type=str, required=True) 262 | parser.add_argument('--img_dir', type=str, default="IMG") 263 | parser.add_argument('--gt_dir', type=str, default="GT") 264 | parser.add_argument('--shape', type=int, nargs="+", default=[256, 256]) 265 | parser.add_argument('--retain', type=int, default=25, help="Number of retained patient for the validation data") 266 | parser.add_argument('--seed', type=int, default=0) 267 | parser.add_argument('--n_augment', type=int, default=0, 268 | help="Number of augmentation to create per image, only for the training set") 269 | parser.add_argument('--discard_negatives', action='store_true') 270 | args = parser.parse_args() 271 | random.seed(args.seed) 272 | 273 | print(args) 274 | 275 | return args 276 | 277 | 278 | if __name__ == "__main__": 279 | args = get_args() 280 | random.seed(args.seed) 281 | 282 | main(args) 283 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Boundary loss 2 | Official repository for [Boundary loss for highly unbalanced segmentation](http://proceedings.mlr.press/v102/kervadec19a.html), _runner-up for best paper award_ at [MIDL 2019](https://2019.midl.io). Recording of the talk is available [on the MIDL YouTube channel](https://www.youtube.com/watch?v=_z6gmFlD_qE). 3 | 4 | A journal extension has been published in [Medical Image Analysis (MedIA), volume 67](https://doi.org/10.1016/j.media.2020.101851). 5 | 6 | * [MIDL 2019 Proceedings](http://proceedings.mlr.press/v102/kervadec19a.html) 7 | * [MedIA volume](https://doi.org/10.1016/j.media.2020.101851) 8 | * [arXiv preprint](https://arxiv.org/abs/1812.07032) 9 | 10 | ![Visual comparison](resources/readme_comparison.png) 11 | 12 | The code has been simplified and updated to the latest Python and Pytorch release. On top of the original ISLES and WMH datasets, we also include a working example in a multi-class setting (ACDC dataset), *where the boundary loss can work as a stand-alone loss*. 13 | 14 | ## Table of contents 15 | * [Table of contents](#table-of-contents) 16 | * [License](#license) 17 | * [Requirements (PyTorch)](#requirements-pytorch) 18 | * [Other frameworks](#other-frameworks) 19 | * [Keras/tensorflow](#kerastensorflow) 20 | * [Others](#others) 21 | * [Usage](#usage) 22 | * [Extension to 3D](#extension-to-3d) 23 | * [Automation](#automation) 24 | * [Data scheme](#data-scheme) 25 | * [dataset](#dataset) 26 | * [results](#results) 27 | * [Cool tricks](#cool-tricks) 28 | * [Multi-class setting](#multi-class-setting) 29 | * [Frequently asked questions](#frequently-asked-question) 30 | * [Can the loss be negative?](#can-the-loss-be-negative) 31 | * [Do I need to normalize the distance map?](#do-i-need-to-normalize-the-distance-map) 32 | * [Other papers using boundary loss](#other-papers-using-boundary-loss) 33 | 34 | ## License 35 | This code is under [MIT license](LICENSE), which permits re-use both in open and closed-source software. 36 | 37 | ## Requirements (PyTorch) 38 | Core implementation (to integrate the boundary loss into your own code): 39 | * python3.5+ 40 | * pytorch 1.0+ 41 | * scipy (any version) 42 | 43 | To reproduce our experiments: 44 | * python3.9+ 45 | * Pytorch 1.7+ 46 | * nibabel (only when slicing 3D volumes) 47 | * Scipy 48 | * NumPy 49 | * Matplotlib 50 | * Scikit-image 51 | * zsh 52 | 53 | ## Other frameworks 54 | ### Keras/Tensorflow 55 | @akamojo and @marcinkaczor proposed a Keras/Tensorflow implementation (I am very grateful for that), available in [keras_loss.py](keras_loss.py). 56 | The discussion is available in the [related github issue](https://github.com/LIVIAETS/surface-loss/issues/14). 57 | 58 | ### Others 59 | People willing to contribute other implementations can create a new pull-request, for their favorite framework. 60 | 61 | ## Usage 62 | The boundary loss, at its core, is a pixel-wise multiplication between the network predictions (the _softmaxes_), and a pre-computed distance map. Henceforth, a big chunk of the implementation happens at the data-loader level. 63 | 64 | The implementation has three key functions: 65 | * the boundary loss itself (`BoundaryLoss` in [losses.py#L76](losses.py#L76)); 66 | * the distance map function (`one_hot2dist` in [utils.py#L260](utils.py#L260)); 67 | * the dataloader transforms (`dist_map_transform` in [dataloader.py#L105](dataloader.py#L105)). 68 | 69 | This codebase computes the distance map at the dataloader level, taking as an input the label file (stored as a `.png`), putting it through the `dist_map_transform`, and then returning it with the corresponding input image. A higher-order view of the process: 70 | ```python 71 | class SliceSet(Dataset): 72 | def __init__(self): 73 | self.filenames: list[str] # You get the list as you would usually 74 | self.dataset_root: Path # Path to the root of the data 75 | 76 | self.disttransform = dist_map_transform([1, 1], 2) 77 | 78 | def __get__(self, n: int) -> dict[str, Tensor]: 79 | filename: str = self.filenames[index] 80 | 81 | image = Image.open(self.dataset_root / "img" / filename) 82 | label = Image.open(self.dataset_root / "gt" / filename) 83 | 84 | image_tensor: Tensor # usual transform for an image 85 | one_hot_tensor: Tensor # Usualy transform from png to one_hot encoding 86 | dist_map_tensor: Tensor = self.disttransform(label) 87 | 88 | return {"images": image_tensor, 89 | "gt": one_hot_tensor, 90 | "dist_map": dist_map_tensor} 91 | ``` 92 | 93 | In the main loop (when iterating over the dataloader), this gives the following pseudo-code: 94 | ```python 95 | dice_loss = GeneralizedDiceLoss(idc=[0, 1]) # add here the extra params for the losses 96 | boundary_loss = BoundaryLoss(idc=[1]) 97 | 98 | α = 0.01 99 | for data in loader: 100 | image: Tensor = data["images"] 101 | target: Tensor = data["gt"] 102 | dist_map_label: list[Tensor] = data["distmap"] 103 | 104 | pred_logits: Tensor = net(image) 105 | pred_probs: Tensor = F.softmax(pred_logits, dim=1) 106 | 107 | gdl_loss = dice_loss(pred_probs, target) 108 | bl_loss = boundary_loss(pred_probs, dist_map_label) # Notice we do not give the same input to that loss 109 | total_loss = gdl_loss + α * bl_loss 110 | 111 | loss.backward() 112 | optimizer.step() 113 | ``` 114 | 115 | Special care has to be taken when the spatial resolution varies across axises; this is especially true in 3D. There is an optional argument in the `one_hot2dist` function, and the `dist_map_transform` is parametrized with the resolution and number of classes. For instance: 116 | ```python 117 | disttransform = dist_map_transform([1, 1], 2) 118 | ``` 119 | will instantiate a distance transform for a binary setting, with 1mm on each axis, while 120 | ```python 121 | disttransform = dist_map_transform([0.97, 1, 2.5], 5) 122 | ``` 123 | would be for a 5-classes setting, with the `z` axis much wider than `x` or `y`. 124 | 125 | When dealing with a distance map in 3D, it is easiest to compute it while slicing the 3D volume to 2D images. An example of such processing is done in [preprocess/slice_wmh.py#L94](preprocess/slice_wmh.py#L94)). 126 | 127 | ## Extension to 3D 128 | Extension to a 3D-CNN is trivial, one need only to pre-compute the 3D-distance map, and to sub-patch it in a traditionnal fashion. 129 | 130 | The code of the Boundary loss remains the same, except for the einsum (line #89) that accounts for the extra axis (`xyz` in place of `wh`): 131 | ```python 132 | multipled = einsum("bkxyz,bkxyz->bkxyz", pc, dc) 133 | ``` 134 | 135 | ## Automation 136 | Experiments are handled by [GNU Make](https://en.wikipedia.org/wiki/Make_(software)). It should be installed on pretty much any machine. 137 | 138 | Instruction to download the data are contained in the lineage files [ISLES.lineage](data/ISLES.lineage) and [wmh.lineage](data/wmh.lineage). They are text files containing the md5sum of the original zip. 139 | 140 | Once the zip is in place, everything should be automatic: 141 | ```sh 142 | make -f isles.make 143 | make -f wmh.make 144 | ``` 145 | Usually takes a little bit more than a day per makefile. 146 | 147 | This perform in the following order: 148 | * unpacking of the data; 149 | * remove unwanted big files; 150 | * normalization and slicing of the data; 151 | * training with the different methods; 152 | * plotting of the metrics curves; 153 | * display of a report; 154 | * archiving of the results in an .tar.gz stored in the `archives` folder. 155 | 156 | Make will handle by itself the dependencies between the different parts. For instance, once the data has been pre-processed, it won't do it another time, even if you delete the training results. It is also a good way to avoid overwriting existing results by accident. 157 | 158 | Of course, parts can be launched separately : 159 | ```sh 160 | make -f isles.make data/isles # Unpack only 161 | make -f isles.make data/ISLES # unpack if needed, then slice the data 162 | make -f isles.make results/isles/gdl # train only with the GDL. Create the data if needed 163 | make -f isles.make results/isles/val_dice.png # Create only this plot. Do the trainings if needed 164 | ``` 165 | There is many options for the main script, because I use the same code-base for other projects. You can safely ignore most of them, and the different recipe in the makefiles should give you an idea on how to modify the training settings and create new targets. In case of questions, feel free to contact me. 166 | 167 | ### Data scheme 168 | #### datasets 169 | For instance 170 | ``` 171 | ISLES/ 172 | train/ 173 | cbf/ 174 | case_10_0_0.png 175 | ... 176 | cbv/ 177 | gt/ 178 | in_npy/ 179 | case_10_0_0.npy 180 | ... 181 | gt_npy/ 182 | ... 183 | val/ 184 | cbf/ 185 | case_10_0_0.png 186 | ... 187 | cbv/ 188 | gt/ 189 | in_npy/ 190 | case_10_0_0.npy 191 | ... 192 | gt_npy/ 193 | ... 194 | ``` 195 | The network takes npy files as an input (there is multiple modalities), but images for each modality are saved for convenience. The gt folder contains gray-scale images of the ground-truth, where the gray-scale level are the number of the class (namely, 0 and 1). This is because I often use my [segmentation viewer](https://github.com/HKervadec/segmentation_viewer) to visualize the results, so that does not really matter. If you want to see it directly in an image viewer, you can either use the remap script, or use imagemagick: 196 | ``` 197 | mogrify -normalize data/ISLES/val/gt/*.png 198 | ``` 199 | 200 | #### results 201 | ``` 202 | results/ 203 | isles/ 204 | gdl/ 205 | best_epoch/ 206 | val/ 207 | case_10_0_0.png 208 | ... 209 | iter000/ 210 | val/ 211 | ... 212 | best.pkl # best model saved 213 | metrics.csv # metrics over time, csv 214 | best_epoch.txt # number of the best epoch 215 | val_dice.npy # log of all the metric over time for each image and class 216 | gdl_surface_steal/ 217 | ... 218 | val_dice.png # Plot over time comparing different methods 219 | ... 220 | wmh/ 221 | ... 222 | archives/ 223 | $(REPO)-$(DATE)-$(HASH)-$(HOSTNAME)-isles.tar.gz 224 | $(REPO)-$(DATE)-$(HASH)-$(HOSTNAME)-wmh.tar.gz 225 | ``` 226 | 227 | ### Cool tricks 228 | Remove all assertions from the code. Usually done after making sure it does not crash for one complete epoch: 229 | ```sh 230 | make -f isles.make CFLAGS=-O 231 | ``` 232 | 233 | Use a specific python executable: 234 | ```sh 235 | make -f isles.make CC=/path/to/the/executable 236 | ``` 237 | 238 | Train for only 5 epochs, with a dummy network, and only 10 images per data loader. Useful for debugging: 239 | ```sh 240 | make -f isles.make NET=Dimwit EPC=5 DEBUG=--debug 241 | ``` 242 | 243 | Rebuild everything even if already exist: 244 | ```sh 245 | make -f isles.make -B 246 | ``` 247 | 248 | Only print the commands that will be run: 249 | ```sh 250 | make -f isles.make -n 251 | ``` 252 | 253 | Create a gif for the predictions over time of a specific patient: 254 | ``` 255 | cd results/isles/gdl 256 | convert iter*/val/case_14_0_0.png case_14_0_0.gif 257 | mogrify -normalize case_14_0_0.gif 258 | ``` 259 | 260 | ## Multi-class setting 261 | The implementation for multi-class is trivial and requires no modification: one only requires to change the parameters `idc` of the boundary loss to supervise all classes. In the case of ACDC (4-classes), we have: 262 | ```python 263 | boundary_loss = BoundaryLoss(idc=[0, 1, 2, 3]) 264 | ``` 265 | 266 | ![Boundary loss as a stand-alone loss](resources/acdc_bl.png) 267 | 268 | ## Frequently asked question 269 | ### Can the loss be negative? 270 | Yes. As the distance map is signed (meaning that inside the object, the distance is negative), a perfect prediction will sum only negative distances, leading to a negative value. As we are in a minimization setting, this is not an issue. 271 | 272 | ### Do I need to normalize the distance map? 273 | Possibly, it will be dataset dependent. In our experiments, we did not had to, but several persons reported that normalization helped in their respective application. 274 | 275 | ## Other papers using boundary loss 276 | If your paper uses the boundary loss, and you want to be added, feel free to drop us a message. 277 | 278 | * How Distance Transform Maps Boost Segmentation CNNs: An Empirical Study, MIDL 2020, [Conference link](https://2020.midl.io/papers/ma20a.html), [code](https://github.com/JunMa11/SegWithDistMap) 279 | * Multi-modal U-Nets with Boundary Loss and Pre-training for Brain Tumor Segmentation, MICCAI Brainlesion Workshop 2020, [proceedings](https://link.springer.com/chapter/10.1007/978-3-030-46643-5_13) 280 | * Deep learning approach to left ventricular non-compactionmeasurement, pre-print 2020, [arXiv](https://arxiv.org/abs/2011.14773) 281 | * Esophageal Tumor Segmentation in CT Imagesusing Dilated Dense Attention Unet (DDAUnet), pre-print 2020, [arXiv](https://arxiv.org/abs/2012.03242), [code](https://github.com/yousefis/DenseUnet_Esophagus_Segmentation) 282 | * Vehicle lane markings segmentation and keypoint determination using deep convolutional neural networks, Multimedia Tools and Applications (2021), [journal link](https://link.springer.com/article/10.1007/s11042-020-10248-2) 283 | * A global method to identify trees outside of closed-canopy forests with medium-resolution satellite imagery, International journal of remote sensing 2021, [DOI](https://doi.org/10.1080/01431161.2020.1841324), [code](https://github.com/wri/restoration-mapper) -------------------------------------------------------------------------------- /release.md: -------------------------------------------------------------------------------- 1 | # 1.0 release 2 | 3 | Differences: 4 | * 3d computation of the distance maps, done at the pre-processing time 5 | * Other losses implementation, for comparison 6 | * Example of a multi-class setting (ACDC dataset) 7 | * Improved readme and instructions 8 | * Updated for latest python and pytorch releases 9 | * Removed all remnants from our constrained-cnn works (shared codebase) 10 | * Add colors to the makefile, for improved readability 11 | * More flexible makefiles, that allow separate results folders (through the RD environment variable) 12 | * submodule for the viewer, pointing to https://github.com/HKervadec/segmentation_viewer 13 | * The training recipe now store commit-hash and diff, to track exactly the code version -------------------------------------------------------------------------------- /report.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import argparse 26 | from pathlib import Path 27 | 28 | import numpy as np 29 | 30 | 31 | def main(args) -> None: 32 | print(f"Reporting on {len(args.folders)} folders.") 33 | 34 | main_metric: str = args.metrics[0] 35 | 36 | best_epoch: list[int] = display_metric(args, main_metric, args.folders, args.axises) 37 | for metric in args.metrics[1:]: 38 | display_metric(args, metric, args.folders, args.axises, best_epoch) 39 | 40 | 41 | def display_metric(args, metric: str, folders: list[str], axises: tuple[int], best_epoch: list[int] = None): 42 | print(f"{metric} (classes {axises})") 43 | 44 | if not best_epoch: 45 | get_epoch = True 46 | best_epoch = [0] * len(folders) 47 | else: 48 | get_epoch = False 49 | 50 | for i, folder in enumerate(folders): 51 | file: Path = Path(folder, metric).with_suffix(".npy") 52 | data: np.ndarray = np.load(file)[:, :, axises] # Epoch, sample, classes 53 | averages: np.ndarray = data.mean(axis=(1, 2)) 54 | stds: np.ndarray = data.std(axis=(1, 2)) 55 | 56 | class_wise_avg: np.ndarray = data.mean(axis=1) 57 | class_wise_std: np.ndarray = data.std(axis=1) 58 | 59 | if get_epoch: 60 | if args.mode == "max": 61 | best_epoch[i] = np.argmax(averages) 62 | elif args.mode == "min": 63 | best_epoch[i] = np.argmin(averages) 64 | 65 | val: float 66 | val_std: float 67 | if args.mode in ['max', 'min']: 68 | val = averages[best_epoch[i]] 69 | val_std = stds[best_epoch[i]] 70 | val_class_wise = class_wise_avg[best_epoch[i]] 71 | else: 72 | val = averages[-args.last_n_epc:].mean() 73 | val_std = averages[-args.last_n_epc:].std() 74 | val_class_wise = class_wise_avg[-args.last_n_epc:].mean(axis=0) 75 | 76 | assert val_class_wise.shape == (len(axises),) 77 | 78 | precision: int = args.precision 79 | print(f"\t{Path(folder).name}: {val:.{precision}f} ({val_std:.{precision}f}) at epoch {best_epoch[i]}") 80 | if len(axises) > 1 and args.detail_axises: 81 | val_cw_std = class_wise_std[best_epoch[i]] 82 | assert val_cw_std.shape == (len(axises),) 83 | 84 | # print(f"\t\t {' '.join(f'{a}={val_class_wise[j]:.{precision}f}' for j,a in enumerate(axises))}") 85 | print(f"\t\t {' '.join(f'{a}={val_class_wise[j]:.{precision}f} ({val_cw_std[j]:.{precision}f})' for j,a in enumerate(axises))}") 86 | 87 | return best_epoch 88 | 89 | 90 | def get_args() -> argparse.Namespace: 91 | parser = argparse.ArgumentParser(description='Plot data over time') 92 | parser.add_argument('--folders', type=str, required=True, nargs='+', help="The folders containing the file") 93 | parser.add_argument('--metrics', type=str, required=True, nargs='+') 94 | parser.add_argument('--axises', type=int, required=True, nargs='+') 95 | parser.add_argument('--mode', type=str, default='max', choices=['max', 'min', 'avg']) 96 | parser.add_argument('--last_n_epc', type=int, default=1) 97 | parser.add_argument('--precision', type=int, default=4) 98 | parser.add_argument('--debug', action='store_true', help="Dummy for compatibility.") 99 | 100 | parser.add_argument('--detail_axises', action='store_true', 101 | help="Print each axis value on top of the mean") 102 | 103 | args = parser.parse_args() 104 | 105 | print(args) 106 | 107 | return args 108 | 109 | 110 | if __name__ == "__main__": 111 | main(get_args()) 112 | -------------------------------------------------------------------------------- /resources/acdc_bl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LIVIAETS/boundary-loss/171c32d88a4ce59af8be46fb88b96d3637b9515b/resources/acdc_bl.png -------------------------------------------------------------------------------- /resources/readme_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LIVIAETS/boundary-loss/171c32d88a4ce59af8be46fb88b96d3637b9515b/resources/readme_comparison.png -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from typing import Any, Callable, Tuple 26 | from operator import add 27 | 28 | from utils import map_, uc_ 29 | 30 | 31 | class DummyScheduler(object): 32 | def __call__(self, epoch: int, optimizer: Any, loss_fns: list[list[Callable]], loss_weights: list[list[float]]) \ 33 | -> Tuple[float, list[list[Callable]], list[list[float]]]: 34 | return optimizer, loss_fns, loss_weights 35 | 36 | 37 | class AddWeightLoss(): 38 | def __init__(self, to_add: list[float]): 39 | self.to_add: list[float] = to_add 40 | 41 | def __call__(self, epoch: int, optimizer: Any, loss_fns: list[list[Callable]], loss_weights: list[list[float]]) \ 42 | -> Tuple[float, list[list[Callable]], list[list[float]]]: 43 | assert len(self.to_add) == len(loss_weights[0]) 44 | if len(loss_weights) > 1: 45 | raise NotImplementedError 46 | new_weights: list[list[float]] = map_(lambda w: map_(uc_(add), zip(w, self.to_add)), loss_weights) 47 | 48 | print(f"Loss weights went from {loss_weights} to {new_weights}") 49 | 50 | return optimizer, loss_fns, new_weights 51 | 52 | 53 | class StealWeight(): 54 | def __init__(self, to_steal: float): 55 | self.to_steal: float = to_steal 56 | 57 | def __call__(self, epoch: int, optimizer: Any, loss_fns: list[list[Callable]], loss_weights: list[list[float]]) \ 58 | -> Tuple[float, list[list[Callable]], list[list[float]]]: 59 | new_weights: list[list[float]] = [[max(0.1, a - self.to_steal), b + self.to_steal] for a, b in loss_weights] 60 | 61 | print(f"Loss weights went from {loss_weights} to {new_weights}") 62 | 63 | return optimizer, loss_fns, new_weights 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import unittest 26 | 27 | import torch 28 | import numpy as np 29 | 30 | import utils 31 | 32 | class TestDice(unittest.TestCase): 33 | def test_equal(self): 34 | t = torch.zeros((1, 100, 100), dtype=torch.int64) 35 | t[0, 40:60, 40:60] = 1 36 | 37 | c = utils.class2one_hot(t, K=2) 38 | 39 | self.assertEqual(utils.dice_coef(c, c)[0, 0], 1) 40 | 41 | def test_empty(self): 42 | t = torch.zeros((1, 100, 100), dtype=torch.int64) 43 | t[0, 40:60, 40:60] = 1 44 | 45 | c = utils.class2one_hot(t, K=2) 46 | 47 | self.assertEqual(utils.dice_coef(c, c)[0, 0], 1) 48 | 49 | def test_caca(self): 50 | t = torch.zeros((1, 100, 100), dtype=torch.int64) 51 | t[0, 40:60, 40:60] = 1 52 | 53 | c = utils.class2one_hot(t, K=2) 54 | z = torch.zeros_like(c) 55 | z[0, 1, ...] = 1 56 | 57 | self.assertEqual(utils.dice_coef(c, z, smooth=0)[0, 0], 0) # Annoying to deal with the almost equal thing 58 | 59 | 60 | class TestHausdorff(unittest.TestCase): 61 | def test_closure(self): 62 | t = torch.zeros((1, 256, 256), dtype=torch.int64) 63 | t[0, 50:60, :] = 1 64 | 65 | t2 = utils.class2one_hot(t, K=2) 66 | self.assertEqual(tuple(t2.shape), (1, 2, 256, 256)) 67 | 68 | self.assertTrue(torch.equal(utils.hausdorff(t2, t2), torch.zeros((1, 2)))) 69 | 70 | def test_empty(self): 71 | t = torch.zeros((1, 256, 256), dtype=torch.int64) 72 | 73 | t2 = utils.class2one_hot(t, K=2) 74 | self.assertEqual(tuple(t2.shape), (1, 2, 256, 256)) 75 | 76 | self.assertTrue(torch.equal(utils.hausdorff(t2, t2), torch.zeros((1, 2)))) 77 | 78 | def test_caca(self): 79 | t = torch.zeros((1, 256, 256), dtype=torch.int64) 80 | t[0, 50:60, :] = 1 81 | 82 | t2 = utils.class2one_hot(t, K=2) 83 | self.assertEqual(tuple(t2.shape), (1, 2, 256, 256)) 84 | 85 | z = torch.zeros_like(t) 86 | z2 = utils.class2one_hot(z, K=2) 87 | 88 | diag = (256**2 + 256**2) ** 0.5 89 | # print(f"{diag=}") 90 | # print(f"{utils.hausdorff(z2, t2)=}") 91 | 92 | self.assertTrue(torch.equal(utils.hausdorff(z2, t2), 93 | torch.tensor([[60, diag]], dtype=torch.float32))) 94 | 95 | def test_proper(self): 96 | t = torch.zeros((1, 256, 256), dtype=torch.int64) 97 | t[0, 50:60, :] = 1 98 | 99 | t2 = utils.class2one_hot(t, K=2) 100 | self.assertEqual(tuple(t2.shape), (1, 2, 256, 256)) 101 | 102 | z = torch.zeros_like(t) 103 | z[0, 80:90, :] = 1 104 | z2 = utils.class2one_hot(z, K=2) 105 | 106 | self.assertTrue(torch.equal(utils.hausdorff(z2, t2), 107 | torch.tensor([[30, 30]], dtype=torch.float32))) 108 | 109 | 110 | class TestDistMap(unittest.TestCase): 111 | def test_closure(self): 112 | a = np.zeros((1, 256, 256)) 113 | a[:, 50:60, :] = 1 114 | 115 | o = utils.class2one_hot(torch.Tensor(a).type(torch.int64), K=2).numpy() 116 | res = utils.one_hot2dist(o[0]) 117 | self.assertEqual(res.shape, (2, 256, 256)) 118 | 119 | neg = (res <= 0) * res 120 | 121 | self.assertEqual(neg.sum(), (o * res).sum()) 122 | 123 | def test_full_coverage(self): 124 | a = np.zeros((1, 256, 256)) 125 | a[:, 50:60, :] = 1 126 | 127 | o = utils.class2one_hot(torch.Tensor(a).type(torch.int64), K=2).numpy() 128 | res = utils.one_hot2dist(o[0]) 129 | self.assertEqual(res.shape, (2, 256, 256)) 130 | 131 | self.assertEqual((res[1] <= 0).sum(), a.sum()) 132 | self.assertEqual((res[1] > 0).sum(), (1 - a).sum()) 133 | 134 | def test_empty(self): 135 | a = np.zeros((1, 256, 256)) 136 | 137 | o = utils.class2one_hot(torch.Tensor(a).type(torch.int64), K=2).numpy() 138 | res = utils.one_hot2dist(o[0]) 139 | self.assertEqual(res.shape, (2, 256, 256)) 140 | 141 | self.assertEqual(res[1].sum(), 0) 142 | self.assertEqual((res[0] <= 0).sum(), a.size) 143 | 144 | def test_max_dist(self): 145 | """ 146 | The max dist for a box should be at the midle of the object, +-1 147 | """ 148 | a = np.zeros((1, 256, 256)) 149 | a[:, 1:254, 1:254] = 1 150 | 151 | o = utils.class2one_hot(torch.Tensor(a).type(torch.int64), K=2).numpy() 152 | res = utils.one_hot2dist(o[0]) 153 | self.assertEqual(res.shape, (2, 256, 256)) 154 | 155 | self.assertEqual(res[0].max(), 127) 156 | self.assertEqual(np.unravel_index(res[0].argmax(), (256, 256)), (127, 127)) 157 | 158 | self.assertEqual(res[1].min(), -126) 159 | self.assertEqual(np.unravel_index(res[1].argmin(), (256, 256)), (127, 127)) 160 | 161 | def test_border(self): 162 | """ 163 | Make sure the border inside the object is 0 in the distance map 164 | """ 165 | 166 | for l in range(3, 5): 167 | a = np.zeros((1, 25, 25)) 168 | a[:, 3:3 + l, 3:3 + l] = 1 169 | 170 | o = utils.class2one_hot(torch.Tensor(a).type(torch.int64), K=2).numpy() 171 | res = utils.one_hot2dist(o[0]) 172 | self.assertEqual(res.shape, (2, 25, 25)) 173 | 174 | border = (res[1] == 0) 175 | 176 | self.assertEqual(border.sum(), 4 * (l - 1)) 177 | 178 | 179 | if __name__ == "__main__": 180 | unittest.main() 181 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3.9 2 | 3 | # MIT License 4 | 5 | # Copyright (c) 2023 Hoel Kervadec 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import argparse 26 | from pathlib import Path 27 | from operator import add 28 | from multiprocessing.pool import Pool 29 | from random import random, uniform, randint 30 | from functools import partial 31 | 32 | from typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union, cast 33 | 34 | import torch 35 | import numpy as np 36 | import torch.sparse 37 | from tqdm import tqdm 38 | from torch import einsum 39 | from torch import Tensor 40 | from skimage.io import imsave 41 | from PIL import Image, ImageOps 42 | from medpy.metric.binary import hd 43 | from scipy.ndimage import distance_transform_edt as eucl_distance 44 | 45 | 46 | colors = ["c", "r", "g", "b", "m", 'y', 'k', 'chartreuse', 'coral', 'gold', 'lavender', 47 | 'silver', 'tan', 'teal', 'wheat', 'orchid', 'orange', 'tomato'] 48 | 49 | # functions redefinitions 50 | tqdm_ = partial(tqdm, dynamic_ncols=True, 51 | leave=False, 52 | bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [' '{rate_fmt}{postfix}]') 53 | 54 | A = TypeVar("A") 55 | B = TypeVar("B") 56 | T = TypeVar("T", Tensor, np.ndarray) 57 | 58 | 59 | def str2bool(v): 60 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 61 | return True 62 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 63 | return False 64 | else: 65 | raise argparse.ArgumentTypeError('Boolean value expected.') 66 | 67 | 68 | def map_(fn: Callable[[A], B], iter: Iterable[A]) -> List[B]: 69 | return list(map(fn, iter)) 70 | 71 | 72 | def mmap_(fn: Callable[[A], B], iter: Iterable[A]) -> List[B]: 73 | return Pool().map(fn, iter) 74 | 75 | 76 | def starmmap_(fn: Callable[[Tuple[A]], B], iter: Iterable[Tuple[A]]) -> List[B]: 77 | return Pool().starmap(fn, iter) 78 | 79 | 80 | def uc_(fn: Callable) -> Callable: 81 | return partial(uncurry, fn) 82 | 83 | 84 | def uncurry(fn: Callable, args: List[Any]) -> Any: 85 | return fn(*args) 86 | 87 | 88 | def id_(x): 89 | return x 90 | 91 | 92 | def flatten_(to_flat: Iterable[Iterable[A]]) -> List[A]: 93 | return [e for l in to_flat for e in l] 94 | 95 | 96 | def flatten__(to_flat): 97 | if type(to_flat) != list: 98 | return [to_flat] 99 | 100 | return [e for l in to_flat for e in flatten__(l)] 101 | 102 | 103 | def depth(e: List) -> int: 104 | """ 105 | Compute the depth of nested lists 106 | """ 107 | if type(e) == list and e: 108 | return 1 + depth(e[0]) 109 | 110 | return 0 111 | 112 | 113 | # fns 114 | def soft_size(a: Tensor) -> Tensor: 115 | return torch.einsum("bk...->bk", a)[..., None] 116 | 117 | 118 | def batch_soft_size(a: Tensor) -> Tensor: 119 | return torch.einsum("bk...->k", a)[..., None] 120 | 121 | 122 | # Assert utils 123 | def uniq(a: Tensor) -> Set: 124 | return set(torch.unique(a.cpu()).numpy()) 125 | 126 | 127 | def sset(a: Tensor, sub: Iterable) -> bool: 128 | return uniq(a).issubset(sub) 129 | 130 | 131 | def eq(a: Tensor, b) -> bool: 132 | return torch.eq(a, b).all() 133 | 134 | 135 | def simplex(t: Tensor, axis=1) -> bool: 136 | _sum = cast(Tensor, t.sum(axis).type(torch.float32)) 137 | _ones = torch.ones_like(_sum, dtype=torch.float32) 138 | return torch.allclose(_sum, _ones) 139 | 140 | 141 | def one_hot(t: Tensor, axis=1) -> bool: 142 | return simplex(t, axis) and sset(t, [0, 1]) 143 | 144 | 145 | # # Metrics and shitz 146 | def meta_dice(sum_str: str, label: Tensor, pred: Tensor, smooth: float = 1e-8) -> Tensor: 147 | assert label.shape == pred.shape 148 | assert one_hot(label) 149 | assert one_hot(pred) 150 | 151 | inter_size: Tensor = einsum(sum_str, [intersection(label, pred)]).type(torch.float32) 152 | sum_sizes: Tensor = (einsum(sum_str, [label]) + einsum(sum_str, [pred])).type(torch.float32) 153 | 154 | dices: Tensor = (2 * inter_size + smooth) / (sum_sizes + smooth) 155 | 156 | return dices 157 | 158 | 159 | dice_coef = partial(meta_dice, "bk...->bk") 160 | dice_batch = partial(meta_dice, "bk...->k") # used for 3d dice 161 | 162 | 163 | def intersection(a: Tensor, b: Tensor) -> Tensor: 164 | assert a.shape == b.shape 165 | assert sset(a, [0, 1]) 166 | assert sset(b, [0, 1]) 167 | 168 | res = a & b 169 | assert sset(res, [0, 1]) 170 | 171 | return res 172 | 173 | 174 | def union(a: Tensor, b: Tensor) -> Tensor: 175 | assert a.shape == b.shape 176 | assert sset(a, [0, 1]) 177 | assert sset(b, [0, 1]) 178 | 179 | res = a | b 180 | assert sset(res, [0, 1]) 181 | 182 | return res 183 | 184 | 185 | def inter_sum(a: Tensor, b: Tensor) -> Tensor: 186 | return einsum("bk...->bk", intersection(a, b).type(torch.float32)) 187 | 188 | 189 | def union_sum(a: Tensor, b: Tensor) -> Tensor: 190 | return einsum("bk...->bk", union(a, b).type(torch.float32)) 191 | 192 | 193 | def hausdorff(preds: Tensor, target: Tensor, spacing: Tensor = None) -> Tensor: 194 | assert preds.shape == target.shape 195 | assert one_hot(preds) 196 | assert one_hot(target) 197 | 198 | B, K, *img_shape = preds.shape 199 | 200 | if spacing is None: 201 | D: int = len(img_shape) 202 | spacing = torch.ones((B, D), dtype=torch.float32) 203 | 204 | assert spacing.shape == (B, len(img_shape)) 205 | 206 | res = torch.zeros((B, K), dtype=torch.float32, device=preds.device) 207 | n_pred = preds.cpu().numpy() 208 | n_target = target.cpu().numpy() 209 | n_spacing = spacing.cpu().numpy() 210 | 211 | for b in range(B): 212 | # print(spacing[b]) 213 | # if K == 2: 214 | # res[b, :] = hd(n_pred[b, 1], n_target[b, 1], voxelspacing=n_spacing[b]) 215 | # continue 216 | 217 | for k in range(K): 218 | if not n_target[b, k].any(): # No object to predict 219 | if n_pred[b, k].any(): # Predicted something nonetheless 220 | res[b, k] = sum((dd * d)**2 for (dd, d) in zip(n_spacing[b], img_shape)) ** 0.5 221 | continue 222 | else: 223 | res[b, k] = 0 224 | continue 225 | if not n_pred[b, k].any(): 226 | if n_target[b, k].any(): 227 | res[b, k] = sum((dd * d)**2 for (dd, d) in zip(n_spacing[b], img_shape)) ** 0.5 228 | continue 229 | else: 230 | res[b, k] = 0 231 | continue 232 | 233 | res[b, k] = hd(n_pred[b, k], n_target[b, k], voxelspacing=n_spacing[b]) 234 | 235 | return res 236 | 237 | 238 | # switch between representations 239 | def probs2class(probs: Tensor) -> Tensor: 240 | b, _, *img_shape = probs.shape 241 | assert simplex(probs) 242 | 243 | res = probs.argmax(dim=1) 244 | assert res.shape == (b, *img_shape) 245 | 246 | return res 247 | 248 | 249 | def class2one_hot(seg: Tensor, K: int) -> Tensor: 250 | # Breaking change but otherwise can't deal with both 2d and 3d 251 | # if len(seg.shape) == 3: # Only w, h, d, used by the dataloader 252 | # return class2one_hot(seg.unsqueeze(dim=0), K)[0] 253 | 254 | assert sset(seg, list(range(K))), (uniq(seg), K) 255 | 256 | b, *img_shape = seg.shape # type: Tuple[int, ...] 257 | 258 | device = seg.device 259 | res = torch.zeros((b, K, *img_shape), dtype=torch.int32, device=device).scatter_(1, seg[:, None, ...], 1) 260 | 261 | assert res.shape == (b, K, *img_shape) 262 | assert one_hot(res) 263 | 264 | return res 265 | 266 | 267 | def np_class2one_hot(seg: np.ndarray, K: int) -> np.ndarray: 268 | # print("Np enters") 269 | """ 270 | Seems to be blocking here when using multi-processing. 271 | Don't know why, so for now I'll re-implement the same function in numpy 272 | which should be faster anyhow, but can introduce inconsistencies in the code 273 | so need to be careful. 274 | """ 275 | b, w, h = seg.shape 276 | res = np.zeros((b, K, w, h), dtype=np.int64) 277 | np.put_along_axis(res, seg[:, None, :, :], 1, axis=1) 278 | 279 | return res 280 | # return class2one_hot(torch.from_numpy(seg.copy()).type(torch.int64), K).numpy() 281 | 282 | 283 | def probs2one_hot(probs: Tensor) -> Tensor: 284 | _, K, *_ = probs.shape 285 | assert simplex(probs) 286 | 287 | res = class2one_hot(probs2class(probs), K) 288 | assert res.shape == probs.shape 289 | assert one_hot(res) 290 | 291 | return res 292 | 293 | 294 | def one_hot2dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None, 295 | dtype=None) -> np.ndarray: 296 | assert one_hot(torch.tensor(seg), axis=0) 297 | K: int = len(seg) 298 | 299 | res = np.zeros_like(seg, dtype=dtype) 300 | for k in range(K): 301 | posmask = seg[k].astype(np.bool) 302 | 303 | if posmask.any(): 304 | negmask = ~posmask 305 | res[k] = eucl_distance(negmask, sampling=resolution) * negmask \ 306 | - (eucl_distance(posmask, sampling=resolution) - 1) * posmask 307 | # The idea is to leave blank the negative classes 308 | # since this is one-hot encoded, another class will supervise that pixel 309 | 310 | return res 311 | 312 | 313 | def one_hot2hd_dist(seg: np.ndarray, resolution: Tuple[float, float, float] = None, 314 | dtype=None) -> np.ndarray: 315 | """ 316 | Used for https://arxiv.org/pdf/1904.10030.pdf, 317 | implementation from https://github.com/JunMa11/SegWithDistMap 318 | """ 319 | # Relasx the assertion to allow computation live on only a 320 | # subset of the classes 321 | # assert one_hot(torch.tensor(seg), axis=0) 322 | K: int = len(seg) 323 | 324 | res = np.zeros_like(seg, dtype=dtype) 325 | for k in range(K): 326 | posmask = seg[k].astype(np.bool) 327 | 328 | if posmask.any(): 329 | res[k] = eucl_distance(posmask, sampling=resolution) 330 | 331 | return res 332 | 333 | 334 | # Misc utils 335 | def save_images(segs: Tensor, names: Iterable[str], root: str, mode: str, iter: int) -> None: 336 | for seg, name in zip(segs, names): 337 | save_path = Path(root, f"iter{iter:03d}", mode, name).with_suffix(".png") 338 | save_path.parent.mkdir(parents=True, exist_ok=True) 339 | 340 | if len(seg.shape) == 2: 341 | imsave(str(save_path), seg.detach().cpu().numpy().astype(np.uint8)) 342 | elif len(seg.shape) == 3: 343 | np.save(str(save_path), seg.detach().cpu().numpy()) 344 | else: 345 | raise ValueError("How did you get here") 346 | 347 | 348 | def augment(*arrs: Union[np.ndarray, Image.Image], rotate_angle: float = 45, 349 | flip: bool = True, mirror: bool = True, 350 | rotate: bool = True, scale: bool = False) -> List[Image.Image]: 351 | imgs: List[Image.Image] = map_(Image.fromarray, arrs) if isinstance(arrs[0], np.ndarray) else list(arrs) 352 | 353 | if flip and random() > 0.5: 354 | imgs = map_(ImageOps.flip, imgs) 355 | if mirror and random() > 0.5: 356 | imgs = map_(ImageOps.mirror, imgs) 357 | if rotate and random() > 0.5: 358 | angle: float = uniform(-rotate_angle, rotate_angle) 359 | imgs = map_(lambda e: e.rotate(angle), imgs) 360 | if scale and random() > 0.5: 361 | scale_factor: float = uniform(1, 1.2) 362 | w, h = imgs[0].size # Tuple[int, int] 363 | nw, nh = int(w * scale_factor), int(h * scale_factor) # Tuple[int, int] 364 | 365 | # Resize 366 | imgs = map_(lambda i: i.resize((nw, nh)), imgs) 367 | 368 | # Now need to crop to original size 369 | bw, bh = randint(0, nw - w), randint(0, nh - h) # Tuple[int, int] 370 | 371 | imgs = map_(lambda i: i.crop((bw, bh, bw + w, bh + h)), imgs) 372 | assert all(i.size == (w, h) for i in imgs) 373 | 374 | return imgs 375 | 376 | 377 | def augment_arr(*arrs_a: np.ndarray, rotate_angle: float = 45, 378 | flip: bool = True, mirror: bool = True, 379 | rotate: bool = True, scale: bool = False, 380 | noise: bool = False, noise_loc: float = 0.5, noise_lambda: float = 0.1) -> List[np.ndarray]: 381 | arrs = list(arrs_a) # manoucherie type check 382 | 383 | if flip and random() > 0.5: 384 | arrs = map_(np.flip, arrs) 385 | if mirror and random() > 0.5: 386 | arrs = map_(np.fliplr, arrs) 387 | if noise and random() > 0.5: 388 | mask: np.ndarray = np.random.laplace(noise_loc, noise_lambda, arrs[0].shape) 389 | arrs = map_(partial(add, mask), arrs) 390 | arrs = map_(lambda e: (e - e.min()) / (e.max() - e.min()), arrs) 391 | # if random() > 0.5: 392 | # orig_shape = arrs[0].shape 393 | 394 | # angle = random() * 90 - 45 395 | # arrs = map_(lambda e: sp.ndimage.rotate(e, angle, order=1), arrs) 396 | 397 | # arrs = get_center(orig_shape, *arrs) 398 | 399 | return arrs 400 | 401 | 402 | def get_center(shape: Tuple, *arrs: np.ndarray) -> List[np.ndarray]: 403 | """ center cropping """ 404 | def g_center(arr): 405 | if arr.shape == shape: 406 | return arr 407 | 408 | offsets: List[int] = [(arrs - s) // 2 for (arrs, s) in zip(arr.shape, shape)] 409 | 410 | if 0 in offsets: 411 | return arr[[slice(0, s) for s in shape]] 412 | 413 | res = arr[[slice(d, -d) for d in offsets]][[slice(0, s) for s in shape]] # Deal with off-by-one errors 414 | assert res.shape == shape, (res.shape, shape, offsets) 415 | 416 | return res 417 | 418 | return [g_center(arr) for arr in arrs] 419 | 420 | 421 | def center_pad(arr: np.ndarray, target_shape: Tuple[int, ...]) -> np.ndarray: 422 | assert len(arr.shape) == len(target_shape) 423 | 424 | diff: List[int] = [(nx - x) for (x, nx) in zip(arr.shape, target_shape)] 425 | pad_width: List[Tuple[int, int]] = [(w // 2, w - (w // 2)) for w in diff] 426 | 427 | res = np.pad(arr, pad_width) 428 | assert res.shape == target_shape, (res.shape, target_shape) 429 | 430 | return res 431 | -------------------------------------------------------------------------------- /wmh.make: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2023 Hoel Kervadec 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | CC = python3 24 | SHELL = /usr/bin/zsh 25 | PP = PYTHONPATH="$(PYTHONPATH):." 26 | 27 | # RD stands for Result DIR -- useful way to report from extracted archive 28 | RD = results/wmh 29 | 30 | .PHONY = all boundary plot train metrics hausdorff pack 31 | 32 | red:=$(shell tput bold ; tput setaf 1) 33 | green:=$(shell tput bold ; tput setaf 2) 34 | yellow:=$(shell tput bold ; tput setaf 3) 35 | blue:=$(shell tput bold ; tput setaf 4) 36 | reset:=$(shell tput sgr0) 37 | 38 | # CFLAGS = -O 39 | # DEBUG = --debug 40 | EPC = 100 41 | # EPC = 5 42 | 43 | K = 2 44 | BS = 8 45 | G_RGX = (\d+_\d+)_\d+ 46 | P_RGX = (\d+)_\d+_\d+ 47 | NET = UNet 48 | B_DATA = [('in_npy', tensor_transform, False), ('gt_npy', gt_transform, True)] 49 | 50 | TRN = $(RD)/gdl $(RD)/gdl_surface_steal $(RD)/gdl_3d_surface_steal $(RD)/gdl_hausdorff_w 51 | 52 | GRAPH = $(RD)/tra_loss.png $(RD)/val_loss.png \ 53 | $(RD)/val_dice.png $(RD)/tra_dice.png \ 54 | $(RD)/val_3d_hausdorff.png \ 55 | $(RD)/val_3d_hd95.png 56 | BOXPLOT = $(RD)/val_dice_boxplot.png 57 | PLT = $(GRAPH) $(BOXPLOT) 58 | 59 | REPO = $(shell basename `git rev-parse --show-toplevel`) 60 | DATE = $(shell date +"%y%m%d") 61 | HASH = $(shell git rev-parse --short HEAD) 62 | HOSTNAME = $(shell hostname) 63 | PBASE = archives 64 | PACK = $(PBASE)/$(REPO)-$(DATE)-$(HASH)-$(HOSTNAME)-wmh.tar.gz 65 | 66 | all: $(PACK) 67 | 68 | plot: $(PLT) 69 | 70 | train: $(TRN) 71 | 72 | pack: report $(PACK) 73 | $(PACK): $(PLT) $(TRN) 74 | $(info $(red)tar cf $@$(reset)) 75 | mkdir -p $(@D) 76 | # tar -zc -f $@ $^ # Use if pigz is not available 77 | tar cf - $^ | pigz > $@ 78 | chmod -w $@ 79 | 80 | 81 | # Extraction and slicing 82 | data/WMH/train/in_npy data/WMH/val/in_npy: data/WMH 83 | data/WMH: data/wmh 84 | $(info $(yellow)$(CC) $(CFLAGS) preprocess/slice_wmh.py$(reset)) 85 | rm -rf $@_tmp 86 | $(PP) $(CC) $(CFLAGS) preprocess/slice_wmh.py --source_dir $< --dest_dir $@_tmp --n_augment=0 --retain=10 87 | mv $@_tmp $@ 88 | 89 | data/wmh: data/wmh.lineage data/Amsterdam_GE3T.zip data/Singapore.zip data/Utrecht.zip 90 | $(info $(yellow)unzip data/Amsterdam_GE3T.zip data/Singapore.zip data/Utrecht.zip$(reset)) 91 | md5sum -c $< 92 | rm -rf $@_tmp $@ 93 | unzip -q $(word 2, $^) -d $@_tmp 94 | unzip -q $(word 3, $^) -d $@_tmp 95 | unzip -q $(word 4, $^) -d $@_tmp 96 | mv $@_tmp/*/* $@_tmp && rmdir $@_tmp/GE3T $@_tmp/Singapore $@_tmp/Utrecht 97 | rm -r $@_tmp/*/orig # Do not care about that part 98 | rm -r $@_tmp/*/pre/3DT1.nii.gz # Cannot align to the rest 99 | mv $@_tmp $@ 100 | 101 | 102 | # Training 103 | $(RD)/gdl: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1)]" 104 | $(RD)/gdl: data/WMH/train/in_npy data/WMH/val/in_npy 105 | $(RD)/gdl: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True)]" 106 | 107 | $(RD)/gdl_surface_w: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 108 | ('SurfaceLoss', {'idc': [1]}, 0.1)]" 109 | $(RD)/gdl_surface_w: data/WMH/train/in_npy data/WMH/val/in_npy 110 | $(RD)/gdl_surface_w: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 111 | ('gt_npy', dist_map_transform, False)]" 112 | 113 | $(RD)/gdl_hausdorff_w: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 114 | ('HausdorffLoss', {'idc': [1]}, 0.1)]" 115 | $(RD)/gdl_hausdorff_w: data/WMH/train/in_npy data/WMH/val/in_npy 116 | $(RD)/gdl_hausdorff_w: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 117 | ('gt_npy', gt_transform, True)]" 118 | 119 | 120 | $(RD)/hausdorff: OPT = --losses="[('HausdorffLoss', {'idc': [1]}, 0.1)]" 121 | $(RD)/hausdorff: data/WMH/train/in_npy data/WMH/val/in_npy 122 | $(RD)/hausdorff: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True)]" 123 | 124 | 125 | $(RD)/gdl_surface_add: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 126 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 127 | $(RD)/gdl_surface_add: data/WMH/train/in_npy data/WMH/val/in_npy 128 | $(RD)/gdl_surface_add: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 129 | ('gt_npy', dist_map_transform, False)]" \ 130 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 131 | 132 | $(RD)/gdl_surface_steal: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 133 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 134 | $(RD)/gdl_surface_steal: data/WMH/train/in_npy data/WMH/val/in_npy 135 | $(RD)/gdl_surface_steal: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 136 | ('gt_npy', dist_map_transform, False)]" \ 137 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 138 | 139 | 140 | $(RD)/gdl_3d_surface_steal: OPT = --losses="[('GeneralizedDice', {'idc': [0, 1]}, 1), \ 141 | ('SurfaceLoss', {'idc': [1]}, 0.01)]" 142 | $(RD)/gdl_3d_surface_steal: data/WMH/train/in_npy data/WMH/val/in_npy 143 | $(RD)/gdl_3d_surface_steal: DATA = --folders="$(B_DATA)+[('gt_npy', gt_transform, True), \ 144 | ('3d_distmap', raw_npy_transform, False)]" \ 145 | --scheduler=StealWeight --scheduler_params="{'to_steal': 0.01}" 146 | 147 | $(RD)/surface: OPT = --losses="[('SurfaceLoss', {'idc': [1]}, 0.1)]" 148 | $(RD)/surface: data/WMH/train/in_npy data/WMH/val/in_npy 149 | $(RD)/surface: DATA = --folders="$(B_DATA)+[('gt_npy', dist_map_transform, False)]" 150 | 151 | 152 | $(RD)/%: 153 | $(info $(green)$(CC) $(CFLAGS) main.py $@$(reset)) 154 | rm -rf $@_tmp 155 | mkdir -p $@_tmp 156 | printenv > $@_tmp/env.txt 157 | git diff > $@_tmp/repo.diff 158 | git rev-parse --short HEAD > $@_tmp/commit_hash 159 | $(CC) $(CFLAGS) main.py --dataset=$(dir $(