├── utils ├── __init__.py ├── misc.py ├── layers.py ├── instructions.py ├── loss_functions.py ├── volume.py ├── slices.py ├── samples.py ├── centers.py ├── patch.py ├── metrics.py ├── visualization.py └── medpy_hausdorff.py ├── dataset ├── __init__.py ├── Dataset.py ├── Isles15_SISS.py ├── Isles15_SPES.py └── isles2017.py ├── log ├── ISLES2017 │ └── .gitkeep ├── ISLES15_SISS │ └── .gitkeep └── ISLES15_SPES │ └── .gitkeep ├── workflow ├── __init__.py ├── learning.py ├── in_out.py ├── filenames.py ├── network.py ├── generators.py └── evaluate.py ├── architecture ├── __init__.py ├── Architecture.py ├── Ronneberger.py ├── Guerrero.py ├── Cicek.py └── SUNETx4.py ├── results ├── ISLES2017 │ └── .gitkeep ├── ISLES15_SISS │ └── .gitkeep └── ISLES15_SPES │ └── .gitkeep ├── checkpoints ├── ISLES15_SISS │ ├── .gitkeep │ ├── 3Dunet_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_0_to_5.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_12_to_17.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_18_to_23.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_24_to_27.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_6_to_11.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_24_to_27.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_6_to_11.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_24_to_27.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_6_to_11.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_24_to_27.h5 │ └── SUNETx4_3D_(24,24,8)_Adadelta_6_to_11.h5 ├── ISLES15_SPES │ ├── .gitkeep │ ├── 3Dunet_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_0_to_5.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_12_to_17.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_18_to_23.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_24_to_29.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_6_to_11.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_24_to_29.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_6_to_11.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_24_to_29.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_6_to_11.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_0_to_5.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_12_to_17.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_18_to_23.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_24_to_29.h5 │ └── SUNETx4_3D_(24,24,8)_Adadelta_6_to_11.h5 └── ISLES2017 │ ├── .gitkeep │ ├── 2Dunet_2D_(48,48,1)_Adadelta_0_to_8.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_18_to_26.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_27_to_35.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_36_to_42.h5 │ ├── 2Dunet_2D_(48,48,1)_Adadelta_9_to_17.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_0_to_8.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_18_to_26.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_27_to_35.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_36_to_42.h5 │ ├── 3Dunet_3D_(24,24,8)_Adadelta_9_to_17.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_0_to_8.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_0_to_8.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_18_to_26.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_27_to_35.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_36_to_42.h5 │ ├── SUNETx4_3D_(24,24,8)_Adadelta_9_to_17.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_18_to_26.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_27_to_35.h5 │ ├── 3Duresnet_3D_(24,24,8)_Adadelta_36_to_42.h5 │ └── 3Duresnet_3D_(24,24,8)_Adadelta_9_to_17.h5 ├── .gitignore ├── .gitattributes ├── requirements.txt ├── README.md ├── main.py └── configuration.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log/ISLES2017/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /architecture/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log/ISLES15_SISS/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log/ISLES15_SPES/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/ISLES2017/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/ISLES15_SISS/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/ISLES15_SPES/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .fuse* 3 | 4 | .idea/ -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Dunet_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c4ba0910f85f9baf7ea578dded466ed1950180f8b41f26b2e724b6fbea404159 3 | size 22563272 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Dunet_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d113e9a53443d16997594b17181ed110bb8d3b087086460d04863300332515b8 3 | size 22573640 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/2Dunet_2D_(48,48,1)_Adadelta_0_to_8.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:548a385641caa0136e4be3186628971cd6fd1562825b14fb31f43e37f537a476 3 | size 112395048 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/2Dunet_2D_(48,48,1)_Adadelta_18_to_26.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db8cc48743c2248196fe456065c0414a2b280a47e5e5eb511655f5e0a3346445 3 | size 112395376 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/2Dunet_2D_(48,48,1)_Adadelta_27_to_35.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ee80d4432f82a6561b7313f1d2b156c2157b3e88eae277416a3b88d38f3ca72d 3 | size 112395384 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/2Dunet_2D_(48,48,1)_Adadelta_36_to_42.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1729a1c4933e3f2fd4a6469421e0653a358dbc4d64acb85aeb8739d7abae0ec9 3 | size 112395384 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/2Dunet_2D_(48,48,1)_Adadelta_9_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:596cec62acc645cedc57060ea86544aa90cdfaa2a5366697c2d95726a681921e 3 | size 112395424 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Dunet_3D_(24,24,8)_Adadelta_0_to_8.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b36cb70c97becea3c9557319c3adc2183d9c60a5210e7e26a8eed4897722380c 3 | size 22570144 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Dunet_3D_(24,24,8)_Adadelta_18_to_26.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:31ddb2a638d365842f25cc430bc6bedb2e63dacf9d878d765407afeec832de3f 3 | size 22570120 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Dunet_3D_(24,24,8)_Adadelta_27_to_35.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ac52a164d4d1398aebe98f6939388dac94e1f40f7b2197383222cfe8be5cce8d 3 | size 22570136 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Dunet_3D_(24,24,8)_Adadelta_36_to_42.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:55842eecf93fbf4232bfce92f9fdfb98dc698e280567392b3110d8fbc3ddef91 3 | size 22570144 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Dunet_3D_(24,24,8)_Adadelta_9_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2125eb5bdcffa6347866f100c126da23b674f13797dce384a53d6ce5e841aee0 3 | size 22570120 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Duresnet_3D_(24,24,8)_Adadelta_0_to_8.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a9474b18d40aefb5b86c2f39fbf9fe704c6c7bf345faebd42a341966fd0349a5 3 | size 10695624 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/SUNETx4_3D_(24,24,8)_Adadelta_0_to_8.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6ef0c7154ab03592e7e3e56894b9699f3a6dfb48ad9f75893ebc6dc167a7b4b0 3 | size 28321592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/SUNETx4_3D_(24,24,8)_Adadelta_18_to_26.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3aa5b13720f38b27139c639be2b8f1ae0382e83a40a81045963be2d8207877ba 3 | size 28321592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/SUNETx4_3D_(24,24,8)_Adadelta_27_to_35.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9ea55dcdcf3db22a1c8db464cd1eeec3ba3c874290045c235c050d758e1d9725 3 | size 28321592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/SUNETx4_3D_(24,24,8)_Adadelta_36_to_42.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4218b8dcf10e20d91c619670c88565646b2ab015dbb9d428e7df2864b103dbcc 3 | size 28321592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/SUNETx4_3D_(24,24,8)_Adadelta_9_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f57022db1e3731a90f5b35cc216db371fe4d3e0ef25874cf5e5730ca93c9b701 3 | size 28321592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/2Dunet_2D_(48,48,1)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:21b7dcbe202020d8cbd1b1bc762852dd329cf93579b19e3799fc21c672a077bf 3 | size 112390664 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/2Dunet_2D_(48,48,1)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8429466538f7b2c82bfe6b6f7d1967934a336e86a0e576cae6cdb08758b5db64 3 | size 112390776 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/2Dunet_2D_(48,48,1)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:666556bdf3b4b952bd9a4cc2a947627279e10db7fea0653e0986d511d26f6578 3 | size 112390800 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/2Dunet_2D_(48,48,1)_Adadelta_24_to_27.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c855a355e4e620c3ffa586d791d85e1181d4cf6968d95834583387161f96a816 3 | size 112390808 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/2Dunet_2D_(48,48,1)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d4f80b9742da69ee2a9480a4fd6937b1eba14e605e779481250e65fdc606eaa4 3 | size 112390672 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Dunet_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b8aef7a66630685111978f08158acf1e29c4355f180b048f83848a1dbb4838b5 3 | size 22563272 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Dunet_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fd81857b645bfeec06b20e4ac1a7bb7e1121a4fb132f61578c986c798ed9cb67 3 | size 22563552 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Dunet_3D_(24,24,8)_Adadelta_24_to_27.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a140aee25a451ca7939fc6a664f03912ecf8251cb108e5517502d1327f180375 3 | size 22563224 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Dunet_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5f2ec41d3e8beb01d085a965c1d896c0be3a60ee59ddb3279f8803aa1609d9c7 3 | size 22563272 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Duresnet_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:33c1634133c77b9265e8129a224f0e742d7f8529d9c52de2b4c456de1e29c5c9 3 | size 10690568 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Duresnet_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4fe319395e7e0e5d4a7fc0601f76649d420e8bb89d488552d67adf21be4cb678 3 | size 10690952 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Duresnet_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9d299f4f13ce71d0f472f2ad8f3dab5fb261966f33476110c8e7399a756c36d9 3 | size 10690952 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Duresnet_3D_(24,24,8)_Adadelta_24_to_27.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:77126d20176caf85506ed75959a1719d4a11439afdb2ca0a65ddb1728c216175 3 | size 10690952 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/3Duresnet_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a415daac019a99ee78edd9fbd4f823c80ddf34c82b164a2676c80acb95483db4 3 | size 10691224 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/SUNETx4_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c7f9e1611259c9f14217df454360d9fc1da3c1996683c7b8fe98d429c8a564dd 3 | size 28311792 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/SUNETx4_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:19650e94504b18870c70ab3fe450120f5f5a7b50e96151e0d07559ed84ac788a 3 | size 28311408 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/SUNETx4_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c33aedeeba0f685e98c6016373deab9e2501e4c6a1a8488bea1202a017993bd7 3 | size 28311408 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/SUNETx4_3D_(24,24,8)_Adadelta_24_to_27.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fb9f21a2180083570342f33cc30ea66447e2ae3b49c56ef73a84d6c5c44aaee2 3 | size 28311408 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SISS/SUNETx4_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a42861dc5bbe6575b95cec07a8e564d9536417935d5f1a582d24defa8b959543 3 | size 28311464 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/2Dunet_2D_(48,48,1)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d8bcffad6be325bcfe77f01819a340c12f81e851f06b2639709ce674cfaf91c5 3 | size 112397248 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/2Dunet_2D_(48,48,1)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ef14274a79ea6cbb5044ecbab387c4b055f3363e32d9034d0e5ea4d8ef35ab6b 3 | size 112397688 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/2Dunet_2D_(48,48,1)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:786ba9bd340a474225388ba20c2ef7c0d320d7ea9f6a431952b1e4d12f118901 3 | size 112397712 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/2Dunet_2D_(48,48,1)_Adadelta_24_to_29.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:de4c38747b5c5f4a66e2167aef62ffaf1c962c42da37ea1da68e6d93436de2eb 3 | size 112397712 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/2Dunet_2D_(48,48,1)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a931a54347d90a13c000e0274377977b25869d8b80bc0749323e8a7a76c74fc8 3 | size 112397584 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Dunet_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0a313982c802651fad948faa8b886e861938033f32ff3c63991def413ff46062 3 | size 22573640 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Dunet_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:350652ceeaa9460efda97f39b3abe1146c35ddd56c778e29796f3dd041f0fb73 3 | size 22573592 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Dunet_3D_(24,24,8)_Adadelta_24_to_29.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:22e6bfd89fd82227451c847a7b4a2a3be2f253b8b7a214de3343a8a6ea460088 3 | size 22573248 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Dunet_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b5fa933f1b9db9be8d1a5fedad9496b33826438a21da1bb22bbac70653e5defb 3 | size 22573640 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Duresnet_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dbd25915b7764c86570efc8cdc85288b6eeb5e1b0708bef7ad3d5e0488ce27eb 3 | size 10700072 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Duresnet_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1b223ad7130199c4b6996c12f22971902cf2f6f7da9978763e75ccb15f01afc2 3 | size 10700080 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Duresnet_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bcf436aaa0788b4f4201afcb1dacb03a2a2992936975d109b0df5e99f6876b56 3 | size 10701496 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Duresnet_3D_(24,24,8)_Adadelta_24_to_29.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:febf46191c0728890a6a37e39b3eadb56370bc068c026f4fc36e06da9188e230 3 | size 10701816 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/3Duresnet_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0ad93628bfb2522abfba10ddca8197654c73d10ca8c0e3f30ce3cff7aabc39e2 3 | size 10700072 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/SUNETx4_3D_(24,24,8)_Adadelta_0_to_5.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:67502390a5f79a886f03f8938053427b3ccbfe94dfd8b7c51e698d90a204a35e 3 | size 28321624 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/SUNETx4_3D_(24,24,8)_Adadelta_12_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ed9cc387d4971677b44199f426bcd24394b8a5af78163a3506253c521be6f6d6 3 | size 28321744 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/SUNETx4_3D_(24,24,8)_Adadelta_18_to_23.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8bd3a191fe9c7547e055f2357301aa238b0acd8e163ed0494ce745af857138da 3 | size 28322072 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/SUNETx4_3D_(24,24,8)_Adadelta_24_to_29.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cacd24117f2983463f60c504f086a0ed7dd4d06cfc1a362b31d90ad7a333ae35 3 | size 28322072 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES15_SPES/SUNETx4_3D_(24,24,8)_Adadelta_6_to_11.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:67e0637d47c862927dad409c5295c873b62f1e72ffa42ebe3f1c58faa9127a25 3 | size 28322024 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Duresnet_3D_(24,24,8)_Adadelta_18_to_26.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8e51639f6ce881cf1d27f967036e260553576732bd4cc2c241d6e064f7e4f413 3 | size 10695296 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Duresnet_3D_(24,24,8)_Adadelta_27_to_35.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b7a2e2b32010b0dab8d08836c094e1a32c97dde413264920c697ad16afad486a 3 | size 10695624 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Duresnet_3D_(24,24,8)_Adadelta_36_to_42.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:722f8c0ecfce17d682eb2e8c82edb116d9a37d8ef7fb66fa5239c9fd2bfbdd39 3 | size 10695624 4 | -------------------------------------------------------------------------------- /checkpoints/ISLES2017/3Duresnet_3D_(24,24,8)_Adadelta_9_to_17.h5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:39573d6686ee5c51dbef15cc3b11f87b3020617f2cd1ff11ae63376fbf5c955c 3 | size 10695624 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras==2.2.0 2 | matplotlib==2.2.2 3 | MedPy==0.3.0 4 | nibabel==2.3.0 5 | numpy==1.14.5 6 | h5py==2.7.1 7 | scipy==1.1.0 8 | scikit_learn==0.19.1 9 | opencv-python==3.4.1.15 10 | protobuf==3.5.1 11 | six==1.11.0 12 | tensorboard==1.8.0 13 | tensorflow==1.8.0 14 | tensorflow-gpu==1.8.0 15 | tensorflow-tensorboard==1.5.0 16 | -------------------------------------------------------------------------------- /workflow/learning.py: -------------------------------------------------------------------------------- 1 | from utils.instructions import * 2 | 3 | from workflow.generators import TrainPatchGenerator, TestPatchGenerator 4 | 5 | def build_learning_generators(config, dataset): 6 | log.info("Building training generator...") 7 | train_generator = TrainPatchGenerator(config, dataset.train, augment=True, sampling=config.train.sampling) 8 | 9 | log.info("Building validation generator...") 10 | val_generator = TrainPatchGenerator(config, dataset.val, is_validation=True, sampling=config.train.sampling) 11 | 12 | return train_generator, val_generator 13 | 14 | 15 | def build_training_generator(config, sample_in, is_validation=False, augment=False): 16 | return TrainPatchGenerator(config, sample_in, augment=augment, is_validation=is_validation) 17 | 18 | 19 | def build_testing_generator(config, sample_in): 20 | return TestPatchGenerator(config, sample_in) 21 | 22 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def get_resampling_indexes(num_patches, num_resampled): 4 | assert num_patches > 0 5 | 6 | resampled_idxs = list() 7 | sampling_left = num_resampled 8 | 9 | # Repeat all patches until sampling_left is smaller than num_patches 10 | if num_patches < num_resampled: 11 | while sampling_left >= num_patches: 12 | resampled_idxs += range(0, num_patches) 13 | sampling_left -= num_patches 14 | 15 | # Fill rest of indexes with uniform undersampling 16 | if sampling_left > 0: 17 | sampling_step = float(num_patches) / sampling_left 18 | sampling_point = 0.0 19 | for i in range(sampling_left): 20 | resampled_idxs.append(int(math.floor(sampling_point))) 21 | sampling_point += sampling_step 22 | 23 | assert len(resampled_idxs) == num_resampled 24 | return resampled_idxs 25 | -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | import numpy as np 4 | 5 | 6 | class ZeroPaddingChannel(Layer): 7 | def __init__(self, padding, bs=32, **kwargs): 8 | self.padding = padding 9 | self.batch_size = bs 10 | super(ZeroPaddingChannel, self).__init__(**kwargs) 11 | 12 | def build(self, input_shape): 13 | super(ZeroPaddingChannel, self).build(input_shape) # Be sure to call this at the end 14 | 15 | def call(self, x): 16 | shape_padding = list(x.shape) 17 | shape_padding[1] = self.padding 18 | y = K.zeros([self.batch_size] + shape_padding[1:]) 19 | return K.concatenate([x, y], axis=1) 20 | 21 | def compute_output_shape(self, input_shape): 22 | shape = list(input_shape) 23 | assert len(shape) == 5 24 | shape[1] += self.padding 25 | return tuple(shape) 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (DEPRECATED) updated version at (https://github.com/NIC-VICOROB/stroke-core-ct-segmentation) 2 | 3 | ## (DEPRECATED) SUNet: a deep learning architecture for acute stroke lesion segmentation and outcome prediction in multimodal MRI 4 | Development framework for evaluation of deep learning architectures in the paper (https://arxiv.org/abs/1810.13304) 5 | 6 | ## Installation 7 | 8 | The method makes use of Keras and Tensorflow. If the method is running on GPU, please make sure CUDA 9.X is correctly installed. Then, in the base directory run: 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ## Running the code 14 | 15 | 1. Read [ISLES challenge registration instructions](https://www.smir.ch/ISLES/Start2017) in the 'How to join' section and register. 16 | 17 | 2. Download and extract the [ISLES2015](https://www.smir.ch/ISLES/Start2015) (SISS and SPES) and [ISLES2017](https://www.smir.ch/ISLES/Start2017) datasets. 18 | 19 | 3. Update the dataset dictionary with the path to each dataset in `configuration.py` (line 137). 20 | 21 | 4. Reproduce the cross-validation results in the paper by running : 22 | 23 | ``` 24 | python main.py 25 | ``` 26 | 27 | For each performed cross-validation: 28 | 1. The included pre-trained models from `checkpoints/` will be loaded for the corresponding fold. 29 | 30 | 2. The corresponding validation images of the training set will be segmented. 31 | 32 | 3. Finally, the computed evaluation metrics will be written to a spreadsheet file. 33 | 34 | 5. Accessing the results: 35 | + The resulting binary segmentations will be found in the `results/` folder. 36 | + A spreadsheet with the evaluation metrics for each crossvalidation will be in the `metrics/` folder. 37 | -------------------------------------------------------------------------------- /dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import abc # abstract base classes 3 | import logging as log 4 | 5 | class Sample: 6 | def __init__(self, id=None, nib=None, mask=None, data=None, labels=None): 7 | self.id = id 8 | self.nib = nib 9 | self.mask = mask # Brain mask (NOT lesion mask) 10 | self.data = data 11 | self.labels = labels 12 | 13 | class Dataset: 14 | __metaclass__ = abc.ABCMeta 15 | 16 | def __init__(self): 17 | self.train = [] 18 | self.val = [] 19 | self.test = [] 20 | 21 | def load(self, config): 22 | self.load_samples(config.dataset) 23 | self.as_float32() # Convert loaded data to float32 24 | 25 | log.debug("Loaded {} training, {} validation, {} testing".format( 26 | len(self.train), len(self.val), len(self.test))) 27 | 28 | @abc.abstractmethod 29 | def load_samples(self, dataset_info): 30 | pass 31 | 32 | def __is_sample(self, sample_in): 33 | assert isinstance(sample_in, Sample), "Elements added to dataset must be instances of Sample" 34 | 35 | def add_train(self, sample_in): 36 | self.__is_sample(sample_in) 37 | self.train.append(sample_in) 38 | 39 | def add_val(self, sample_in): 40 | self.__is_sample(sample_in) 41 | self.val.append(sample_in) 42 | 43 | def add_test(self, sample_in): 44 | self.__is_sample(sample_in) 45 | self.test.append(sample_in) 46 | 47 | def as_float32(self): 48 | for idx in range(len(self.train)): 49 | self.train[idx].data = self.train[idx].data.astype('float32') 50 | 51 | for idx in range(len(self.val)): 52 | self.val[idx].data = self.val[idx].data.astype('float32') 53 | 54 | for idx in range(len(self.test)): 55 | self.test[idx].data = self.test[idx].data.astype('float32') 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | 4 | import logging as log 5 | from configuration import Configuration 6 | 7 | config = Configuration() 8 | 9 | #print("Using GPU {} on {}".format(config.train.cuda_device_id[socket.gethostname()], socket.gethostname())) 10 | # os.environ["CUDA_VISIBLE_DEVICES"] = str(config.train.cuda_device_id[socket.gethostname()]) 11 | 12 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 13 | os.environ["CUDA_VISIBLE_DEVICES"] = str(config.train.cuda_device_id) #str(config.train.cuda_device_id[socket.gethostname()]) 14 | #os.environ["TF_CPP_MIN_LOG_LEVEL"]="2" # Silence tensorflow initialization messages 15 | 16 | from workflow.evaluate import * 17 | 18 | if config.eval.verbose == 0: 19 | log.basicConfig(format="%(levelname)s: %(message)s", level=log.WARN) 20 | log.warn("logging level WARN") 21 | elif config.eval.verbose == 2: 22 | log.basicConfig(format="%(levelname)s: %(message)s", level=log.DEBUG) 23 | log.debug("logging level DEBUG") 24 | else: 25 | log.basicConfig(format="%(levelname)s: %(message)s", level=log.INFO) 26 | log.info("logging level INFO") 27 | 28 | if config.eval.verbose in [1, 2]: 29 | parse_configuration(config, return_csv=True, print_terminal=True, num_columns=2, exclude=None) 30 | 31 | if config.train.dynamic_gpu_memory: 32 | from keras import backend as K 33 | import tensorflow as tf 34 | config_tf = tf.ConfigProto() 35 | config_tf.gpu_options.allow_growth=True 36 | K.set_session(tf.Session(config=config_tf)) 37 | 38 | evaluation_type = config.eval.evaluation_type 39 | if evaluation_type is 'crossval': 40 | run_crossvalidation(config) 41 | elif evaluation_type is 'eval': 42 | run_evaluation(config) 43 | elif evaluation_type is 'test': 44 | run_testing(config) 45 | elif evaluation_type is 'multi': 46 | run_multi(config) 47 | elif evaluation_type is 'metrics': 48 | run_metrics(config) 49 | elif evaluation_type is 'metrics_search': 50 | run_metrics_search(config) 51 | else: 52 | raise NotImplementedError("\'{}\' evaluation type is not valid".format(evaluation_type)) 53 | 54 | 55 | -------------------------------------------------------------------------------- /workflow/in_out.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import logging as log 3 | import numpy as np 4 | import string 5 | import os 6 | import datetime 7 | 8 | import cv2 9 | 10 | from workflow.filenames import * 11 | 12 | from dataset.isles2017 import Isles2017 13 | from dataset.Isles15_SISS import Isles15_SISS 14 | from dataset.Isles15_SPES import Isles15_SPES 15 | 16 | from architecture.Cicek import Cicek 17 | from architecture.Guerrero import Guerrero 18 | from architecture.Ronneberger import Ronneberger 19 | from architecture.SUNETx4 import SUNETx4 20 | 21 | def get_architecture_instance(config): 22 | arch_name = config.train.architecture_name 23 | 24 | arch = get_architecture_class(arch_name) 25 | 26 | return arch() 27 | 28 | def get_architecture_class(arch): 29 | # DONT FORGET TO IMPORT THE CLASS 30 | arch_dict = { 31 | '3Dunet': Cicek, 32 | '3Duresnet': Guerrero, 33 | '2Dunet': Ronneberger, 34 | 'SUNETx4': SUNETx4, 35 | 'SUNETx4_f32_25k_bs16': SUNETx4 36 | } 37 | 38 | if arch in arch_dict: 39 | return arch_dict[arch] 40 | raise NotImplementedError("Architecture name not linked to object, add entry to dictionary.") 41 | 42 | 43 | def get_dataset_instance(config, dataset=None) : 44 | dataset_name = config.train.dataset_name if dataset is None else dataset 45 | 46 | # DONT FORGET TO IMPORT THE CLASS 47 | dataset_dict = { 48 | 'ISLES15_SISS': Isles15_SISS, 49 | 'ISLES15_SPES': Isles15_SPES, 50 | 'ISLES2017': Isles2017, 51 | } 52 | 53 | assert dataset_name in dataset_dict 54 | dataset_instance = dataset_dict[dataset_name] 55 | return dataset_instance() 56 | 57 | 58 | # Load pre-trained model 59 | def read_model(config, model_def, crossval_id=None) : 60 | model_filename = generate_model_filename(config, crossval_id) 61 | log.debug("Reading model weights from {}".format(os.path.basename(model_filename))) 62 | model_def.load_weights(model_filename) 63 | return model_def 64 | 65 | def save_result_sample(config, sample, result_vol, params=(), file_format='nii.gz', asuint16=False) : 66 | out_filename = generate_result_filename(config, (sample.id, ) + params, file_format) 67 | log.debug("Saving volume {}: {}".format(result_vol.shape, out_filename)) 68 | 69 | # Remove originally empty voxels by multiplying by brain mask 70 | result_vol = np.multiply(result_vol, sample.mask) 71 | if asuint16: 72 | result_vol = result_vol.astype('uint16') 73 | 74 | img = nib.Nifti1Image(result_vol, sample.nib.affine, sample.nib.header) 75 | nib.save(img, out_filename) -------------------------------------------------------------------------------- /utils/instructions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import math 4 | import numpy as np 5 | import logging as log 6 | import scipy.misc 7 | import itertools as iter 8 | 9 | from keras.utils import np_utils, Sequence 10 | import keras.backend as K 11 | 12 | from utils.patch import * 13 | from utils.samples import * 14 | from utils.visualization import * 15 | from utils.misc import * 16 | 17 | class PatchExtractInstruction: 18 | def __init__(self, sample_idx=-1, data_patch_slice=None, label_patch_slice=None, augment_func=None): 19 | self.sample_idx = sample_idx 20 | 21 | self.data_patch_slice = data_patch_slice 22 | self.label_patch_slice = label_patch_slice 23 | 24 | self.augment_func = augment_func 25 | 26 | 27 | def extract_patch_with_instruction(samples, instruction): 28 | assert isinstance(instruction, PatchExtractInstruction) 29 | sample = samples[instruction.sample_idx] if isinstance(samples, list) else samples 30 | 31 | extract_label = instruction.label_patch_slice is not None and sample.labels is not None 32 | 33 | # Extract patches 34 | data_patch = extract_patch_at(sample.data, instruction.data_patch_slice) 35 | label_patch = extract_patch_at(sample.labels, instruction.label_patch_slice) if extract_label else None 36 | 37 | # Augment patches 38 | if instruction.augment_func is not None: 39 | augment_func = get_augment_functions(x_axis=1, y_axis=2)[instruction.augment_func] 40 | data_patch = augment_func(data_patch) 41 | label_patch = augment_func(label_patch) if extract_label else None 42 | 43 | return data_patch, label_patch 44 | 45 | 46 | def augment_instructions(original_instructions, goal_num_instructions): 47 | augment_funcs = get_augment_functions(x_axis=1, y_axis=2) # (modality, x, y, z) 48 | 49 | num_patches_in = len(original_instructions) 50 | num_augments_per_patch = np.minimum( int(math.ceil(goal_num_instructions / num_patches_in)), len(augment_funcs)) 51 | goal_num_augmented_patches = int(goal_num_instructions - num_patches_in) 52 | 53 | # Augment and add remaining copies 54 | sampling_idxs = get_resampling_indexes(num_patches_in, goal_num_augmented_patches // num_augments_per_patch) 55 | func_idxs = get_resampling_indexes(len(augment_funcs), num_augments_per_patch) 56 | 57 | augmented_instructions = list() 58 | for sampling_idx, func_idx in iter.product(sampling_idxs, func_idxs): 59 | aug_instr = copy.copy(original_instructions[sampling_idx]) 60 | aug_instr.augment_func = func_idx 61 | augmented_instructions.append(aug_instr) 62 | 63 | final_instructions = original_instructions + augmented_instructions 64 | 65 | return final_instructions -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | 3 | def jaccard(y_true, y_pred, smooth=100.0): 4 | """ 5 | Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|) 6 | = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|)) 7 | 8 | The jaccard distance loss is usefull for unbalanced datasets. This has been 9 | shifted so it converges on 0 and is smoothed to avoid exploding or disapearing 10 | gradient. 11 | 12 | Ref: https://en.wikipedia.org/wiki/Jaccard_index 13 | 14 | @url: https://gist.githuAdadeltab.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96 15 | @author: wassname 16 | """ 17 | intersection = K.sum(K.abs(y_true * y_pred), axis=-1) 18 | sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) 19 | jac = (intersection + smooth) / (sum_ - intersection + smooth) 20 | return (1 - jac) * smooth 21 | 22 | def dice(ground_truth, prediction): 23 | """ 24 | Computing mean-class Dice similarity. 25 | This function assumes one-hot encoded ground truth (CATEGORICAL :D) 26 | 27 | :param prediction: last dimension should have ``num_classes`` 28 | :param ground_truth: segmentation ground truth (encoded as a binary matrix) 29 | last dimension should be ``num_classes`` 30 | :param weight_map: 31 | :return: ``1.0 - mean(Dice similarity per class)`` 32 | """ 33 | 34 | prediction = K.cast(prediction, dtype='float32') 35 | ground_truth = K.cast(ground_truth, dtype='float32') 36 | 37 | # computing Dice over the spatial dimensions 38 | reduce_axes = list(range(len(prediction.shape) - 1)) 39 | dice_numerator = 2.0 * K.sum(prediction * ground_truth, axis=reduce_axes) 40 | dice_denominator = K.sum(prediction, axis=reduce_axes) + K.sum(ground_truth, axis=reduce_axes) 41 | 42 | epsilon_denominator = 0.0001 43 | dice_score = dice_numerator / (dice_denominator + epsilon_denominator) 44 | return 1.0 - K.mean(dice_score) 45 | 46 | def ss(y_true, y_pred,weight_map=None,r=0.05): 47 | """ 48 | Function to calculate a multiple-ground_truth version of 49 | the sensitivity-specificity loss defined in "Deep Convolutional 50 | Encoder Networks for Multiple Sclerosis Lesion Segmentation", 51 | Brosch et al, MICCAI 2015, 52 | https://link.springer.com/chapter/10.1007/978-3-319-24574-4_1 53 | 54 | error is the sum of r(specificity part) and (1-r)(sensitivity part) 55 | 56 | :param prediction: the logits 57 | :param ground_truth: segmentation ground_truth. 58 | :param r: the 'sensitivity ratio' 59 | (authors suggest values from 0.01-0.10 will have similar effects) 60 | :return: the loss 61 | """ 62 | 63 | # chosen region may contain no voxels of a given label. Prevents nans. 64 | eps = 1e-5 65 | 66 | y_true_f = K.flatten(y_true) 67 | y_pred_f = K.flatten(y_pred) 68 | 69 | sq_error = K.square(y_true_f - y_pred_f) 70 | 71 | spec_part = K.sum(sq_error * y_true_f) / (K.sum(y_true_f) + eps) 72 | sens_part = K.sum(sq_error * (1 - y_true_f)) / (K.sum(1 - y_true_f) + eps) 73 | 74 | return r*spec_part + (1.0 - r)*sens_part -------------------------------------------------------------------------------- /utils/volume.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import logging as log 3 | import os 4 | import itertools 5 | import numpy as np 6 | import math 7 | 8 | from utils.patch import zeropad_patches 9 | from utils.centers import sample_uniform_patch_centers 10 | 11 | def read_volume(filename): 12 | 13 | try: 14 | nib_file = load_nib_file(filename) 15 | except Exception as e: 16 | filename = filename[:-3] if filename.endswith('.gz') else filename + '.gz' 17 | nib_file = load_nib_file(filename) 18 | 19 | return nib_file.get_data() 20 | 21 | def load_nib_file(filename): 22 | return nib.load(filename) 23 | 24 | def remove_zeropad_volume(volume, patch_shape): 25 | # Get padding amount per each dimension 26 | selection = [] 27 | for dim_size in patch_shape: 28 | slice_start = dim_size // 2 29 | slice_stop = -slice_start if slice_start != 0 else None 30 | selection += [slice(slice_start, slice_stop)] 31 | volume = volume[selection] 32 | return volume 33 | 34 | def pad_volume(volume, patch_shape): 35 | assert len(patch_shape) == (len(volume.shape) - 1) 36 | pad_size = [patch_dim // 2 for patch_dim in patch_shape] 37 | padding = ((0, 0), (pad_size[0], pad_size[0]), (pad_size[1], pad_size[1]), (pad_size[2], pad_size[2])) 38 | return np.pad(volume, padding, 'constant', constant_values=0).astype(np.float32) 39 | 40 | def get_brain_mask(vol): 41 | if len(vol.shape) > 3: 42 | vol = vol[0] # Take only first modality 43 | return np.sum(vol, axis=0) > 0 44 | 45 | 46 | def reconstruct_volume(patches, patch_shape, original_vol, original_centers, extraction_step, num_classes) : 47 | expected_shape = original_vol.shape[1:] 48 | assert len(expected_shape) == len(extraction_step) == len(patch_shape) == 3 49 | 50 | if len(patches.shape) != 5: 51 | patches = np.expand_dims(patches, axis=3) # 2D/3D compatibility (None, x, y, new:z, c) 52 | 53 | ### Pad predicted patches to match input patch size (just as they where extracted from original volume) 54 | output_patch_shape = patches.shape[1:] 55 | 56 | ### Compute patch sides for slicing 57 | half_sizes = [[dim // 2, dim // 2] for dim in output_patch_shape] 58 | for i in range(len(half_sizes)): # If even dimension, subtract 1 to account for assymetry 59 | if output_patch_shape[i] % 2 == 0: half_sizes[i][1] -= 1 60 | 61 | ### Voting space 62 | vote_img = np.zeros(expected_shape + (num_classes,), dtype=np.float32) 63 | count_img = np.zeros(expected_shape + (num_classes,), dtype=np.float32) 64 | 65 | # Create counting patch 66 | counting_patch = np.ones(output_patch_shape) 67 | for count, coord in enumerate(original_centers): 68 | selection = [slice(coord[0] - half_sizes[0][0], coord[0] + half_sizes[0][1] + 1), # x 69 | slice(coord[1] - half_sizes[1][0], coord[1] + half_sizes[1][1] + 1), # y 70 | slice(coord[2] - half_sizes[2][0], coord[2] + half_sizes[2][1] + 1), # z 71 | slice(None)] #selects all classes 72 | 73 | vote_img[selection] += patches[count] 74 | count_img[selection] += counting_patch 75 | 76 | count_img[count_img == 0.0] = 1.0 # Avoid division by 0 77 | volume_probs = np.divide(vote_img, count_img) 78 | 79 | lesion_probs = volume_probs[:, :, :, 1] # Get probability of lesion 80 | return lesion_probs -------------------------------------------------------------------------------- /architecture/Architecture.py: -------------------------------------------------------------------------------- 1 | import abc #abstract base classes 2 | from utils.loss_functions import * 3 | import keras.optimizers as Koptimizers 4 | 5 | import itertools as iter 6 | 7 | 8 | from utils.patch import * 9 | from utils.slices import * 10 | from utils.centers import * 11 | 12 | from utils.instructions import * 13 | 14 | class Architecture: 15 | __metaclass__ = abc.ABCMeta 16 | 17 | def __init__(self): 18 | pass 19 | 20 | def generate_compiled_model(self, config, crossval_id=None): 21 | """ 22 | Builds and compiles the keras model associated with this architecture 23 | """ 24 | model = self.get_model(config, crossval_id=crossval_id) 25 | 26 | model.compile( 27 | loss=Architecture.get_loss_func_object(config), 28 | optimizer=Architecture.get_optimizer_object(config), 29 | metrics=config.train.metrics) 30 | 31 | return model 32 | 33 | @staticmethod 34 | @abc.abstractmethod 35 | def get_model(config, crossval_id=None): 36 | """ 37 | Builds the keras model associated with this architecture 38 | """ 39 | pass 40 | 41 | @staticmethod 42 | def generate_global_volume(image_vol): 43 | return None 44 | 45 | @staticmethod 46 | def get_patch_extraction_instructions(config, sample_idx, centers, image_vol, labels_vol, augment=False): 47 | """ 48 | Returns patch extraction instructions as required by the architecture 49 | (Specially useful for 2.5D and multipatch architectures) 50 | """ 51 | data_slices = get_patch_slices(centers, config.arch.patch_shape) 52 | label_slices = get_patch_slices(centers, config.arch.output_shape) 53 | 54 | sample_instructions = list() 55 | for data_slice, label_slice in iter.izip(data_slices, label_slices): 56 | instruction = PatchExtractInstruction( 57 | sample_idx=sample_idx, data_patch_slice=data_slice, label_patch_slice=label_slice) 58 | sample_instructions.append(instruction) 59 | 60 | if augment: 61 | goal_num_patches = math.ceil((1 - config.train.uniform_ratio) * config.train.num_patches) 62 | sample_instructions = augment_instructions(sample_instructions, goal_num_patches) 63 | 64 | return sample_instructions 65 | 66 | @staticmethod 67 | def get_optimizer_object(config): 68 | """ 69 | Returns configured optimizer object (i.e. with appropiate learning_rate, batch_size...) 70 | :param config: 71 | :return: Optimizer object with appropiate configuration 72 | """ 73 | 74 | if config.train.optimizer is 'Adam': 75 | optimizer = Koptimizers.Adam(lr=config.train.learning_rate, 76 | decay=config.train.lr_decay) 77 | elif config.train.optimizer is 'Adadelta': 78 | optimizer = Koptimizers.Adadelta(lr=config.train.learning_rate) 79 | elif config.train.optimizer is 'Adagrad': 80 | optimizer = Koptimizers.Adagrad(lr=config.train.learning_rate) 81 | else: 82 | # If specified optimizer not implemented return string for default parameters 83 | optimizer = config.train.optimizer 84 | 85 | return optimizer 86 | 87 | @staticmethod 88 | def get_loss_func_object(config): 89 | loss_dictionary = { 90 | 'jaccard' : jaccard, 91 | 'categorical_dice' : dice, 92 | 'sensitivity_specificity' : ss 93 | } 94 | 95 | if config.train.loss in loss_dictionary: 96 | loss_func = loss_dictionary[config.train.loss] 97 | else: 98 | loss_func = config.train.loss 99 | 100 | return loss_func -------------------------------------------------------------------------------- /utils/slices.py: -------------------------------------------------------------------------------- 1 | import logging as log 2 | import numpy as np 3 | import math 4 | import itertools as iter 5 | import scipy.ndimage as ndimage 6 | import copy 7 | import sys 8 | 9 | from sklearn.feature_extraction.image import extract_patches as sk_extract_patches 10 | from utils.visualization import normalize_and_save_patch 11 | 12 | from utils.misc import * 13 | 14 | def get_patch_slices(centers, patch_shape, step=None): 15 | assert len(patch_shape) == centers.shape[1] == 3 16 | 17 | ### Compute patch sides for slicing 18 | half_sizes = [[dim // 2, dim // 2] for dim in patch_shape] 19 | for i in range(len(half_sizes)): # If even dimension, subtract 1 to account for assymetry 20 | if patch_shape[i] % 2 == 0: half_sizes[i][1] -= 1 21 | 22 | # Actually create slices 23 | patch_locations = [] 24 | for count, center in enumerate(centers): 25 | patch_slice = [slice(None), # slice(None) selects all modalities 26 | slice(center[0] - half_sizes[0][0], center[0] + half_sizes[0][1] + 1, step), 27 | slice(center[1] - half_sizes[1][0], center[1] + half_sizes[1][1] + 1, step), 28 | slice(center[2] - half_sizes[2][0], center[2] + half_sizes[2][1] + 1, step)] 29 | patch_locations.append(patch_slice) 30 | 31 | return patch_locations 32 | 33 | 34 | def filter_slices_out_of_bounds(slice_sets, vol_shape): 35 | unpack_return = False 36 | if isinstance(slice_sets, list): 37 | slice_sets = (slice_sets, ) 38 | unpack_return = True 39 | 40 | assert type(slice_sets) is tuple 41 | 42 | # First slice set in tuple is limit case for out of bounds (typically biggest patch shape) 43 | slice_reference_set = slice_sets[0] 44 | 45 | valid_slices = [] 46 | for slice_count, patch_slice in enumerate(slice_reference_set): 47 | for i, dim_slice in enumerate(patch_slice): 48 | if i > 0 and (dim_slice.start < 0 or dim_slice.stop > vol_shape[i]): 49 | break 50 | else: 51 | valid_slices.append(slice_count) 52 | 53 | slice_sets_out = () 54 | for slice_set_in in slice_sets: 55 | slice_set_out = [slice_set_in[valid_slice] for valid_slice in valid_slices] 56 | slice_sets_out += (slice_set_out,) 57 | 58 | return slice_sets_out if not unpack_return else slice_sets_out[0] 59 | 60 | def get_in_bounds_slices_index(slices, volume_shape): 61 | valid_slices_idx = [] 62 | for slice_count, patch_slice in enumerate(slices): 63 | for i, dim_slice in enumerate(patch_slice): 64 | if i > 0 and (dim_slice.start < 0 or dim_slice.stop > volume_shape[i]): 65 | break 66 | else: 67 | valid_slices_idx.append(int(slice_count)) 68 | 69 | return valid_slices_idx 70 | 71 | def get_foreground_slices_index(slices, volume_in, min_fg_percentage): 72 | volume = volume_in if volume_in.ndim == 3 else volume_in[0] # Take modality 0 73 | 74 | volume_foreground = (volume > np.min(volume)) 75 | if np.all(volume_foreground) or np.all(np.invert(volume_foreground)): 76 | raise ValueError, "Detected all foreground/background volume (there should be a bit of both right?)" 77 | 78 | valid_idxs = list() 79 | for idx, s in enumerate(slices): 80 | vol_slice_foreground = volume_foreground[s[1:]] # Slices are 4D 81 | 82 | if np.any(vol_slice_foreground): # At least one foreground voxel 83 | if min_fg_percentage == 0.0: 84 | valid_idxs.append(idx) 85 | else: 86 | # A background voxel has been detected, check for ratio 87 | min_foreground_vox = min_fg_percentage * float(vol_slice_foreground.size) 88 | if np.count_nonzero(vol_slice_foreground) > min_foreground_vox: 89 | valid_idxs.append(idx) 90 | 91 | #log.debug("Filter BG: filtering {} slices out, leaving {}".format(len(slices) - len(valid_idxs), len(valid_idxs))) 92 | return valid_idxs -------------------------------------------------------------------------------- /architecture/Ronneberger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | from keras import backend as K 5 | from keras.layers import Activation, Input 6 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, MaxPooling2D 7 | from keras.layers.convolutional import Conv3D, Conv3DTranspose, MaxPooling3D 8 | from keras.layers.core import Permute, Reshape 9 | from keras.layers.merge import concatenate 10 | from keras.layers.normalization import BatchNormalization 11 | from keras.models import Model 12 | from .Architecture import Architecture 13 | 14 | K.set_image_dim_ordering('th') 15 | 16 | class Ronneberger(Architecture): 17 | @staticmethod 18 | def get_model(config, crossval_id=None) : 19 | assert config.arch.num_dimensions is 2 20 | assert config.arch.patch_shape[0] % 16 == 0, "Invalid patch shape" 21 | 22 | num_modalities = len(config.dataset.modalities) 23 | input_layer_shape = (num_modalities, ) + config.arch.patch_shape[:config.arch.num_dimensions] 24 | output_layer_shape = (config.train.num_classes, np.prod(config.arch.patch_shape[:config.arch.num_dimensions])) 25 | 26 | model = generate_unet_model( 27 | config.arch.num_dimensions, 28 | config.train.num_classes, 29 | input_layer_shape, 30 | output_layer_shape, 31 | config.train.activation, 32 | downsize_factor=2) 33 | 34 | return model 35 | 36 | def generate_unet_model( 37 | dimension, num_classes, input_shape, output_shape, activation, downsize_factor=2): 38 | input = Input(shape=input_shape) 39 | 40 | L1_out = get_conv_core(input, num_filters=64) 41 | 42 | L2_in = get_max_pooling_layer(L1_out) 43 | L2_out = get_conv_core(L2_in, num_filters=128) 44 | 45 | L3_in = get_max_pooling_layer(L2_out) 46 | L3_out = get_conv_core(L3_in, num_filters=256) 47 | 48 | L4_in = get_max_pooling_layer(L3_out) 49 | L4_out = get_conv_core(L4_in, num_filters=512) 50 | 51 | C5_in = get_max_pooling_layer(L4_out) 52 | C5_out = get_conv_core(C5_in, num_filters=1024) 53 | 54 | R4_in = get_deconv_layer(C5_out, num_filters=256) 55 | R4_in = concatenate([R4_in, L4_out], axis=1) 56 | R4_out = get_conv_core(R4_in, num_filters=512) 57 | 58 | R3_in = get_deconv_layer(R4_out, num_filters=128) 59 | R3_in = concatenate([R3_in, L3_out], axis=1) 60 | R3_out = get_conv_core(R3_in, num_filters=256) 61 | 62 | R2_in = get_deconv_layer(R3_out, num_filters=64) 63 | R2_in = concatenate([R2_in, L2_out], axis=1) 64 | R2_out = get_conv_core(R2_in, num_filters=128) 65 | 66 | R1_in = get_deconv_layer(R2_out, num_filters=32) 67 | R1_in = concatenate([R1_in, L1_out], axis=1) 68 | R1_out = get_conv_core(R1_in, num_filters=64) 69 | 70 | out = get_conv_fc(R1_out, num_classes) 71 | pred = organise_output(out, output_shape, 'softmax') 72 | 73 | return Model(inputs=[input], outputs=[pred]) 74 | 75 | 76 | def get_conv_core(input, num_filters, batch_norm=False) : 77 | kernel_size = (3, 3) 78 | 79 | x = Conv2D(num_filters, kernel_size=kernel_size, padding='same')(input) 80 | x = Activation('relu')(x) 81 | if batch_norm: x = BatchNormalization(axis=1)(x) 82 | x = Conv2D(num_filters, kernel_size=kernel_size, padding='same')(x) 83 | x = Activation('relu')(x) 84 | if batch_norm: x = BatchNormalization(axis=1)(x) 85 | 86 | return x 87 | 88 | def get_max_pooling_layer(input) : 89 | return MaxPooling2D(pool_size=(2, 2))(input) 90 | 91 | def get_deconv_layer(input, num_filters) : 92 | return Conv2DTranspose(num_filters, kernel_size=(2, 2), strides=(2, 2))(input) 93 | 94 | def get_conv_fc(input, num_classes): 95 | return Conv2D(num_classes, kernel_size=(1, 1))(input) 96 | 97 | def organise_output(input, output_shape, activation) : 98 | pred = Reshape(output_shape)(input) 99 | pred = Permute((2, 1))(pred) 100 | return Activation(activation)(pred) -------------------------------------------------------------------------------- /utils/samples.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import logging as log 4 | import itertools as iter 5 | 6 | from dataset.Dataset import Sample 7 | from utils.volume import pad_volume 8 | 9 | def split_train_val(dataset, validation_split=None, idxs=None): 10 | assert validation_split is not None or idxs is not None # One is not None 11 | assert validation_split is None or idxs is None # Other is None 12 | if validation_split is not None: 13 | assert isinstance(validation_split, float) and 0.0 < validation_split < 1.0 14 | if idxs is not None: 15 | assert isinstance(idxs, list) 16 | 17 | if len(dataset.val) > 0: 18 | log.debug("Existing validation set found -> num_train={}, num_val={}" \ 19 | .format(len(dataset.train), len(dataset.val))) 20 | return dataset 21 | 22 | # Last volumes will be relocated to val set 23 | N = len(dataset.train) 24 | if validation_split is not None: 25 | num_val_volumes = np.int(np.ceil(N * validation_split)) 26 | val_idxs = np.asarray(range(N - num_val_volumes, N), dtype=int) 27 | else: # idxs is not None 28 | val_idxs = np.asarray(idxs, dtype=int) 29 | val_idxs = val_idxs[val_idxs < len(dataset.train)] 30 | 31 | train_idxs = np.delete(np.arange(N, dtype=int), val_idxs) 32 | 33 | 34 | dataset.val = [dataset.train[idx] for idx in val_idxs] 35 | dataset.train = [dataset.train[idx] for idx in train_idxs] 36 | 37 | crossval_id = '{}_to_{}'.format(val_idxs[0], val_idxs[-1]) 38 | 39 | disp_msg = "Splitted train into train+val -> num_train={}, num_val={}, num_test={}, crossval_id={}" 40 | log.debug(disp_msg.format(len(dataset.train), len(dataset.val), len(dataset.test), crossval_id)) 41 | 42 | return dataset, crossval_id 43 | 44 | 45 | def compute_set_statistics(set_in): 46 | num_samples = len(set_in) 47 | num_modalities = set_in[0].data.shape[0] 48 | 49 | mean = np.zeros((num_samples, num_modalities,)) 50 | std = np.zeros((num_samples, num_modalities,)) 51 | 52 | for idx in range(len(set_in)): 53 | mean[idx], std[idx] = compute_sample_statistics(set_in[idx]) 54 | 55 | return mean, std 56 | 57 | def normalise_set(set_in, mean=None, std=None): 58 | return_stats = False 59 | if mean is None and std is None: 60 | mean, std = compute_set_statistics(set_in) 61 | return_stats = True 62 | 63 | for i in range(len(set_in)): 64 | set_in[i] = normalise_sample(set_in[i], mean=mean[i], std=std[i]) 65 | 66 | return set_in if not return_stats else set_in, mean, std 67 | 68 | def compute_sample_statistics(sample_in): 69 | assert isinstance(sample_in, Sample) 70 | num_modalities = sample_in.data.shape[0] 71 | 72 | mean = np.zeros((num_modalities,)) 73 | std = np.zeros((num_modalities,)) 74 | for modality in range(num_modalities): 75 | mean[modality] = np.mean(sample_in.data[modality]) 76 | std[modality] = np.std(sample_in.data[modality]) 77 | 78 | return mean, std 79 | 80 | def normalise_sample(sample_in, mean=None, std=None): 81 | if mean is None and std is None: 82 | mean, std = compute_sample_statistics(sample_in) 83 | 84 | for modality in range(sample_in.data.shape[0]): 85 | sample_in.data[modality] -= mean[modality] 86 | sample_in.data[modality] /= std[modality] 87 | 88 | return sample_in 89 | 90 | def zeropad_set(samples_in, patch_shape): 91 | samples_out = list() 92 | for sample in samples_in: 93 | samples_out.append(zeropad_sample(sample, patch_shape)) 94 | return samples_out 95 | 96 | 97 | def zeropad_sample(sample_in, patch_shape): 98 | sample_in.data = pad_volume(sample_in.data, patch_shape) 99 | try: sample_in.labels = pad_volume(sample_in.labels, patch_shape) 100 | except AttributeError: pass # Some test samples don't have labels 101 | return sample_in 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /dataset/Isles15_SISS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging as log 4 | 5 | from .Dataset import Dataset, Sample 6 | from utils.volume import get_brain_mask, read_volume, load_nib_file 7 | 8 | 9 | class Isles15_SISS(Dataset): 10 | def load_samples(self, config_dataset): 11 | log.info("Loading Isles15 SISS dataset...") 12 | 13 | num_volumes = config_dataset.num_volumes 14 | dataset_path = config_dataset.path 15 | pattern = config_dataset.general_pattern 16 | modalities = config_dataset.modalities 17 | 18 | # Training loading 19 | for case_num in range(1, num_volumes[0] + 1): 20 | filepaths = [None] * (len(modalities) + 1) 21 | 22 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[0].format(case_num))): 23 | for idx_mod, modality in enumerate(modalities): 24 | 25 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename] 26 | assert len(modality_file) in [0,1], "Found more than one file for the same modality" 27 | if modality_file: 28 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 29 | 30 | ground_truth_file = [file_idx for file_idx, filename in enumerate(files) if 'OT' in filename] 31 | if ground_truth_file: 32 | filepaths[-1] = os.path.join(root, files[ground_truth_file[0]]) 33 | 34 | 35 | # Check folder exists (some samples missing) 36 | if any([filepath is None for filepath in filepaths]): 37 | raise ValueError, "Didn't find all expected modalities" 38 | 39 | sample = Sample(id=case_num) 40 | 41 | # Load volume to check dimensions (not the same for all train samples) 42 | sample.nib = load_nib_file(filepaths[0]) 43 | vol = sample.nib.get_data() 44 | 45 | sample.data = np.zeros((len(modalities),) + vol.shape) 46 | sample.labels = np.zeros((1,) + vol.shape) 47 | 48 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 49 | sample.data[0] = vol 50 | for i, filepath in enumerate(filepaths[:-1]): 51 | sample.data[i] = read_volume(filepath) 52 | 53 | sample.mask = get_brain_mask(sample.data) 54 | 55 | # Ground truth loading 56 | sample.labels[0] = read_volume(filepaths[-1]) 57 | 58 | self.add_train(sample) 59 | 60 | # Testing loading 61 | for case_num in range(1, num_volumes[1] + 1): 62 | filepaths = [None] * (len(modalities)) 63 | 64 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[1].format(case_num))): 65 | for idx_mod, modality in enumerate(modalities): 66 | 67 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename] 68 | assert len(modality_file) in [0,1], "Found more than one file for the same modality" 69 | if modality_file: 70 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 71 | 72 | # Check folder exists (some samples missing) 73 | if any([filepath is None for filepath in filepaths]): 74 | raise ValueError, "Didn't find all expected modalities" 75 | 76 | sample = Sample(id=case_num) 77 | 78 | # Load volume to check dimensions (not the same for all train samples) 79 | sample.nib = load_nib_file(filepaths[0]) 80 | vol = sample.nib.get_data() 81 | 82 | sample.data = np.zeros((len(modalities),) + vol.shape) 83 | 84 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 85 | sample.data[0] = vol 86 | for i, filepath in enumerate(filepaths): 87 | sample.data[i] = read_volume(filepath) 88 | 89 | sample.mask = get_brain_mask(sample.data) 90 | 91 | self.add_test(sample) 92 | 93 | 94 | -------------------------------------------------------------------------------- /dataset/Isles15_SPES.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import logging as log 4 | 5 | from .Dataset import Dataset, Sample 6 | from utils.volume import get_brain_mask, read_volume, load_nib_file 7 | 8 | 9 | class Isles15_SPES(Dataset): 10 | def load_samples(self, config_dataset): 11 | log.info("Loading Isles15 SPESS dataset...") 12 | 13 | num_volumes = config_dataset.num_volumes 14 | dataset_path = config_dataset.path 15 | pattern = config_dataset.general_pattern 16 | modalities = config_dataset.modalities 17 | 18 | # Training loading 19 | for case_num in range(1, num_volumes[0] + 1): 20 | filepaths = [None] * (len(modalities) + 1) 21 | 22 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[0].format(case_num))): 23 | if 'corelabel' in root or 'mergedlabels' in root: 24 | continue 25 | 26 | for idx_mod, modality in enumerate(modalities): 27 | 28 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename] 29 | assert len(modality_file) in [0,1], "Found more than one file for the same modality" 30 | if modality_file: 31 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 32 | 33 | ground_truth_file = [file_idx for file_idx, filename in enumerate(files) if 'OT' in filename] 34 | if ground_truth_file: 35 | filepaths[-1] = os.path.join(root, files[ground_truth_file[0]]) 36 | 37 | 38 | # Check folder exists (some samples missing) 39 | if any([filepath is None for filepath in filepaths]): 40 | raise ValueError, "Didn't find all expected modalities" 41 | 42 | sample = Sample(id=case_num) 43 | 44 | # Load volume to check dimensions (not the same for all train samples) 45 | sample.nib = load_nib_file(filepaths[0]) 46 | vol = sample.nib.get_data() 47 | 48 | sample.data = np.zeros((len(modalities),) + vol.shape) 49 | sample.labels = np.zeros((1,) + vol.shape) 50 | 51 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 52 | sample.data[0] = vol 53 | for i, filepath in enumerate(filepaths[:-1]): 54 | sample.data[i] = read_volume(filepath) 55 | 56 | sample.mask = get_brain_mask(sample.data) 57 | 58 | # Ground truth loading 59 | sample.labels[0] = read_volume(filepaths[-1]) 60 | 61 | self.add_train(sample) 62 | 63 | # Testing loading 64 | for case_num in range(1, num_volumes[1] + 1): 65 | filepaths = [None] * (len(modalities)) 66 | 67 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[1].format(case_num))): 68 | if 'corelabel' in root or 'mergedlabels' in root: 69 | continue 70 | 71 | for idx_mod, modality in enumerate(modalities): 72 | 73 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename] 74 | assert len(modality_file) in [0,1], "Found more than one file for the same modality" 75 | if modality_file: 76 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 77 | 78 | 79 | # Check folder exists (some samples missing) 80 | if any([filepath is None for filepath in filepaths]): 81 | raise ValueError, "Didn't find all expected modalities" 82 | 83 | sample = Sample(id=case_num) 84 | 85 | # Load volume to check dimensions (not the same for all train samples) 86 | sample.nib = load_nib_file(filepaths[0]) 87 | vol = sample.nib.get_data() 88 | 89 | sample.data = np.zeros((len(modalities),) + vol.shape) 90 | 91 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 92 | sample.data[0] = vol 93 | for i, filepath in enumerate(filepaths): 94 | sample.data[i] = read_volume(filepath) 95 | 96 | sample.mask = get_brain_mask(sample.data) 97 | 98 | 99 | self.add_test(sample) 100 | 101 | 102 | -------------------------------------------------------------------------------- /dataset/isles2017.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | 4 | import numpy as np 5 | import logging as log 6 | 7 | from .Dataset import Dataset, Sample 8 | from utils.volume import get_brain_mask, read_volume, load_nib_file 9 | 10 | 11 | class Isles2017(Dataset): 12 | def load_samples(self, config_dataset): 13 | log.info("Loading Isles2017 dataset...") 14 | 15 | num_volumes = config_dataset.num_volumes 16 | dataset_path = config_dataset.path 17 | pattern = config_dataset.general_pattern 18 | modalities = config_dataset.modalities 19 | 20 | # Training loading 21 | next_id = 1 22 | for case_num in range(1, num_volumes[0] + 1): 23 | filepaths = [None] * (len(modalities) + 1) 24 | 25 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[0].format(case_num))): 26 | for idx_mod, modality in enumerate(modalities): 27 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename and 'nii' in filename] 28 | assert len(modality_file) in [0, 1], "Found more than one file for the same modality" 29 | if modality_file: 30 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 31 | 32 | ground_truth_file = [file_idx for file_idx, filename in enumerate(files) if 'OT' in filename and 'nii' in filename] 33 | if ground_truth_file: 34 | filepaths[-1] = os.path.join(root, files[ground_truth_file[0]]) 35 | 36 | # Check folder exists (some samples missing) 37 | if all([fp is None for fp in filepaths]): #CASE NUMBERS NOT SEQUENTIAL 38 | continue 39 | else: 40 | if any([filepath is None for filepath in filepaths]): 41 | raise ValueError, "Didn't find all expected modalities\n {}".format(filepaths) 42 | 43 | sample = Sample(id=next_id) 44 | next_id += 1 45 | 46 | # Load volume to check dimensions (not the same for all train samples) 47 | sample.nib = load_nib_file(filepaths[0]) 48 | vol = sample.nib.get_data() 49 | 50 | sample.data = np.zeros((len(modalities),) + vol.shape) 51 | sample.labels = np.zeros((1,) + vol.shape) 52 | 53 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 54 | sample.data[0] = vol 55 | for i, filepath in enumerate(filepaths[:-1]): 56 | sample.data[i] = read_volume(filepath) 57 | 58 | sample.mask = get_brain_mask(sample.data) 59 | 60 | # Ground truth loading 61 | sample.labels[0] = read_volume(filepaths[-1]) 62 | 63 | self.add_train(sample) 64 | 65 | # Testing loading 66 | next_id = 1 67 | for case_num in range(1, num_volumes[1] + 1): 68 | filepaths = [None] * (len(modalities)) 69 | 70 | for root, subdirs, files in os.walk(os.path.join(dataset_path, pattern[1].format(case_num))): 71 | for idx_mod, modality in enumerate(modalities): 72 | 73 | modality_file = [file_idx for file_idx, filename in enumerate(files) if modality in filename and 'nii' in filename] 74 | assert len(modality_file) in [0, 1], "Found more than one file for the same modality" 75 | if modality_file: 76 | filepaths[idx_mod] = os.path.join(root, files[modality_file[0]]) 77 | 78 | # Check folder exists (some samples missing) 79 | if all([fp is None for fp in filepaths]): #CASE NUMBERS NOT SEQUENTIAL 80 | continue 81 | else: 82 | if any([filepath is None for filepath in filepaths]): 83 | raise ValueError, "Didn't find all expected modalities\n {}".format(filepaths) 84 | 85 | sample = Sample(id=next_id) 86 | next_id += 1 87 | 88 | # Load volume to check dimensions (not the same for all train samples) 89 | sample.nib = load_nib_file(filepaths[0]) 90 | vol = sample.nib.get_data() 91 | 92 | sample.data = np.zeros((len(modalities),) + vol.shape) 93 | 94 | # Load all modalities (except last which is gt segmentation) into last appended ndarray 95 | sample.data[0] = vol 96 | for i, filepath in enumerate(filepaths): 97 | sample.data[i] = read_volume(filepath) 98 | 99 | sample.mask = get_brain_mask(sample.data) 100 | 101 | self.add_test(sample) 102 | -------------------------------------------------------------------------------- /workflow/filenames.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | import logging as log 3 | import numpy as np 4 | import string 5 | import os 6 | import datetime 7 | 8 | import cv2 9 | 10 | def generate_output_filename(path, dataset, params, extension) : 11 | base_path = os.path.join(path, dataset) if dataset is not None else os.path.join(path) 12 | params_filename = string.join(['{}'.format(param) for param in params], '_') 13 | extension_name = '.{}'.format(extension) if extension[0] is not '.' else extension 14 | params_filename = params_filename.replace('.', ',') # Avoid decimal points that could be confounded with extension . 15 | params_filename = params_filename.replace(' ', '') # Avoid spaces 16 | return os.path.join(base_path, params_filename + extension_name) 17 | 18 | def generate_model_filename_old(config, crossval_id=None): 19 | model_params = ( 20 | config.train.architecture_name, 21 | "{}D".format(config.arch.num_dimensions), 22 | config.arch.patch_shape, 23 | config.train.optimizer, 24 | "{:.0e}".format(config.train.learning_rate), 25 | "{:.0e}".format(config.train.lr_decay)) 26 | 27 | if crossval_id is not None: 28 | model_params += (crossval_id,) 29 | 30 | return generate_output_filename(config.model_path, config.train.dataset_name, model_params, 'h5') 31 | 32 | def generate_model_filename(config, crossval_id=None): 33 | model_params = ( 34 | config.train.architecture_name, 35 | "{}D".format(config.arch.num_dimensions), 36 | config.arch.patch_shape, 37 | config.train.optimizer) 38 | 39 | if crossval_id is not None: 40 | model_params += (crossval_id,) 41 | 42 | return generate_output_filename(config.model_path, config.train.dataset_name, model_params, 'h5') 43 | 44 | def generate_result_filename(config, additional_params, extension): 45 | result_params = ( 46 | config.dataset.evaluation_set, 47 | config.train.architecture_name, 48 | "{}D".format(config.arch.num_dimensions), 49 | config.arch.patch_shape, 50 | config.train.extraction_step_test) 51 | result_params = additional_params + result_params 52 | 53 | return generate_output_filename(config.results_path, config.train.dataset_name, result_params, extension) 54 | 55 | def generate_log_filename(config): 56 | log_params = ( 57 | datetime.datetime.now().isoformat('_').replace(':', '-').split(".")[0], 58 | config.train.architecture_name, 59 | config.arch.num_dimensions, 60 | config.train.optimizer, 61 | config.train.batch_size, 62 | "{:.0e}".format(config.train.learning_rate), 63 | "{:.0e}".format(config.train.lr_decay), 64 | config.train.num_patches, 65 | config.train.uniform_ratio) 66 | return generate_output_filename(config.log_path, config.train.dataset_name, log_params, 'csv') 67 | 68 | def generate_metrics_filename(config): 69 | metrics_params = () 70 | 71 | if config.eval.evaluation_type is 'crossval': 72 | metrics_params += ( 73 | str(config.eval.crossval_fold) + 'fold',) 74 | elif config.eval.evaluation_type in ['val', 'test']: 75 | pass 76 | elif config.eval.evaluation_type is 'grid': 77 | metrics_params += ( 78 | config.eval.metrics_eval_set, 79 | config.eval.metrics_thresh_values, 80 | config.eval.metrics_lesion_sizes) 81 | 82 | metrics_params += ( 83 | datetime.datetime.now().isoformat('_').replace(':', '-').split(".")[0], 84 | config.train.dataset_name, 85 | config.train.architecture_name, 86 | "{}D".format(config.arch.num_dimensions), 87 | config.train.optimizer, 88 | config.train.batch_size, 89 | "{:.0e}".format(config.train.learning_rate), 90 | "{:.0e}".format(config.train.lr_decay), 91 | config.train.num_patches, 92 | config.dataset.min_lesion_voxels, 93 | config.dataset.lesion_threshold) 94 | 95 | if config.eval.evaluation_type is 'grid': 96 | datetime_string = datetime.datetime.now().isoformat('_').replace(':', '-').split(".")[0] 97 | metrics_params = (datetime_string,) + metrics_params 98 | 99 | temp_filename = generate_output_filename(config.metrics_path, None, ('grid_temp',) + metrics_params, 'csv') 100 | tables_filename = generate_output_filename(config.metrics_path, None, ('grid_tables',) + metrics_params, 'csv') 101 | return temp_filename, tables_filename 102 | else: 103 | return generate_output_filename(config.metrics_path, None, (config.eval.evaluation_type, ) + metrics_params, 'csv') 104 | -------------------------------------------------------------------------------- /utils/centers.py: -------------------------------------------------------------------------------- 1 | import logging as log 2 | import numpy as np 3 | import math 4 | import itertools as iter 5 | import scipy.ndimage as ndimage 6 | import copy 7 | import sys 8 | 9 | from sklearn.feature_extraction.image import extract_patches as sk_extract_patches 10 | from utils.visualization import normalize_and_save_patch 11 | from utils.misc import * 12 | 13 | from utils.slices import get_in_bounds_slices_index, get_patch_slices 14 | 15 | def sample_uniform_patch_centers(patch_shape, extraction_step, volume_foreground): 16 | """ 17 | Generate patch center indexes for 3D volume uniform sampling 18 | patch_shape, extraction step and expected shape should be in 3D! 19 | 20 | volume_foreground = sample.data, is used to avoid adding centers for which the center of the patch is empty 21 | """ 22 | expected_shape = volume_foreground.shape[1:] 23 | 24 | assert len(patch_shape) == len(extraction_step) == len(expected_shape) == 3 25 | 26 | # Get patch half size 27 | half_sizes = [[dim // 2, dim // 2] for dim in patch_shape] 28 | for i in range(len(half_sizes)): # If even dimension, subtract 1 to account for assymetry 29 | if patch_shape[i] % 2 == 0: half_sizes[i][1] -= 1 30 | 31 | idxs = [] 32 | for dim in range(len(expected_shape)): 33 | idxs.append(range(half_sizes[dim][0], expected_shape[dim] - half_sizes[dim][1], extraction_step[dim])) 34 | 35 | # Put in ndarray format 36 | centers = np.zeros((np.prod([len(dim_idxs) for dim_idxs in idxs]), len(idxs)), dtype=int) 37 | 38 | # Make sure added center is foreground 39 | added_centers = 0 40 | 41 | for count, center_coords in enumerate(iter.product(*idxs)): 42 | if volume_foreground[0, center_coords[0], center_coords[1], center_coords[2]] != 0: 43 | centers[added_centers] = np.asarray(center_coords) 44 | added_centers += 1 45 | 46 | centers = centers[:added_centers] # Get filled centers 47 | 48 | centers = clip_centers_out_of_bounds(centers, patch_shape, volume_foreground) 49 | 50 | return centers 51 | 52 | 53 | def sample_positive_patch_centers(labels_volume): 54 | if len(labels_volume.shape) > 3: 55 | labels_volume = labels_volume[0] # Select modality 0 56 | 57 | pos_centers = np.where(labels_volume == 1) 58 | 59 | # Put in ndarray format 60 | centers = np.zeros((len(pos_centers[0]), len(pos_centers)), dtype=int) 61 | for j, dim_coords in enumerate(pos_centers): 62 | for i, dim_coord in enumerate(dim_coords): 63 | centers[i][j] = dim_coord 64 | 65 | return centers 66 | 67 | 68 | def randomly_offset_centers(centers, offset_shape, patch_shape, original_vol): 69 | ### Offset patches to avoid location bias 70 | 71 | # Create offset matrix 72 | center_offsets = 2.0 * (np.random.rand(centers.shape[0], centers.shape[1]) - 0.5) # generate uniform sampling between -1 and 1 73 | offset_ranges = np.stack([(offset_shape[0] // 2) * np.ones((centers.shape[0],)), 74 | (offset_shape[1] // 2) * np.ones((centers.shape[0],)), 75 | (offset_shape[2] // 2) * np.ones((centers.shape[0],))], axis=1) 76 | center_offsets = np.multiply(center_offsets, offset_ranges).astype(int) 77 | 78 | # Apply offset 79 | centers += center_offsets 80 | 81 | # Check centers in-bounds 82 | centers = clip_centers_out_of_bounds(centers, patch_shape, original_vol) 83 | 84 | return centers 85 | 86 | 87 | 88 | def resample_centers(centers, min_samples=None, max_samples=None): 89 | resampled_idxs = None 90 | if min_samples is not None: 91 | if centers.shape[0] < min_samples: 92 | resampled_idxs = get_resampling_indexes(centers.shape[0], min_samples) 93 | 94 | if max_samples is not None: 95 | if centers.shape[0] > max_samples: 96 | resampled_idxs = get_resampling_indexes(centers.shape[0], max_samples) 97 | 98 | if resampled_idxs is not None: # Perform resampling 99 | resampled_idxs = np.asarray(resampled_idxs, dtype=np.int) 100 | centers = centers[resampled_idxs] 101 | 102 | return centers 103 | 104 | def clip_centers_out_of_bounds(centers, patch_shape, original_vol): 105 | vol_shape = original_vol.shape[1:] # Omit modality dimension 106 | center_range = [[(patch_shape[0] // 2) + 1, vol_shape[0] - (patch_shape[0] // 2) - 1], 107 | [(patch_shape[1] // 2) + 1, vol_shape[1] - (patch_shape[1] // 2) - 1], 108 | [(patch_shape[2] // 2) + 1, vol_shape[2] - (patch_shape[2] // 2) - 1]] 109 | center_range = np.asarray(center_range, dtype='int') 110 | 111 | for i in [0,1,2]: 112 | centers[:, i] = np.clip(centers[:, i], a_min=center_range[i][0], a_max=center_range[i][1]) 113 | 114 | return centers -------------------------------------------------------------------------------- /architecture/Guerrero.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from keras import backend as K 4 | from keras.layers import Activation, Input 5 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, MaxPooling2D 6 | from keras.layers.convolutional import Conv3D, Conv3DTranspose, MaxPooling3D 7 | from keras.layers.core import Permute, Reshape 8 | from keras.layers.merge import add 9 | from keras.layers.normalization import BatchNormalization 10 | from keras.models import Model 11 | from .Architecture import Architecture 12 | 13 | K.set_image_dim_ordering('th') 14 | 15 | class Guerrero (Architecture): 16 | @staticmethod 17 | def get_model(config, crossval_id=None): 18 | assert config.arch.num_dimensions in [2, 3] 19 | 20 | num_modalities = len(config.dataset.modalities) 21 | input_layer_shape = (num_modalities,) + config.arch.patch_shape[:config.arch.num_dimensions] 22 | output_layer_shape = (config.train.num_classes, np.prod(config.arch.patch_shape[:config.arch.num_dimensions])) 23 | 24 | model = generate_uresnet_model( 25 | config.arch.num_dimensions, 26 | config.train.num_classes, 27 | input_layer_shape, 28 | output_layer_shape, 29 | config.train.activation) 30 | 31 | return model 32 | 33 | 34 | def generate_uresnet_model( 35 | dimension, num_classes, input_shape, output_shape, activation): 36 | input_uresnet = Input(shape=input_shape) 37 | 38 | conv1 = get_res_conv_core(dimension, input_uresnet, 32) 39 | pool1 = get_max_pooling_layer(dimension, conv1) 40 | 41 | conv2 = get_res_conv_core(dimension, pool1, 64) 42 | pool2 = get_max_pooling_layer(dimension, conv2) 43 | 44 | conv3 = get_res_conv_core(dimension, pool2, 128) 45 | pool3 = get_max_pooling_layer(dimension, conv3) 46 | 47 | conv4 = get_res_conv_core(dimension, pool3, 256) 48 | up1 = get_deconv_layer(dimension, conv4, 128) 49 | conv5 = get_res_conv_core(dimension, up1, 128) 50 | 51 | add35 = add([conv3, conv5]) 52 | add35 = BatchNormalization(axis=1)(add35) 53 | add35 = Activation('relu')(add35) 54 | conv6 = get_res_conv_core(dimension, add35, 128) 55 | up2 = get_deconv_layer(dimension, conv6, 64) 56 | 57 | add22 = add([conv2, up2]) 58 | add22 = BatchNormalization(axis=1)(add22) 59 | add22 = Activation('relu')(add22) 60 | conv7 = get_res_conv_core(dimension, add22, 64) 61 | up3 = get_deconv_layer(dimension, conv7, 32) 62 | 63 | add13 = add([conv1, up3]) 64 | add13 = BatchNormalization(axis=1)(add13) 65 | add13 = Activation('relu')(add13) 66 | conv8 = get_res_conv_core(dimension, add13, 32) 67 | 68 | pred = get_conv_fc(dimension, conv8, num_classes) 69 | pred = organise_output(pred, output_shape, activation) 70 | 71 | return Model(inputs=[input_uresnet], outputs=[pred]) 72 | 73 | def get_res_conv_core(dimension, input, num_filters) : 74 | kernel_size_a = (3, 3) if dimension == 2 else (3, 3, 3) 75 | kernel_size_b = (1, 1) if dimension == 2 else (1, 1, 1) 76 | 77 | if dimension == 2 : 78 | a = Conv2D(num_filters, kernel_size=kernel_size_a, padding='same')(input) 79 | a = BatchNormalization(axis=1)(a) 80 | b = Conv2D(num_filters, kernel_size=kernel_size_b, padding='same')(input) 81 | b = BatchNormalization(axis=1)(b) 82 | else : 83 | a = Conv3D(num_filters, kernel_size=kernel_size_a, padding='same')(input) 84 | a = BatchNormalization(axis=1)(a) 85 | b = Conv3D(num_filters, kernel_size=kernel_size_b, padding='same')(input) 86 | b = BatchNormalization(axis=1)(b) 87 | 88 | c = add([a, b]) 89 | c = BatchNormalization(axis=1)(c) 90 | return Activation('relu')(c) 91 | 92 | def get_max_pooling_layer(dimension, input) : 93 | pool_size = (2, 2) if dimension == 2 else (2, 2, 2) 94 | 95 | if dimension == 2: 96 | return MaxPooling2D(pool_size=pool_size)(input) 97 | else : 98 | return MaxPooling3D(pool_size=pool_size)(input) 99 | 100 | def get_deconv_layer(dimension, input, num_filters) : 101 | kernel_size = (2, 2) if dimension == 2 else (2, 2, 2) 102 | strides = (2, 2) if dimension == 2 else (2, 2, 2) 103 | 104 | if dimension == 2: 105 | return Conv2DTranspose(num_filters, kernel_size=kernel_size, strides=strides)(input) 106 | else : 107 | return Conv3DTranspose(num_filters, kernel_size=kernel_size, strides=strides)(input) 108 | 109 | def get_conv_fc(dimension, input, num_filters) : 110 | kernel_size = (1, 1) if dimension == 2 else (1, 1, 1) 111 | 112 | if dimension == 2 : 113 | fc = Conv2D(num_filters, kernel_size=kernel_size)(input) 114 | else : 115 | fc = Conv3D(num_filters, kernel_size=kernel_size)(input) 116 | 117 | return Activation('relu')(fc) 118 | 119 | def organise_output(input, output_shape, activation) : 120 | pred = Reshape(output_shape)(input) 121 | pred = Permute((2, 1))(pred) 122 | return Activation(activation)(pred) -------------------------------------------------------------------------------- /architecture/Cicek.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | from keras import backend as K 5 | from keras.layers import Activation, Input 6 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, MaxPooling2D 7 | from keras.layers.convolutional import Conv3D, Conv3DTranspose, MaxPooling3D 8 | from keras.layers.core import Permute, Reshape 9 | from keras.layers.merge import concatenate 10 | from keras.layers.normalization import BatchNormalization 11 | from keras.models import Model 12 | from .Architecture import Architecture 13 | 14 | K.set_image_dim_ordering('th') 15 | 16 | class Cicek (Architecture): 17 | @staticmethod 18 | def get_model(config, crossval_id=None) : 19 | assert config.arch.num_dimensions in [2, 3] 20 | 21 | num_modalities = len(config.dataset.modalities) 22 | input_layer_shape = (num_modalities, ) + config.arch.patch_shape[:config.arch.num_dimensions] 23 | output_layer_shape = (config.train.num_classes, np.prod(config.arch.patch_shape[:config.arch.num_dimensions])) 24 | 25 | model = generate_unet_model( 26 | config.arch.num_dimensions, 27 | config.train.num_classes, 28 | input_layer_shape, 29 | output_layer_shape, 30 | config.train.activation, 31 | downsize_factor=2) 32 | 33 | return model 34 | 35 | def generate_unet_model( 36 | dimension, num_classes, input_shape, output_shape, activation, downsize_factor=2): 37 | input = Input(shape=input_shape) 38 | 39 | conv1 = get_conv_core(dimension, input, int(math.floor(64/downsize_factor))) 40 | pool1 = get_max_pooling_layer(dimension, conv1) 41 | 42 | conv2 = get_conv_core(dimension, pool1, int(math.floor(128/downsize_factor))) 43 | pool2 = get_max_pooling_layer(dimension, conv2) 44 | 45 | conv3 = get_conv_core(dimension, pool2, int(math.floor(256/downsize_factor))) 46 | pool3 = get_max_pooling_layer(dimension, conv3) 47 | 48 | conv4 = get_conv_core(dimension, pool3, int(math.floor(512/downsize_factor))) 49 | 50 | up5 = get_deconv_layer(dimension, conv4, int(math.floor(256/downsize_factor))) 51 | up5 = concatenate([up5, conv3], axis=1) 52 | 53 | conv5 = get_conv_core(dimension, up5, int(math.floor(256/downsize_factor))) 54 | 55 | up6 = get_deconv_layer(dimension, conv5, int(math.floor(128/downsize_factor))) 56 | up6 = concatenate([up6, conv2], axis=1) 57 | 58 | conv6 = get_conv_core(dimension, up6, int(math.floor(128/downsize_factor))) 59 | 60 | up7 = get_deconv_layer(dimension, conv6, int(math.floor(64/downsize_factor))) 61 | up7 = concatenate([up7, conv1], axis=1) 62 | 63 | conv7 = get_conv_core(dimension, up7, int(math.floor(64/downsize_factor))) 64 | 65 | pred = get_conv_fc(dimension, conv7, num_classes) 66 | pred = organise_output(pred, output_shape, activation) 67 | 68 | return Model(inputs=[input], outputs=[pred]) 69 | 70 | 71 | def get_conv_core(dimension, input, num_filters) : 72 | x = None 73 | kernel_size = (3, 3) if dimension == 2 else (3, 3, 3) 74 | 75 | if dimension == 2 : 76 | x = Conv2D(num_filters, kernel_size=kernel_size, padding='same')(input) 77 | x = Activation('relu')(x) 78 | x = BatchNormalization(axis=1)(x) 79 | x = Conv2D(num_filters, kernel_size=kernel_size, padding='same')(x) 80 | x = Activation('relu')(x) 81 | x = BatchNormalization(axis=1)(x) 82 | else : 83 | x = Conv3D(num_filters, kernel_size=kernel_size, padding='same')(input) 84 | x = Activation('relu')(x) 85 | x = BatchNormalization(axis=1)(x) 86 | x = Conv3D(num_filters, kernel_size=kernel_size, padding='same')(x) 87 | x = Activation('relu')(x) 88 | x = BatchNormalization(axis=1)(x) 89 | 90 | return x 91 | 92 | def get_max_pooling_layer(dimension, input) : 93 | pool_size = (2, 2) if dimension == 2 else (2, 2, 2) 94 | 95 | if dimension == 2: 96 | return MaxPooling2D(pool_size=pool_size)(input) 97 | else : 98 | return MaxPooling3D(pool_size=pool_size)(input) 99 | 100 | def get_deconv_layer(dimension, input, num_filters) : 101 | strides = (2, 2) if dimension == 2 else (2, 2, 2) 102 | kernel_size = (2, 2) if dimension == 2 else (2, 2, 2) 103 | 104 | if dimension == 2: 105 | return Conv2DTranspose(num_filters, kernel_size=kernel_size, strides=strides)(input) 106 | else : 107 | return Conv3DTranspose(num_filters, kernel_size=kernel_size, strides=strides)(input) 108 | 109 | def get_conv_fc(dimension, input, num_filters) : 110 | fc = None 111 | kernel_size = (1, 1) if dimension == 2 else (1, 1, 1) 112 | 113 | if dimension == 2 : 114 | fc = Conv2D(num_filters, kernel_size=kernel_size)(input) 115 | else : 116 | fc = Conv3D(num_filters, kernel_size=kernel_size)(input) 117 | 118 | return Activation('relu')(fc) 119 | 120 | def organise_output(input, output_shape, activation) : 121 | pred = Reshape(output_shape)(input) 122 | pred = Permute((2, 1))(pred) 123 | return Activation(activation)(pred) -------------------------------------------------------------------------------- /utils/patch.py: -------------------------------------------------------------------------------- 1 | import logging as log 2 | import numpy as np 3 | import math 4 | import itertools as iter 5 | import scipy.ndimage as ndimage 6 | import copy 7 | import sys 8 | 9 | 10 | from utils.misc import * 11 | 12 | from sklearn.feature_extraction.image import extract_patches as sk_extract_patches 13 | from utils.visualization import normalize_and_save_patch 14 | 15 | def extract_patch_at(volume, patch_slice): 16 | assert volume.ndim == len(patch_slice) 17 | 18 | patch_out = copy.copy(volume[patch_slice]) 19 | if patch_out.shape[-1] == 1: # 2D/3D compatibility 20 | patch_out = patch_out.squeeze(axis=-1) 21 | 22 | return patch_out 23 | 24 | def extract_patches_at_locations(volume, patch_slices): 25 | # Infer patch shape from the patch_slices 26 | patch_shape = volume[patch_slices[0]].shape 27 | 28 | patches = np.zeros((len(patch_slices),) + patch_shape) 29 | for i, patch_slice in enumerate(patch_slices): 30 | patches[i] = copy.copy(volume[patch_slice]) 31 | 32 | # If last dimension is 1 means 2D case (remove dimension) 33 | if patches.shape[-1] == 1: 34 | patches = patches.squeeze(axis=-1) 35 | 36 | return patches 37 | 38 | 39 | def zeropad_patches(patches, desired_shape): 40 | """ 41 | Zeropad patches to match desired_shape 42 | 43 | :patches: as (#, x, y, z, class) 44 | :desired_shape: 3D tuple desired x,y,z 45 | """ 46 | 47 | #log.debug("Zeropad patches: patches shape {}, desired_shape {}".format(patches.shape, desired_shape)) 48 | 49 | assert patches.ndim == 5 50 | patch_shape = patches.shape[1:-1] # Extract only x,y,z 51 | 52 | pad = [] 53 | for dim in range(len(desired_shape)): 54 | pad.append([int(math.ceil((desired_shape[dim] - patch_shape[dim]) / 2.0)), 55 | int(math.floor((desired_shape[dim] - patch_shape[dim]) / 2.0))]) 56 | 57 | # Check if we need to do padding >.< 58 | if np.sum(pad) > 0: 59 | # Pre-append 0,0 to not pad batch dimension and postappend 0,0 to not pad class dimension (only x,y,z) 60 | padding = ((0, 0), (pad[0][0], pad[0][1]), (pad[1][0], pad[1][1]), (pad[2][0], pad[2][1]), (0, 0)) 61 | patches = np.pad(patches, padding, 'constant', constant_values=0) 62 | 63 | #log.debug("out: patches shape {}".format(patches.shape)) 64 | 65 | return patches 66 | 67 | def augment_patches(patch_sets_in, goal_num_patches): 68 | patch_set_ref = patch_sets_in[0] # Use first patch set for meta-computations 69 | if goal_num_patches <= patch_set_ref.shape[0]: 70 | return patch_sets_in 71 | 72 | num_patches_in = len(patch_set_ref) 73 | num_augments_per_patch = int(math.ceil(np.minimum(goal_num_patches / num_patches_in, 5))) 74 | num_patches_augment = int(goal_num_patches - num_patches_in) 75 | 76 | # Allocate space for output array and copy non augmented patches 77 | num_augmented_so_far = 0 78 | patch_sets_augmented = [] 79 | for patch_set in patch_sets_in: 80 | patch_set_aug = np.zeros((int(goal_num_patches),) + patch_set.shape[1:]) 81 | patch_set_aug[:num_patches_in] = patch_set 82 | patch_sets_augmented += [patch_set_aug, ] 83 | num_augmented_so_far += num_patches_in 84 | 85 | # Augment and add remaining copies 86 | augment_funcs = get_augment_functions(x_axis=1, y_axis=2) # (nmods, x, y, z) 87 | sampling_idxs = get_resampling_indexes(num_patches_in, num_patches_augment // num_augments_per_patch) 88 | func_idxs = get_resampling_indexes(len(augment_funcs), num_augments_per_patch) 89 | for idx in sampling_idxs: 90 | for func_idx in func_idxs: 91 | augment_func = augment_funcs[func_idx] 92 | for set_idx in range(len(patch_sets_augmented)): 93 | patch_sets_augmented[set_idx][num_augmented_so_far] = augment_func(patch_sets_in[set_idx][idx]) 94 | num_augmented_so_far += 1 95 | 96 | # log.debug("AUGMENTED Data patches: {}, Label patches: {}".format(patch_set_reference.shape, label_patches.shape)) 97 | for set_idx in range(len(patch_sets_augmented)): 98 | patch_sets_augmented[set_idx] = patch_sets_augmented[set_idx][:num_augmented_so_far] 99 | 100 | patch_sets_augmented = tuple(patch_sets_augmented) 101 | return patch_sets_augmented 102 | 103 | 104 | 105 | def normalise_patches(patch_sets, mean, std): 106 | unpack_on_return = False 107 | if type(patch_sets) is not tuple: 108 | patch_sets = (patch_sets,) 109 | unpack_on_return = True 110 | 111 | assert patch_sets[0].shape[1] == len(mean) == len(std) 112 | 113 | for i, patch_set in enumerate(patch_sets): 114 | for modality in range(patch_set.shape[1]): 115 | patch_sets[i][modality] -= mean[modality] 116 | patch_sets[i][modality] /= std[modality] 117 | 118 | return patch_sets if not unpack_on_return else patch_sets[0] 119 | 120 | 121 | def normalise_patch(patch, mean, std): 122 | assert patch.shape[0] == len(mean) == len(std) 123 | for modality in range(patch.shape[0]): 124 | patch[modality] -= mean[modality] 125 | patch[modality] /= std[modality] 126 | return patch 127 | 128 | 129 | def get_augment_functions(x_axis=1, y_axis=2): 130 | augment_funcs = { 131 | 0: lambda patch: np.rot90(patch.astype(np.float32), k=1, axes=(x_axis, y_axis)), 132 | 1: lambda patch: np.rot90(patch.astype(np.float32), k=2, axes=(x_axis, y_axis)), 133 | 2: lambda patch: np.rot90(patch.astype(np.float32), k=3, axes=(x_axis, y_axis)), 134 | 3: lambda patch: np.flip(patch.astype(np.float32), axis=x_axis), 135 | 4: lambda patch: np.flip(patch.astype(np.float32), axis=y_axis) 136 | } 137 | 138 | return augment_funcs 139 | -------------------------------------------------------------------------------- /architecture/SUNETx4.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging as log 3 | 4 | from keras import backend as K 5 | from keras.layers import Input, Dropout, Add, Softmax, PReLU, Concatenate, Maximum 6 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, MaxPooling2D, UpSampling2D 7 | from keras.layers.convolutional import Conv3D, Conv3DTranspose, MaxPooling3D, UpSampling3D 8 | 9 | from keras.layers.core import Permute, Reshape 10 | from keras.layers.normalization import BatchNormalization 11 | from keras.models import Model 12 | from .Architecture import Architecture 13 | 14 | K.set_image_dim_ordering('th') 15 | 16 | 17 | class SUNETx4(Architecture): 18 | @staticmethod 19 | def get_model(config, crossval_id=None): 20 | assert config.arch.num_dimensions in [2, 3] 21 | 22 | num_modalities = len(config.dataset.modalities) 23 | input_layer_shape = (num_modalities,) + config.arch.patch_shape[:config.arch.num_dimensions] 24 | output_layer_shape = (config.train.num_classes, np.prod(config.arch.patch_shape[:config.arch.num_dimensions])) 25 | 26 | model = generate_uresnet_model( 27 | config.arch.num_dimensions, 28 | config.train.num_classes, 29 | input_layer_shape, 30 | output_layer_shape, 31 | config.arch.dropout_rate, 32 | config.arch.base_filters 33 | ) 34 | 35 | trainable_count = int(np.sum([K.count_params(p) for p in set(model.trainable_weights)])) 36 | log.debug("SUNETx4 model with {} trainable parameters".format(trainable_count)) 37 | 38 | return model 39 | 40 | 41 | def generate_uresnet_model( 42 | ndims, num_classes, input_shape, output_shape, dropout_rate=0.2, filts=32): 43 | 44 | l1_input = Input(shape=input_shape) 45 | l1_start = get_conv_input(ndims, l1_input, 1*filts) 46 | 47 | l1_end = get_dual_res_layer(ndims, l1_start, 1*filts) 48 | l2_start = get_downconvolution_layer(ndims, l1_end, 1*filts) 49 | 50 | l2_end = get_dual_res_layer(ndims, l2_start, 2*filts) 51 | l3_start = get_downconvolution_layer(ndims, l2_end, 2*filts) 52 | 53 | l3_end = get_dual_res_layer(ndims, l3_start, 4*filts) 54 | l4_start = get_downconvolution_layer(ndims, l3_end, 4*filts) 55 | 56 | l4_latent = get_dual_res_layer(ndims, l4_start, 8*filts, dropout_rate) 57 | r4_upped = get_upconvolution_layer(ndims, l4_latent, 4*filts) 58 | 59 | r3_start = Add()([l3_end, r4_upped]) 60 | r3_end = get_mono_res_layer(ndims, r3_start, num_filters=4*filts) 61 | r3_upped = get_upconvolution_layer(ndims, r3_end, num_filters=2*filts) 62 | 63 | r2_start = Add()([l2_end, r3_upped]) 64 | r2_end = get_mono_res_layer(ndims, r2_start, num_filters=2*filts) 65 | r2_upped = get_upconvolution_layer(ndims, r2_end, num_filters=1*filts) 66 | 67 | r1_start = Add()([l1_end, r2_upped]) 68 | r1_end = get_mono_res_layer(ndims, r1_start, num_filters=1*filts) 69 | pred = get_conv_output(ndims, r1_end, num_classes, output_shape) 70 | 71 | return Model(inputs=[l1_input], outputs=[pred]) 72 | 73 | 74 | def get_dual_res_layer(ndims, layer_in, num_filters, dropout_rate=0.2): 75 | Conv = Conv2D if ndims == 2 else Conv3D 76 | kernel_size_a = (3, 3) if ndims == 2 else (3, 3, 3) 77 | 78 | a = BatchNormalization(axis=1)(layer_in) 79 | a = PReLU(shared_axes=(2,3) if ndims is 2 else (2,3,4))(a) 80 | a = Conv(num_filters, kernel_size=kernel_size_a, padding='same')(a) 81 | 82 | a = Dropout(dropout_rate)(a) 83 | 84 | a = BatchNormalization(axis=1)(a) 85 | a = PReLU(shared_axes=(2,3) if ndims is 2 else (2,3,4))(a) 86 | a = Conv(num_filters, kernel_size=kernel_size_a, padding='same')(a) 87 | 88 | layer_out = Add()([a, layer_in]) 89 | return layer_out 90 | 91 | 92 | def get_mono_res_layer(ndims, layer_in, num_filters, dropout_rate=0.0): 93 | Conv = Conv2D if ndims == 2 else Conv3D 94 | kernel_size_a = (3, 3) if ndims == 2 else (3, 3, 3) 95 | 96 | a = BatchNormalization(axis=1)(layer_in) 97 | a = PReLU(shared_axes=(2,3) if ndims is 2 else (2,3,4))(a) 98 | a = Conv(num_filters, kernel_size=kernel_size_a, padding='same')(a) 99 | a = Dropout(dropout_rate)(a) 100 | 101 | layer_out = Add()([a, layer_in]) 102 | 103 | return layer_out 104 | 105 | 106 | def get_upconvolution_layer(ndims, layer_in, num_filters): 107 | kernel_size = (3, 3) if ndims == 2 else (3, 3, 3) 108 | strides = (2, 2) if ndims == 2 else (2, 2, 2) 109 | ConvTranspose = Conv2DTranspose if ndims == 2 else Conv3DTranspose 110 | 111 | layer_out = ConvTranspose(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(layer_in) 112 | 113 | return layer_out 114 | 115 | 116 | def get_downconvolution_layer(ndims, layer_in, num_filters, dropout_rate=0.0): 117 | Conv = Conv2D if ndims == 2 else Conv3D 118 | MaxPooling = MaxPooling2D if ndims == 2 else MaxPooling3D 119 | kernel_size_a = (3, 3) if ndims == 2 else (3, 3, 3) 120 | 121 | # Low Res Branch 122 | a_halved = BatchNormalization(axis=1)(layer_in) 123 | a_halved = PReLU(shared_axes=(2,3) if ndims is 2 else (2,3,4))(a_halved) 124 | a_halved = Conv(num_filters, kernel_size=kernel_size_a, padding='same', strides=2)(a_halved) 125 | a_halved = Dropout(dropout_rate)(a_halved) 126 | 127 | # MP branch 128 | b_mp = MaxPooling(pool_size=(2, 2) if ndims == 2 else (2, 2, 2))(layer_in) 129 | 130 | layer_out_halved = Concatenate(axis=1)([a_halved, b_mp]) 131 | 132 | return layer_out_halved 133 | 134 | 135 | def get_conv_input(ndims, layer_in, num_filters): 136 | Conv = Conv2D if ndims == 2 else Conv3D 137 | kernel_size = (3, 3) if ndims == 2 else (3, 3, 3) 138 | 139 | layer_out = Conv(num_filters, kernel_size=kernel_size, padding='same')(layer_in) 140 | 141 | return layer_out 142 | 143 | 144 | def get_conv_output(ndims, layer_in, num_filters, output_shape): 145 | Conv = Conv2D if ndims == 2 else Conv3D 146 | kernel_size = (1, 1) if ndims == 2 else (3, 1, 1) 147 | 148 | pred = Conv(num_filters, kernel_size=kernel_size, padding='same')(layer_in) 149 | pred = Reshape(output_shape)(pred) 150 | pred = Permute((2, 1))(pred) # Put classes in last dimension 151 | pred = Softmax(axis=-1)(pred) # Apply softmax on last dimension 152 | 153 | return pred -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | import numpy as np 3 | import logging as log 4 | import time 5 | import math 6 | 7 | import itertools as it 8 | 9 | from scipy import ndimage 10 | from utils.medpy_hausdorff import hd as haussdorf_dist 11 | 12 | import copy 13 | import nibabel as nib 14 | 15 | 16 | def class_from_probabilities(y_prob_in, thresh, min_lesion_vox): 17 | """ 18 | Generates final class prediction by thresholding according to threshold and filtering by minimum lesion size 19 | """ 20 | 21 | # Apply threshold 22 | y_prob = y_prob_in > thresh 23 | 24 | # Get connected components information 25 | y_prob_labelled, nlesions = ndimage.label(y_prob) 26 | if nlesions > 0: 27 | label_list = np.arange(1, nlesions + 1) 28 | lesion_volumes = ndimage.labeled_comprehension(y_prob, y_prob_labelled, label_list, np.sum, float, 0) 29 | 30 | # Set to 0 invalid lesions 31 | lesions_to_ignore = [idx + 1 for idx, lesion_vol in enumerate(lesion_volumes) if lesion_vol < min_lesion_vox] 32 | y_prob_labelled[np.isin(y_prob_labelled, lesions_to_ignore)] = 0 33 | 34 | # Generate binary mask and return 35 | y_pred = (y_prob_labelled > 0).astype('uint8') 36 | 37 | return y_pred 38 | 39 | 40 | def compute_confusion_matrix(y_true, y_pred): 41 | """ 42 | Returns tuple tp, tn, fp, fn 43 | """ 44 | 45 | assert y_true.size == y_pred.size 46 | 47 | true_pos = np.sum(np.logical_and(y_true, y_pred)) 48 | true_neg = np.sum(np.logical_and(y_true == 0, y_pred == 0)) 49 | 50 | false_pos = np.sum(np.logical_and(y_true == 0, y_pred)) 51 | false_neg = np.sum(np.logical_and(y_true, y_pred == 0)) 52 | 53 | return true_pos, true_neg, false_pos, false_neg 54 | 55 | def compute_segmentation_metrics(y_true, y_pred, lesion=False, exclude=None): 56 | metrics = {} 57 | eps = K.epsilon() 58 | 59 | tp, tn, fp, fn = compute_confusion_matrix(y_true, y_pred) 60 | 61 | #Sensitivity and specificity 62 | metrics['sens'] = tp / (tp + fn + eps) 63 | metrics['spec'] = tn / (tn + fp + eps) 64 | 65 | # Voxel Fractions 66 | #metrics['tpf'] = metrics['sens'] 67 | metrics['fpf'] = 1 - metrics['spec'] 68 | 69 | # Lesion metrics 70 | if lesion: 71 | tpl, fpl, num_lesions_true, num_lesions_pred = compute_lesion_confusion_matrix(y_true, y_pred) 72 | metrics['l_tpf'] = tpl / num_lesions_true if num_lesions_true > 0 else np.nan 73 | metrics['l_fpf'] = fpl / num_lesions_pred if num_lesions_pred > 0 else np.nan 74 | 75 | metrics['l_ppv'] = tpl / (tpl + fpl + eps) 76 | metrics['l_f1'] = (2.0 * metrics['l_ppv'] * metrics['l_tpf']) / (metrics['l_ppv'] + metrics['l_tpf'] + eps) 77 | 78 | #Dice coefficient 79 | metrics['dsc'] = dice_coef(y_true, y_pred) 80 | # RELATIVE volume difference 81 | metrics['avd'] = 2.0 * np.abs(np.sum(y_pred) - np.sum(y_true))/(np.sum(y_pred) + np.sum(y_true) + eps) 82 | 83 | # Haussdorf distance 84 | try: 85 | metrics['hd'] = haussdorf_dist(y_pred.astype(np.float32), y_true.astype(np.float32), connectivity=3) # Why connectivity 3? 86 | except Exception as e: 87 | print(e) 88 | metrics['hd'] = np.nan 89 | 90 | if exclude is not None: 91 | [metrics.pop(metric, None) for metric in exclude] 92 | 93 | return metrics 94 | 95 | def add_metrics_avg_std(metrics_in): 96 | assert type(metrics_in) is list 97 | metrics_list = copy.deepcopy(metrics_in) 98 | 99 | metrics_all = copy.deepcopy(metrics_list[0]) 100 | for k, v in sorted(metrics_all.items()): 101 | metrics_all[k] = list() 102 | 103 | for metrics in metrics_list: 104 | for k, v in sorted(metrics.items()): 105 | metrics_all[k].append(v) 106 | 107 | metrics_avg = copy.deepcopy(metrics_list[0]) 108 | metrics_std = copy.deepcopy(metrics_list[0]) 109 | for k, v in sorted(metrics_all.items()): 110 | metrics_avg[k] = np.nanmean(metrics_all[k]) 111 | metrics_std[k] = np.nanstd(metrics_all[k]) 112 | 113 | # Avoid nan's corrupting average 114 | metrics_list.append(metrics_avg) 115 | metrics_list.append(metrics_std) 116 | 117 | return metrics_list 118 | 119 | # TODO avoid duplicated code! 120 | def get_metrics_avg_std(metrics_in): 121 | assert type(metrics_in) is list 122 | metrics_list = copy.deepcopy(metrics_in) 123 | 124 | metrics_all = copy.deepcopy(metrics_list[0]) 125 | for k, v in sorted(metrics_all.items()): 126 | metrics_all[k] = list() 127 | 128 | for metrics in metrics_list: 129 | for k, v in sorted(metrics.items()): 130 | metrics_all[k].append(v) 131 | 132 | metrics_avg = copy.deepcopy(metrics_list[0]) 133 | metrics_std = copy.deepcopy(metrics_list[0]) 134 | for k, v in sorted(metrics_all.items()): 135 | metrics_avg[k] = np.nanmean(metrics_all[k]) 136 | metrics_std[k] = np.nanstd(metrics_all[k]) 137 | 138 | return [metrics_avg, metrics_std] 139 | 140 | 141 | def compute_lesion_confusion_matrix(y_true, y_pred): 142 | # True positives 143 | lesions_true, num_lesions_true = ndimage.label(y_true) 144 | lesions_pred, num_lesions_pred = ndimage.label(y_pred) 145 | 146 | true_pos = 0.0 147 | for i in range(num_lesions_true): 148 | lesion_detected = np.logical_and(y_pred, lesions_true == (i + 1)).any() 149 | if lesion_detected: true_pos += 1 150 | true_pos = np.min([true_pos, num_lesions_pred]) 151 | 152 | # False positives 153 | tp_labels = np.unique(y_true * lesions_pred) 154 | fp_labels = np.unique(np.logical_not(y_true) * lesions_pred) 155 | 156 | # [label for label in fp_labels if label not in tp_labels] 157 | false_pos = 0.0 158 | for fp_label in fp_labels: 159 | if fp_label not in tp_labels: false_pos += 1 160 | 161 | return true_pos, false_pos, num_lesions_true, num_lesions_pred 162 | 163 | 164 | def dice_coef(y_true, y_pred, smooth = 0.01): 165 | intersection = np.sum(np.logical_and(y_true, y_pred)) 166 | 167 | if intersection > 0: 168 | return (2.0 * intersection + smooth) / (np.sum(y_true) + np.sum(y_pred) + smooth) 169 | else: 170 | return 0.0 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /workflow/network.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint 4 | 5 | from utils.metrics import * 6 | from utils.patch import * 7 | from utils.samples import * 8 | from utils.visualization import * 9 | from utils.volume import * 10 | from workflow.in_out import * 11 | from workflow.learning import build_testing_generator, build_learning_generators 12 | 13 | 14 | def generate_callbacks(config, crossval_id=None): 15 | model_filename = generate_model_filename(config, crossval_id=crossval_id) 16 | csv_filename = generate_log_filename(config) 17 | 18 | stopper = EarlyStopping( 19 | mode=config.train.train_monitor_mode, 20 | monitor=config.train.train_monitor, 21 | min_delta=1e-5, 22 | patience=config.train.patience) 23 | 24 | checkpointer = ModelCheckpoint( 25 | filepath=model_filename, 26 | verbose=0, 27 | monitor=config.train.train_monitor, 28 | save_best_only=True, 29 | save_weights_only=True) 30 | 31 | csv_logger = CSVLogger( 32 | csv_filename, 33 | separator=',') 34 | 35 | return [stopper, checkpointer, csv_logger] 36 | 37 | def train_model_on_dataset(config, model_def, dataset, crossval_id=None, load_trained=False): 38 | trainable_count = int(np.sum([K.count_params(p) for p in set(model_def.trainable_weights)])) 39 | 40 | if load_trained: 41 | try: 42 | model_def = read_model(config, model_def, crossval_id) 43 | log.debug("Loaded weights BEFORE training") 44 | except IOError: 45 | log.debug("Failed to load weights BEFORE training") 46 | 47 | if trainable_count == 0: 48 | weight_savefile = generate_model_filename(config, crossval_id) 49 | log.warn("Detected untrainable network, storing weights...\n{}".format(weight_savefile)) 50 | model_def.save_weights(weight_savefile) 51 | elif config.train.num_epochs > 0: 52 | train_gen, val_gen = build_learning_generators(config, dataset) 53 | train_model(config, model_def, train_gen, val_gen, crossval_id) 54 | del train_gen, val_gen 55 | 56 | model = read_model(config, model_def, crossval_id) 57 | return model 58 | 59 | 60 | def train_model(config, model_def, train_data, validation_data, crossval_id=None): 61 | train_generator, validation_generator = train_data, validation_data 62 | 63 | model_def.fit_generator( 64 | train_generator, 65 | epochs=config.train.num_epochs, 66 | validation_data=validation_generator, 67 | max_queue_size=10, 68 | shuffle=True, 69 | verbose=config.eval.keras_verbose, 70 | callbacks=generate_callbacks(config, crossval_id)) 71 | 72 | 73 | def test_samples(config, model, samples_in, save_segmentations=True): 74 | log.info("Testing {} samples...".format(len(samples_in))) 75 | 76 | for i, sample_test in enumerate(samples_in): 77 | lesion_probs = predict_sample(config, model, sample_test) 78 | 79 | rec_vol = class_from_probabilities(lesion_probs, config.dataset.lesion_threshold, config.dataset.min_lesion_voxels) 80 | 81 | if save_segmentations: 82 | save_result_sample(config, samples_in[i], rec_vol, params=('seg',), file_format='nii', asuint16=True) 83 | 84 | 85 | def evaluate_samples(config, model, samples_in, save_probabilities=True, save_segmentations=False): 86 | # TODO add thresh and ls search 87 | 88 | log.info("Evaluating with {} samples...".format(len(samples_in))) 89 | 90 | metrics_list = list() 91 | for i, sample_test in enumerate(samples_in): 92 | lesion_probs = predict_sample(config, model, sample_test) 93 | 94 | if save_probabilities: 95 | save_result_sample(config, samples_in[i], lesion_probs, params=('probs',)) 96 | 97 | 98 | rec_vol = class_from_probabilities(lesion_probs, config.dataset.lesion_threshold, config.dataset.min_lesion_voxels) 99 | 100 | if save_segmentations: 101 | save_result_sample(config, samples_in[i], rec_vol, params=('seg',)) 102 | 103 | if sample_test.labels is not None: 104 | true_vol = samples_in[i].labels[0] 105 | metrics_list += [compute_segmentation_metrics(true_vol, rec_vol, lesion=config.dataset.lesion_metrics)] 106 | print_metrics_list(metrics_list[-1], case_names=[sample_test.id]) 107 | 108 | 109 | return metrics_list if len(metrics_list) > 0 else None 110 | 111 | def predict_sample(config, model, sample_in): 112 | """ 113 | Given the volumes and the trained model -> Outputs positive lesion probabilities 114 | """ 115 | 116 | assert isinstance(sample_in, Sample) 117 | sample_seg = copy.deepcopy(sample_in) 118 | 119 | log.info("Segmenting case {} {}".format(sample_seg.id, sample_seg.data.shape)) 120 | sample_seg = zeropad_sample(sample_seg, config.arch.patch_shape) 121 | 122 | test_generator = build_testing_generator(config, sample_seg) 123 | pred = predict_model_generator(config, model, test_generator) 124 | 125 | lesion_probs = reconstruct_volume( 126 | patches=pred, 127 | patch_shape = config.arch.patch_shape, 128 | original_vol = sample_seg.data, 129 | original_centers=test_generator.centers, 130 | extraction_step=config.train.extraction_step_test, 131 | num_classes=config.train.num_classes) 132 | 133 | lesion_probs = remove_zeropad_volume(lesion_probs, config.arch.patch_shape) 134 | 135 | assert lesion_probs.shape == sample_in.data.shape[1:], str(lesion_probs.shape) + ", " + str(sample_in.data.shape[1:]) 136 | return lesion_probs 137 | 138 | def predict_model_generator(config, model, test_generator): 139 | """ 140 | Performs prediction on a set of patches and returns the reshaped prediction 141 | """ 142 | 143 | num_classes = config.train.num_classes 144 | output_shape = config.arch.output_shape[:config.arch.num_dimensions] 145 | 146 | pred = model.predict_generator(test_generator, verbose=config.eval.keras_verbose, workers=5) 147 | pred = pred.reshape((len(pred),) + output_shape + (num_classes,)) 148 | 149 | if len(pred.shape) != 5: 150 | pred = np.expand_dims(pred, axis=3) # 2D/3D compatibility (0 None, 1 x, 2 y, 3 new:z, 4 classes) 151 | log.debug("predict_model 2D detected, expanding to {}".format(pred.shape)) 152 | 153 | return pred 154 | 155 | -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | from utils.loss_functions import * 4 | from math import floor 5 | 6 | 7 | class TrainConfiguration: 8 | def __init__(self): 9 | # GPU management 10 | self.cuda_device_id = 0 11 | self.dynamic_gpu_memory = False 12 | 13 | # Basic 14 | self.num_classes = 2 15 | self.dataset_name = 'ISLES2017' # Overrided in multi evaluation 16 | self.architecture_name = 'SUNETx4' # Overrided in multi evaluation 17 | 18 | # Patch sampling config 19 | self.extraction_step = (6, 6, 3) 20 | self.extraction_step_test = (6, 6, 1) 21 | 22 | self.sampling = 'hybrid' # hybrid, Kamnitsas, Guerrero 23 | self.num_patches = 10000 # Max number of patches extracted PER CASE (pos + unif) 24 | self.uniform_ratio = 0.5 # Percentage of num_patches that will be uniformly sampled 25 | self.min_fg_percentage = 0.0 26 | 27 | # Filled from Configuration.on_change() 28 | self.min_pos_patches, self.max_pos_patches, self.max_unif_patches = None, None, None 29 | 30 | # Learning parameters 31 | self.activation = 'softmax' 32 | self.loss = 'categorical_crossentropy' # 'categorical_crossentropy' 33 | self.metrics = ['acc', dice] 34 | self.train_monitor, self.train_monitor_mode = 'val_dice', 'min' 35 | 36 | self.optimizer = 'Adadelta' # Check Arquitecture.py to add more optimizers 37 | self.learning_rate, self.lr_decay = 1.0, 1e-3 # Only used for 'Adam' optimizer # 0.03 38 | 39 | self.batch_size = 32 # default: 32 40 | self.patience = 6 41 | self.num_epochs = 0 42 | 43 | 44 | class EvaluationConfiguration: 45 | def __init__(self): 46 | self.verbose = 1 # 0 warn, 1 info, 2 debug 47 | self.keras_verbose = 1 48 | self.evaluation_type = 'multi' # multi, crossval, eval, metrics, metrics_search, debug, test 49 | 50 | # Val or test 51 | self.save_probabilities = False 52 | self.save_segmentations = True 53 | 54 | # Crossval 55 | self.crossval_fold = 9 56 | self.crossval_start = 0 # INDEX OF SAMPLE to start with, useful for resuming crossvalidations 57 | self.crossval_stop = None 58 | 59 | # Metrics 60 | self.metrics_eval_set = 'all' # all, val, test, 61 | self.metrics_thresh_values = [.1, .2, .3, .4, .5, .6, .7, .8] 62 | self.metrics_lesion_sizes = [1, 10, 20, 50, 100, 200, 300, 500, 750, 1000] 63 | 64 | # Multievaluation 65 | self.multieval_params = [ 66 | ('crossval', 'SUNETx4', 'ISLES15_SISS', 10000, 6, (24, 24, 8), (0.6, 200)), 67 | ('crossval', 'SUNETx4', 'ISLES15_SPES', 10000, 6, (24, 24, 8), (0.4, 500)), 68 | ('crossval', 'SUNETx4', 'ISLES2017', 7500, 9, (24, 24, 8), (0.2, 100)), 69 | 70 | ('crossval', '2Dunet', 'ISLES15_SISS', 10000, 6, (48, 48, 1), (0.8, 200)), 71 | ('crossval', '2Dunet', 'ISLES15_SPES', 10000, 6, (48, 48, 1), (0.1, 750)), 72 | ('crossval', '2Dunet', 'ISLES2017', 7500, 9, (48, 48, 1), (0.1, 100)), 73 | 74 | ('crossval', '3Dunet', 'ISLES15_SISS', 10000, 6, (24, 24, 8), (0.6, 50)), 75 | ('crossval', '3Dunet', 'ISLES15_SPES', 10000, 6, (24, 24, 8), (0.2, 1000)), 76 | ('crossval', '3Dunet', 'ISLES2017', 7500, 9, (24, 24, 8), (0.1, 200)), 77 | 78 | ('crossval', '3Duresnet', 'ISLES15_SISS', 10000, 6, (24, 24, 8), (0.7, 50)), 79 | ('crossval', '3Duresnet', 'ISLES15_SPES', 10000, 6, (24, 24, 8), (0.3, 1000)), 80 | ('crossval', '3Duresnet', 'ISLES2017', 7500, 9, (24, 24, 8), (0.2, 100)), 81 | ] 82 | 83 | def multieval_update_config(config, params): 84 | assert isinstance(config, Configuration) 85 | config.eval.evaluation_type = params[0] 86 | config.change_architecture(params[1]) 87 | config.change_dataset(params[2]) 88 | config.change_num_patches(params[3]) 89 | config.eval.crossval_fold = params[4] 90 | config.arch.patch_shape = params[5] if params[5] is not None else config.arch.patch_shape 91 | config.arch.output_shape = params[5] if params[5] is not None else config.arch.output_shape 92 | config.dataset.lesion_threshold = params[6][0] 93 | config.dataset.min_lesion_voxels = params[6][1] 94 | return config 95 | self.multieval_update_config = multieval_update_config 96 | 97 | 98 | """ 99 | Mandatory fields for architecture entry: 100 | -> 'multires' : Bool # If architecture is multiple patch input 101 | -> 'patch_shape' : 3-element tuple extraction shape (normally the biggest of em all) 102 | -> 'num_dimensions' 103 | """ 104 | architecture_dict = { 105 | '2Dunet': { 106 | 'num_dimensions': 2, 107 | 'patch_shape': (48, 48, 1), 108 | 'output_shape': (48, 48, 1) 109 | }, 110 | '3Dunet': { 111 | 'num_dimensions' : 3, 112 | 'patch_shape': (24, 24, 8), 113 | 'output_shape': (24, 24, 8) 114 | }, 115 | '3Duresnet': { 116 | 'num_dimensions' : 3, 117 | 'patch_shape': (24, 24, 8), 118 | 'output_shape': (24, 24, 8) 119 | }, 120 | 'SUNETx4': { 121 | 'num_dimensions' : 3, 122 | 'patch_shape': (24, 24, 8), 123 | 'output_shape': (24, 24, 8), 124 | 'dropout_rate': 0.2, 125 | 'base_filters': 32 126 | }, 127 | 128 | } 129 | 130 | """ 131 | Mandatory fields! 132 | 'lesion_metrics' 133 | 'modalities': ['', ...] 134 | 'evaluation_set' 135 | """ 136 | 137 | dataset_dict = { # Each dataset uses its own configuration keys, no need to preserve consistency 138 | 'ISLES15_SISS': { 139 | 'path': '/path/to/SISS/dataset/', 140 | 'lesion_metrics': False, 141 | 'evaluation_set': 'val', 142 | 'validation_split': 0.2, 143 | 'lesion_threshold': 0.6, 144 | 'min_lesion_voxels': 200, 145 | 'general_pattern': ['training/{}/', 'testing/{}/'], 146 | 'num_volumes': [28, 36], 147 | 'modalities': ['Flair', 'T1', 'T2', 'DWI'], 148 | }, 149 | 'ISLES15_SPES': { 150 | 'path': '/path/to/SPES/dataset/', 151 | 'lesion_metrics': False, 152 | 'evaluation_set': 'val', 153 | 'validation_split': 0.2, 154 | 'lesion_threshold': 0.4, 155 | 'min_lesion_voxels': 500, 156 | 'general_pattern': ['training/{}/', 'testing/Nr{}/'], 157 | 'num_volumes': [30, 20], 158 | 'modalities': ['DWI', 'CBF', 'CBV', 'T1c', 'T2', 'Tmax', 'TTP'], 159 | }, 160 | 'ISLES2017': { 161 | 'path': '/path/to/ISLES2017/dataset/', 162 | 'lesion_metrics': False, 163 | 'validation_split': 0.2, 164 | 'lesion_threshold': 0.2, 165 | 'min_lesion_voxels': 100, 166 | 'general_pattern': ['training/training_{}/', 'testing/test_{}/'], 167 | 'num_volumes': [48, 40], #[43, 32], 168 | 'modalities': ['ADC', 'CBF', 'CBV', 'MTT', 'Tmax', 'TTP'], 169 | 'evaluation_set': 'val', 170 | }, 171 | } 172 | 173 | """ 174 | Assertions and main config class 175 | """ 176 | def check_dictionaries(dataset, architecture): 177 | for dataset_name, dataset_config in dataset.items(): 178 | assert 'lesion_metrics' in dataset_config, 'Missing field \'lesion_metrics\' for database entry' 179 | assert 'modalities' in dataset_config, 'Missing field \'lesion_metrics\' for database entry' 180 | assert 'evaluation_set' in dataset_config, 'Missing field \'evaluation_set\' for database entry' 181 | 182 | for arch_name, arch_config in architecture.items(): 183 | assert all([len(patch_shape) == 3 for key, patch_shape in arch_config.items() if 'shape' in key]), arch_name 184 | 185 | check_dictionaries(dataset_dict, architecture_dict) 186 | 187 | import copy 188 | class DictionaryDot(dict): 189 | """ 190 | Wrapper for dictionary that allows for dot notation access to entries 191 | 192 | Example: 193 | m = DictDot({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) 194 | print(m.first_name) 195 | """ 196 | 197 | def __init__(self, *args, **kwargs): 198 | super(DictionaryDot, self).__init__(*args, **kwargs) 199 | for arg in args: 200 | if isinstance(arg, dict): 201 | for k, v in arg.iteritems(): 202 | self[k] = v 203 | 204 | if kwargs: 205 | for k, v in kwargs.iteritems(): 206 | self[k] = v 207 | 208 | def __getattr__(self, attr): 209 | return self.get(attr) 210 | 211 | def __setattr__(self, key, value): 212 | self.__setitem__(key, value) 213 | 214 | def __setitem__(self, key, value): 215 | super(DictionaryDot, self).__setitem__(key, value) 216 | self.__dict__.update({key: value}) 217 | 218 | def __delattr__(self, item): 219 | self.__delitem__(item) 220 | 221 | def __delitem__(self, key): 222 | super(DictionaryDot, self).__delitem__(key) 223 | del self.__dict__[key] 224 | 225 | def __deepcopy__(self, memo): 226 | return DictionaryDot([(copy.deepcopy(k, memo), copy.deepcopy(v, memo)) for k, v in self.items()]) 227 | 228 | class Configuration: 229 | def __init__(self): 230 | self.log_path = 'log/' 231 | self.model_path = 'checkpoints/' 232 | self.results_path = 'results/' 233 | self.metrics_path = 'metrics/' 234 | 235 | self.eval = EvaluationConfiguration() 236 | self.train = TrainConfiguration() 237 | self.dataset = DictionaryDot(dataset_dict[self.train.dataset_name]) 238 | self.arch = DictionaryDot(architecture_dict[self.train.architecture_name]) 239 | 240 | self.on_change() 241 | 242 | def on_change(self): 243 | self.train.min_pos_patches = int(((1.0 - self.train.uniform_ratio) * self.train.num_patches) // 6) 244 | self.train.max_pos_patches = int(floor((1.0 - self.train.uniform_ratio) * self.train.num_patches)) 245 | self.train.max_unif_patches = int(floor(self.train.uniform_ratio * self.train.num_patches)) 246 | 247 | # Other operations 248 | self.dataset.path = os.path.expanduser(self.dataset.path) 249 | if self.dataset.path[-1] is not '/': 250 | self.dataset.path += '/' 251 | 252 | if self.arch.num_dimensions is 2: 253 | assert all([shape[-1] == 1 for key, shape in self.arch.items() if 'shape' in key]) 254 | 255 | def change_num_patches(self, new_num_patches): 256 | self.train.num_patches = new_num_patches 257 | self.on_change() 258 | 259 | def change_architecture(self, new_arch_name): 260 | self.train.architecture_name = new_arch_name 261 | self.arch = DictionaryDot(architecture_dict[self.train.architecture_name]) 262 | self.on_change() 263 | 264 | def change_dataset(self, new_dataset_name): 265 | self.train.dataset_name = new_dataset_name 266 | self.dataset = DictionaryDot(dataset_dict[self.train.dataset_name]) 267 | self.on_change() 268 | -------------------------------------------------------------------------------- /workflow/generators.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import math 4 | import numpy as np 5 | import logging as log 6 | import nibabel as nib 7 | import scipy.misc 8 | import itertools as iter 9 | 10 | from keras.utils import Sequence 11 | import keras.backend as K 12 | 13 | from utils.patch import * 14 | from utils.samples import * 15 | from utils.visualization import * 16 | from utils.centers import * 17 | from utils.instructions import * 18 | from utils.slices import * 19 | from utils.volume import * 20 | 21 | from random import shuffle 22 | 23 | from workflow.in_out import get_architecture_instance 24 | 25 | class TestPatchGenerator(Sequence): 26 | def __init__(self, config, sample_in): 27 | # Important parameters 28 | self.config = config 29 | self.stats = compute_sample_statistics(sample_in) 30 | self.batch_size = config.train.batch_size 31 | 32 | # Store volumes and patch extraction instructions 33 | self.sample = sample_in 34 | self.centers = sample_uniform_patch_centers(config.arch.patch_shape, config.train.extraction_step_test, self.sample.data) 35 | 36 | self.sample = normalise_sample(sample_in, self.stats[0], self.stats[1]) 37 | self.instructions = get_architecture_instance(config).get_patch_extraction_instructions( 38 | config, 0, self.centers, self.sample.data, self.sample.labels, augment=False) 39 | 40 | log.debug("Test generator: extracted {} instructions, making {} batches".format( 41 | len(self.instructions), int(math.ceil(len(self.instructions) / float(self.batch_size))))) 42 | 43 | def __len__(self): 44 | return int(math.ceil(len(self.instructions) / float(self.batch_size))) 45 | 46 | def __getitem__(self, idx_global): 47 | batch_start = idx_global * self.batch_size 48 | batch_end = np.minimum((idx_global + 1) * self.batch_size, len(self.instructions)) 49 | batch_slice = slice(batch_start, batch_end) 50 | 51 | # Get instructions for current batch creation 52 | batch_instructions = self.instructions[batch_slice] 53 | batch_length = len(batch_instructions) 54 | assert batch_length <= self.batch_size, "Returning bigger batch than expected" 55 | 56 | # Allocate space for batch data 57 | ndims = self.config.arch.num_dimensions 58 | nmodalities = len(self.config.dataset.modalities) 59 | x = np.zeros((batch_length, nmodalities) + self.config.arch.patch_shape[:ndims], dtype=np.float32) 60 | 61 | # Extract patches as stated by instructions 62 | for idx, instruction in enumerate(batch_instructions): 63 | data_patch, _ = extract_patch_with_instruction(self.sample, instruction) 64 | x[idx] = data_patch 65 | 66 | return x 67 | 68 | class TrainPatchGenerator(Sequence): 69 | def __init__(self, config, samples_in, augment=False, is_validation=False, sampling='hybrid'): 70 | # Important parameters 71 | self.config = config 72 | self.stats = compute_set_statistics(samples_in) 73 | self.batch_size = config.train.batch_size 74 | self.is_validation = is_validation 75 | 76 | # Store volumes and patch extraction instructions 77 | self.samples = zeropad_set(copy.deepcopy(samples_in), config.arch.patch_shape) 78 | self.instructions = build_set_extraction_instructions(config, self.samples, augment=augment, sampling=sampling) 79 | shuffle(self.instructions) 80 | 81 | log.debug("Train generator: extracted {} instructions, making {} batches".format( 82 | len(self.instructions), int(math.ceil(len(self.instructions) / float(self.batch_size))))) 83 | 84 | def on_epoch_end(self): 85 | shuffle(self.instructions) 86 | 87 | def __len__(self): 88 | num_batches = len(self.instructions) / float(self.batch_size) 89 | return int(math.floor(num_batches)) if self.is_validation else int(math.ceil(num_batches)) 90 | 91 | def __getitem__(self, idx_global): 92 | batch_start = idx_global * self.batch_size 93 | batch_end = np.minimum((idx_global + 1) * self.batch_size, len(self.instructions)) 94 | batch_slice = slice(batch_start, batch_end) 95 | 96 | # Get instructions for current batch creation 97 | batch_instructions = self.instructions[batch_slice] 98 | batch_length = len(batch_instructions) 99 | 100 | if self.is_validation: assert batch_length == self.batch_size, "In validation cannot return partial batches" 101 | else: assert batch_length <= self.batch_size, "Training generator is returning bigger batch than it should" 102 | 103 | # Allocate space for batch data 104 | ndims = self.config.arch.num_dimensions 105 | nmodalities = len(self.config.dataset.modalities) 106 | 107 | x = np.zeros((batch_length, nmodalities) + self.config.arch.patch_shape[:ndims], dtype=np.float32) 108 | y = np.zeros((batch_length, np.prod(self.config.arch.output_shape[:ndims]), self.config.train.num_classes)) 109 | 110 | # Extract patches as stated by instructions 111 | for idx, instruction in enumerate(batch_instructions): 112 | data_patch, label_patch = extract_patch_with_instruction(self.samples, instruction) 113 | 114 | sample_mean = self.stats[0][instruction.sample_idx] 115 | sample_std = self.stats[1][instruction.sample_idx] 116 | 117 | # Store in batch 118 | try: 119 | x[idx] = normalise_patch(data_patch, sample_mean, sample_std) 120 | y[idx] = np_utils.to_categorical(label_patch.flatten(), self.config.train.num_classes) 121 | except Exception as e: 122 | print(e) 123 | print(instruction.sample_idx, instruction.data_patch_slice, instruction.label_patch_slice) 124 | 125 | 126 | return x, y 127 | 128 | 129 | def build_set_extraction_instructions(config, samples_in, augment=False, sampling='hybrid'): 130 | set_instructions = list() 131 | for idx, sample in enumerate(samples_in): 132 | printProgressBar(idx, len(samples_in), suffix='samples processed') 133 | set_instructions += build_sample_patch_extraction_instructions(config, idx, sample, augment=augment, sampling=sampling) 134 | printProgressBar(len(samples_in), len(samples_in), suffix='samples processed') 135 | 136 | return set_instructions 137 | 138 | 139 | def build_sample_patch_extraction_instructions(config, sample_idx, sample, augment=False, sampling='hybrid'): 140 | assert sampling in {'hybrid', 'Guerrero', 'Kamnitsas'} 141 | 142 | #log.debug("Using {} sampling strategy".format(sampling)) 143 | 144 | if sampling is 'hybrid': 145 | ### -------------------------- 146 | ### 1. Centers 147 | ### -------------------------- 148 | 149 | # Positive 150 | pos_centers = sample_positive_patch_centers(sample.labels) 151 | pos_centers = resample_centers( 152 | pos_centers, min_samples=config.train.min_pos_patches, max_samples=config.train.max_pos_patches) 153 | 154 | offset_patch_shape = config.arch.patch_shape if not config.arch.offset_shape else config.arch.offset_shape 155 | 156 | pos_centers = randomly_offset_centers( 157 | pos_centers, offset_shape=offset_patch_shape, patch_shape=config.arch.patch_shape, original_vol=sample.data) 158 | 159 | # Uniform 160 | extraction_step = config.train.extraction_step 161 | while True: 162 | unif_centers = sample_uniform_patch_centers(config.arch.patch_shape, extraction_step, sample.data) 163 | unif_centers = resample_centers(unif_centers, max_samples=config.train.max_unif_patches) 164 | 165 | if unif_centers.shape[0] >= config.train.max_unif_patches - 1: 166 | break 167 | 168 | extraction_step = tuple([np.maximum(1, dim_step - 1) for dim_step in extraction_step]) 169 | if np.array_equal(extraction_step, (1, 1, 1)): 170 | raise ValueError, "Cannot extract enough uniform patches, please decrease number of patches" 171 | 172 | ### ----------------------------------- 173 | ### 2. Patch extraction instructions 174 | ### ----------------------------------- 175 | 176 | arch = get_architecture_instance(config) 177 | pos_instructions = arch.get_patch_extraction_instructions( 178 | config, sample_idx, pos_centers, sample.data, sample.labels, augment=augment) 179 | unif_instructions = arch.get_patch_extraction_instructions( 180 | config, sample_idx, unif_centers, sample.data, sample.labels) 181 | 182 | sample_instructions = pos_instructions + unif_instructions 183 | 184 | elif sampling is 'Guerrero': 185 | ### -------------------------- 186 | ### 1. Centers 187 | ### -------------------------- 188 | 189 | # Positive 190 | pos_centers = sample_positive_patch_centers(sample.labels) 191 | pos_centers = resample_centers( 192 | pos_centers, min_samples=config.train.num_patches, max_samples=config.train.num_patches) 193 | 194 | offset_patch_shape = config.arch.patch_shape if not config.arch.offset_shape else config.arch.offset_shape 195 | pos_centers = randomly_offset_centers( 196 | pos_centers, offset_shape=offset_patch_shape, patch_shape=config.arch.patch_shape, original_vol=sample.data) 197 | 198 | ### ----------------------------------- 199 | ### 2. Patch extraction instructions 200 | ### ----------------------------------- 201 | 202 | arch = get_architecture_instance(config) 203 | pos_instructions = arch.get_patch_extraction_instructions( 204 | config, sample_idx, pos_centers, sample.data, sample.labels) 205 | 206 | sample_instructions = pos_instructions 207 | elif sampling is 'Kamnitsas': 208 | ### -------------------------- 209 | ### 1. Centers 210 | ### -------------------------- 211 | 212 | # Positive 213 | pos_centers = sample_positive_patch_centers(sample.labels) 214 | pos_centers = resample_centers( 215 | pos_centers, max_samples=config.train.max_pos_patches) 216 | 217 | # Uniform 218 | extraction_step = config.train.extraction_step 219 | while True: 220 | unif_centers = sample_uniform_patch_centers(config.arch.patch_shape, extraction_step, sample.data) 221 | unif_centers = resample_centers(unif_centers, max_samples=config.train.max_unif_patches) 222 | 223 | if unif_centers.shape[0] >= config.train.max_unif_patches - 1: 224 | break 225 | 226 | extraction_step = tuple([np.maximum(1, dim_step - 1) for dim_step in extraction_step]) 227 | if np.array_equal(extraction_step, (1, 1, 1)): 228 | raise ValueError, "Cannot extract enough uniform patches, please decrease number of patches" 229 | 230 | ### ----------------------------------- 231 | ### 2. Patch extraction instructions 232 | ### ----------------------------------- 233 | 234 | arch = get_architecture_instance(config) 235 | pos_instructions = arch.get_patch_extraction_instructions( 236 | config, sample_idx, pos_centers, sample.data, sample.labels, augment=augment) 237 | unif_instructions = arch.get_patch_extraction_instructions( 238 | config, sample_idx, unif_centers, sample.data, sample.labels) 239 | 240 | sample_instructions = pos_instructions + unif_instructions 241 | else: 242 | raise ValueError 243 | 244 | 245 | return sample_instructions -------------------------------------------------------------------------------- /workflow/evaluate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import scipy.misc 4 | import copy 5 | import time 6 | import datetime 7 | import numpy as np 8 | import logging as log 9 | import itertools as iter 10 | 11 | from workflow.learning import * 12 | from workflow.network import * 13 | from workflow.filenames import * 14 | 15 | import multiprocessing 16 | 17 | 18 | def run_multi(config): 19 | log.info("\n\n" + "="*75 + "\n Running multieval on {} \n".format(config.dataset.evaluation_set) + "="*75 + "\n") 20 | 21 | multieval_params = config.eval.multieval_params 22 | multieval_update_config = config.eval.multieval_update_config #function 23 | assert callable(multieval_update_config) 24 | 25 | # Set evaluation type as eval for naming consistency 26 | for eval_num, params in enumerate(multieval_params): 27 | log.info("\n Running multieval {}/{} - {}\n".format( 28 | eval_num, len(multieval_params), params)) 29 | 30 | config_param = multieval_update_config(copy.deepcopy(config), params) 31 | 32 | if config.eval.verbose in [1, 2]: 33 | parse_configuration(config_param, return_csv=False, print_terminal=False, num_columns=2, exclude=None) 34 | 35 | evaluation_type = config_param.eval.evaluation_type 36 | if evaluation_type == 'crossval': 37 | run_crossvalidation(config_param) 38 | elif evaluation_type == 'eval': 39 | run_evaluation(config_param) 40 | elif evaluation_type == 'metrics': 41 | run_metrics(config_param) 42 | elif evaluation_type == 'metrics_search': 43 | run_metrics_search(config_param) 44 | else: 45 | raise NotImplementedError("\'{}\' evaluation type is not valid".format(evaluation_type)) 46 | 47 | 48 | def run_evaluation(config): 49 | log.info("\n\n" + "="*75 + "\n Running evaluation on {} set \n".format(config.dataset.evaluation_set) + "="*75 + "\n") 50 | 51 | # Get instances of specified architecture and dataset in config 52 | arch = get_architecture_instance(config) 53 | _ = arch.get_model(config) # Test model before loading dataset 54 | 55 | # Load database samples and compile keras model 56 | dataset = get_dataset_instance(config) 57 | dataset.load(config) 58 | 59 | # Normalise and split train into train and val set (val will be normalised since train is already normalised) 60 | dataset, crossval_id = split_train_val(dataset, validation_split=config.dataset.validation_split) 61 | 62 | # Train model 63 | model_def = arch.generate_compiled_model(config, crossval_id) 64 | model = train_model_on_dataset(config, model_def, dataset, crossval_id, load_trained=True) 65 | log.info("Finished training") 66 | 67 | # Test model 68 | samples_eval = dataset.val if config.dataset.evaluation_set is 'val' else dataset.test 69 | metrics_sample = evaluate_samples( 70 | config, model, samples_eval, config.eval.save_probabilities, config.eval.save_segmentations) 71 | 72 | if len(metrics_sample) > 0: 73 | # Print results (only when there are labels to compare) 74 | metrics_final = add_metrics_avg_std(metrics_sample) 75 | case_ids = [s.id for s in samples_eval] + ['avg', 'std'] 76 | 77 | print_metrics_list(metrics_final, case_names=case_ids) 78 | save_metrics_csv(config, metrics_final, case_names=case_ids, filename=generate_metrics_filename(config)) 79 | 80 | 81 | def run_testing(config): 82 | log.info("\n\n" + "="*75 + "\n Running evaluation on test set \n" + "="*75 + "\n") 83 | 84 | # Get instances of specified architecture and dataset in config 85 | arch = get_architecture_instance(config) 86 | #_ = arch.get_model(config) # Test model before loading dataset 87 | 88 | # Load database samples and compile keras model 89 | dataset = get_dataset_instance(config) 90 | dataset.load(config) 91 | assert dataset.test, "Testing set not loaded" 92 | 93 | # Normalise and split train into train and val set (val will be normalised since train is already normalised) 94 | dataset, crossval_id = split_train_val(dataset, validation_split=config.dataset.validation_split) 95 | 96 | # Train model 97 | model_def = arch.generate_compiled_model(config, crossval_id) 98 | model = train_model_on_dataset(config, model_def, dataset, crossval_id, load_trained=False) # WONT DO ANYTHING -> No trainable weights 99 | log.info("Finished training") 100 | 101 | # Test model 102 | test_samples(config, model, dataset.test, save_segmentations=True) 103 | 104 | 105 | def run_crossvalidation(config): 106 | fold_factor = config.eval.crossval_fold 107 | 108 | log.info("\n\n" + "="*75 + "\n Running {}-fold crossvalidation on {} set \n".format( 109 | fold_factor, config.dataset.evaluation_set) + "="*75 + "\n") 110 | 111 | # Get instances of specified architecture and dataset in config 112 | arch = get_architecture_instance(config) 113 | #_ = arch.get_model(config) # Test model is correct before loading dataset # BROKEN if loading models inside arch 114 | 115 | # Load database samples and compile keras model 116 | dataset = get_dataset_instance(config) 117 | dataset.load(config) 118 | 119 | # Prepare variables 120 | dataset_original = copy.deepcopy(dataset) 121 | csv_filename = generate_metrics_filename(config) 122 | metrics_all, case_ids_all = list(), list() 123 | 124 | # BEGIN CROSS VALIDATION 125 | for i in range(config.eval.crossval_start, len(dataset_original.train), fold_factor): 126 | if config.eval.crossval_stop is not None: 127 | if i >= config.eval.crossval_stop: 128 | log.info("Reached crossval stop index, finishing...") 129 | break 130 | 131 | log.info("\n\n Running crossval iteration {} to {}\n\n".format(i, i + fold_factor)) 132 | 133 | dataset, crossval_id = split_train_val(copy.deepcopy(dataset_original), idxs=range(i, i + fold_factor)) 134 | 135 | # Train model 136 | model_def = arch.generate_compiled_model(config, crossval_id) 137 | model = train_model_on_dataset(config, model_def, dataset, crossval_id) 138 | log.info("Finished training") 139 | 140 | samples_eval = dataset.val if config.dataset.evaluation_set is 'val' else dataset.test 141 | metrics_fold = evaluate_samples( 142 | config, model, samples_eval, 143 | save_probabilities=config.eval.save_probabilities, 144 | save_segmentations=config.eval.save_segmentations) 145 | 146 | #Save metrics_fold 147 | case_ids_all += [s.id for s in samples_eval] 148 | metrics_all += metrics_fold 149 | 150 | # Print stuff 151 | metrics_final = add_metrics_avg_std(metrics_all) 152 | case_ids_final = case_ids_all + ['avg', 'std'] 153 | 154 | print_metrics_list(metrics_final, case_names=case_ids_final) 155 | save_metrics_csv(config, metrics_final, case_names=case_ids_final, filename=csv_filename) 156 | 157 | 158 | def run_metrics(config): 159 | # Get instances of specified architecture and dataset in config 160 | arch = get_architecture_instance(config) 161 | 162 | # Load database samples and compile keras model 163 | dataset = get_dataset_instance(config) 164 | dataset.load(config) 165 | dataset, crossval_id = split_train_val(dataset, validation_split=config.dataset.validation_split) 166 | 167 | # Pick evaluation samples 168 | evaluation_set = \ 169 | config.dataset.evaluation_set if config.eval.metrics_eval_set in [None, ''] else config.eval.metrics_eval_set 170 | if evaluation_set is 'all': 171 | samples_eval = dataset.train + dataset.val 172 | else: 173 | samples_eval = dataset.val if config.dataset.evaluation_set is 'val' else dataset.test 174 | 175 | # Generate result filename and try to load_samples results 176 | metrics_list = list() 177 | for i, sample in enumerate(samples_eval): 178 | if '031950' in str(sample.id): 179 | continue 180 | 181 | result_filename = generate_result_filename(config, (sample.id, 'probs'), 'nii.gz') 182 | log.info("Loading {}".format(result_filename)) 183 | lesion_probs = nib.load(result_filename).get_data() 184 | 185 | true_vol = sample.labels[0] 186 | rec_vol = class_from_probabilities(lesion_probs, config.dataset.lesion_threshold, config.dataset.min_lesion_voxels) 187 | 188 | metrics_list.append(compute_segmentation_metrics(true_vol, rec_vol, lesion=config.dataset.lesion_metrics)) 189 | 190 | metrics_final = add_metrics_avg_std(metrics_list) 191 | case_ids_final = [s.id for s in samples_eval] + ['avg', 'std'] 192 | 193 | #print_metrics_list(metrics_final, case_names=case_ids_final) 194 | print("Saving metrics...") 195 | save_metrics_csv(config, metrics_final, case_names=case_ids_final, filename=generate_metrics_filename(config)) 196 | 197 | def get_metrics(params): 198 | metrics_list_shared, job_num, thresh, lesion_size, prob_vols, true_vols, lesion_metrics = params #tuple unpacking 199 | 200 | metrics_iter = list() 201 | for lesion_probs, true_vol in iter.izip(prob_vols, true_vols): 202 | rec_vol = class_from_probabilities(lesion_probs, thresh, lesion_size) 203 | metrics_iter.append(compute_segmentation_metrics(true_vol, rec_vol, lesion=lesion_metrics)) 204 | 205 | m_avg, m_std = get_metrics_avg_std(metrics_iter) 206 | for k, v in m_std.items(): 207 | m_avg['{}_std'.format(k)] = v 208 | 209 | metrics_list_shared[job_num] = m_avg 210 | log.debug("Evaluated tresh={}, lesion_size={}".format(thresh, lesion_size)) 211 | 212 | def run_metrics_search(config): 213 | thresh_values = config.eval.metrics_thresh_values 214 | lesion_sizes = config.eval.metrics_lesion_sizes 215 | 216 | # Get instances of specified architecture and dataset in config 217 | #arch = get_architecture_instance(config) 218 | 219 | # Load database samples and compile keras model 220 | dataset = get_dataset_instance(config) 221 | dataset.load(config) 222 | dataset, crossval_id = split_train_val(dataset, validation_split=config.dataset.validation_split) 223 | 224 | evaluation_set = \ 225 | config.dataset.evaluation_set if config.eval.metrics_eval_set in [None, ''] else config.eval.metrics_eval_set 226 | if evaluation_set is 'all': 227 | samples_eval = dataset.train + dataset.val 228 | print("Getting ALL samples") 229 | else: 230 | print("Getting specific samples", config.dataset.evaluation_set) 231 | samples_eval = dataset.val if config.dataset.evaluation_set is 'val' else dataset.test 232 | 233 | 234 | true_vols, prob_vols = list(), list() 235 | for i, sample in enumerate(samples_eval): 236 | result_filename = generate_result_filename(config, (sample.id, 'probs'), 'nii.gz') 237 | log.info("Loading {}".format(result_filename)) 238 | prob_vols.append(nib.load(result_filename).get_data()) 239 | true_vols.append(sample.labels[0]) 240 | 241 | # Generate result filename and try to load_samples results 242 | metrics_list = list() 243 | metrics_names = list() 244 | 245 | for thresh, lesion_size in iter.product(thresh_values, lesion_sizes): 246 | log.debug("Evaluating tresh={}, lesion_size={}".format(thresh, lesion_size)) 247 | 248 | metrics_iter = list() 249 | for lesion_probs, true_vol in iter.izip(prob_vols, true_vols): 250 | rec_vol = class_from_probabilities(lesion_probs, thresh, lesion_size) 251 | metrics_iter.append(compute_segmentation_metrics(true_vol, rec_vol, lesion=config.dataset.lesion_metrics)) 252 | 253 | m_avg, m_std = get_metrics_avg_std(metrics_iter) 254 | for k, v in m_std.items(): 255 | m_avg['{}_std'.format(k)] = v 256 | 257 | metrics_list.append(m_avg) 258 | metrics_names.append("th={}_ls={}".format(thresh, lesion_size)) 259 | 260 | #metrics_list.append(get_metrics_avg_std(metrics_iter)[0]) 261 | #metrics_names.append("th={}_ls={}".format(thresh, lesion_size)) 262 | 263 | config.eval.evaluation_type = 'grid' 264 | result_filename, tables_filename = generate_metrics_filename(config) 265 | print_metrics_list(metrics_list, case_names=metrics_names) 266 | save_metrics_csv(config, metrics_list, case_names=metrics_names, filename=result_filename) 267 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import sys 4 | import time 5 | import logging as log 6 | import csv 7 | import itertools as iter 8 | import string 9 | import numpy as np 10 | import cv2 11 | 12 | import copy 13 | import nibabel as nib 14 | 15 | # Print iterations progress 16 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 25, fill = '='): 17 | """ 18 | Call in a loop to create terminal progress bar 19 | @params: 20 | iteration - Required : current iteration (Int) 21 | total - Required : total iterations (Int) 22 | prefix - Optional : prefix string (Str) 23 | suffix - Optional : suffix string (Str) 24 | decimals - Optional : positive number of decimals in percent complete (Int) 25 | length - Optional : character length of bar (Int) 26 | fill - Optional : bar fill character (Str) 27 | """ 28 | 29 | total = total 30 | 31 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 32 | filledLength = int(length * iteration // total) 33 | 34 | bar = fill * filledLength + '>' * min(length - filledLength, 1) + '.' * (length - filledLength - 1) 35 | 36 | print('\r{} [{}] {}% {}'.format(prefix, bar, percent, suffix), end='\r') 37 | sys.stdout.flush() 38 | 39 | # Print New Line on Complete 40 | if iteration == total: print() 41 | 42 | def print_metrics_list(metrics_list, case_names=None): 43 | if type(metrics_list) is not list: 44 | metrics_list = [metrics_list] 45 | 46 | if case_names is None: 47 | case_names = range(len(metrics_list)) 48 | 49 | #log.debug("{:^40}".format("Metrics")) 50 | print("{:<16}".format('Sample'), end='') 51 | for metric_name, metric_value in sorted(metrics_list[0].items()): 52 | print(" {:>8}".format(metric_name), end='') 53 | print("") 54 | 55 | for i, metrics in enumerate(metrics_list): 56 | print("{:<16}".format(case_names[i]), end='') 57 | for metric_name, metric_value in sorted(metrics.items()): 58 | print(" {:<08.3}".format(metric_value), end='') 59 | print("") 60 | print("") 61 | 62 | def save_metrics_csv(config, metrics_list, filename=None, case_names=None, print_config=False): 63 | if type(metrics_list) is not list: 64 | metrics_list = [metrics_list] 65 | if case_names is None: 66 | case_names = range(len(metrics_list)) 67 | 68 | with open(filename, 'wb') as csvfile: 69 | row = "{},".format("#sample") 70 | for metric_name, metric_value in sorted(metrics_list[0].items()): 71 | row += "{},".format(metric_name) 72 | csvfile.write(row + '\n') 73 | 74 | for i, metrics in enumerate(metrics_list): 75 | row = "{},".format(case_names[i]) 76 | for metric_name, metric_value in sorted(metrics.items()): 77 | row += "{},".format(metric_value) 78 | 79 | csvfile.write(row + '\n') 80 | 81 | csvfile.write(parse_configuration(config, return_csv=True, num_columns=1)) 82 | 83 | def append_grid_search_result(metrics_avg, param_names, param_values, filename): 84 | with open(filename, 'ab') as csvfile: 85 | row = " , ," # Add two empty entries to compensate for param names 86 | for metric_name, metric_value in sorted(metrics_avg.items()): 87 | row += "{},".format(metric_name) 88 | csvfile.write(row + '\n') 89 | 90 | for param_name, param_value in iter.izip(param_names, param_values): 91 | row += "{}={},".format(param_name, param_value) 92 | csvfile.write(row) 93 | 94 | row = "" 95 | for metric_name, metric_value in sorted(metrics_avg.items()): 96 | row += "{},".format(metric_value) 97 | csvfile.write(row + '\n') 98 | 99 | def save_grid_search_tables(config, metrics_avg_list, param_names, param_values, filename): 100 | assert len(metrics_avg_list) == np.prod([len(p) for p in param_values]) 101 | 102 | empty_table = [[None] * (len(param_values[0]) + 1) for i in range(len(param_values[1]) + 1)] 103 | 104 | metric_names = metrics_avg_list[0].keys() 105 | for metric_name in sorted(metric_names): 106 | # Make a copy of empty table 107 | metric_table = copy.deepcopy(empty_table) 108 | 109 | # Metric name 110 | metric_table[0][0] = "{},".format(metric_name) 111 | # Labels 1st column - param0 112 | for idx in range(len(param_values[0])): 113 | metric_table[idx+1][0] = "{},".format(param_names[0] + "=" + str(param_values[0][idx])) 114 | # Labels 1st row - param1 115 | for idx in range(len(param_values[1])): 116 | metric_table[0][idx+1] = "{},".format(param_names[1] + "=" + str(param_values[1][idx])) 117 | # Average metric values 118 | for count, idxs in enumerate(iter.product(range(len(param_values[0])), range(len(param_values[1])))): 119 | metric_table[idxs[0]+1][idxs[1]+1] = "{},".format(metrics_avg_list[count][metric_name]) 120 | 121 | with open(filename, 'ab') as csvfile: 122 | for row in metric_table: 123 | csvfile.write(string.join(row) + "\n") 124 | csvfile.write("\n") 125 | csvfile.write(parse_configuration(config, return_csv=True, num_columns=1)) 126 | 127 | def parse_configuration(config, return_csv=False, print_terminal=False, num_columns=1, exclude=list()): 128 | width_name, width_value = 20, 13 129 | total_width = num_columns*(width_name + width_value + 6) 130 | 131 | template_supertitle = '{:^' + str(total_width-6) + '}' 132 | template_title = '{}' 133 | template_name = ' {:.<' + str(width_name) + '}' 134 | template_value = ' {:<' + str(width_value) + '}' 135 | 136 | if exclude is not None: 137 | exclude += ['keras_verbose', 'num_classes', 'extraction_step', 'bg_discard_percentage', 'verbose', 'general_pattern', 138 | 'min_pos_patches', 'max_pos_patches', 'max_unif_patches', 'loss', 'train_monitor', 'metrics', 'format', 'path'] 139 | else: 140 | exclude = [] 141 | 142 | settings = list() 143 | 144 | settings += ['', 'Evaluation configuration'] 145 | settings.append(['evaluation_type', config.eval.evaluation_type]) 146 | settings.append(['verbose', config.eval.verbose]) 147 | settings.append(['save_probabilities', config.eval.save_probabilities]) 148 | settings.append(['save_segmentations', config.eval.save_segmentations]) 149 | if config.eval.evaluation_type is 'crossval': 150 | settings.append(['crossval_fold', str(config.eval.crossval_fold) + 'fold']) 151 | elif config.eval.evaluation_type in ['val', 'test']: 152 | pass 153 | elif config.eval.evaluation_type is 'grid': 154 | pass 155 | 156 | settings += ['', 'Dataset'] 157 | settings.append(['name', config.train.dataset_name]) 158 | for name, value in sorted(vars(config.dataset).items()): 159 | if name not in exclude: settings.append([name, value]) 160 | 161 | settings += ['', 'Architecture'] 162 | settings.append(['name', config.train.architecture_name]) 163 | for name, value in sorted(vars(config.arch).items()): 164 | if name not in exclude: settings.append([name, value]) 165 | 166 | settings += ['','Training'] 167 | for name, value in sorted(vars(config.train).items()): 168 | if name not in exclude + ['dataset_name', 'architecture_name']: settings.append([name, value]) 169 | 170 | if print_terminal: 171 | log.info( 172 | "\n\n" + "-" * total_width + "\n" + template_supertitle.format('Configuration') + "\n" + "-" * total_width) 173 | 174 | row = '' 175 | num_settings_section = 0 176 | for i, setting in enumerate(settings): 177 | if isinstance(setting, list): 178 | row += template_name.format(setting[0]) + template_value.format(setting[1]) 179 | num_settings_section += 1 180 | 181 | if num_settings_section > 0 and num_settings_section % num_columns == 0: # Normal column break 182 | print(row) 183 | num_settings_section, row = 0, '' 184 | else: 185 | if num_settings_section > 0: 186 | print(row) 187 | num_settings_section, row = 0, '' 188 | print(template_title.format(setting)) 189 | print('') 190 | 191 | csv_string = '' 192 | if return_csv: 193 | row = '' 194 | num_settings_section = 0 195 | for i, setting in enumerate(settings): 196 | if isinstance(setting, list): 197 | row += str(setting[0]) + ',' + str(setting[1]) + ',' 198 | num_settings_section += 1 199 | 200 | if num_settings_section > 0 and num_settings_section % num_columns == 0: # Normal column break 201 | csv_string += row + '\n' 202 | num_settings_section, row = 0, '' 203 | else: 204 | if num_settings_section > 0: 205 | csv_string += row + '\n' 206 | num_settings_section, row = 0, '' 207 | csv_string += str(setting) + '\n' 208 | 209 | return csv_string 210 | 211 | def store_batch(filename, x, y_in): 212 | if y_in.ndim != x[0].ndim: 213 | log.debug("Detected categorical y {}".format(y_in.shape)) 214 | y = np.reshape(y_in[:, :, 1], (len(y_in), 1, 24, 24, 24)) 215 | else: 216 | y = y_in 217 | 218 | cell_shape = (6*30, 6*30, 24) 219 | 220 | x_save = np.zeros(cell_shape) 221 | y_save = np.zeros(cell_shape) 222 | 223 | patch_shape = (24, 24, 24) 224 | modality_idx = 0 225 | 226 | # 6*6 227 | for i, j in iter.product(range(0, 6), range(6)): 228 | idx = np.ravel_multi_index((i,j), (6, 6)) 229 | 230 | selection = [slice(i*30, i*30 + patch_shape[0]), slice(j*30, j*30 + patch_shape[1]), slice(None)] 231 | x_save[selection] = x[0][idx, modality_idx] 232 | 233 | selection = [slice(i* 30, i * 30 + patch_shape[0]), slice(j * 30, j * 30 + patch_shape[1]), slice(None)] 234 | y_save[selection] = y[idx, 0] 235 | 236 | x_save = x_save - np.min(x_save) 237 | 238 | log.debug("Storing batch {}".format(filename.format('x'))) 239 | img_x = nib.Nifti1Image(x_save, np.eye(4)) 240 | nib.save(img_x, filename.format('x')) 241 | 242 | log.debug("Storing batch {}".format(filename.format('y'))) 243 | img_y = nib.Nifti1Image(y_save, np.eye(4)) 244 | nib.save(img_y, filename.format('y')) 245 | 246 | def store_sample_patches(filename, x, y_in): 247 | if y_in.ndim != x.ndim: 248 | log.debug("Detected categorical y {}".format(y_in.shape)) 249 | y = np.reshape(y_in[:, 1], (1, 24, 24, 8)) 250 | else: 251 | y = y_in 252 | 253 | num_patches = x.shape[0] + 1 254 | cell_shape = (num_patches*30, 24, 8) 255 | patch_shape = (24, 24, 8) 256 | 257 | x_save = np.zeros(cell_shape) 258 | 259 | for i in range(x.shape[0] + 1): 260 | selection = [slice(i*30, i*30 + patch_shape[0]), slice(None), slice(None)] 261 | 262 | if i < x.shape[0]: 263 | x_save[selection] = normalize_image(x[i]) 264 | else: 265 | x_save[selection] = normalize_image(y[0]) 266 | 267 | log.debug("Storing batch {}".format(filename.format('x'))) 268 | img_x = nib.Nifti1Image(x_save, np.eye(4)) 269 | nib.save(img_x, filename.format('x')) 270 | 271 | def normalize_image(img): 272 | image = np.round(255.0*((img - np.min(img)) / (np.max(img) - np.min(img))), decimals=0) 273 | return image.astype('uint8') 274 | 275 | def normalize_and_save_patch(image_patch, label_patch, filename): 276 | save_size = (128, 128) 277 | 278 | img = normalize_image(image_patch) 279 | img_big = cv2.resize(img, save_size, interpolation=cv2.INTER_NEAREST) 280 | 281 | if label_patch is not None: 282 | lbl = normalize_image(label_patch) 283 | lbl_big = cv2.resize(lbl, save_size, interpolation=cv2.INTER_NEAREST) 284 | patch = np.concatenate((img_big, lbl_big), axis=1) 285 | else: 286 | patch = img_big 287 | 288 | log.debug("Saving {} with max:{}, min{}".format(filename, np.max(patch), np.min(patch))) 289 | cv2.imwrite(filename, patch) 290 | 291 | def normalize_and_save_multires_patch(global_patch, local_patch, label_patch, filename): 292 | 293 | local_patch = np.pad(local_patch, (16,), 'constant', constant_values=0) 294 | label_patch = np.pad(label_patch, (24,), 'constant', constant_values=0) 295 | 296 | save_size = (128, 128) 297 | img_g_big = cv2.resize(global_patch, save_size, interpolation=cv2.INTER_NEAREST) 298 | img_l_big = cv2.resize(local_patch, save_size, interpolation=cv2.INTER_NEAREST) 299 | lbl_big = cv2.resize(label_patch, save_size, interpolation=cv2.INTER_NEAREST) 300 | 301 | patch = np.concatenate((img_g_big, img_l_big), axis=1) 302 | patch = normalize_image(patch) 303 | patch = np.concatenate((patch, normalize_image(lbl_big)), axis=1).astype('uint8') 304 | 305 | 306 | log.debug("Saving {}".format(filename)) 307 | cv2.imwrite(filename, patch) 308 | 309 | def normalize_and_save_patches(patches, filename): 310 | save_size = (128, 128) 311 | 312 | patch_out = None 313 | for patch in patches: 314 | p = normalize_image(patch) 315 | p_big = cv2.resize(np.copy(p), save_size, interpolation=cv2.INTER_NEAREST) 316 | patch_out = p_big if patch_out is None else np.concatenate((patch_out, p_big), axis=1) 317 | 318 | log.debug("Saving {}".format(filename)) 319 | cv2.imwrite(filename, patch_out) 320 | -------------------------------------------------------------------------------- /utils/medpy_hausdorff.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2013 Oskar Maier 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | # 16 | # author Oskar Maier 17 | # version r0.1.1 18 | # since 2014-03-13 19 | # status Release 20 | 21 | # build-in modules 22 | 23 | # third-party modules 24 | import numpy 25 | from scipy.ndimage import _ni_support 26 | from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\ 27 | generate_binary_structure 28 | from scipy.ndimage.measurements import label, find_objects 29 | from scipy.stats import pearsonr 30 | 31 | # own modules 32 | 33 | # code 34 | def dc(result, reference): 35 | r""" 36 | Dice coefficient 37 | 38 | Computes the Dice coefficient (also known as Sorensen index) between the binary 39 | objects in two images. 40 | 41 | The metric is defined as 42 | 43 | .. math:: 44 | 45 | DC=\frac{2|A\cap B|}{|A|+|B|} 46 | 47 | , where :math:`A` is the first and :math:`B` the second set of samples (here: binary objects). 48 | 49 | Parameters 50 | ---------- 51 | result : array_like 52 | Input data containing objects. Can be any type but will be converted 53 | into binary: background where 0, object everywhere else. 54 | reference : array_like 55 | Input data containing objects. Can be any type but will be converted 56 | into binary: background where 0, object everywhere else. 57 | 58 | Returns 59 | ------- 60 | dc : float 61 | The Dice coefficient between the object(s) in ```result``` and the 62 | object(s) in ```reference```. It ranges from 0 (no overlap) to 1 (perfect overlap). 63 | 64 | Notes 65 | ----- 66 | This is a real metric. The binary images can therefore be supplied in any order. 67 | """ 68 | result = numpy.atleast_1d(result.astype(numpy.bool)) 69 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 70 | 71 | intersection = numpy.count_nonzero(result & reference) 72 | 73 | size_i1 = numpy.count_nonzero(result) 74 | size_i2 = numpy.count_nonzero(reference) 75 | 76 | try: 77 | dc = 2. * intersection / float(size_i1 + size_i2) 78 | except ZeroDivisionError: 79 | dc = 0.0 80 | 81 | return dc 82 | 83 | def jc(result, reference): 84 | """ 85 | Jaccard coefficient 86 | 87 | Computes the Jaccard coefficient between the binary objects in two images. 88 | 89 | Parameters 90 | ---------- 91 | result: array_like 92 | Input data containing objects. Can be any type but will be converted 93 | into binary: background where 0, object everywhere else. 94 | reference: array_like 95 | Input data containing objects. Can be any type but will be converted 96 | into binary: background where 0, object everywhere else. 97 | Returns 98 | ------- 99 | jc: float 100 | The Jaccard coefficient between the object(s) in `result` and the 101 | object(s) in `reference`. It ranges from 0 (no overlap) to 1 (perfect overlap). 102 | 103 | Notes 104 | ----- 105 | This is a real metric. The binary images can therefore be supplied in any order. 106 | """ 107 | result = numpy.atleast_1d(result.astype(numpy.bool)) 108 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 109 | 110 | intersection = numpy.count_nonzero(result & reference) 111 | union = numpy.count_nonzero(result | reference) 112 | 113 | jc = float(intersection) / float(union) 114 | 115 | return jc 116 | 117 | def precision(result, reference): 118 | """ 119 | Precison. 120 | 121 | Parameters 122 | ---------- 123 | result : array_like 124 | Input data containing objects. Can be any type but will be converted 125 | into binary: background where 0, object everywhere else. 126 | reference : array_like 127 | Input data containing objects. Can be any type but will be converted 128 | into binary: background where 0, object everywhere else. 129 | 130 | Returns 131 | ------- 132 | precision : float 133 | The precision between two binary datasets, here mostly binary objects in images, 134 | which is defined as the fraction of retrieved instances that are relevant. The 135 | precision is not symmetric. 136 | 137 | See also 138 | -------- 139 | :func:`recall` 140 | 141 | Notes 142 | ----- 143 | Not symmetric. The inverse of the precision is :func:`recall`. 144 | High precision means that an algorithm returned substantially more relevant results than irrelevant. 145 | 146 | References 147 | ---------- 148 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 149 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 150 | """ 151 | result = numpy.atleast_1d(result.astype(numpy.bool)) 152 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 153 | 154 | tp = numpy.count_nonzero(result & reference) 155 | fp = numpy.count_nonzero(result & ~reference) 156 | 157 | try: 158 | precision = tp / float(tp + fp) 159 | except ZeroDivisionError: 160 | precision = 0.0 161 | 162 | return precision 163 | 164 | def recall(result, reference): 165 | """ 166 | Recall. 167 | 168 | Parameters 169 | ---------- 170 | result : array_like 171 | Input data containing objects. Can be any type but will be converted 172 | into binary: background where 0, object everywhere else. 173 | reference : array_like 174 | Input data containing objects. Can be any type but will be converted 175 | into binary: background where 0, object everywhere else. 176 | 177 | Returns 178 | ------- 179 | recall : float 180 | The recall between two binary datasets, here mostly binary objects in images, 181 | which is defined as the fraction of relevant instances that are retrieved. The 182 | recall is not symmetric. 183 | 184 | See also 185 | -------- 186 | :func:`precision` 187 | 188 | Notes 189 | ----- 190 | Not symmetric. The inverse of the recall is :func:`precision`. 191 | High recall means that an algorithm returned most of the relevant results. 192 | 193 | References 194 | ---------- 195 | .. [1] http://en.wikipedia.org/wiki/Precision_and_recall 196 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 197 | """ 198 | result = numpy.atleast_1d(result.astype(numpy.bool)) 199 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 200 | 201 | tp = numpy.count_nonzero(result & reference) 202 | fn = numpy.count_nonzero(~result & reference) 203 | 204 | try: 205 | recall = tp / float(tp + fn) 206 | except ZeroDivisionError: 207 | recall = 0.0 208 | 209 | return recall 210 | 211 | def sensitivity(result, reference): 212 | """ 213 | Sensitivity. 214 | Same as :func:`recall`, see there for a detailed description. 215 | 216 | See also 217 | -------- 218 | :func:`specificity` 219 | """ 220 | return recall(result, reference) 221 | 222 | def specificity(result, reference): 223 | """ 224 | Specificity. 225 | 226 | Parameters 227 | ---------- 228 | result : array_like 229 | Input data containing objects. Can be any type but will be converted 230 | into binary: background where 0, object everywhere else. 231 | reference : array_like 232 | Input data containing objects. Can be any type but will be converted 233 | into binary: background where 0, object everywhere else. 234 | 235 | Returns 236 | ------- 237 | specificity : float 238 | The specificity between two binary datasets, here mostly binary objects in images, 239 | which denotes the fraction of correctly returned negatives. The 240 | specificity is not symmetric. 241 | 242 | See also 243 | -------- 244 | :func:`sensitivity` 245 | 246 | Notes 247 | ----- 248 | Not symmetric. The completment of the specificity is :func:`sensitivity`. 249 | High recall means that an algorithm returned most of the irrelevant results. 250 | 251 | References 252 | ---------- 253 | .. [1] https://en.wikipedia.org/wiki/Sensitivity_and_specificity 254 | .. [2] http://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion 255 | """ 256 | result = numpy.atleast_1d(result.astype(numpy.bool)) 257 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 258 | 259 | tn = numpy.count_nonzero(~result & ~reference) 260 | fp = numpy.count_nonzero(result & ~reference) 261 | 262 | try: 263 | specificity = tn / float(tn + fp) 264 | except ZeroDivisionError: 265 | specificity = 0.0 266 | 267 | return specificity 268 | 269 | def true_negative_rate(result, reference): 270 | """ 271 | True negative rate. 272 | Same as :func:`sensitivity`, see there for a detailed description. 273 | 274 | See also 275 | -------- 276 | :func:`true_positive_rate` 277 | :func:`positive_predictive_value` 278 | """ 279 | return sensitivity(result, reference) 280 | 281 | def true_positive_rate(result, reference): 282 | """ 283 | True positive rate. 284 | Same as :func:`recall`, see there for a detailed description. 285 | 286 | See also 287 | -------- 288 | :func:`positive_predictive_value` 289 | :func:`true_negative_rate` 290 | """ 291 | return recall(result, reference) 292 | 293 | def positive_predictive_value(result, reference): 294 | """ 295 | Positive predictive value. 296 | Same as :func:`precision`, see there for a detailed description. 297 | 298 | See also 299 | -------- 300 | :func:`true_positive_rate` 301 | :func:`true_negative_rate` 302 | """ 303 | return precision(result, reference) 304 | 305 | def hd(result, reference, voxelspacing=None, connectivity=1): 306 | """ 307 | Hausdorff Distance. 308 | 309 | Computes the (symmetric) Hausdorff Distance (HD) between the binary objects in two 310 | images. It is defined as the maximum surface distance between the objects. 311 | 312 | Parameters 313 | ---------- 314 | result : array_like 315 | Input data containing objects. Can be any type but will be converted 316 | into binary: background where 0, object everywhere else. 317 | reference : array_like 318 | Input data containing objects. Can be any type but will be converted 319 | into binary: background where 0, object everywhere else. 320 | voxelspacing : float or sequence of floats, optional 321 | The voxelspacing in a distance unit i.e. spacing of elements 322 | along each dimension. If a sequence, must be of length equal to 323 | the input rank; if a single number, this is used for all axes. If 324 | not specified, a grid spacing of unity is implied. 325 | connectivity : int 326 | The neighbourhood/connectivity considered when determining the surface 327 | of the binary objects. This value is passed to 328 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 329 | Note that the connectivity influences the result in the case of the Hausdorff distance. 330 | 331 | Returns 332 | ------- 333 | hd : float 334 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 335 | object(s) in ```reference```. The distance unit is the same as for the spacing of 336 | elements along each dimension, which is usually given in mm. 337 | 338 | See also 339 | -------- 340 | :func:`assd` 341 | :func:`asd` 342 | 343 | Notes 344 | ----- 345 | This is a real metric. The binary images can therefore be supplied in any order. 346 | """ 347 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max() 348 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max() 349 | hd = max(hd1, hd2) 350 | return hd 351 | 352 | 353 | def hd95(result, reference, voxelspacing=None, connectivity=1): 354 | """ 355 | 95th percentile of the Hausdorff Distance. 356 | Computes the 95th percentile of the (symmetric) Hausdorff Distance (HD) between the binary objects in two 357 | images. Compared to the Hausdorff Distance, this metric is slightly more stable to small outliers and is 358 | commonly used in Biomedical Segmentation challenges. 359 | Parameters 360 | ---------- 361 | result : array_like 362 | Input data containing objects. Can be any type but will be converted 363 | into binary: background where 0, object everywhere else. 364 | reference : array_like 365 | Input data containing objects. Can be any type but will be converted 366 | into binary: background where 0, object everywhere else. 367 | voxelspacing : float or sequence of floats, optional 368 | The voxelspacing in a distance unit i.e. spacing of elements 369 | along each dimension. If a sequence, must be of length equal to 370 | the input rank; if a single number, this is used for all axes. If 371 | not specified, a grid spacing of unity is implied. 372 | connectivity : int 373 | The neighbourhood/connectivity considered when determining the surface 374 | of the binary objects. This value is passed to 375 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 376 | Note that the connectivity influences the result in the case of the Hausdorff distance. 377 | Returns 378 | ------- 379 | hd : float 380 | The symmetric Hausdorff Distance between the object(s) in ```result``` and the 381 | object(s) in ```reference```. The distance unit is the same as for the spacing of 382 | elements along each dimension, which is usually given in mm. 383 | See also 384 | -------- 385 | :func:`hd` 386 | Notes 387 | ----- 388 | This is a real metric. The binary images can therefore be supplied in any order. 389 | """ 390 | hd1 = __surface_distances(result, reference, voxelspacing, connectivity) 391 | hd2 = __surface_distances(reference, result, voxelspacing, connectivity) 392 | hd95 = numpy.percentile(numpy.hstack((hd1, hd2)), 95) 393 | return hd95 394 | 395 | 396 | def assd(result, reference, voxelspacing=None, connectivity=1): 397 | """ 398 | Average symmetric surface distance. 399 | 400 | Computes the average symmetric surface distance (ASD) between the binary objects in 401 | two images. 402 | 403 | Parameters 404 | ---------- 405 | result : array_like 406 | Input data containing objects. Can be any type but will be converted 407 | into binary: background where 0, object everywhere else. 408 | reference : array_like 409 | Input data containing objects. Can be any type but will be converted 410 | into binary: background where 0, object everywhere else. 411 | voxelspacing : float or sequence of floats, optional 412 | The voxelspacing in a distance unit i.e. spacing of elements 413 | along each dimension. If a sequence, must be of length equal to 414 | the input rank; if a single number, this is used for all axes. If 415 | not specified, a grid spacing of unity is implied. 416 | connectivity : int 417 | The neighbourhood/connectivity considered when determining the surface 418 | of the binary objects. This value is passed to 419 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 420 | The decision on the connectivity is important, as it can influence the results 421 | strongly. If in doubt, leave it as it is. 422 | 423 | Returns 424 | ------- 425 | assd : float 426 | The average symmetric surface distance between the object(s) in ``result`` and the 427 | object(s) in ``reference``. The distance unit is the same as for the spacing of 428 | elements along each dimension, which is usually given in mm. 429 | 430 | See also 431 | -------- 432 | :func:`asd` 433 | :func:`hd` 434 | 435 | Notes 436 | ----- 437 | This is a real metric, obtained by calling and averaging 438 | 439 | >>> asd(result, reference) 440 | 441 | and 442 | 443 | >>> asd(reference, result) 444 | 445 | The binary images can therefore be supplied in any order. 446 | """ 447 | assd = numpy.mean( (asd(result, reference, voxelspacing, connectivity), asd(reference, result, voxelspacing, connectivity)) ) 448 | return assd 449 | 450 | def asd(result, reference, voxelspacing=None, connectivity=1): 451 | """ 452 | Average surface distance metric. 453 | 454 | Computes the average surface distance (ASD) between the binary objects in two images. 455 | 456 | Parameters 457 | ---------- 458 | result : array_like 459 | Input data containing objects. Can be any type but will be converted 460 | into binary: background where 0, object everywhere else. 461 | reference : array_like 462 | Input data containing objects. Can be any type but will be converted 463 | into binary: background where 0, object everywhere else. 464 | voxelspacing : float or sequence of floats, optional 465 | The voxelspacing in a distance unit i.e. spacing of elements 466 | along each dimension. If a sequence, must be of length equal to 467 | the input rank; if a single number, this is used for all axes. If 468 | not specified, a grid spacing of unity is implied. 469 | connectivity : int 470 | The neighbourhood/connectivity considered when determining the surface 471 | of the binary objects. This value is passed to 472 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 473 | The decision on the connectivity is important, as it can influence the results 474 | strongly. If in doubt, leave it as it is. 475 | 476 | Returns 477 | ------- 478 | asd : float 479 | The average surface distance between the object(s) in ``result`` and the 480 | object(s) in ``reference``. The distance unit is the same as for the spacing 481 | of elements along each dimension, which is usually given in mm. 482 | 483 | See also 484 | -------- 485 | :func:`assd` 486 | :func:`hd` 487 | 488 | 489 | Notes 490 | ----- 491 | This is not a real metric, as it is directed. See `assd` for a real metric of this. 492 | 493 | The method is implemented making use of distance images and simple binary morphology 494 | to achieve high computational speed. 495 | 496 | Examples 497 | -------- 498 | The `connectivity` determines what pixels/voxels are considered the surface of a 499 | binary object. Take the following binary image showing a cross 500 | 501 | >>> from scipy.ndimage.morphology import generate_binary_structure 502 | >>> cross = generate_binary_structure(2, 1) 503 | array([[0, 1, 0], 504 | [1, 1, 1], 505 | [0, 1, 0]]) 506 | 507 | With `connectivity` set to `1` a 4-neighbourhood is considered when determining the 508 | object surface, resulting in the surface 509 | 510 | .. code-block:: python 511 | 512 | array([[0, 1, 0], 513 | [1, 0, 1], 514 | [0, 1, 0]]) 515 | 516 | Changing `connectivity` to `2`, a 8-neighbourhood is considered and we get: 517 | 518 | .. code-block:: python 519 | 520 | array([[0, 1, 0], 521 | [1, 1, 1], 522 | [0, 1, 0]]) 523 | 524 | , as a diagonal connection does no longer qualifies as valid object surface. 525 | 526 | This influences the results `asd` returns. Imagine we want to compute the surface 527 | distance of our cross to a cube-like object: 528 | 529 | >>> cube = generate_binary_structure(2, 1) 530 | array([[1, 1, 1], 531 | [1, 1, 1], 532 | [1, 1, 1]]) 533 | 534 | , which surface is, independent of the `connectivity` value set, always 535 | 536 | .. code-block:: python 537 | 538 | array([[1, 1, 1], 539 | [1, 0, 1], 540 | [1, 1, 1]]) 541 | 542 | Using a `connectivity` of `1` we get 543 | 544 | >>> asd(cross, cube, connectivity=1) 545 | 0.0 546 | 547 | while a value of `2` returns us 548 | 549 | >>> asd(cross, cube, connectivity=2) 550 | 0.20000000000000001 551 | 552 | due to the center of the cross being considered surface as well. 553 | 554 | """ 555 | sds = __surface_distances(result, reference, voxelspacing, connectivity) 556 | asd = sds.mean() 557 | return asd 558 | 559 | def ravd(result, reference): 560 | """ 561 | Relative absolute volume difference. 562 | 563 | Compute the relative absolute volume difference between the (joined) binary objects 564 | in the two images. 565 | 566 | Parameters 567 | ---------- 568 | result : array_like 569 | Input data containing objects. Can be any type but will be converted 570 | into binary: background where 0, object everywhere else. 571 | reference : array_like 572 | Input data containing objects. Can be any type but will be converted 573 | into binary: background where 0, object everywhere else. 574 | 575 | Returns 576 | ------- 577 | ravd : float 578 | The relative absolute volume difference between the object(s) in ``result`` 579 | and the object(s) in ``reference``. This is a percentage value in the range 580 | :math:`[-1.0, +inf]` for which a :math:`0` denotes an ideal score. 581 | 582 | Raises 583 | ------ 584 | RuntimeError 585 | If the reference object is empty. 586 | 587 | See also 588 | -------- 589 | :func:`dc` 590 | :func:`precision` 591 | :func:`recall` 592 | 593 | Notes 594 | ----- 595 | This is not a real metric, as it is directed. Negative values denote a smaller 596 | and positive values a larger volume than the reference. 597 | This implementation does not check, whether the two supplied arrays are of the same 598 | size. 599 | 600 | Examples 601 | -------- 602 | Considering the following inputs 603 | 604 | >>> import numpy 605 | >>> arr1 = numpy.asarray([[0,1,0],[1,1,1],[0,1,0]]) 606 | >>> arr1 607 | array([[0, 1, 0], 608 | [1, 1, 1], 609 | [0, 1, 0]]) 610 | >>> arr2 = numpy.asarray([[0,1,0],[1,0,1],[0,1,0]]) 611 | >>> arr2 612 | array([[0, 1, 0], 613 | [1, 0, 1], 614 | [0, 1, 0]]) 615 | 616 | comparing `arr1` to `arr2` we get 617 | 618 | >>> ravd(arr1, arr2) 619 | -0.2 620 | 621 | and reversing the inputs the directivness of the metric becomes evident 622 | 623 | >>> ravd(arr2, arr1) 624 | 0.25 625 | 626 | It is important to keep in mind that a perfect score of `0` does not mean that the 627 | binary objects fit exactely, as only the volumes are compared: 628 | 629 | >>> arr1 = numpy.asarray([1,0,0]) 630 | >>> arr2 = numpy.asarray([0,0,1]) 631 | >>> ravd(arr1, arr2) 632 | 0.0 633 | 634 | """ 635 | result = numpy.atleast_1d(result.astype(numpy.bool)) 636 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 637 | 638 | vol1 = numpy.count_nonzero(result) 639 | vol2 = numpy.count_nonzero(reference) 640 | 641 | if 0 == vol2: 642 | raise RuntimeError('The second supplied array does not contain any binary object.') 643 | 644 | return (vol1 - vol2) / float(vol2) 645 | 646 | def volume_correlation(results, references): 647 | r""" 648 | Volume correlation. 649 | 650 | Computes the linear correlation in binary object volume between the 651 | contents of the successive binary images supplied. Measured through 652 | the Pearson product-moment correlation coefficient. 653 | 654 | Parameters 655 | ---------- 656 | results : sequence of array_like 657 | Ordered list of input data containing objects. Each array_like will be 658 | converted into binary: background where 0, object everywhere else. 659 | references : sequence of array_like 660 | Ordered list of input data containing objects. Each array_like will be 661 | converted into binary: background where 0, object everywhere else. 662 | The order must be the same as for ``results``. 663 | 664 | Returns 665 | ------- 666 | r : float 667 | The correlation coefficient between -1 and 1. 668 | p : float 669 | The two-side p value. 670 | 671 | """ 672 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 673 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 674 | 675 | results_volumes = [numpy.count_nonzero(r) for r in results] 676 | references_volumes = [numpy.count_nonzero(r) for r in references] 677 | 678 | return pearsonr(results_volumes, references_volumes) # returns (Pearson' 679 | 680 | def volume_change_correlation(results, references): 681 | r""" 682 | Volume change correlation. 683 | 684 | Computes the linear correlation of change in binary object volume between 685 | the contents of the successive binary images supplied. Measured through 686 | the Pearson product-moment correlation coefficient. 687 | 688 | Parameters 689 | ---------- 690 | results : sequence of array_like 691 | Ordered list of input data containing objects. Each array_like will be 692 | converted into binary: background where 0, object everywhere else. 693 | references : sequence of array_like 694 | Ordered list of input data containing objects. Each array_like will be 695 | converted into binary: background where 0, object everywhere else. 696 | The order must be the same as for ``results``. 697 | 698 | Returns 699 | ------- 700 | r : float 701 | The correlation coefficient between -1 and 1. 702 | p : float 703 | The two-side p value. 704 | 705 | """ 706 | results = numpy.atleast_2d(numpy.array(results).astype(numpy.bool)) 707 | references = numpy.atleast_2d(numpy.array(references).astype(numpy.bool)) 708 | 709 | results_volumes = numpy.asarray([numpy.count_nonzero(r) for r in results]) 710 | references_volumes = numpy.asarray([numpy.count_nonzero(r) for r in references]) 711 | 712 | results_volumes_changes = results_volumes[1:] - results_volumes[:-1] 713 | references_volumes_changes = references_volumes[1:] - references_volumes[:-1] 714 | 715 | return pearsonr(results_volumes_changes, references_volumes_changes) # returns (Pearson's correlation coefficient, 2-tailed p-value) 716 | 717 | def obj_assd(result, reference, voxelspacing=None, connectivity=1): 718 | """ 719 | Average symmetric surface distance. 720 | 721 | Computes the average symmetric surface distance (ASSD) between the binary objects in 722 | two images. 723 | 724 | Parameters 725 | ---------- 726 | result : array_like 727 | Input data containing objects. Can be any type but will be converted 728 | into binary: background where 0, object everywhere else. 729 | reference : array_like 730 | Input data containing objects. Can be any type but will be converted 731 | into binary: background where 0, object everywhere else. 732 | voxelspacing : float or sequence of floats, optional 733 | The voxelspacing in a distance unit i.e. spacing of elements 734 | along each dimension. If a sequence, must be of length equal to 735 | the input rank; if a single number, this is used for all axes. If 736 | not specified, a grid spacing of unity is implied. 737 | connectivity : int 738 | The neighbourhood/connectivity considered when determining what accounts 739 | for a distinct binary object as well as when determining the surface 740 | of the binary objects. This value is passed to 741 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 742 | The decision on the connectivity is important, as it can influence the results 743 | strongly. If in doubt, leave it as it is. 744 | 745 | Returns 746 | ------- 747 | assd : float 748 | The average symmetric surface distance between all mutually existing distinct 749 | binary object(s) in ``result`` and ``reference``. The distance unit is the same as for 750 | the spacing of elements along each dimension, which is usually given in mm. 751 | 752 | See also 753 | -------- 754 | :func:`obj_asd` 755 | 756 | Notes 757 | ----- 758 | This is a real metric, obtained by calling and averaging 759 | 760 | >>> obj_asd(result, reference) 761 | 762 | and 763 | 764 | >>> obj_asd(reference, result) 765 | 766 | The binary images can therefore be supplied in any order. 767 | """ 768 | assd = numpy.mean( (obj_asd(result, reference, voxelspacing, connectivity), obj_asd(reference, result, voxelspacing, connectivity)) ) 769 | return assd 770 | 771 | 772 | def obj_asd(result, reference, voxelspacing=None, connectivity=1): 773 | """ 774 | Average surface distance between objects. 775 | 776 | First correspondences between distinct binary objects in reference and result are 777 | established. Then the average surface distance is only computed between corresponding 778 | objects. Correspondence is defined as unique and at least one voxel overlap. 779 | 780 | Parameters 781 | ---------- 782 | result : array_like 783 | Input data containing objects. Can be any type but will be converted 784 | into binary: background where 0, object everywhere else. 785 | reference : array_like 786 | Input data containing objects. Can be any type but will be converted 787 | into binary: background where 0, object everywhere else. 788 | voxelspacing : float or sequence of floats, optional 789 | The voxelspacing in a distance unit i.e. spacing of elements 790 | along each dimension. If a sequence, must be of length equal to 791 | the input rank; if a single number, this is used for all axes. If 792 | not specified, a grid spacing of unity is implied. 793 | connectivity : int 794 | The neighbourhood/connectivity considered when determining what accounts 795 | for a distinct binary object as well as when determining the surface 796 | of the binary objects. This value is passed to 797 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 798 | The decision on the connectivity is important, as it can influence the results 799 | strongly. If in doubt, leave it as it is. 800 | 801 | Returns 802 | ------- 803 | asd : float 804 | The average surface distance between all mutually existing distinct binary 805 | object(s) in ``result`` and ``reference``. The distance unit is the same as for the 806 | spacing of elements along each dimension, which is usually given in mm. 807 | 808 | See also 809 | -------- 810 | :func:`obj_assd` 811 | :func:`obj_tpr` 812 | :func:`obj_fpr` 813 | 814 | Notes 815 | ----- 816 | This is not a real metric, as it is directed. See `obj_assd` for a real metric of this. 817 | 818 | For the understanding of this metric, both the notions of connectedness and surface 819 | distance are essential. Please see :func:`obj_tpr` and :func:`obj_fpr` for more 820 | information on the first and :func:`asd` on the second. 821 | 822 | Examples 823 | -------- 824 | >>> arr1 = numpy.asarray([[1,1,1],[1,1,1],[1,1,1]]) 825 | >>> arr2 = numpy.asarray([[0,1,0],[0,1,0],[0,1,0]]) 826 | >>> arr1 827 | array([[1, 1, 1], 828 | [1, 1, 1], 829 | [1, 1, 1]]) 830 | >>> arr2 831 | array([[0, 1, 0], 832 | [0, 1, 0], 833 | [0, 1, 0]]) 834 | >>> obj_asd(arr1, arr2) 835 | 1.5 836 | >>> obj_asd(arr2, arr1) 837 | 0.333333333333 838 | 839 | With the `voxelspacing` parameter, the distances between the voxels can be set for 840 | each dimension separately: 841 | 842 | >>> obj_asd(arr1, arr2, voxelspacing=(1,2)) 843 | 1.5 844 | >>> obj_asd(arr2, arr1, voxelspacing=(1,2)) 845 | 0.333333333333 846 | 847 | More examples depicting the notion of object connectedness: 848 | 849 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,0],[0,0,0]]) 850 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 851 | >>> arr1 852 | array([[1, 0, 1], 853 | [1, 0, 0], 854 | [0, 0, 0]]) 855 | >>> arr2 856 | array([[1, 0, 1], 857 | [1, 0, 0], 858 | [0, 0, 1]]) 859 | >>> obj_asd(arr1, arr2) 860 | 0.0 861 | >>> obj_asd(arr2, arr1) 862 | 0.0 863 | 864 | >>> arr1 = numpy.asarray([[1,0,1],[1,0,1],[0,0,1]]) 865 | >>> arr2 = numpy.asarray([[1,0,1],[1,0,0],[0,0,1]]) 866 | >>> arr1 867 | array([[1, 0, 1], 868 | [1, 0, 1], 869 | [0, 0, 1]]) 870 | >>> arr2 871 | array([[1, 0, 1], 872 | [1, 0, 0], 873 | [0, 0, 1]]) 874 | >>> obj_asd(arr1, arr2) 875 | 0.6 876 | >>> obj_asd(arr2, arr1) 877 | 0.0 878 | 879 | Influence of `connectivity` parameter can be seen in the following example, where 880 | with the (default) connectivity of `1` the first array is considered to contain two 881 | objects, while with an increase connectivity of `2`, just one large object is 882 | detected. 883 | 884 | >>> arr1 = numpy.asarray([[1,0,0],[0,1,1],[0,1,1]]) 885 | >>> arr2 = numpy.asarray([[1,0,0],[0,0,0],[0,0,0]]) 886 | >>> arr1 887 | array([[1, 0, 0], 888 | [0, 1, 1], 889 | [0, 1, 1]]) 890 | >>> arr2 891 | array([[1, 0, 0], 892 | [0, 0, 0], 893 | [0, 0, 0]]) 894 | >>> obj_asd(arr1, arr2) 895 | 0.0 896 | >>> obj_asd(arr1, arr2, connectivity=2) 897 | 1.742955328 898 | 899 | Note that the connectivity also influence the notion of what is considered an object 900 | surface voxels. 901 | """ 902 | sds = list() 903 | labelmap1, labelmap2, _a, _b, mapping = __distinct_binary_object_correspondences(result, reference, connectivity) 904 | slicers1 = find_objects(labelmap1) 905 | slicers2 = find_objects(labelmap2) 906 | for lid2, lid1 in mapping.items(): 907 | window = __combine_windows(slicers1[lid1 - 1], slicers2[lid2 - 1]) 908 | object1 = labelmap1[window] == lid1 909 | object2 = labelmap2[window] == lid2 910 | sds.extend(__surface_distances(object1, object2, voxelspacing, connectivity)) 911 | asd = numpy.mean(sds) 912 | return asd 913 | 914 | def obj_fpr(result, reference, connectivity=1): 915 | """ 916 | The false positive rate of distinct binary object detection. 917 | 918 | The false positive rates gives a percentage measure of how many distinct binary 919 | objects in the second array do not exists in the first array. A partial overlap 920 | (of minimum one voxel) is here considered sufficient. 921 | 922 | In cases where two distinct binary object in the second array overlap with a single 923 | distinct object in the first array, only one is considered to have been detected 924 | successfully and the other is added to the count of false positives. 925 | 926 | Parameters 927 | ---------- 928 | result : array_like 929 | Input data containing objects. Can be any type but will be converted 930 | into binary: background where 0, object everywhere else. 931 | reference : array_like 932 | Input data containing objects. Can be any type but will be converted 933 | into binary: background where 0, object everywhere else. 934 | connectivity : int 935 | The neighbourhood/connectivity considered when determining what accounts 936 | for a distinct binary object. This value is passed to 937 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 938 | The decision on the connectivity is important, as it can influence the results 939 | strongly. If in doubt, leave it as it is. 940 | 941 | Returns 942 | ------- 943 | tpr : float 944 | A percentage measure of how many distinct binary objects in ``results`` have no 945 | corresponding binary object in ``reference``. It has the range :math:`[0, 1]`, where a :math:`0` 946 | denotes an ideal score. 947 | 948 | Raises 949 | ------ 950 | RuntimeError 951 | If the second array is empty. 952 | 953 | See also 954 | -------- 955 | :func:`obj_tpr` 956 | 957 | Notes 958 | ----- 959 | This is not a real metric, as it is directed. Whatever array is considered as 960 | reference should be passed second. A perfect score of :math:`0` tells that there are no 961 | distinct binary objects in the second array that do not exists also in the reference 962 | array, but does not reveal anything about objects in the reference array also 963 | existing in the second array (use :func:`obj_tpr` for this). 964 | 965 | Examples 966 | -------- 967 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 968 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 969 | >>> arr2 970 | array([[1, 0, 0], 971 | [1, 0, 1], 972 | [0, 0, 1]]) 973 | >>> arr1 974 | array([[0, 0, 1], 975 | [1, 0, 1], 976 | [0, 0, 1]]) 977 | >>> obj_fpr(arr1, arr2) 978 | 0.0 979 | >>> obj_fpr(arr2, arr1) 980 | 0.0 981 | 982 | Example of directedness: 983 | 984 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 985 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 986 | >>> obj_fpr(arr1, arr2) 987 | 0.0 988 | >>> obj_fpr(arr2, arr1) 989 | 0.3333333333333333 990 | 991 | Examples of multiple overlap treatment: 992 | 993 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 994 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 995 | >>> obj_fpr(arr1, arr2) 996 | 0.3333333333333333 997 | >>> obj_fpr(arr2, arr1) 998 | 0.3333333333333333 999 | 1000 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 1001 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1002 | >>> obj_fpr(arr1, arr2) 1003 | 0.0 1004 | >>> obj_fpr(arr2, arr1) 1005 | 0.3333333333333333 1006 | 1007 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1008 | [1,0,0,0,0], 1009 | [1,0,1,1,1], 1010 | [0,0,0,0,0], 1011 | [1,0,1,0,0]]) 1012 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1013 | [0,0,0,0,0], 1014 | [1,1,1,0,1], 1015 | [0,0,0,0,0], 1016 | [1,1,1,0,0]]) 1017 | >>> obj_fpr(arr1, arr2) 1018 | 0.0 1019 | >>> obj_fpr(arr2, arr1) 1020 | 0.2 1021 | """ 1022 | _, _, _, n_obj_reference, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1023 | return (n_obj_reference - len(mapping)) / float(n_obj_reference) 1024 | 1025 | def obj_tpr(result, reference, connectivity=1): 1026 | """ 1027 | The true positive rate of distinct binary object detection. 1028 | 1029 | The true positive rates gives a percentage measure of how many distinct binary 1030 | objects in the first array also exists in the second array. A partial overlap 1031 | (of minimum one voxel) is here considered sufficient. 1032 | 1033 | In cases where two distinct binary object in the first array overlaps with a single 1034 | distinct object in the second array, only one is considered to have been detected 1035 | successfully. 1036 | 1037 | Parameters 1038 | ---------- 1039 | result : array_like 1040 | Input data containing objects. Can be any type but will be converted 1041 | into binary: background where 0, object everywhere else. 1042 | reference : array_like 1043 | Input data containing objects. Can be any type but will be converted 1044 | into binary: background where 0, object everywhere else. 1045 | connectivity : int 1046 | The neighbourhood/connectivity considered when determining what accounts 1047 | for a distinct binary object. This value is passed to 1048 | `scipy.ndimage.morphology.generate_binary_structure` and should usually be :math:`> 1`. 1049 | The decision on the connectivity is important, as it can influence the results 1050 | strongly. If in doubt, leave it as it is. 1051 | 1052 | Returns 1053 | ------- 1054 | tpr : float 1055 | A percentage measure of how many distinct binary objects in ``result`` also exists 1056 | in ``reference``. It has the range :math:`[0, 1]`, where a :math:`1` denotes an ideal score. 1057 | 1058 | Raises 1059 | ------ 1060 | RuntimeError 1061 | If the reference object is empty. 1062 | 1063 | See also 1064 | -------- 1065 | :func:`obj_fpr` 1066 | 1067 | Notes 1068 | ----- 1069 | This is not a real metric, as it is directed. Whatever array is considered as 1070 | reference should be passed second. A perfect score of :math:`1` tells that all distinct 1071 | binary objects in the reference array also exist in the result array, but does not 1072 | reveal anything about additional binary objects in the result array 1073 | (use :func:`obj_fpr` for this). 1074 | 1075 | Examples 1076 | -------- 1077 | >>> arr2 = numpy.asarray([[1,0,0],[1,0,1],[0,0,1]]) 1078 | >>> arr1 = numpy.asarray([[0,0,1],[1,0,1],[0,0,1]]) 1079 | >>> arr2 1080 | array([[1, 0, 0], 1081 | [1, 0, 1], 1082 | [0, 0, 1]]) 1083 | >>> arr1 1084 | array([[0, 0, 1], 1085 | [1, 0, 1], 1086 | [0, 0, 1]]) 1087 | >>> obj_tpr(arr1, arr2) 1088 | 1.0 1089 | >>> obj_tpr(arr2, arr1) 1090 | 1.0 1091 | 1092 | Example of directedness: 1093 | 1094 | >>> arr2 = numpy.asarray([1,0,1,0,1]) 1095 | >>> arr1 = numpy.asarray([1,0,1,0,0]) 1096 | >>> obj_tpr(arr1, arr2) 1097 | 0.6666666666666666 1098 | >>> obj_tpr(arr2, arr1) 1099 | 1.0 1100 | 1101 | Examples of multiple overlap treatment: 1102 | 1103 | >>> arr2 = numpy.asarray([1,0,1,0,1,1,1]) 1104 | >>> arr1 = numpy.asarray([1,1,1,0,1,0,1]) 1105 | >>> obj_tpr(arr1, arr2) 1106 | 0.6666666666666666 1107 | >>> obj_tpr(arr2, arr1) 1108 | 0.6666666666666666 1109 | 1110 | >>> arr2 = numpy.asarray([1,0,1,1,1,0,1]) 1111 | >>> arr1 = numpy.asarray([1,1,1,0,1,1,1]) 1112 | >>> obj_tpr(arr1, arr2) 1113 | 0.6666666666666666 1114 | >>> obj_tpr(arr2, arr1) 1115 | 1.0 1116 | 1117 | >>> arr2 = numpy.asarray([[1,0,1,0,0], 1118 | [1,0,0,0,0], 1119 | [1,0,1,1,1], 1120 | [0,0,0,0,0], 1121 | [1,0,1,0,0]]) 1122 | >>> arr1 = numpy.asarray([[1,1,1,0,0], 1123 | [0,0,0,0,0], 1124 | [1,1,1,0,1], 1125 | [0,0,0,0,0], 1126 | [1,1,1,0,0]]) 1127 | >>> obj_tpr(arr1, arr2) 1128 | 0.8 1129 | >>> obj_tpr(arr2, arr1) 1130 | 1.0 1131 | """ 1132 | _, _, n_obj_result, _, mapping = __distinct_binary_object_correspondences(reference, result, connectivity) 1133 | return len(mapping) / float(n_obj_result) 1134 | 1135 | def __distinct_binary_object_correspondences(reference, result, connectivity=1): 1136 | """ 1137 | Determines all distinct (where connectivity is defined by the connectivity parameter 1138 | passed to scipy's `generate_binary_structure`) binary objects in both of the input 1139 | parameters and returns a 1to1 mapping from the labelled objects in reference to the 1140 | corresponding (whereas a one-voxel overlap suffices for correspondence) objects in 1141 | result. 1142 | 1143 | All stems from the problem, that the relationship is non-surjective many-to-many. 1144 | 1145 | @return (labelmap1, labelmap2, n_lables1, n_labels2, labelmapping2to1) 1146 | """ 1147 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1148 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1149 | 1150 | # binary structure 1151 | footprint = generate_binary_structure(result.ndim, connectivity) 1152 | 1153 | # label distinct binary objects 1154 | labelmap1, n_obj_result = label(result, footprint) 1155 | labelmap2, n_obj_reference = label(reference, footprint) 1156 | 1157 | # find all overlaps from labelmap2 to labelmap1; collect one-to-one relationships and store all one-two-many for later processing 1158 | slicers = find_objects(labelmap2) # get windows of labelled objects 1159 | mapping = dict() # mappings from labels in labelmap2 to corresponding object labels in labelmap1 1160 | used_labels = set() # set to collect all already used labels from labelmap2 1161 | one_to_many = list() # list to collect all one-to-many mappings 1162 | for l1id, slicer in enumerate(slicers): # iterate over object in labelmap2 and their windows 1163 | l1id += 1 # labelled objects have ids sarting from 1 1164 | bobj = (l1id) == labelmap2[slicer] # find binary object corresponding to the label1 id in the segmentation 1165 | l2ids = numpy.unique(labelmap1[slicer][bobj]) # extract all unique object identifiers at the corresponding positions in the reference (i.e. the mapping) 1166 | l2ids = l2ids[0 != l2ids] # remove background identifiers (=0) 1167 | if 1 == len(l2ids): # one-to-one mapping: if target label not already used, add to final list of object-to-object mappings and mark target label as used 1168 | l2id = l2ids[0] 1169 | if not l2id in used_labels: 1170 | mapping[l1id] = l2id 1171 | used_labels.add(l2id) 1172 | elif 1 < len(l2ids): # one-to-many mapping: store relationship for later processing 1173 | one_to_many.append((l1id, set(l2ids))) 1174 | 1175 | # process one-to-many mappings, always choosing the one with the least labelmap2 correspondences first 1176 | while True: 1177 | one_to_many = [(l1id, l2ids - used_labels) for l1id, l2ids in one_to_many] # remove already used ids from all sets 1178 | one_to_many = [x for x in one_to_many if x[1]] # remove empty sets 1179 | one_to_many = sorted(one_to_many, key=lambda x: len(x[1])) # sort by set length 1180 | if 0 == len(one_to_many): 1181 | break 1182 | l2id = one_to_many[0][1].pop() # select an arbitrary target label id from the shortest set 1183 | mapping[one_to_many[0][0]] = l2id # add to one-to-one mappings 1184 | used_labels.add(l2id) # mark target label as used 1185 | one_to_many = one_to_many[1:] # delete the processed set from all sets 1186 | 1187 | return labelmap1, labelmap2, n_obj_result, n_obj_reference, mapping 1188 | 1189 | def __surface_distances(result, reference, voxelspacing=None, connectivity=1): 1190 | """ 1191 | The distances between the surface voxel of binary objects in result and their 1192 | nearest partner surface voxel of a binary object in reference. 1193 | """ 1194 | result = numpy.atleast_1d(result.astype(numpy.bool)) 1195 | reference = numpy.atleast_1d(reference.astype(numpy.bool)) 1196 | if voxelspacing is not None: 1197 | voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) 1198 | voxelspacing = numpy.asarray(voxelspacing, dtype=numpy.float64) 1199 | if not voxelspacing.flags.contiguous: 1200 | voxelspacing = voxelspacing.copy() 1201 | 1202 | # binary structure 1203 | footprint = generate_binary_structure(result.ndim, connectivity) 1204 | 1205 | # test for emptiness 1206 | if 0 == numpy.count_nonzero(result): 1207 | raise RuntimeError('The first supplied array does not contain any binary object.') 1208 | if 0 == numpy.count_nonzero(reference): 1209 | raise RuntimeError('The second supplied array does not contain any binary object.') 1210 | 1211 | # extract only 1-pixel border line of objects 1212 | result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) 1213 | reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) 1214 | 1215 | # compute average surface distance 1216 | # Note: scipys distance transform is calculated only inside the borders of the 1217 | # foreground objects, therefore the input has to be reversed 1218 | dt = distance_transform_edt(~reference_border, sampling=voxelspacing) 1219 | sds = dt[result_border] 1220 | 1221 | return sds 1222 | 1223 | def __combine_windows(w1, w2): 1224 | """ 1225 | Joins two windows (defined by tuple of slices) such that their maximum 1226 | combined extend is covered by the new returned window. 1227 | """ 1228 | res = [] 1229 | for s1, s2 in zip(w1, w2): 1230 | res.append(slice(min(s1.start, s2.start), max(s1.stop, s2.stop))) 1231 | return tuple(res) --------------------------------------------------------------------------------