├── .gitignore ├── LICENSE ├── README.md ├── load_LIDC_data.py ├── probabilistic_unet.py ├── train_model.py ├── unet.py ├── unet_blocks.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic UNet in PyTorch 2 | A Probabilistic U-Net for segmentation of ambiguous images implemented in PyTorch. This is a pytorch implementation of this paper https://arxiv.org/abs/1806.05034, for which the code can be found here: https://github.com/SimonKohl/probabilistic_unet. 3 | 4 | ## Adding KL divergence for Independent distribution 5 | In order to implement an Gaussian distribution with an axis aligned covariance matrix in PyTorch, I needed to wrap a Normal distribution in a Independent distribution. Therefore you need the add the following to the PyTorch source code at torch/distributions/kl.py (source: https://github.com/pytorch/pytorch/issues/13545). 6 | 7 | ``` 8 | def _kl_independent_independent(p, q): 9 | if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: 10 | raise NotImplementedError 11 | result = kl_divergence(p.base_dist, q.base_dist) 12 | return _sum_rightmost(result, p.reinterpreted_batch_ndims) 13 | ``` 14 | 15 | ## Training 16 | In order to train your own Probabilistic UNet in PyTorch, you should first write your own data loader. Then you can use the following code snippet to train the network 17 | 18 | ``` 19 | train_loader = define this yourself 20 | net = ProbabilisticUnet(no_channels,no_classes,filter_list,latent_dim,no_fcomb_convs,beta) 21 | net.to(device) 22 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) 23 | for epoch in range(epochs): 24 | for step, (patch, mask) in enumerate(train_loader): 25 | patch = patch.to(device) 26 | mask = mask.to(device) 27 | mask = torch.unsqueeze(mask,1) 28 | net.forward(patch, mask, training=True) 29 | elbo = net.elbo(mask) 30 | reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) 31 | loss = -elbo + 1e-5 * reg_loss 32 | optimizer.zero_grad() 33 | loss.backward() 34 | optimizer.step() 35 | ``` 36 | 37 | ## Train on LIDC Dataset 38 | One of the datasets used in the original paper is the [LIDC dataset](https://wiki.cancerimagingarchive.net). I've preprocessed this data and stored them in a pickle file, which you can [download here](https://drive.google.com/drive/folders/1xKfKCQo8qa6SAr3u7qWNtQjIphIrvmd5?usp=sharing). After downloading the files you should place them in a folder called 'data'. After that, you can train your own Probabilistic UNet on the LIDC dataset using the simple train script provided in train_model.py. 39 | 40 | -------------------------------------------------------------------------------- /load_LIDC_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | import os 6 | import random 7 | import pickle 8 | 9 | class LIDC_IDRI(Dataset): 10 | images = [] 11 | labels = [] 12 | series_uid = [] 13 | 14 | def __init__(self, dataset_location, transform=None): 15 | self.transform = transform 16 | max_bytes = 2**31 - 1 17 | data = {} 18 | for file in os.listdir(dataset_location): 19 | filename = os.fsdecode(file) 20 | if '.pickle' in filename: 21 | print("Loading file", filename) 22 | file_path = dataset_location + filename 23 | bytes_in = bytearray(0) 24 | input_size = os.path.getsize(file_path) 25 | with open(file_path, 'rb') as f_in: 26 | for _ in range(0, input_size, max_bytes): 27 | bytes_in += f_in.read(max_bytes) 28 | new_data = pickle.loads(bytes_in) 29 | data.update(new_data) 30 | 31 | for key, value in data.items(): 32 | self.images.append(value['image'].astype(float)) 33 | self.labels.append(value['masks']) 34 | self.series_uid.append(value['series_uid']) 35 | 36 | assert (len(self.images) == len(self.labels) == len(self.series_uid)) 37 | 38 | for img in self.images: 39 | assert np.max(img) <= 1 and np.min(img) >= 0 40 | for label in self.labels: 41 | assert np.max(label) <= 1 and np.min(label) >= 0 42 | 43 | del new_data 44 | del data 45 | 46 | def __getitem__(self, index): 47 | image = np.expand_dims(self.images[index], axis=0) 48 | 49 | #Randomly select one of the four labels for this image 50 | label = self.labels[index][random.randint(0,3)].astype(float) 51 | if self.transform is not None: 52 | image = self.transform(image) 53 | 54 | series_uid = self.series_uid[index] 55 | 56 | # Convert image and label to torch tensors 57 | image = torch.from_numpy(image) 58 | label = torch.from_numpy(label) 59 | 60 | #Convert uint8 to float tensors 61 | image = image.type(torch.FloatTensor) 62 | label = label.type(torch.FloatTensor) 63 | 64 | return image, label, series_uid 65 | 66 | # Override to give PyTorch size of dataset 67 | def __len__(self): 68 | return len(self.images) -------------------------------------------------------------------------------- /probabilistic_unet.py: -------------------------------------------------------------------------------- 1 | #This code is based on: https://github.com/SimonKohl/probabilistic_unet 2 | 3 | from unet_blocks import * 4 | from unet import Unet 5 | from utils import init_weights,init_weights_orthogonal_normal, l2_regularisation 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal, Independent, kl 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | class Encoder(nn.Module): 12 | """ 13 | A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, 14 | after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied. 15 | """ 16 | def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False): 17 | super(Encoder, self).__init__() 18 | self.contracting_path = nn.ModuleList() 19 | self.input_channels = input_channels 20 | self.num_filters = num_filters 21 | 22 | if posterior: 23 | #To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 24 | self.input_channels += 1 25 | 26 | layers = [] 27 | for i in range(len(self.num_filters)): 28 | """ 29 | Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 30 | All the subsequent layers are output x output. 31 | """ 32 | input_dim = self.input_channels if i == 0 else output_dim 33 | output_dim = num_filters[i] 34 | 35 | if i != 0: 36 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 37 | 38 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding))) 39 | layers.append(nn.ReLU(inplace=True)) 40 | 41 | for _ in range(no_convs_per_block-1): 42 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding))) 43 | layers.append(nn.ReLU(inplace=True)) 44 | 45 | self.layers = nn.Sequential(*layers) 46 | 47 | self.layers.apply(init_weights) 48 | 49 | def forward(self, input): 50 | output = self.layers(input) 51 | return output 52 | 53 | class AxisAlignedConvGaussian(nn.Module): 54 | """ 55 | A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 56 | """ 57 | def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False): 58 | super(AxisAlignedConvGaussian, self).__init__() 59 | self.input_channels = input_channels 60 | self.channel_axis = 1 61 | self.num_filters = num_filters 62 | self.no_convs_per_block = no_convs_per_block 63 | self.latent_dim = latent_dim 64 | self.posterior = posterior 65 | if self.posterior: 66 | self.name = 'Posterior' 67 | else: 68 | self.name = 'Prior' 69 | self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, posterior=self.posterior) 70 | self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1) 71 | self.show_img = 0 72 | self.show_seg = 0 73 | self.show_concat = 0 74 | self.show_enc = 0 75 | self.sum_input = 0 76 | 77 | nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') 78 | nn.init.normal_(self.conv_layer.bias) 79 | 80 | def forward(self, input, segm=None): 81 | 82 | #If segmentation is not none, concatenate the mask to the channel axis of the input 83 | if segm is not None: 84 | self.show_img = input 85 | self.show_seg = segm 86 | input = torch.cat((input, segm), dim=1) 87 | self.show_concat = input 88 | self.sum_input = torch.sum(input) 89 | 90 | encoding = self.encoder(input) 91 | self.show_enc = encoding 92 | 93 | #We only want the mean of the resulting hxw image 94 | encoding = torch.mean(encoding, dim=2, keepdim=True) 95 | encoding = torch.mean(encoding, dim=3, keepdim=True) 96 | 97 | #Convert encoding to 2 x latent dim and split up for mu and log_sigma 98 | mu_log_sigma = self.conv_layer(encoding) 99 | 100 | #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 101 | mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 102 | mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2) 103 | 104 | mu = mu_log_sigma[:,:self.latent_dim] 105 | log_sigma = mu_log_sigma[:,self.latent_dim:] 106 | 107 | #This is a multivariate normal with diagonal covariance matrix sigma 108 | #https://github.com/pytorch/pytorch/pull/11178 109 | dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1) 110 | return dist 111 | 112 | class Fcomb(nn.Module): 113 | """ 114 | A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 115 | and output of the UNet (the feature map) by concatenating them along their channel axis. 116 | """ 117 | def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True): 118 | super(Fcomb, self).__init__() 119 | self.num_channels = num_output_channels #output channels 120 | self.num_classes = num_classes 121 | self.channel_axis = 1 122 | self.spatial_axes = [2,3] 123 | self.num_filters = num_filters 124 | self.latent_dim = latent_dim 125 | self.use_tile = use_tile 126 | self.no_convs_fcomb = no_convs_fcomb 127 | self.name = 'Fcomb' 128 | 129 | if self.use_tile: 130 | layers = [] 131 | 132 | #Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 133 | layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 134 | layers.append(nn.ReLU(inplace=True)) 135 | 136 | for _ in range(no_convs_fcomb-2): 137 | layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 138 | layers.append(nn.ReLU(inplace=True)) 139 | 140 | self.layers = nn.Sequential(*layers) 141 | 142 | self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 143 | 144 | if initializers['w'] == 'orthogonal': 145 | self.layers.apply(init_weights_orthogonal_normal) 146 | self.last_layer.apply(init_weights_orthogonal_normal) 147 | else: 148 | self.layers.apply(init_weights) 149 | self.last_layer.apply(init_weights) 150 | 151 | def tile(self, a, dim, n_tile): 152 | """ 153 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 154 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 155 | """ 156 | init_dim = a.size(dim) 157 | repeat_idx = [1] * a.dim() 158 | repeat_idx[dim] = n_tile 159 | a = a.repeat(*(repeat_idx)) 160 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 161 | return torch.index_select(a, dim, order_index) 162 | 163 | def forward(self, feature_map, z): 164 | """ 165 | Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW. 166 | So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 167 | """ 168 | if self.use_tile: 169 | z = torch.unsqueeze(z,2) 170 | z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 171 | z = torch.unsqueeze(z,3) 172 | z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 173 | 174 | #Concatenate the feature map (output of the UNet) and the sample taken from the latent space 175 | feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 176 | output = self.layers(feature_map) 177 | return self.last_layer(output) 178 | 179 | 180 | class ProbabilisticUnet(nn.Module): 181 | """ 182 | A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation. 183 | input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) 184 | num_classes: the number of classes to predict 185 | num_filters: is a list consisint of the amount of filters layer 186 | latent_dim: dimension of the latent space 187 | no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 188 | """ 189 | 190 | def __init__(self, input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=6, no_convs_fcomb=4, beta=10.0): 191 | super(ProbabilisticUnet, self).__init__() 192 | self.input_channels = input_channels 193 | self.num_classes = num_classes 194 | self.num_filters = num_filters 195 | self.latent_dim = latent_dim 196 | self.no_convs_per_block = 3 197 | self.no_convs_fcomb = no_convs_fcomb 198 | self.initializers = {'w':'he_normal', 'b':'normal'} 199 | self.beta = beta 200 | self.z_prior_sample = 0 201 | 202 | self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, self.initializers, apply_last_layer=False, padding=True).to(device) 203 | self.prior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers,).to(device) 204 | self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True).to(device) 205 | self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, self.num_classes, self.no_convs_fcomb, {'w':'orthogonal', 'b':'normal'}, use_tile=True).to(device) 206 | 207 | def forward(self, patch, segm, training=True): 208 | """ 209 | Construct prior latent space for patch and run patch through UNet, 210 | in case training is True also construct posterior latent space 211 | """ 212 | if training: 213 | self.posterior_latent_space = self.posterior.forward(patch, segm) 214 | self.prior_latent_space = self.prior.forward(patch) 215 | self.unet_features = self.unet.forward(patch,False) 216 | 217 | def sample(self, testing=False): 218 | """ 219 | Sample a segmentation by reconstructing from a prior sample 220 | and combining this with UNet features 221 | """ 222 | if testing == False: 223 | z_prior = self.prior_latent_space.rsample() 224 | self.z_prior_sample = z_prior 225 | else: 226 | #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 227 | #z_prior = self.prior_latent_space.base_dist.loc 228 | z_prior = self.prior_latent_space.sample() 229 | self.z_prior_sample = z_prior 230 | return self.fcomb.forward(self.unet_features,z_prior) 231 | 232 | 233 | def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 234 | """ 235 | Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 236 | use_posterior_mean: use posterior_mean instead of sampling z_q 237 | calculate_posterior: use a provided sample or sample from posterior latent space 238 | """ 239 | if use_posterior_mean: 240 | z_posterior = self.posterior_latent_space.loc 241 | else: 242 | if calculate_posterior: 243 | z_posterior = self.posterior_latent_space.rsample() 244 | return self.fcomb.forward(self.unet_features, z_posterior) 245 | 246 | def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 247 | """ 248 | Calculate the KL divergence between the posterior and prior KL(Q||P) 249 | analytic: calculate KL analytically or via sampling from the posterior 250 | calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 251 | """ 252 | if analytic: 253 | #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 254 | kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space) 255 | else: 256 | if calculate_posterior: 257 | z_posterior = self.posterior_latent_space.rsample() 258 | log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior) 259 | log_prior_prob = self.prior_latent_space.log_prob(z_posterior) 260 | kl_div = log_posterior_prob - log_prior_prob 261 | return kl_div 262 | 263 | def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False): 264 | """ 265 | Calculate the evidence lower bound of the log-likelihood of P(Y|X) 266 | """ 267 | 268 | criterion = nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None) 269 | z_posterior = self.posterior_latent_space.rsample() 270 | 271 | self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)) 272 | 273 | #Here we use the posterior sample sampled above 274 | self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior) 275 | 276 | reconstruction_loss = criterion(input=self.reconstruction, target=segm) 277 | self.reconstruction_loss = torch.sum(reconstruction_loss) 278 | self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 279 | 280 | return -(self.reconstruction_loss + self.beta * self.kl) 281 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from load_LIDC_data import LIDC_IDRI 6 | from probabilistic_unet import ProbabilisticUnet 7 | from utils import l2_regularisation 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | dataset = LIDC_IDRI(dataset_location = 'data/') 11 | dataset_size = len(dataset) 12 | indices = list(range(dataset_size)) 13 | split = int(np.floor(0.1 * dataset_size)) 14 | np.random.shuffle(indices) 15 | train_indices, test_indices = indices[split:], indices[:split] 16 | train_sampler = SubsetRandomSampler(train_indices) 17 | test_sampler = SubsetRandomSampler(test_indices) 18 | train_loader = DataLoader(dataset, batch_size=5, sampler=train_sampler) 19 | test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler) 20 | print("Number of training/test patches:", (len(train_indices),len(test_indices))) 21 | 22 | net = ProbabilisticUnet(input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=2, no_convs_fcomb=4, beta=10.0) 23 | net.to(device) 24 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=0) 25 | epochs = 10 26 | for epoch in range(epochs): 27 | for step, (patch, mask, _) in enumerate(train_loader): 28 | patch = patch.to(device) 29 | mask = mask.to(device) 30 | mask = torch.unsqueeze(mask,1) 31 | net.forward(patch, mask, training=True) 32 | elbo = net.elbo(mask) 33 | reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers) 34 | loss = -elbo + 1e-5 * reg_loss 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | from unet_blocks import * 2 | import torch.nn.functional as F 3 | 4 | class Unet(nn.Module): 5 | """ 6 | A UNet (https://arxiv.org/abs/1505.04597) implementation. 7 | input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) 8 | num_classes: the number of classes to predict 9 | num_filters: list with the amount of filters per layer 10 | apply_last_layer: boolean to apply last layer or not (not used in Probabilistic UNet) 11 | padidng: Boolean, if true we pad the images with 1 so that we keep the same dimensions 12 | """ 13 | 14 | def __init__(self, input_channels, num_classes, num_filters, initializers, apply_last_layer=True, padding=True): 15 | super(Unet, self).__init__() 16 | self.input_channels = input_channels 17 | self.num_classes = num_classes 18 | self.num_filters = num_filters 19 | self.padding = padding 20 | self.activation_maps = [] 21 | self.apply_last_layer = apply_last_layer 22 | self.contracting_path = nn.ModuleList() 23 | 24 | for i in range(len(self.num_filters)): 25 | input = self.input_channels if i == 0 else output 26 | output = self.num_filters[i] 27 | 28 | if i == 0: 29 | pool = False 30 | else: 31 | pool = True 32 | 33 | self.contracting_path.append(DownConvBlock(input, output, initializers, padding, pool=pool)) 34 | 35 | self.upsampling_path = nn.ModuleList() 36 | 37 | n = len(self.num_filters) - 2 38 | for i in range(n, -1, -1): 39 | input = output + self.num_filters[i] 40 | output = self.num_filters[i] 41 | self.upsampling_path.append(UpConvBlock(input, output, initializers, padding)) 42 | 43 | if self.apply_last_layer: 44 | self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1) 45 | #nn.init.kaiming_normal_(self.last_layer.weight, mode='fan_in',nonlinearity='relu') 46 | #nn.init.normal_(self.last_layer.bias) 47 | 48 | 49 | def forward(self, x, val): 50 | blocks = [] 51 | for i, down in enumerate(self.contracting_path): 52 | x = down(x) 53 | if i != len(self.contracting_path)-1: 54 | blocks.append(x) 55 | 56 | for i, up in enumerate(self.upsampling_path): 57 | x = up(x, blocks[-i-1]) 58 | 59 | del blocks 60 | 61 | #Used for saving the activations and plotting 62 | if val: 63 | self.activation_maps.append(x) 64 | 65 | if self.apply_last_layer: 66 | x = self.last_layer(x) 67 | 68 | return x 69 | -------------------------------------------------------------------------------- /unet_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from utils import init_weights 6 | 7 | class DownConvBlock(nn.Module): 8 | """ 9 | A block of three convolutional layers where each layer is followed by a non-linear activation function 10 | Between each block we add a pooling operation. 11 | """ 12 | def __init__(self, input_dim, output_dim, initializers, padding, pool=True): 13 | super(DownConvBlock, self).__init__() 14 | layers = [] 15 | 16 | if pool: 17 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 18 | 19 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 20 | layers.append(nn.ReLU(inplace=True)) 21 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 22 | layers.append(nn.ReLU(inplace=True)) 23 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 24 | layers.append(nn.ReLU(inplace=True)) 25 | 26 | self.layers = nn.Sequential(*layers) 27 | 28 | self.layers.apply(init_weights) 29 | 30 | def forward(self, patch): 31 | return self.layers(patch) 32 | 33 | 34 | class UpConvBlock(nn.Module): 35 | """ 36 | A block consists of an upsampling layer followed by a convolutional layer to reduce the amount of channels and then a DownConvBlock 37 | If bilinear is set to false, we do a transposed convolution instead of upsampling 38 | """ 39 | def __init__(self, input_dim, output_dim, initializers, padding, bilinear=True): 40 | super(UpConvBlock, self).__init__() 41 | self.bilinear = bilinear 42 | 43 | if not self.bilinear: 44 | self.upconv_layer = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2) 45 | self.upconv_layer.apply(init_weights) 46 | 47 | self.conv_block = DownConvBlock(input_dim, output_dim, initializers, padding, pool=False) 48 | 49 | def forward(self, x, bridge): 50 | if self.bilinear: 51 | up = nn.functional.interpolate(x, mode='bilinear', scale_factor=2, align_corners=True) 52 | else: 53 | up = self.upconv_layer(x) 54 | 55 | assert up.shape[3] == bridge.shape[3] 56 | out = torch.cat([up, bridge], 1) 57 | out = self.conv_block(out) 58 | 59 | return out -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | 7 | def truncated_normal_(tensor, mean=0, std=1): 8 | size = tensor.shape 9 | tmp = tensor.new_empty(size + (4,)).normal_() 10 | valid = (tmp < 2) & (tmp > -2) 11 | ind = valid.max(-1, keepdim=True)[1] 12 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 13 | tensor.data.mul_(std).add_(mean) 14 | 15 | def init_weights(m): 16 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 17 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 18 | #nn.init.normal_(m.weight, std=0.001) 19 | #nn.init.normal_(m.bias, std=0.001) 20 | truncated_normal_(m.bias, mean=0, std=0.001) 21 | 22 | def init_weights_orthogonal_normal(m): 23 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 24 | nn.init.orthogonal_(m.weight) 25 | truncated_normal_(m.bias, mean=0, std=0.001) 26 | #nn.init.normal_(m.bias, std=0.001) 27 | 28 | def l2_regularisation(m): 29 | l2_reg = None 30 | 31 | for W in m.parameters(): 32 | if l2_reg is None: 33 | l2_reg = W.norm(2) 34 | else: 35 | l2_reg = l2_reg + W.norm(2) 36 | return l2_reg 37 | 38 | def save_mask_prediction_example(mask, pred, iter): 39 | plt.imshow(pred[0,:,:],cmap='Greys') 40 | plt.savefig('images/'+str(iter)+"_prediction.png") 41 | plt.imshow(mask[0,:,:],cmap='Greys') 42 | plt.savefig('images/'+str(iter)+"_mask.png") --------------------------------------------------------------------------------