├── LICENSE ├── README.md ├── ThirdParty ├── diffusion.py └── layers.py ├── apply_script_diffusion.py ├── data_samples ├── experiment_3D │ ├── Diffusion_3DCTC_CE_epoch=4999.ckpt │ ├── pred_0_sketch3D_sim_CTCCE_0.tif │ ├── pred_1_sketch3D_sim_CTCCE_0.tif │ ├── pred_2_sketch3D_sim_CTCCE_0.tif │ └── pred_3_sketch3D_sim_CTCCE_0.tif ├── image3D_sim_CTCCE_0.h5 ├── image3D_sim_CTCCE_0.tif ├── image_files_3D.csv ├── sketch3D_sim_CTCCE_0.h5 ├── sketch3D_sim_CTCCE_0.tif └── sketch_files_3D.csv ├── dataloader ├── augmenter.py └── h5_dataloader.py ├── figures ├── example_data.png ├── multi-channel.png ├── overlapping_cells.png └── timeseries.gif ├── models ├── DiffusionModel2D.py ├── DiffusionModel3D.py ├── module_UNet2D_pixelshuffle_inject.py └── module_UNet3D_pixelshuffle_inject.py ├── notebooks ├── README.md ├── jupyter_apply_script.ipynb ├── jupyter_preparation_script.ipynb └── jupyter_train_script.ipynb ├── requirements.txt ├── train_script.py └── utils ├── PNAS_sampling.csv ├── csv_generator.py ├── h5_converter.py ├── harmonics.py ├── jupyter_widgets.py ├── synthetic_cell_membrane_masks.py ├── synthetic_cell_nuclei_masks.py ├── theta_phi_sampling_5000points_10000iter.npy └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Denoising Diffusion Probabilistic Models for Generation of Realistic Fully-Annotated Microscopy Image Data Sets 2 |
3 | This repository contains code to simulated 2D/3D cellular structures and synthesize corresponding microscopy image data based on Denoising Diffusion Probabilistic Models (DDPM). 4 | Sketches are generated to indicate cell shapes and structural characteristics, and they serve as a basis for the diffusion process to ultimately allow for the generation of fully-annotated microscopy image data sets without the need for human annotation effort. 5 | Generated data sets are available at OSF and the article is available at PLOS CB. 6 | To access the trained models and get a showcase of the fully-simulated data sets, please visit to our website (work in progress).

7 | Diverse examplary synthetic data samples.
8 | Synthetic multi-channel data sample.Synthetic data sample of overlapping cells.Synthetic timeseries data sample.

9 | Exemplary synthetic samples from our experiments


10 | 11 | 12 | If you are using code or data, please cite the following work: 13 | ``` 14 | @article{eschweiler2024celldiffusion, 15 | title={Denoising diffusion probabilistic models for generation of realistic fully-annotated microscopy image datasets}, 16 | author={Eschweiler, Dennis and Yilmaz, R{\"u}veyda and Baumann, Matisse and Laube, Ina and Roy, Rijo and Jose, Abin and Br{\"u}ckner, Daniel and Stegmaier, Johannes}, 17 | journal={PLOS Computational Biology}, 18 | volume={20}, 19 | number={2}, 20 | pages={e1011890}, 21 | year={2024} 22 | } 23 | ``` 24 |


25 | We provide Jupyter Notebooks that give an overview of how to preprocess your data, train and apply the image generation process. The following gives a very brief overview of the general functionality, for more detailed examples we refer to the notebooks. 26 | 27 | ## Data Preparation 28 | The presented pipelines require the hdf5 file format for processing. Therefore, each image file has to be converted to hdf5, which can be done by using `utils.h5_converter.prepare_images`. Once all files have been converted, a list of those files has to be stored as a csv file to make them accessible by the processing pipelines. This can be done by using `utils.csv_generator.create_csv`. A more detailed explanation is given in `jupyter_preparation_script.ipynb`. 29 | 30 | 31 | ## Diffusion Model Training and Application 32 | To use the proposed pipeline to either train or apply your models, make sure to adapt all parameters in the pipeline files `models/DiffusionModel3D` or `models/DiffusionModel2D`, and in the training script `train_script.py` or application script `apply_script_diffusion.py`. Alternatively, all parameters can be provided as command line arguments. A more detailed explanation is given in `jupyter_train_script.ipynb` and `jupyter_apply_script.ipynb`. 33 | 34 | 35 | ## Simulation of Cellular Structures and Sketch Generation 36 | Since the proposed approach is working in a very intuitive manner, sketches can generally be created in any arbitrary way. We mainly focused on using and adapting simulation techniques proposed in 3D fluorescence microscopy data synthesis for segmentation and benchmarking. Nevertheless, the functionality used in this work can be found in `utils.synthetic_cell_membrane_masks` and `utils.synthetic_cell_nuclei_masks` for cellular membranes and nuclei, respectively. 37 | 38 | ## Processing Times and Hardware Requirements 39 | Due to the layout and working principle of the computation cluster we used, determining precise hardware requirements was difficult. 40 | The used hardware offered varying hardware capabilities, including GPUs ranging from GTX 1080 to RTX 6000. 41 | Nevertheless, the following should give some brief indication of the specifications needed to run the presented pipelines. 42 | 43 | Training times varied between one day for 2D models to roughly a week for 3D models, requiring 2-8 GB of RAM and 4-8 GB of GPU memory. 44 | Prediction times were highly varying, as they are influences by the choice of the backward starting point $t_\mathrm{start}$ and the total image size. 45 | For the proposed starting point $t_\mathrm{start}=400$, image generation times ranged between less than a minute for 2D image data of size 1024x1024 pixel to roughly one hour for image data of size 512x512x180 voxel or 1700x2450x13 voxel. 46 | Memory requirements were similarly varying, ranging between 2-80 GB of RAM and 4-8 GB of GPU memory. 47 | -------------------------------------------------------------------------------- /ThirdParty/diffusion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Code adapted from https://github.com/w86763777/pytorch-ddpm/blob/master/diffusion.py and https://huggingface.co/blog/annotated-diffusion 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | #%% 11 | # Definitions 12 | 13 | class SinusoidalPositionEmbeddings(nn.Module): 14 | def __init__(self, dim): 15 | super().__init__() 16 | self.dim = dim 17 | 18 | def forward(self, t): 19 | device = t.device 20 | half_dim = self.dim // 2 21 | embeddings = math.log(10000) / (half_dim - 1) 22 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 23 | embeddings = t[:, None] * embeddings[None, :] 24 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 25 | return embeddings 26 | 27 | 28 | def beta_schedule(timesteps, schedule='linear'): 29 | 30 | if schedule=='cosine': 31 | """ 32 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672 33 | """ 34 | s = 0.008 35 | steps = timesteps + 1 36 | x = torch.linspace(0, timesteps, steps) 37 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 38 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 39 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 40 | betas = torch.clip(betas, 0.0001, 0.9999) 41 | 42 | if schedule=='linear': 43 | beta_start = 0.0001 44 | beta_end = 0.02 45 | betas = torch.linspace(beta_start, beta_end, timesteps) 46 | 47 | if schedule=='quadratic': 48 | beta_start = 0.0001 49 | beta_end = 0.02 50 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 51 | 52 | if schedule=='sigmoid': 53 | beta_start = 0.0001 54 | beta_end = 0.02 55 | betas = torch.linspace(-6, 6, timesteps) 56 | betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 57 | 58 | return betas.double() 59 | 60 | 61 | def extract(v, t, x_shape): 62 | """ 63 | Extract some coefficients at specified timesteps, then reshape to 64 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 65 | """ 66 | out = torch.gather(v, index=t, dim=0).float() 67 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 68 | 69 | 70 | #%% 71 | # Training 72 | 73 | class GaussianDiffusionTrainer(nn.Module): 74 | def __init__(self, T, schedule='cosine'): 75 | assert schedule in ['cosine', 'linear', 'quadratic', 'sigmoid'], 'Unknown schedule "{0}"'.format(schedule) 76 | super().__init__() 77 | 78 | self.T = T 79 | 80 | 81 | self.register_buffer('betas', beta_schedule(T, schedule=schedule)) 82 | alphas = 1. - self.betas 83 | alphas_bar = torch.cumprod(alphas, dim=0) 84 | 85 | # calculations for diffusion q(x_t | x_{t-1}) and others 86 | self.register_buffer( 87 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) 88 | self.register_buffer( 89 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) 90 | 91 | def forward(self, x_0, t=None): 92 | """ 93 | Algorithm 1. 94 | """ 95 | if t==None: 96 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) 97 | else: 98 | t = x_0.new_ones([x_0.shape[0], ], dtype=torch.long) * t 99 | noise = torch.randn_like(x_0) 100 | x_t = ( 101 | extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + 102 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise) 103 | return x_t, noise, t 104 | 105 | #%% 106 | # Testing 107 | 108 | class GaussianDiffusionSampler(nn.Module): 109 | def __init__(self, model, T, t_start=1000, t_save=100, t_step=1,\ 110 | schedule='linear', mean_type='epsilon', var_type='fixedlarge'): 111 | assert mean_type in ['xprev' 'xstart', 'epsilon'], 'Unknown mean_type "{0}"'.format(mean_type) 112 | assert var_type in ['fixedlarge', 'fixedsmall'], 'Unknown var_type "{0}"'.format(var_type) 113 | assert schedule in ['cosine', 'linear', 'quadratic', 'sigmoid'], 'Unknown schedule "{0}"'.format(schedule) 114 | super().__init__() 115 | 116 | self.model = model 117 | self.T = T 118 | self.t_start = t_start 119 | self.t_save = t_save 120 | self.t_step = t_step 121 | self.schedule = schedule 122 | self.mean_type = mean_type 123 | self.var_type = var_type 124 | 125 | self.register_buffer('betas', beta_schedule(T, schedule=self.schedule)) 126 | alphas = 1. - self.betas 127 | alphas_bar = torch.cumprod(alphas, dim=0) 128 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] 129 | 130 | # calculations for diffusion q(x_t | x_{t-1}) and others 131 | self.register_buffer( 132 | 'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar)) 133 | self.register_buffer( 134 | 'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1)) 135 | 136 | # calculations for posterior q(x_{t-1} | x_t, x_0) 137 | self.register_buffer( 138 | 'posterior_var', 139 | self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) 140 | # below: log calculation clipped because the posterior variance is 0 at 141 | # the beginning of the diffusion chain 142 | self.register_buffer( 143 | 'posterior_log_var_clipped', 144 | torch.log( 145 | torch.cat([self.posterior_var[1:2], self.posterior_var[1:]]))) 146 | self.register_buffer( 147 | 'posterior_mean_coef1', 148 | torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar)) 149 | self.register_buffer( 150 | 'posterior_mean_coef2', 151 | torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar)) 152 | 153 | def q_mean_variance(self, x_0, x_t, t): 154 | """ 155 | Compute the mean and variance of the diffusion posterior 156 | q(x_{t-1} | x_t, x_0) 157 | """ 158 | assert x_0.shape == x_t.shape 159 | posterior_mean = ( 160 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 + 161 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 162 | ) 163 | posterior_log_var_clipped = extract( 164 | self.posterior_log_var_clipped, t, x_t.shape) 165 | return posterior_mean, posterior_log_var_clipped 166 | 167 | def predict_xstart_from_eps(self, x_t, t, eps): 168 | assert x_t.shape == eps.shape 169 | return ( 170 | extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t - 171 | extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps 172 | ) 173 | 174 | def predict_xstart_from_xprev(self, x_t, t, xprev): 175 | assert x_t.shape == xprev.shape 176 | return ( # (xprev - coef2*x_t) / coef1 177 | extract( 178 | 1. / self.posterior_mean_coef1, t, x_t.shape) * xprev - 179 | extract( 180 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t, 181 | x_t.shape) * x_t 182 | ) 183 | 184 | def p_mean_variance(self, x_t, cond, t): 185 | # below: only log_variance is used in the KL computations 186 | model_log_var = { 187 | # for fixedlarge, we set the initial (log-)variance like so to 188 | # get a better decoder log likelihood 189 | 'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2], 190 | self.betas[1:]])), 191 | 'fixedsmall': self.posterior_log_var_clipped, 192 | }[self.var_type] 193 | model_log_var = extract(model_log_var, t, x_t.shape) 194 | 195 | # Mean parameterization 196 | if self.mean_type == 'xprev': # the model predicts x_{t-1} 197 | x_prev = self.model(torch.cat((x_t,cond), axis=1), t) 198 | x_0 = self.predict_xstart_from_xprev(x_t, t, xprev=x_prev) 199 | model_mean = x_prev 200 | elif self.mean_type == 'xstart': # the model predicts x_0 201 | x_0 = self.model(torch.cat((x_t,cond), axis=1), t) 202 | model_mean, _ = self.q_mean_variance(x_0, x_t, t) 203 | elif self.mean_type == 'epsilon': # the model predicts epsilon 204 | eps = self.model(torch.cat((x_t,cond), axis=1), t) 205 | x_0 = self.predict_xstart_from_eps(x_t, t, eps=eps) 206 | model_mean, _ = self.q_mean_variance(x_0, x_t, t) 207 | else: 208 | raise NotImplementedError(self.mean_type) 209 | x_0 = torch.clip(x_0, -1., 1.) 210 | 211 | return model_mean, model_log_var 212 | 213 | def __len__(self): 214 | iterations = torch.linspace(0,self.t_start-1,self.t_start//self.t_step, dtype=int) % self.t_save 215 | return torch.count_nonzero(iterations==0) 216 | 217 | def forward(self, x_T, cond): 218 | """ 219 | Algorithm 2. 220 | """ 221 | x_t = x_T 222 | x_intermediate = [] 223 | for time_step in reversed(torch.linspace(0,self.t_start-1,self.t_start//self.t_step, dtype=int)): 224 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step 225 | mean, log_var = self.p_mean_variance(x_t=x_t, cond=cond, t=t) 226 | # no noise when t == 0 227 | if time_step > 0: 228 | noise = torch.randn_like(x_t) 229 | else: 230 | noise = 0 231 | x_t = mean + torch.exp(0.5 * log_var) * noise 232 | 233 | if time_step%self.t_save == 0: 234 | x_intermediate.append(x_t.cpu()) 235 | print('Processing timestep {0}'.format(time_step)) 236 | x_0 = x_t 237 | return x_0, x_intermediate 238 | 239 | 240 | -------------------------------------------------------------------------------- /ThirdParty/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 20 18:09:03 2020 4 | 5 | """ 6 | 7 | 8 | import math 9 | import numbers 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | 15 | 16 | class PixelShuffle3d(nn.Module): 17 | ''' 18 | reference: https://github.com/gap370/pixelshuffle3d 19 | ''' 20 | ''' 21 | This class is a 3d version of pixelshuffle. 22 | ''' 23 | def __init__(self, scale): 24 | ''' 25 | :param scale: upsample scale 26 | ''' 27 | super().__init__() 28 | self.scale = scale 29 | 30 | def forward(self, input): 31 | batch_size, channels, in_depth, in_height, in_width = input.size() 32 | nOut = channels // self.scale ** 3 33 | 34 | out_depth = in_depth * self.scale 35 | out_height = in_height * self.scale 36 | out_width = in_width * self.scale 37 | 38 | input_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, self.scale, in_depth, in_height, in_width) 39 | 40 | output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() 41 | 42 | return output.view(batch_size, nOut, out_depth, out_height, out_width) 43 | 44 | 45 | -------------------------------------------------------------------------------- /apply_script_diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jan 14 13:46:20 2020 5 | 6 | @author: eschweiler 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | import csv 13 | from skimage import io 14 | from scipy.ndimage import gaussian_filter 15 | from argparse import ArgumentParser 16 | from torch.autograd import Variable 17 | 18 | from dataloader.h5_dataloader import MeristemH5Tiler as Tiler 19 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler 20 | from utils.utils import print_timestamp 21 | 22 | SEED = 1337 23 | torch.manual_seed(SEED) 24 | np.random.seed(SEED) 25 | 26 | 27 | def main(hparams): 28 | 29 | 30 | """ 31 | Main testing routine specific for this project 32 | :param hparams: 33 | """ 34 | 35 | # ------------------------ 36 | # 0 SANITY CHECKS 37 | # ------------------------ 38 | if not isinstance(hparams.overlap, (tuple, list)): 39 | hparams.overlap = (hparams.overlap,) * len(hparams.patch_size) 40 | if not isinstance(hparams.crop, (tuple, list)): 41 | hparams.crop = (hparams.crop,) * len(hparams.patch_size) 42 | assert all([p-2*o-2*c>0 for p,o,c in zip(hparams.patch_size, hparams.overlap, hparams.crop)]), 'Invalid combination of patch size, overlap and crop size.' 43 | 44 | # ------------------------ 45 | # 1 INIT LIGHTNING MODEL 46 | # ------------------------ 47 | model = network(hparams=hparams) 48 | model = model.load_from_checkpoint(hparams.ckpt_path) 49 | model = model.cuda() 50 | 51 | # ------------------------ 52 | # 2 INIT DIFFUSION PARAMETERS 53 | # ------------------------ 54 | device="cuda" if torch.cuda.is_available() else "cpu" 55 | 56 | if 'diffusionmodel' in hparams.pipeline.lower(): 57 | DiffusionTrainer = GaussianDiffusionTrainer(hparams.num_timesteps, schedule=hparams.diffusion_schedule).to(device) 58 | DiffusionSampler = GaussianDiffusionSampler(model, hparams.num_timesteps, t_start=hparams.timesteps_start,\ 59 | t_save=hparams.timesteps_save, t_step=hparams.timesteps_step,\ 60 | schedule=hparams.diffusion_schedule,\ 61 | mean_type='epsilon', var_type='fixedlarge').to(device) 62 | else : 63 | raise NotImplementedError() 64 | 65 | hparams.out_channels *= DiffusionSampler.__len__() 66 | 67 | # ------------------------ 68 | # 3 INIT DATA TILER 69 | # ------------------------ 70 | tiler = Tiler(hparams.test_list, no_mask=hparams.input_batch=='image', no_img=hparams.input_batch=='mask',\ 71 | boundary_handling='none', reduce_dim='2d' in hparams.pipeline.lower(), **vars(hparams)) 72 | fading_map = tiler.get_fading_map() 73 | fading_map = np.repeat(fading_map[np.newaxis,...], hparams.out_channels, axis=0) 74 | 75 | # ------------------------ 76 | # 4 FILE AND FOLDER CHECKS 77 | # ------------------------ 78 | os.makedirs(hparams.output_path, exist_ok=True) 79 | file_checklist = [] 80 | 81 | # ------------------------ 82 | # 5 PROCESS EACH IMAGE 83 | # ------------------------ 84 | if hparams.num_files is None or hparams.num_files < 0: 85 | hparams.num_files = len(tiler.data_list) 86 | else: 87 | hparams.num_files = np.minimum(len(tiler.data_list), hparams.num_files) 88 | 89 | with torch.no_grad(): 90 | 91 | for image_idx in range(hparams.num_files): 92 | 93 | # Get the current state of the processed files 94 | if os.path.isfile(os.path.join(hparams.output_path, 'tmp_file_checklist.csv')): 95 | file_checklist = [] 96 | with open(os.path.join(hparams.output_path, 'tmp_file_checklist.csv'), 'r') as f: 97 | reader = csv.reader(f, delimiter=';') 98 | for row in reader: 99 | if not len(row)==0: 100 | file_checklist.append(row[0]) 101 | 102 | # Check if current file has already been processed 103 | if not any([f==tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1] for f in file_checklist]): 104 | 105 | print_timestamp('_'*20) 106 | print_timestamp('Processing file {0}', [tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1]]) 107 | 108 | # Initialize current file 109 | tiler.set_data_idx(image_idx) 110 | 111 | # Determine if the patch size exceeds the image size 112 | working_size = tuple(np.max(np.array(tiler.locations), axis=0) - np.min(np.array(tiler.locations), axis=0) + np.array(hparams.patch_size)) 113 | 114 | # Initialize maps 115 | predicted_img = np.full((hparams.out_channels,)+working_size, 0, dtype=np.float32) 116 | norm_map = np.full((hparams.out_channels,)+working_size, 0, dtype=np.float32) 117 | 118 | for patch_idx in range(tiler.__len__()): 119 | 120 | print_timestamp('Generating patch {0}/{1}...',(patch_idx+1, tiler.__len__())) 121 | 122 | # Get the input 123 | sample = tiler.__getitem__(patch_idx) 124 | 125 | # Apply gaussian blur to image data 126 | if hparams.blur_sigma>0: 127 | for ndim in range(sample[hparams.input_batch].shape[0]-1): 128 | sample[hparams.input_batch][ndim,...] = gaussian_filter(sample[hparams.input_batch][ndim,...], hparams.blur_sigma, order=0) 129 | 130 | data = Variable(torch.from_numpy(sample[hparams.input_batch][np.newaxis,...]).cuda()) 131 | data = data.float() 132 | 133 | pred_patch = data[:,:-1,...].clone() 134 | pred_patch,_,_ = DiffusionTrainer(pred_patch, hparams.timesteps_start-1) 135 | 136 | cond = data[:,-1:,...].clone() 137 | 138 | _,pred_patch = DiffusionSampler(pred_patch, cond) 139 | 140 | # Convert final patch to numpy for saving 141 | pred_patch = [p.cpu().data.numpy() for p in pred_patch] 142 | pred_patch = np.array(pred_patch) 143 | #if '2d' in hparams.pipeline.lower(): 144 | # pred_patch = pred_patch[:,0,...] #remove batch dimension 145 | #else: 146 | # pred_patch = np.squeeze(pred_patch) 147 | pred_patch = pred_patch[:,0,...] #remove batch dimension 148 | pred_patch = np.reshape(pred_patch, (pred_patch.shape[0]*pred_patch.shape[1],)+pred_patch.shape[2:]) 149 | 150 | #pred_patch = np.squeeze(pred_patch) 151 | pred_patch = np.clip(pred_patch, hparams.clip[0], hparams.clip[1]) 152 | 153 | # Get the current slice position 154 | slicing = tuple(map(slice, (0,)+tuple(tiler.patch_start+tiler.global_crop_before), (hparams.out_channels,)+tuple(tiler.patch_end+tiler.global_crop_before))) 155 | 156 | # Add predicted patch and fading weights to the corresponding maps 157 | predicted_img[slicing] = predicted_img[slicing]+pred_patch*fading_map 158 | norm_map[slicing] = norm_map[slicing]+fading_map 159 | 160 | # Normalize the predicted image 161 | norm_map = np.clip(norm_map, 1e-5, np.inf) 162 | predicted_img = predicted_img / norm_map 163 | 164 | # Crop the predicted image to its original size 165 | slicing = tuple(map(slice, (0,)+tuple(tiler.global_crop_before), (hparams.out_channels,)+tuple(np.array(predicted_img.shape[1:])+np.array(tiler.global_crop_after)))) 166 | predicted_img = predicted_img[slicing] 167 | 168 | # Save the predicted image 169 | predicted_img = np.transpose(predicted_img, (1,2,3,0)) 170 | predicted_img = predicted_img.astype(np.float32) 171 | if hparams.out_channels > 1: 172 | for channel in range(hparams.out_channels): 173 | io.imsave(os.path.join(hparams.output_path, 'pred_'+str(channel)+'_'+os.path.split(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1])[-1][:-3]+'.tif'), predicted_img[...,channel]) 174 | else: 175 | io.imsave(os.path.join(hparams.output_path, 'pred_'+os.path.split(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1])[-1][:-3]+'.tif'), predicted_img[...,0]) 176 | 177 | # Mark current file as processed 178 | file_checklist.append(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1]) 179 | with open(os.path.join(hparams.output_path, 'tmp_file_checklist.csv'), 'w') as f: 180 | writer = csv.writer(f, delimiter=';') 181 | for check_file in file_checklist: 182 | writer.writerow([check_file]) 183 | 184 | else: 185 | print_timestamp('_'*20) 186 | print_timestamp('Skipping file {0}', [tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1]]) 187 | 188 | # Delete temporary checklist if everything has been processed 189 | os.remove(os.path.join(hparams.output_path, 'tmp_file_checklist.csv')) 190 | 191 | 192 | if __name__ == '__main__': 193 | # ------------------------ 194 | # TRAINING ARGUMENTS 195 | # ------------------------ 196 | # these are project-wide arguments 197 | 198 | parent_parser = ArgumentParser(add_help=False) 199 | 200 | parent_parser.add_argument( 201 | '--output_path', 202 | type=str, 203 | default=r'results/experiment1', 204 | help='output path for test results' 205 | ) 206 | 207 | parent_parser.add_argument( 208 | '--ckpt_path', 209 | type=str, 210 | default=r'results/experiment1/checkpoint.ckpt', 211 | help='output path for test results' 212 | ) 213 | 214 | parent_parser.add_argument( 215 | '--gpus', 216 | type=int, 217 | default=1, 218 | help='number of GPUs to use' 219 | ) 220 | 221 | parent_parser.add_argument( 222 | '--overlap', 223 | type=int, 224 | default=(0,0,0), 225 | help='overlap of adjacent patches', 226 | nargs='+' 227 | ) 228 | 229 | parent_parser.add_argument( 230 | '--crop', 231 | type=int, 232 | default=(0,0,0), 233 | help='safety crop of patches', 234 | nargs='+' 235 | ) 236 | 237 | parent_parser.add_argument( 238 | '--input_batch', 239 | type=str, 240 | default='mask', 241 | help='which part of the batch is used as input (image | mask)' 242 | ) 243 | 244 | parent_parser.add_argument( 245 | '--clip', 246 | type=float, 247 | default=(-1000.0, 1000.0), 248 | help='clipping values for network outputs', 249 | nargs='+' 250 | ) 251 | 252 | parent_parser.add_argument( 253 | '--num_files', 254 | type=int, 255 | default=1, 256 | help='number of files to process' 257 | ) 258 | 259 | parent_parser.add_argument( 260 | '--timesteps_start', 261 | type=int, 262 | default=400, 263 | help='number of steps between saves' 264 | ) 265 | 266 | parent_parser.add_argument( 267 | '--timesteps_save', 268 | type=int, 269 | default=100, 270 | help='number of steps between saves' 271 | ) 272 | 273 | parent_parser.add_argument( 274 | '--blur_sigma', 275 | type=int, 276 | default=1, 277 | help='sigma of gaussian blur used on input data' 278 | ) 279 | 280 | parent_parser.add_argument( 281 | '--timesteps_step', 282 | type=int, 283 | default=1, 284 | help='timesteps skipped between iterations' 285 | ) 286 | 287 | parent_parser.add_argument( 288 | '--pipeline', 289 | type=str, 290 | default='DiffusionModel3D', 291 | help='which pipeline to load (DiffusionModel3D | DiffusionModel2D)' 292 | ) 293 | 294 | parent_args = parent_parser.parse_known_args()[0] 295 | 296 | # load the desired network architecture 297 | 298 | if parent_args.pipeline.lower() == 'diffusionmodel3d': 299 | from models.DiffusionModel3D import DiffusionModel3D as network 300 | elif parent_args.pipeline.lower() == 'diffusionmodel2d': 301 | from models.DiffusionModel2D import DiffusionModel2D as network 302 | else: 303 | raise ValueError('Unknown pipeline "{0}".'.format(parent_args.pipeline)) 304 | 305 | # each LightningModule defines arguments relevant to it 306 | parser = network.add_model_specific_args(parent_parser) 307 | hyperparams = parser.parse_args() 308 | 309 | # --------------------- 310 | # RUN TRAINING 311 | # --------------------- 312 | main(hyperparams) 313 | -------------------------------------------------------------------------------- /data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt -------------------------------------------------------------------------------- /data_samples/experiment_3D/pred_0_sketch3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_0_sketch3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/experiment_3D/pred_1_sketch3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_1_sketch3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/experiment_3D/pred_2_sketch3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_2_sketch3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/experiment_3D/pred_3_sketch3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_3_sketch3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/image3D_sim_CTCCE_0.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/image3D_sim_CTCCE_0.h5 -------------------------------------------------------------------------------- /data_samples/image3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/image3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/image_files_3D.csv: -------------------------------------------------------------------------------- 1 | image3D_sim_CTCCE_0.h5 2 | -------------------------------------------------------------------------------- /data_samples/sketch3D_sim_CTCCE_0.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/sketch3D_sim_CTCCE_0.h5 -------------------------------------------------------------------------------- /data_samples/sketch3D_sim_CTCCE_0.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/sketch3D_sim_CTCCE_0.tif -------------------------------------------------------------------------------- /data_samples/sketch_files_3D.csv: -------------------------------------------------------------------------------- 1 | sketch3D_sim_CTCCE_0.h5 2 | -------------------------------------------------------------------------------- /figures/example_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/example_data.png -------------------------------------------------------------------------------- /figures/multi-channel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/multi-channel.png -------------------------------------------------------------------------------- /figures/overlapping_cells.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/overlapping_cells.png -------------------------------------------------------------------------------- /figures/timeseries.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/timeseries.gif -------------------------------------------------------------------------------- /models/DiffusionModel2D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import json 5 | import torch 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | import torchvision 9 | 10 | from argparse import ArgumentParser, Namespace 11 | from torch.utils.data import DataLoader 12 | from dataloader.h5_dataloader import MeristemH5Dataset 13 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler 14 | 15 | 16 | 17 | class DiffusionModel2D(pl.LightningModule): 18 | 19 | def __init__(self, hparams): 20 | super(DiffusionModel2D, self).__init__() 21 | 22 | if type(hparams) is dict: 23 | hparams = Namespace(**hparams) 24 | self.save_hyperparameters(hparams) 25 | self.augmentation_dict = {} 26 | 27 | # load the backbone network architecture 28 | if self.hparams.backbone.lower() == 'unet2d_pixelshuffle_inject': 29 | from models.module_UNet2D_pixelshuffle_inject import module_UNet2D_pixelshuffle_inject as backbone 30 | else: 31 | raise ValueError('Unknown backbone architecture {0}!'.format(self.hparams.backbone)) 32 | 33 | self.network = backbone(patch_size=self.hparams.patch_size, in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels,\ 34 | feat_channels=self.hparams.feat_channels, t_channels=self.hparams.t_channels,\ 35 | out_activation=self.hparams.out_activation, layer_norm=self.hparams.layer_norm) 36 | # cache for generated images 37 | self.last_predictions = None 38 | self.last_imgs = None 39 | 40 | # set up diffusion parameters 41 | device="cuda" if torch.cuda.is_available() else "cpu" 42 | self.DiffusionTrainer = GaussianDiffusionTrainer(self.hparams.num_timesteps, schedule=self.hparams.diffusion_schedule).to(device) 43 | self.DiffusionSampler = GaussianDiffusionSampler(self.network, self.hparams.num_timesteps, t_start=self.hparams.num_timesteps,\ 44 | t_save=self.hparams.num_timesteps, t_step=1, schedule=self.hparams.diffusion_schedule,\ 45 | mean_type='epsilon', var_type='fixedlarge').to(device) 46 | 47 | def forward(self, z, t): 48 | return self.network(z, t) 49 | 50 | 51 | def load_pretrained(self, pretrained_file, strict=True, verbose=True): 52 | 53 | # Load the state dict 54 | state_dict = torch.load(pretrained_file)['state_dict'] 55 | 56 | # Make sure to have a weight dict 57 | if not isinstance(state_dict, dict): 58 | state_dict = dict(state_dict) 59 | 60 | # Get parameter dict of current model 61 | param_dict = dict(self.network.named_parameters()) 62 | 63 | layers = [] 64 | for layer in param_dict: 65 | if strict and not 'network.'+layer in state_dict: 66 | if verbose: 67 | print('Could not find weights for layer "{0}"'.format(layer)) 68 | continue 69 | try: 70 | param_dict[layer].data.copy_(state_dict['network.'+layer].data) 71 | layers.append(layer) 72 | except (RuntimeError, KeyError) as e: 73 | print('Error at layer {0}:\n{1}'.format(layer, e)) 74 | 75 | self.network.load_state_dict(param_dict) 76 | 77 | if verbose: 78 | print('Loaded weights for the following layers:\n{0}'.format(layers)) 79 | 80 | 81 | def denoise_loss(self, y_hat, y): 82 | return F.mse_loss(y_hat, y) 83 | 84 | 85 | def training_step(self, batch, batch_idx): 86 | 87 | # Get image ans mask of current batch 88 | self.last_imgs = batch['image'].float() 89 | 90 | # get x_t, noise for a random t 91 | self.x_t, noise, t = self.DiffusionTrainer(self.last_imgs[:,0:1,...]) 92 | self.x_t.requires_grad = True 93 | 94 | # generate prediction 95 | self.generated_noise = self.forward(torch.cat((self.x_t, self.last_imgs[:,1:,...]), axis=1), t) 96 | 97 | # get the losses 98 | loss_denoise = self.denoise_loss(self.generated_noise, noise) 99 | 100 | self.logger.experiment.add_scalar('loss_denoise', loss_denoise, self.current_epoch) 101 | 102 | return loss_denoise 103 | 104 | 105 | def test_step(self, batch, batch_idx): 106 | x = batch['image'] 107 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index)) 108 | return {'test_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])} 109 | 110 | def test_end(self, outputs): 111 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 112 | tensorboard_logs = {'test_loss': avg_loss} 113 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 114 | 115 | def validation_step(self, batch, batch_idx): 116 | x = batch['image'] 117 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index)) 118 | return {'val_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])} 119 | 120 | def validation_end(self, outputs): 121 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 122 | tensorboard_logs = {'val_loss': avg_loss} 123 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} 124 | 125 | def configure_optimizers(self): 126 | opt = torch.optim.RAdam(self.network.parameters(), lr=self.hparams.learning_rate) 127 | return [opt], [] 128 | 129 | def train_dataloader(self): 130 | if self.hparams.train_list is None: 131 | return None 132 | else: 133 | dataset = MeristemH5Dataset(self.hparams.train_list, self.hparams.data_root, patch_size=self.hparams.patch_size,\ 134 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, reduce_dim=True,\ 135 | augmentation_dict=self.augmentation_dict, samples_per_epoch=self.hparams.samples_per_epoch,\ 136 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none', \ 137 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 138 | return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True, drop_last=True) 139 | 140 | def test_dataloader(self): 141 | if self.hparams.test_list is None: 142 | return None 143 | else: 144 | dataset = MeristemH5Dataset(self.hparams.test_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=True,\ 145 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\ 146 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\ 147 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 148 | return DataLoader(dataset, batch_size=self.hparams.batch_size) 149 | 150 | def val_dataloader(self): 151 | if self.hparams.val_list is None: 152 | return None 153 | else: 154 | dataset = MeristemH5Dataset(self.hparams.val_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=True,\ 155 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\ 156 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\ 157 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 158 | return DataLoader(dataset, batch_size=self.hparams.batch_size) 159 | 160 | 161 | def on_train_epoch_end(self): 162 | 163 | 164 | self.DiffusionSampler.model = self.network 165 | 166 | with torch.no_grad(): 167 | 168 | input_patch = self.last_imgs 169 | 170 | # get x_0 171 | x_0,_ = self.DiffusionSampler(input_patch[:,0:1,...], input_patch[:,1:,...]) 172 | 173 | # log sampled images 174 | prediction_grid = torchvision.utils.make_grid(x_0) 175 | self.logger.experiment.add_image('predicted_x_0', prediction_grid, self.current_epoch) 176 | 177 | img_grid = torchvision.utils.make_grid(input_patch) 178 | self.logger.experiment.add_image('raw_x_0', img_grid, self.current_epoch) 179 | 180 | 181 | def set_augmentations(self, augmentation_dict_file): 182 | if not augmentation_dict_file is None: 183 | self.augmentation_dict = json.load(open(augmentation_dict_file)) 184 | 185 | 186 | @staticmethod 187 | def add_model_specific_args(parent_parser): 188 | """ 189 | Parameters you define here will be available to your model through self.hparams 190 | """ 191 | parser = ArgumentParser(parents=[parent_parser]) 192 | 193 | # network params 194 | parser.add_argument('--backbone', default='UNet2D_PixelShuffle_inject', type=str, help='which model to load (UNet3D_PixelShuffle_inject)') 195 | parser.add_argument('--in_channels', default=2, type=int) 196 | parser.add_argument('--out_channels', default=1, type=int) 197 | parser.add_argument('--feat_channels', default=16, type=int) 198 | parser.add_argument('--t_channels', default=128, type=int) 199 | parser.add_argument('--patch_size', default=(1,256,256), type=int, nargs='+') 200 | parser.add_argument('--layer_norm', default='instance', type=str) 201 | parser.add_argument('--out_activation', default='none', type=str) 202 | 203 | # data 204 | parser.add_argument('--data_norm', default='minmax_shifted', type=str) 205 | parser.add_argument('--data_root', default='/data/root', type=str) 206 | parser.add_argument('--train_list', default='/path/to/training_data/split1_train.csv', type=str) 207 | parser.add_argument('--test_list', default='/path/to/testing_data/split1_test.csv', type=str) 208 | parser.add_argument('--val_list', default='/path/to/validation_data/split1_val.csv', type=str) 209 | parser.add_argument('--image_groups', default=('data/image',), type=str, nargs='+') 210 | parser.add_argument('--mask_groups', default=('data/diffusion_mask',), type=str, nargs='+') 211 | parser.add_argument('--image_noise_channel', default=-1, type=int) 212 | parser.add_argument('--mask_noise_channel', default=-1, type=int) 213 | parser.add_argument('--noise_type', default='gaussian', type=str) 214 | 215 | # diffusion parameter 216 | parser.add_argument('--num_timesteps', default=1000, type=int) 217 | parser.add_argument('--diffusion_schedule', default='cosine', type=str) 218 | 219 | # training params 220 | parser.add_argument('--samples_per_epoch', default=-1, type=int) 221 | parser.add_argument('--batch_size', default=1, type=int) 222 | parser.add_argument('--learning_rate', default=0.001, type=float) 223 | 224 | return parser -------------------------------------------------------------------------------- /models/DiffusionModel3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import json 5 | import torch 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | import torchvision 9 | 10 | from argparse import ArgumentParser, Namespace 11 | from torch.utils.data import DataLoader 12 | from dataloader.h5_dataloader import MeristemH5Dataset 13 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler 14 | 15 | 16 | 17 | class DiffusionModel3D(pl.LightningModule): 18 | 19 | def __init__(self, hparams): 20 | super(DiffusionModel3D, self).__init__() 21 | 22 | if type(hparams) is dict: 23 | hparams = Namespace(**hparams) 24 | self.save_hyperparameters(hparams) 25 | self.augmentation_dict = {} 26 | 27 | # load the backbone network architecture 28 | if self.hparams.backbone.lower() == 'unet3d_pixelshuffle_inject': 29 | from models.module_UNet3D_pixelshuffle_inject import module_UNet3D_pixelshuffle_inject as backbone 30 | else: 31 | raise ValueError('Unknown backbone architecture {0}!'.format(self.hparams.backbone)) 32 | 33 | self.network = backbone(patch_size=self.hparams.patch_size, in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels,\ 34 | feat_channels=self.hparams.feat_channels, t_channels=self.hparams.t_channels,\ 35 | out_activation=self.hparams.out_activation, layer_norm=self.hparams.layer_norm) 36 | # cache for generated images 37 | self.last_predictions = None 38 | self.last_imgs = None 39 | 40 | # set up diffusion parameters 41 | device="cuda" if torch.cuda.is_available() else "cpu" 42 | self.DiffusionTrainer = GaussianDiffusionTrainer(self.hparams.num_timesteps, schedule=self.hparams.diffusion_schedule).to(device) 43 | self.DiffusionSampler = GaussianDiffusionSampler(self.network, self.hparams.num_timesteps, t_start=self.hparams.num_timesteps,\ 44 | t_save=self.hparams.num_timesteps, t_step=1, schedule=self.hparams.diffusion_schedule,\ 45 | mean_type='epsilon', var_type='fixedlarge').to(device) 46 | 47 | def forward(self, z, t): 48 | return self.network(z, t) 49 | 50 | 51 | def load_pretrained(self, pretrained_file, strict=True, verbose=True): 52 | 53 | # Load the state dict 54 | state_dict = torch.load(pretrained_file)['state_dict'] 55 | 56 | # Make sure to have a weight dict 57 | if not isinstance(state_dict, dict): 58 | state_dict = dict(state_dict) 59 | 60 | # Get parameter dict of current model 61 | param_dict = dict(self.network.named_parameters()) 62 | 63 | layers = [] 64 | for layer in param_dict: 65 | if strict and not 'network.'+layer in state_dict: 66 | if verbose: 67 | print('Could not find weights for layer "{0}"'.format(layer)) 68 | continue 69 | try: 70 | param_dict[layer].data.copy_(state_dict['network.'+layer].data) 71 | layers.append(layer) 72 | except (RuntimeError, KeyError) as e: 73 | print('Error at layer {0}:\n{1}'.format(layer, e)) 74 | 75 | self.network.load_state_dict(param_dict) 76 | 77 | if verbose: 78 | print('Loaded weights for the following layers:\n{0}'.format(layers)) 79 | 80 | 81 | def denoise_loss(self, y_hat, y): 82 | return F.mse_loss(y_hat, y) 83 | 84 | 85 | def training_step(self, batch, batch_idx): 86 | 87 | # Get image ans mask of current batch 88 | self.last_imgs = batch['image'].float() 89 | 90 | # get x_t, noise for a random t 91 | self.x_t, noise, t = self.DiffusionTrainer(self.last_imgs[:,0:1,...]) 92 | self.x_t.requires_grad = True 93 | 94 | # generate prediction 95 | self.generated_noise = self.forward(torch.cat((self.x_t, self.last_imgs[:,1:,...]), axis=1), t) 96 | 97 | # get the losses 98 | loss_denoise = self.denoise_loss(self.generated_noise, noise) 99 | 100 | self.logger.experiment.add_scalar('loss_denoise', loss_denoise, self.current_epoch) 101 | 102 | return loss_denoise 103 | 104 | 105 | def test_step(self, batch, batch_idx): 106 | x = batch['image'] 107 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index)) 108 | return {'test_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])} 109 | 110 | def test_end(self, outputs): 111 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() 112 | tensorboard_logs = {'test_loss': avg_loss} 113 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs} 114 | 115 | def validation_step(self, batch, batch_idx): 116 | x = batch['image'] 117 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index)) 118 | return {'val_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])} 119 | 120 | def validation_end(self, outputs): 121 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 122 | tensorboard_logs = {'val_loss': avg_loss} 123 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs} 124 | 125 | def configure_optimizers(self): 126 | opt = torch.optim.RAdam(self.network.parameters(), lr=self.hparams.learning_rate) 127 | return [opt], [] 128 | 129 | def train_dataloader(self): 130 | if self.hparams.train_list is None: 131 | return None 132 | else: 133 | dataset = MeristemH5Dataset(self.hparams.train_list, self.hparams.data_root, patch_size=self.hparams.patch_size,\ 134 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, reduce_dim=False,\ 135 | augmentation_dict=self.augmentation_dict, samples_per_epoch=self.hparams.samples_per_epoch,\ 136 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none', \ 137 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 138 | return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True, drop_last=True) 139 | 140 | def test_dataloader(self): 141 | if self.hparams.test_list is None: 142 | return None 143 | else: 144 | dataset = MeristemH5Dataset(self.hparams.test_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=False,\ 145 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\ 146 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\ 147 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 148 | return DataLoader(dataset, batch_size=self.hparams.batch_size) 149 | 150 | def val_dataloader(self): 151 | if self.hparams.val_list is None: 152 | return None 153 | else: 154 | dataset = MeristemH5Dataset(self.hparams.val_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=False,\ 155 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\ 156 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\ 157 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type) 158 | return DataLoader(dataset, batch_size=self.hparams.batch_size) 159 | 160 | 161 | def on_train_epoch_end(self): 162 | 163 | 164 | self.DiffusionSampler.model = self.network 165 | 166 | with torch.no_grad(): 167 | 168 | input_patch = self.last_imgs 169 | 170 | # get x_0 171 | x_0,_ = self.DiffusionSampler(input_patch[:,0:1,...], input_patch[:,1:,...]) 172 | 173 | # log sampled images 174 | prediction_grid = torchvision.utils.make_grid(x_0[...,int(self.hparams.patch_size[0]//2),:,:]) 175 | self.logger.experiment.add_image('predicted_x_0', prediction_grid, self.current_epoch) 176 | 177 | img_grid = torchvision.utils.make_grid(input_patch[...,int(self.hparams.patch_size[0]//2),:,:]) 178 | self.logger.experiment.add_image('raw_x_0', img_grid, self.current_epoch) 179 | 180 | 181 | def set_augmentations(self, augmentation_dict_file): 182 | if not augmentation_dict_file is None: 183 | self.augmentation_dict = json.load(open(augmentation_dict_file)) 184 | 185 | 186 | @staticmethod 187 | def add_model_specific_args(parent_parser): 188 | """ 189 | Parameters you define here will be available to your model through self.hparams 190 | """ 191 | parser = ArgumentParser(parents=[parent_parser]) 192 | 193 | # network params 194 | parser.add_argument('--backbone', default='UNet3D_PixelShuffle_inject', type=str, help='which model to load (UNet3D_PixelShuffle_inject)') 195 | parser.add_argument('--in_channels', default=2, type=int) 196 | parser.add_argument('--out_channels', default=1, type=int) 197 | parser.add_argument('--feat_channels', default=16, type=int) 198 | parser.add_argument('--t_channels', default=128, type=int) 199 | parser.add_argument('--patch_size', default=(32,128,128), type=int, nargs='+') 200 | parser.add_argument('--layer_norm', default='instance', type=str) 201 | parser.add_argument('--out_activation', default='none', type=str) 202 | 203 | # data 204 | parser.add_argument('--data_norm', default='minmax_shifted', type=str) 205 | parser.add_argument('--data_root', default='../data_samples', type=str) 206 | parser.add_argument('--train_list', default='../data_samples/image_files_3D.csv', type=str) 207 | parser.add_argument('--test_list', default='../data_samples/image_files_3D.csv', type=str) 208 | parser.add_argument('--val_list', default='../data_samples/image_files_3D.csv', type=str) 209 | parser.add_argument('--image_groups', default=('data/image',), type=str, nargs='+') 210 | parser.add_argument('--mask_groups', default=('data/image',), type=str, nargs='+') 211 | parser.add_argument('--image_noise_channel', default=-1, type=int) 212 | parser.add_argument('--mask_noise_channel', default=-1, type=int) 213 | parser.add_argument('--noise_type', default='gaussian', type=str) 214 | 215 | # diffusion parameter 216 | parser.add_argument('--num_timesteps', default=1000, type=int) 217 | parser.add_argument('--diffusion_schedule', default='cosine', type=str) 218 | 219 | # training params 220 | parser.add_argument('--samples_per_epoch', default=-1, type=int) 221 | parser.add_argument('--batch_size', default=1, type=int) 222 | parser.add_argument('--learning_rate', default=0.001, type=float) 223 | 224 | return parser -------------------------------------------------------------------------------- /models/module_UNet2D_pixelshuffle_inject.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implementation of the 3D UNet architecture with PixelShuffle upsampling. 4 | https://arxiv.org/pdf/1609.05158v2.pdf 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | 11 | from ThirdParty.diffusion import SinusoidalPositionEmbeddings 12 | 13 | 14 | 15 | class module_UNet2D_pixelshuffle_inject(nn.Module): 16 | """Implementation of the 3D U-Net architecture. 17 | """ 18 | 19 | def __init__(self, patch_size, in_channels, out_channels, feat_channels=16, t_channels=128, out_activation='sigmoid', layer_norm='none', **kwargs): 20 | super(module_UNet2D_pixelshuffle_inject, self).__init__() 21 | 22 | self.patch_size = patch_size 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.feat_channels = feat_channels 26 | self.t_channels = t_channels 27 | self.layer_norm = layer_norm # instance | batch | none 28 | self.out_activation = out_activation # relu | leakyrelu | sigmoid | tanh | hardtanh | none 29 | 30 | self.norm_methods = { 31 | 'instance': nn.InstanceNorm2d, 32 | 'batch': nn.BatchNorm2d, 33 | 'none': nn.Identity 34 | } 35 | 36 | self.out_activations = nn.ModuleDict({ 37 | 'relu': nn.ReLU(), 38 | 'leakyrelu': nn.LeakyReLU(negative_slope=0.2, inplace=True), 39 | 'sigmoid': nn.Sigmoid(), 40 | 'tanh': nn.Tanh(), 41 | 'hardtanh': nn.Hardtanh(0,1), 42 | 'none': nn.Identity() 43 | }) 44 | 45 | 46 | # Define layer instances 47 | self.t1 = nn.Sequential( 48 | SinusoidalPositionEmbeddings(t_channels), 49 | nn.Linear(t_channels, feat_channels), 50 | nn.PReLU(feat_channels), 51 | nn.Linear(feat_channels, feat_channels) 52 | ) 53 | self.c1 = nn.Sequential( 54 | nn.Conv2d(in_channels, feat_channels//2, kernel_size=3, padding=1), 55 | nn.PReLU(feat_channels//2), 56 | self.norm_methods[self.layer_norm](feat_channels//2), 57 | nn.Conv2d(feat_channels//2, feat_channels, kernel_size=3, padding=1), 58 | nn.PReLU(feat_channels), 59 | self.norm_methods[self.layer_norm](feat_channels) 60 | ) 61 | self.d1 = nn.Sequential( 62 | nn.Conv2d(feat_channels, feat_channels, kernel_size=4, stride=2, padding=1), 63 | nn.PReLU(feat_channels), 64 | self.norm_methods[self.layer_norm](feat_channels) 65 | ) 66 | 67 | 68 | self.t2 = nn.Sequential( 69 | SinusoidalPositionEmbeddings(t_channels), 70 | nn.Linear(t_channels, feat_channels*2), 71 | nn.PReLU(feat_channels*2), 72 | nn.Linear(feat_channels*2, feat_channels*2) 73 | ) 74 | self.c2 = nn.Sequential( 75 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1), 76 | nn.PReLU(feat_channels), 77 | self.norm_methods[self.layer_norm](feat_channels), 78 | nn.Conv2d(feat_channels, feat_channels*2, kernel_size=3, padding=1), 79 | nn.PReLU(feat_channels*2), 80 | self.norm_methods[self.layer_norm](feat_channels*2) 81 | ) 82 | self.d2 = nn.Sequential( 83 | nn.Conv2d(feat_channels*2, feat_channels*2, kernel_size=4, stride=2, padding=1), 84 | nn.PReLU(feat_channels*2), 85 | self.norm_methods[self.layer_norm](feat_channels*2) 86 | ) 87 | 88 | 89 | self.t3 = nn.Sequential( 90 | SinusoidalPositionEmbeddings(t_channels), 91 | nn.Linear(t_channels, feat_channels*4), 92 | nn.PReLU(feat_channels*4), 93 | nn.Linear(feat_channels*4, feat_channels*4) 94 | ) 95 | self.c3 = nn.Sequential( 96 | nn.Conv2d(feat_channels*2, feat_channels*2, kernel_size=3, padding=1), 97 | nn.PReLU(feat_channels*2), 98 | self.norm_methods[self.layer_norm](feat_channels*2), 99 | nn.Conv2d(feat_channels*2, feat_channels*4, kernel_size=3, padding=1), 100 | nn.PReLU(feat_channels*4), 101 | self.norm_methods[self.layer_norm](feat_channels*4) 102 | ) 103 | self.d3 = nn.Sequential( 104 | nn.Conv2d(feat_channels*4, feat_channels*4, kernel_size=4, stride=2, padding=1), 105 | nn.PReLU(feat_channels*4), 106 | self.norm_methods[self.layer_norm](feat_channels*4) 107 | ) 108 | 109 | 110 | self.t4 = nn.Sequential( 111 | SinusoidalPositionEmbeddings(t_channels), 112 | nn.Linear(t_channels, feat_channels*8), 113 | nn.PReLU(feat_channels*8), 114 | nn.Linear(feat_channels*8, feat_channels*8) 115 | ) 116 | self.c4 = nn.Sequential( 117 | nn.Conv2d(feat_channels*4, feat_channels*4, kernel_size=3, padding=1), 118 | nn.PReLU(feat_channels*4), 119 | self.norm_methods[self.layer_norm](feat_channels*4), 120 | nn.Conv2d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1), 121 | nn.PReLU(feat_channels*8), 122 | self.norm_methods[self.layer_norm](feat_channels*8) 123 | ) 124 | 125 | 126 | self.u1 = nn.Sequential( 127 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1), 128 | nn.PReLU(feat_channels*8), 129 | self.norm_methods[self.layer_norm](feat_channels*8), 130 | nn.PixelShuffle(2) 131 | ) 132 | self.t5 = nn.Sequential( 133 | SinusoidalPositionEmbeddings(t_channels), 134 | nn.Linear(t_channels, feat_channels*8), 135 | nn.PReLU(feat_channels*8), 136 | nn.Linear(feat_channels*8, feat_channels*8) 137 | ) 138 | self.c5 = nn.Sequential( 139 | nn.Conv2d(feat_channels*6, feat_channels*8, kernel_size=3, padding=1), 140 | nn.PReLU(feat_channels*8), 141 | self.norm_methods[self.layer_norm](feat_channels*8), 142 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1), 143 | nn.PReLU(feat_channels*8), 144 | self.norm_methods[self.layer_norm](feat_channels*8) 145 | ) 146 | 147 | 148 | self.u2 = nn.Sequential( 149 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1), 150 | nn.PReLU(feat_channels*8), 151 | self.norm_methods[self.layer_norm](feat_channels*8), 152 | nn.PixelShuffle(2) 153 | ) 154 | self.t6 = nn.Sequential( 155 | SinusoidalPositionEmbeddings(t_channels), 156 | nn.Linear(t_channels, feat_channels*8), 157 | nn.PReLU(feat_channels*8), 158 | nn.Linear(feat_channels*8, feat_channels*8) 159 | ) 160 | self.c6 = nn.Sequential( 161 | nn.Conv2d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1), 162 | nn.PReLU(feat_channels*8), 163 | self.norm_methods[self.layer_norm](feat_channels*8), 164 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1), 165 | nn.PReLU(feat_channels*8), 166 | self.norm_methods[self.layer_norm](feat_channels*8) 167 | ) 168 | 169 | 170 | self.u3 = nn.Sequential( 171 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1), 172 | nn.PReLU(feat_channels*8), 173 | self.norm_methods[self.layer_norm](feat_channels*8), 174 | nn.PixelShuffle(2) 175 | ) 176 | self.t7 = nn.Sequential( 177 | SinusoidalPositionEmbeddings(t_channels), 178 | nn.Linear(t_channels, feat_channels), 179 | nn.PReLU(feat_channels), 180 | nn.Linear(feat_channels, feat_channels) 181 | ) 182 | self.c7 = nn.Sequential( 183 | nn.Conv2d(feat_channels*3, feat_channels, kernel_size=3, padding=1), 184 | nn.PReLU(feat_channels), 185 | self.norm_methods[self.layer_norm](feat_channels), 186 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1), 187 | nn.PReLU(feat_channels), 188 | self.norm_methods[self.layer_norm](feat_channels) 189 | ) 190 | 191 | 192 | self.out = nn.Sequential( 193 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1), 194 | nn.PReLU(feat_channels), 195 | self.norm_methods[self.layer_norm](feat_channels), 196 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1), 197 | nn.PReLU(feat_channels), 198 | self.norm_methods[self.layer_norm](feat_channels), 199 | nn.Conv2d(feat_channels, out_channels, kernel_size=1), 200 | self.out_activations[self.out_activation] 201 | ) 202 | 203 | 204 | def forward(self, img, t): 205 | 206 | t1 = self.t1(t) 207 | c1 = self.c1(img)+t1[...,None,None] 208 | d1 = self.d1(c1) 209 | 210 | t2 = self.t2(t) 211 | c2 = self.c2(d1)+t2[...,None,None] 212 | d2 = self.d2(c2) 213 | 214 | t3 = self.t3(t) 215 | c3 = self.c3(d2)+t3[...,None,None] 216 | d3 = self.d3(c3) 217 | 218 | t4 = self.t4(t) 219 | c4 = self.c4(d3)+t4[...,None,None] 220 | 221 | u1 = self.u1(c4) 222 | t5 = self.t5(t) 223 | c5 = self.c5(torch.cat((u1,c3),1))+t5[...,None,None] 224 | 225 | u2 = self.u2(c5) 226 | t6 = self.t6(t) 227 | c6 = self.c6(torch.cat((u2,c2),1))+t6[...,None,None] 228 | 229 | u3 = self.u3(c6) 230 | t7 = self.t7(t) 231 | c7 = self.c7(torch.cat((u3,c1),1))+t7[...,None,None] 232 | 233 | out = self.out(c7) 234 | 235 | return out 236 | 237 | -------------------------------------------------------------------------------- /models/module_UNet3D_pixelshuffle_inject.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implementation of the 3D UNet architecture with PixelShuffle upsampling. 4 | https://arxiv.org/pdf/1609.05158v2.pdf 5 | """ 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | 11 | from ThirdParty.layers import PixelShuffle3d 12 | from ThirdParty.diffusion import SinusoidalPositionEmbeddings 13 | 14 | 15 | 16 | class module_UNet3D_pixelshuffle_inject(nn.Module): 17 | """Implementation of the 3D U-Net architecture. 18 | """ 19 | 20 | def __init__(self, patch_size, in_channels, out_channels, feat_channels=16, t_channels=128, out_activation='sigmoid', layer_norm='none', **kwargs): 21 | super(module_UNet3D_pixelshuffle_inject, self).__init__() 22 | 23 | self.patch_size = patch_size 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.feat_channels = feat_channels 27 | self.t_channels = t_channels 28 | self.layer_norm = layer_norm # instance | batch | none 29 | self.out_activation = out_activation # relu | leakyrelu | sigmoid | tanh | hardtanh | none 30 | 31 | self.norm_methods = { 32 | 'instance': nn.InstanceNorm3d, 33 | 'batch': nn.BatchNorm3d, 34 | 'none': nn.Identity 35 | } 36 | 37 | self.out_activations = nn.ModuleDict({ 38 | 'relu': nn.ReLU(), 39 | 'leakyrelu': nn.LeakyReLU(negative_slope=0.2, inplace=True), 40 | 'sigmoid': nn.Sigmoid(), 41 | 'tanh': nn.Tanh(), 42 | 'hardtanh': nn.Hardtanh(0,1), 43 | 'none': nn.Identity() 44 | }) 45 | 46 | 47 | # Define layer instances 48 | self.t1 = nn.Sequential( 49 | SinusoidalPositionEmbeddings(t_channels), 50 | nn.Linear(t_channels, feat_channels), 51 | nn.PReLU(feat_channels), 52 | nn.Linear(feat_channels, feat_channels) 53 | ) 54 | self.c1 = nn.Sequential( 55 | nn.Conv3d(in_channels, feat_channels//2, kernel_size=3, padding=1), 56 | nn.PReLU(feat_channels//2), 57 | self.norm_methods[self.layer_norm](feat_channels//2), 58 | nn.Conv3d(feat_channels//2, feat_channels, kernel_size=3, padding=1), 59 | nn.PReLU(feat_channels), 60 | self.norm_methods[self.layer_norm](feat_channels) 61 | ) 62 | self.d1 = nn.Sequential( 63 | nn.Conv3d(feat_channels, feat_channels, kernel_size=4, stride=2, padding=1), 64 | nn.PReLU(feat_channels), 65 | self.norm_methods[self.layer_norm](feat_channels) 66 | ) 67 | 68 | 69 | self.t2 = nn.Sequential( 70 | SinusoidalPositionEmbeddings(t_channels), 71 | nn.Linear(t_channels, feat_channels*2), 72 | nn.PReLU(feat_channels*2), 73 | nn.Linear(feat_channels*2, feat_channels*2) 74 | ) 75 | self.c2 = nn.Sequential( 76 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1), 77 | nn.PReLU(feat_channels), 78 | self.norm_methods[self.layer_norm](feat_channels), 79 | nn.Conv3d(feat_channels, feat_channels*2, kernel_size=3, padding=1), 80 | nn.PReLU(feat_channels*2), 81 | self.norm_methods[self.layer_norm](feat_channels*2) 82 | ) 83 | self.d2 = nn.Sequential( 84 | nn.Conv3d(feat_channels*2, feat_channels*2, kernel_size=4, stride=2, padding=1), 85 | nn.PReLU(feat_channels*2), 86 | self.norm_methods[self.layer_norm](feat_channels*2) 87 | ) 88 | 89 | 90 | self.t3 = nn.Sequential( 91 | SinusoidalPositionEmbeddings(t_channels), 92 | nn.Linear(t_channels, feat_channels*4), 93 | nn.PReLU(feat_channels*4), 94 | nn.Linear(feat_channels*4, feat_channels*4) 95 | ) 96 | self.c3 = nn.Sequential( 97 | nn.Conv3d(feat_channels*2, feat_channels*2, kernel_size=3, padding=1), 98 | nn.PReLU(feat_channels*2), 99 | self.norm_methods[self.layer_norm](feat_channels*2), 100 | nn.Conv3d(feat_channels*2, feat_channels*4, kernel_size=3, padding=1), 101 | nn.PReLU(feat_channels*4), 102 | self.norm_methods[self.layer_norm](feat_channels*4) 103 | ) 104 | self.d3 = nn.Sequential( 105 | nn.Conv3d(feat_channels*4, feat_channels*4, kernel_size=4, stride=2, padding=1), 106 | nn.PReLU(feat_channels*4), 107 | self.norm_methods[self.layer_norm](feat_channels*4) 108 | ) 109 | 110 | 111 | self.t4 = nn.Sequential( 112 | SinusoidalPositionEmbeddings(t_channels), 113 | nn.Linear(t_channels, feat_channels*8), 114 | nn.PReLU(feat_channels*8), 115 | nn.Linear(feat_channels*8, feat_channels*8) 116 | ) 117 | self.c4 = nn.Sequential( 118 | nn.Conv3d(feat_channels*4, feat_channels*4, kernel_size=3, padding=1), 119 | nn.PReLU(feat_channels*4), 120 | self.norm_methods[self.layer_norm](feat_channels*4), 121 | nn.Conv3d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1), 122 | nn.PReLU(feat_channels*8), 123 | self.norm_methods[self.layer_norm](feat_channels*8) 124 | ) 125 | 126 | 127 | self.u1 = nn.Sequential( 128 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1), 129 | nn.PReLU(feat_channels*8), 130 | self.norm_methods[self.layer_norm](feat_channels*8), 131 | PixelShuffle3d(2) 132 | ) 133 | self.t5 = nn.Sequential( 134 | SinusoidalPositionEmbeddings(t_channels), 135 | nn.Linear(t_channels, feat_channels*8), 136 | nn.PReLU(feat_channels*8), 137 | nn.Linear(feat_channels*8, feat_channels*8) 138 | ) 139 | self.c5 = nn.Sequential( 140 | nn.Conv3d(feat_channels*5, feat_channels*8, kernel_size=3, padding=1), 141 | nn.PReLU(feat_channels*8), 142 | self.norm_methods[self.layer_norm](feat_channels*8), 143 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1), 144 | nn.PReLU(feat_channels*8), 145 | self.norm_methods[self.layer_norm](feat_channels*8) 146 | ) 147 | 148 | 149 | self.u2 = nn.Sequential( 150 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1), 151 | nn.PReLU(feat_channels*8), 152 | self.norm_methods[self.layer_norm](feat_channels*8), 153 | PixelShuffle3d(2) 154 | ) 155 | self.t6 = nn.Sequential( 156 | SinusoidalPositionEmbeddings(t_channels), 157 | nn.Linear(t_channels, feat_channels*8), 158 | nn.PReLU(feat_channels*8), 159 | nn.Linear(feat_channels*8, feat_channels*8) 160 | ) 161 | self.c6 = nn.Sequential( 162 | nn.Conv3d(feat_channels*3, feat_channels*8, kernel_size=3, padding=1), 163 | nn.PReLU(feat_channels*8), 164 | self.norm_methods[self.layer_norm](feat_channels*8), 165 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1), 166 | nn.PReLU(feat_channels*8), 167 | self.norm_methods[self.layer_norm](feat_channels*8) 168 | ) 169 | 170 | 171 | self.u3 = nn.Sequential( 172 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1), 173 | nn.PReLU(feat_channels*8), 174 | self.norm_methods[self.layer_norm](feat_channels*8), 175 | PixelShuffle3d(2) 176 | ) 177 | self.t7 = nn.Sequential( 178 | SinusoidalPositionEmbeddings(t_channels), 179 | nn.Linear(t_channels, feat_channels), 180 | nn.PReLU(feat_channels), 181 | nn.Linear(feat_channels, feat_channels) 182 | ) 183 | self.c7 = nn.Sequential( 184 | nn.Conv3d(feat_channels*2, feat_channels, kernel_size=3, padding=1), 185 | nn.PReLU(feat_channels), 186 | self.norm_methods[self.layer_norm](feat_channels), 187 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1), 188 | nn.PReLU(feat_channels), 189 | self.norm_methods[self.layer_norm](feat_channels) 190 | ) 191 | 192 | 193 | self.out = nn.Sequential( 194 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1), 195 | nn.PReLU(feat_channels), 196 | self.norm_methods[self.layer_norm](feat_channels), 197 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1), 198 | nn.PReLU(feat_channels), 199 | self.norm_methods[self.layer_norm](feat_channels), 200 | nn.Conv3d(feat_channels, out_channels, kernel_size=1), 201 | self.out_activations[self.out_activation] 202 | ) 203 | 204 | 205 | def forward(self, img, t): 206 | 207 | t1 = self.t1(t) 208 | c1 = self.c1(img)+t1[...,None,None,None] 209 | d1 = self.d1(c1) 210 | 211 | t2 = self.t2(t) 212 | c2 = self.c2(d1)+t2[...,None,None,None] 213 | d2 = self.d2(c2) 214 | 215 | t3 = self.t3(t) 216 | c3 = self.c3(d2)+t3[...,None,None,None] 217 | d3 = self.d3(c3) 218 | 219 | t4 = self.t4(t) 220 | c4 = self.c4(d3)+t4[...,None,None,None] 221 | 222 | u1 = self.u1(c4) 223 | t5 = self.t5(t) 224 | c5 = self.c5(torch.cat((u1,c3),1))+t5[...,None,None,None] 225 | 226 | u2 = self.u2(c5) 227 | t6 = self.t6(t) 228 | c6 = self.c6(torch.cat((u2,c2),1))+t6[...,None,None,None] 229 | 230 | u3 = self.u3(c6) 231 | t7 = self.t7(t) 232 | c7 = self.c7(torch.cat((u3,c1),1))+t7[...,None,None,None] 233 | 234 | out = self.out(c7) 235 | 236 | return out 237 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | By providing jupyter notebooks we try to give a clear overview of how the data preparation, training and application pipelines work. 2 | Parameters that are important to these pipelines and might require further clarification are listed and explained in the following. 3 | 4 | #### General Network Parameters 5 | - `in_channels`: (int) Number of input channels, including an additional noise channel 6 | - `out_channels`: (int) Number of output channels 7 | - `feat_channels`: (int) Number of feature channels in the first block of the UNet backbone. Other channel numbers are derived from this 8 | - `t_channels`: (int) Number of channels used for timestep encoding 9 | - `patch_size`: (int+) Size of the patches. For 2D make sure to still provide a 3D patch size with a leading 1 with (1,y,x) 10 | - `data_root`: (str) Directory of the image data 11 | - `train_list`, `test_list`, `val_list`: (str) Path to the csv file listing all image data that should be used for training/testing/validation. The concatenation of the _data_root_ parameter and the entries of this list should give the full path to each file. 12 | - `image_groups`: (str+) List of group names listed in the hdf5 file that should be used 13 | - `num_timesteps`: (int) Number of total diffusion timesteps 14 | - `diffusion_schedule`: (str) Noise schedule used during the diffusion forward process 15 | 16 | #### Training-specific Parameters 17 | - `output_path`: (str) Directory for saving the model 18 | - `log_path`: (str) Directory for saving the log files 19 | - `no_resume`: (bool) Flag to not resume training from an existing checkpoint with the same name as the current training experiment 20 | - `pretrained`: (str) Explicit definition of a pretrained model for further training 21 | - `epochs`: (int) Number of training epochs 22 | - `samples_per_epoch`: (int) Number of samples used in one training epoch. Set to -1 to use all available samples 23 | 24 | #### Application-specific Parameters 25 | - `output_path`: (str) Directory for saving the generated image data 26 | - `ckpt_path`: (str) Path to the trained model checkpoint file 27 | - `overlap`: (int+) Tuple of overlap of neighbouring patches during the patch-based application, defined in z,y,x 28 | - `crop`: (int+) Tuple of cropped patch borders to avoid potential border artifacts, defined in z,y,x 29 | - `clip`: (int+) Intensity range the output gets clipped to 30 | - `timesteps_start`: (int) Timestep for initiating the diffusion backward process 31 | - `timesteps_save`: (int) Interval of timesteps after which (intermediate) results are saved 32 | - `timesteps_step`: (int) Number of timesteps skipped between consecutive iterations of the diffusion backward process 33 | - `blur_sigma`: (int) Sigma of the Gaussian blurring applied before starting the diffusion forward process 34 | - `num_files`: (int) Number of files that should be processed. Set to -1 to process all available files 35 | -------------------------------------------------------------------------------- /notebooks/jupyter_apply_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3f58caaa", 6 | "metadata": {}, 7 | "source": [ 8 | "# Application Pipeline" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "5be99b70", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Imports\n", 19 | "\n", 20 | "import os\n", 21 | "import sys\n", 22 | "import torch\n", 23 | "import ipywidgets as wg\n", 24 | "from IPython.display import display, Javascript \n", 25 | "from argparse import ArgumentParser, Namespace\n", 26 | "\n", 27 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 28 | "from utils.jupyter_widgets import get_pipelin_widget, get_apply_parameter_widgets, get_execution_widgets" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "4a674686", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# Check for GPU\n", 39 | "\n", 40 | "use_cuda = torch.cuda.is_available()\n", 41 | "if use_cuda:\n", 42 | " print('The following GPU was found:\\n')\n", 43 | " print('CUDNN VERSION:', torch.backends.cudnn.version())\n", 44 | " print('Number CUDA Devices:', torch.cuda.device_count())\n", 45 | " print('CUDA Device Name:',torch.cuda.get_device_name(0))\n", 46 | " print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)\n", 47 | "else:\n", 48 | " print('No GPU was found. CPU will be used.')\n", 49 | "\n", 50 | "# Select a pipeline \n", 51 | "pipeline = get_pipelin_widget()\n", 52 | "display(pipeline)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "8162637c", 58 | "metadata": {}, 59 | "source": [ 60 | "---\n", 61 | "After executing the next block, please adapt all parameters accordingly.\n", 62 | "The pipeline expects a list of files that should be used for testing. \n", 63 | "Absolute paths to each files are automatically obtained by concatenating the provided data root and each entry of the file lists. When changing the selected pipeline, please again execute the following block.
\n", 64 | "A pretrained model is already provided with the repository for demonstration purposes. Further pretrained models can be downloaded from our website." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "22365964", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Define general inference parameters\n", 75 | "params = {'output_path': '../data_samples/experiment_3D',\n", 76 | " 'ckpt_path': '../data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt',\n", 77 | " 'gpus': use_cuda,\n", 78 | " 'overlap': (0,20,20),\n", 79 | " 'crop': (0,20,20),\n", 80 | " 'input_batch': 'image',\n", 81 | " 'clip': (-1.0, 1.0),\n", 82 | " 'num_files':-1,\n", 83 | " 'add_noise_channel': False,\n", 84 | " 'pipeline': pipeline.value,\n", 85 | " }\n", 86 | "\n", 87 | "params_diff = {'timesteps_start': 400,\n", 88 | " 'timesteps_save': 100,\n", 89 | " 'timesteps_step': 1,\n", 90 | " 'blur_sigma': 1\n", 91 | " }\n", 92 | "\n", 93 | "\n", 94 | "# Load selected pipeline\n", 95 | "if params['pipeline'].lower() == 'diffusionmodel3d':\n", 96 | " from models.DiffusionModel3D import DiffusionModel3D as network\n", 97 | " params.update(params_diff)\n", 98 | "elif params['pipeline'].lower() == 'diffusionmodel2d':\n", 99 | " from models.DiffusionModel2D import DiffusionModel2D as network\n", 100 | " params.update(params_diff)\n", 101 | "else:\n", 102 | " raise ValueError('Pipeline {0} unknown.'.format(params['pipeline']))\n", 103 | "\n", 104 | " \n", 105 | "# Get and show corresponding parameters\n", 106 | "pipeline_args = ArgumentParser(add_help=False)\n", 107 | "pipeline_args = network.add_model_specific_args(pipeline_args)\n", 108 | "pipeline_args = vars(pipeline_args.parse_known_args()[0])\n", 109 | "params = {**params, **pipeline_args}\n", 110 | "\n", 111 | "print('-'*60+'\\nPARAMETER FOR PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n", 112 | "param_names, widget_list = get_apply_parameter_widgets(params)\n", 113 | "for widget in widget_list: \n", 114 | " display(widget)\n", 115 | " \n", 116 | "print('-'*60+'\\nEXECUTION SETTINGS\\n'+'-'*60)\n", 117 | "wg_execute, wg_arguments = get_execution_widgets()\n", 118 | "display(wg_arguments)\n", 119 | "display(wg_execute)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "19adcedb", 125 | "metadata": {}, 126 | "source": [ 127 | "---\n", 128 | "Finish preparations and start processing by executing the next block. The outputs are expected to be in the value range (-1,1)." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "ec280ea3", 135 | "metadata": { 136 | "scrolled": false 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "# Get parameters\n", 141 | "param_names = [p for p,w in zip(param_names, widget_list) if w.value!=False and w.value!='']\n", 142 | "widget_list = [w for w in widget_list if w.value!=False and w.value!='']\n", 143 | "command_line_args = ' '.join(['--pipeline {0}'.format(pipeline.value)]+\\\n", 144 | " [n+' '+str(w.value) if not type(w.value)==bool else n\\\n", 145 | " for n,w in zip(param_names, widget_list)])\n", 146 | "\n", 147 | "# Show the command line arguments\n", 148 | "if wg_arguments.value:\n", 149 | " print('_'*90+'\\nCOMMAND LINE ARGUMENTS FOR apply_script_diffusion.py WITH PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*90)\n", 150 | " print(command_line_args)\n", 151 | " print('\\n')\n", 152 | " \n", 153 | "# Execute the pipeline\n", 154 | "if wg_execute.value:\n", 155 | " print('_'*60+'\\nEXECUTING PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n", 156 | " %run \"../apply_script_diffusion.py\" {command_line_args}" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3 (ipykernel)", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.7.13" 177 | }, 178 | "varInspector": { 179 | "cols": { 180 | "lenName": 16, 181 | "lenType": 16, 182 | "lenVar": 40 183 | }, 184 | "kernels_config": { 185 | "python": { 186 | "delete_cmd_postfix": "", 187 | "delete_cmd_prefix": "del ", 188 | "library": "var_list.py", 189 | "varRefreshCmd": "print(var_dic_list())" 190 | }, 191 | "r": { 192 | "delete_cmd_postfix": ") ", 193 | "delete_cmd_prefix": "rm(", 194 | "library": "var_list.r", 195 | "varRefreshCmd": "cat(var_dic_list()) " 196 | } 197 | }, 198 | "types_to_exclude": [ 199 | "module", 200 | "function", 201 | "builtin_function_or_method", 202 | "instance", 203 | "_Feature" 204 | ], 205 | "window_display": false 206 | }, 207 | "vscode": { 208 | "interpreter": { 209 | "hash": "3d6afa663d3b7d8b7c28e0e5bf1fc62360d26f74485b03653b2cb99921ca431b" 210 | } 211 | } 212 | }, 213 | "nbformat": 4, 214 | "nbformat_minor": 5 215 | } 216 | -------------------------------------------------------------------------------- /notebooks/jupyter_preparation_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e877057c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Data Preparation Pipeline " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "722a8a44", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Imports\n", 19 | "\n", 20 | "import os\n", 21 | "import glob\n", 22 | "from skimage import io\n", 23 | "\n", 24 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 25 | "from utils.h5_converter import prepare_images\n", 26 | "from utils.csv_generator import create_csv" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "66084504", 32 | "metadata": {}, 33 | "source": [ 34 | "Both, the training and application scripts are designed to use the hdf5 data format. Therefore, all data samples need to be converted into this data format first and the function `prepare_images` can be used for that purpose.
\n", 35 | "Two example data samples are availabe in the _data_samples_ folder and the following function saves the converted data samples to the same folder. If another directory or folder is desired, the parameters _save_path_ and _save_folders_ can additionally be adapted accordingly." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "44dbe16b", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Convert images to h5\n", 46 | "prepare_images(data_path=r'../', folders=['data_samples',], identifier='*.tif')" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "a3eb2027", 52 | "metadata": {}, 53 | "source": [ 54 | "As a next step, the converted data samples need to be listed in a csv file to make the accessible to the training and application pipelines.
\n", 55 | "For creating the file lists, the function `create_csv` can be used. In case a test or valdation split is desired, the parameters _test_split_ and _val_split_ can be adapted to provide percentages of the respective splits." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "c7864ed0", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# Get a list of all files\n", 66 | "image_files = glob.glob(r'../data_samples/image3D*.h5')\n", 67 | "sketch_files = glob.glob(r'../data_samples/sketch3D*.h5')\n", 68 | "\n", 69 | "# Save the lists as csv file\n", 70 | "create_csv([[os.path.split(i)[-1],] for i in image_files],\\\n", 71 | " save_path=r'../data_samples/image_files_3D', test_split=0, val_split=0)\n", 72 | "create_csv([[os.path.split(s)[-1]] for s in sketch_files],\\\n", 73 | " save_path=r'../data_samples/sketch_files_3D', test_split=0, val_split=0)" 74 | ] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "Python 3 (ipykernel)", 80 | "language": "python", 81 | "name": "python3" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.7.13" 94 | }, 95 | "varInspector": { 96 | "cols": { 97 | "lenName": 16, 98 | "lenType": 16, 99 | "lenVar": 40 100 | }, 101 | "kernels_config": { 102 | "python": { 103 | "delete_cmd_postfix": "", 104 | "delete_cmd_prefix": "del ", 105 | "library": "var_list.py", 106 | "varRefreshCmd": "print(var_dic_list())" 107 | }, 108 | "r": { 109 | "delete_cmd_postfix": ") ", 110 | "delete_cmd_prefix": "rm(", 111 | "library": "var_list.r", 112 | "varRefreshCmd": "cat(var_dic_list()) " 113 | } 114 | }, 115 | "types_to_exclude": [ 116 | "module", 117 | "function", 118 | "builtin_function_or_method", 119 | "instance", 120 | "_Feature" 121 | ], 122 | "window_display": false 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 5 127 | } 128 | -------------------------------------------------------------------------------- /notebooks/jupyter_train_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c5a553ab", 6 | "metadata": {}, 7 | "source": [ 8 | "# Training Pipeline " 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "9914fa37", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Imports\n", 19 | "\n", 20 | "import os\n", 21 | "import sys\n", 22 | "import torch\n", 23 | "import ipywidgets as wg\n", 24 | "from IPython.display import display, Javascript \n", 25 | "from argparse import ArgumentParser, Namespace\n", 26 | "\n", 27 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n", 28 | "from utils.jupyter_widgets import get_pipelin_widget, get_parameter_widgets, get_execution_widgets" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "6dd050db", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# Check for GPU\n", 39 | "\n", 40 | "use_cuda = torch.cuda.is_available()\n", 41 | "if use_cuda:\n", 42 | " print('The following GPU was found:\\n')\n", 43 | " print('CUDNN VERSION:', torch.backends.cudnn.version())\n", 44 | " print('Number CUDA Devices:', torch.cuda.device_count())\n", 45 | " print('CUDA Device Name:',torch.cuda.get_device_name(0))\n", 46 | " print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)\n", 47 | "else:\n", 48 | " print('No GPU was found. CPU will be used.')\n", 49 | "\n", 50 | "# Select a pipeline \n", 51 | "pipeline = get_pipelin_widget()\n", 52 | "display(pipeline)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "fb2c9954", 58 | "metadata": {}, 59 | "source": [ 60 | "---\n", 61 | "After executing the next block, please adapt all parameters accordingly.\n", 62 | "The pipeline expects lists of files that should be used for training, testing and validation. \n", 63 | "Absolute paths to each files are automatically obtained by concatenating the provided data root and each entry of the file lists.\n", 64 | "When changing the selected pipeline, please again execute the following block." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "d5642dad", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Define general training parameters\n", 75 | "params = {'output_path': '../data_samples/experiment_3D',\n", 76 | " 'log_path': '../data_samples/experiment_3D/logs',\n", 77 | " 'gpus': use_cuda,\n", 78 | " 'no_resume': False,\n", 79 | " 'pretrained': None,\n", 80 | " 'augmentations': None,\n", 81 | " 'epochs': 5000,\n", 82 | " 'pipeline': pipeline.value}\n", 83 | "\n", 84 | "# Load selected pipeline\n", 85 | "if params['pipeline'].lower() == 'diffusionmodel3d':\n", 86 | " from models.DiffusionModel3D import DiffusionModel3D as network\n", 87 | "elif params['pipeline'].lower() == 'diffusionmodel2d':\n", 88 | " from models.DiffusionModel2D import DiffusionModel2D as network\n", 89 | "else:\n", 90 | " raise ValueError('Pipeline {0} unknown.'.format(params['pipeline']))\n", 91 | "# Get and show corresponding parameters\n", 92 | "pipeline_args = ArgumentParser(add_help=False)\n", 93 | "pipeline_args = network.add_model_specific_args(pipeline_args)\n", 94 | "pipeline_args = vars(pipeline_args.parse_known_args()[0])\n", 95 | "params = {**params, **pipeline_args}\n", 96 | "\n", 97 | "print('-'*60+'\\nPARAMETER FOR PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n", 98 | "param_names, widget_list, _ = get_parameter_widgets(params)\n", 99 | "for widget in widget_list: \n", 100 | " display(widget)\n", 101 | " \n", 102 | "print('-'*60+'\\nEXECUTION SETTINGS\\n'+'-'*60)\n", 103 | "wg_execute, wg_arguments = get_execution_widgets()\n", 104 | "display(wg_arguments)\n", 105 | "display(wg_execute)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "1efbe8ae", 111 | "metadata": {}, 112 | "source": [ 113 | "---\n", 114 | "Finish preparations and start processing by executing the next block." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "2daffbc0", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Get parameters\n", 125 | "param_names = [p for p,w in zip(param_names, widget_list) if w.value!=False and w.value!='']\n", 126 | "widget_list = [w for w in widget_list if w.value!=False and w.value!='']\n", 127 | "command_line_args = ' '.join(['--pipeline {0}'.format(pipeline.value)]+\\\n", 128 | " [n+' '+str(w.value) if not type(w.value)==bool else n\\\n", 129 | " for n,w in zip(param_names, widget_list)])\n", 130 | "\n", 131 | "# Show the command line arguments\n", 132 | "if wg_arguments.value:\n", 133 | " print('_'*90+'\\nCOMMAND LINE ARGUMENTS FOR train_script.py WITH PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*90)\n", 134 | " print(command_line_args)\n", 135 | " print('\\n')\n", 136 | " \n", 137 | "# Execute the pipeline\n", 138 | "if wg_execute.value:\n", 139 | " print('_'*60+'\\nEXECUTING PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n", 140 | " %run \"../train_script.py\" {command_line_args}" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 3 (ipykernel)", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.7.13" 161 | }, 162 | "varInspector": { 163 | "cols": { 164 | "lenName": 16, 165 | "lenType": 16, 166 | "lenVar": 40 167 | }, 168 | "kernels_config": { 169 | "python": { 170 | "delete_cmd_postfix": "", 171 | "delete_cmd_prefix": "del ", 172 | "library": "var_list.py", 173 | "varRefreshCmd": "print(var_dic_list())" 174 | }, 175 | "r": { 176 | "delete_cmd_postfix": ") ", 177 | "delete_cmd_prefix": "rm(", 178 | "library": "var_list.r", 179 | "varRefreshCmd": "cat(var_dic_list()) " 180 | } 181 | }, 182 | "types_to_exclude": [ 183 | "module", 184 | "function", 185 | "builtin_function_or_method", 186 | "instance", 187 | "_Feature" 188 | ], 189 | "window_display": false 190 | }, 191 | "vscode": { 192 | "interpreter": { 193 | "hash": "3d6afa663d3b7d8b7c28e0e5bf1fc62360d26f74485b03653b2cb99921ca431b" 194 | } 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 5 199 | } 200 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | alabaster==0.7.12 5 | appdirs==1.4.4 6 | argon2-cffi==21.3.0 7 | argon2-cffi-bindings==21.2.0 8 | arrow==1.2.2 9 | astroid==2.11.5 10 | astropy==4.3.1 11 | async-timeout==4.0.2 12 | asynctest==0.13.0 13 | atomicwrites==1.4.0 14 | attrs==21.4.0 15 | autopep8==1.6.0 16 | Babel==2.10.1 17 | backcall==0.2.0 18 | beautifulsoup4==4.11.1 19 | binaryornot==0.4.4 20 | black==22.3.0 21 | bleach==5.0.0 22 | cachetools==5.2.0 23 | certifi==2022.5.18.1 24 | cffi==1.15.0 25 | chardet==4.0.0 26 | charset-normalizer==2.0.12 27 | click==8.1.3 28 | cloudpickle==2.1.0 29 | cookiecutter==2.1.1 30 | cryptography==37.0.2 31 | cycler==0.11.0 32 | Cython==0.29.30 33 | debugpy==1.6.0 34 | decorator==5.1.1 35 | defusedxml==0.7.1 36 | diff-match-patch==20200713 37 | dill==0.3.5.1 38 | dipy==1.5.0 39 | docutils==0.17.1 40 | entrypoints==0.4 41 | fastjsonschema==2.15.3 42 | flake8==4.0.1 43 | fonttools==4.33.3 44 | frozenlist==1.3.0 45 | fsspec==2022.5.0 46 | google-auth==2.7.0 47 | google-auth-oauthlib==0.4.6 48 | grpcio==1.46.3 49 | gryds @ git+https://github.com/tueimage/gryds@cda4bac8f71e8bb47fc632b8cdea010904ae5cf1 50 | h5py==3.7.0 51 | hdbscan==0.8.28 52 | idna==3.3 53 | imagecodecs==2021.11.20 54 | imageio==2.19.3 55 | imagesize==1.3.0 56 | importlib-metadata==4.2.0 57 | importlib-resources==5.7.1 58 | inflection==0.5.1 59 | intervaltree==3.1.0 60 | ipykernel==6.13.1 61 | ipython==7.34.0 62 | ipython-genutils==0.2.0 63 | ipywidgets==7.7.0 64 | isort==5.10.1 65 | jedi==0.18.1 66 | jeepney==0.8.0 67 | jellyfish==0.9.0 68 | Jinja2==3.1.2 69 | jinja2-time==0.2.0 70 | joblib==1.1.0 71 | jsonschema==4.6.0 72 | jupyter==1.0.0 73 | jupyter-client==7.3.4 74 | jupyter-console==6.4.3 75 | jupyter-core==4.10.0 76 | jupyterlab-pygments==0.2.2 77 | jupyterlab-widgets==1.1.0 78 | keyring==23.6.0 79 | kiwisolver==1.4.3 80 | lazy-object-proxy==1.7.1 81 | Markdown==3.3.7 82 | MarkupSafe==2.1.1 83 | matplotlib==3.5.2 84 | matplotlib-inline==0.1.3 85 | mccabe==0.6.1 86 | mistune==0.8.4 87 | multidict==6.0.2 88 | mypy-extensions==0.4.3 89 | nbclient==0.6.4 90 | nbconvert==6.5.0 91 | nbformat==5.4.0 92 | nest-asyncio==1.5.5 93 | networkx==2.6.3 94 | nibabel==4.0.1 95 | notebook==6.4.12 96 | numpy==1.21.6 97 | numpydoc==1.3.1 98 | oauthlib==3.2.0 99 | opencv-python==4.7.0.72 100 | packaging==21.3 101 | pandas==1.3.5 102 | pandocfilters==1.5.0 103 | parso==0.8.3 104 | pathspec==0.9.0 105 | pexpect==4.8.0 106 | pickleshare==0.7.5 107 | Pillow==9.1.1 108 | platformdirs==2.5.2 109 | pluggy==1.0.0 110 | pooch==1.6.0 111 | prometheus-client==0.14.1 112 | prompt-toolkit==3.0.29 113 | protobuf==3.19.4 114 | psutil==5.9.1 115 | ptyprocess==0.7.0 116 | pyasn1==0.4.8 117 | pyasn1-modules==0.2.8 118 | pycodestyle==2.8.0 119 | pycparser==2.21 120 | pyDeprecate==0.3.2 121 | pydocstyle==6.1.1 122 | pyerfa==2.0.0.1 123 | pyflakes==2.4.0 124 | Pygments==2.12.0 125 | pylint==2.14.1 126 | pyls-spyder==0.4.0 127 | pyparsing==3.0.9 128 | PyQt5==5.15.6 129 | PyQt5-Qt5==5.15.2 130 | PyQt5-sip==12.10.1 131 | PyQtWebEngine==5.15.5 132 | PyQtWebEngine-Qt5==5.15.2 133 | pyquaternion==0.9.9 134 | pyrsistent==0.18.1 135 | pyshtools==4.10 136 | python-dateutil==2.8.2 137 | python-lsp-black==1.2.1 138 | python-lsp-jsonrpc==1.0.0 139 | python-lsp-server==1.4.1 140 | python-slugify==6.1.2 141 | pytorch-lightning==1.6.4 142 | pytz==2022.1 143 | PyWavelets==1.3.0 144 | pyxdg==0.28 145 | PyYAML==6.0 146 | pyzmq==23.1.0 147 | QDarkStyle==3.0.3 148 | qstylizer==0.2.1 149 | QtAwesome==1.1.1 150 | qtconsole==5.3.1 151 | QtPy==2.1.0 152 | requests==2.27.1 153 | requests-oauthlib==1.3.1 154 | rope==1.1.1 155 | rsa==4.8 156 | Rtree==1.0.0 157 | scikit-image==0.19.3 158 | scikit-learn==1.0.2 159 | scipy==1.7.3 160 | SecretStorage==3.3.2 161 | Send2Trash==1.8.0 162 | six==1.16.0 163 | snowballstemmer==2.2.0 164 | sortedcontainers==2.4.0 165 | soupsieve==2.3.2.post1 166 | Sphinx==4.3.2 167 | sphinxcontrib-applehelp==1.0.2 168 | sphinxcontrib-devhelp==1.0.2 169 | sphinxcontrib-htmlhelp==2.0.0 170 | sphinxcontrib-jsmath==1.0.1 171 | sphinxcontrib-qthelp==1.0.3 172 | sphinxcontrib-serializinghtml==1.1.5 173 | spyder==5.3.1 174 | spyder-kernels==2.3.1 175 | tensorboard==2.9.1 176 | tensorboard-data-server==0.6.1 177 | tensorboard-plugin-wit==1.8.1 178 | terminado==0.15.0 179 | text-unidecode==1.3 180 | textdistance==4.2.2 181 | threadpoolctl==3.1.0 182 | three-merge==0.1.1 183 | tifffile==2021.11.2 184 | tinycss2==1.1.1 185 | toml==0.10.2 186 | tomli==2.0.1 187 | tomlkit==0.11.0 188 | torch==1.11.0+cu113 189 | torchaudio==0.11.0+cu113 190 | torchmetrics==0.9.1 191 | torchvision==0.12.0 192 | tornado==6.1 193 | tqdm==4.64.0 194 | traitlets==5.2.2.post1 195 | typed-ast==1.5.4 196 | typing_extensions==4.2.0 197 | ujson==5.3.0 198 | urllib3==1.26.9 199 | watchdog==2.1.8 200 | wcwidth==0.2.5 201 | webencodings==0.5.1 202 | Werkzeug==2.1.2 203 | widgetsnbextension==3.6.0 204 | wrapt==1.14.1 205 | wurlitzer==3.0.2 206 | xarray==0.20.2 207 | yapf==0.32.0 208 | yarl==1.7.2 209 | zipp==3.8.0 210 | -------------------------------------------------------------------------------- /train_script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jan 14 13:46:20 2020 5 | 6 | @author: eschweiler 7 | """ 8 | 9 | from argparse import ArgumentParser 10 | 11 | import numpy as np 12 | import torch 13 | import glob 14 | import os 15 | 16 | from pytorch_lightning import Trainer 17 | from pytorch_lightning.loggers import TensorBoardLogger 18 | from pytorch_lightning.callbacks import ModelCheckpoint 19 | 20 | SEED = 1337 21 | torch.manual_seed(SEED) 22 | np.random.seed(SEED) 23 | 24 | 25 | def main(hparams): 26 | 27 | """ 28 | Main training routine specific for this project 29 | :param hparams: 30 | """ 31 | 32 | # ------------------------ 33 | # 1 INIT LIGHTNING MODEL 34 | # ------------------------ 35 | model = network(hparams=hparams) 36 | os.makedirs(hparams.output_path, exist_ok=True) 37 | 38 | # Load pretrained weights if available 39 | if not hparams.pretrained is None: 40 | model.load_pretrained(hparams.pretrained) 41 | 42 | # Resume from checkpoint if available 43 | resume_ckpt = None 44 | if hparams.resume: 45 | checkpoints = glob.glob(os.path.join(hparams.output_path,'*.ckpt')) 46 | checkpoints.sort(key=os.path.getmtime) 47 | if len(checkpoints)>0: 48 | resume_ckpt = checkpoints[-1] 49 | print('Resuming from checkpoint: {0}'.format(resume_ckpt)) 50 | 51 | # Set the augmentations if available 52 | model.set_augmentations(hparams.augmentations) 53 | 54 | # Save a few samples for sanity checks 55 | print('Saving 20 data samples for sanity checks...') 56 | model.train_dataloader().dataset.test(os.path.join(hparams.output_path, 'samples'), num_files=20) 57 | 58 | # ------------------------ 59 | # 2 INIT TRAINER 60 | # ------------------------ 61 | 62 | checkpoint_callback = ModelCheckpoint( 63 | dirpath=hparams.output_path, 64 | filename=hparams.pipeline+'-{epoch:03d}-{step}', 65 | save_top_k=1, 66 | monitor='step', 67 | mode='max', 68 | verbose=True, 69 | every_n_epochs=1 70 | ) 71 | 72 | logger = TensorBoardLogger( 73 | save_dir=hparams.log_path, 74 | name='lightning_logs_'+hparams.pipeline.lower() 75 | ) 76 | 77 | trainer = Trainer( 78 | logger=logger, 79 | enable_checkpointing=True, 80 | callbacks=[checkpoint_callback], 81 | gpus=hparams.gpus, 82 | min_epochs=hparams.epochs, 83 | max_epochs=hparams.epochs, 84 | resume_from_checkpoint=resume_ckpt 85 | ) 86 | 87 | # ------------------------ 88 | # 3 START TRAINING 89 | # ------------------------ 90 | trainer.fit(model) 91 | 92 | 93 | 94 | if __name__ == '__main__': 95 | # ------------------------ 96 | # TRAINING ARGUMENTS 97 | # ------------------------ 98 | # these are project-wide arguments 99 | 100 | parent_parser = ArgumentParser(add_help=False) 101 | 102 | # gpu args 103 | parent_parser.add_argument( 104 | '--output_path', 105 | type=str, 106 | default=r'results/experiment1', 107 | help='output path for test results' 108 | ) 109 | 110 | parent_parser.add_argument( 111 | '--log_path', 112 | type=str, 113 | default=r'logs/logs_experiment1', 114 | help='output path for test results' 115 | ) 116 | 117 | parent_parser.add_argument( 118 | '--gpus', 119 | type=int, 120 | default=1, 121 | help='number of GPUs to use' 122 | ) 123 | 124 | parent_parser.add_argument( 125 | '--no_resume', 126 | dest='resume', 127 | action='store_false', 128 | default=True, 129 | help='Do not resume training from latest checkpoint' 130 | ) 131 | 132 | parent_parser.add_argument( 133 | '--pretrained', 134 | type=str, 135 | default=None, 136 | nargs='+', 137 | help='path to pretrained model weights' 138 | ) 139 | 140 | parent_parser.add_argument( 141 | '--augmentations', 142 | type=str, 143 | default=None, 144 | help='path to augmentation dict file' 145 | ) 146 | 147 | parent_parser.add_argument( 148 | '--epochs', 149 | type=int, 150 | default=5000, 151 | help='number of epochs' 152 | ) 153 | 154 | parent_parser.add_argument( 155 | '--pipeline', 156 | type=str, 157 | default='DiffusionModel3D', 158 | help='which pipeline to load (DiffusionModel3D | DiffusionModel2D)' 159 | ) 160 | 161 | parent_args = parent_parser.parse_known_args()[0] 162 | 163 | # load the desired network architecture 164 | if parent_args.pipeline.lower() == 'diffusionmodel3d': 165 | from models.DiffusionModel3D import DiffusionModel3D as network 166 | elif parent_args.pipeline.lower() == 'diffusionmodel2d': 167 | from models.DiffusionModel2D import DiffusionModel2D as network 168 | else: 169 | raise ValueError('Pipeline {0} unknown.'.format(parent_args.pipeline)) 170 | 171 | # each LightningModule defines arguments relevant to it 172 | parser = network.add_model_specific_args(parent_parser) 173 | hyperparams = parser.parse_args() 174 | 175 | # --------------------- 176 | # RUN TRAINING 177 | # --------------------- 178 | main(hyperparams) 179 | -------------------------------------------------------------------------------- /utils/csv_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 19 15:55:13 2020 4 | 5 | @author: Nutzer 6 | """ 7 | 8 | import os 9 | import glob 10 | import csv 11 | import numpy as np 12 | 13 | 14 | def get_files(folders, data_root='', descriptor='', filetype='tif'): 15 | 16 | filelist = [] 17 | 18 | for folder in folders: 19 | files = glob.glob(os.path.join(data_root, folder, '*'+descriptor+'*.'+filetype)) 20 | filelist.extend([os.path.join(folder, os.path.split(f)[-1]) for f in files]) 21 | 22 | return filelist 23 | 24 | 25 | 26 | def read_csv(list_path, data_root=''): 27 | 28 | filelist = [] 29 | 30 | with open(list_path, 'r') as f: 31 | reader = csv.reader(f, delimiter=';') 32 | for row in reader: 33 | if len(row)==0 or np.sum([len(r) for r in row])==0: continue 34 | row = [os.path.join(data_root, r) for r in row] 35 | filelist.append(row) 36 | 37 | return filelist 38 | 39 | 40 | 41 | def create_csv(data_list, save_path='list_folder/experiment_name', test_split=0.2, val_split=0.1, shuffle=False): 42 | 43 | if shuffle: 44 | np.random.shuffle(data_list) 45 | 46 | # Get number of files for each split 47 | num_files = len(data_list) 48 | num_test_files = int(test_split*num_files) 49 | num_val_files = int((num_files-num_test_files)*val_split) 50 | num_train_files = num_files - num_test_files - num_val_files 51 | 52 | # Adjust file identifier if there is no split 53 | if test_split>0 or val_split>0: 54 | train_identifier='_train.csv' 55 | else: 56 | train_identifier='.csv' 57 | 58 | # Get file indices 59 | file_idx = np.arange(num_files) 60 | 61 | # Save csv files 62 | if num_test_files > 0: 63 | test_idx = sorted(np.random.choice(file_idx, size=num_test_files, replace=False)) 64 | with open(save_path+'_test.csv', 'w', newline='') as fh: 65 | writer = csv.writer(fh, delimiter=';') 66 | for idx in test_idx: 67 | writer.writerow(data_list[idx]) 68 | else: 69 | test_idx = [] 70 | 71 | if num_val_files > 0: 72 | val_idx = sorted(np.random.choice(list(set(file_idx)-set(test_idx)), size=num_val_files, replace=False)) 73 | with open(save_path+'_val.csv', 'w', newline='') as fh: 74 | writer = csv.writer(fh, delimiter=';') 75 | for idx in val_idx: 76 | writer.writerow(data_list[idx]) 77 | else: 78 | val_idx = [] 79 | 80 | if num_train_files > 0: 81 | train_idx = sorted(list(set(file_idx) - set(test_idx) - set(val_idx))) 82 | with open(save_path+train_identifier, 'w', newline='') as fh: 83 | writer = csv.writer(fh, delimiter=';') 84 | for idx in train_idx: 85 | writer.writerow(data_list[idx]) 86 | 87 | -------------------------------------------------------------------------------- /utils/h5_converter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import h5py 6 | import glob 7 | import numpy as np 8 | 9 | from skimage import io, morphology, measure, filters 10 | from scipy.ndimage import distance_transform_edt, zoom, generic_filter 11 | from scipy.spatial import ConvexHull, Delaunay 12 | 13 | from utils.utils import print_timestamp 14 | 15 | 16 | 17 | def h5_writer(data_list, save_path, group_root='data', group_names=['image']): 18 | 19 | save_path = os.path.abspath(save_path) 20 | 21 | assert(len(data_list)==len(group_names)), 'Each data matrix needs a group name' 22 | 23 | with h5py.File(save_path, 'w') as f_handle: 24 | grp = f_handle.create_group(group_root) 25 | for data, group_name in zip(data_list, group_names): 26 | grp.create_dataset(group_name, data=data, chunks=True, compression='gzip') 27 | 28 | 29 | 30 | def h5_reader(file, source_group='data/image'): 31 | 32 | with h5py.File(file, 'r') as file_handle: 33 | data = file_handle[source_group][:] 34 | 35 | return data 36 | 37 | 38 | 39 | 40 | def h52tif(file_dir='', identifier='*', group_names=['data/image']): 41 | 42 | # Get all files within the given directory 43 | filelist = glob.glob(os.path.join(file_dir, identifier+'.h5')) 44 | 45 | # Create saving folders 46 | for group in group_names: 47 | os.makedirs(os.path.join(file_dir, ''.join(s for s in group if s.isalnum())), exist_ok=True) 48 | 49 | # Save each desired group 50 | for num_file,file in enumerate(filelist): 51 | print_timestamp('Processing file {0}/{1}', (num_file+1, len(filelist))) 52 | with h5py.File(file, 'r') as file_handle: 53 | for group in group_names: 54 | data = file_handle[group][:] 55 | io.imsave(os.path.join(file_dir, ''.join(s for s in group if s.isalnum()), os.path.split(file)[-1][:-2]+'tif'), data) 56 | 57 | 58 | 59 | def replace_h5_group(source_list, target_list, source_group='data/image', target_group=None): 60 | 61 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list)) 62 | if target_group is None: target_group=source_group 63 | 64 | for num_pair, pair in enumerate(zip(source_list, target_list)): 65 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(target_list)]) 66 | 67 | # Load the source mask 68 | with h5py.File(pair[0], 'r') as source_handle: 69 | source_data = source_handle[source_group][...] 70 | 71 | # Save the data to the target file 72 | with h5py.File(pair[1], 'r+') as target_handle: 73 | target_data = target_handle[target_group] 74 | target_data[...] = source_data 75 | 76 | 77 | 78 | 79 | def add_group(file, data, target_group='data/image'): 80 | 81 | with h5py.File(file, 'a') as file_handle: 82 | file_handle.create_dataset(target_group, data=data, chunks=True, compression='gzip') 83 | 84 | 85 | 86 | 87 | def add_h5_group(source_list, target_list, source_group='data/distance', target_group=None): 88 | 89 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list)) 90 | if target_group is None: target_group=source_group 91 | 92 | for num_pair, pair in enumerate(zip(source_list, target_list)): 93 | 94 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(source_list)]) 95 | 96 | # Get the data from the source file 97 | with h5py.File(pair[0], 'r') as source_handle: 98 | source_data = source_handle[source_group][...] 99 | 100 | # Save the data to the target file 101 | try: 102 | with h5py.File(pair[1], 'a') as target_handle: 103 | target_handle.create_dataset(target_group, data=source_data, chunks=True, compression='gzip') 104 | except: 105 | print_timestamp('Skipping file "{0}"...', [os.path.split(pair[1])[-1]]) 106 | 107 | 108 | 109 | 110 | def add_tiff_group(source_list, target_list, target_group='data/newgroup'): 111 | 112 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list)) 113 | assert target_group is not None, 'There needs to be a target group name!' 114 | 115 | for num_pair, pair in enumerate(zip(source_list, target_list)): 116 | 117 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(source_list)]) 118 | 119 | # Get the data from the source file 120 | source_data = io.imread(pair[0]) 121 | 122 | # Save the data to the target file 123 | with h5py.File(pair[1], 'a') as target_handle: 124 | target_handle.create_dataset(target_group, data=source_data-np.min(source_data), chunks=True, compression='gzip') 125 | 126 | 127 | 128 | def remove_h5_group(file_list, source_group='data/nuclei'): 129 | 130 | for num_file, file in enumerate(file_list): 131 | 132 | print_timestamp('Processing file {0}/{1}...', [num_file+1, len(file_list)]) 133 | 134 | with h5py.File(file, 'a') as file_handle: 135 | del file_handle[source_group] 136 | 137 | 138 | 139 | def flood_fill_hull(image): 140 | 141 | # Credits: https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays 142 | points = np.transpose(np.where(image)) 143 | hull = ConvexHull(points) 144 | deln = Delaunay(points[hull.vertices]) 145 | idx = np.stack(np.indices(image.shape), axis = -1) 146 | out_idx = np.nonzero(deln.find_simplex(idx) + 1) 147 | out_img = np.zeros(image.shape) 148 | out_img[out_idx] = 1 149 | 150 | return out_img 151 | 152 | 153 | 154 | def calculate_flows(instance_mask, bg_label=0): 155 | 156 | flow_x = np.zeros(instance_mask.shape, dtype=np.float32) 157 | flow_y = np.zeros(instance_mask.shape, dtype=np.float32) 158 | flow_z = np.zeros(instance_mask.shape, dtype=np.float32) 159 | regions = measure.regionprops(instance_mask) 160 | for props in regions: 161 | 162 | if props.label == bg_label: 163 | continue 164 | 165 | # get all coordinates within instance 166 | c = props.centroid 167 | coords = np.where(instance_mask==props.label) 168 | 169 | # calculate minimum extend in all spatial directions 170 | norm_x = np.maximum(1, np.minimum(np.abs(c[0]-props.bbox[0]),np.abs(c[0]-props.bbox[3]))/3) 171 | norm_y = np.maximum(1, np.minimum(np.abs(c[1]-props.bbox[1]),np.abs(c[1]-props.bbox[4]))/3) 172 | norm_z = np.maximum(1, np.minimum(np.abs(c[2]-props.bbox[2]),np.abs(c[2]-props.bbox[5]))/3) 173 | 174 | # calculate flows 175 | flow_x[coords] = np.tanh((coords[0]-c[0])/norm_x) 176 | flow_y[coords] = np.tanh((coords[1]-c[1])/norm_y) 177 | flow_z[coords] = np.tanh((coords[2]-c[2])/norm_z) 178 | 179 | return flow_x, flow_y, flow_z 180 | 181 | 182 | 183 | def rescale_data(data, zoom_factor, order=0): 184 | 185 | if any([zf!=1 for zf in zoom_factor]): 186 | data_shape = data.shape 187 | data = zoom(data, zoom_factor, order=order) 188 | print_timestamp('Rescaled image from size {0} to {1}'.format(data_shape, data.shape)) 189 | 190 | return data 191 | 192 | 193 | 194 | def prepare_images(data_path='', folders=[''], identifier='*.tif', descriptor='', normalize=[0,100],\ 195 | get_distance=False, get_illumination=False, get_variance=False, variance_size=(5,5,5),\ 196 | fg_selem_size=5, zoom_factor=(1,1,1), channel=0, clip=(-99999,99999),\ 197 | save_path=None, save_folders=None): 198 | 199 | data_path = os.path.abspath(data_path) 200 | 201 | if save_path is None: save_path = data_path 202 | if save_folders is None: save_folders = folders 203 | if len(save_folders)==1: save_folders = save_folders*len(folders) 204 | elif len(save_folders)!=len(folders): 205 | save_folders=folders 206 | print_timestamp('Could not save into the given folders! Number of save folders and folders must be the same.') 207 | 208 | 209 | for num_folder, (image_folder, save_folder) in enumerate(zip(folders, save_folders)): 210 | os.makedirs(os.path.join(save_path, save_folder), exist_ok=True) 211 | image_list = glob.glob(os.path.join(data_path, image_folder, identifier)) 212 | for num_file,file in enumerate(image_list): 213 | 214 | print_timestamp('Processing image {0}/{1} in folder {2}/{3} {4}', (num_file+1, len(image_list), num_folder+1, len(folders), image_folder)) 215 | 216 | # load the image 217 | processed_img = io.imread(file) 218 | processed_img = processed_img.astype(np.float32) 219 | processed_img = np.clip(processed_img, *clip) 220 | 221 | # get the desired channel, if the image is a multichannel image 222 | if processed_img.ndim == 4: 223 | processed_img = processed_img[...,channel] 224 | 225 | # get the desired image dimensionality, if image is only 2D 226 | if processed_img.ndim==2: 227 | processed_img = processed_img[np.newaxis,...] 228 | 229 | # rescale the image 230 | processed_img = rescale_data(processed_img, zoom_factor, order=3) 231 | 232 | # normalize the image 233 | perc1, perc2 = np.percentile(processed_img, list(normalize)) 234 | processed_img -= perc1 235 | processed_img /= (perc2-perc1) 236 | processed_img = np.clip(processed_img, 0, 1) 237 | processed_img = processed_img.astype(np.float32) 238 | 239 | save_imgs = [processed_img,] 240 | save_groups = ['image',] 241 | 242 | if get_illumination: 243 | 244 | print_timestamp('Extracting illumination image...') 245 | 246 | # create downscales image for computantially intensive processing 247 | small_img = processed_img[::2,::2,::2] 248 | 249 | # create an illuminance image (downscale for faster processing) 250 | illu_img = morphology.closing(small_img, selem=morphology.ball(7)) 251 | illu_img = filters.gaussian(illu_img, 2).astype(np.float32) 252 | 253 | # rescale illuminance image 254 | illu_img = np.repeat(illu_img, 2, axis=0) 255 | illu_img = np.repeat(illu_img, 2, axis=1) 256 | illu_img = np.repeat(illu_img, 2, axis=2) 257 | dim_missmatch = np.array(processed_img.shape)-np.array(illu_img.shape) 258 | if dim_missmatch[0]<0: illu_img = illu_img[:dim_missmatch[0],...] 259 | if dim_missmatch[1]<0: illu_img = illu_img[:,:dim_missmatch[1],:] 260 | if dim_missmatch[2]<0: illu_img = illu_img[...,:dim_missmatch[2]] 261 | 262 | save_imgs.append(illu_img.astype(np.float32)) 263 | save_groups.append('illumination') 264 | 265 | if get_distance: 266 | 267 | print_timestamp('Extracting distance image...') 268 | 269 | # create downscales image for computantially intensive processing 270 | small_img = processed_img[::4,::4,::4] 271 | 272 | # find suitable threshold 273 | thresh = filters.threshold_otsu(small_img) 274 | fg_img = small_img > thresh 275 | 276 | # remove noise and fill holes 277 | fg_img = morphology.binary_closing(fg_img, selem=morphology.ball(fg_selem_size)) 278 | fg_img = morphology.binary_opening(fg_img, selem=morphology.ball(fg_selem_size)) 279 | fg_img = flood_fill_hull(fg_img) 280 | fg_img = fg_img.astype(np.bool) 281 | 282 | # create distance transform 283 | fg_img = distance_transform_edt(fg_img) - distance_transform_edt(~fg_img) 284 | 285 | # rescale distance image 286 | fg_img = np.repeat(fg_img, 4, axis=0) 287 | fg_img = np.repeat(fg_img, 4, axis=1) 288 | fg_img = np.repeat(fg_img, 4, axis=2) 289 | dim_missmatch = np.array(processed_img.shape)-np.array(fg_img.shape) 290 | if dim_missmatch[0]<0: fg_img = fg_img[:dim_missmatch[0],...] 291 | if dim_missmatch[1]<0: fg_img = fg_img[:,:dim_missmatch[1],:] 292 | if dim_missmatch[2]<0: fg_img = fg_img[...,:dim_missmatch[2]] 293 | 294 | save_imgs.append(fg_img.astype(np.float32)) 295 | save_groups.append('distance') 296 | 297 | if get_variance: 298 | 299 | print_timestamp('Extracting variance image...') 300 | 301 | # create downscales image for computantially intensive processing 302 | small_img = processed_img[::4,::4,::4] 303 | 304 | # create variance image 305 | std_img = generic_filter(small_img, np.std, size=variance_size) 306 | 307 | # rescale variance image 308 | std_img = np.repeat(std_img, 4, axis=0) 309 | std_img = np.repeat(std_img, 4, axis=1) 310 | std_img = np.repeat(std_img, 4, axis=2) 311 | dim_missmatch = np.array(processed_img.shape)-np.array(std_img.shape) 312 | if dim_missmatch[0]<0: std_img = std_img[:dim_missmatch[0],...] 313 | if dim_missmatch[1]<0: std_img = std_img[:,:dim_missmatch[1],:] 314 | if dim_missmatch[2]<0: std_img = std_img[...,:dim_missmatch[2]] 315 | 316 | save_imgs.append(std_img.astype(np.float32)) 317 | save_groups.append('variance') 318 | 319 | # save the data 320 | save_name = os.path.split(file)[-1] 321 | save_name = os.path.join(save_path, save_folder, descriptor+save_name[:-4]+'.h5') 322 | h5_writer(save_imgs, save_name, group_root='data', group_names=save_groups) 323 | 324 | 325 | 326 | 327 | 328 | def prepare_masks(data_path='', folders=[''], identifier='*.tif', descriptor='',\ 329 | bg_label=0, get_flows=False, get_boundary=True, get_seeds=False, get_distance=True,\ 330 | corrupt_prob=0.0, zoom_factor=(1,1,1), convex_hull=False,\ 331 | save_path=None, save_folders=None): 332 | 333 | data_path = os.path.abspath(data_path) 334 | 335 | if save_path is None: save_path = data_path 336 | if save_folders is None: save_folders = folders 337 | if len(save_folders)==1: save_folders = save_folders*len(folders) 338 | elif len(save_folders)!=len(folders): 339 | save_folders=folders 340 | print_timestamp('Could not save into the given folders! Number of save folders and folders must be the same.') 341 | 342 | for num_folder, (mask_folder,save_folder) in enumerate(zip(folders,save_folders)): 343 | mask_list = glob.glob(os.path.join(data_path, mask_folder, identifier)) 344 | experiment_identifier = 'corrupt'+str(corrupt_prob).replace('.','') if corrupt_prob > 0 else '' 345 | os.makedirs(os.path.join(data_path, save_folder, experiment_identifier), exist_ok=True) 346 | for num_file,file in enumerate(mask_list): 347 | 348 | print_timestamp('Processing mask {0}/{1} in folder {2}/{3} {4}', (num_file+1, len(mask_list), num_folder+1, len(folders), mask_folder)) 349 | 350 | # load the mask 351 | instance_mask = io.imread(file) 352 | instance_mask = instance_mask.astype(np.uint16) 353 | instance_mask[instance_mask==bg_label] = 0 354 | 355 | # get the desired image dimensionality, if image is only 2D 356 | if instance_mask.ndim==2: 357 | instance_mask = instance_mask[np.newaxis,...] 358 | 359 | # rescale the mask 360 | instance_mask = rescale_data(instance_mask, zoom_factor, order=0) 361 | 362 | if corrupt_prob > 0: 363 | # Randomly merge neighbouring instances 364 | labels = list(set(np.unique(instance_mask))-set([bg_label])) 365 | instance_mask_eroded = morphology.erosion(instance_mask, selem=morphology.ball(3)) 366 | instance_mask_dilated = morphology.dilation(instance_mask, selem=morphology.ball(3)) 367 | for label in labels: 368 | if np.random.rand() < corrupt_prob: 369 | neighbour_labels = list(instance_mask_eroded[instance_mask==label]) + list(instance_mask_dilated[instance_mask==label]) 370 | neighbour_labels = list(set(neighbour_labels)-set([label,])) 371 | if len(neighbour_labels) > 0: 372 | replace_label = np.random.choice(neighbour_labels) 373 | instance_mask[instance_mask==label] = replace_label 374 | 375 | save_groups = ['instance',] 376 | save_masks = [instance_mask,] 377 | 378 | # get the boundary mask 379 | if get_boundary: 380 | membrane_mask = morphology.dilation(instance_mask, selem=morphology.ball(2)) - instance_mask 381 | membrane_mask = membrane_mask != 0 382 | membrane_mask = membrane_mask.astype(np.float32) 383 | save_groups.append('boundary') 384 | save_masks.append(membrane_mask) 385 | 386 | # get the distance mask 387 | if get_distance: 388 | fg_img = instance_mask[::4,::4,::4]>0 389 | if convex_hull: fg_img = flood_fill_hull(fg_img) 390 | fg_img = fg_img.astype(np.bool) 391 | distance_mask = distance_transform_edt(fg_img) - distance_transform_edt(~fg_img) 392 | distance_mask = distance_mask.astype(np.float32) 393 | distance_mask = np.repeat(distance_mask, 4, axis=0) 394 | distance_mask = np.repeat(distance_mask, 4, axis=1) 395 | distance_mask = np.repeat(distance_mask, 4, axis=2) 396 | dim_missmatch = np.array(instance_mask.shape)-np.array(distance_mask.shape) 397 | if dim_missmatch[0]<0: distance_mask = distance_mask[:dim_missmatch[0],...] 398 | if dim_missmatch[1]<0: distance_mask = distance_mask[:,:dim_missmatch[1],:] 399 | if dim_missmatch[2]<0: distance_mask = distance_mask[...,:dim_missmatch[2]] 400 | save_groups.append('distance') 401 | save_masks.append(distance_mask) 402 | 403 | # get the centroid mask 404 | if get_seeds: 405 | centroid_mask = np.zeros(instance_mask.shape, dtype=np.float32) 406 | regions = measure.regionprops(instance_mask) 407 | 408 | for props in regions: 409 | 410 | if props.label == bg_label: 411 | continue 412 | 413 | c = props.centroid 414 | centroid_mask[np.int(c[0]), np.int(c[1]), np.int(c[2])] = 1 415 | 416 | save_groups.append('seeds') 417 | save_masks.append(centroid_mask) 418 | 419 | # calculate the flow field 420 | if get_flows: 421 | 422 | flow_x, flow_y, flow_z = calculate_flows(instance_mask, bg_label=bg_label) 423 | 424 | save_groups.extend(['flow_x','flow_y', 'flow_z']) 425 | save_masks.extend([flow_x, flow_y, flow_z]) 426 | 427 | # save the data 428 | save_name = os.path.split(file)[-1] 429 | save_name = os.path.join(save_path, save_folder, experiment_identifier, descriptor+save_name[:-4]+'.h5') 430 | h5_writer(save_masks, save_name, group_root='data', group_names=save_groups) 431 | -------------------------------------------------------------------------------- /utils/harmonics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 5 11:28:53 2020 4 | 5 | @author: Nutzer 6 | """ 7 | 8 | import numpy as np 9 | import multiprocessing as mp 10 | 11 | from functools import partial 12 | from skimage import morphology 13 | from scipy.spatial import cKDTree, Delaunay 14 | from scipy.special import sph_harm 15 | from matplotlib import pyplot as plt 16 | from mpl_toolkits.mplot3d import Axes3D 17 | from dipy.core.geometry import sphere2cart, cart2sphere 18 | 19 | from utils.utils import print_timestamp 20 | 21 | 22 | 23 | def scatter_3d(coords1, coords2, coords3, cartesian=True): 24 | # (x, y, z) or (r, theta, phi) 25 | 26 | if not cartesian: 27 | coords1, coords2, coords3 = sphere2cart(coords1, coords2, coords3) 28 | 29 | fig = plt.figure() 30 | ax = Axes3D(fig) 31 | ax.scatter(coords1, coords2, coords3, depthshade=True) 32 | plt.show() 33 | 34 | 35 | 36 | def render_coeffs(coeffs, theta_phi_sampling_file, sh_order=5, radius=30): 37 | 38 | assert len(coeffs)==(sh_order+1)**2, 'SH order and number of coefficients do not match.' 39 | 40 | coeffs = np.array(coeffs) 41 | coeffs = coeffs/coeffs[0] 42 | coeffs *= radius 43 | 44 | img_shape = np.array(3*(radius*3,)) 45 | 46 | theta_phi_sampling = np.load(theta_phi_sampling_file) 47 | h2s = harmonics2sampling(sh_order, theta_phi_sampling) 48 | r_sampling = h2s.convert(coeffs[np.newaxis,:]) 49 | 50 | instance_mask = sampling2instance([img_shape//2,], r_sampling, theta_phi_sampling, img_shape, verbose=True) 51 | 52 | return instance_mask 53 | 54 | 55 | 56 | def instance2sampling(instances, theta_phi_sampling, bg_label=1, centroids=None, verbose=True): 57 | 58 | # Get labels 59 | labels = np.unique(instances) 60 | labels = np.array(list(set(labels)-set([bg_label]))) 61 | 62 | sampling = np.zeros((len(labels),theta_phi_sampling.shape[0])) 63 | 64 | if centroids is None: 65 | get_centroids = True 66 | centroids = np.zeros((len(labels),3)) 67 | else: 68 | get_centroids = False 69 | assert len(labels) == len(centroids), 'There needs to be a centroid for each label!' 70 | 71 | for num_label, label in enumerate(labels): 72 | 73 | if verbose: print_timestamp('Sampling instance {0}/{1}...', [num_label+1, len(labels)]) 74 | 75 | if np.count_nonzero(instances==label) < 9: 76 | print_timestamp('Skipping label {0} due to its tiny size.', [label]) 77 | continue 78 | 79 | # get inner boundary of the current instance 80 | instance_inner = morphology.binary_erosion(instances==label, selem=morphology.ball(1)) 81 | instance_inner = np.logical_xor(instance_inner, instances==label) 82 | 83 | # get binary instance mask 84 | x,y,z = np.where(instance_inner) 85 | if get_centroids: 86 | centroids[num_label,:] = [x.mean(),y.mean(),z.mean()] 87 | x -= int(centroids[num_label,0]) 88 | y -= int(centroids[num_label,1]) 89 | z -= int(centroids[num_label,2]) 90 | r,theta,phi = cart2sphere(x, y, z) 91 | 92 | # find closest sampling angles 93 | sampling_tree = cKDTree(np.array([theta,phi]).T) 94 | _, assignments = sampling_tree.query(theta_phi_sampling, k=3) 95 | 96 | # get sampling 97 | sampling[num_label,:] = np.mean(r[assignments], axis=1) 98 | 99 | return labels, centroids, sampling 100 | 101 | 102 | 103 | def sampling2instance(centroids, sampling, theta_phi_sampling, shape, verbose=True): 104 | 105 | instances = np.full(shape, 0, dtype=np.uint16) 106 | idx = np.reshape(np.indices(shape), (3,-1)) 107 | 108 | label = 0 109 | for centroid, r in zip(centroids, sampling): 110 | label += 1 111 | if verbose: print_timestamp('Reconstructing instance {0}/{1}...', [label, sampling.shape[0]]) 112 | 113 | x,y,z = sphere2cart(r, theta_phi_sampling[:,0], theta_phi_sampling[:,1]) 114 | x += centroid[0] 115 | y += centroid[1] 116 | z += centroid[2] 117 | 118 | delaunay_tri = Delaunay(np.array([x,y,z]).T) 119 | 120 | voxel_idx = delaunay_tri.find_simplex(idx.T).reshape(shape)>0 121 | 122 | instances[voxel_idx] = label 123 | 124 | return instances 125 | 126 | 127 | 128 | 129 | class sampling2harmonics(): 130 | 131 | def __init__(self, sh_order, theta_phi_sampling, lb_lambda=0.006): 132 | super(sampling2harmonics, self).__init__() 133 | self.sh_order = sh_order 134 | self.theta_phi_sampling = theta_phi_sampling 135 | self.lb_lambda = lb_lambda 136 | self.num_samples = len(theta_phi_sampling) 137 | self.num_coefficients = np.int((self.sh_order+1)**2) 138 | 139 | b = np.zeros((self.num_samples, self.num_coefficients)) 140 | l = np.zeros((self.num_coefficients, self.num_coefficients)) 141 | 142 | for num_sample in range(self.num_samples): 143 | num_coefficient = 0 144 | for num_order in range(self.sh_order+1): 145 | for num_degree in range(-num_order, num_order+1): 146 | 147 | theta = theta_phi_sampling[num_sample][0] 148 | phi = theta_phi_sampling[num_sample][1] 149 | 150 | y = sph_harm(np.abs(num_degree), num_order, phi, theta) 151 | 152 | if num_degree < 0: 153 | b[num_sample, num_coefficient] = np.real(y) * np.sqrt(2) 154 | elif num_degree == 0: 155 | b[num_sample, num_coefficient] = np.real(y) 156 | elif num_degree > 0: 157 | b[num_sample, num_coefficient] = np.imag(y) * np.sqrt(2) 158 | 159 | l[num_coefficient, num_coefficient] = self.lb_lambda * num_order ** 2 * (num_order + 1) ** 2 160 | num_coefficient += 1 161 | 162 | b_inv = np.linalg.pinv(np.matmul(b.transpose(), b) + l) 163 | self.convert_mat = np.matmul(b_inv, b.transpose()).transpose() 164 | 165 | def convert(self, r_sampling): 166 | converted_samples = np.zeros((r_sampling.shape[0],self.num_coefficients)) 167 | for num_sample, r_sample in enumerate(r_sampling): 168 | r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat) 169 | converted_samples[num_sample] = np.squeeze(r_converted) 170 | return converted_samples 171 | 172 | 173 | 174 | 175 | class harmonics2sampling(): 176 | 177 | def __init__(self, sh_order, theta_phi_sampling): 178 | super(harmonics2sampling, self).__init__() 179 | self.sh_order = sh_order 180 | self.theta_phi_sampling = theta_phi_sampling 181 | self.num_samples = len(theta_phi_sampling) 182 | self.num_coefficients = np.int((self.sh_order+1)**2) 183 | 184 | convert_mat = np.zeros((self.num_coefficients, self.num_samples)) 185 | 186 | for num_sample in range(self.num_samples): 187 | num_coefficient = 0 188 | for num_order in range(self.sh_order+1): 189 | for num_degree in range(-num_order, num_order+1): 190 | 191 | theta = theta_phi_sampling[num_sample][0] 192 | phi = theta_phi_sampling[num_sample][1] 193 | 194 | y = sph_harm(np.abs(num_degree), num_order, phi, theta) 195 | 196 | if num_degree < 0: 197 | convert_mat[num_coefficient, num_sample] = np.real(y) * np.sqrt(2) 198 | elif num_degree == 0: 199 | convert_mat[num_coefficient, num_sample] = np.real(y) 200 | elif num_degree > 0: 201 | convert_mat[num_coefficient, num_sample] = np.imag(y) * np.sqrt(2) 202 | 203 | num_coefficient += 1 204 | 205 | self.convert_mat = convert_mat 206 | 207 | def convert(self, r_harmonic): 208 | converted_harmonics = np.zeros((r_harmonic.shape[0],self.theta_phi_sampling.shape[0])) 209 | for num_sample, r_sample in enumerate(r_harmonic): 210 | r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat) 211 | converted_harmonics[num_sample] = np.squeeze(r_converted) 212 | return converted_harmonics 213 | 214 | 215 | 216 | 217 | def sphere_intersection_poolhelper(instance_indices, point_coords=None, radii=None): 218 | 219 | # get radiii, positions and distance 220 | r1 = radii[instance_indices[0]] 221 | r2 = radii[instance_indices[1]] 222 | p1 = point_coords[instance_indices[0]] 223 | p2 = point_coords[instance_indices[1]] 224 | d = np.sqrt(np.sum((np.array(p1)-np.array(p2))**2)) 225 | 226 | # calculate individual volumes 227 | vol1 = 4/3*np.pi*r1**3 228 | vol2 = 4/3*np.pi*r2**3 229 | 230 | # calculate intersection of volumes 231 | 232 | # Smaller sphere inside the bigger sphere 233 | if d <= np.abs(r1-r2): 234 | intersect_vol = 4/3*np.pi*np.minimum(r1,r2)**3 235 | # No intersection at all 236 | elif d > r1+r2: 237 | intersect_vol = 0 238 | # Partially intersecting spheres 239 | else: 240 | intersect_vol = np.pi * (r1 + r2 - d)**2 * (d**2 + 2*d*r2 - 3*r2**2 + 2*d*r1 + 6*r2*r1 - 3*r1**2) / (12*d) 241 | 242 | return (intersect_vol, vol1, vol2) 243 | 244 | 245 | 246 | def harmonic_non_max_suppression(point_coords, point_probs, shape_descriptors, overlap_thresh=0.5, dim_scale=(1,1,1), num_kernel=1, **kwargs): 247 | 248 | if len(point_coords)>3000: 249 | 250 | print_timestamp('Too many points, aborting NMS') 251 | nms_coords = point_coords[:3000] 252 | nms_probs = point_probs[:3000] 253 | nms_shapes = shape_descriptors[:3000] 254 | 255 | elif len(point_coords)>1: 256 | 257 | dim_scale = [d/np.min(dim_scale) for d in dim_scale] 258 | point_coords_uniform = [] 259 | for point_coord in point_coords: 260 | point_coords_uniform.append(tuple([p*d for p,d in zip(point_coord,dim_scale)])) 261 | 262 | # calculate upper and lower volumes 263 | r_upper = [r.max() for r in shape_descriptors] 264 | r_lower = [r.min() for r in shape_descriptors] 265 | 266 | # Calculate intersections of lower and upper spheres 267 | #instance_indices = list(itertools.combinations(range(len(point_coords)), r=2)) 268 | r_max = np.max(r_upper) 269 | instance_indices = [ (i, j) for i in range(len(point_coords)) 270 | for j in range(i+1, len(point_coords)) 271 | if np.sum(np.sqrt(np.abs(np.array(point_coords[i])-np.array(point_coords[j])))) < r_max*2 ] 272 | with mp.Pool(processes=num_kernel) as p: 273 | vol_upper = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_upper), instance_indices) 274 | vol_lower = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_lower), instance_indices) 275 | 276 | instances_keep = np.ones((len(point_coords),), dtype=np.bool) 277 | 278 | # calculate overlap measure 279 | for inst_idx, v_up, v_low in zip(instance_indices, vol_upper, vol_lower): 280 | 281 | # average intersection with smaller sphere 282 | overlap_measure_up = v_up[0] / np.minimum(v_up[1],v_up[2]) 283 | overlap_measure_low = v_low[0] / np.minimum(v_low[1],v_low[2]) 284 | overlap_measure = (overlap_measure_up+overlap_measure_low)/2 285 | 286 | if overlap_measure > overlap_thresh: 287 | # Get min and max probable indice 288 | inst_min = inst_idx[np.argmin([point_probs[i] for i in inst_idx])] 289 | inst_max = inst_idx[np.argmax([point_probs[i] for i in inst_idx])] 290 | 291 | # If there already was an instance with higher probability, don't add the current "winner" 292 | if instances_keep[inst_max] == 0: 293 | # Mark both as excluded 294 | instances_keep[inst_min] = 0 295 | instances_keep[inst_max] = 0 296 | else: 297 | # Exclude the loser 298 | instances_keep[inst_min] = 0 299 | #instances_keep[inst_max] = 1 300 | 301 | # Mark remaining indices for keeping 302 | #instances_keep = instances_keep != -1 303 | 304 | nms_coords = [point_coords[i] for i,v in enumerate(instances_keep) if v] 305 | nms_probs = [point_probs[i] for i,v in enumerate(instances_keep) if v] 306 | nms_shapes = [shape_descriptors[i] for i,v in enumerate(instances_keep) if v] 307 | 308 | else: 309 | nms_coords = point_coords 310 | nms_probs = point_probs 311 | nms_shapes = shape_descriptors 312 | 313 | return nms_coords, nms_shapes, nms_probs 314 | 315 | 316 | -------------------------------------------------------------------------------- /utils/jupyter_widgets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import os 6 | import ipywidgets as wg 7 | from IPython.display import display 8 | style={'description_width': 'initial'} 9 | 10 | 11 | 12 | def get_pipelin_widget(): 13 | 14 | pipeline = wg.Dropdown( 15 | options=['DiffusionModel3d', 'DiffusionModel2d'], 16 | description='Select Pipeline:', style=style) 17 | 18 | return pipeline 19 | 20 | 21 | def get_execution_widgets(): 22 | 23 | wg_execute = wg.Checkbox(description='Execute Now!', value=True, style=style) 24 | wg_arguments = wg.Checkbox(description='Get Command Line Arguments', value=True, style=style) 25 | 26 | return [wg_execute, wg_arguments] 27 | 28 | 29 | def get_synthesizer_widget(): 30 | 31 | pipeline = wg.Dropdown( 32 | options=['SyntheticTRIC', 'SyntheticCE', 'Synthetic2DGOWT1', 'Synthetic2DHeLa', 'SyntheticMeristem'], 33 | description='Select Synthesizer:', style=style) 34 | 35 | return pipeline 36 | 37 | 38 | 39 | def get_parameter_widgets(param_dict): 40 | 41 | param_names = [] 42 | widget_list = [] 43 | test_related = [] 44 | 45 | for key in param_dict.keys(): 46 | ### Script Parameter 47 | 48 | if key == 'output_path': 49 | widget_list.append(wg.Text(description='Output Path:', value=param_dict[key], style=style)) 50 | param_names.append('--'+key) 51 | test_related.append(True) 52 | if key == 'log_path': 53 | widget_list.append(wg.Text(description='Log Path:', value=param_dict[key], style=style)) 54 | param_names.append('--'+key) 55 | test_related.append(False) 56 | if key == 'gpus': 57 | widget_list.append(wg.IntSlider(description = 'Use GPU:', min=0, max=1, value=param_dict[key], style=style)) 58 | param_names.append('--'+key) 59 | test_related.append(True) 60 | if key == 'no_resume': 61 | widget_list.append(wg.Checkbox(description='Resume:', value=not param_dict[key], style=style)) 62 | param_names.append('--'+key) 63 | test_related.append(False) 64 | if key == 'pretrained': 65 | widget_list.append(wg.Text(description='Path To The Pretrained Model:', value=param_dict[key], style=style)) 66 | param_names.append('--'+key) 67 | test_related.append(False) 68 | if key == 'augmentations': 69 | widget_list.append(wg.Text(description='Augmentation Dictionary File:', value=param_dict[key], style=style)) 70 | param_names.append('--'+key) 71 | test_related.append(False) 72 | if key == 'epochs': 73 | widget_list.append(wg.BoundedIntText(description='Epochs:', min=1, max=10000, value=param_dict[key], style=style)) 74 | param_names.append('--'+key) 75 | test_related.append(False) 76 | 77 | 78 | ### Network Parameter 79 | 80 | if key == 'backbone': 81 | widget_list.append(wg.Dropdown(description='Network Architecture:', options=['UNet3D_PixelShuffle_inject', 'UNet2D_PixelShuffle_inject'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 82 | param_names.append('--'+key) 83 | test_related.append(True) 84 | if key == 'in_channels': 85 | widget_list.append(wg.BoundedIntText(description='Input Channels:', value=param_dict[key], min=1, max=10000, style=style)) 86 | param_names.append('--'+key) 87 | test_related.append(True) 88 | if key == 'out_channels': 89 | widget_list.append(wg.BoundedIntText(description='Output Channels:', value=param_dict[key], min=1, max=10000, style=style)) 90 | param_names.append('--'+key) 91 | test_related.append(True) 92 | if key == 'feat_channels': 93 | widget_list.append(wg.BoundedIntText(description='Feature Channels:', value=param_dict[key], min=2, max=10000, style=style)) 94 | param_names.append('--'+key) 95 | test_related.append(True) 96 | if key == 'patch_size': 97 | if not param_dict[key] is str: 98 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 99 | widget_list.append(wg.Text(description='Patch Size (z,y,x):', value=param_dict[key], style=style)) 100 | param_names.append('--'+key) 101 | test_related.append(True) 102 | if key == 'out_activation': 103 | widget_list.append(wg.Dropdown(description='Output Activation:', options=['tanh', 'sigmoid', 'hardtanh', 'relu', 'leakyrelu', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 104 | param_names.append('--'+key) 105 | test_related.append(True) 106 | if key == 'layer_norm': 107 | widget_list.append(wg.Dropdown(description='Layer Normalization:', options=['instance', 'batch', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 108 | param_names.append('--'+key) 109 | test_related.append(True) 110 | if key == 't_channels': 111 | widget_list.append(wg.BoundedIntText(description='T Channels:', value=param_dict[key], min=1, max=10000, style=style)) 112 | param_names.append('--'+key) 113 | test_related.append(True) 114 | 115 | ### Data Parameter 116 | 117 | if key == 'data_norm': 118 | widget_list.append(wg.Dropdown(description='Data Normalization:', options=['percentile', 'minmax', 'meanstd', 'minmax_shifted', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 119 | param_names.append('--'+key) 120 | test_related.append(True) 121 | if key == 'data_root': 122 | widget_list.append(wg.Text(value=param_dict[key], description='Data Root:', style=style)) 123 | param_names.append('--'+key) 124 | test_related.append(True) 125 | if key == 'train_list': 126 | widget_list.append(wg.Text(value=param_dict[key], description='Train List:', style=style)) 127 | param_names.append('--'+key) 128 | test_related.append(True) 129 | if key == 'test_list': 130 | widget_list.append(wg.Text(value=param_dict[key], description='Test List:', style=style)) 131 | param_names.append('--'+key) 132 | test_related.append(True) 133 | if key == 'val_list': 134 | widget_list.append(wg.Text(value=param_dict[key], description='Validation List:', style=style)) 135 | param_names.append('--'+key) 136 | test_related.append(True) 137 | if key == 'image_groups': 138 | if not param_dict[key] is str: 139 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 140 | widget_list.append(wg.Text(description='Image Groups:', value=param_dict[key], style=style)) 141 | param_names.append('--'+key) 142 | test_related.append(True) 143 | if key == 'mask_groups': 144 | if not param_dict[key] is str: 145 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 146 | widget_list.append(wg.Text(description='Mask Groups:', value=param_dict[key], style=style)) 147 | param_names.append('--'+key) 148 | test_related.append(True) 149 | if key == 'dist_handling': 150 | widget_list.append(wg.Dropdown(description='Distance Handling:', options=['float', 'bool', 'bool_inv', 'exp', 'tanh', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 151 | param_names.append('--'+key) 152 | test_related.append(True) 153 | if key == 'dist_scaling': 154 | if not param_dict[key] is str: 155 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 156 | widget_list.append(wg.Dropdown(description='Distance Scaling:', value=param_dict[key], style=style)) 157 | param_names.append('--'+key) 158 | test_related.append(True) 159 | if key == 'seed_handling': 160 | widget_list.append(wg.Dropdown(description='Seed Handling:', options=['float', 'bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 161 | param_names.append('--'+key) 162 | test_related.append(True) 163 | if key == 'boundary_handling': 164 | widget_list.append(wg.Dropdown(description='Boundary Handling:', options=['bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 165 | param_names.append('--'+key) 166 | test_related.append(True) 167 | if key == 'instance_handling': 168 | widget_list.append(wg.Dropdown(description='Instance Handling:', options=['bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 169 | param_names.append('--'+key) 170 | test_related.append(True) 171 | if key == 'strides': 172 | if not param_dict[key] is str: 173 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 174 | widget_list.append(wg.Dropdown(description='Strides:', value=param_dict[key], style=style)) 175 | param_names.append('--'+key) 176 | test_related.append(True) 177 | if key == 'sh_order': 178 | widget_list.append(wg.BoundedIntText(description='SH Order:', value=param_dict[key], min=0, max=10000, style=style)) 179 | param_names.append('--'+key) 180 | test_related.append(True) 181 | if key == 'mean_low': 182 | widget_list.append(wg.BoundedFloatText(description='Sampling Mean Weight Low:', value=param_dict[key], min=0, max=1000000, style=style)) 183 | param_names.append('--'+key) 184 | test_related.append(True) 185 | if key == 'mean_high': 186 | widget_list.append(wg.BoundedFloatText(description='Sampling Mean Weight High:', value=param_dict[key], min=0, max=1000000, style=style)) 187 | param_names.append('--'+key) 188 | test_related.append(True) 189 | if key == 'var_high': 190 | widget_list.append(wg.BoundedFloatText(description='Sampling Variance Weight High:', value=param_dict[key], min=0, max=1000000, style=style)) 191 | param_names.append('--'+key) 192 | test_related.append(True) 193 | if key == 'variance_levels': 194 | widget_list.append(wg.BoundedIntText(description='Sampling Variance Levels:', value=param_dict[key], min=0, max=10000, style=style)) 195 | param_names.append('--'+key) 196 | test_related.append(True) 197 | if key == 'strategy': 198 | widget_list.append(wg.Dropdown(description='Sampling Variance Strategy:', options=['random', 'structured'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 199 | param_names.append('--'+key) 200 | test_related.append(True) 201 | if key == 'image_noise_channel': 202 | widget_list.append(wg.BoundedIntText(description='Image Noise Channel:', value=param_dict[key], min=-5, max=5, style=style)) 203 | param_names.append('--'+key) 204 | test_related.append(True) 205 | if key == 'mask_noise_channel': 206 | widget_list.append(wg.BoundedIntText(description='Mask Noise Channel:', value=param_dict[key], min=-5, max=5, style=style)) 207 | param_names.append('--'+key) 208 | test_related.append(True) 209 | if key == 'noise_type': 210 | widget_list.append(wg.Dropdown(description='Noise Type:', options=['gaussian', 'rayleigh', 'laplace'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 211 | param_names.append('--'+key) 212 | test_related.append(True) 213 | 214 | ### Training Parameter 215 | 216 | if key == 'samples_per_epoch': 217 | widget_list.append(wg.BoundedIntText(description='Samples Per Epoch:', value=param_dict[key], min=-1, max=1000000, style=style)) 218 | param_names.append('--'+key) 219 | test_related.append(True) 220 | if key == 'batch_size': 221 | widget_list.append(wg.BoundedIntText(description='Batch Size:', value=param_dict[key], min=1, max=1000000, style=style)) 222 | param_names.append('--'+key) 223 | test_related.append(True) 224 | if key == 'learning_rate': 225 | widget_list.append(wg.BoundedFloatText(description='Learning Rate:', value=param_dict[key], min=0, max=1000000, style=style)) 226 | param_names.append('--'+key) 227 | test_related.append(True) 228 | if key == 'background_weight': 229 | widget_list.append(wg.BoundedFloatText(description='Background Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 230 | param_names.append('--'+key) 231 | test_related.append(True) 232 | if key == 'seed_weight': 233 | widget_list.append(wg.BoundedFloatText(description='Seed Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 234 | param_names.append('--'+key) 235 | test_related.append(True) 236 | if key == 'boundary_weight': 237 | widget_list.append(wg.BoundedFloatText(description='Boundary Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 238 | param_names.append('--'+key) 239 | test_related.append(True) 240 | if key == 'flow_weight': 241 | widget_list.append(wg.BoundedFloatText(description='Flow Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 242 | param_names.append('--'+key) 243 | test_related.append(True) 244 | if key == 'centroid_weight': 245 | widget_list.append(wg.BoundedFloatText(description='Centroid Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 246 | param_names.append('--'+key) 247 | test_related.append(True) 248 | if key == 'encoding_weight': 249 | widget_list.append(wg.BoundedFloatText(description='Encoding Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 250 | param_names.append('--'+key) 251 | test_related.append(True) 252 | if key == 'robustness_weight': 253 | widget_list.append(wg.BoundedFloatText(description='Robustness Weight:', value=param_dict[key], min=0, max=1000000, style=style)) 254 | param_names.append('--'+key) 255 | test_related.append(True) 256 | if key == 'variance_interval': 257 | widget_list.append(wg.BoundedIntText(description='Sampling Variance Interval:', value=param_dict[key], min=0, max=10000, style=style)) 258 | param_names.append('--'+key) 259 | test_related.append(True) 260 | if key == 'ada_update_period': 261 | widget_list.append(wg.BoundedIntText(description='ADA Update Period:', value=param_dict[key], min=0, max=10000, style=style)) 262 | param_names.append('--'+key) 263 | test_related.append(True) 264 | if key == 'ada_update': 265 | widget_list.append(wg.BoundedFloatText(description='ADA Update Step:', value=param_dict[key], min=0, max=1000000, style=style)) 266 | param_names.append('--'+key) 267 | test_related.append(True) 268 | if key == 'ada_target': 269 | widget_list.append(wg.BoundedFloatText(description='ADA Target:', value=param_dict[key], min=0, max=1000000, style=style)) 270 | param_names.append('--'+key) 271 | test_related.append(True) 272 | if key == 'num_samples': 273 | widget_list.append(wg.BoundedIntText(description='Number Of Samples:', value=param_dict[key], min=0, max=10000, style=style)) 274 | param_names.append('--'+key) 275 | test_related.append(True) 276 | 277 | # diffusion parameter 278 | if key == 'num_timesteps': 279 | widget_list.append(wg.BoundedIntText(description='Number Of Timesteps:', value=param_dict[key], min=0, max=1000, style=style)) 280 | param_names.append('--'+key) 281 | test_related.append(True) 282 | if key == 'diffusion_schedule': 283 | widget_list.append(wg.Dropdown(description='Diffusion Schedule:', options=['cosine', 'linear', 'quadratic', 'sigmoid'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 284 | param_names.append('--'+key) 285 | test_related.append(True) 286 | return param_names, widget_list, test_related 287 | 288 | 289 | 290 | def get_apply_parameter_widgets(param_dict): 291 | 292 | param_names = [] 293 | widget_list = [] 294 | 295 | for key in param_dict.keys(): 296 | 297 | 298 | ### Script Parameter 299 | 300 | if key == 'output_path': 301 | widget_list.append(wg.Text(description='Output Path:', value=param_dict[key], style=style)) 302 | param_names.append('--'+key) 303 | elif key == 'ckpt_path': 304 | widget_list.append(wg.Text(description='Checkpoint Path:', value=param_dict[key], style=style)) 305 | param_names.append('--'+key) 306 | elif key == 'gpus': 307 | widget_list.append(wg.IntSlider(description = 'Use GPU:', min=0, max=1, value=param_dict[key], style=style)) 308 | param_names.append('--'+key) 309 | elif key == 'distributed_backend': 310 | widget_list.append(wg.Dropdown(description='Distributed Backend:', options=['dp', 'ddp', 'ddp2'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 311 | param_names.append('--'+key) 312 | elif key == 'overlap': 313 | if not param_dict[key] is str: 314 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 315 | widget_list.append(wg.Text(description='Overlap (z,y,x):', value=param_dict[key], style=style)) 316 | param_names.append('--'+key) 317 | elif key == 'crop': 318 | if not param_dict[key] is str: 319 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 320 | widget_list.append(wg.Text(description='Crop (z,y,x):', value=param_dict[key], style=style)) 321 | param_names.append('--'+key) 322 | elif key == 'input_batch': 323 | widget_list.append(wg.Dropdown(description='Input Batch:', options=['image', 'mask'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 324 | param_names.append('--'+key) 325 | elif key == 'clip': 326 | if not param_dict[key] is str: 327 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 328 | widget_list.append(wg.Text(description='Clip (min, max):', value=param_dict[key], style=style)) 329 | param_names.append('--'+key) 330 | elif key == 'num_files': 331 | widget_list.append(wg.BoundedIntText(description='Number of Files:', value=param_dict[key], min=-1, max=10000, style=style)) 332 | param_names.append('--'+key) 333 | elif key == 'add_noise_channel': 334 | widget_list.append(wg.BoundedIntText(description='Noise Channel:', value=param_dict[key], min=-2, max=10000, style=style)) 335 | param_names.append('--'+key) 336 | elif key == 'theta_phi_sampling': 337 | widget_list.append(wg.Text(description='Angular Sampling File Path:', value=param_dict[key], style=style)) 338 | param_names.append('--'+key) 339 | 340 | ### Network Parameter 341 | elif key == 'out_channels': 342 | widget_list.append(wg.BoundedIntText(description='Output Channels:', value=param_dict[key], min=1, max=10000, style=style)) 343 | param_names.append('--'+key) 344 | elif key == 'patch_size': 345 | if not param_dict[key] is str: 346 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 347 | widget_list.append(wg.Text(description='Patch Size (z,y,x):', value=param_dict[key], style=style)) 348 | param_names.append('--'+key) 349 | elif key == 'resolution_weights': 350 | if not param_dict[key] is str: 351 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 352 | widget_list.append(wg.Text(description='Prediction Weights at Each Resolution:', value=param_dict[key], style=style)) 353 | param_names.append('--'+key) 354 | elif key == 'centroid_thresh': 355 | widget_list.append(wg.BoundedFloatText(description='Centroid Threshold:', value=param_dict[key], min=0, max=100, style=style)) 356 | param_names.append('--'+key) 357 | elif key == 'minsize': 358 | widget_list.append(wg.BoundedIntText(description='The Minimum Cell Size:', value=param_dict[key], min=0, max=100, style=style)) 359 | param_names.append('--'+key) 360 | elif key == 'maxsize': 361 | widget_list.append(wg.BoundedIntText(description='The Maximum Cell Size:', value=param_dict[key], min=0, max=100, style=style)) 362 | param_names.append('--'+key) 363 | elif key == 'use_watershed': 364 | widget_list.append(wg.Checkbox(description='Use Watershed:', value=param_dict[key], style=style)) 365 | param_names.append('--'+key) 366 | elif key == 'use_sizefilter': 367 | widget_list.append(wg.Checkbox(description='Use Size Filter:', value=param_dict[key], style=style)) 368 | param_names.append('--'+key) 369 | elif key == 'use_nms': 370 | widget_list.append(wg.Checkbox(description='Use non-Maximum Suppression:', value=param_dict[key], style=style)) 371 | param_names.append('--'+key) 372 | ### Data Parameter 373 | elif key == 'data_root': 374 | widget_list.append(wg.Text(value=param_dict[key], description='Data Root:', style=style)) 375 | param_names.append('--'+key) 376 | elif key == 'test_list': 377 | widget_list.append(wg.Text(value=param_dict[key], description='Test List:', style=style)) 378 | param_names.append('--'+key) 379 | elif key == 'image_groups': 380 | if not param_dict[key] is str: 381 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 382 | widget_list.append(wg.Text(description='Image Groups:', value=param_dict[key], style=style)) 383 | param_names.append('--'+key) 384 | elif key == 'mask_groups': 385 | if not param_dict[key] is str: 386 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 387 | widget_list.append(wg.Text(description='Mask Groups:', value=param_dict[key], style=style)) 388 | param_names.append('--'+key) 389 | elif key == 'sh_order': 390 | widget_list.append(wg.BoundedIntText(description='SH Order:', value=param_dict[key], min=0, max=10000, style=style)) 391 | param_names.append('--'+key) 392 | 393 | # diffusion parameter 394 | elif key == 'timesteps_start': 395 | widget_list.append(wg.BoundedIntText(description='Timestep to Start the Reverse Process:', value=param_dict[key], min=0, max=1000, style=style)) 396 | param_names.append('--'+key) 397 | elif key == 'timesteps_save': 398 | widget_list.append(wg.BoundedIntText(description='Number of Timesteps Between Saves:', value=param_dict[key], min=0, max=1000, style=style)) 399 | param_names.append('--'+key) 400 | elif key == 'timesteps_step': 401 | widget_list.append(wg.BoundedIntText(description='Timesteps Skipped Between Iterations:', value=param_dict[key], min=0, max=1000, style=style)) 402 | param_names.append('--'+key) 403 | elif key == 'blur_sigma': 404 | widget_list.append(wg.BoundedIntText(description='Sigma for Gaussian Blurring of Inputs:', value=param_dict[key], min=0, max=5, style=style)) 405 | param_names.append('--'+key) 406 | elif key == 'num_timesteps': 407 | widget_list.append(wg.BoundedIntText(description='Total Number Of Training Timesteps:', value=param_dict[key], min=0, max=1000, style=style)) 408 | param_names.append('--'+key) 409 | elif key == 'diffusion_schedule': 410 | widget_list.append(wg.Dropdown(description='Diffusion Schedule:', options=['cosine', 'linear', 'quadratic', 'sigmoid'], value=param_dict[key], layout={'width': 'max-content'}, style=style)) 411 | param_names.append('--'+key) 412 | 413 | return param_names, widget_list 414 | 415 | 416 | def get_sim_parameter_widgets(param_dict): 417 | 418 | param_names = [] 419 | widget_list = [] 420 | 421 | for key in param_dict.keys(): 422 | if key == 'save_path': 423 | widget_list.append(wg.Text(description='Save Path:', value=param_dict[key], style=style)) 424 | param_names.append('--'+key) 425 | elif key == 'experiment_name': 426 | widget_list.append(wg.Text(description='Experiment Name:', value=param_dict[key], style=style)) 427 | param_names.append('--'+key) 428 | elif key == 'num_imgs' or key == 'img_count': 429 | widget_list.append(wg.BoundedIntText(description='Number of Images:', value=param_dict[key], min=1, max=1000, style=style)) 430 | param_names.append('--'+key) 431 | 432 | # Nuclei masks 433 | elif key == 'img_shape': 434 | if not param_dict[key] is str: 435 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 436 | widget_list.append(wg.Text(description='Image Shape:', value=param_dict[key], style=style)) 437 | param_names.append('--'+key) 438 | elif key == 'max_radius': 439 | widget_list.append(wg.BoundedIntText(description='Maximum Radius:', value=param_dict[key], min=1, max=1000, style=style)) 440 | param_names.append('--'+key) 441 | elif key == 'min_radius': 442 | widget_list.append(wg.BoundedIntText(description='Minimum Radius:', value=param_dict[key], min=1, max=100, style=style)) 443 | param_names.append('--'+key) 444 | elif key == 'radius_range': 445 | widget_list.append(wg.BoundedIntText(description='Radius Range:', value=param_dict[key], min=-100, max=100, style=style)) 446 | param_names.append('--'+key) 447 | elif key == 'sh_order': 448 | widget_list.append(wg.BoundedIntText(description='Sh Order :', value=param_dict[key], min=1, max=100, style=style)) 449 | param_names.append('--'+key) 450 | elif key == 'smooth_std': 451 | widget_list.append(wg.BoundedFloatText(description='Smoothing Std :', value=param_dict[key], min=0, max=100, style=style)) 452 | param_names.append('--'+key) 453 | elif key == 'noise_std': 454 | widget_list.append(wg.BoundedFloatText(description='Noise Std :', value=param_dict[key], min=0, max=100, style=style)) 455 | param_names.append('--'+key) 456 | elif key == 'noise_mean': 457 | widget_list.append(wg.BoundedFloatText(description='Noise Mean :', value=param_dict[key], min=0, max=100, style=style)) 458 | param_names.append('--'+key) 459 | elif key == 'position_std': 460 | widget_list.append(wg.BoundedFloatText(description='Position Std :', value=param_dict[key], min=0, max=100, style=style)) 461 | param_names.append('--'+key) 462 | elif key == 'num_cells': 463 | widget_list.append(wg.BoundedIntText(description='Number of Cells :', value=param_dict[key], min=1, max=1000, style=style)) 464 | param_names.append('--'+key) 465 | elif key == 'num_cells_range': 466 | widget_list.append(wg.BoundedIntText(description='Number of Cells Range:', value=param_dict[key], min=1, max=1000, style=style)) 467 | param_names.append('--'+key) 468 | elif key == 'circularity': 469 | widget_list.append(wg.BoundedFloatText(description='Circularity :', value=param_dict[key], min=0, max=100, style=style)) 470 | param_names.append('--'+key) 471 | elif key == 'generate_images': 472 | widget_list.append(wg.Checkbox(description='Generate Images ', value=False, style=style)) 473 | param_names.append('--'+key) 474 | elif key == 'theta_phi_sampling_file': 475 | widget_list.append(wg.Text(value=param_dict[key], description='Theta Phi Sampling File:', style=style)) 476 | param_names.append('--'+key) 477 | elif key == 'cell_elongation': 478 | widget_list.append(wg.BoundedFloatText(description='Cell Elongation :', value=param_dict[key], min=0, max=100, style=style)) 479 | param_names.append('--'+key) 480 | elif key == 'z_anisotropy': 481 | widget_list.append(wg.BoundedFloatText(description='Z Anisotropy :', value=param_dict[key], min=0, max=100, style=style)) 482 | param_names.append('--'+key) 483 | elif key == 'irregularity_extend': 484 | widget_list.append(wg.BoundedFloatText(description='Irregularity Extend :', value=param_dict[key], min=0, max=100, style=style)) 485 | param_names.append('--'+key) 486 | 487 | # Membrane masks 488 | elif key == 'gridsize': 489 | if not param_dict[key] is str: 490 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]]) 491 | widget_list.append(wg.Text(description='Grid Size:', value=param_dict[key], style=style)) 492 | param_names.append('--'+key) 493 | elif key == 'distance_weight': 494 | widget_list.append(wg.BoundedFloatText(description='Distance Weight :', value=param_dict[key], min=0, max=10, style=style)) 495 | param_names.append('--'+key) 496 | elif key == 'morph_radius': 497 | widget_list.append(wg.BoundedIntText(description='Morph Radius:', value=param_dict[key], min=1, max=1000, style=style)) 498 | param_names.append('--'+key) 499 | elif key == 'weights': 500 | widget_list.append(wg.Text(description='Cell Weights:', value=param_dict[key], style=style)) 501 | param_names.append('--'+key) 502 | elif key == 'cell_density': 503 | widget_list.append(wg.BoundedFloatText(description='Cell Density:', value=param_dict[key], min=0, max=10, style=style)) 504 | param_names.append('--'+key) 505 | elif key == 'cell_density_decay': 506 | widget_list.append(wg.BoundedFloatText(description='Cell Density Decay:', value=param_dict[key], min=0, max=10, style=style)) 507 | param_names.append('--'+key) 508 | elif key == 'cell_position_smoothness': 509 | widget_list.append(wg.BoundedIntText(description='Cell Position Smoothness:', value=param_dict[key], min=1, max=1000, style=style)) 510 | param_names.append('--'+key) 511 | elif key == 'ring_density': 512 | widget_list.append(wg.BoundedFloatText(description='Ring Density:', value=param_dict[key], min=0, max=10, style=style)) 513 | param_names.append('--'+key) 514 | elif key == 'ring_density_decay': 515 | widget_list.append(wg.BoundedFloatText(description='Ring Density Decay:', value=param_dict[key], min=0, max=10, style=style)) 516 | param_names.append('--'+key) 517 | elif key == 'angular_sampling_file': 518 | widget_list.append(wg.Text(description='Angular Sampling File:', value=param_dict[key], style=style)) 519 | param_names.append('--'+key) 520 | elif key == 'specimen_sampling_file': 521 | widget_list.append(wg.Text(description='Specimen Sampling File:', value=param_dict[key], style=style)) 522 | param_names.append('--'+key) 523 | return param_names, widget_list 524 | 525 | -------------------------------------------------------------------------------- /utils/synthetic_cell_membrane_masks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from skimage import morphology, measure, io 6 | from scipy.ndimage import distance_transform_edt 7 | from scipy.spatial import cKDTree 8 | from sklearn.decomposition import PCA 9 | from sklearn.cluster import AgglomerativeClustering 10 | 11 | from utils.h5_converter import h5_writer 12 | from utils.utils import print_timestamp 13 | 14 | 15 | 16 | def generate_data(syn_class, save_path='Segmentations_h5', experiment_name='synthetic_data', img_count=100, param_dict={}): 17 | 18 | synthesizer = syn_class(**param_dict) 19 | os.makedirs(save_path, exist_ok=True) 20 | 21 | for num_img in range(img_count): 22 | 23 | print_timestamp('_'*20) 24 | print_timestamp('Generating mask {0}/{1}', (num_img+1, img_count)) 25 | 26 | # Generate a new mask 27 | synthesizer.generate_instances() 28 | 29 | # Get and save the instance, boundary, centroid and distance masks 30 | print_timestamp('Saving...') 31 | instance_mask = synthesizer.get_instance_mask().astype(np.uint16) 32 | instance_mask = np.transpose(instance_mask, [2,1,0]) 33 | boundary_mask = synthesizer.get_boundary_mask().astype(np.uint8) 34 | boundary_mask = np.transpose(boundary_mask, [2,1,0]) 35 | distance_mask = synthesizer.get_distance_mask().astype(np.float32) 36 | distance_mask = np.transpose(distance_mask, [2,1,0]) 37 | centroid_mask = synthesizer.get_centroid_mask().astype(np.uint8) 38 | centroid_mask = np.transpose(centroid_mask, [2,1,0]) 39 | 40 | save_name = os.path.join(save_path, experiment_name+'_'+str(num_img)+'.h5') 41 | h5_writer([instance_mask, boundary_mask, distance_mask, centroid_mask], save_name, group_names=['instances', 'boundary', 'distance', 'seeds']) 42 | 43 | 44 | 45 | def agglomerative_clustering(x_samples, y_samples, z_samples, max_dist=10): 46 | 47 | samples = np.array([x_samples, y_samples, z_samples]).T 48 | 49 | clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=max_dist, linkage='complete').fit(samples) 50 | 51 | cluster_labels = clustering.labels_ 52 | 53 | cluster_samples_x = [] 54 | cluster_samples_y = [] 55 | cluster_samples_z = [] 56 | for label in np.unique(cluster_labels): 57 | 58 | cluster_samples_x.append(int(np.mean(x_samples[cluster_labels==label]))) 59 | cluster_samples_y.append(int(np.mean(y_samples[cluster_labels==label]))) 60 | cluster_samples_z.append(int(np.mean(z_samples[cluster_labels==label]))) 61 | 62 | cluster_samples_x = np.array(cluster_samples_x) 63 | cluster_samples_y = np.array(cluster_samples_y) 64 | cluster_samples_z = np.array(cluster_samples_z) 65 | 66 | return cluster_samples_x, cluster_samples_y, cluster_samples_z 67 | 68 | 69 | 70 | 71 | class SyntheticCellMembranes: 72 | """Parent class for generating synthetic cell membranes""" 73 | 74 | 75 | 76 | def __init__(self, gridsize=(128,128,128), distance_weight=0.25, cell_density=1/20): 77 | 78 | self.gridsize = gridsize 79 | self.distance_weight = distance_weight 80 | self.cell_density = cell_density 81 | 82 | self.instance_mask = None 83 | self.x_fg = [] 84 | self.y_fg = [] 85 | self.z_fg = [] 86 | self.x_cell = [] 87 | self.y_cell = [] 88 | self.z_cell = [] 89 | 90 | 91 | 92 | def _cart2sphere(self, x, y, z): 93 | 94 | n = x**2 + y**2 + z**2 95 | n = n.astype(np.float) 96 | n[n==0] += 1e-7 # prevent zero divisions 97 | r = np.sqrt(n) 98 | p = np.arctan2(y, x) 99 | t = np.arccos(z/np.sqrt(n)) 100 | 101 | return r, t, p 102 | 103 | 104 | 105 | def _sphere2cart(self, r, t, p): 106 | 107 | x = np.sin(t) * np.cos(p) * r 108 | y = np.sin(t) * np.sin(p) * r 109 | z = np.cos(t) * r 110 | 111 | return x, y, z 112 | 113 | 114 | 115 | def generate_instances(self): 116 | 117 | print_timestamp('Generating foreground region...') 118 | self._generate_foreground() 119 | print_timestamp('Placing random centroids...') 120 | self._place_centroids() 121 | print_timestamp('Creating Voronoi tessellation...') 122 | self._voronoi_tessellation() 123 | print_timestamp('Morphological postprocessing...') 124 | self._post_processing() 125 | 126 | 127 | 128 | # sphere generation 129 | def _generate_foreground(self): 130 | 131 | # Determine foreground region 132 | image_ind = np.indices(self.gridsize, dtype=np.int) 133 | x_fg = image_ind[0].flatten() 134 | y_fg = image_ind[1].flatten() 135 | z_fg = image_ind[2].flatten() 136 | 137 | self.x_fg = x_fg 138 | self.y_fg = y_fg 139 | self.z_fg = z_fg 140 | 141 | 142 | 143 | def _place_centroids(self): 144 | 145 | # calculate the volume 146 | vol = len(self.x_fg) 147 | 148 | # estimate the number of cells 149 | cell_count = vol*(self.cell_density**3) 150 | cell_count = int(cell_count) 151 | 152 | # select random points within the foregound region 153 | rnd_idx = np.random.choice(np.arange(0,len(self.x_fg)), size=cell_count, replace=False) 154 | 155 | x_cell = self.x_fg[rnd_idx].copy() 156 | y_cell = self.y_fg[rnd_idx].copy() 157 | z_cell = self.z_fg[rnd_idx].copy() 158 | 159 | # perform clustering 160 | x_cell = np.array(x_cell) 161 | y_cell = np.array(y_cell) 162 | z_cell = np.array(z_cell) 163 | x_cell, y_cell, z_cell = agglomerative_clustering(x_cell, y_cell, z_samples=z_cell, max_dist=2*self.cell_density**-1) 164 | 165 | self.x_cell = x_cell 166 | self.y_cell = y_cell 167 | self.z_cell = z_cell 168 | 169 | 170 | 171 | def _voronoi_tessellation(self): 172 | 173 | # get each foreground sample to be tesselated 174 | samples = list(zip(self.x_fg, self.y_fg, self.z_fg)) 175 | 176 | # get each cell centroid 177 | cells = list(zip(self.x_cell, self.y_cell, self.z_cell)) 178 | cells_tree = cKDTree(cells) 179 | 180 | # determine weights for each cell (improved roundness) 181 | closest_cell_dist, _ = cells_tree.query(cells, k=2) 182 | weights = closest_cell_dist[:,1]/np.max(closest_cell_dist[:,1])*self.distance_weight + 1 183 | 184 | # determine the closest cell candidates for each sampling point 185 | candidates_dist, candidates_idx = cells_tree.query(samples, k=np.minimum(5,len(self.x_cell)-1)) 186 | candidates_dist = np.multiply(candidates_dist, weights[candidates_idx]) 187 | candidates_closest = np.argmin(candidates_dist, axis=1) 188 | 189 | instance_mask = np.zeros(self.gridsize, dtype=np.uint16) 190 | for sample, cand_idx, closest_idx in zip(samples, candidates_idx, candidates_closest): 191 | instance_mask[sample] = cand_idx[closest_idx]+1 192 | 193 | self.instance_mask = instance_mask 194 | 195 | 196 | 197 | def _post_processing(self): 198 | pass 199 | 200 | 201 | 202 | def get_instance_mask(self): 203 | 204 | return self.instance_mask 205 | 206 | 207 | 208 | def get_centroid_mask(self, data=None): 209 | 210 | if data is None: 211 | if self.instance_mask is None: 212 | return None 213 | else: 214 | data = self.instance_mask 215 | 216 | # create centroid mask 217 | centroid_mask = np.zeros(data.shape, dtype=np.bool) 218 | 219 | # find and place centroids 220 | regions = measure.regionprops(data) 221 | for props in regions: 222 | c = props.centroid 223 | centroid_mask[np.int(c[0]), np.int(c[1]), np.int(c[2])] = True 224 | 225 | return centroid_mask 226 | 227 | 228 | 229 | def get_boundary_mask(self, data=None): 230 | 231 | # get instance mask 232 | if data is None: 233 | if self.instance_mask is None: 234 | return None 235 | else: 236 | data = self.instance_mask 237 | 238 | membrane_mask = morphology.dilation(data, selem=morphology.ball(3)) - data 239 | membrane_mask = membrane_mask != 0 240 | 241 | return membrane_mask 242 | 243 | 244 | def get_distance_mask(self, data=None): 245 | 246 | # get instance mask 247 | if data is None: 248 | if self.instance_mask is None: 249 | return None 250 | else: 251 | data = self.instance_mask 252 | 253 | distance_encoding = np.zeros(data.shape, dtype=np.float32) 254 | 255 | # get foreground distance 256 | distance_encoding = distance_transform_edt(data>0) 257 | 258 | # get background distance 259 | distance_encoding = distance_encoding - distance_transform_edt(data<=0) 260 | 261 | return distance_encoding 262 | 263 | 264 | 265 | 266 | 267 | 268 | # Class for generation of synthetic meristem data 269 | class SyntheticMeristem(SyntheticCellMembranes): 270 | """Child class for generating synthetic meristem membranes""" 271 | 272 | 273 | 274 | def __init__(self, gridsize=(120,512,512), distance_weight=0.25, # general params 275 | morph_radius=3, weights=None, # foreground params 276 | cell_density=1/23, cell_density_decay=0.9, cell_position_smoothness=10, # cell params 277 | ring_density=1/23, ring_density_decay=0.9, # ring params 278 | angular_sampling_file=r'utils/theta_phi_sampling_5000points_10000iter.npy', 279 | specimen_sampling_file=r'utils/PNAS_sampling.csv'): 280 | 281 | super().__init__(gridsize=gridsize, distance_weight=distance_weight, cell_density=cell_density) 282 | 283 | # foregound params 284 | self.gridsize_max = np.max(gridsize) 285 | self.morph_radius = morph_radius 286 | self.weights = weights 287 | 288 | # cell params 289 | self.cell_density_decay = cell_density_decay 290 | self.cell_position_smoothness = cell_position_smoothness 291 | 292 | # ring params 293 | self.ring_density = ring_density 294 | self.ring_density_decay = ring_density_decay 295 | 296 | self.angular_sampling_file = angular_sampling_file 297 | self.specimen_sampling_file = specimen_sampling_file 298 | 299 | # initialize the statistical shape model 300 | print_timestamp('Initializing the statistical shape model...') 301 | self.sampling_angles = np.load(angular_sampling_file) 302 | specimen_sampling = pd.read_csv(specimen_sampling_file, sep=';').to_numpy() 303 | self.specimen_mean = np.mean(specimen_sampling, axis=1) 304 | 305 | # calculate the PCA 306 | specimen_cov = np.cov(specimen_sampling) 307 | specimen_pca = PCA(n_components=3) 308 | specimen_pca.fit(specimen_cov) 309 | self.specimen_pca = specimen_pca 310 | 311 | if not self.weights is None: 312 | assert len(self.weights) == len(self.specimen_pca.singular_values_), 'Number of weights ({0}) does not match the number of eigenvectors {1}'.format(len(self.weights), len(self.specimen_pca.singular_values_)) 313 | 314 | 315 | # sphere generation 316 | def _generate_foreground(self): 317 | 318 | # Construct the sampling tree 319 | sampling_tree = cKDTree(self.sampling_angles) 320 | 321 | if self.weights is None: 322 | # Generate random sampling 323 | weights = np.random.randn(len(self.specimen_pca.singular_values_)) 324 | else: 325 | weights = np.array(self.weights) 326 | specimen_rnd = self.specimen_mean + np.matmul(self.specimen_pca.components_.T, np.sqrt(self.specimen_pca.singular_values_) * weights) 327 | 328 | # Normalize the radii 329 | specimen_rnd /= specimen_rnd.max() 330 | 331 | # Calculate the image size based on the shape model and the desired gridsize 332 | specimen_x, specimen_y, specimen_z = self._sphere2cart(specimen_rnd,self.sampling_angles[:,0],self.sampling_angles[:,1]) 333 | specimen_dim_ratio = np.array([2*np.abs(specimen_x.max()), 2*np.abs(specimen_y.max()), np.abs(specimen_z.max())]) # *2 on x and y, since it's a hemisphere starting at the center 334 | specimen_dim_ratio /= specimen_dim_ratio.max() 335 | self.gridsize = np.array(specimen_dim_ratio*self.gridsize_max, dtype=np.int) 336 | 337 | # Adjust to desired grid size 338 | specimen_rnd *= self.gridsize_max/2*0.95 339 | self.sampling_radii = specimen_rnd.copy() 340 | 341 | # Determine foreground region 342 | image_ind = np.indices(self.gridsize, dtype=np.int) 343 | x = image_ind[0].flatten()-self.gridsize[0]/2 344 | y = image_ind[1].flatten()-self.gridsize[1]/2 345 | z = image_ind[2].flatten() 346 | 347 | r,t,p = self._cart2sphere(x,y,z) 348 | 349 | # Determine nearest sampling angles 350 | _, assignments = sampling_tree.query(np.array([t,p]).T, k=3) 351 | specimen_rnd[specimen_rnd==0] = np.nan 352 | 353 | # Determine foreground region 354 | foreground_ind = r <= np.nanmean(specimen_rnd[assignments], axis=1) 355 | x_fg = x[foreground_ind]+self.gridsize[0]/2 356 | x_fg = x_fg.astype(np.int) 357 | y_fg = y[foreground_ind]+self.gridsize[1]/2 358 | y_fg = y_fg.astype(np.int) 359 | z_fg = z[foreground_ind] 360 | z_fg = z_fg.astype(np.int) 361 | 362 | self.x_fg = x_fg 363 | self.y_fg = y_fg 364 | self.z_fg = z_fg 365 | 366 | 367 | 368 | def _place_centroids(self): 369 | 370 | cell_density = self.cell_density 371 | ring_density = self.ring_density 372 | cluster_density = cell_density 373 | 374 | # initialize the first radius (small offset to place the first ring near the boundary) 375 | ring_radii = self.sampling_radii + (self.ring_density**-1)/2 376 | 377 | x_cell = [] 378 | y_cell = [] 379 | z_cell = [] 380 | ring_count = 0 381 | 382 | while ring_radii.max() - ring_density**-1 > 0: 383 | 384 | # set off the new ring 385 | ring_radii = np.maximum(0, ring_radii - ring_density**-1) 386 | 387 | # get the spherical coordinated of the new ring 388 | t = self.sampling_angles[ring_radii != 0, 0] 389 | p = self.sampling_angles[ring_radii != 0, 1] 390 | r = ring_radii[ring_radii != 0] 391 | 392 | # apply small offsets to the radii, depending on the depth of the current ring 393 | r = r + np.random.normal(loc=0, scale=ring_count/cell_density/self.cell_position_smoothness, size=len(r)) 394 | r = np.maximum(0, r) 395 | 396 | # convert to cartesian coordinates 397 | x,y,z = self._sphere2cart(r,t,p) 398 | x = np.array(x) 399 | y = np.array(y) 400 | z = np.array(z) 401 | 402 | # cluster cell centroid candidates to reduce computation afterwards 403 | if len(x) > 1: 404 | x, y, z = agglomerative_clustering(x, y, z_samples=z, max_dist=(cell_density**-1)) 405 | 406 | # extend the centroid list 407 | x_cell.extend(x+self.gridsize[0]//2) 408 | y_cell.extend(y+self.gridsize[1]//2) 409 | z_cell.extend(z) 410 | 411 | cluster_density = np.max([cluster_density, cell_density, ring_density]) 412 | cell_density = cell_density*self.cell_density_decay 413 | ring_density = ring_density*self.ring_density_decay 414 | ring_count += 1 415 | 416 | # perform clustering 417 | x_cell = np.array(x_cell) 418 | y_cell = np.array(y_cell) 419 | z_cell = np.array(z_cell) 420 | x_cell, y_cell, z_cell = agglomerative_clustering(x_cell, y_cell, z_samples=z_cell, max_dist=cluster_density**-1) 421 | 422 | self.x_cell = x_cell 423 | self.y_cell = y_cell 424 | self.z_cell = z_cell 425 | 426 | 427 | 428 | def _post_processing(self): 429 | 430 | # get the memebrane mask 431 | membrane_mask = self.get_boundary_mask() 432 | 433 | # open the inner parts of the cells 434 | opened_fg = morphology.binary_opening(~membrane_mask, selem=morphology.ball(self.morph_radius)) 435 | opened_fg[self.instance_mask==0] = False 436 | 437 | # erode the fg region 438 | eroded_fg = morphology.binary_erosion(self.instance_mask>0, selem=morphology.ball(int(self.morph_radius*1.5))) 439 | 440 | # generate the enhanced foreground mask 441 | enhanced_fg = np.logical_or(opened_fg, eroded_fg) 442 | 443 | # generate the instance mask 444 | self.instance_mask[~enhanced_fg] = 0 445 | self.x_fg, self.y_fg, self.z_fg = np.nonzero(self.instance_mask) -------------------------------------------------------------------------------- /utils/theta_phi_sampling_5000points_10000iter.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/utils/theta_phi_sampling_5000points_10000iter.npy -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu May 20 13:53:11 2021 4 | 5 | @author: Nutzer 6 | """ 7 | 8 | import time 9 | import os 10 | import glob 11 | from skimage import io 12 | 13 | 14 | def print_timestamp(msg, args=None): 15 | 16 | print('[{0:02.0f}:{1:02.0f}:{2:02.0f}] '.format(time.localtime().tm_hour,time.localtime().tm_min,time.localtime().tm_sec) +\ 17 | msg.format(*args if not args is None else '')) 18 | 19 | 20 | 21 | 22 | --------------------------------------------------------------------------------