├── .gitignore ├── LICENSE ├── README.md ├── SGE_scripts └── run_on_host.sh ├── config ├── __init__.py └── system.py ├── data ├── __init__.py ├── batch_provider.py ├── data_switch.py ├── lidc_data.py └── lidc_data_loader.py ├── eval_dice_plot.py ├── eval_ged_plot.py ├── eval_ncc_plot.py ├── figures ├── graphical_model.png ├── gt_id165.gif └── samples_id165.gif ├── phiseg ├── __init__.py ├── experiments │ ├── __init__.py │ ├── detunet.py │ ├── phiseg_7_1.py │ ├── phiseg_7_1_1annot.py │ ├── phiseg_7_5.py │ ├── phiseg_7_5_1annot.py │ ├── probunet.py │ └── probunet_1annot.py ├── model_zoo │ ├── __init__.py │ ├── likelihoods.py │ ├── posteriors.py │ └── priors.py └── phiseg_model.py ├── phiseg_generate_samples.py ├── phiseg_makegif_samples.py ├── phiseg_sample_construction.py ├── phiseg_test_predictions.py ├── phiseg_test_quantitative.py ├── phiseg_train.py ├── requirements.txt ├── tfwrapper ├── __init__.py ├── activations.py ├── layers.py ├── losses.py ├── normalisation.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | .idea 4 | *.pyc 5 | *.swp 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PHiSeg Code 2 | 3 | Public tensorflow implementation for our paper [PHiSeg: Capturing Uncertainty in 4 | Medical Image Segmentation](https://arxiv.org/abs/1906.04045) method, 5 | which was accepted for presentation at [MICCAI 2019](https://www.miccai2019.org/). 6 | 7 | If you find this code helpful in your research please cite the following paper: 8 | 9 | ``` 10 | @article{PHiSeg2019Baumgartner, 11 | author={Baumgartner, Christian F. and Tezcan, Kerem C. and 12 | Chaitanya, Krishna and H{\"o}tker, Andreas M. and 13 | Muehlematter, Urs J. and Schawkat, Khoschy and Becker, Anton S. and 14 | Donati, Olivio and Konukoglu, Ender}, 15 | title={{PHiSeg}: Capturing Uncertainty in Medical Image Segmentation}, 16 | journal={arXiv:1906.04045}, 17 | year={2019}, 18 | } 19 | ``` 20 | 21 | ## Method overview 22 | 23 | Many medical image segmentation tasks are inherently ambiguous. For example, if six radiologists are asked 24 | to segment two regions of the prostate in an MR image, you will get six different answers: 25 | 26 | ![Method overview](figures/gt_id165.gif) 27 | 28 | We address this problem by developing a hierarchical probabilistic model that - unlike 29 | most conventional segmentation techniques - does not produce a single segmentation, 30 | but rather produces **samples from the distribution of probable segmentations for a specific image**. 31 | 32 | Here is an example of our methods output for the same test image. 33 | 34 | ![Method overview](figures/samples_id165.gif) 35 | 36 | It can be seen that the samples are very similar to those generated by the 6 experts above. 37 | 38 | Having access to such samples allows us to give a human user of this tool several options to choose from. 39 | The model can also be used to determine the most likely sample according to our model, and to **visualize 40 | areas of high uncertainty**. 41 | 42 | The method functions by constructing a hierarchical probabilistic model that assumes 43 | a generative process for the segmentation s, given the image x, in which the image is constructed 44 | one resolution level at a time (similar to Laplacian pyramids). The generation of each resolution 45 | level is assumed to be governed by a hidden, low dimensional variable z_l. Here is an image of the graphical model (right) 46 | along with and example of the generative process (left). 47 | 48 | ![Method overview](figures/graphical_model.png) 49 | 50 | In our [paper](https://arxiv.org/abs/1906.04045), we show that inference in this probabilistic model can be performed 51 | using a variation of the well known autoencoding variational Bayes framework. 52 | 53 | In the paper, we show that we outperform the related probabilistic U-NET on two relevant metrics. We furthermore 54 | show that **taking the probabilistic perspective for the segmentation problem comes at no cost for 55 | segmentation accuracy!** Our method performs just as well (in fact slightly better) than a deterministic 56 | U-NET on the conventional Dice Score metric. 57 | 58 | ## Virtual Environment Setup 59 | 60 | The code is implemented in Python 3.5 using using the tensorflow library. We only tested the code 61 | with tensorflow 1.12. One way of getting all the requirements is using virtualenv and the `requirements.txt` file. 62 | 63 | * Set up a virtual environment (e.g. conda or virtualenv) with Python 3.5 64 | * Install all non-tensorflow requirements using: 65 | 66 | ````pip install -r requirements.txt```` 67 | 68 | * Install the GPU version of tensorflow using 69 | 70 | ````pip install -r tensorflow-gpu==1.12```` 71 | 72 | ## Running the code 73 | 74 | Before running the code, a number of parameters must be configured: 75 | 76 | * Open `config/system.py` and change the settings to match your system. 77 | 78 | * Download the LIDC data from [Stefan Knegt's gihub page](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch). 79 | 80 | * In the experiment files under `phiseg/experiments` adapt the paths for the source data and 81 | the target path for the preprocessd data. 82 | 83 | Then the code can be run using the following command: 84 | 85 | * Start training a model by running `phiseg_train.py` with the corresponding experiment file. For example 86 | 87 | ```python phiseg_train.py phiseg/experiments/phiseg_7_5.py``` 88 | 89 | * The easiest way to monitor training is using tensorboard. 90 | 91 | * After training has finished, you can use `phiseg_generate_samples.py` to sample from 92 | the learned distribution, and `phiseg_test_quantitative` to obtain the quantitative results 93 | reported in the paper. 94 | 95 | ## Data 96 | 97 | The public implementation of our code currently trains and evaluates on the publicly available 98 | [LIDC Chest Lesion dataset](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI). We used the 99 | preprocessed data available on [Stefan Knegt's gihub page](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch) 100 | (see link at the very bottom). 101 | 102 | The prostate dataset collected at University Hospital Zurich is unfortuantely not publicly available 103 | as of yet. 104 | 105 | ## Code structure 106 | 107 | Under `phiseg/experiments` you will find a number of experiment config files, where the 108 | architecture and training details of the method are specified. Modify these files to explore different settings. 109 | 110 | ## Using the code with your own data 111 | 112 | If you want to use your own data, you need to create an appropriate data loader (see `data/lid_data_loader.py`) 113 | and a data provider (see `data/lidc_data.py`). You also need to add 114 | the dataset to `data/data_switch.py`. -------------------------------------------------------------------------------- /SGE_scripts/run_on_host.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Script to send job to BIWI clusters using qsub. 4 | # Usage: qsub evaluate_on_host.sh PATH-TO-EXPERIMENT-LOG-DIR 5 | 6 | # The script also requires changing the paths of the CUDA and python environments 7 | # and the code to the local equivalents of your machines. 8 | # Author: Christian F. Baumgartner (c.f.baumgartner@gmail.com) 9 | 10 | ## SET THE FOLLOWING VARIABLES ACCORDING TO YOUR SYSTEM ## 11 | CUDA_HOME=/scratch_net/bmicdl03/libs/cuda-9.0 12 | PROJECT_HOME=/scratch_net/bmicdl03/code/python/phiseg_public/ 13 | VIRTUAL_ENV_PATH=/scratch_net/bmicdl03/code/python/environments/tensorflow1.12-gpu/ 14 | 15 | ## SGE Variables: 16 | # 17 | ## otherwise the default shell would be used 18 | #$ -S /bin/bash 19 | # 20 | ## <= 2h is short queue, <= 24h is middle queue, <= 120h is long queue 21 | #$ -l h_rt=48:00:00 22 | 23 | ## the maximum memory usage of this job, (below 4G does not make much sense) 24 | #$ -l h_vmem=40G # Less RAM is required for evaluating than for training 25 | 26 | # Host and gpu settings 27 | #$ -l gpu 28 | ##$ -l hostname=bmicgpu04 ## <-------------- Comment in or out to force a specific machine 29 | 30 | ## stderr and stdout are merged together to stdout 31 | #$ -j y 32 | # 33 | # logging directory. preferably on your scratch 34 | #$ -o /scratch_net/bmicdl03/logs/phiseg/ ## <---------------- CHANGE TO MATCH YOUR SYSTEM 35 | # 36 | ## send mail on job's end and abort 37 | #$ -m a 38 | 39 | ## LOCAL PATHS 40 | # I think .bashrc is not executed on the remote host if you use qsub, so you need to set all the paths 41 | # and environment variables before executing the python code. 42 | 43 | # cuda paths 44 | export PATH=$CUDA_HOME/bin:$PATH 45 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH 46 | 47 | # for pyenv 48 | export PATH="/home/baumgach/.pyenv/bin:$PATH" 49 | eval "$(pyenv init -)" 50 | eval "$(pyenv virtualenv-init -)" 51 | 52 | # activate virtual environment 53 | source $VIRTUAL_ENV_PATH/bin/activate 54 | 55 | ## EXECUTION OF PYTHON CODE: 56 | python $PROJECT_HOME/$1 $2 57 | 58 | echo "Hostname was: `hostname`" 59 | echo "Reached end of job file." 60 | 61 | 62 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/config/__init__.py -------------------------------------------------------------------------------- /config/system.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | 4 | import os 5 | import socket 6 | import logging 7 | 8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 9 | 10 | ### SET THESE PATHS MANUALLY ##################################################### 11 | # Full paths are required because otherwise the code will not know where to look 12 | # when it is executed on one of the clusters. 13 | 14 | at_biwi = True # Are you running this code from the ETH Computer Vision Lab (Biwi)? 15 | 16 | project_root = '/scratch_net/bmicdl03/code/python/phiseg_public' 17 | local_hostnames = ['bmicdl03'] # used to check if on cluster or not 18 | log_root = '/itet-stor/baumgach/net_scratch/logs/phiseg_public' 19 | 20 | ################################################################################## 21 | 22 | running_on_gpu_host = True if socket.gethostname() not in local_hostnames else False 23 | 24 | 25 | def setup_GPU_environment(): 26 | 27 | if at_biwi: 28 | 29 | hostname = socket.gethostname() 30 | print('Running on %s' % hostname) 31 | if not hostname in local_hostnames: 32 | logging.info('Setting CUDA_VISIBLE_DEVICES variable...') 33 | 34 | # This command is multi GPU compatible: 35 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(os.environ["SGE_GPU"].split('\n')) 36 | logging.info('SGE_GPU is %s' % os.environ['SGE_GPU']) 37 | logging.info('CUDA_VISIBLE_DEVICES is %s' % os.environ['CUDA_VISIBLE_DEVICES']) 38 | 39 | else: 40 | logging.warning('!! No GPU setup defined. Perhaps you need to set CUDA_VISIBLE_DEVICES etc...?') -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/data/__init__.py -------------------------------------------------------------------------------- /data/batch_provider.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | 4 | import numpy as np 5 | 6 | from scipy.ndimage import zoom 7 | import utils 8 | 9 | import logging 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 11 | 12 | def resize_batch(imgs, target_size): 13 | 14 | sx = imgs.shape[1] 15 | sy = imgs.shape[2] 16 | return zoom(imgs, (1,float(target_size[0])/sx,float(target_size[1]/sy),1), order=0) 17 | 18 | class BatchProvider(): 19 | """ 20 | This is a helper class to conveniently access mini batches of training, testing and validation data 21 | """ 22 | 23 | def __init__(self, X, y, indices, add_dummy_dimension=False, **kwargs): # indices don't always cover all of X and Y (e.g. in the case of val set) 24 | 25 | self.X = X 26 | self.y = y 27 | self.indices = indices 28 | self.unused_indices = indices.copy() 29 | self.add_dummy_dimension = add_dummy_dimension 30 | 31 | self.num_labels_per_subject = kwargs.get('num_labels_per_subject', 1) 32 | if self.num_labels_per_subject > 1: 33 | self.annotator_range = kwargs.get('annotator_range', range(self.num_labels_per_subject)) 34 | 35 | self.resize_to = kwargs.get('resize_to', None) 36 | 37 | self.do_augmentations = kwargs.get('do_augmentations', False) 38 | self.augmentation_options = kwargs.get('augmentation_options', None) 39 | self.rescale_range = kwargs.get('rescale_range', None) 40 | self.rescale_rgb = kwargs.get('rescale_rgb', None) 41 | self.normalise_images = True if not self.rescale_range else False # normalise if not rescale 42 | 43 | def next_batch(self, batch_size): 44 | """ 45 | Get a single random batch. This implements sampling without replacement (not just on a batch level), this means 46 | all the data gets sampled eventually. 47 | """ 48 | 49 | if len(self.unused_indices) < batch_size: 50 | self.unused_indices = self.indices 51 | 52 | batch_indices = np.random.choice(self.unused_indices, batch_size, replace=False) 53 | self.unused_indices = np.setdiff1d(self.unused_indices, batch_indices) 54 | 55 | # HDF5 requires indices to be in increasing order 56 | batch_indices = np.sort(batch_indices) 57 | 58 | X_batch = self.X[batch_indices, ...] 59 | y_batch = self.y[batch_indices, ...] 60 | 61 | 62 | if self.num_labels_per_subject > 1: 63 | y_batch = self._select_random_label(y_batch, self.annotator_range) 64 | 65 | X_batch, y_batch = self._post_process_batch(X_batch, y_batch) 66 | 67 | return X_batch, y_batch 68 | 69 | def iterate_batches(self, batch_size, shuffle=True): 70 | """ 71 | Get a range of batches. Use as argument of a for loop like you would normally use 72 | the range() function. 73 | """ 74 | 75 | if shuffle: 76 | np.random.shuffle(self.indices) 77 | 78 | N = self.indices.shape[0] 79 | 80 | for b_i in range(0, N, batch_size): 81 | 82 | # if b_i + batch_size > N: 83 | # continue 84 | 85 | # HDF5 requires indices to be in increasing order 86 | batch_indices = np.sort(self.indices[b_i:b_i + batch_size]) 87 | 88 | X_batch = self.X[batch_indices, ...] 89 | y_batch = self.y[batch_indices, ...] 90 | 91 | if self.num_labels_per_subject > 1: 92 | y_batch = self._select_random_label(y_batch, self.annotator_range) 93 | 94 | X_batch, y_batch = self._post_process_batch(X_batch, y_batch) 95 | 96 | yield X_batch, y_batch 97 | 98 | 99 | def _post_process_batch(self, X_batch, y_batch): 100 | 101 | if self.resize_to: 102 | X_batch = resize_batch(X_batch, self.resize_to) 103 | y_batch = resize_batch(y_batch, self.resize_to) if y_batch.ndim > 1 else y_batch 104 | 105 | # logging.info('@@@ Shape start') 106 | # logging.info(X_batch.shape) 107 | # logging.info(y_batch.shape) 108 | 109 | 110 | if self.do_augmentations: 111 | X_batch, y_batch = self._augmentation_function(X_batch, y_batch) 112 | 113 | # logging.info('@@@ Shape after aug') 114 | # logging.info(X_batch.shape) 115 | # logging.info(y_batch.shape) 116 | 117 | if self.normalise_images: 118 | utils.normalise_images(np.float32(X_batch)) 119 | 120 | if self.rescale_rgb: 121 | X_batch = X_batch.astype(np.float32) / 127.5 - 1 122 | 123 | if self.rescale_range is not None: 124 | X_batch = utils.map_images_to_intensity_range(np.float32(X_batch), self.rescale_range[0], self.rescale_range[1], percentiles=0.0) 125 | 126 | if self.add_dummy_dimension: 127 | X_batch = np.expand_dims(X_batch, axis=-1) 128 | 129 | return X_batch, y_batch 130 | 131 | def _select_random_label(self, labels, annotator_range): 132 | 133 | y_tmp_list = [] 134 | for ii in range(labels.shape[0]): 135 | # print('random annotator: %d' % np.random.choice(annotator_range)) 136 | y_tmp_list.append(labels[ii, ..., np.random.choice(annotator_range)]) 137 | return np.asarray(y_tmp_list) 138 | 139 | 140 | def _augmentation_function(self, images, labels): 141 | ''' 142 | Function for augmentation of minibatches. It will transform a set of images and corresponding labels 143 | by a number of optional transformations. Each image/mask pair in the minibatch will be seperately transformed 144 | with random parameters. 145 | :param images: A numpy array of shape [minibatch, X, Y, (Z), nchannels] 146 | :param labels: A numpy array containing a corresponding label mask 147 | :param do_rotations: Rotate the input images by a random angle between -15 and 15 degrees. 148 | :param do_scaleaug: Do scale augmentation by sampling one length of a square, then cropping and upsampling the image 149 | back to the original size. 150 | :param do_fliplr: Perform random flips with a 50% chance in the left right direction. 151 | :return: A mini batch of the same size but with transformed images and masks. 152 | ''' 153 | 154 | def get_option(name, default): 155 | return self.augmentation_options[name] if name in self.augmentation_options else default 156 | 157 | try: 158 | import cv2 159 | except: 160 | return False 161 | else: 162 | 163 | if images.ndim > 4: 164 | raise AssertionError('Augmentation will only work with 2D images') 165 | 166 | # If segmentation labels also augment them, otherwise don't 167 | augment_labels = True if labels.ndim > 1 else False 168 | 169 | do_rotations = get_option('do_rotations', False) 170 | do_scaleaug = get_option('do_scaleaug', False) 171 | do_fliplr = get_option('do_fliplr', False) 172 | do_flipud = get_option('do_flipud', False) 173 | do_elasticaug = get_option('do_elasticaug', False) 174 | augment_every_nth = get_option('augment_every_nth', 2) # 2 means augment half of the images 175 | # 1 means augment every image 176 | 177 | if do_rotations or do_scaleaug or do_elasticaug: 178 | nlabels = get_option('nlabels', None) 179 | if not nlabels: 180 | raise AssertionError("When doing augmentations with rotations, scaling, or elastic transformations " 181 | "the parameter 'nlabels' must be provided.") 182 | 183 | 184 | new_images = [] 185 | new_labels = [] 186 | num_images = images.shape[0] 187 | 188 | for ii in range(num_images): 189 | 190 | img = np.squeeze(images[ii, ...]) 191 | lbl = np.squeeze(labels[ii, ...]) 192 | 193 | coin_flip = np.random.randint(augment_every_nth) 194 | if coin_flip == 0: 195 | 196 | # ROTATE 197 | if do_rotations: 198 | 199 | angles = get_option('rot_degrees', 10.0) 200 | random_angle = np.random.uniform(-angles, angles) 201 | img = utils.rotate_image(img, random_angle) 202 | 203 | if augment_labels: 204 | if nlabels <= 4: 205 | lbl = utils.rotate_image_as_onehot(lbl, random_angle, nlabels=nlabels) 206 | else: 207 | # If there are more than 4 labels open CV can no longer handle one-hot interpolation 208 | lbl = utils.rotate_image(lbl, random_angle, interp=cv2.INTER_NEAREST) 209 | 210 | # RANDOM CROP SCALE 211 | if do_scaleaug: 212 | 213 | offset = get_option('offset', 30) 214 | n_x, n_y = img.shape 215 | r_y = np.random.random_integers(n_y - offset, n_y) 216 | p_x = np.random.random_integers(0, n_x - r_y) 217 | p_y = np.random.random_integers(0, n_y - r_y) 218 | 219 | img = utils.resize_image(img[p_y:(p_y + r_y), p_x:(p_x + r_y)], (n_x, n_y)) 220 | if augment_labels: 221 | if nlabels <= 4: 222 | lbl = utils.resize_image_as_onehot(lbl[p_y:(p_y + r_y), p_x:(p_x + r_y)], (n_x, n_y), nlabels=nlabels) 223 | else: 224 | lbl = utils.resize_image(lbl[p_y:(p_y + r_y), p_x:(p_x + r_y)], (n_x, n_y), interp=cv2.INTER_NEAREST) 225 | 226 | # RANDOM ELASTIC DEFOMRATIONS (like in U-NET) 227 | if do_elasticaug: 228 | 229 | mu = 0 230 | sigma = 10 231 | n_x, n_y = img.shape 232 | 233 | dx = np.random.normal(mu, sigma, 9) 234 | dx_mat = np.reshape(dx, (3, 3)) 235 | dx_img = utils.resize_image(dx_mat, (n_x, n_y), interp=cv2.INTER_CUBIC) 236 | 237 | dy = np.random.normal(mu, sigma, 9) 238 | dy_mat = np.reshape(dy, (3, 3)) 239 | dy_img = utils.resize_image(dy_mat, (n_x, n_y), interp=cv2.INTER_CUBIC) 240 | 241 | img = utils.dense_image_warp(img, dx_img, dy_img) 242 | 243 | if augment_labels: 244 | 245 | if nlabels <= 4: 246 | lbl = utils.dense_image_warp_as_onehot(lbl, dx_img, dy_img, nlabels=nlabels) 247 | else: 248 | lbl = utils.dense_image_warp(lbl, dx_img, dy_img, interp=cv2.INTER_NEAREST, do_optimisation=False) 249 | 250 | 251 | # RANDOM FLIP 252 | if do_fliplr: 253 | coin_flip = np.random.randint(max(2, augment_every_nth)) # Flipping wouldn't make sense if you do it always 254 | if coin_flip == 0: 255 | img = np.fliplr(img) 256 | if augment_labels: 257 | lbl = np.fliplr(lbl) 258 | 259 | if do_flipud: 260 | coin_flip = np.random.randint(max(2, augment_every_nth)) 261 | if coin_flip == 0: 262 | img = np.flipud(img) 263 | if augment_labels: 264 | lbl = np.flipud(lbl) 265 | 266 | new_images.append(img[...]) 267 | new_labels.append(lbl[...]) 268 | 269 | sampled_image_batch = np.asarray(new_images) 270 | sampled_label_batch = np.asarray(new_labels) 271 | 272 | return sampled_image_batch, sampled_label_batch -------------------------------------------------------------------------------- /data/data_switch.py: -------------------------------------------------------------------------------- 1 | 2 | def data_switch(data_identifier): 3 | 4 | if data_identifier == 'acdc': 5 | from data.acdc_data import acdc_data as data_loader 6 | elif data_identifier == 'lidc': 7 | from data.lidc_data import lidc_data as data_loader 8 | elif data_identifier == 'uzh_prostate': 9 | from data.uzh_prostate_data import uzh_prostate_data as data_loader 10 | else: 11 | raise ValueError('Unknown data identifier: %s' % data_identifier) 12 | 13 | return data_loader -------------------------------------------------------------------------------- /data/lidc_data.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | 4 | import numpy as np 5 | from data import lidc_data_loader 6 | from data.batch_provider import BatchProvider 7 | 8 | class lidc_data(): 9 | 10 | def __init__(self, exp_config): 11 | 12 | data = lidc_data_loader.load_and_maybe_process_data( 13 | input_file=exp_config.data_root, 14 | preprocessing_folder=exp_config.preproc_folder, 15 | force_overwrite=False, 16 | ) 17 | 18 | self.data = data 19 | 20 | # Extract the number of training and testing points 21 | indices = {} 22 | for tt in data: 23 | N = data[tt]['images'].shape[0] 24 | indices[tt] = np.arange(N) 25 | 26 | # Create the batch providers 27 | augmentation_options = exp_config.augmentation_options 28 | 29 | # Backwards compatibility, TODO remove for final version 30 | if not hasattr(exp_config, 'annotator_range'): 31 | exp_config.annotator_range = range(exp_config.num_labels_per_subject) 32 | 33 | self.train = BatchProvider(data['train']['images'], data['train']['labels'], indices['train'], 34 | add_dummy_dimension=True, 35 | do_augmentations=True, 36 | augmentation_options=augmentation_options, 37 | num_labels_per_subject=exp_config.num_labels_per_subject, 38 | annotator_range=exp_config.annotator_range) 39 | self.validation = BatchProvider(data['val']['images'], data['val']['labels'], indices['val'], 40 | add_dummy_dimension=True, 41 | num_labels_per_subject=exp_config.num_labels_per_subject, 42 | annotator_range=exp_config.annotator_range) 43 | self.test = BatchProvider(data['test']['images'], data['test']['labels'], indices['test'], 44 | add_dummy_dimension=True, 45 | num_labels_per_subject=exp_config.num_labels_per_subject, 46 | annotator_range=exp_config.annotator_range) 47 | 48 | self.test.images = data['test']['images'] 49 | self.test.labels = data['test']['labels'] 50 | 51 | self.validation.images = data['val']['images'] 52 | self.validation.labels = data['val']['labels'] 53 | 54 | 55 | if __name__ == '__main__': 56 | 57 | # If the program is called as main, perform some debugging operations 58 | from phiseg.experiments import phiseg_7_5 as exp_config 59 | data = lidc_data(exp_config) 60 | 61 | print(data.validation.images.shape) 62 | 63 | print(data.data['val']['images'].shape[0]) 64 | print(data.data['test']['images'].shape[0]) 65 | print(data.data['train']['images'].shape[0]) 66 | print(data.data['train']['images'].shape[0]+data.data['test']['images'].shape[0]+data.data['val']['images'].shape[0]) 67 | 68 | print('DEBUGGING OUTPUT') 69 | print('training') 70 | for ii in range(2): 71 | X_tr, Y_tr = data.train.next_batch(10) 72 | print(np.mean(X_tr)) 73 | print(Y_tr.shape) 74 | print('--') 75 | 76 | print('test') 77 | for ii in range(2): 78 | X_te, Y_te = data.test.next_batch(10) 79 | print(np.mean(X_te)) 80 | print(Y_te.shape) 81 | print('--') 82 | 83 | print('validation') 84 | for ii in range(2): 85 | X_va, Y_va = data.validation.next_batch(10) 86 | print(np.mean(X_va)) 87 | print(Y_va.shape) 88 | print('--') 89 | -------------------------------------------------------------------------------- /data/lidc_data_loader.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | # Lisa M. Koch (lisa.margret.koch@gmail.com) 4 | 5 | import os 6 | import numpy as np 7 | import logging 8 | import h5py 9 | import pickle 10 | from sklearn.model_selection import train_test_split 11 | 12 | import utils 13 | 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 15 | 16 | def crop_or_pad_slice_to_size(slice, nx, ny): 17 | x, y = slice.shape 18 | 19 | x_s = (x - nx) // 2 20 | y_s = (y - ny) // 2 21 | x_c = (nx - x) // 2 22 | y_c = (ny - y) // 2 23 | 24 | if x > nx and y > ny: 25 | slice_cropped = slice[x_s:x_s + nx, y_s:y_s + ny] 26 | else: 27 | slice_cropped = np.zeros((nx, ny)) 28 | if x <= nx and y > ny: 29 | slice_cropped[x_c:x_c + x, :] = slice[:, y_s:y_s + ny] 30 | elif x > nx and y <= ny: 31 | slice_cropped[:, y_c:y_c + y] = slice[x_s:x_s + nx, :] 32 | else: 33 | slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice[:, :] 34 | 35 | return slice_cropped 36 | 37 | 38 | def find_subset_for_id(ids_dict, id): 39 | 40 | for tt in ['test', 'train', 'val']: 41 | if id in ids_dict[tt]: 42 | return tt 43 | raise ValueError('id was not found in any of the train/test/val subsets.') 44 | 45 | 46 | def prepare_data(input_file, output_file): 47 | ''' 48 | Main function that prepares a dataset from the raw challenge data to an hdf5 dataset 49 | ''' 50 | 51 | hdf5_file = h5py.File(output_file, "w") 52 | max_bytes = 2 ** 31 - 1 53 | 54 | data = {} 55 | file_path = os.fsdecode(input_file) 56 | bytes_in = bytearray(0) 57 | input_size = os.path.getsize(file_path) 58 | with open(file_path, 'rb') as f_in: 59 | for _ in range(0, input_size, max_bytes): 60 | bytes_in += f_in.read(max_bytes) 61 | new_data = pickle.loads(bytes_in) 62 | data.update(new_data) 63 | 64 | series_uid = [] 65 | 66 | for key, value in data.items(): 67 | series_uid.append(value['series_uid']) 68 | 69 | unique_subjects = np.unique(series_uid) 70 | 71 | split_ids = {} 72 | train_and_val_ids, split_ids['test'] = train_test_split(unique_subjects, test_size=0.2) 73 | split_ids['train'], split_ids['val'] = train_test_split(train_and_val_ids, test_size=0.2) 74 | 75 | images = {} 76 | labels = {} 77 | uids = {} 78 | groups = {} 79 | 80 | for tt in ['train', 'test', 'val']: 81 | images[tt] = [] 82 | labels[tt] = [] 83 | uids[tt] = [] 84 | groups[tt] = hdf5_file.create_group(tt) 85 | 86 | for key, value in data.items(): 87 | 88 | s_id = value['series_uid'] 89 | 90 | tt = find_subset_for_id(split_ids, s_id) 91 | 92 | images[tt].append(value['image'].astype(float)-0.5) 93 | 94 | lbl = np.asarray(value['masks']) # this will be of shape 4 x 128 x 128 95 | lbl = lbl.transpose((1,2,0)) 96 | 97 | labels[tt].append(lbl) 98 | uids[tt].append(hash(s_id)) # Checked manually that there are no collisions 99 | 100 | for tt in ['test', 'train', 'val']: 101 | 102 | groups[tt].create_dataset('uids', data=np.asarray(uids[tt], dtype=np.int)) 103 | groups[tt].create_dataset('labels', data=np.asarray(labels[tt], dtype=np.uint8)) 104 | groups[tt].create_dataset('images', data=np.asarray(images[tt], dtype=np.float)) 105 | 106 | hdf5_file.close() 107 | 108 | 109 | def load_and_maybe_process_data(input_file, 110 | preprocessing_folder, 111 | force_overwrite=False): 112 | ''' 113 | This function is used to load and if necessary preprocesses the LIDC challenge data 114 | 115 | :param input_folder: Folder where the raw ACDC challenge data is located 116 | :param preprocessing_folder: Folder where the proprocessed data should be written to 117 | :param force_overwrite: Set this to True if you want to overwrite already preprocessed data [default: False] 118 | 119 | :return: Returns an h5py.File handle to the dataset 120 | ''' 121 | 122 | data_file_name = 'data_lidc.hdf5' 123 | 124 | data_file_path = os.path.join(preprocessing_folder, data_file_name) 125 | 126 | utils.makefolder(preprocessing_folder) 127 | 128 | if not os.path.exists(data_file_path) or force_overwrite: 129 | logging.info('This configuration of mode, size and target resolution has not yet been preprocessed') 130 | logging.info('Preprocessing now!') 131 | prepare_data(input_file, data_file_path) 132 | else: 133 | logging.info('Already preprocessed this configuration. Loading now!') 134 | 135 | return h5py.File(data_file_path, 'r') 136 | 137 | 138 | if __name__ == '__main__': 139 | 140 | input_file = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 141 | preprocessing_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 142 | 143 | d = load_and_maybe_process_data(input_file, preprocessing_folder, force_overwrite=True) 144 | 145 | -------------------------------------------------------------------------------- /eval_dice_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | from scipy import stats 7 | 8 | 9 | # LIDC RERUN 10 | experiment_base_folder = '/itet-stor/baumgach/net_scratch/logs/phiseg/lidc/' 11 | experiment_list = ['detunet', 12 | 'probunet_1annot', 13 | 'phiseg_7_1_1annot', 14 | 'phiseg_7_5_1annot'] 15 | experiment_names = ['detUNET','ProbUNET_1annot', 'SegVAE_1lvls_1annot', 'SegVAE_5lvls_1annot'] 16 | file_list = ['dice_best_dice.npz']*len(experiment_list) 17 | 18 | 19 | ged_list = [] 20 | 21 | for folder, exp_name, file in zip(experiment_list, experiment_names, file_list): 22 | 23 | experiment_path = os.path.join(experiment_base_folder, folder, file) 24 | 25 | ged_arr = np.load(experiment_path)['arr_0'] 26 | 27 | print(ged_arr.shape) 28 | 29 | ged_list.append(np.mean(ged_arr[:,1:],axis=-1)) 30 | 31 | ged_tot_arr = np.asarray(ged_list).T 32 | 33 | print('significance') 34 | print('REMINDER: are you checking the right methods?') 35 | print(stats.ttest_rel(ged_list[0], ged_list[3])) 36 | 37 | print('Results summary') 38 | # means = np.median(ged_tot_arr, axis=0) 39 | means = ged_tot_arr.mean(axis=0) 40 | stds= ged_tot_arr.std(axis=0) 41 | 42 | print(means.shape) 43 | print(stds.shape) 44 | 45 | for i in range(means.shape[0]): 46 | print('Exp. name: %s \t %.4f +- %.4f' % (experiment_names[i], means[i], stds[i])) 47 | 48 | df = pd.DataFrame(ged_tot_arr, columns=experiment_names) 49 | df = df.melt(var_name='experiments', value_name='vals') 50 | 51 | sns.boxplot(x='experiments', y='vals', data=df) 52 | plt.show() -------------------------------------------------------------------------------- /eval_ged_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | from scipy import stats 7 | 8 | experiment_base_folder = '/itet-stor/baumgach/net_scratch/logs/phiseg/lidc/' 9 | experiment_list = ['probunet', 10 | 'phiseg_7_1', 11 | 'phiseg_7_5', 12 | 'probunet_1annot', 13 | 'phiseg_7_1_1annot', 14 | 'phiseg_7_5_1annot'] 15 | experiment_names = ['probunet','phiseg_7_1', 'phiseg_7_5', 'probunet_1annot', 'phiseg_7_1_1annot', 'phiseg_7_5_1annot'] 16 | file_list = ['ged100_best_ged.npz']*len(experiment_list) 17 | 18 | 19 | ged_list = [] 20 | 21 | for folder, exp_name, file in zip(experiment_list, experiment_names, file_list): 22 | 23 | experiment_path = os.path.join(experiment_base_folder, folder, file) 24 | 25 | ged_arr = np.load(experiment_path)['arr_0'] 26 | 27 | ged_list.append(ged_arr) 28 | 29 | ged_tot_arr = np.asarray(ged_list).T 30 | 31 | print('significance') 32 | print('REMINDER: are you checking the right methods?') 33 | print(stats.ttest_rel(ged_list[0], ged_list[1])) 34 | 35 | print('Results summary') 36 | means = ged_tot_arr.mean(axis=0) 37 | stds= ged_tot_arr.std(axis=0) 38 | 39 | for i in range(means.shape[0]): 40 | print('Exp. name: %s \t %.4f +- %.4f' % (experiment_names[i], means[i], stds[i])) 41 | 42 | df = pd.DataFrame(ged_tot_arr, columns=experiment_names) 43 | df = df.melt(var_name='experiments', value_name='vals') 44 | 45 | sns.boxplot(x='experiments', y='vals', data=df) 46 | plt.show() -------------------------------------------------------------------------------- /eval_ncc_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | from scipy import stats 7 | 8 | experiment_base_folder = '/itet-stor/baumgach/net_scratch/logs/phiseg/lidc/' 9 | experiment_list = ['probunet', 10 | 'phiseg_7_1', 11 | 'phiseg_7_5', 12 | 'probunet_1annot', 13 | 'phiseg_7_1_1annot', 14 | 'phiseg_7_5_1annot'] 15 | experiment_names = ['probunet','phiseg_7_1', 'phiseg_7_5', 'probunet_1annot', 'phiseg_7_1_1annot', 'phiseg_7_5_1annot'] 16 | file_list = ['ncc100_best_loss.npz']*len(experiment_list) 17 | 18 | 19 | ged_list = [] 20 | 21 | for folder, exp_name, file in zip(experiment_list, experiment_names, file_list): 22 | 23 | experiment_path = os.path.join(experiment_base_folder, folder, file) 24 | 25 | ged_arr = np.squeeze(np.load(experiment_path)['arr_0']) 26 | 27 | ged_list.append(ged_arr) 28 | 29 | ged_tot_arr = np.asarray(ged_list).T 30 | 31 | print('significance') 32 | print('REMINDER: are you checking the right methods?') 33 | print(stats.ttest_rel(ged_list[2], ged_list[3])) 34 | 35 | print('Results summary') 36 | means = ged_tot_arr.mean(axis=0) 37 | stds= ged_tot_arr.std(axis=0) 38 | 39 | print(ged_tot_arr.shape) 40 | 41 | for i in range(means.shape[0]): 42 | print('Exp. name: %s \t %.4f +- %.4f' % (experiment_names[i], means[i], stds[i])) 43 | 44 | df = pd.DataFrame(ged_tot_arr, columns=experiment_names) 45 | df = df.melt(var_name='experiments', value_name='vals') 46 | 47 | sns.boxplot(x='experiments', y='vals', data=df) 48 | plt.show() -------------------------------------------------------------------------------- /figures/graphical_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/figures/graphical_model.png -------------------------------------------------------------------------------- /figures/gt_id165.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/figures/gt_id165.gif -------------------------------------------------------------------------------- /figures/samples_id165.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/figures/samples_id165.gif -------------------------------------------------------------------------------- /phiseg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/phiseg/__init__.py -------------------------------------------------------------------------------- /phiseg/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/phiseg/experiments/__init__.py -------------------------------------------------------------------------------- /phiseg/experiments/detunet.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'detunet' 6 | log_dir_name = 'lidc2' 7 | 8 | # architecture 9 | posterior = posteriors.dummy 10 | likelihood = likelihoods.det_unet2D 11 | prior = priors.dummy 12 | layer_norm = tfnorm.batch_norm 13 | 14 | latent_levels = 1 15 | resolution_levels = 7 16 | n0 = 32 17 | zdim0 = 6 18 | max_channel_power = 4 # max number of channels will be n0*2**max_channel_power 19 | 20 | # Data settings 21 | data_identifier = 'lidc' 22 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 23 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 24 | dimensionality_mode = '2D' 25 | image_size = (128, 128, 1) 26 | nlabels = 2 27 | num_labels_per_subject = 4 28 | 29 | augmentation_options = {'do_flip_lr': True, 30 | 'do_flip_ud': True, 31 | 'do_rotations': True, 32 | 'do_scaleaug': True, 33 | 'nlabels': nlabels} 34 | 35 | # training 36 | optimizer = tf.train.AdamOptimizer 37 | lr_schedule_dict = {0: 1e-3} 38 | deep_supervision = True 39 | batch_size = 12 40 | num_iter = 5000000 41 | annotator_range = [0] # which annotators to actually use for training 42 | 43 | # losses 44 | KL_divergence_loss_weight = None 45 | exponential_weighting = True 46 | 47 | residual_multinoulli_loss_weight = 1.0 48 | 49 | # monitoring 50 | do_image_summaries = True 51 | rescale_RGB = False 52 | validation_frequency = 500 53 | validation_samples = 16 54 | num_validation_images = 100 #'all' 55 | tensorboard_update_frequency = 100 56 | 57 | -------------------------------------------------------------------------------- /phiseg/experiments/phiseg_7_1.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'phiseg_7_1' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.phiseg 10 | likelihood = likelihoods.phiseg 11 | prior = priors.phiseg 12 | layer_norm = tfnorm.batch_norm 13 | use_logistic_transform = False 14 | 15 | latent_levels = 1 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 2 19 | max_channel_power = 4 # max number of channels will be n0*2**max_channel_power 20 | 21 | # Data settings 22 | data_identifier = 'lidc' 23 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 24 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 25 | dimensionality_mode = '2D' 26 | image_size = (128, 128, 1) 27 | nlabels = 2 28 | num_labels_per_subject = 4 29 | 30 | augmentation_options = {'do_flip_lr': True, 31 | 'do_flip_ud': True, 32 | 'do_rotations': True, 33 | 'do_scaleaug': True, 34 | 'nlabels': nlabels} 35 | 36 | # training 37 | optimizer = tf.train.AdamOptimizer 38 | lr_schedule_dict = {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = range(num_labels_per_subject) # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/experiments/phiseg_7_1_1annot.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'phiseg_7_1_1annot' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.phiseg 10 | likelihood = likelihoods.phiseg 11 | prior = priors.phiseg 12 | layer_norm = tfnorm.batch_norm 13 | use_logistic_transform = False 14 | 15 | latent_levels = 1 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 2 19 | max_channel_power = 4 # max number of channels will be n0*2**max_channel_power 20 | 21 | # Data settings 22 | data_identifier = 'lidc' 23 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 24 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 25 | dimensionality_mode = '2D' 26 | image_size = (128, 128, 1) 27 | nlabels = 2 28 | num_labels_per_subject = 4 29 | 30 | augmentation_options = {'do_flip_lr': True, 31 | 'do_flip_ud': True, 32 | 'do_rotations': True, 33 | 'do_scaleaug': True, 34 | 'nlabels': nlabels} 35 | 36 | # training 37 | optimizer = tf.train.AdamOptimizer 38 | lr_schedule_dict = {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = [0] # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/experiments/phiseg_7_5.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'phiseg_7_5' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.phiseg 10 | likelihood = likelihoods.phiseg 11 | prior = priors.phiseg 12 | layer_norm = tfnorm.batch_norm 13 | use_logistic_transform = False 14 | 15 | latent_levels = 5 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 2 19 | max_channel_power = 4 # max number of channels will be n0*2**max_channel_power 20 | 21 | # Data settings 22 | data_identifier = 'lidc' 23 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 24 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 25 | dimensionality_mode = '2D' 26 | image_size = (128, 128, 1) 27 | nlabels = 2 28 | num_labels_per_subject = 4 29 | 30 | augmentation_options = {'do_flip_lr': True, 31 | 'do_flip_ud': True, 32 | 'do_rotations': True, 33 | 'do_scaleaug': True, 34 | 'nlabels': nlabels} 35 | 36 | # training 37 | optimizer = tf.train.AdamOptimizer 38 | lr_schedule_dict = {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = range(num_labels_per_subject) # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/experiments/phiseg_7_5_1annot.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'phiseg_7_5_1annot' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.phiseg 10 | likelihood = likelihoods.phiseg 11 | prior = priors.phiseg 12 | layer_norm = tfnorm.batch_norm 13 | use_logistic_transform = False 14 | 15 | latent_levels = 5 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 2 19 | max_channel_power = 4 # max number of channels will be n0*2**max_channel_power 20 | 21 | # Data settings 22 | data_identifier = 'lidc' 23 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 24 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 25 | dimensionality_mode = '2D' 26 | image_size = (128, 128, 1) 27 | nlabels = 2 28 | num_labels_per_subject = 4 29 | 30 | augmentation_options = {'do_flip_lr': True, 31 | 'do_flip_ud': True, 32 | 'do_rotations': True, 33 | 'do_scaleaug': True, 34 | 'nlabels': nlabels} 35 | 36 | # training 37 | optimizer = tf.train.AdamOptimizer 38 | lr_schedule_dict = {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = [0] # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/experiments/probunet.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'probunet' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.prob_unet2D 10 | likelihood = likelihoods.prob_unet2D 11 | prior = priors.prob_unet2D 12 | layer_norm = tfnorm.batch_norm # No layer normalisation! 13 | use_logistic_transform = False 14 | 15 | latent_levels = 1 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 6 19 | 20 | # Data settings 21 | data_identifier = 'lidc' 22 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 23 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 24 | dimensionality_mode = '2D' 25 | image_size = (128, 128, 1) 26 | nlabels = 2 27 | num_labels_per_subject = 4 28 | 29 | augmentation_options = {'do_flip_lr': True, 30 | 'do_flip_ud': True, 31 | 'do_rotations': True, 32 | 'do_scaleaug': True, 33 | 'nlabels': nlabels} 34 | 35 | # training 36 | optimizer = tf.train.AdamOptimizer 37 | lr_schedule_dict = {0: 1e-3} 38 | # lr_schedule_dict = {0: 1e-4, 80000: 0.5e-4, 160000: 1e-5, 240000: 0.5e-6} # {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = range(num_labels_per_subject) # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/experiments/probunet_1annot.py: -------------------------------------------------------------------------------- 1 | from phiseg.model_zoo import likelihoods, posteriors, priors 2 | import tensorflow as tf 3 | from tfwrapper import normalisation as tfnorm 4 | 5 | experiment_name = 'probunet_1annot' 6 | log_dir_name = 'lidc' 7 | 8 | # architecture 9 | posterior = posteriors.prob_unet2D 10 | likelihood = likelihoods.prob_unet2D 11 | prior = priors.prob_unet2D 12 | layer_norm = tfnorm.batch_norm # No layer normalisation! 13 | use_logistic_transform = False 14 | 15 | latent_levels = 1 16 | resolution_levels = 7 17 | n0 = 32 18 | zdim0 = 6 19 | 20 | # Data settings 21 | data_identifier = 'lidc' 22 | preproc_folder = '/srv/glusterfs/baumgach/preproc_data/lidc' 23 | data_root = '/itet-stor/baumgach/bmicdatasets-originals/Originals/LIDC-IDRI/data_lidc.pickle' 24 | dimensionality_mode = '2D' 25 | image_size = (128, 128, 1) 26 | nlabels = 2 27 | num_labels_per_subject = 4 28 | 29 | augmentation_options = {'do_flip_lr': True, 30 | 'do_flip_ud': True, 31 | 'do_rotations': True, 32 | 'do_scaleaug': True, 33 | 'nlabels': nlabels} 34 | 35 | # training 36 | optimizer = tf.train.AdamOptimizer 37 | lr_schedule_dict = {0: 1e-3} 38 | # lr_schedule_dict = {0: 1e-4, 80000: 0.5e-4, 160000: 1e-5, 240000: 0.5e-6} # {0: 1e-3} 39 | deep_supervision = True 40 | batch_size = 12 41 | num_iter = 5000000 42 | annotator_range = [0] # which annotators to actually use for training 43 | 44 | # losses 45 | KL_divergence_loss_weight = 1.0 46 | exponential_weighting = True 47 | 48 | residual_multinoulli_loss_weight = 1.0 49 | 50 | # monitoring 51 | do_image_summaries = True 52 | rescale_RGB = False 53 | validation_frequency = 500 54 | validation_samples = 16 55 | num_validation_images = 100 #'all' 56 | tensorboard_update_frequency = 100 57 | 58 | -------------------------------------------------------------------------------- /phiseg/model_zoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/phiseg/model_zoo/__init__.py -------------------------------------------------------------------------------- /phiseg/model_zoo/likelihoods.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tfwrapper import layers 3 | from tfwrapper import normalisation as tfnorm 4 | import numpy as np 5 | 6 | import logging 7 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 8 | 9 | 10 | def det_unet2D(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 11 | 12 | x = kwargs.get('x') 13 | 14 | resolution_levels = kwargs.get('resolution_levels', 7) 15 | n0 = kwargs.get('n0', 32) 16 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 17 | 18 | conv_unit = layers.conv2D 19 | deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) 20 | 21 | with tf.variable_scope('likelihood') as scope: 22 | 23 | if scope_reuse: 24 | scope.reuse_variables() 25 | 26 | add_bias = False if norm == tfnorm.batch_norm else True 27 | 28 | enc = [] 29 | 30 | with tf.variable_scope('encoder'): 31 | 32 | for ii in range(resolution_levels): 33 | 34 | enc.append([]) 35 | 36 | # In first layer set input to x rather than max pooling 37 | if ii == 0: 38 | enc[ii].append(x) 39 | else: 40 | enc[ii].append(layers.averagepool2D(enc[ii-1][-1])) 41 | 42 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 43 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 44 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 45 | 46 | dec = [] 47 | 48 | with tf.variable_scope('decoder'): 49 | 50 | for jj in range(resolution_levels-1): 51 | 52 | ii = resolution_levels - jj - 1 # used to index the encoder again 53 | 54 | dec.append([]) 55 | 56 | if jj == 0: 57 | next_inp = enc[ii][-1] 58 | else: 59 | next_inp = dec[jj-1][-1] 60 | 61 | 62 | dec[jj].append(deconv_unit(next_inp)) 63 | 64 | # skip connection 65 | dec[jj].append(layers.crop_and_concat([dec[jj][-1], enc[ii-1][-1]], axis=3)) 66 | 67 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) # projection True to make it work with res units. 68 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 69 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 70 | 71 | net = dec[-1][-1] 72 | 73 | recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 74 | recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 75 | recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 76 | 77 | s = [layers.conv2D(recomb, 'prediction', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity)] 78 | 79 | return s 80 | 81 | def prob_unet2D(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 82 | 83 | x = kwargs.get('x') 84 | 85 | z = z_list[0] 86 | 87 | resolution_levels = kwargs.get('resolution_levels', 7) 88 | n0 = kwargs.get('n0', 32) 89 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 90 | 91 | conv_unit = layers.conv2D 92 | deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) 93 | 94 | bs = tf.shape(x)[0] 95 | zdim = z.get_shape().as_list()[-1] 96 | 97 | with tf.variable_scope('likelihood') as scope: 98 | 99 | if scope_reuse: 100 | scope.reuse_variables() 101 | 102 | add_bias = False if norm == tfnorm.batch_norm else True 103 | 104 | enc = [] 105 | 106 | with tf.variable_scope('encoder'): 107 | 108 | for ii in range(resolution_levels): 109 | 110 | enc.append([]) 111 | 112 | # In first layer set input to x rather than max pooling 113 | if ii == 0: 114 | enc[ii].append(x) 115 | else: 116 | enc[ii].append(layers.averagepool2D(enc[ii-1][-1])) 117 | 118 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 119 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 120 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 121 | 122 | dec = [] 123 | 124 | with tf.variable_scope('decoder'): 125 | 126 | for jj in range(resolution_levels-1): 127 | 128 | ii = resolution_levels - jj - 1 # used to index the encoder again 129 | 130 | dec.append([]) 131 | 132 | if jj == 0: 133 | next_inp = enc[ii][-1] 134 | else: 135 | next_inp = dec[jj-1][-1] 136 | 137 | 138 | dec[jj].append(deconv_unit(next_inp)) 139 | 140 | # skip connection 141 | dec[jj].append(layers.crop_and_concat([dec[jj][-1], enc[ii-1][-1]], axis=3)) 142 | 143 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) # projection True to make it work with res units. 144 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 145 | dec[jj].append(conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 146 | 147 | z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim))) 148 | 149 | broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1)) 150 | 151 | net = tf.concat([dec[-1][-1], broadcast_z], axis=-1) 152 | 153 | recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 154 | recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 155 | recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1,1), training=training, normalisation=norm, add_bias=add_bias) 156 | 157 | s = [layers.conv2D(recomb, 'prediction', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity)] 158 | 159 | return s 160 | 161 | 162 | def phiseg(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 163 | """ 164 | This is a U-NET like arch with skips before and after latent space and a rather simple decoder 165 | """ 166 | 167 | n0 = kwargs.get('n0', 32) 168 | num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] 169 | 170 | def increase_resolution(x, times, num_filters, name): 171 | 172 | with tf.variable_scope(name): 173 | nett = x 174 | 175 | for i in range(times): 176 | nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) 177 | nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=num_filters, normalisation=norm, training=training) 178 | 179 | return nett 180 | 181 | with tf.variable_scope('likelihood') as scope: 182 | 183 | if scope_reuse: 184 | scope.reuse_variables() 185 | 186 | resolution_levels = kwargs.get('resolution_levels', 7) 187 | latent_levels = kwargs.get('latent_levels', 5) 188 | lvl_diff = resolution_levels - latent_levels 189 | 190 | post_z = [None] * latent_levels 191 | post_c = [None] * latent_levels 192 | 193 | s = [None] * latent_levels 194 | 195 | # Generate post_z 196 | for i in range(latent_levels): 197 | net = layers.conv2D(z_list[i], 'z%d_post_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) 198 | net = layers.conv2D(net, 'z%d_post_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) 199 | net = increase_resolution(net, resolution_levels - latent_levels, num_filters=num_channels[i], name='preups_%d' % i) 200 | 201 | post_z[i] = net 202 | 203 | # Upstream path 204 | post_c[latent_levels - 1] = post_z[latent_levels - 1] 205 | 206 | for i in reversed(range(latent_levels - 1)): 207 | ups_below = layers.bilinear_upsample2D(post_c[i + 1], name='post_z%d_ups' % (i + 1), factor=2) 208 | ups_below = layers.conv2D(ups_below, 'post_z%d_ups_c' % (i + 1), num_filters=num_channels[i], normalisation=norm, training=training) 209 | 210 | concat = tf.concat([post_z[i], ups_below], axis=3, name='concat_%d' % i) 211 | 212 | net = layers.conv2D(concat, 'post_c_%d_1' % i, num_filters=num_channels[i+lvl_diff], normalisation=norm, training=training) 213 | net = layers.conv2D(net, 'post_c_%d_2' % i, num_filters=num_channels[i+lvl_diff], normalisation=norm, training=training) 214 | 215 | post_c[i] = net 216 | 217 | # Outputs 218 | for i in range(latent_levels): 219 | 220 | s_in = layers.conv2D(post_c[i], 'y_lvl%d' % i, num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) 221 | s[i] = tf.image.resize_images(s_in, image_size[0:2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 222 | 223 | return s 224 | 225 | -------------------------------------------------------------------------------- /phiseg/model_zoo/posteriors.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tfwrapper import layers, utils 3 | from tfwrapper import normalisation as tfnorm 4 | import numpy as np 5 | import logging 6 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 7 | 8 | 9 | def prob_unet2D(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 10 | 11 | resolution_levels = kwargs.get('resolution_levels', 7) 12 | n0 = kwargs.get('n0', 32) 13 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 14 | 15 | conv_unit = layers.conv2D 16 | 17 | 18 | with tf.variable_scope('posterior') as scope: 19 | 20 | if scope_reuse: 21 | scope.reuse_variables() 22 | 23 | add_bias = False if norm == tfnorm.batch_norm else True 24 | 25 | enc = [] 26 | 27 | for ii in range(resolution_levels): 28 | 29 | enc.append([]) 30 | 31 | # In first layer set input to x rather than max pooling 32 | if ii == 0: 33 | enc[ii].append(tf.concat([x, s_oh-0.5], axis=-1)) 34 | else: 35 | enc[ii].append(layers.averagepool2D(enc[ii-1][-1])) 36 | 37 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 38 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 39 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 40 | 41 | mu_p = conv_unit(enc[-1][-1], 'pre_mu', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.identity) 42 | mu = [layers.global_averagepool2D(mu_p)] 43 | 44 | sigma_p = conv_unit(enc[-1][-1], 'pre_sigma', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.nn.softplus) 45 | sigma = [layers.global_averagepool2D(sigma_p)] 46 | 47 | z = [mu[0] + sigma[0] * tf.random_normal(tf.shape(mu[0]), 0, 1, dtype=tf.float32)] 48 | 49 | print('@@ z shape in posterior') 50 | print(z[0].get_shape().as_list()) 51 | 52 | return z, mu, sigma 53 | 54 | 55 | 56 | def phiseg(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 57 | 58 | n0 = kwargs.get('n0', 32) 59 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 60 | 61 | with tf.variable_scope('posterior') as scope: 62 | 63 | if scope_reuse: 64 | scope.reuse_variables() 65 | 66 | full_cov_list = kwargs.get('full_cov_list', None) 67 | 68 | n0 = kwargs.get('n0', 32) 69 | latent_levels = kwargs.get('latent_levels', 5) 70 | resolution_levels = kwargs.get('resolution_levels', 7) 71 | 72 | spatial_xdim = x.get_shape().as_list()[1:3] 73 | 74 | pre_z = [None] * resolution_levels 75 | 76 | mu = [None] * latent_levels 77 | sigma = [None] * latent_levels 78 | z = [None] * latent_levels 79 | 80 | z_ups_mat = [] 81 | for i in range(latent_levels): z_ups_mat.append([None]*latent_levels) # encoding [original resolution][upsampled to] 82 | 83 | # Generate pre_z's 84 | for i in range(resolution_levels): 85 | 86 | if i == 0: 87 | net = tf.concat([x, s_oh-0.5], axis=-1) 88 | else: 89 | net = layers.averagepool2D(pre_z[i-1]) 90 | 91 | net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) 92 | net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) 93 | net = layers.conv2D(net, 'z%d_pre_3' % i, num_filters=num_channels[i], normalisation=norm, training=training) 94 | 95 | pre_z[i] = net 96 | 97 | # Generate z's 98 | for i in reversed(range(latent_levels)): 99 | 100 | spatial_zdim = [d // 2 ** (i + resolution_levels - latent_levels) for d in spatial_xdim] 101 | spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] 102 | 103 | if i == latent_levels - 1: 104 | 105 | mu[i] = layers.conv2D(pre_z[i+resolution_levels-latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) 106 | 107 | sigma[i] = layers.conv2D(pre_z[i+resolution_levels-latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) 108 | z[i] = mu[i] + sigma[i] * tf.random_normal(tf.shape(mu[i]), 0, 1, dtype=tf.float32) 109 | 110 | else: 111 | 112 | for j in reversed(range(0, i+1)): 113 | 114 | z_below_ups = layers.bilinear_upsample2D(z_ups_mat[j+1][i+1], factor=2, name='ups') 115 | z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i+1), (j+1)), num_filters=zdim_0*n0, normalisation=norm, training=training) 116 | z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i+1), (j+1)), num_filters=zdim_0*n0, normalisation=norm, training=training) 117 | 118 | z_ups_mat[j][i + 1] = z_below_ups 119 | 120 | z_input = tf.concat([pre_z[i+resolution_levels-latent_levels], z_ups_mat[i][i+1]], axis=3, name='concat_%d' % i) 121 | 122 | z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) 123 | z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) 124 | 125 | mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1,1)) 126 | 127 | sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) 128 | z[i] = mu[i] + sigma[i] * tf.random_normal(tf.shape(mu[i]), 0, 1, dtype=tf.float32) 129 | 130 | z_ups_mat[i][i] = z[i] 131 | 132 | return z, mu, sigma 133 | 134 | 135 | def dummy(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 136 | latent_levels = kwargs.get('latent_levels', 5) 137 | z = mu = sigma = [tf.constant(0)]*latent_levels 138 | return [z, mu, sigma] -------------------------------------------------------------------------------- /phiseg/model_zoo/priors.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tfwrapper import layers 3 | from tfwrapper import normalisation as tfnorm 4 | import logging 5 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 6 | 7 | 8 | def prob_unet2D(z_list, x, zdim_0, n_classes, generation_mode, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 9 | 10 | resolution_levels = kwargs.get('resolution_levels', 7) 11 | 12 | n0 = kwargs.get('n0', 32) 13 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 14 | 15 | conv_unit = layers.conv2D 16 | 17 | with tf.variable_scope('prior') as scope: 18 | 19 | if scope_reuse: 20 | scope.reuse_variables() 21 | 22 | add_bias = False if norm == tfnorm.batch_norm else True 23 | 24 | enc = [] 25 | 26 | for ii in range(resolution_levels): 27 | 28 | enc.append([]) 29 | 30 | # In first layer set input to x rather than max pooling 31 | if ii == 0: 32 | enc[ii].append(x) 33 | else: 34 | enc[ii].append(layers.averagepool2D(enc[ii-1][-1])) 35 | 36 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 37 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 38 | enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) 39 | 40 | mu_p = conv_unit(enc[-1][-1], 'pre_mu', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.identity) 41 | mu = [layers.global_averagepool2D(mu_p)] 42 | 43 | sigma_p = conv_unit(enc[-1][-1], 'pre_sigma', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.nn.softplus) 44 | sigma = [layers.global_averagepool2D(sigma_p)] 45 | 46 | z = [mu[0] + sigma[0] * tf.random_normal(tf.shape(mu[0]), 0, 1, dtype=tf.float32)] 47 | 48 | return z, mu, sigma 49 | 50 | 51 | def phiseg(z_list, x, zdim_0, n_classes, generation_mode, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 52 | 53 | n0 = kwargs.get('n0', 32) 54 | num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0] 55 | 56 | with tf.variable_scope('prior') as scope: 57 | 58 | if scope_reuse: 59 | scope.reuse_variables() 60 | 61 | n0 = kwargs.get('n0', 32) 62 | latent_levels = kwargs.get('latent_levels', 5) 63 | resolution_levels = kwargs.get('resolution_levels', 7) 64 | 65 | spatial_xdim = x.get_shape().as_list()[1:3] 66 | 67 | pre_z = [None] * resolution_levels 68 | 69 | mu = [None] * latent_levels 70 | sigma = [None] * latent_levels 71 | z = [None] * latent_levels 72 | 73 | z_ups_mat = [] 74 | for i in range(latent_levels): z_ups_mat.append([None]*latent_levels) # encoding [original resolution][upsampled to] 75 | 76 | # Generate pre_z's 77 | for i in range(resolution_levels): 78 | 79 | if i == 0: 80 | net = x 81 | else: 82 | net = layers.averagepool2D(pre_z[i-1]) 83 | 84 | net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) 85 | net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) 86 | net = layers.conv2D(net, 'z%d_pre_3' % i, num_filters=num_channels[i], normalisation=norm, training=training) 87 | 88 | pre_z[i] = net 89 | 90 | # Generate z's 91 | for i in reversed(range(latent_levels)): 92 | 93 | spatial_zdim = [d // 2 ** (i + resolution_levels - latent_levels) for d in spatial_xdim] 94 | 95 | if i == latent_levels - 1: 96 | 97 | mu[i] = layers.conv2D(pre_z[i+resolution_levels-latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) 98 | 99 | sigma[i] = layers.conv2D(pre_z[i+resolution_levels-latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) 100 | z[i] = mu[i] + sigma[i] * tf.random_normal(tf.shape(mu[i]), 0, 1, dtype=tf.float32) 101 | 102 | else: 103 | 104 | for j in reversed(range(0, i+1)): 105 | 106 | z_below_ups = layers.bilinear_upsample2D(z_ups_mat[j+1][i+1], factor=2, name='ups') 107 | z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i+1), (j+1)), num_filters=zdim_0*n0, normalisation=norm, training=training) 108 | z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i+1), (j+1)), num_filters=zdim_0*n0, normalisation=norm, training=training) 109 | 110 | z_ups_mat[j][i + 1] = z_below_ups 111 | 112 | z_input = tf.concat([pre_z[i+resolution_levels-latent_levels], z_ups_mat[i][i+1]], axis=3, name='concat_%d' % i) 113 | 114 | z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) 115 | z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) 116 | 117 | mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1,1)) 118 | 119 | sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) 120 | z[i] = mu[i] + sigma[i] * tf.random_normal(tf.shape(mu[i]), 0, 1, dtype=tf.float32) 121 | 122 | # While training use posterior samples, when generating use prior samples here 123 | if generation_mode: 124 | z_ups_mat[i][i] = z[i] 125 | else: 126 | z_ups_mat[i][i] = z_list[i] 127 | 128 | return z, mu, sigma 129 | 130 | def dummy(z_list, x, zdim_0, n_classes, generation_mode, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): 131 | latent_levels = kwargs.get('latent_levels', 5) 132 | z = mu = sigma = [tf.constant(1)]*latent_levels 133 | return [z, mu, sigma] -------------------------------------------------------------------------------- /phiseg/phiseg_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tfwrapper import utils as tfutils 3 | 4 | import utils 5 | 6 | import numpy as np 7 | import os 8 | import time 9 | from medpy.metric import dc 10 | 11 | from config import system as sys_config 12 | 13 | sys_config.setup_GPU_environment() 14 | 15 | import logging 16 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 17 | 18 | class phiseg(): 19 | 20 | def __init__(self, exp_config): 21 | 22 | self.exp_config = exp_config 23 | self.checks() 24 | 25 | # define the input place holder 26 | self.x_inp = tf.placeholder(tf.float32, shape=[None] + list(self.exp_config.image_size), name='x_input') 27 | self.s_inp = tf.placeholder(tf.uint8, shape=[None] + list(self.exp_config.image_size[0:2]), name='s_input') 28 | 29 | self.s_inp_oh = tf.one_hot(self.s_inp, depth=exp_config.nlabels) 30 | 31 | self.training_pl = tf.placeholder(tf.bool, shape=[], name='training_time') 32 | self.lr_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') 33 | 34 | 35 | # CREATE NETwORKS 36 | 37 | self.z_list, self.mu_list, self.sigma_list = self.exp_config.posterior( 38 | self.x_inp, 39 | self.s_inp_oh, 40 | exp_config.zdim0, 41 | training=self.training_pl, 42 | n0=exp_config.n0, 43 | resolution_levels=exp_config.resolution_levels, 44 | latent_levels=exp_config.latent_levels, 45 | norm=exp_config.layer_norm 46 | ) 47 | 48 | self.prior_z_list, self.prior_mu_list, self.prior_sigma_list = self.exp_config.prior( 49 | self.z_list, 50 | self.x_inp, 51 | zdim_0=exp_config.zdim0, 52 | n_classes=self.exp_config.nlabels, 53 | training=self.training_pl, 54 | n0=exp_config.n0, 55 | generation_mode=False, 56 | resolution_levels=exp_config.resolution_levels, 57 | latent_levels=exp_config.latent_levels, 58 | norm=exp_config.layer_norm 59 | ) 60 | 61 | self.prior_z_list_gen, self.prior_mu_list_gen, self.prior_sigma_list_gen = self.exp_config.prior( 62 | self.z_list, 63 | self.x_inp, 64 | zdim_0=exp_config.zdim0, 65 | n_classes=self.exp_config.nlabels, 66 | training=self.training_pl, 67 | n0=exp_config.n0, 68 | generation_mode=True, 69 | scope_reuse=True, 70 | resolution_levels=exp_config.resolution_levels, 71 | latent_levels=exp_config.latent_levels, 72 | norm=exp_config.layer_norm 73 | ) 74 | 75 | self.s_out_list = self.exp_config.likelihood(self.z_list, 76 | self.training_pl, 77 | n0=exp_config.n0, 78 | n_classes=exp_config.nlabels, 79 | resolution_levels=exp_config.resolution_levels, 80 | latent_levels=exp_config.latent_levels, 81 | image_size=exp_config.image_size, 82 | norm=exp_config.layer_norm, 83 | x=self.x_inp) # This is only needed for probUNET! 84 | 85 | self.s_out_sm_list = [None]*self.exp_config.latent_levels 86 | for ii in range(self.exp_config.latent_levels): 87 | self.s_out_sm_list[ii] = tf.nn.softmax(self.s_out_list[ii]) 88 | 89 | self.s_out_eval_list = self.exp_config.likelihood(self.prior_z_list_gen, 90 | self.training_pl, 91 | scope_reuse=True, 92 | n0=exp_config.n0, 93 | n_classes=exp_config.nlabels, 94 | resolution_levels=exp_config.resolution_levels, 95 | latent_levels=exp_config.latent_levels, 96 | image_size=exp_config.image_size, 97 | norm=exp_config.layer_norm, 98 | x=self.x_inp) # This is only needed for probUNET! 99 | 100 | self.s_out_eval_sm_list = [None]*self.exp_config.latent_levels 101 | for ii in range(self.exp_config.latent_levels): 102 | self.s_out_eval_sm_list[ii] = tf.nn.softmax(self.s_out_eval_list[ii]) 103 | 104 | 105 | # Create final output from output list 106 | self.s_out = self._aggregate_output_list(self.s_out_list, use_softmax=False) 107 | self.s_out_eval = self._aggregate_output_list(self.s_out_eval_list, use_softmax=False) 108 | 109 | self.s_out_eval_sm = tf.nn.softmax(self.s_out_eval) 110 | 111 | self.eval_xent = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.s_inp_oh, logits=self.s_out_eval) 112 | 113 | # Add losses 114 | self.loss_dict = {} 115 | self.loss_tot = 0 116 | 117 | logging.info('ADDING LOSSES') 118 | if hasattr(self.exp_config, 'residual_multinoulli_loss_weight') and exp_config.residual_multinoulli_loss_weight is not None: 119 | logging.info(' - Adding residual multinoulli loss') 120 | self.add_residual_multinoulli_loss() 121 | 122 | if hasattr(self.exp_config, 'KL_divergence_loss_weight') and exp_config.KL_divergence_loss_weight is not None: 123 | logging.info(' - Adding hierarchical KL loss') 124 | self.add_hierarchical_KL_div_loss() 125 | 126 | if hasattr(self.exp_config, 'weight_decay_weight') and exp_config.weight_decay_weight is not None: 127 | logging.info(' - Adding weight decay') 128 | self.add_weight_decay() 129 | 130 | self.loss_dict['total_loss'] = self.loss_tot 131 | 132 | self.global_step = tf.train.get_or_create_global_step() 133 | 134 | # Create Update Operation 135 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 136 | with tf.control_dependencies(update_ops): 137 | if exp_config.optimizer == tf.train.MomentumOptimizer: 138 | optimizer = exp_config.optimizer(learning_rate=self.lr_pl, momentum=0.9, use_nesterov=True) 139 | else: 140 | optimizer = exp_config.optimizer(learning_rate=self.lr_pl) 141 | self.train_step = optimizer.minimize(self.loss_tot, global_step=self.global_step) 142 | 143 | # Create a saver for writing training checkpoints. 144 | self.saver = tf.train.Saver(max_to_keep=1, keep_checkpoint_every_n_hours=3) 145 | self.saver_best_loss = tf.train.Saver(max_to_keep=2) 146 | self.saver_best_dice = tf.train.Saver(max_to_keep=2) 147 | self.saver_best_ged = tf.train.Saver(max_to_keep=2) 148 | self.saver_best_ncc = tf.train.Saver(max_to_keep=2) 149 | 150 | # Settings to optimize GPU memory usage 151 | config = tf.ConfigProto() 152 | config.gpu_options.allow_growth = True 153 | config.allow_soft_placement = True 154 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 155 | 156 | # Create a session for running Ops on the Graph. 157 | self.sess = tf.Session(config=config) 158 | 159 | 160 | def checks(self): 161 | 162 | pass 163 | # if hasattr(self.exp_config, 'residual_multinoulli_loss_weight') and not self.exp_config.discrete_data: 164 | # raise ValueError('Invalid settings in exp_config: residual_multinoulli_loss requires discrete_data to be True.') 165 | 166 | def train(self, data): 167 | 168 | # Sort out proper logging 169 | self._setup_log_dir_and_continue_mode() 170 | 171 | # Create tensorboard summaries 172 | self._make_tensorboard_summaries() 173 | 174 | # Initialise variables 175 | self.sess.run(tf.global_variables_initializer()) 176 | 177 | # Restore session if there is one 178 | if self.continue_run: 179 | self.saver.restore(self.sess, self.init_checkpoint_path) 180 | 181 | self.best_dice = -1 182 | self.best_loss = np.inf 183 | self.best_ged = np.inf 184 | self.best_ncc = -1 185 | 186 | for step in range(self.init_step, self.exp_config.num_iter): 187 | 188 | # Get current learning rate from lr_dict 189 | lr_key, _ = utils.find_floor_in_list(self.exp_config.lr_schedule_dict.keys(), step) 190 | lr = self.exp_config.lr_schedule_dict[lr_key] 191 | 192 | 193 | x_b, s_b = data.train.next_batch(self.exp_config.batch_size) 194 | _, loss_tot_eval = self.sess.run([self.train_step, self.loss_tot], feed_dict={self.x_inp: x_b, 195 | self.s_inp: s_b, 196 | self.training_pl: True, 197 | self.lr_pl: lr}) 198 | 199 | if step % self.exp_config.tensorboard_update_frequency == 0: 200 | 201 | summary_str = self.sess.run(self.summary, feed_dict={self.x_inp: x_b, self.s_inp: s_b, self.training_pl: False, self.lr_pl: lr}) 202 | self.summary_writer.add_summary(summary_str, step) 203 | self.summary_writer.flush() 204 | 205 | if step % self.exp_config.validation_frequency == 0: 206 | 207 | self._do_validation(data) 208 | 209 | 210 | def KL_two_gauss_with_diag_cov(self, mu0, sigma0, mu1, sigma1): 211 | 212 | sigma0_fs = tf.square(tfutils.flatten(sigma0)) 213 | sigma1_fs = tf.square(tfutils.flatten(sigma1)) 214 | 215 | logsigma0_fs = tf.log(sigma0_fs + 1e-10) 216 | logsigma1_fs = tf.log(sigma1_fs + 1e-10) 217 | 218 | mu0_f = tfutils.flatten(mu0) 219 | mu1_f = tfutils.flatten(mu1) 220 | 221 | return tf.reduce_mean( 222 | 0.5*tf.reduce_sum(tf.divide(sigma0_fs + tf.square(mu1_f - mu0_f), sigma1_fs + 1e-10) 223 | + logsigma1_fs 224 | - logsigma0_fs 225 | - 1, axis=1) 226 | ) 227 | 228 | 229 | def multinoulli_loss_with_logits(self, x_gt, y_target): 230 | 231 | bs = tf.shape(x_gt)[0] 232 | 233 | x_f = tf.reshape(x_gt, tf.stack([bs, -1, self.exp_config.nlabels])) 234 | y_f = tf.reshape(y_target, tf.stack([bs, -1, self.exp_config.nlabels])) 235 | 236 | return tf.reduce_mean( 237 | tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels=x_f, logits=y_f), axis=1) 238 | ) 239 | 240 | 241 | def add_residual_multinoulli_loss(self): 242 | 243 | # TODO: move s_accum outside of this function 244 | 245 | self.s_accum = [None] * self.exp_config.latent_levels 246 | 247 | for ii, s_ii in zip(reversed(range(self.exp_config.latent_levels)), 248 | reversed(self.s_out_list)): 249 | 250 | if ii == self.exp_config.latent_levels-1: 251 | 252 | self.s_accum[ii] = s_ii 253 | self.loss_dict['residual_multinoulli_loss_lvl%d' % ii] = self.multinoulli_loss_with_logits(self.s_inp_oh, self.s_accum[ii]) 254 | 255 | else: 256 | 257 | self.s_accum[ii] = self.s_accum[ii+1] + s_ii 258 | self.loss_dict['residual_multinoulli_loss_lvl%d' % ii] = self.multinoulli_loss_with_logits(self.s_inp_oh, self.s_accum[ii]) 259 | 260 | logging.info(' -- Added residual multinoulli loss at level %d' % (ii)) 261 | 262 | self.loss_tot += self.exp_config.residual_multinoulli_loss_weight * self.loss_dict['residual_multinoulli_loss_lvl%d' % ii] 263 | 264 | 265 | def add_hierarchical_KL_div_loss(self): 266 | 267 | prior_sigma_list = self.prior_sigma_list 268 | prior_mu_list = self.prior_mu_list 269 | 270 | if self.exp_config.exponential_weighting: 271 | level_weights = [4**i for i in list(range(self.exp_config.latent_levels))] 272 | else: 273 | level_weights = [1]*self.exp_config.latent_levels 274 | 275 | for ii, mu_i, sigma_i in zip(reversed(range(self.exp_config.latent_levels)), 276 | reversed(self.mu_list), 277 | reversed(self.sigma_list)): 278 | 279 | self.loss_dict['KL_divergence_loss_lvl%d' % ii] = level_weights[ii]*self.KL_two_gauss_with_diag_cov( 280 | mu_i, 281 | sigma_i, 282 | prior_mu_list[ii], 283 | prior_sigma_list[ii]) 284 | 285 | logging.info(' -- Added hierarchical loss with at level %d with alpha_%d=%d' % (ii,ii, level_weights[ii])) 286 | 287 | self.loss_tot += self.exp_config.KL_divergence_loss_weight * self.loss_dict['KL_divergence_loss_lvl%d' % ii] 288 | 289 | 290 | def add_weight_decay(self): 291 | 292 | weights_norm = tf.reduce_sum( 293 | input_tensor= tf.stack( 294 | [tf.nn.l2_loss(ii) for ii in tf.get_collection('weight_variables')] 295 | ), 296 | name='weights_norm' 297 | ) 298 | 299 | self.loss_dict['weight_decay'] = self.exp_config.weight_decay_weight*weights_norm 300 | self.loss_tot += self.loss_dict['weight_decay'] 301 | 302 | 303 | 304 | def _aggregate_output_list(self, output_list, use_softmax=True): 305 | 306 | s_accum = output_list[-1] 307 | for i in range(len(output_list) - 1): 308 | s_accum += output_list[i] 309 | if use_softmax: 310 | return tf.nn.softmax(s_accum) 311 | return s_accum 312 | 313 | def generate_samples_from_z(self, z_list, x_in, output_all_levels=False): 314 | 315 | feed_dict = { i: d for i, d in zip(self.z_list, z_list)} 316 | feed_dict[self.training_pl] = False 317 | feed_dict[self.x_inp] = x_in 318 | 319 | if output_all_levels: 320 | return self.sess.run(self.s_out_list, feed_dict=feed_dict) 321 | else: 322 | return self.sess.run(self.s_out, feed_dict=feed_dict) 323 | 324 | 325 | def generate_prior_samples(self, x_in, return_params=False): 326 | 327 | z_samples = self.sess.run(self.prior_z_list_gen, feed_dict={self.training_pl: False, self.x_inp: x_in}) 328 | 329 | if return_params: 330 | prior_mu_list = self.sess.run(self.prior_mu_list_gen, feed_dict={self.training_pl: False, self.x_inp: x_in}) 331 | prior_sigma_list = self.sess.run(self.prior_sigma_list_gen, feed_dict={self.training_pl: False, self.x_inp: x_in}) 332 | return z_samples, prior_mu_list, prior_sigma_list 333 | else: 334 | return z_samples 335 | 336 | 337 | def predict(self, x_in, num_samples=50, return_softmax=False): 338 | 339 | feed_dict = {} 340 | feed_dict[self.training_pl] = False 341 | feed_dict[self.x_inp] = x_in 342 | # feed_dict[self.s_inp] = np.zeros([x_in.shape[0]] + self.s_inp.get_shape().as_list()[1:]) # dummy 343 | 344 | cumsum_sm = self.sess.run(self.s_out_eval_sm, feed_dict=feed_dict) 345 | 346 | for i in range(num_samples-1): 347 | # print(' - sample: %d' % (i+1)) 348 | cumsum_sm += self.sess.run(self.s_out_eval_sm, feed_dict=feed_dict) 349 | 350 | if return_softmax: 351 | return np.argmax(cumsum_sm, axis=-1), cumsum_sm / num_samples 352 | 353 | return np.argmax(cumsum_sm, axis=-1) 354 | 355 | 356 | def predict_segmentation_sample(self, x_in, return_softmax=False): 357 | 358 | feed_dict = {} 359 | feed_dict[self.training_pl] = False 360 | feed_dict[self.x_inp] = x_in 361 | 362 | if return_softmax: 363 | return self.sess.run(self.s_out_eval_sm, feed_dict=feed_dict) 364 | return np.argmax(self.sess.run(self.s_out_eval, feed_dict=feed_dict), axis=-1) 365 | 366 | 367 | def predict_segmentation_sample_levels(self, x_in, return_softmax=False): 368 | 369 | feed_dict = {} 370 | feed_dict[self.training_pl] = False 371 | feed_dict[self.x_inp] = x_in 372 | 373 | if return_softmax: 374 | return self.sess.run(self.s_out_eval_sm_list, feed_dict=feed_dict) 375 | return self.sess.run(self.s_out_eval_list, feed_dict=feed_dict) 376 | 377 | 378 | def predict_segmentation_sample_variance_sm_cov(self, x_in, num_samples): 379 | 380 | # assert self.exp_config.use_logistic_transform is True, 'predict variance is only implemented for logistic transform nets' 381 | 382 | feed_dict = {} 383 | feed_dict[self.training_pl] = False 384 | feed_dict[self.x_inp] = x_in 385 | 386 | segms = [] 387 | for _ in range(num_samples): 388 | segms.append(self.sess.run(self.s_out_eval, feed_dict=feed_dict)) 389 | 390 | segm_arr = np.squeeze(np.asarray(segms)) # num_samples x size_1 x size_2 x n_classes - 1 391 | segm_arr = segm_arr[...,:-1] 392 | segm_arr = segm_arr.transpose((1,2,3,0)) 393 | segm_arr = np.clip(segm_arr, 1e-5, 1 - (1e-5)) 394 | 395 | corr_mat = np.einsum('ghij,ghkj->ghik', segm_arr, segm_arr) / num_samples 396 | mu_mat = np.mean(segm_arr, axis=-1) 397 | outer_mu = np.einsum('ghi,ghj->ghij', mu_mat, mu_mat) 398 | cov_mat = corr_mat - outer_mu # 128, 128, nlabels -1 x nlabels -1 399 | # det_cov_mat = np.linalg.det(cov_mat) 400 | det_cov_mat, _ = np.linalg.eig(cov_mat) # 128, 128, nlabels -1 401 | 402 | 403 | return np.sum(det_cov_mat, axis=-1) 404 | 405 | 406 | def predict_segmentation_sample_variance_sm_cov_bf(self, x_in, num_samples): 407 | 408 | # assert self.exp_config.use_logistic_transform is True, 'predict variance is only implemented for logistic transform nets' 409 | 410 | feed_dict = {} 411 | feed_dict[self.training_pl] = False 412 | feed_dict[self.x_inp] = x_in 413 | 414 | segms = [] 415 | for _ in range(num_samples): 416 | segms.append(self.sess.run(self.s_out_eval_sm, feed_dict=feed_dict)) 417 | 418 | segm_arr = np.squeeze(np.asarray(segms)) # num_samples x size_1 x size_2 x n_classes - 1 419 | # segm_arr = segm_arr[..., :-1] 420 | 421 | segm_arr = segm_arr.transpose((1, 2, 3, 0)) # size_1, size_2, n_classes-1, num_samples 422 | 423 | out = np.zeros((self.exp_config.image_size[0:2])) 424 | 425 | for ii in range(self.exp_config.image_size[0]): 426 | for jj in range(self.exp_config.image_size[1]): 427 | cov = np.cov(segm_arr[ii, jj, :, :]) 428 | out[ii, jj] = np.linalg.det(cov) 429 | 430 | return out 431 | 432 | 433 | def get_crossentropy_error_map(self, s_gt, x_in, num_samples=100): 434 | 435 | feed_dict = {} 436 | feed_dict[self.training_pl] = False 437 | feed_dict[self.x_inp] = x_in 438 | feed_dict[self.s_inp] = s_gt 439 | 440 | err_maps = [] 441 | for _ in range(num_samples): 442 | err_maps.append(self.sess.run(self.eval_xent, feed_dict=feed_dict)) 443 | 444 | err_maps_arr = np.asarray(err_maps) 445 | 446 | return np.mean(err_maps_arr, axis=0) 447 | 448 | 449 | def predict_mean_variance_and_error_maps(self, s_gt, x_in, num_samples): 450 | 451 | feed_dict = {} 452 | feed_dict[self.training_pl] = False 453 | feed_dict[self.x_inp] = x_in 454 | feed_dict[self.s_inp] = s_gt 455 | 456 | err_maps = [] 457 | segms = [] 458 | 459 | for _ in range(num_samples): 460 | seg, err = self.sess.run([self.s_out_eval_sm, self.eval_xent], feed_dict=feed_dict) 461 | 462 | err_maps.append(err) 463 | segms.append(seg) 464 | 465 | err_maps_arr = np.squeeze(np.asarray(err_maps)) 466 | segm_arr = np.squeeze(np.asarray(segms)) 467 | 468 | vars = np.std(segm_arr, axis=0) 469 | vars = np.mean(vars, axis=-1) 470 | 471 | means = np.argmax(np.mean(segm_arr, 0), axis=-1) 472 | 473 | errs = np.mean(err_maps_arr, axis=0) 474 | 475 | return means, vars, errs 476 | 477 | 478 | def generate_samples_from_prior(self, x_in, output_all_levels=False): 479 | 480 | z_samples = self.generate_prior_samples(x_in) 481 | return self.generate_samples_from_z(z_samples, output_all_levels) 482 | 483 | 484 | def generate_posterior_samples(self, x_in, s_in, return_params=False): 485 | 486 | z_samples = self.sess.run(self.z_list, feed_dict={self.training_pl: False, 487 | self.x_inp: x_in, 488 | self.s_inp: s_in}) 489 | 490 | if return_params: 491 | mu_list = self.sess.run(self.mu_list, feed_dict={self.training_pl: False, self.x_inp: x_in, self.s_inp: s_in}) 492 | sigma_list = self.sess.run(self.sigma_list, feed_dict={self.training_pl: False, self.x_inp: x_in, self.s_inp: s_in}) 493 | return z_samples, mu_list, sigma_list 494 | else: 495 | return z_samples 496 | 497 | 498 | def generate_all_output_levels(self, x_in): 499 | 500 | y_list = self.sess.run(self.s_out_list, feed_dict={self.x_inp: x_in, 501 | self.training_pl: False}) 502 | return y_list 503 | 504 | 505 | def load_weights(self, log_dir=None, type='latest', **kwargs): 506 | 507 | if not log_dir: 508 | log_dir = self.log_dir 509 | 510 | if type == 'latest': 511 | init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt') 512 | elif type == 'best_dice': 513 | init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(log_dir, 'model_best_dice.ckpt') 514 | elif type == 'best_loss': 515 | init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(log_dir, 'model_best_loss.ckpt') 516 | elif type == 'best_ged': 517 | init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(log_dir, 'model_best_ged.ckpt') 518 | elif type == 'iter': 519 | assert 'iteration' in kwargs, "argument 'iteration' must be provided for type='iter'" 520 | iteration = kwargs['iteration'] 521 | init_checkpoint_path = os.path.join(log_dir, 'model.ckpt-%d' % iteration) 522 | else: 523 | raise ValueError('Argument type=%s is unknown. type can be latest/iter.' % type) 524 | 525 | self.saver.restore(self.sess, init_checkpoint_path) 526 | 527 | 528 | ### HELPER FUNCTIONS ################# 529 | 530 | def _do_validation(self, data): 531 | 532 | global_step = self.sess.run(self.global_step) - 1 533 | 534 | checkpoint_file = os.path.join(self.log_dir, 'model.ckpt') 535 | self.saver.save(self.sess, checkpoint_file, global_step=global_step) 536 | 537 | val_x, val_s = data.validation.next_batch(self.exp_config.batch_size) 538 | val_losses_out = self.sess.run(list(self.loss_dict.values()), 539 | feed_dict={self.x_inp: val_x, self.s_inp: val_s, self.training_pl: False} 540 | ) 541 | 542 | # Note that val_losses_out are now sorted in the same way as loss_dict, 543 | tot_loss_index = list(self.loss_dict.keys()).index('total_loss') 544 | val_loss_tot = val_losses_out[tot_loss_index] 545 | 546 | train_x, train_s = data.train.next_batch(self.exp_config.batch_size) 547 | train_losses_out = self.sess.run(list(self.loss_dict.values()), 548 | feed_dict={self.x_inp: train_x, self.s_inp: train_s, self.training_pl: False} 549 | ) 550 | 551 | 552 | logging.info('----- Step: %d ------' % global_step) 553 | logging.info('BATCH VALIDATION:') 554 | for ii, loss_name in enumerate(self.loss_dict.keys()): 555 | logging.info('%s | training: %f | validation: %f' % (loss_name, train_losses_out[ii], val_losses_out[ii])) 556 | 557 | # Evaluate validation Dice: 558 | 559 | start_dice_val = time.time() 560 | num_batches = 0 561 | 562 | dice_list = [] 563 | elbo_list = [] 564 | ged_list = [] 565 | ncc_list = [] 566 | 567 | N = data.validation.images.shape[0] if self.exp_config.num_validation_images == 'all' else self.exp_config.num_validation_images 568 | 569 | for ii in range(N): 570 | 571 | # logging.info(ii) 572 | 573 | x = data.validation.images[ii, ...].reshape([1] + list(self.exp_config.image_size)) 574 | s_gt_arr = data.validation.labels[ii, ...] 575 | s = s_gt_arr[:,:,np.random.choice(self.exp_config.annotator_range)] 576 | 577 | x_b = np.tile(x, [self.exp_config.validation_samples, 1, 1, 1]) 578 | s_b = np.tile(s, [self.exp_config.validation_samples, 1, 1]) 579 | 580 | feed_dict = {} 581 | feed_dict[self.training_pl] = False 582 | feed_dict[self.x_inp] = x_b 583 | feed_dict[self.s_inp] = s_b 584 | 585 | s_pred_sm_arr, elbo = self.sess.run([self.s_out_eval_sm, self.loss_tot], feed_dict=feed_dict) 586 | 587 | s_pred_sm_mean_ = np.mean(s_pred_sm_arr, axis=0) 588 | 589 | s_pred_arr = np.argmax(s_pred_sm_arr, axis=-1) 590 | s_gt_arr_r = s_gt_arr.transpose((2, 0, 1)) # num gts x X x Y 591 | 592 | s_gt_arr_r_sm = utils.convert_batch_to_onehot(s_gt_arr_r, self.exp_config.nlabels) # num gts x X x Y x nlabels 593 | 594 | ged = utils.generalised_energy_distance(s_pred_arr, s_gt_arr_r, 595 | nlabels=self.exp_config.nlabels-1, 596 | label_range=range(1, self.exp_config.nlabels)) 597 | 598 | ncc = utils.variance_ncc_dist(s_pred_sm_arr, s_gt_arr_r_sm) 599 | 600 | s_ = np.argmax(s_pred_sm_mean_, axis=-1) 601 | 602 | # Write losses to list 603 | per_lbl_dice = [] 604 | for lbl in range(self.exp_config.nlabels): 605 | binary_pred = (s_ == lbl) * 1 606 | binary_gt = (s == lbl) * 1 607 | 608 | if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0: 609 | per_lbl_dice.append(1) 610 | elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(binary_pred) == 0 and np.sum(binary_gt) > 0: 611 | per_lbl_dice.append(0) 612 | else: 613 | per_lbl_dice.append(dc(binary_pred, binary_gt)) 614 | 615 | num_batches += 1 616 | 617 | dice_list.append(per_lbl_dice) 618 | elbo_list.append(elbo) 619 | ged_list.append(ged) 620 | ncc_list.append(ncc) 621 | 622 | dice_arr = np.asarray(dice_list) 623 | per_structure_dice = dice_arr.mean(axis=0) 624 | 625 | avg_dice = np.mean(dice_arr) 626 | avg_elbo = utils.list_mean(elbo_list) 627 | avg_ged = utils.list_mean(ged_list) 628 | avg_ncc = utils.list_mean(ncc_list) 629 | 630 | logging.info('FULL VALIDATION (%d images):' % N) 631 | logging.info(' - Mean foreground dice: %.4f' % np.mean(per_structure_dice)) 632 | logging.info(' - Mean (neg.) ELBO: %.4f' % avg_elbo) 633 | logging.info(' - Mean GED: %.4f' % avg_ged) 634 | logging.info(' - Mean NCC: %.4f' % avg_ncc) 635 | 636 | logging.info('@ Running through validation set took: %.2f secs' % (time.time() - start_dice_val)) 637 | 638 | if np.mean(per_structure_dice) >= self.best_dice: 639 | self.best_dice = np.mean(per_structure_dice) 640 | logging.info('New best validation Dice! (%.3f)' % self.best_dice) 641 | best_file = os.path.join(self.log_dir, 'model_best_dice.ckpt') 642 | self.saver_best_dice.save(self.sess, best_file, global_step=global_step) 643 | 644 | if avg_elbo <= self.best_loss: 645 | self.best_loss = avg_elbo 646 | logging.info('New best validation loss! (%.3f)' % self.best_loss) 647 | best_file = os.path.join(self.log_dir, 'model_best_loss.ckpt') 648 | self.saver_best_loss.save(self.sess, best_file, global_step=global_step) 649 | 650 | if avg_ged <= self.best_ged: 651 | self.best_ged = avg_ged 652 | logging.info('New best GED score! (%.3f)' % self.best_ged) 653 | best_file = os.path.join(self.log_dir, 'model_best_ged.ckpt') 654 | self.saver_best_ged.save(self.sess, best_file, global_step=global_step) 655 | 656 | if avg_ncc >= self.best_ncc: 657 | self.best_ncc = avg_ncc 658 | logging.info('New best NCC score! (%.3f)' % self.best_ncc) 659 | best_file = os.path.join(self.log_dir, 'model_best_ncc.ckpt') 660 | self.saver_best_ncc.save(self.sess, best_file, global_step=global_step) 661 | 662 | # Create Validation Summary feed dict 663 | z_prior_list = self.generate_prior_samples(x_in=val_x) 664 | val_summary_feed_dict = {i: d for i, d in zip(self.z_list_gen, z_prior_list)} # this is for prior samples 665 | val_summary_feed_dict[self.x_for_gen] = val_x 666 | 667 | # Fill placeholders for all losses 668 | for loss_key, loss_val in zip(self.loss_dict.keys(), val_losses_out): 669 | # The detour over loss_dict.keys() is necessary because val_losses_out is sorted in the same 670 | # way as loss_dict. Same for the training below. 671 | loss_pl = self.validation_loss_pl_dict[loss_key] 672 | val_summary_feed_dict[loss_pl] = loss_val 673 | 674 | # Fill placeholders for validation Dice 675 | val_summary_feed_dict[self.val_tot_dice_score] = avg_dice 676 | val_summary_feed_dict[self.val_mean_dice_score] = np.mean(per_structure_dice) 677 | val_summary_feed_dict[self.val_elbo] = avg_elbo 678 | val_summary_feed_dict[self.val_ged] = avg_ged 679 | val_summary_feed_dict[self.val_ncc] = np.squeeze(avg_ncc) 680 | 681 | for ii in range(self.exp_config.nlabels): 682 | val_summary_feed_dict[self.val_lbl_dice_scores[ii]] = per_structure_dice[ii] 683 | 684 | val_summary_feed_dict[self.x_inp] = val_x 685 | val_summary_feed_dict[self.s_inp] = val_s 686 | val_summary_feed_dict[self.training_pl] = False 687 | 688 | val_summary_msg = self.sess.run(self.val_summary, feed_dict=val_summary_feed_dict) 689 | self.summary_writer.add_summary(val_summary_msg, global_step) 690 | 691 | # Create train Summary feed dict 692 | train_summary_feed_dict = {} 693 | for loss_key, loss_val in zip(self.loss_dict.keys(), train_losses_out): 694 | loss_pl = self.train_loss_pl_dict[loss_key] 695 | train_summary_feed_dict[loss_pl] = loss_val 696 | train_summary_feed_dict[self.training_pl] = False 697 | 698 | train_summary_msg = self.sess.run(self.train_summary, 699 | feed_dict=train_summary_feed_dict 700 | ) 701 | self.summary_writer.add_summary(train_summary_msg, global_step) 702 | 703 | 704 | def _make_tensorboard_summaries(self): 705 | 706 | def create_im_summary(img, name, rescale_mode, batch_size=self.exp_config.batch_size): 707 | 708 | if tfutils.tfndims(img) == 3: 709 | img_disp = tf.expand_dims(img, axis=-1) 710 | elif tfutils.tfndims(img) == 4: 711 | img_disp = img 712 | else: 713 | raise ValueError("Unexpected tensor ndim: %d" % tfutils.tfndims(img)) 714 | 715 | nlabels = self.exp_config.nlabels if rescale_mode == 'labelmap' else None 716 | return tf.summary.image(name, tfutils.put_kernels_on_grid(img_disp, batch_size=batch_size, rescale_mode=rescale_mode, nlabels=nlabels)) 717 | 718 | 719 | tf.summary.scalar('batch_total_loss', self.loss_tot) 720 | tf.summary.scalar('learning_rate', self.lr_pl) 721 | 722 | for ii, (mu, sigma) in enumerate(zip(self.mu_list, self.sigma_list)): 723 | tf.summary.scalar('average_mu_lvl%d' % ii, tf.reduce_mean(mu)) 724 | tf.summary.scalar('average_sigma_lvl%d' % ii, tf.reduce_mean(sigma)) 725 | tf.summary.scalar('average_prior_mu_lvl%d' % ii, tf.reduce_mean(self.prior_mu_list[ii])) 726 | tf.summary.scalar('average_prior_sigma_lvl%d' % ii, tf.reduce_mean(self.prior_sigma_list[ii])) 727 | 728 | if self.exp_config.do_image_summaries: 729 | 730 | create_im_summary(self.x_inp, 'train_x_inp', rescale_mode='standardize') 731 | create_im_summary(self.s_inp, 'train_s_inp', rescale_mode='labelmap') 732 | create_im_summary(tf.argmax(self.s_out, axis=-1), 'train_s_out', rescale_mode='labelmap') 733 | 734 | for ii in range(self.exp_config.latent_levels): 735 | create_im_summary(tf.argmax(self.s_out_list[ii], axis=-1), 'train_s_out_list_%d' % ii, rescale_mode='labelmap') 736 | create_im_summary(tf.argmax(self.s_accum[ii], axis=-1), 'train_s_accum_list_%d' % ii, rescale_mode='labelmap') 737 | 738 | 739 | # Build the summary Tensor based on the TF collection of Summaries. 740 | self.summary = tf.summary.merge_all() 741 | 742 | # Validation summaries 743 | self.validation_loss_pl_dict= {} 744 | val_summary_list = [] 745 | for loss_name in self.loss_dict.keys(): 746 | self.validation_loss_pl_dict[loss_name] = tf.placeholder(tf.float32, shape=[], name='val_%s' % loss_name) 747 | val_summary_list.append(tf.summary.scalar('val_batch_%s' % loss_name, self.validation_loss_pl_dict[loss_name])) 748 | 749 | self.val_summary = tf.summary.merge(val_summary_list) 750 | 751 | if self.exp_config.do_image_summaries: 752 | 753 | # Validation reconstructions 754 | 755 | val_img_sum = [] 756 | 757 | val_img_sum.append(create_im_summary(self.x_inp, 'val_x_inp', rescale_mode='standardize')) 758 | val_img_sum.append(create_im_summary(self.s_inp, 'val_s_inp', rescale_mode='labelmap')) 759 | val_img_sum.append(create_im_summary(tf.argmax(self.s_out, axis=-1), 'val_s_out', rescale_mode='labelmap')) 760 | 761 | for ii in range(self.exp_config.latent_levels): 762 | 763 | val_img_sum.append(create_im_summary(tf.argmax(self.s_out_list[ii], axis=-1), 'val_s_out_list_%d' % ii, rescale_mode='labelmap')) 764 | val_img_sum.append(create_im_summary(tf.argmax(self.s_accum[ii], axis=-1), 'val_s_accum_list_%d' % ii, rescale_mode='labelmap')) 765 | 766 | self.x_for_gen = tf.placeholder(tf.float32, shape=self.x_inp.shape) 767 | self.z_list_gen = [] 768 | for z in self.z_list: 769 | self.z_list_gen.append(tf.placeholder(tf.float32, shape=z.shape)) 770 | 771 | s_from_prior = tf.argmax(self.s_out_eval, axis=-1) 772 | 773 | val_img_sum.append(create_im_summary(s_from_prior, 'generated_seg', rescale_mode='labelmap')) 774 | val_img_sum.append(create_im_summary(self.x_for_gen, 'generated_x_in', rescale_mode='standardize')) 775 | 776 | self.val_summary = tf.summary.merge([self.val_summary, val_img_sum]) 777 | 778 | # Val Dice summaries 779 | 780 | self.val_tot_dice_score = tf.placeholder(tf.float32, shape=[], name='val_dice_total_score') 781 | val_tot_dice_summary = tf.summary.scalar('validation_dice_tot_score', self.val_tot_dice_score) 782 | 783 | self.val_mean_dice_score = tf.placeholder(tf.float32, shape=[], name='val_dice_mean_score') 784 | val_mean_dice_summary = tf.summary.scalar('validation_dice_mean_score', self.val_mean_dice_score) 785 | 786 | self.val_elbo = tf.placeholder(tf.float32, shape=[], name='val_elbo') 787 | val_elbo_summary = tf.summary.scalar('validation_neg_elbo', self.val_elbo) 788 | 789 | self.val_ged = tf.placeholder(tf.float32, shape=[], name='val_ged') 790 | val_ged_summary = tf.summary.scalar('validation_GED', self.val_ged) 791 | 792 | self.val_ncc = tf.placeholder(tf.float32, shape=[], name='val_ncc') 793 | val_ncc_summary = tf.summary.scalar('validation_NCC', self.val_ncc) 794 | 795 | 796 | self.val_lbl_dice_scores = [] 797 | val_lbl_dice_summaries = [] 798 | for ii in range(self.exp_config.nlabels): 799 | curr_pl = tf.placeholder(tf.float32, shape=[], name='validation_dice_lbl_%d' % ii) 800 | self.val_lbl_dice_scores.append(curr_pl) 801 | val_lbl_dice_summaries.append(tf.summary.scalar('validation_dice_lbl_%d' % ii, curr_pl)) 802 | 803 | self.val_summary = tf.summary.merge([self.val_summary, 804 | val_tot_dice_summary, 805 | val_mean_dice_summary, 806 | val_lbl_dice_summaries, 807 | val_elbo_summary, 808 | val_ged_summary, 809 | val_ncc_summary]) 810 | 811 | # Train summaries 812 | self.train_loss_pl_dict= {} 813 | train_summary_list = [] 814 | for loss_name in self.loss_dict.keys(): 815 | self.train_loss_pl_dict[loss_name] = tf.placeholder(tf.float32, shape=[], name='val_%s' % loss_name) 816 | train_summary_list.append(tf.summary.scalar('train_batch_%s' % loss_name, self.train_loss_pl_dict[loss_name])) 817 | 818 | self.train_summary = tf.summary.merge(train_summary_list) 819 | 820 | 821 | def _setup_log_dir_and_continue_mode(self): 822 | 823 | # Default values 824 | self.log_dir = os.path.join(sys_config.log_root, self.exp_config.log_dir_name, self.exp_config.experiment_name) 825 | self.init_checkpoint_path = None 826 | self.continue_run = False 827 | self.init_step = 0 828 | 829 | # If a checkpoint file already exists enable continue mode 830 | if tf.gfile.Exists(self.log_dir): 831 | init_checkpoint_path = tfutils.get_latest_model_checkpoint_path(self.log_dir, 'model.ckpt') 832 | if init_checkpoint_path is not False: 833 | 834 | self.init_checkpoint_path = init_checkpoint_path 835 | self.continue_run = True 836 | self.init_step = int(self.init_checkpoint_path.split('/')[-1].split('-')[-1]) 837 | self.log_dir += '_cont' 838 | 839 | logging.info('--------------------------- Continuing previous run --------------------------------') 840 | logging.info('Checkpoint path: %s' % self.init_checkpoint_path) 841 | logging.info('Latest step was: %d' % self.init_step) 842 | logging.info('------------------------------------------------------------------------------------') 843 | 844 | tf.gfile.MakeDirs(self.log_dir) 845 | self.summary_writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) 846 | 847 | # Copy experiment config file to log_dir for future reference 848 | # shutil.copy(self.exp_config.__file__, self.log_dir) 849 | # logging.info('!!!! Copied exp_config file to experiment folder !!!!') -------------------------------------------------------------------------------- /phiseg_generate_samples.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import logging 4 | import os 5 | from importlib.machinery import SourceFileLoader 6 | import cv2 7 | import argparse 8 | 9 | import numpy as np 10 | 11 | import config.system as sys_config 12 | import utils 13 | from data.data_switch import data_switch 14 | from phiseg.phiseg_model import phiseg 15 | 16 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 17 | 18 | # import matplotlib.pyplot as plt 19 | 20 | import matplotlib as mpl 21 | mpl.use('Agg') 22 | import matplotlib.pyplot as plt 23 | 24 | model_selection = 'best_ged' 25 | 26 | def preproc_image(x, nlabels=None): 27 | 28 | x_b = np.squeeze(x) 29 | 30 | ims = x_b.shape[:2] 31 | 32 | if nlabels: 33 | x_b = np.uint8((x_b / (nlabels)) * 255) # not nlabels - 1 because I prefer gray over white 34 | else: 35 | x_b = utils.convert_to_uint8(x_b) 36 | 37 | # x_b = cv2.cvtColor(np.squeeze(x_b), cv2.COLOR_GRAY2BGR) 38 | # x_b = utils.histogram_equalization(x_b) 39 | x_b = utils.resize_image(x_b, (2 * ims[0], 2 * ims[1]), interp=cv2.INTER_NEAREST) 40 | 41 | # ims_n = x_b.shape[:2] 42 | # x_b = x_b[ims_n[0]//4:3*ims_n[0]//4, ims_n[1]//4: 3*ims_n[1]//4,...] 43 | return x_b 44 | 45 | 46 | def generate_error_maps(sample_arr, gt_arr): 47 | 48 | def pixel_wise_xent(m_samp, m_gt, eps=1e-8): 49 | 50 | 51 | log_samples = np.log(m_samp + eps) 52 | return -1.0*np.sum(m_gt*log_samples, axis=-1) 53 | 54 | mean_seg = np.mean(sample_arr, axis=0) 55 | 56 | N = sample_arr.shape[0] 57 | M = gt_arr.shape[0] 58 | 59 | sX = sample_arr.shape[1] 60 | sY = sample_arr.shape[2] 61 | 62 | E_ss_arr = np.zeros((N,sX,sY)) 63 | for i in range(N): 64 | E_ss_arr[i,...] = pixel_wise_xent(sample_arr[i,...], mean_seg) 65 | 66 | E_ss = np.mean(E_ss_arr, axis=0) 67 | 68 | E_sy_arr = np.zeros((M,N, sX, sY)) 69 | for j in range(M): 70 | for i in range(N): 71 | E_sy_arr[j,i, ...] = pixel_wise_xent(sample_arr[i,...], gt_arr[j,...]) 72 | 73 | E_sy_avg = np.mean(np.mean(E_sy_arr, axis=1), axis=0) 74 | 75 | E_yy_arr = np.zeros((M,M, sX, sY)) 76 | for j in range(M): 77 | for i in range(M): 78 | E_yy_arr[j,i, ...] = pixel_wise_xent(sample_arr[i,...], gt_arr[j,...]) 79 | 80 | E_yy_avg = np.mean(np.mean(E_yy_arr, axis=1), axis=0) 81 | 82 | return E_ss, E_sy_avg, E_yy_avg 83 | 84 | 85 | 86 | def main(model_path, exp_config): 87 | 88 | # Make and restore vagan model 89 | phiseg_model = phiseg(exp_config=exp_config) 90 | phiseg_model.load_weights(model_path, type=model_selection) 91 | 92 | data_loader = data_switch(exp_config.data_identifier) 93 | data = data_loader(exp_config) 94 | 95 | N = data.test.images.shape[0] 96 | 97 | n_images = 16 98 | n_samples = 16 99 | 100 | # indices = np.arange(N) 101 | # sample_inds = np.random.choice(indices, n_images) 102 | sample_inds = [165, 280, 213] # <-- prostate 103 | # sample_inds = [1551] #[907, 1296, 1551] # <-- LIDC 104 | 105 | for ii in sample_inds: 106 | 107 | print('------- Processing image %d -------' % ii) 108 | 109 | outfolder = os.path.join(model_path, 'samples_%s' % model_selection, str(ii)) 110 | utils.makefolder(outfolder) 111 | 112 | x_b = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size)) 113 | s_b = data.test.labels[ii, ...] 114 | 115 | if np.sum(s_b) < 10: 116 | print('WARNING: skipping cases with no structures') 117 | continue 118 | 119 | s_b_r = utils.convert_batch_to_onehot(s_b.transpose((2, 0, 1)), exp_config.nlabels) 120 | 121 | print('Plotting input image') 122 | plt.figure() 123 | x_b_d = preproc_image(x_b) 124 | plt.imshow(x_b_d, cmap='gray') 125 | plt.axis('off') 126 | plt.savefig(os.path.join(outfolder, 'input_img_%d.png' % ii),bbox_inches='tight') 127 | 128 | print('Generating 100 samples') 129 | s_p_list = [] 130 | for kk in range(100): 131 | s_p_list.append(phiseg_model.predict_segmentation_sample(x_b, return_softmax=True)) 132 | s_p_arr = np.squeeze(np.asarray(s_p_list)) 133 | 134 | 135 | print('Plotting %d of those samples' % n_samples) 136 | for jj in range(n_samples): 137 | 138 | s_p_sm = s_p_arr[jj,...] 139 | s_p_am = np.argmax(s_p_sm, axis=-1) 140 | 141 | plt.figure() 142 | s_p_d = preproc_image(s_p_am, nlabels=exp_config.nlabels) 143 | plt.imshow(s_p_d, cmap='gray') 144 | plt.axis('off') 145 | plt.savefig(os.path.join(outfolder, 'sample_img_%d_samp_%d.png' % (ii,jj)),bbox_inches='tight') 146 | 147 | print('Plotting ground-truths masks') 148 | for jj in range(s_b_r.shape[0]): 149 | 150 | s_b_sm = s_b_r[jj,...] 151 | s_b_am = np.argmax(s_b_sm, axis=-1) 152 | 153 | plt.figure() 154 | s_p_d = preproc_image(s_b_am, nlabels=exp_config.nlabels) 155 | plt.imshow(s_p_d, cmap='gray') 156 | plt.axis('off') 157 | plt.savefig(os.path.join(outfolder, 'gt_img_%d_samp_%d.png' % (ii,jj)),bbox_inches='tight') 158 | 159 | print('Generating error masks') 160 | E_ss, E_sy_avg, E_yy_avg = generate_error_maps(s_p_arr, s_b_r) 161 | 162 | print('Plotting them') 163 | plt.figure() 164 | plt.imshow(preproc_image(E_ss)) 165 | plt.axis('off') 166 | plt.savefig(os.path.join(outfolder, 'E_ss_%d.png' % ii), bbox_inches='tight') 167 | 168 | print('Plotting them') 169 | plt.figure() 170 | plt.imshow(preproc_image(np.log(E_ss))) 171 | plt.axis('off') 172 | plt.savefig(os.path.join(outfolder, 'log_E_ss_%d.png' % ii), bbox_inches='tight') 173 | 174 | 175 | plt.figure() 176 | plt.imshow(preproc_image(E_sy_avg)) 177 | plt.axis('off') 178 | plt.savefig(os.path.join(outfolder, 'E_sy_avg_%d_.png' % ii), bbox_inches='tight') 179 | 180 | plt.figure() 181 | plt.imshow(preproc_image(E_yy_avg)) 182 | plt.axis('off') 183 | plt.savefig(os.path.join(outfolder, 'E_yy_avg_%d_.png' % ii), bbox_inches='tight') 184 | 185 | plt.close('all') 186 | 187 | # plt.show() 188 | 189 | if __name__ == '__main__': 190 | 191 | parser = argparse.ArgumentParser( 192 | description="Script for a simple test loop evaluating a network on the test dataset") 193 | parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)") 194 | args = parser.parse_args() 195 | 196 | base_path = sys_config.project_root 197 | 198 | model_path = args.EXP_PATH 199 | config_file = glob.glob(model_path + '/*py')[0] 200 | config_module = config_file.split('/')[-1].rstrip('.py') 201 | 202 | exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module() 203 | 204 | main(model_path, exp_config=exp_config) 205 | -------------------------------------------------------------------------------- /phiseg_makegif_samples.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import logging 4 | import os 5 | from importlib.machinery import SourceFileLoader 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | import config.system as sys_config 11 | import utils 12 | from data.data_switch import data_switch 13 | from phiseg.phiseg_model import phiseg 14 | # import scipy.misc 15 | from PIL import Image 16 | 17 | 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 19 | 20 | def softmax(x): 21 | """Compute softmax values for each sets of scores in x.""" 22 | return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True) 23 | 24 | 25 | SAVE_VIDEO = False 26 | SAVE_GIF = True 27 | DISPLAY_VIDEO = True 28 | 29 | video_target_size = (256, 256) 30 | 31 | 32 | def histogram_equalization(img): 33 | 34 | lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 35 | 36 | # -----Splitting the LAB image to different channels------------------------- 37 | l, a, b = cv2.split(lab) 38 | 39 | # -----Applying CLAHE to L-channel------------------------------------------- 40 | clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) 41 | cl = clahe.apply(l) 42 | 43 | # -----Merge the CLAHE enhanced L-channel with the a and b channel----------- 44 | limg = cv2.merge((cl, a, b)) 45 | 46 | # -----Converting image from LAB Color model to RGB model-------------------- 47 | final = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) 48 | 49 | return final 50 | 51 | 52 | def main(model_path, exp_config): 53 | 54 | # Make and restore vagan model 55 | phiseg_model = phiseg(exp_config=exp_config) 56 | phiseg_model.load_weights(model_path, type='best_ged') 57 | 58 | data_loader = data_switch(exp_config.data_identifier) 59 | data = data_loader(exp_config) 60 | 61 | lat_lvls = exp_config.latent_levels 62 | 63 | # RANDOM IMAGE 64 | # x_b, s_b = data.test.next_batch(1) 65 | 66 | # FIXED IMAGE 67 | # Cardiac: 100 normal image 68 | # LIDC: 200 large lesion, 203, 1757 complicated lesion 69 | # Prostate: 165 nice slice, 170 is a challenging and interesting slice 70 | index = 165 # # 71 | 72 | if SAVE_GIF: 73 | outfolder_gif = os.path.join(model_path, 'model_samples_id%d_gif' % index) 74 | utils.makefolder(outfolder_gif) 75 | 76 | x_b = data.test.images[index,...].reshape([1]+list(exp_config.image_size)) 77 | 78 | x_b_d = utils.convert_to_uint8(np.squeeze(x_b)) 79 | x_b_d = utils.resize_image(x_b_d, video_target_size) 80 | 81 | if exp_config.data_identifier == 'uzh_prostate': 82 | # rotate 83 | rows, cols = x_b_d.shape 84 | M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 270, 1) 85 | x_b_d = cv2.warpAffine(x_b_d, M, (cols, rows)) 86 | 87 | if SAVE_VIDEO: 88 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 89 | outfile = os.path.join(model_path, 'model_samples_id%d.avi' % index) 90 | out = cv2.VideoWriter(outfile, fourcc, 5.0, (2*video_target_size[1], video_target_size[0])) 91 | 92 | samps = 20 93 | for ii in range(samps): 94 | 95 | # fix all below current level (the correct implementation) 96 | feed_dict = {} 97 | feed_dict[phiseg_model.training_pl] = False 98 | feed_dict[phiseg_model.x_inp] = x_b 99 | 100 | s_p, s_p_list = phiseg_model.sess.run([phiseg_model.s_out_eval, phiseg_model.s_out_eval_list], feed_dict=feed_dict) 101 | s_p = np.argmax(s_p, axis=-1) 102 | 103 | # s_p_d = utils.convert_to_uint8(np.squeeze(s_p)) 104 | s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels)*255)) 105 | s_p_d = utils.resize_image(s_p_d, video_target_size, interp=cv2.INTER_NEAREST) 106 | 107 | if exp_config.data_identifier == 'uzh_prostate': 108 | #rotate 109 | s_p_d = cv2.warpAffine(s_p_d, M, (cols, rows)) 110 | 111 | img = np.concatenate([x_b_d, s_p_d], axis=1) 112 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 113 | 114 | img = histogram_equalization(img) 115 | 116 | if exp_config.data_identifier == 'acdc': 117 | # labels (0 85 170 255) 118 | rv = cv2.inRange(s_p_d, 84, 86) 119 | my = cv2.inRange(s_p_d, 169, 171) 120 | rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 121 | my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 122 | 123 | cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1) 124 | cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1) 125 | if exp_config.data_identifier == 'uzh_prostate': 126 | 127 | print(np.unique(s_p_d)) 128 | s1 = cv2.inRange(s_p_d, 84, 86) 129 | s2 = cv2.inRange(s_p_d, 169, 171) 130 | # s3 = cv2.inRange(s_p_d, 190, 192) 131 | s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 132 | s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 133 | # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 134 | 135 | cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1) 136 | cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1) 137 | # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1) 138 | elif exp_config.data_identifier == 'lidc': 139 | thresh = cv2.inRange(s_p_d, 127, 255) 140 | lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 141 | cv2.drawContours(img, lesion, -1, (0, 255, 0), 1) 142 | 143 | 144 | if SAVE_VIDEO: 145 | out.write(img) 146 | 147 | if SAVE_GIF: 148 | outfile_gif = os.path.join(outfolder_gif, 'frame_%s.png' % str(ii).zfill(3)) 149 | img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 150 | # scipy.misc.imsave(outfile_gif, img_rgb) 151 | im = Image.fromarray(img_rgb) 152 | im = im.resize((im.size[0]*2, im.size[1]*2), Image.ANTIALIAS) 153 | 154 | im.save(outfile_gif) 155 | 156 | if DISPLAY_VIDEO: 157 | cv2.imshow('frame', img) 158 | if cv2.waitKey(1) & 0xFF == ord('q'): 159 | break 160 | 161 | if SAVE_VIDEO: 162 | out.release() 163 | cv2.destroyAllWindows() 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | base_path = sys_config.project_root 169 | 170 | # Code for selecting experiment from command line 171 | # parser = argparse.ArgumentParser( 172 | # description="Script for a simple test loop evaluating a network on the test dataset") 173 | # parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)") 174 | # args = parser.parse_args() 175 | 176 | 177 | # exp_path = args.EXP_PATH 178 | 179 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/lidc/segvae_7_5' 180 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/lidc/probunet' 181 | # 182 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/segvae_7_5_1annot' 183 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/segvae_7_5' 184 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/probUNET_1annotator_2' 185 | exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/segvae_7_5_batchnorm_rerun' 186 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/segvae_7_5_batchnorm_schedule' 187 | # exp_path = '/itet-stor/baumgach/net_scratch/logs/segvae/uzh_prostate_afterpaper/probUNET' 188 | 189 | 190 | 191 | 192 | model_path = exp_path 193 | config_file = glob.glob(model_path + '/*py')[0] 194 | config_module = config_file.split('/')[-1].rstrip('.py') 195 | 196 | exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module() 197 | 198 | main(model_path, exp_config=exp_config) 199 | -------------------------------------------------------------------------------- /phiseg_sample_construction.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import logging 4 | import os 5 | from importlib.machinery import SourceFileLoader 6 | import cv2 7 | 8 | import numpy as np 9 | 10 | import config.system as sys_config 11 | import utils 12 | from data.data_switch import data_switch 13 | from phiseg.phiseg_model import phiseg 14 | 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | import itertools 20 | def findsubsets(S,m): 21 | return list(itertools.combinations(S, m)) 22 | 23 | def main(model_path, exp_config): 24 | 25 | # Make and restore vagan model 26 | phiseg_model = phiseg(exp_config=exp_config) 27 | phiseg_model.load_weights(model_path, type='best_dice') 28 | 29 | data_loader = data_switch(exp_config.data_identifier) 30 | data = data_loader(exp_config) 31 | 32 | outfolder = '/home/baumgach/Reports/ETH/MICCAI2019_segvae/raw_figures' 33 | 34 | ims = exp_config.image_size 35 | 36 | # x_b, s_b = data.test.next_batch(1) 37 | 38 | # heart 100 39 | # prostate 165 40 | index = 165 # 100 is a normal image, 15 is a very good slice 41 | x_b = data.test.images[index, ...].reshape([1] + list(exp_config.image_size)) 42 | if exp_config.data_identifier == 'lidc': 43 | s_b = data.test.labels[index, ...] 44 | if np.sum(s_b[..., 0]) > 0: 45 | s_b = s_b[..., 0] 46 | elif np.sum(s_b[..., 1]) > 0: 47 | s_b = s_b[..., 1] 48 | elif np.sum(s_b[..., 2]) > 0: 49 | s_b = s_b[..., 2] 50 | else: 51 | s_b = s_b[..., 3] 52 | 53 | s_b = s_b.reshape([1] + list(exp_config.image_size[0:2])) 54 | elif exp_config.data_identifier == 'uzh_prostate': 55 | s_b = data.test.labels[index, ...] 56 | s_b = s_b[..., 0] 57 | s_b = s_b.reshape([1] + list(exp_config.image_size[0:2])) 58 | else: 59 | s_b = data.test.labels[index, ...].reshape([1] + list(exp_config.image_size[0:2])) 60 | 61 | 62 | 63 | x_b_for_cnt = utils.convert_to_uint8(np.squeeze(x_b.copy())) 64 | x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_GRAY2BGR) 65 | 66 | x_b_for_cnt = utils.resize_image(x_b_for_cnt, (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 67 | x_b_for_cnt = utils.histogram_equalization(x_b_for_cnt) 68 | 69 | for ss in range(3): 70 | 71 | print(ss) 72 | 73 | s_p_list = phiseg_model.predict_segmentation_sample_levels(x_b, return_softmax=False) 74 | 75 | accum_list = [None]*exp_config.latent_levels 76 | accum_list[exp_config.latent_levels-1] = s_p_list[-1] 77 | for lvl in reversed(range(exp_config.latent_levels-1)): 78 | accum_list[lvl] = accum_list[lvl+1] + s_p_list[lvl] 79 | 80 | print('Plotting accum_list') 81 | for ii, img in enumerate(accum_list): 82 | 83 | plt.figure() 84 | img = utils.resize_image(np.squeeze(np.argmax(img, axis=-1)), (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 85 | plt.imshow(img[2*30:2*192-2*30,2*30:2*192-2*30], cmap='gray') 86 | plt.axis('off') 87 | plt.savefig(os.path.join(outfolder, 'segm_lvl_%d_samp_%d.png' % (ii, ss)),bbox_inches='tight') 88 | 89 | print('Plotting s_p_list') 90 | for ii, img in enumerate(s_p_list): 91 | 92 | img = utils.softmax(img) 93 | 94 | plt.figure() 95 | img = utils.resize_image(np.squeeze(img[...,1]), (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 96 | plt.imshow(img[2*30:2*192-2*30,2*30:2*192-2*30], cmap='gray') 97 | plt.axis('off') 98 | plt.savefig(os.path.join(outfolder, 'residual_lvl_%d_samp_%d.png' % (ii, ss)),bbox_inches='tight') 99 | 100 | s_p_d = np.uint8((np.squeeze(np.argmax(accum_list[0], axis=-1)) / (exp_config.nlabels-1)) * 255) 101 | s_p_d = utils.resize_image(s_p_d, (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 102 | 103 | print('Calculating contours') 104 | print(np.unique(s_p_d)) 105 | rv = cv2.inRange(s_p_d, 84, 86) 106 | my = cv2.inRange(s_p_d, 169, 171) 107 | rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 108 | my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 109 | 110 | x_b_for_cnt = cv2.drawContours(x_b_for_cnt, rv_cnt, -1, (0, 255, 0), 1) 111 | x_b_for_cnt = cv2.drawContours(x_b_for_cnt, my_cnt, -1, (0, 0, 255), 1) 112 | 113 | x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_BGR2RGB) 114 | 115 | print('Plotting final images...') 116 | plt.figure() 117 | plt.imshow(x_b_for_cnt[2*30:2*192-2*30,2*30:2*192-2*30,:], cmap='gray') 118 | plt.axis('off') 119 | plt.savefig(os.path.join(outfolder, 'input_img_cnts.png'),bbox_inches='tight') 120 | 121 | plt.figure() 122 | x_b = utils.convert_to_uint8(x_b) 123 | x_b = cv2.cvtColor(np.squeeze(x_b), cv2.COLOR_GRAY2BGR) 124 | x_b = utils.histogram_equalization(x_b) 125 | x_b = utils.resize_image(x_b, (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 126 | plt.imshow(x_b[2*30:2*192-2*30,2*30:2*192-2*30], cmap='gray') 127 | plt.axis('off') 128 | plt.savefig(os.path.join(outfolder, 'input_img.png'),bbox_inches='tight') 129 | 130 | plt.figure() 131 | s_b = utils.resize_image(np.squeeze(s_b), (2*ims[0], 2*ims[1]), interp=cv2.INTER_NEAREST) 132 | plt.imshow(s_b[2*30:2*192-2*30,2*30:2*192-2*30], cmap='gray') 133 | plt.axis('off') 134 | plt.savefig(os.path.join(outfolder, 'gt_seg.png'),bbox_inches='tight') 135 | 136 | 137 | # plt.show() 138 | 139 | 140 | if __name__ == '__main__': 141 | 142 | base_path = sys_config.project_root 143 | 144 | # Code for selecting experiment from command line 145 | # parser = argparse.ArgumentParser( 146 | # description="Script for a simple test loop evaluating a network on the test dataset") 147 | # parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)") 148 | # args = parser.parse_args() 149 | 150 | 151 | # exp_path = args.EXP_PATH 152 | 153 | # 154 | exp_path = '/itet-stor/baumgach/net_scratch/logs/phiseg/uzh_prostate/phiseg_7_5' 155 | 156 | model_path = os.path.join(base_path, exp_path) 157 | config_file = glob.glob(model_path + '/*py')[0] 158 | config_module = config_file.split('/')[-1].rstrip('.py') 159 | 160 | exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module() 161 | 162 | main(model_path, exp_config=exp_config) 163 | -------------------------------------------------------------------------------- /phiseg_test_predictions.py: -------------------------------------------------------------------------------- 1 | # Get classification metrics for a trained classifier model 2 | # Authors: 3 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 4 | 5 | from phiseg.model_zoo import likelihoods 6 | import numpy as np 7 | import os 8 | import glob 9 | from importlib.machinery import SourceFileLoader 10 | import argparse 11 | from medpy.metric import dc 12 | 13 | import config.system as sys_config 14 | from phiseg.phiseg_model import phiseg 15 | import utils 16 | 17 | if not sys_config.running_on_gpu_host: 18 | import matplotlib.pyplot as plt 19 | 20 | import logging 21 | from data.data_switch import data_switch 22 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 23 | 24 | # structures_dict = {1: 'RV', 2: 'Myo', 3: 'LV'} 25 | 26 | model_selection = 'best_dice' 27 | 28 | def main(model_path, exp_config, do_plots=False): 29 | 30 | # Get Data 31 | phiseg_model = phiseg(exp_config=exp_config) 32 | phiseg_model.load_weights(model_path, type=model_selection) 33 | 34 | data_loader = data_switch(exp_config.data_identifier) 35 | data = data_loader(exp_config) 36 | 37 | # Run predictions in an endless loop 38 | dice_list = [] 39 | 40 | num_samples = 1 if exp_config.likelihood is likelihoods.det_unet2D else 100 41 | 42 | for ii, batch in enumerate(data.test.iterate_batches(1)): 43 | 44 | if ii % 10 == 0: 45 | logging.info("Progress: %d" % ii) 46 | 47 | # print(ii) 48 | 49 | x, y = batch 50 | 51 | y_ = np.squeeze(phiseg_model.predict(x, num_samples=num_samples)) 52 | 53 | per_lbl_dice = [] 54 | per_pixel_preds = [] 55 | per_pixel_gts = [] 56 | 57 | if do_plots and not sys_config.running_on_gpu_host: 58 | fig = plt.figure() 59 | fig.add_subplot(131) 60 | plt.imshow(np.squeeze(x), cmap='gray') 61 | fig.add_subplot(132) 62 | plt.imshow(np.squeeze(y_)) 63 | fig.add_subplot(133) 64 | plt.imshow(np.squeeze(y)) 65 | plt.show() 66 | 67 | for lbl in range(exp_config.nlabels): 68 | 69 | binary_pred = (y_ == lbl) * 1 70 | binary_gt = (y == lbl) * 1 71 | 72 | if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0: 73 | per_lbl_dice.append(1) 74 | elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(binary_pred) == 0 and np.sum(binary_gt) > 0: 75 | logging.warning('Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.') 76 | per_lbl_dice.append(0) 77 | else: 78 | per_lbl_dice.append(dc(binary_pred, binary_gt)) 79 | 80 | dice_list.append(per_lbl_dice) 81 | 82 | per_pixel_preds.append(y_.flatten()) 83 | per_pixel_gts.append(y.flatten()) 84 | 85 | dice_arr = np.asarray(dice_list) 86 | 87 | mean_per_lbl_dice = dice_arr.mean(axis=0) 88 | 89 | logging.info('Dice') 90 | logging.info(mean_per_lbl_dice) 91 | logging.info(np.mean(mean_per_lbl_dice)) 92 | logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:]))) 93 | 94 | np.savez(os.path.join(model_path, 'dice_%s.npz' % model_selection), dice_arr) 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | parser = argparse.ArgumentParser( 100 | description="Script for a simple test loop evaluating a network on the test dataset") 101 | parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)") 102 | args = parser.parse_args() 103 | 104 | base_path = sys_config.project_root 105 | 106 | model_path = args.EXP_PATH 107 | config_file = glob.glob(model_path + '/*py')[0] 108 | config_module = config_file.split('/')[-1].rstrip('.py') 109 | 110 | exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module() 111 | 112 | main(model_path, exp_config=exp_config, do_plots=False) 113 | 114 | -------------------------------------------------------------------------------- /phiseg_test_quantitative.py: -------------------------------------------------------------------------------- 1 | # Get classification metrics for a trained classifier model 2 | # Authors: 3 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 4 | 5 | 6 | import numpy as np 7 | import os 8 | import glob 9 | from importlib.machinery import SourceFileLoader 10 | import argparse 11 | 12 | import config.system as sys_config 13 | from phiseg.phiseg_model import phiseg 14 | import utils 15 | 16 | import logging 17 | from data.data_switch import data_switch 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 19 | 20 | structures_dict = {1: 'RV', 2: 'Myo', 3: 'LV'} 21 | 22 | def main(model_path, exp_config, do_plots=False): 23 | 24 | n_samples = 50 25 | model_selection = 'best_ged' 26 | 27 | # Get Data 28 | phiseg_model = phiseg(exp_config=exp_config) 29 | phiseg_model.load_weights(model_path, type=model_selection) 30 | 31 | data_loader = data_switch(exp_config.data_identifier) 32 | data = data_loader(exp_config) 33 | 34 | N = data.test.images.shape[0] 35 | 36 | ged_list = [] 37 | ncc_list = [] 38 | 39 | for ii in range(N): 40 | 41 | if ii % 10 == 0: 42 | logging.info("Progress: %d" % ii) 43 | 44 | x_b = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size)) 45 | s_b = data.test.labels[ii, ...] 46 | 47 | x_b_stacked = np.tile(x_b, [n_samples, 1, 1, 1]) 48 | 49 | feed_dict = {} 50 | feed_dict[phiseg_model.training_pl] = False 51 | feed_dict[phiseg_model.x_inp] = x_b_stacked 52 | 53 | 54 | s_arr_sm = phiseg_model.sess.run(phiseg_model.s_out_eval_sm, feed_dict=feed_dict) 55 | s_arr = np.argmax(s_arr_sm, axis=-1) 56 | 57 | # s_arr = np.squeeze(np.asarray(s_list)) # num samples x X x Y 58 | s_b_r = s_b.transpose((2,0,1)) # num gts x X x Y 59 | s_b_r_sm = utils.convert_batch_to_onehot(s_b_r, exp_config.nlabels) # num gts x X x Y x nlabels 60 | 61 | ged = utils.generalised_energy_distance(s_arr, s_b_r, nlabels=exp_config.nlabels-1, label_range=range(1,exp_config.nlabels)) 62 | ged_list.append(ged) 63 | 64 | ncc = utils.variance_ncc_dist(s_arr_sm, s_b_r_sm) 65 | ncc_list.append(ncc) 66 | 67 | 68 | 69 | ged_arr = np.asarray(ged_list) 70 | ncc_arr = np.asarray(ncc_list) 71 | 72 | logging.info('-- GED: --') 73 | logging.info(np.mean(ged_arr)) 74 | logging.info(np.std(ged_arr)) 75 | 76 | logging.info('-- NCC: --') 77 | logging.info(np.mean(ncc_arr)) 78 | logging.info(np.std(ncc_arr)) 79 | 80 | np.savez(os.path.join(model_path, 'ged%s_%s.npz' % (str(n_samples), model_selection)), ged_arr) 81 | np.savez(os.path.join(model_path, 'ncc%s_%s.npz' % (str(n_samples), model_selection)), ncc_arr) 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | parser = argparse.ArgumentParser( 87 | description="Script for a simple test loop evaluating a network on the test dataset") 88 | parser.add_argument("EXP_PATH", type=str, help="Path to experiment folder (assuming you are in the working directory)") 89 | args = parser.parse_args() 90 | 91 | base_path = sys_config.project_root 92 | 93 | model_path = args.EXP_PATH 94 | config_file = glob.glob(model_path + '/*py')[0] 95 | config_module = config_file.split('/')[-1].rstrip('.py') 96 | 97 | exp_config = SourceFileLoader(config_module, os.path.join(config_file)).load_module() 98 | 99 | main(model_path, exp_config=exp_config, do_plots=False) 100 | 101 | -------------------------------------------------------------------------------- /phiseg_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from importlib.machinery import SourceFileLoader 4 | import argparse 5 | 6 | from data.data_switch import data_switch 7 | import os 8 | import config.system as sys_config 9 | import shutil 10 | import utils 11 | 12 | from phiseg import phiseg_model 13 | 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 15 | 16 | def main(exp_config): 17 | 18 | logging.info('**************************************************************') 19 | logging.info(' *** Running Experiment: %s', exp_config.experiment_name) 20 | logging.info('**************************************************************') 21 | 22 | # Get Data 23 | data_loader = data_switch(exp_config.data_identifier) 24 | data = data_loader(exp_config) 25 | 26 | # Create Model 27 | phiseg = phiseg_model.phiseg(exp_config) 28 | 29 | # Fit model to data 30 | phiseg.train(data) 31 | 32 | if __name__ == '__main__': 33 | 34 | parser = argparse.ArgumentParser( 35 | description="Script for training") 36 | parser.add_argument("EXP_PATH", type=str, help="Path to experiment config file") 37 | args = parser.parse_args() 38 | 39 | config_file = args.EXP_PATH 40 | config_module = config_file.split('/')[-1].rstrip('.py') 41 | 42 | exp_config = SourceFileLoader(config_module, config_file).load_module() 43 | 44 | log_dir = os.path.join(sys_config.log_root, exp_config.log_dir_name, exp_config.experiment_name) 45 | utils.makefolder(log_dir) 46 | 47 | shutil.copy(exp_config.__file__, log_dir) 48 | logging.info('!!!! Copied exp_config file to experiment folder !!!!') 49 | 50 | main(exp_config=exp_config) 51 | 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.2.2 2 | astor==0.6.2 3 | backcall==0.1.0 4 | backports-abc==0.5 5 | bleach==1.5.0 6 | cloudpickle==0.5.3 7 | cycler==0.10.0 8 | Cython==0.28.3 9 | dask==0.18.1 10 | decorator==4.3.0 11 | entrypoints==0.2.3 12 | gast==0.2.0 13 | grpcio==1.13.0 14 | h5py==2.8.0 15 | html5lib==0.9999999 16 | ipykernel==4.9.0 17 | ipython==6.4.0 18 | ipython-genutils==0.2.0 19 | ipywidgets==7.4.1 20 | jedi==0.12.0 21 | Jinja2==2.10 22 | jsonschema==2.6.0 23 | jupyter==1.0.0 24 | jupyter-client==5.2.3 25 | jupyter-console==5.2.0 26 | jupyter-core==4.4.0 27 | kiwisolver==1.0.1 28 | Markdown==2.6.11 29 | MarkupSafe==1.0 30 | matplotlib==2.2.2 31 | -e git+https://github.com/lmkoch/medpy/@b06b6decf41c63489e746f6a83e8fa5ff509adfa#egg=MedPy 32 | mistune==0.8.3 33 | nbconvert==5.3.1 34 | nbformat==4.4.0 35 | networkx==2.1 36 | nibabel==2.3.0 37 | notebook==5.6.0 38 | numpy==1.14.5 39 | opencv-contrib-python==3.4.3.18 40 | pandas==0.20.3 41 | pandocfilters==1.4.2 42 | parso==0.2.1 43 | pexpect==4.6.0 44 | pickleshare==0.7.4 45 | Pillow==5.1.0 46 | prometheus-client==0.3.1 47 | prompt-toolkit==1.0.15 48 | protobuf==3.6.0 49 | ptyprocess==0.6.0 50 | pydicom==1.1.0 51 | Pygments==2.2.0 52 | pynrrd==0.2.4 53 | pyparsing==2.2.0 54 | python-dateutil==2.7.3 55 | pytz==2018.5 56 | PyWavelets==0.5.2 57 | pyzmq==17.1.2 58 | qtconsole==4.4.1 59 | scikit-image==0.14.0 60 | scikit-learn==0.19.1 61 | scipy==1.1.0 62 | seaborn==0.9.0 63 | Send2Trash==1.5.0 64 | simplegeneric==0.8.1 65 | six==1.11.0 66 | termcolor==1.1.0 67 | terminado==0.8.1 68 | testpath==0.3.1 69 | toolz==0.9.0 70 | tornado==5.1 71 | tqdm==4.23.4 72 | traitlets==4.3.2 73 | typing==3.6.4 74 | wcwidth==0.1.7 75 | Werkzeug==0.14.1 76 | widgetsnbextension==3.4.1 77 | -------------------------------------------------------------------------------- /tfwrapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baumgach/PHiSeg-code/c43f3b32e1f434aecba936ff994b6f743ba7a5f8/tfwrapper/__init__.py -------------------------------------------------------------------------------- /tfwrapper/activations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def leaky_relu(x, alpha=0.01): 4 | return tf.maximum(x, alpha * x) 5 | -------------------------------------------------------------------------------- /tfwrapper/layers.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | # Lisa M. Koch (lisa.margret.koch@gmail.com) 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import logging 8 | from tfwrapper import utils 9 | from tfwrapper import normalisation as tfnorm 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 12 | 13 | # Will be used as default in all the layers below 14 | STANDARD_NONLINEARITY = tf.nn.relu 15 | 16 | ## Pooling layers ## 17 | 18 | def maxpool2D(x, kernel_size=(2, 2), strides=(2, 2), padding="SAME"): 19 | ''' 20 | nets2D max pooling layer with standard 2x2 pooling as default 21 | ''' 22 | 23 | kernel_size_aug = [1, kernel_size[0], kernel_size[1], 1] 24 | strides_aug = [1, strides[0], strides[1], 1] 25 | 26 | op = tf.nn.max_pool(x, ksize=kernel_size_aug, strides=strides_aug, padding=padding) 27 | 28 | return op 29 | 30 | 31 | def maxpool3D(x, kernel_size=(2, 2, 2), strides=(2, 2, 2), padding="SAME"): 32 | ''' 33 | nets3D max pooling layer with 2x2x2 pooling as default 34 | ''' 35 | 36 | kernel_size_aug = [1, kernel_size[0], kernel_size[1], kernel_size[2], 1] 37 | strides_aug = [1, strides[0], strides[1], strides[2], 1] 38 | 39 | op = tf.nn.max_pool3d(x, ksize=kernel_size_aug, strides=strides_aug, padding=padding) 40 | 41 | return op 42 | 43 | 44 | def averagepool2D(x, kernel_size=(2, 2), strides=(2, 2), padding="SAME"): 45 | ''' 46 | nets2D max pooling layer with standard 2x2 pooling as default 47 | ''' 48 | 49 | kernel_size_aug = [1, kernel_size[0], kernel_size[1], 1] 50 | strides_aug = [1, strides[0], strides[1], 1] 51 | 52 | op = tf.nn.avg_pool(x, ksize=kernel_size_aug, strides=strides_aug, padding=padding) 53 | 54 | return op 55 | 56 | 57 | def reshape_pool2D_layer(x): 58 | ''' 59 | nets2D max pooling layer with standard 2x2 pooling as default 60 | ''' 61 | 62 | S0 = x[:,0::2,0::2, :] 63 | S1 = x[:,1::2,0::2, :] 64 | S2 = x[:,0::2,1::2, :] 65 | S3 = x[:,1::2,1::2, :] 66 | 67 | return tf.concat([S0, S1, S2, S3], axis=3) 68 | 69 | 70 | def global_averagepool2D(x, name=None): 71 | ''' 72 | nets3D max pooling layer with 2x2x2 pooling as default 73 | ''' 74 | 75 | op = tf.reduce_mean(x, axis=(1,2), keepdims=False, name=name) 76 | tf.summary.histogram(op.op.name + '/activations', op) 77 | 78 | return op 79 | 80 | 81 | def global_averagepool3D(x, name=None): 82 | ''' 83 | nets3D max pooling layer with 2x2x2 pooling as default 84 | ''' 85 | 86 | op = tf.reduce_mean(x, axis=(1,2,3), keepdims=False, name=name) 87 | tf.summary.histogram(op.op.name + '/activations', op) 88 | 89 | return op 90 | 91 | 92 | ## Standard feed-forward layers ## 93 | 94 | def conv2D(x, 95 | name, 96 | kernel_size=(3,3), 97 | num_filters=32, 98 | strides=(1,1), 99 | activation=STANDARD_NONLINEARITY, 100 | normalisation=tf.identity, 101 | normalise_post_activation=False, 102 | dropout_p=None, 103 | padding="SAME", 104 | weight_init='he_normal', 105 | add_bias=True, 106 | **kwargs): 107 | 108 | ''' 109 | Standard nets2D convolutional layer 110 | kwargs can have training, and potentially other normalisation paramters 111 | ''' 112 | 113 | bottom_num_filters = x.get_shape().as_list()[-1] 114 | 115 | weight_shape = [kernel_size[0], kernel_size[1], bottom_num_filters, num_filters] 116 | bias_shape = [num_filters] 117 | 118 | strides_augm = [1, strides[0], strides[1], 1] 119 | 120 | with tf.variable_scope(name): 121 | 122 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 123 | op = tf.nn.conv2d(x, filter=weights, strides=strides_augm, padding=padding) 124 | 125 | biases = None # so there is always something for summary 126 | if add_bias and normalisation is tfnorm.batch_norm: 127 | logging.info('Turning of bias because using batch norm.') 128 | add_bias = False 129 | 130 | if add_bias: 131 | biases = utils.get_bias_variable(bias_shape, name='b') 132 | op = tf.nn.bias_add(op, biases) 133 | 134 | if not normalise_post_activation: 135 | op = activation(normalisation(op, **kwargs)) 136 | else: 137 | op = normalisation(activation(op), **kwargs) 138 | 139 | if dropout_p is not None: 140 | op = dropout(op, keep_prob=dropout_p, **kwargs) 141 | 142 | # Add Tensorboard summaries 143 | _add_summaries(op, weights, biases) 144 | 145 | return op 146 | 147 | 148 | def conv3D(x, 149 | name, 150 | kernel_size=(3,3,3), 151 | num_filters=32, 152 | strides=(1,1,1), 153 | activation=STANDARD_NONLINEARITY, 154 | normalisation=tf.identity, 155 | normalise_post_activation=False, 156 | dropout_p=None, 157 | padding="SAME", 158 | weight_init='he_normal', 159 | add_bias=True, 160 | **kwargs): 161 | 162 | ''' 163 | Standard nets3D convolutional layer 164 | ''' 165 | 166 | bottom_num_filters = x.get_shape().as_list()[-1] 167 | 168 | weight_shape = [kernel_size[0], kernel_size[1], kernel_size[2], bottom_num_filters, num_filters] 169 | bias_shape = [num_filters] 170 | 171 | strides_augm = [1, strides[0], strides[1], strides[2], 1] 172 | 173 | with tf.variable_scope(name): 174 | 175 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 176 | op = tf.nn.conv3d(x, filter=weights, strides=strides_augm, padding=padding) 177 | 178 | biases = None 179 | if add_bias: 180 | biases = utils.get_bias_variable(bias_shape, name='b') 181 | op = tf.nn.bias_add(op, biases) 182 | 183 | if not normalise_post_activation: 184 | op = activation(normalisation(op, **kwargs)) 185 | else: 186 | op = normalisation(activation(op), **kwargs) 187 | 188 | if dropout_p is not None: 189 | op = dropout(op, keep_prob=dropout_p, **kwargs) 190 | 191 | # Add Tensorboard summaries 192 | _add_summaries(op, weights, biases) 193 | 194 | return op 195 | 196 | 197 | def transposed_conv2D(bottom, 198 | name, 199 | kernel_size=(4,4), 200 | num_filters=32, 201 | strides=(2,2), 202 | output_shape=None, 203 | activation=STANDARD_NONLINEARITY, 204 | normalisation=tf.identity, 205 | normalise_post_activation=False, 206 | dropout_p=None, 207 | padding="SAME", 208 | weight_init='he_normal', 209 | add_bias=True, 210 | **kwargs): 211 | 212 | ''' 213 | Standard nets2D transpose (also known as deconvolution) layer. Default behaviour upsamples the input by a 214 | factor of 2. 215 | ''' 216 | 217 | bottom_shape = bottom.get_shape().as_list() 218 | if output_shape is None: 219 | batch_size = tf.shape(bottom)[0] 220 | output_shape = tf.stack([batch_size, bottom_shape[1]*strides[0], bottom_shape[2]*strides[1], num_filters]) 221 | 222 | bottom_num_filters = bottom_shape[3] 223 | 224 | weight_shape = [kernel_size[0], kernel_size[1], num_filters, bottom_num_filters] 225 | bias_shape = [num_filters] 226 | strides_augm = [1, strides[0], strides[1], 1] 227 | 228 | with tf.variable_scope(name): 229 | 230 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 231 | 232 | op = tf.nn.conv2d_transpose(bottom, 233 | filter=weights, 234 | output_shape=output_shape, 235 | strides=strides_augm, 236 | padding=padding) 237 | 238 | # The line below is hack necessary to fix a bug with tensorflow. The same operation is not required 239 | # for the 3D equivalent of this layer. 240 | op = tf.reshape(op, output_shape) 241 | 242 | biases = None 243 | if add_bias: 244 | biases = utils.get_bias_variable(bias_shape, name='b') 245 | op = tf.nn.bias_add(op, biases) 246 | 247 | if not normalise_post_activation: 248 | op = activation(normalisation(op, **kwargs)) 249 | else: 250 | op = normalisation(activation(op), **kwargs) 251 | 252 | if dropout_p is not None: 253 | op = dropout(op, keep_prob=dropout_p, **kwargs) 254 | 255 | # Add Tensorboard summaries 256 | _add_summaries(op, weights, biases) 257 | 258 | return op 259 | 260 | 261 | def transposed_conv3D(bottom, 262 | name, 263 | kernel_size=(4,4,4), 264 | num_filters=32, 265 | strides=(2,2,2), 266 | output_shape=None, 267 | activation=STANDARD_NONLINEARITY, 268 | normalisation=tf.identity, 269 | normalise_post_activation=False, 270 | dropout_p=None, 271 | padding="SAME", 272 | weight_init='he_normal', 273 | add_bias=True, 274 | **kwargs): 275 | 276 | ''' 277 | Standard nets2D transpose (also known as deconvolution) layer. Default behaviour upsamples the input by a 278 | factor of 2. 279 | ''' 280 | 281 | bottom_shape = bottom.get_shape().as_list() 282 | 283 | if output_shape is None: 284 | batch_size = tf.shape(bottom)[0] 285 | output_shape = tf.stack([batch_size, bottom_shape[1]*strides[0], bottom_shape[2]*strides[1], bottom_shape[3]*strides[2], num_filters]) 286 | 287 | bottom_num_filters = bottom_shape[4] 288 | 289 | weight_shape = [kernel_size[0], kernel_size[1], kernel_size[2], num_filters, bottom_num_filters] 290 | 291 | bias_shape = [num_filters] 292 | 293 | strides_augm = [1, strides[0], strides[1], strides[2], 1] 294 | 295 | with tf.variable_scope(name): 296 | 297 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 298 | 299 | op = tf.nn.conv3d_transpose(bottom, 300 | filter=weights, 301 | output_shape=output_shape, 302 | strides=strides_augm, 303 | padding=padding) 304 | 305 | # op = tf.reshape(op, output_shape) 306 | 307 | biases = None 308 | if add_bias: 309 | biases = utils.get_bias_variable(bias_shape, name='b') 310 | op = tf.nn.bias_add(op, biases) 311 | 312 | if not normalise_post_activation: 313 | op = activation(normalisation(op, **kwargs)) 314 | else: 315 | op = normalisation(activation(op), **kwargs) 316 | 317 | if dropout_p is not None: 318 | op = dropout(op, keep_prob=dropout_p, **kwargs) 319 | 320 | # Add Tensorboard summaries 321 | _add_summaries(op, weights, biases) 322 | 323 | return op 324 | 325 | 326 | def nearest_neighbour_upsample2D(x, factor): 327 | 328 | bottom_shape = x.get_shape().as_list() 329 | output_size = tf.stack([bottom_shape[1] * factor, bottom_shape[2] * factor]) 330 | 331 | op = tf.image.resize_images(x, output_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 332 | 333 | return op 334 | 335 | 336 | def bilinear_upsample2D(x, name, factor): 337 | 338 | bottom_shape = x.get_shape().as_list() 339 | output_size = tf.stack([bottom_shape[1] * factor, bottom_shape[2] * factor]) 340 | 341 | with tf.variable_scope(name): 342 | 343 | op = tf.image.resize_images(x, output_size) 344 | 345 | return op 346 | 347 | 348 | def bilinear_upsample3D(x, name, factor): 349 | 350 | # Taken from: https://niftynet.readthedocs.io/en/dev/_modules/niftynet/layer/linear_resize.html 351 | 352 | with tf.variable_scope(name): 353 | 354 | b_size, x_size, y_size, z_size, c_size = x.shape.as_list() 355 | 356 | x_size_new = x_size*factor 357 | y_size_new = y_size*factor 358 | z_size_new = z_size*factor 359 | 360 | # resize y-z 361 | squeeze_b_x = tf.reshape(x, [-1, y_size, z_size, c_size]) 362 | resize_b_x = tf.image.resize_bilinear( squeeze_b_x, [y_size_new, z_size_new]) 363 | resume_b_x = tf.reshape(resize_b_x, [b_size, x_size, y_size_new, z_size_new, c_size]) 364 | 365 | # resize x 366 | # first reorient 367 | reoriented = tf.transpose(resume_b_x, [0, 3, 2, 1, 4]) 368 | 369 | # squeeze and 2d resize 370 | squeeze_b_z = tf.reshape( reoriented, [-1, y_size_new, x_size, c_size]) 371 | resize_b_z = tf.image.resize_bilinear(squeeze_b_z, [y_size_new, x_size_new]) 372 | resume_b_z = tf.reshape(resize_b_z, [b_size, z_size_new, y_size_new, x_size_new, c_size]) 373 | 374 | output_tensor = tf.transpose(resume_b_z, [0, 3, 2, 1, 4]) 375 | 376 | return output_tensor 377 | 378 | def dilated_conv2D(bottom, 379 | name, 380 | kernel_size=(3,3), 381 | num_filters=32, 382 | rate=2, 383 | activation=STANDARD_NONLINEARITY, 384 | normalisation=tf.identity, 385 | normalise_post_activation=False, 386 | dropout_p=None, 387 | padding="SAME", 388 | weight_init='he_normal', 389 | add_bias=True, 390 | **kwargs): 391 | 392 | ''' 393 | nets2D dilated convolution layer. This layer can be used to increase the receptive field of a network. 394 | It is described in detail in this paper: Yu et al, Multi-Scale Context Aggregation by Dilated Convolutions, 395 | 2015 (https://arxiv.org/pdf/1511.07122.pdf) 396 | ''' 397 | 398 | bottom_num_filters = bottom.get_shape().as_list()[3] 399 | 400 | weight_shape = [kernel_size[0], kernel_size[1], bottom_num_filters, num_filters] 401 | bias_shape = [num_filters] 402 | 403 | with tf.variable_scope(name): 404 | 405 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 406 | 407 | op = tf.nn.atrous_conv2d(bottom, filters=weights, rate=rate, padding=padding) 408 | 409 | biases = None 410 | if add_bias: 411 | biases = utils.get_bias_variable(bias_shape, name='b') 412 | op = tf.nn.bias_add(op, biases) 413 | 414 | if not normalise_post_activation: 415 | op = activation(normalisation(op, **kwargs)) 416 | else: 417 | op = normalisation(activation(op), **kwargs) 418 | 419 | if dropout_p is not None: 420 | op = dropout(op, keep_prob=dropout_p, **kwargs) 421 | 422 | # Add Tensorboard summaries 423 | _add_summaries(op, weights, biases) 424 | 425 | return op 426 | 427 | 428 | def residual_unit2D(x, 429 | name, 430 | num_filters=32, 431 | down_sample=False, 432 | projection=False, 433 | activation=STANDARD_NONLINEARITY, 434 | normalisation=tfnorm.batch_norm, 435 | add_bias=True, 436 | **kwargs): 437 | 438 | """ 439 | See https://arxiv.org/abs/1512.03385 440 | """ 441 | 442 | bottom_num_filters = x.get_shape().as_list()[-1] 443 | 444 | adjust_nfilters = True if not bottom_num_filters == num_filters or down_sample else False 445 | 446 | if down_sample: 447 | first_strides = (2,2) 448 | else: 449 | first_strides = (1,1) 450 | 451 | with tf.variable_scope(name): 452 | 453 | conv1 = conv2D(x, 'conv1', num_filters=num_filters, strides=first_strides, activation=tf.identity, add_bias=add_bias) 454 | conv1 = normalisation(conv1, scope='bn1', **kwargs) 455 | conv1 = activation(conv1) 456 | 457 | conv2 = conv2D(conv1, 'conv2', num_filters=num_filters, activation=tf.identity, add_bias=add_bias) 458 | conv2 = normalisation(conv2, scope='bn2', **kwargs) 459 | # conv2 = activation(conv2) 460 | 461 | if adjust_nfilters: 462 | if projection: 463 | projection = conv2D(x, 'projection', num_filters=num_filters, strides=first_strides, kernel_size=(1, 1), activation=tf.identity, add_bias=add_bias) 464 | projection = normalisation(projection, scope='bn_projection', **kwargs) 465 | skip = activation(projection) 466 | else: 467 | pad_size = (num_filters - bottom_num_filters) // 2 468 | identity = tf.pad(x, paddings=[[0, 0], [0, 0], [0, 0], [pad_size, pad_size]]) 469 | if down_sample: 470 | identity = identity[:,::2,::2,:] 471 | skip = identity 472 | else: 473 | skip = x 474 | 475 | block = tf.add(skip, conv2, name='add') 476 | block = activation(block) 477 | 478 | return block 479 | 480 | 481 | def identity_residual_unit2D(x, 482 | name, 483 | num_filters, 484 | down_sample=False, 485 | projection=True, 486 | activation=STANDARD_NONLINEARITY, 487 | normalisation=tfnorm.batch_norm, 488 | add_bias=True, 489 | **kwargs): 490 | 491 | """ 492 | Better residual unit which should allow better gradient flow. 493 | See Identity Mappings in Deep Residual Networks (https://link.springer.com/chapter/10.1007/978-3-319-46493-0_38) 494 | """ 495 | 496 | bottom_num_filters = x.get_shape().as_list()[-1] 497 | 498 | if not projection: 499 | assert (bottom_num_filters == num_filters) or (bottom_num_filters*2 == num_filters), \ 500 | 'Number of filters must remain constant, or be increased by a ' \ 501 | 'factor of 2. In filters: %d, Out filters: %d' % (bottom_num_filters, num_filters) 502 | 503 | increase_nfilters = True if not bottom_num_filters == num_filters or down_sample else False 504 | 505 | if down_sample: 506 | first_strides = (2,2) 507 | else: 508 | first_strides = (1,1) 509 | 510 | with tf.variable_scope(name): 511 | 512 | op1 = normalisation(x, scope='bn1', **kwargs) 513 | op1 = activation(op1) 514 | op1 = conv2D(op1, 'conv1', num_filters=num_filters, strides=first_strides, activation=tf.identity, add_bias=add_bias) 515 | 516 | op2 = normalisation(op1, scope='bn2', **kwargs) 517 | op2 = activation(op2) 518 | op2 = conv2D(op2, 'conv2', num_filters=num_filters, activation=tf.identity, add_bias=add_bias) 519 | 520 | if increase_nfilters: 521 | if projection: 522 | projection = conv2D(x, 'projection', num_filters=num_filters, strides=first_strides, kernel_size=(1, 1), activation=tf.identity, add_bias=add_bias) 523 | projection = normalisation(projection, scope='bn_projection', **kwargs) 524 | skip = activation(projection) 525 | else: 526 | pad_size = (num_filters - bottom_num_filters) // 2 527 | identity = tf.pad(x, paddings=[[0, 0], [0, 0], [0, 0], [pad_size, pad_size]]) 528 | if down_sample: 529 | identity = identity[:,::2,::2,:] 530 | skip = identity 531 | else: 532 | skip = x 533 | 534 | block = tf.add(skip, op2, name='add') 535 | 536 | return block 537 | 538 | 539 | def dense_layer(bottom, 540 | name, 541 | hidden_units=512, 542 | activation=STANDARD_NONLINEARITY, 543 | normalisation=tfnorm.batch_norm, 544 | normalise_post_activation=False, 545 | dropout_p=None, 546 | weight_init='he_normal', 547 | add_bias=True, 548 | **kwargs): 549 | 550 | ''' 551 | Dense a.k.a. fully connected layer 552 | ''' 553 | 554 | bottom_flat = utils.flatten(bottom) 555 | bottom_rhs_dim = utils.get_rhs_dim(bottom_flat) 556 | 557 | weight_shape = [bottom_rhs_dim, hidden_units] 558 | bias_shape = [hidden_units] 559 | 560 | with tf.variable_scope(name): 561 | 562 | weights = utils.get_weight_variable(weight_shape, name='W', type=weight_init, regularize=True) 563 | 564 | op = tf.matmul(bottom_flat, weights) 565 | 566 | biases = None 567 | if add_bias: 568 | biases = utils.get_bias_variable(bias_shape, name='b') 569 | op = tf.nn.bias_add(op, biases) 570 | 571 | if not normalise_post_activation: 572 | op = activation(normalisation(op, **kwargs)) 573 | else: 574 | op = normalisation(activation(op), **kwargs) 575 | 576 | if dropout_p is not None: 577 | op = dropout(op, keep_prob=dropout_p, **kwargs) 578 | 579 | # Add Tensorboard summaries 580 | _add_summaries(op, weights, biases) 581 | 582 | return op 583 | 584 | ## Other layers ## 585 | 586 | def crop_and_concat(inputs, axis=-1): 587 | 588 | ''' 589 | Layer for cropping and stacking feature maps of different size along a different axis. 590 | Currently, the first feature map in the inputs list defines the output size. 591 | The feature maps can have different numbers of channels. 592 | :param inputs: A list of input tensors of the same dimensionality but can have different sizes 593 | :param axis: Axis along which to concatentate the inputs 594 | :return: The concatentated feature map tensor 595 | ''' 596 | 597 | output_size = inputs[0].get_shape().as_list()[1:] 598 | # output_size = tf.shape(inputs[0])[1:] 599 | concat_inputs = [inputs[0]] 600 | 601 | for ii in range(1, len(inputs)): 602 | 603 | larger_size = inputs[ii].get_shape().as_list()[1:] 604 | # larger_size = tf.shape(inputs[ii]) 605 | 606 | # Don't subtract over batch_size because it may be None 607 | start_crop = np.subtract(larger_size, output_size) // 2 608 | 609 | if len(output_size) == 4: # nets3D images 610 | cropped_tensor = tf.slice(inputs[ii], 611 | (0, start_crop[0], start_crop[1], start_crop[2], 0), 612 | (-1, output_size[0], output_size[1], output_size[2], -1)) 613 | elif len(output_size) == 3: # nets2D images 614 | cropped_tensor = tf.slice(inputs[ii], 615 | (0, start_crop[0], start_crop[1], 0), 616 | (-1, output_size[0], output_size[1], -1)) 617 | else: 618 | raise ValueError('Unexpected number of dimensions on tensor: %d' % len(output_size)) 619 | 620 | concat_inputs.append(cropped_tensor) 621 | 622 | return tf.concat(concat_inputs, axis=axis) 623 | 624 | 625 | def pad_to_size(bottom, output_size): 626 | 627 | ''' 628 | A layer used to pad the tensor bottom to output_size by padding zeros around it 629 | TODO: implement for nets3D data 630 | ''' 631 | 632 | input_size = bottom.get_shape().as_list() 633 | size_diff = np.subtract(output_size, input_size) 634 | 635 | pad_size = size_diff // 2 636 | odd_bit = np.mod(size_diff, 2) 637 | 638 | if len(input_size) == 4: 639 | 640 | padded = tf.pad(bottom, paddings=[[0, 0], 641 | [pad_size[1], pad_size[1] + odd_bit[1]], 642 | [pad_size[2], pad_size[2] + odd_bit[2]], 643 | [0, 0]]) 644 | 645 | return padded 646 | 647 | elif len(input_size) == 5: 648 | raise NotImplementedError('This layer has not yet been extended to nets3D') 649 | else: 650 | raise ValueError('Unexpected input size: %d' % input_size) 651 | 652 | 653 | def dropout(bottom, keep_prob, training): 654 | ''' 655 | Performs dropout on the activations of an input 656 | ''' 657 | 658 | with tf.variable_scope('dropout_layer'): 659 | keep_prob_pl = tf.cond(training, 660 | lambda: tf.constant(keep_prob, dtype=bottom.dtype), 661 | lambda: tf.constant(1.0, dtype=bottom.dtype)) 662 | 663 | # The tf.nn.dropout function takes care of all the scaling 664 | # (https://www.tensorflow.org/get_started/mnist/pros) 665 | return tf.nn.dropout(bottom, keep_prob=keep_prob_pl, name='dropout') 666 | 667 | 668 | 669 | ### HELPER FUNCTIONS #################################################################################### 670 | 671 | def _add_summaries(op, weights, biases): 672 | 673 | # Tensorboard variables 674 | tf.summary.histogram(weights.name, weights) 675 | if biases: 676 | tf.summary.histogram(biases.name, biases) 677 | tf.summary.histogram(op.op.name + '/activations', op) -------------------------------------------------------------------------------- /tfwrapper/losses.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | 8 | def get_dice(logits, labels, epsilon=1e-10, sum_over_labels=False, sum_over_batches=False, use_hard_pred=True): 9 | ''' 10 | Dice coefficient per subject per label 11 | :param logits: network output 12 | :param labels: groundtruth labels (one-hot) 13 | :param epsilon: for numerical stability 14 | :param sum_over_labels: Calculate IOU over all labels rather than for each label separately 15 | :param sum_over_batches: Calculate intersection and union over whole batch rather than single images 16 | :param use_hard_pred: If True calculates proper Dice, if False computes Dice based on softmax outputs directly which is differentiable. 17 | :return: tensor shaped (tf.shape(logits)[0], tf.shape(logits)[-1]) (except when sum_over_batches is on) 18 | ''' 19 | 20 | ndims = logits.get_shape().ndims 21 | 22 | prediction = tf.nn.softmax(logits) 23 | if use_hard_pred: 24 | # This casts the predictions to binary 0 or 1 25 | prediction = tf.one_hot(tf.argmax(prediction, axis=-1), depth=tf.shape(prediction)[-1]) 26 | 27 | intersection = tf.multiply(prediction, labels) 28 | 29 | if ndims == 5: 30 | reduction_axes = [1,2,3] 31 | else: 32 | reduction_axes = [1,2] 33 | 34 | if sum_over_batches: 35 | reduction_axes = [0] + reduction_axes 36 | 37 | if sum_over_labels: 38 | reduction_axes += [reduction_axes[-1] + 1] # also sum over the last axis 39 | 40 | # Reduce the maps over all dimensions except the batch and the label index 41 | i = tf.reduce_sum(intersection, axis=reduction_axes) 42 | l = tf.reduce_sum(prediction, axis=reduction_axes) 43 | r = tf.reduce_sum(labels, axis=reduction_axes) 44 | 45 | dice_per_img_per_lab = 2 * i / (l + r + epsilon) 46 | 47 | return dice_per_img_per_lab 48 | 49 | 50 | def dice_loss(logits, labels, epsilon=1e-10, **kwargs): 51 | ''' 52 | The dice loss is always 1 - dice, however, there are many ways to calculate the dice. Basically, there are 53 | three sums involved: 1) over the pixels, 2) over the labels, 3) over the images in a batch. These sums 54 | can be arranged differently to obtain different behaviour. The behaviour can be controlled either by providing 55 | the 'mode' variable, or by setting the parameters directly. 56 | 57 | Selecting the parameters directly: 58 | :param per_structure: If True the Dice is calculated for each label separately first and then averaged. 59 | :param sum_over_batches: If True the Dice is calculated for each batch separately then averaged. 60 | 61 | Selecting the mode: 62 | :param mode: <'macro'|'macro_robust'|'micro'> 63 | macro: Calculate Dice for each label separately then average. This may cause problems 64 | if a structure is completely missing from the image. Even if correctly predicted 65 | the dice will evaluate to 0/epsilon = 0. However, this method automatically tackles 66 | class imbalance, because each structure contributes equally to the final Dice. 67 | macro_robust: The above calculation can be made more robust by summing over all images in a 68 | minibatch. If the label appear at least in one image in the batch and is perfectly 69 | predicted, the Dice will evaluate to 1 as expected. 70 | micro: Calculate Dice for all labels together. This doesn't have the problems of macro for missing 71 | labels. However, it is sensitive to class imbalance because each label contributes 72 | by how often it appears in the data. 73 | 74 | The above are equivalent to F1 score in macro/micro mode (see http://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html) 75 | 76 | Other parameters: 77 | :param logits: Network output 78 | :param labels: Ground-truth labels 79 | :param only_foreground: Sometimes it can be beneficial to ignore label 0 for the optimisation 80 | :epsilon: To avoid division by zero in Dice calculation. 81 | 82 | ''' 83 | 84 | only_foreground = kwargs.get('only_foreground', False) 85 | mode = kwargs.get('mode', None) 86 | if mode == 'macro': 87 | sum_over_labels = False 88 | sum_over_batches = False 89 | elif mode == 'macro_robust': 90 | sum_over_labels = False 91 | sum_over_batches = True 92 | elif mode == 'micro': 93 | sum_over_labels = True 94 | sum_over_batches = False 95 | elif mode is None: 96 | sum_over_labels = kwargs.get('per_structure') # Intentionally no default value 97 | sum_over_batches = kwargs.get('sum_over_batches', False) 98 | else: 99 | raise ValueError("Encountered unexpected 'mode' in dice_loss: '%s'" % mode) 100 | 101 | 102 | with tf.name_scope('dice_loss'): 103 | 104 | dice_per_img_per_lab = get_dice(logits=logits, 105 | labels=labels, 106 | epsilon=epsilon, 107 | sum_over_labels=sum_over_labels, 108 | sum_over_batches=sum_over_batches, 109 | use_hard_pred=False) 110 | 111 | if only_foreground: 112 | if sum_over_batches: 113 | loss = 1 - tf.reduce_mean(dice_per_img_per_lab[1:]) 114 | else: 115 | loss = 1 - tf.reduce_mean(dice_per_img_per_lab[:, 1:]) 116 | else: 117 | loss = 1 - tf.reduce_mean(dice_per_img_per_lab) 118 | 119 | return loss 120 | 121 | 122 | 123 | def cross_entropy_loss(logits, labels, use_sigmoid=False): 124 | ''' 125 | Simple wrapper for the normal tensorflow cross entropy loss 126 | ''' 127 | 128 | if use_sigmoid: 129 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) 130 | else: 131 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)) 132 | 133 | 134 | 135 | def pixel_wise_cross_entropy_loss_weighted(logits, labels, class_weights): 136 | ''' 137 | Weighted cross entropy loss, with a weight per class 138 | :param logits: Network output before softmax 139 | :param labels: Ground truth masks 140 | :param class_weights: A list of the weights for each class 141 | :return: weighted cross entropy loss 142 | ''' 143 | 144 | n_class = len(class_weights) 145 | 146 | flat_logits = tf.reshape(logits, [-1, n_class]) 147 | flat_labels = tf.reshape(labels, [-1, n_class]) 148 | 149 | class_weights = tf.constant(np.array(class_weights, dtype=np.float32)) 150 | 151 | weight_map = tf.multiply(flat_labels, class_weights) 152 | weight_map = tf.reduce_sum(weight_map, axis=1) 153 | 154 | loss_map = tf.nn.softmax_cross_entropy_with_logits(logits=flat_logits, labels=flat_labels) 155 | weighted_loss = tf.multiply(loss_map, weight_map) 156 | 157 | loss = tf.reduce_mean(weighted_loss) 158 | 159 | return loss 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /tfwrapper/normalisation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def instance_norm2D(x, scope='instance_norm', **kwargs): 4 | 5 | with tf.variable_scope(scope): 6 | 7 | depth = x.get_shape()[3] 8 | scale = tf.get_variable('scale', [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 9 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0)) 10 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 11 | epsilon = 1e-5 12 | inv = tf.rsqrt(variance + epsilon) 13 | normalized = (x - mean) * inv 14 | return scale*normalized + offset 15 | 16 | 17 | def group_norm2D(x, eps=1e-5, scope='group_norm', **kwargs) : 18 | 19 | with tf.variable_scope(scope) : 20 | 21 | N = tf.shape(x)[0] 22 | _, H, W, C = x.get_shape().as_list() 23 | # G = min(G, C) 24 | 25 | G = kwargs.get('num_groups', max(2, C // 16)) # 16 channels per group gave good results in paper, but at least 2 26 | 27 | x = tf.reshape(x, tf.stack([N, H, W, G, C // G])) 28 | mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True) 29 | x = (x - mean) / tf.sqrt(var + eps) 30 | 31 | gamma = tf.get_variable('gamma', [1, 1, 1, C], initializer=tf.constant_initializer(1.0)) 32 | beta = tf.get_variable('beta', [1, 1, 1, C], initializer=tf.constant_initializer(0.0)) 33 | 34 | x = tf.reshape(x, [N, H, W, C]) * gamma + beta 35 | 36 | return x 37 | 38 | 39 | def layer_norm(x, 40 | gamma=None, 41 | beta=None, 42 | axes=(1, 2, 3), 43 | eps=1e-3, 44 | scope='layer_norm', 45 | **kwargs): 46 | 47 | """ 48 | Collect mean and variances on x except the first dimension. And apply normalization as below: 49 | x_ = gamma * (x - mean) / sqrt(var + eps) 50 | :param x: Input variable 51 | :param gamma: scaling parameter 52 | :param beta: bias parameter 53 | :param axes: which axes to collect the statistics over (default is correct for 2D conv) 54 | :param eps: Denominator bias 55 | :param name: Name of the layer 56 | :return: Returns the normalised version of x 57 | """ 58 | 59 | with tf.variable_scope(scope): 60 | mean, var = tf.nn.moments(x, axes, name='moments', keep_dims=True) 61 | normed = (x - mean) / tf.sqrt(eps + var) 62 | if gamma is not None: 63 | normed *= gamma 64 | if beta is not None: 65 | normed += beta 66 | normed = tf.identity(normed, name='normed') 67 | 68 | return normed 69 | 70 | 71 | 72 | def batch_renorm(x, training, moving_average_decay=0.99, scope='batch_renorm', **kwargs): 73 | ''' 74 | Batch renormalisation implementation using tf batch normalisation function. 75 | :param x: Input layer (should be before activation) 76 | :param name: A name for the computational graph 77 | :param training: A tf.bool specifying if the layer is executed at training or testing time 78 | :param moving_average_decay: Moving average decay of data set mean and std 79 | :return: Batch normalised activation 80 | ''' 81 | 82 | def parametrize_variable(global_step, y_min, y_max, x_min, x_max): 83 | # Helper function to create a linear increase of a variable from (x_min, y_min) to (x_max, y_max) paramterised 84 | # by the global number of iterations (global_step). 85 | 86 | # if x < x_min: 87 | # return y_min 88 | # elif x > x_max: 89 | # return y_max 90 | # else: 91 | # return (x - x_min) * (y_max - y_min) / (x_max - x_min) + y_min 92 | 93 | x = tf.to_float(global_step) 94 | 95 | def f1(): return tf.constant(y_min) 96 | 97 | def f2(): return tf.constant(y_max) 98 | 99 | def f3(): return ((x - x_min) * (y_max - y_min) / (x_max - x_min)) + y_min 100 | 101 | y = tf.case({tf.less(x, x_min): f1, 102 | tf.greater(x, x_max): f2}, 103 | default=f3, 104 | exclusive=True) 105 | 106 | return y 107 | 108 | ## End helper function 109 | 110 | rmin = 1.0 111 | rmax = 3.0 112 | 113 | dmin = 0.0 114 | dmax = 5.0 115 | 116 | # values /10 from paper because training goes faster for us 117 | x_min_r = 5000.0 / 10 118 | x_max_r = 40000.0 / 10 119 | 120 | x_min_d = 5000.0 / 10 121 | x_max_d = 25000.0 / 10 122 | 123 | global_step = tf.train.get_or_create_global_step() 124 | 125 | clip_r = parametrize_variable(global_step, rmin, rmax, x_min_r, x_max_r) 126 | clip_d = parametrize_variable(global_step, dmin, dmax, x_min_d, x_max_d) 127 | 128 | tf.summary.scalar('rmax_clip', clip_r) 129 | tf.summary.scalar('dmax_clip', clip_d) 130 | 131 | with tf.variable_scope(scope): 132 | 133 | h_bn = tf.contrib.layers.batch_norm(inputs=x, 134 | renorm_decay=moving_average_decay, 135 | epsilon=1e-3, 136 | is_training=training, 137 | center=True, 138 | scale=True, 139 | renorm=True, 140 | renorm_clipping={'rmax': clip_r, 'dmax': clip_d}) 141 | 142 | return h_bn 143 | 144 | 145 | def batch_norm(x, training, moving_average_decay=0.99, scope='batch_norm', **kwargs): 146 | ''' 147 | Wrapper for tensorflows own batch normalisation function. 148 | :param x: Input layer (should be before activation) 149 | :param name: A name for the computational graph 150 | :param training: A tf.bool specifying if the layer is executed at training or testing time 151 | :return: Batch normalised activation 152 | ''' 153 | 154 | with tf.variable_scope(scope): 155 | 156 | h_bn = tf.contrib.layers.batch_norm(inputs=x, 157 | decay=moving_average_decay, 158 | epsilon=1e-3, 159 | is_training=training, 160 | center=True, 161 | scale=True) 162 | 163 | return h_bn 164 | 165 | 166 | def identity(x, **kwargs): 167 | ''' 168 | Wrapper for tf idenity function, which allows to pass extra arguments that might be needed for other normalisers 169 | via kwargs. 170 | ''' 171 | return tf.identity(x) -------------------------------------------------------------------------------- /tfwrapper/utils.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | # Lisa M. Koch (lisa.margret.koch@gmail.com) 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import math 8 | import glob 9 | import os 10 | from tensorflow.contrib.layers import variance_scaling_initializer, xavier_initializer 11 | import logging 12 | from tensorflow.python import pywrap_tensorflow 13 | 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 15 | 16 | def flatten(tensor): 17 | ''' 18 | Flatten the last N-1 dimensions of a tensor only keeping the first one, which is typically 19 | equal to the number of batches. 20 | Example: A tensor of shape [10, 200, 200, 32] becomes [10, 1280000] 21 | ''' 22 | rhs_dim = get_rhs_dim(tensor) 23 | return tf.reshape(tensor, [-1, rhs_dim]) 24 | 25 | def get_rhs_dim(tensor): 26 | ''' 27 | Get the multiplied dimensions of the last N-1 dimensions of a tensor. 28 | I.e. an input tensor with shape [10, 200, 200, 32] leads to an output of 1280000 29 | ''' 30 | shape = tensor.get_shape().as_list() 31 | return np.prod(shape[1:]) 32 | 33 | 34 | def tfndims(t): 35 | return len(t.get_shape().as_list()) 36 | 37 | 38 | def prepare_tensor_for_summary(img, mode, idx=0, nlabels=None, **kwargs): 39 | ''' 40 | Format a tensor containing imgaes or segmentation masks such that it can be used with 41 | tf.summary.image(...) and displayed in tensorboard. 42 | :param img: Input image or segmentation mask 43 | :param mode: Can be either 'image' or 'mask. The two require slightly different slicing 44 | :param idx: Which index of a minibatch to display. By default it's always the first 45 | :param nlabels: Used for the proper rescaling of the label values. If None it scales by the max label.. 46 | :return: Tensor ready to be used with tf.summary.image(...) 47 | ''' 48 | 49 | if mode == 'mask': 50 | 51 | if img.get_shape().ndims == 3: 52 | V = img[idx, ...] 53 | elif img.get_shape().ndims == 4: 54 | slice = kwargs.get('slice', 10) 55 | V = img[idx, ..., slice] 56 | elif img.get_shape().ndims == 5: 57 | slice = kwargs.get('slice', 10) 58 | V = img[idx, ..., slice, 0] 59 | else: 60 | raise ValueError('Dont know how to deal with input dimension %d' % (img.get_shape().ndims)) 61 | 62 | elif mode == 'image': 63 | 64 | if img.get_shape().ndims == 3: 65 | V = img[idx, ...] 66 | elif img.get_shape().ndims == 4: 67 | V = img[idx, ...] 68 | elif img.get_shape().ndims == 5: 69 | slice = kwargs.get('slice', 10) 70 | V = img[idx, ..., slice, 0] 71 | else: 72 | raise ValueError('Dont know how to deal with input dimension %d' % (img.get_shape().ndims)) 73 | 74 | else: 75 | raise ValueError('Unknown mode: %s. Must be image or mask' % mode) 76 | 77 | if mode=='image' or not nlabels: 78 | V -= tf.reduce_min(V) 79 | V /= tf.reduce_max(V) 80 | else: 81 | V /= (nlabels - 1) # The largest value in a label map is nlabels - 1. 82 | 83 | V *= 255 84 | V = tf.cast(V, dtype=tf.uint8) 85 | 86 | img_w = tf.shape(img)[1] 87 | img_h = tf.shape(img)[2] 88 | 89 | V = tf.reshape(V, tf.stack((1, img_w, img_h, -1))) 90 | return V 91 | 92 | 93 | def put_kernels_on_grid(images, batch_size, pad=1, min_int=None, max_int=None, **kwargs): 94 | 95 | '''Visualize conv. filters as an image (mostly for the 1st layer). 96 | Arranges filters into a grid, with some paddings between adjacent filters. 97 | Args: 98 | images: [batch_size, X, Y, channels] 99 | pad: number of black pixels around each filter (between them) 100 | Return: 101 | Tensor of shape [1, (Y+2*pad)*grid_Y, (X+2*pad)*grid_X, NumChannels]. 102 | ''' 103 | 104 | mode = kwargs.get('mode', 'image') 105 | if mode == 'mask': 106 | nlabels = kwargs.get('nlabels') 107 | 108 | # get shape of the grid. NumKernels == grid_Y * grid_X 109 | def factorization(n): 110 | for i in range(int(math.sqrt(float(n))), 0, -1): 111 | if n % i == 0: 112 | if i == 1: 113 | pass 114 | return (i, int(n / i)) 115 | 116 | # (grid_Y, grid_X) = factorization(images.get_shape()[0].value) 117 | # print('grid: %d = (%d, %d)' % (images.get_shape()[0].value, grid_Y, grid_X)) 118 | 119 | (grid_Y, grid_X) = factorization(batch_size) 120 | # print('grid: %d = (%d, %d)' % (batch_size, grid_Y, grid_X)) 121 | 122 | if mode == 'image': 123 | 124 | if not min_int: 125 | x_min = tf.reduce_min(images) 126 | else: 127 | x_min = min_int 128 | 129 | if not max_int: 130 | x_max = tf.reduce_max(images) 131 | else: 132 | x_max = max_int 133 | 134 | # images = tf.cast(images, tf.float32) 135 | # images = (images - x_min) / (x_max - x_min) 136 | images -= x_min 137 | images /= x_max 138 | 139 | elif mode == 'mask': 140 | images /= (nlabels - 1) 141 | else: 142 | raise ValueError("Unknown mode: '%s'" % mode) 143 | 144 | images *= 254.0 # previously had issues with intensities wrapping around, will setting to 254 fix it? 145 | images = tf.cast(images, tf.uint8) 146 | 147 | # pad X and Y 148 | x = tf.pad(images, tf.constant([[0, 0], [pad, pad], [pad, pad],[0, 0]]), mode='CONSTANT') 149 | 150 | # X and Y dimensions, w.r.t. padding 151 | Y = images.get_shape().as_list()[1] + 2 * pad 152 | X = images.get_shape().as_list()[2] + 2 * pad 153 | 154 | channels = images.get_shape()[3] 155 | 156 | # organize grid on Y axis 157 | x = tf.reshape(x, tf.stack([grid_X, Y * grid_Y, X, channels])) 158 | 159 | # switch X and Y axes 160 | x = tf.transpose(x, (0, 2, 1, 3)) 161 | 162 | # organize grid on X axis 163 | x = tf.reshape(x, tf.stack([1, X * grid_X, Y * grid_Y, channels])) 164 | 165 | # Transpose the image again 166 | x = tf.transpose(x, (0, 2, 1, 3)) 167 | 168 | return x 169 | 170 | 171 | def print_tensornames_in_checkpoint_file(file_name): 172 | """ 173 | """ 174 | 175 | reader = pywrap_tensorflow.NewCheckpointReader(file_name) 176 | 177 | var_to_shape_map = reader.get_variable_to_shape_map() 178 | for key in sorted(var_to_shape_map): 179 | print(" - tensor_name: ", key) 180 | #print(reader.get_tensor(key)) 181 | 182 | def get_checkpoint_weights(file_name): 183 | """ 184 | """ 185 | reader = pywrap_tensorflow.NewCheckpointReader(file_name) 186 | return {n: reader.get_tensor(n) for n in reader.get_variable_to_shape_map()} 187 | 188 | 189 | def get_latest_model_checkpoint_path(folder, name): 190 | ''' 191 | Returns the checkpoint with the highest iteration number with a given name 192 | :param folder: Folder where the checkpoints are saved 193 | :param name: Name under which you saved the model 194 | :return: The path to the checkpoint with the latest iteration 195 | ''' 196 | 197 | iteration_nums = [] 198 | for file in glob.glob(os.path.join(folder, '%s*.meta' % name)): 199 | 200 | file = file.split('/')[-1] 201 | file_base, postfix_and_number, rest = file.split('.')[0:3] 202 | it_num = int(postfix_and_number.split('-')[-1]) 203 | 204 | iteration_nums.append(it_num) 205 | 206 | if len(iteration_nums) == 0: 207 | return False 208 | 209 | latest_iteration = np.max(iteration_nums) 210 | return os.path.join(folder, name + '-' + str(latest_iteration)) 211 | 212 | 213 | 214 | def get_weight_variable(shape, name=None, type='xavier_uniform', regularize=True, **kwargs): 215 | 216 | if 'init_weights' in kwargs and kwargs['init_weights'] is not None: 217 | type = 'pretrained' 218 | logging.info('Using pretrained weights for layer: %s' % name) 219 | 220 | initialise_from_constant = False 221 | if type == 'xavier_uniform': 222 | initial = xavier_initializer(uniform=True, dtype=tf.float32) 223 | elif type == 'xavier_normal': 224 | initial = xavier_initializer(uniform=False, dtype=tf.float32) 225 | elif type == 'he_normal': 226 | initial = variance_scaling_initializer(uniform=False, factor=2.0, mode='FAN_IN', dtype=tf.float32) 227 | elif type == 'he_uniform': 228 | initial = variance_scaling_initializer(uniform=True, factor=2.0, mode='FAN_IN', dtype=tf.float32) 229 | elif type == 'caffe_uniform': 230 | initial = variance_scaling_initializer(uniform=True, factor=1.0, mode='FAN_IN', dtype=tf.float32) 231 | elif type == 'simple': 232 | stddev = kwargs.get('stddev', 0.02) 233 | initial = tf.truncated_normal(shape, stddev=stddev, dtype=tf.float32) 234 | initialise_from_constant = True 235 | elif type == 'bilinear': 236 | weights = _bilinear_upsample_weights(shape) 237 | initial = tf.constant(weights, shape=shape, dtype=tf.float32) 238 | initialise_from_constant = True 239 | elif type == 'pretrained': 240 | initial = kwargs.get('init_weights') 241 | initialise_from_constant = True 242 | logging.info('Using pretrained weights for layer: %s' % name) 243 | else: 244 | raise ValueError('Unknown initialisation requested: %s' % type) 245 | 246 | if name is None: # This keeps to option open to use unnamed Variables 247 | weight = tf.Variable(initial) 248 | else: 249 | if initialise_from_constant: 250 | weight = tf.get_variable(name, initializer=initial) 251 | else: 252 | weight = tf.get_variable(name, shape=shape, initializer=initial) 253 | 254 | if regularize: 255 | tf.add_to_collection('weight_variables', weight) 256 | 257 | return weight 258 | 259 | 260 | 261 | def get_bias_variable(shape, name=None, init_value=0.0, **kwargs): 262 | 263 | if 'init_biases' in kwargs and kwargs['init_biases'] is not None: 264 | initial = kwargs['init_biases'] 265 | logging.info('Using pretrained weights for layer: %s' % name) 266 | else: 267 | initial = tf.constant(init_value, shape=shape, dtype=tf.float32) 268 | if name is None: 269 | return tf.Variable(initial) 270 | else: 271 | return tf.get_variable(name, initializer=initial) 272 | 273 | 274 | 275 | def _upsample_filt(size): 276 | ''' 277 | Make a nets2D bilinear kernel suitable for upsampling of the given (h, w) size. 278 | ''' 279 | factor = (size + 1) // 2 280 | if size % 2 == 1: 281 | center = factor - 1 282 | else: 283 | center = factor - 0.5 284 | og = np.ogrid[:size, :size] 285 | return (1 - abs(og[0] - center) / factor) * \ 286 | (1 - abs(og[1] - center) / factor) 287 | 288 | 289 | def _bilinear_upsample_weights(shape): 290 | ''' 291 | Create weights matrix for transposed convolution with bilinear filter 292 | initialization. 293 | ''' 294 | 295 | if not shape[0] == shape[1]: raise ValueError('kernel is not square') 296 | if not shape[2] == shape[3]: raise ValueError('input and output featuremaps must have the same size') 297 | 298 | kernel_size = shape[0] 299 | num_feature_maps = shape[2] 300 | 301 | weights = np.zeros(shape, dtype=np.float32) 302 | upsample_kernel = _upsample_filt(kernel_size) 303 | 304 | for i in range(num_feature_maps): 305 | weights[:, :, i, i] = upsample_kernel 306 | 307 | return weights -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Authors: 2 | # Christian F. Baumgartner (c.f.baumgartner@gmail.com) 3 | # Lisa M. Koch (lisa.margret.koch@gmail.com) 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | import os 8 | import logging 9 | from skimage import measure, transform 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') 11 | from medpy.metric import jc 12 | 13 | try: 14 | import cv2 15 | except: 16 | logging.warning('Could not import opencv. Augmentation functions will be unavailable.') 17 | else: 18 | def rotate_image(img, angle, interp=cv2.INTER_LINEAR): 19 | 20 | rows, cols = img.shape[:2] 21 | rotation_matrix = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) 22 | return cv2.warpAffine(img, rotation_matrix, (cols, rows), flags=interp) 23 | 24 | def rotate_image_as_onehot(img, angle, nlabels, interp=cv2.INTER_LINEAR): 25 | 26 | onehot_output = rotate_image(convert_to_onehot(img, nlabels=nlabels), angle, interp) 27 | return np.argmax(onehot_output, axis=-1) 28 | 29 | def resize_image(im, size, interp=cv2.INTER_LINEAR): 30 | 31 | im_resized = cv2.resize(im, (size[1], size[0]), interpolation=interp) # swap sizes to account for weird OCV API 32 | return im_resized 33 | 34 | def resize_image_as_onehot(im, size, nlabels, interp=cv2.INTER_LINEAR): 35 | 36 | onehot_output = resize_image(convert_to_onehot(im, nlabels), size, interp=interp) 37 | return np.argmax(onehot_output, axis=-1) 38 | 39 | 40 | def deformation_to_transformation(dx, dy): 41 | 42 | nx, ny = dx.shape 43 | 44 | # grid_x, grid_y = np.meshgrid(np.arange(nx), np.arange(ny)) 45 | grid_y, grid_x = np.meshgrid(np.arange(nx), np.arange(ny), indexing="ij") # Robin's change to make it work with non-square images 46 | 47 | map_x = (grid_x + dx).astype(np.float32) 48 | map_y = (grid_y + dy).astype(np.float32) 49 | 50 | return map_x, map_y 51 | 52 | def dense_image_warp(im, dx, dy, interp=cv2.INTER_LINEAR, do_optimisation=True): 53 | 54 | map_x, map_y = deformation_to_transformation(dx, dy) 55 | 56 | # The following command converts the maps to compact fixed point representation 57 | # this leads to a ~20% increase in speed but could lead to accuracy losses 58 | # Can be uncommented 59 | if do_optimisation: 60 | map_x, map_y = cv2.convertMaps(map_x, map_y, dstmap1type=cv2.CV_16SC2) 61 | return cv2.remap(im, map_x, map_y, interpolation=interp, borderMode=cv2.BORDER_REFLECT) #borderValue=float(np.min(im))) 62 | 63 | 64 | def dense_image_warp_as_onehot(im, dx, dy, nlabels, interp=cv2.INTER_LINEAR, do_optimisation=True): 65 | 66 | onehot_output = dense_image_warp(convert_to_onehot(im, nlabels), dx, dy, interp, do_optimisation=do_optimisation) 67 | return np.argmax(onehot_output, axis=-1) 68 | 69 | 70 | def find_floor_in_list(l, t): 71 | # Linear, because not important enough to optimize 72 | 73 | max_smallest = -np.inf 74 | argmax_smallest = None 75 | 76 | for i, n in enumerate(l): 77 | if t >= n and n > max_smallest: 78 | max_smallest = n 79 | argmax_smallest = i 80 | 81 | if argmax_smallest is None: 82 | raise ValueError("All elements in list l are larger than t=%d" % t) 83 | 84 | return max_smallest, argmax_smallest 85 | 86 | def convert_to_onehot(lblmap, nlabels): 87 | 88 | output = np.zeros((lblmap.shape[0], lblmap.shape[1], nlabels)) 89 | for ii in range(nlabels): 90 | output[:,:,ii] = (lblmap == ii).astype(np.uint8) 91 | return output 92 | 93 | def convert_batch_to_onehot(lblbatch, nlabels): 94 | 95 | out = [] 96 | for ii in range(lblbatch.shape[0]): 97 | 98 | lbl = convert_to_onehot(lblbatch[ii,...], nlabels) 99 | out.append(lbl) 100 | 101 | return np.asarray(out) 102 | 103 | def ncc(a,v, zero_norm=True): 104 | 105 | a = a.flatten() 106 | v = v.flatten() 107 | 108 | if zero_norm: 109 | 110 | a = (a - np.mean(a)) / (np.std(a) * len(a)) 111 | v = (v - np.mean(v)) / np.std(v) 112 | 113 | else: 114 | 115 | a = (a) / (np.std(a) * len(a)) 116 | v = (v) / np.std(v) 117 | 118 | return np.correlate(a,v) 119 | 120 | 121 | def norm_l2(a,v): 122 | 123 | a = a.flatten() 124 | v = v.flatten() 125 | 126 | a = (a - np.mean(a)) / (np.std(a) * len(a)) 127 | v = (v - np.mean(v)) / np.std(v) 128 | 129 | return np.mean(np.sqrt(a**2 + v**2)) 130 | 131 | 132 | 133 | def all_argmax(arr, axis=None): 134 | 135 | return np.argwhere(arr == np.amax(arr, axis=axis)) 136 | 137 | 138 | def makefolder(folder): 139 | ''' 140 | Helper function to make a new folder if doesn't exist 141 | :param folder: path to new folder 142 | :return: True if folder created, False if folder already exists 143 | ''' 144 | if not os.path.exists(folder): 145 | os.makedirs(folder) 146 | return True 147 | return False 148 | 149 | def load_nii(img_path): 150 | 151 | ''' 152 | Shortcut to load a nifti file 153 | ''' 154 | 155 | nimg = nib.load(img_path) 156 | return nimg.get_data(), nimg.affine, nimg.header 157 | 158 | def save_nii(img_path, data, affine, header): 159 | ''' 160 | Shortcut to save a nifty file 161 | ''' 162 | 163 | nimg = nib.Nifti1Image(data, affine=affine, header=header) 164 | nimg.to_filename(img_path) 165 | 166 | 167 | def create_and_save_nii(data, img_path): 168 | 169 | img = nib.Nifti1Image(data, np.eye(4)) 170 | nib.save(img, img_path) 171 | 172 | 173 | 174 | class Bunch: 175 | # Useful shortcut for making struct like contructs 176 | # Example: 177 | # mystruct = Bunch(a=1, b=2) 178 | # print(mystruct.a) 179 | # >>> 1 180 | def __init__(self, **kwds): 181 | self.__dict__.update(kwds) 182 | 183 | 184 | 185 | def convert_to_uint8(image): 186 | image = image - image.min() 187 | image = 255.0*np.divide(image.astype(np.float32),image.max()) 188 | return image.astype(np.uint8) 189 | # 190 | 191 | def convert_to_uint8_rgb_fixed(image): 192 | image = (image + 1) * 127.5 193 | image = np.clip(image, 0, 255) 194 | return image.astype(np.uint8) 195 | 196 | def normalise_image(image): 197 | ''' 198 | make image zero mean and unit standard deviation 199 | ''' 200 | 201 | img_o = np.float32(image.copy()) 202 | m = np.mean(img_o) 203 | s = np.std(img_o) 204 | return np.divide((img_o - m), s) 205 | 206 | 207 | def map_image_to_intensity_range(image, min_o, max_o, percentiles=0): 208 | 209 | # If percentile = 0 uses min and max. Percentile >0 makes normalisation more robust to outliers. 210 | 211 | if image.dtype in [np.uint8, np.uint16, np.uint32]: 212 | assert min_o >= 0, 'Input image type is uintXX but you selected a negative min_o: %f' % min_o 213 | 214 | if image.dtype == np.uint8: 215 | assert max_o <= 255, 'Input image type is uint8 but you selected a max_o > 255: %f' % max_o 216 | 217 | min_i = np.percentile(image, 0 + percentiles) 218 | max_i = np.percentile(image, 100 - percentiles) 219 | 220 | image = (np.divide((image - min_i), max_i - min_i) * (max_o - min_o) + min_o).copy() 221 | 222 | image[image > max_o] = max_o 223 | image[image < min_o] = min_o 224 | 225 | return image 226 | 227 | 228 | def map_images_to_intensity_range(X, min_o, max_o, percentiles=0): 229 | 230 | X_mapped = np.zeros(X.shape, dtype=np.float32) 231 | 232 | for ii in range(X.shape[0]): 233 | 234 | Xc = X[ii,...] 235 | X_mapped[ii,...] = map_image_to_intensity_range(Xc, min_o, max_o, percentiles) 236 | 237 | return X_mapped.astype(np.float32) 238 | 239 | 240 | def normalise_images(X): 241 | ''' 242 | Helper for making the images zero mean and unit standard deviation i.e. `white` 243 | ''' 244 | 245 | X_white = np.zeros(X.shape, dtype=np.float32) 246 | 247 | for ii in range(X.shape[0]): 248 | 249 | Xc = X[ii,...] 250 | X_white[ii,...] = normalise_image(Xc) 251 | 252 | return X_white.astype(np.float32) 253 | 254 | def jaccard_onehot(pred, gt): 255 | 256 | # assuming last dimension is classes 257 | 258 | intersection = np.sum(pred*gt) 259 | pred_count = np.sum(pred) 260 | gt_count = np.sum(gt) 261 | 262 | # FN = np.sum((1-pred)*gt) 263 | # FP = np.sum(pred*(1-gt)) 264 | # 265 | # return TP / (TP + FN + FP) 266 | 267 | return intersection / (pred_count + gt_count - intersection) 268 | 269 | 270 | def generalised_energy_distance(sample_arr, gt_arr, nlabels, **kwargs): 271 | 272 | def dist_fct(m1, m2): 273 | 274 | label_range = kwargs.get('label_range', range(nlabels)) 275 | 276 | per_label_iou = [] 277 | for lbl in label_range: 278 | 279 | # assert not lbl == 0 # tmp check 280 | m1_bin = (m1 == lbl)*1 281 | m2_bin = (m2 == lbl)*1 282 | 283 | if np.sum(m1_bin) == 0 and np.sum(m2_bin) == 0: 284 | per_label_iou.append(1) 285 | elif np.sum(m1_bin) > 0 and np.sum(m2_bin) == 0 or np.sum(m1_bin) == 0 and np.sum(m2_bin) > 0: 286 | per_label_iou.append(0) 287 | else: 288 | per_label_iou.append(jc(m1_bin, m2_bin)) 289 | 290 | # print(1-(sum(per_label_iou) / nlabels)) 291 | 292 | return 1-(sum(per_label_iou) / nlabels) 293 | 294 | """ 295 | :param sample_arr: expected shape N x X x Y 296 | :param gt_arr: M x X x Y 297 | :return: 298 | """ 299 | 300 | N = sample_arr.shape[0] 301 | M = gt_arr.shape[0] 302 | 303 | d_sy = [] 304 | d_ss = [] 305 | d_yy = [] 306 | 307 | for i in range(N): 308 | for j in range(M): 309 | # print(dist_fct(sample_arr[i,...], gt_arr[j,...])) 310 | d_sy.append(dist_fct(sample_arr[i,...], gt_arr[j,...])) 311 | 312 | for i in range(N): 313 | for j in range(N): 314 | # print(dist_fct(sample_arr[i,...], sample_arr[j,...])) 315 | d_ss.append(dist_fct(sample_arr[i,...], sample_arr[j,...])) 316 | 317 | for i in range(M): 318 | for j in range(M): 319 | # print(dist_fct(gt_arr[i,...], gt_arr[j,...])) 320 | d_yy.append(dist_fct(gt_arr[i,...], gt_arr[j,...])) 321 | 322 | return (2./(N*M))*sum(d_sy) - (1./N**2)*sum(d_ss) - (1./M**2)*sum(d_yy) 323 | 324 | 325 | # import matplotlib.pyplot as plt 326 | def variance_ncc_dist(sample_arr, gt_arr): 327 | 328 | def pixel_wise_xent(m_samp, m_gt, eps=1e-8): 329 | 330 | 331 | log_samples = np.log(m_samp + eps) 332 | 333 | return -1.0*np.sum(m_gt*log_samples, axis=-1) 334 | 335 | """ 336 | :param sample_arr: expected shape N x X x Y 337 | :param gt_arr: M x X x Y 338 | :return: 339 | """ 340 | 341 | mean_seg = np.mean(sample_arr, axis=0) 342 | 343 | N = sample_arr.shape[0] 344 | M = gt_arr.shape[0] 345 | 346 | sX = sample_arr.shape[1] 347 | sY = sample_arr.shape[2] 348 | 349 | E_ss_arr = np.zeros((N,sX,sY)) 350 | for i in range(N): 351 | E_ss_arr[i,...] = pixel_wise_xent(sample_arr[i,...], mean_seg) 352 | # print('pixel wise xent') 353 | # plt.imshow( E_ss_arr[i,...]) 354 | # plt.show() 355 | 356 | E_ss = np.mean(E_ss_arr, axis=0) 357 | 358 | E_sy_arr = np.zeros((M,N, sX, sY)) 359 | for j in range(M): 360 | for i in range(N): 361 | E_sy_arr[j,i, ...] = pixel_wise_xent(sample_arr[i,...], gt_arr[j,...]) 362 | 363 | E_sy = np.mean(E_sy_arr, axis=1) 364 | 365 | ncc_list = [] 366 | for j in range(M): 367 | 368 | ncc_list.append(ncc(E_ss, E_sy[j,...])) 369 | 370 | return (1/M)*sum(ncc_list) 371 | 372 | 373 | def histogram_equalization(img): 374 | 375 | lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 376 | 377 | # -----Splitting the LAB image to different channels------------------------- 378 | l, a, b = cv2.split(lab) 379 | 380 | # -----Applying CLAHE to L-channel------------------------------------------- 381 | clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) 382 | cl = clahe.apply(l) 383 | 384 | # -----Merge the CLAHE enhanced L-channel with the a and b channel----------- 385 | limg = cv2.merge((cl, a, b)) 386 | 387 | # -----Converting image from LAB Color model to RGB model-------------------- 388 | final = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) 389 | 390 | return final 391 | 392 | def softmax(x): 393 | """Compute softmax values for each sets of scores in x.""" 394 | return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True) 395 | 396 | def list_mean(lst): 397 | 398 | N = len(lst) 399 | return (1./N)*sum(lst) --------------------------------------------------------------------------------