├── LICENSE
├── README.md
├── ThirdParty
├── diffusion.py
└── layers.py
├── apply_script_diffusion.py
├── data_samples
├── experiment_3D
│ ├── Diffusion_3DCTC_CE_epoch=4999.ckpt
│ ├── pred_0_sketch3D_sim_CTCCE_0.tif
│ ├── pred_1_sketch3D_sim_CTCCE_0.tif
│ ├── pred_2_sketch3D_sim_CTCCE_0.tif
│ └── pred_3_sketch3D_sim_CTCCE_0.tif
├── image3D_sim_CTCCE_0.h5
├── image3D_sim_CTCCE_0.tif
├── image_files_3D.csv
├── sketch3D_sim_CTCCE_0.h5
├── sketch3D_sim_CTCCE_0.tif
└── sketch_files_3D.csv
├── dataloader
├── augmenter.py
└── h5_dataloader.py
├── figures
├── example_data.png
├── multi-channel.png
├── overlapping_cells.png
└── timeseries.gif
├── models
├── DiffusionModel2D.py
├── DiffusionModel3D.py
├── module_UNet2D_pixelshuffle_inject.py
└── module_UNet3D_pixelshuffle_inject.py
├── notebooks
├── README.md
├── jupyter_apply_script.ipynb
├── jupyter_preparation_script.ipynb
└── jupyter_train_script.ipynb
├── requirements.txt
├── train_script.py
└── utils
├── PNAS_sampling.csv
├── csv_generator.py
├── h5_converter.py
├── harmonics.py
├── jupyter_widgets.py
├── synthetic_cell_membrane_masks.py
├── synthetic_cell_nuclei_masks.py
├── theta_phi_sampling_5000points_10000iter.npy
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Denoising Diffusion Probabilistic Models for Generation of Realistic Fully-Annotated Microscopy Image Data Sets
2 |
3 | This repository contains code to simulated 2D/3D cellular structures and synthesize corresponding microscopy image data based on Denoising Diffusion Probabilistic Models (DDPM).
4 | Sketches are generated to indicate cell shapes and structural characteristics, and they serve as a basis for the diffusion process to ultimately allow for the generation of fully-annotated microscopy image data sets without the need for human annotation effort.
5 | Generated data sets are available at OSF and the article is available at PLOS CB.
6 | To access the trained models and get a showcase of the fully-simulated data sets, please visit to our website (work in progress).
7 | 
8 | 


9 | Exemplary synthetic samples from our experiments
10 |
11 |
12 | If you are using code or data, please cite the following work:
13 | ```
14 | @article{eschweiler2024celldiffusion,
15 | title={Denoising diffusion probabilistic models for generation of realistic fully-annotated microscopy image datasets},
16 | author={Eschweiler, Dennis and Yilmaz, R{\"u}veyda and Baumann, Matisse and Laube, Ina and Roy, Rijo and Jose, Abin and Br{\"u}ckner, Daniel and Stegmaier, Johannes},
17 | journal={PLOS Computational Biology},
18 | volume={20},
19 | number={2},
20 | pages={e1011890},
21 | year={2024}
22 | }
23 | ```
24 |
25 | We provide Jupyter Notebooks that give an overview of how to preprocess your data, train and apply the image generation process. The following gives a very brief overview of the general functionality, for more detailed examples we refer to the notebooks.
26 |
27 | ## Data Preparation
28 | The presented pipelines require the hdf5 file format for processing. Therefore, each image file has to be converted to hdf5, which can be done by using `utils.h5_converter.prepare_images`. Once all files have been converted, a list of those files has to be stored as a csv file to make them accessible by the processing pipelines. This can be done by using `utils.csv_generator.create_csv`. A more detailed explanation is given in `jupyter_preparation_script.ipynb`.
29 |
30 |
31 | ## Diffusion Model Training and Application
32 | To use the proposed pipeline to either train or apply your models, make sure to adapt all parameters in the pipeline files `models/DiffusionModel3D` or `models/DiffusionModel2D`, and in the training script `train_script.py` or application script `apply_script_diffusion.py`. Alternatively, all parameters can be provided as command line arguments. A more detailed explanation is given in `jupyter_train_script.ipynb` and `jupyter_apply_script.ipynb`.
33 |
34 |
35 | ## Simulation of Cellular Structures and Sketch Generation
36 | Since the proposed approach is working in a very intuitive manner, sketches can generally be created in any arbitrary way. We mainly focused on using and adapting simulation techniques proposed in 3D fluorescence microscopy data synthesis for segmentation and benchmarking. Nevertheless, the functionality used in this work can be found in `utils.synthetic_cell_membrane_masks` and `utils.synthetic_cell_nuclei_masks` for cellular membranes and nuclei, respectively.
37 |
38 | ## Processing Times and Hardware Requirements
39 | Due to the layout and working principle of the computation cluster we used, determining precise hardware requirements was difficult.
40 | The used hardware offered varying hardware capabilities, including GPUs ranging from GTX 1080 to RTX 6000.
41 | Nevertheless, the following should give some brief indication of the specifications needed to run the presented pipelines.
42 |
43 | Training times varied between one day for 2D models to roughly a week for 3D models, requiring 2-8 GB of RAM and 4-8 GB of GPU memory.
44 | Prediction times were highly varying, as they are influences by the choice of the backward starting point $t_\mathrm{start}$ and the total image size.
45 | For the proposed starting point $t_\mathrm{start}=400$, image generation times ranged between less than a minute for 2D image data of size 1024x1024 pixel to roughly one hour for image data of size 512x512x180 voxel or 1700x2450x13 voxel.
46 | Memory requirements were similarly varying, ranging between 2-80 GB of RAM and 4-8 GB of GPU memory.
47 |
--------------------------------------------------------------------------------
/ThirdParty/diffusion.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Code adapted from https://github.com/w86763777/pytorch-ddpm/blob/master/diffusion.py and https://huggingface.co/blog/annotated-diffusion
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | #%%
11 | # Definitions
12 |
13 | class SinusoidalPositionEmbeddings(nn.Module):
14 | def __init__(self, dim):
15 | super().__init__()
16 | self.dim = dim
17 |
18 | def forward(self, t):
19 | device = t.device
20 | half_dim = self.dim // 2
21 | embeddings = math.log(10000) / (half_dim - 1)
22 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
23 | embeddings = t[:, None] * embeddings[None, :]
24 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
25 | return embeddings
26 |
27 |
28 | def beta_schedule(timesteps, schedule='linear'):
29 |
30 | if schedule=='cosine':
31 | """
32 | cosine schedule as proposed in https://arxiv.org/abs/2102.09672
33 | """
34 | s = 0.008
35 | steps = timesteps + 1
36 | x = torch.linspace(0, timesteps, steps)
37 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
38 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
39 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
40 | betas = torch.clip(betas, 0.0001, 0.9999)
41 |
42 | if schedule=='linear':
43 | beta_start = 0.0001
44 | beta_end = 0.02
45 | betas = torch.linspace(beta_start, beta_end, timesteps)
46 |
47 | if schedule=='quadratic':
48 | beta_start = 0.0001
49 | beta_end = 0.02
50 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
51 |
52 | if schedule=='sigmoid':
53 | beta_start = 0.0001
54 | beta_end = 0.02
55 | betas = torch.linspace(-6, 6, timesteps)
56 | betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
57 |
58 | return betas.double()
59 |
60 |
61 | def extract(v, t, x_shape):
62 | """
63 | Extract some coefficients at specified timesteps, then reshape to
64 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
65 | """
66 | out = torch.gather(v, index=t, dim=0).float()
67 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
68 |
69 |
70 | #%%
71 | # Training
72 |
73 | class GaussianDiffusionTrainer(nn.Module):
74 | def __init__(self, T, schedule='cosine'):
75 | assert schedule in ['cosine', 'linear', 'quadratic', 'sigmoid'], 'Unknown schedule "{0}"'.format(schedule)
76 | super().__init__()
77 |
78 | self.T = T
79 |
80 |
81 | self.register_buffer('betas', beta_schedule(T, schedule=schedule))
82 | alphas = 1. - self.betas
83 | alphas_bar = torch.cumprod(alphas, dim=0)
84 |
85 | # calculations for diffusion q(x_t | x_{t-1}) and others
86 | self.register_buffer(
87 | 'sqrt_alphas_bar', torch.sqrt(alphas_bar))
88 | self.register_buffer(
89 | 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
90 |
91 | def forward(self, x_0, t=None):
92 | """
93 | Algorithm 1.
94 | """
95 | if t==None:
96 | t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
97 | else:
98 | t = x_0.new_ones([x_0.shape[0], ], dtype=torch.long) * t
99 | noise = torch.randn_like(x_0)
100 | x_t = (
101 | extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
102 | extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
103 | return x_t, noise, t
104 |
105 | #%%
106 | # Testing
107 |
108 | class GaussianDiffusionSampler(nn.Module):
109 | def __init__(self, model, T, t_start=1000, t_save=100, t_step=1,\
110 | schedule='linear', mean_type='epsilon', var_type='fixedlarge'):
111 | assert mean_type in ['xprev' 'xstart', 'epsilon'], 'Unknown mean_type "{0}"'.format(mean_type)
112 | assert var_type in ['fixedlarge', 'fixedsmall'], 'Unknown var_type "{0}"'.format(var_type)
113 | assert schedule in ['cosine', 'linear', 'quadratic', 'sigmoid'], 'Unknown schedule "{0}"'.format(schedule)
114 | super().__init__()
115 |
116 | self.model = model
117 | self.T = T
118 | self.t_start = t_start
119 | self.t_save = t_save
120 | self.t_step = t_step
121 | self.schedule = schedule
122 | self.mean_type = mean_type
123 | self.var_type = var_type
124 |
125 | self.register_buffer('betas', beta_schedule(T, schedule=self.schedule))
126 | alphas = 1. - self.betas
127 | alphas_bar = torch.cumprod(alphas, dim=0)
128 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
129 |
130 | # calculations for diffusion q(x_t | x_{t-1}) and others
131 | self.register_buffer(
132 | 'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
133 | self.register_buffer(
134 | 'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))
135 |
136 | # calculations for posterior q(x_{t-1} | x_t, x_0)
137 | self.register_buffer(
138 | 'posterior_var',
139 | self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
140 | # below: log calculation clipped because the posterior variance is 0 at
141 | # the beginning of the diffusion chain
142 | self.register_buffer(
143 | 'posterior_log_var_clipped',
144 | torch.log(
145 | torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
146 | self.register_buffer(
147 | 'posterior_mean_coef1',
148 | torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))
149 | self.register_buffer(
150 | 'posterior_mean_coef2',
151 | torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))
152 |
153 | def q_mean_variance(self, x_0, x_t, t):
154 | """
155 | Compute the mean and variance of the diffusion posterior
156 | q(x_{t-1} | x_t, x_0)
157 | """
158 | assert x_0.shape == x_t.shape
159 | posterior_mean = (
160 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
161 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
162 | )
163 | posterior_log_var_clipped = extract(
164 | self.posterior_log_var_clipped, t, x_t.shape)
165 | return posterior_mean, posterior_log_var_clipped
166 |
167 | def predict_xstart_from_eps(self, x_t, t, eps):
168 | assert x_t.shape == eps.shape
169 | return (
170 | extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t -
171 | extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps
172 | )
173 |
174 | def predict_xstart_from_xprev(self, x_t, t, xprev):
175 | assert x_t.shape == xprev.shape
176 | return ( # (xprev - coef2*x_t) / coef1
177 | extract(
178 | 1. / self.posterior_mean_coef1, t, x_t.shape) * xprev -
179 | extract(
180 | self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
181 | x_t.shape) * x_t
182 | )
183 |
184 | def p_mean_variance(self, x_t, cond, t):
185 | # below: only log_variance is used in the KL computations
186 | model_log_var = {
187 | # for fixedlarge, we set the initial (log-)variance like so to
188 | # get a better decoder log likelihood
189 | 'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2],
190 | self.betas[1:]])),
191 | 'fixedsmall': self.posterior_log_var_clipped,
192 | }[self.var_type]
193 | model_log_var = extract(model_log_var, t, x_t.shape)
194 |
195 | # Mean parameterization
196 | if self.mean_type == 'xprev': # the model predicts x_{t-1}
197 | x_prev = self.model(torch.cat((x_t,cond), axis=1), t)
198 | x_0 = self.predict_xstart_from_xprev(x_t, t, xprev=x_prev)
199 | model_mean = x_prev
200 | elif self.mean_type == 'xstart': # the model predicts x_0
201 | x_0 = self.model(torch.cat((x_t,cond), axis=1), t)
202 | model_mean, _ = self.q_mean_variance(x_0, x_t, t)
203 | elif self.mean_type == 'epsilon': # the model predicts epsilon
204 | eps = self.model(torch.cat((x_t,cond), axis=1), t)
205 | x_0 = self.predict_xstart_from_eps(x_t, t, eps=eps)
206 | model_mean, _ = self.q_mean_variance(x_0, x_t, t)
207 | else:
208 | raise NotImplementedError(self.mean_type)
209 | x_0 = torch.clip(x_0, -1., 1.)
210 |
211 | return model_mean, model_log_var
212 |
213 | def __len__(self):
214 | iterations = torch.linspace(0,self.t_start-1,self.t_start//self.t_step, dtype=int) % self.t_save
215 | return torch.count_nonzero(iterations==0)
216 |
217 | def forward(self, x_T, cond):
218 | """
219 | Algorithm 2.
220 | """
221 | x_t = x_T
222 | x_intermediate = []
223 | for time_step in reversed(torch.linspace(0,self.t_start-1,self.t_start//self.t_step, dtype=int)):
224 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
225 | mean, log_var = self.p_mean_variance(x_t=x_t, cond=cond, t=t)
226 | # no noise when t == 0
227 | if time_step > 0:
228 | noise = torch.randn_like(x_t)
229 | else:
230 | noise = 0
231 | x_t = mean + torch.exp(0.5 * log_var) * noise
232 |
233 | if time_step%self.t_save == 0:
234 | x_intermediate.append(x_t.cpu())
235 | print('Processing timestep {0}'.format(time_step))
236 | x_0 = x_t
237 | return x_0, x_intermediate
238 |
239 |
240 |
--------------------------------------------------------------------------------
/ThirdParty/layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Apr 20 18:09:03 2020
4 |
5 | """
6 |
7 |
8 | import math
9 | import numbers
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 |
14 |
15 |
16 | class PixelShuffle3d(nn.Module):
17 | '''
18 | reference: https://github.com/gap370/pixelshuffle3d
19 | '''
20 | '''
21 | This class is a 3d version of pixelshuffle.
22 | '''
23 | def __init__(self, scale):
24 | '''
25 | :param scale: upsample scale
26 | '''
27 | super().__init__()
28 | self.scale = scale
29 |
30 | def forward(self, input):
31 | batch_size, channels, in_depth, in_height, in_width = input.size()
32 | nOut = channels // self.scale ** 3
33 |
34 | out_depth = in_depth * self.scale
35 | out_height = in_height * self.scale
36 | out_width = in_width * self.scale
37 |
38 | input_view = input.contiguous().view(batch_size, nOut, self.scale, self.scale, self.scale, in_depth, in_height, in_width)
39 |
40 | output = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
41 |
42 | return output.view(batch_size, nOut, out_depth, out_height, out_width)
43 |
44 |
45 |
--------------------------------------------------------------------------------
/apply_script_diffusion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Jan 14 13:46:20 2020
5 |
6 | @author: eschweiler
7 | """
8 |
9 | import os
10 | import numpy as np
11 | import torch
12 | import csv
13 | from skimage import io
14 | from scipy.ndimage import gaussian_filter
15 | from argparse import ArgumentParser
16 | from torch.autograd import Variable
17 |
18 | from dataloader.h5_dataloader import MeristemH5Tiler as Tiler
19 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler
20 | from utils.utils import print_timestamp
21 |
22 | SEED = 1337
23 | torch.manual_seed(SEED)
24 | np.random.seed(SEED)
25 |
26 |
27 | def main(hparams):
28 |
29 |
30 | """
31 | Main testing routine specific for this project
32 | :param hparams:
33 | """
34 |
35 | # ------------------------
36 | # 0 SANITY CHECKS
37 | # ------------------------
38 | if not isinstance(hparams.overlap, (tuple, list)):
39 | hparams.overlap = (hparams.overlap,) * len(hparams.patch_size)
40 | if not isinstance(hparams.crop, (tuple, list)):
41 | hparams.crop = (hparams.crop,) * len(hparams.patch_size)
42 | assert all([p-2*o-2*c>0 for p,o,c in zip(hparams.patch_size, hparams.overlap, hparams.crop)]), 'Invalid combination of patch size, overlap and crop size.'
43 |
44 | # ------------------------
45 | # 1 INIT LIGHTNING MODEL
46 | # ------------------------
47 | model = network(hparams=hparams)
48 | model = model.load_from_checkpoint(hparams.ckpt_path)
49 | model = model.cuda()
50 |
51 | # ------------------------
52 | # 2 INIT DIFFUSION PARAMETERS
53 | # ------------------------
54 | device="cuda" if torch.cuda.is_available() else "cpu"
55 |
56 | if 'diffusionmodel' in hparams.pipeline.lower():
57 | DiffusionTrainer = GaussianDiffusionTrainer(hparams.num_timesteps, schedule=hparams.diffusion_schedule).to(device)
58 | DiffusionSampler = GaussianDiffusionSampler(model, hparams.num_timesteps, t_start=hparams.timesteps_start,\
59 | t_save=hparams.timesteps_save, t_step=hparams.timesteps_step,\
60 | schedule=hparams.diffusion_schedule,\
61 | mean_type='epsilon', var_type='fixedlarge').to(device)
62 | else :
63 | raise NotImplementedError()
64 |
65 | hparams.out_channels *= DiffusionSampler.__len__()
66 |
67 | # ------------------------
68 | # 3 INIT DATA TILER
69 | # ------------------------
70 | tiler = Tiler(hparams.test_list, no_mask=hparams.input_batch=='image', no_img=hparams.input_batch=='mask',\
71 | boundary_handling='none', reduce_dim='2d' in hparams.pipeline.lower(), **vars(hparams))
72 | fading_map = tiler.get_fading_map()
73 | fading_map = np.repeat(fading_map[np.newaxis,...], hparams.out_channels, axis=0)
74 |
75 | # ------------------------
76 | # 4 FILE AND FOLDER CHECKS
77 | # ------------------------
78 | os.makedirs(hparams.output_path, exist_ok=True)
79 | file_checklist = []
80 |
81 | # ------------------------
82 | # 5 PROCESS EACH IMAGE
83 | # ------------------------
84 | if hparams.num_files is None or hparams.num_files < 0:
85 | hparams.num_files = len(tiler.data_list)
86 | else:
87 | hparams.num_files = np.minimum(len(tiler.data_list), hparams.num_files)
88 |
89 | with torch.no_grad():
90 |
91 | for image_idx in range(hparams.num_files):
92 |
93 | # Get the current state of the processed files
94 | if os.path.isfile(os.path.join(hparams.output_path, 'tmp_file_checklist.csv')):
95 | file_checklist = []
96 | with open(os.path.join(hparams.output_path, 'tmp_file_checklist.csv'), 'r') as f:
97 | reader = csv.reader(f, delimiter=';')
98 | for row in reader:
99 | if not len(row)==0:
100 | file_checklist.append(row[0])
101 |
102 | # Check if current file has already been processed
103 | if not any([f==tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1] for f in file_checklist]):
104 |
105 | print_timestamp('_'*20)
106 | print_timestamp('Processing file {0}', [tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1]])
107 |
108 | # Initialize current file
109 | tiler.set_data_idx(image_idx)
110 |
111 | # Determine if the patch size exceeds the image size
112 | working_size = tuple(np.max(np.array(tiler.locations), axis=0) - np.min(np.array(tiler.locations), axis=0) + np.array(hparams.patch_size))
113 |
114 | # Initialize maps
115 | predicted_img = np.full((hparams.out_channels,)+working_size, 0, dtype=np.float32)
116 | norm_map = np.full((hparams.out_channels,)+working_size, 0, dtype=np.float32)
117 |
118 | for patch_idx in range(tiler.__len__()):
119 |
120 | print_timestamp('Generating patch {0}/{1}...',(patch_idx+1, tiler.__len__()))
121 |
122 | # Get the input
123 | sample = tiler.__getitem__(patch_idx)
124 |
125 | # Apply gaussian blur to image data
126 | if hparams.blur_sigma>0:
127 | for ndim in range(sample[hparams.input_batch].shape[0]-1):
128 | sample[hparams.input_batch][ndim,...] = gaussian_filter(sample[hparams.input_batch][ndim,...], hparams.blur_sigma, order=0)
129 |
130 | data = Variable(torch.from_numpy(sample[hparams.input_batch][np.newaxis,...]).cuda())
131 | data = data.float()
132 |
133 | pred_patch = data[:,:-1,...].clone()
134 | pred_patch,_,_ = DiffusionTrainer(pred_patch, hparams.timesteps_start-1)
135 |
136 | cond = data[:,-1:,...].clone()
137 |
138 | _,pred_patch = DiffusionSampler(pred_patch, cond)
139 |
140 | # Convert final patch to numpy for saving
141 | pred_patch = [p.cpu().data.numpy() for p in pred_patch]
142 | pred_patch = np.array(pred_patch)
143 | #if '2d' in hparams.pipeline.lower():
144 | # pred_patch = pred_patch[:,0,...] #remove batch dimension
145 | #else:
146 | # pred_patch = np.squeeze(pred_patch)
147 | pred_patch = pred_patch[:,0,...] #remove batch dimension
148 | pred_patch = np.reshape(pred_patch, (pred_patch.shape[0]*pred_patch.shape[1],)+pred_patch.shape[2:])
149 |
150 | #pred_patch = np.squeeze(pred_patch)
151 | pred_patch = np.clip(pred_patch, hparams.clip[0], hparams.clip[1])
152 |
153 | # Get the current slice position
154 | slicing = tuple(map(slice, (0,)+tuple(tiler.patch_start+tiler.global_crop_before), (hparams.out_channels,)+tuple(tiler.patch_end+tiler.global_crop_before)))
155 |
156 | # Add predicted patch and fading weights to the corresponding maps
157 | predicted_img[slicing] = predicted_img[slicing]+pred_patch*fading_map
158 | norm_map[slicing] = norm_map[slicing]+fading_map
159 |
160 | # Normalize the predicted image
161 | norm_map = np.clip(norm_map, 1e-5, np.inf)
162 | predicted_img = predicted_img / norm_map
163 |
164 | # Crop the predicted image to its original size
165 | slicing = tuple(map(slice, (0,)+tuple(tiler.global_crop_before), (hparams.out_channels,)+tuple(np.array(predicted_img.shape[1:])+np.array(tiler.global_crop_after))))
166 | predicted_img = predicted_img[slicing]
167 |
168 | # Save the predicted image
169 | predicted_img = np.transpose(predicted_img, (1,2,3,0))
170 | predicted_img = predicted_img.astype(np.float32)
171 | if hparams.out_channels > 1:
172 | for channel in range(hparams.out_channels):
173 | io.imsave(os.path.join(hparams.output_path, 'pred_'+str(channel)+'_'+os.path.split(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1])[-1][:-3]+'.tif'), predicted_img[...,channel])
174 | else:
175 | io.imsave(os.path.join(hparams.output_path, 'pred_'+os.path.split(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1])[-1][:-3]+'.tif'), predicted_img[...,0])
176 |
177 | # Mark current file as processed
178 | file_checklist.append(tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1])
179 | with open(os.path.join(hparams.output_path, 'tmp_file_checklist.csv'), 'w') as f:
180 | writer = csv.writer(f, delimiter=';')
181 | for check_file in file_checklist:
182 | writer.writerow([check_file])
183 |
184 | else:
185 | print_timestamp('_'*20)
186 | print_timestamp('Skipping file {0}', [tiler.data_list[image_idx][0 if hparams.input_batch=='image' else 1]])
187 |
188 | # Delete temporary checklist if everything has been processed
189 | os.remove(os.path.join(hparams.output_path, 'tmp_file_checklist.csv'))
190 |
191 |
192 | if __name__ == '__main__':
193 | # ------------------------
194 | # TRAINING ARGUMENTS
195 | # ------------------------
196 | # these are project-wide arguments
197 |
198 | parent_parser = ArgumentParser(add_help=False)
199 |
200 | parent_parser.add_argument(
201 | '--output_path',
202 | type=str,
203 | default=r'results/experiment1',
204 | help='output path for test results'
205 | )
206 |
207 | parent_parser.add_argument(
208 | '--ckpt_path',
209 | type=str,
210 | default=r'results/experiment1/checkpoint.ckpt',
211 | help='output path for test results'
212 | )
213 |
214 | parent_parser.add_argument(
215 | '--gpus',
216 | type=int,
217 | default=1,
218 | help='number of GPUs to use'
219 | )
220 |
221 | parent_parser.add_argument(
222 | '--overlap',
223 | type=int,
224 | default=(0,0,0),
225 | help='overlap of adjacent patches',
226 | nargs='+'
227 | )
228 |
229 | parent_parser.add_argument(
230 | '--crop',
231 | type=int,
232 | default=(0,0,0),
233 | help='safety crop of patches',
234 | nargs='+'
235 | )
236 |
237 | parent_parser.add_argument(
238 | '--input_batch',
239 | type=str,
240 | default='mask',
241 | help='which part of the batch is used as input (image | mask)'
242 | )
243 |
244 | parent_parser.add_argument(
245 | '--clip',
246 | type=float,
247 | default=(-1000.0, 1000.0),
248 | help='clipping values for network outputs',
249 | nargs='+'
250 | )
251 |
252 | parent_parser.add_argument(
253 | '--num_files',
254 | type=int,
255 | default=1,
256 | help='number of files to process'
257 | )
258 |
259 | parent_parser.add_argument(
260 | '--timesteps_start',
261 | type=int,
262 | default=400,
263 | help='number of steps between saves'
264 | )
265 |
266 | parent_parser.add_argument(
267 | '--timesteps_save',
268 | type=int,
269 | default=100,
270 | help='number of steps between saves'
271 | )
272 |
273 | parent_parser.add_argument(
274 | '--blur_sigma',
275 | type=int,
276 | default=1,
277 | help='sigma of gaussian blur used on input data'
278 | )
279 |
280 | parent_parser.add_argument(
281 | '--timesteps_step',
282 | type=int,
283 | default=1,
284 | help='timesteps skipped between iterations'
285 | )
286 |
287 | parent_parser.add_argument(
288 | '--pipeline',
289 | type=str,
290 | default='DiffusionModel3D',
291 | help='which pipeline to load (DiffusionModel3D | DiffusionModel2D)'
292 | )
293 |
294 | parent_args = parent_parser.parse_known_args()[0]
295 |
296 | # load the desired network architecture
297 |
298 | if parent_args.pipeline.lower() == 'diffusionmodel3d':
299 | from models.DiffusionModel3D import DiffusionModel3D as network
300 | elif parent_args.pipeline.lower() == 'diffusionmodel2d':
301 | from models.DiffusionModel2D import DiffusionModel2D as network
302 | else:
303 | raise ValueError('Unknown pipeline "{0}".'.format(parent_args.pipeline))
304 |
305 | # each LightningModule defines arguments relevant to it
306 | parser = network.add_model_specific_args(parent_parser)
307 | hyperparams = parser.parse_args()
308 |
309 | # ---------------------
310 | # RUN TRAINING
311 | # ---------------------
312 | main(hyperparams)
313 |
--------------------------------------------------------------------------------
/data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt
--------------------------------------------------------------------------------
/data_samples/experiment_3D/pred_0_sketch3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_0_sketch3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/experiment_3D/pred_1_sketch3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_1_sketch3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/experiment_3D/pred_2_sketch3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_2_sketch3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/experiment_3D/pred_3_sketch3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/experiment_3D/pred_3_sketch3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/image3D_sim_CTCCE_0.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/image3D_sim_CTCCE_0.h5
--------------------------------------------------------------------------------
/data_samples/image3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/image3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/image_files_3D.csv:
--------------------------------------------------------------------------------
1 | image3D_sim_CTCCE_0.h5
2 |
--------------------------------------------------------------------------------
/data_samples/sketch3D_sim_CTCCE_0.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/sketch3D_sim_CTCCE_0.h5
--------------------------------------------------------------------------------
/data_samples/sketch3D_sim_CTCCE_0.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/data_samples/sketch3D_sim_CTCCE_0.tif
--------------------------------------------------------------------------------
/data_samples/sketch_files_3D.csv:
--------------------------------------------------------------------------------
1 | sketch3D_sim_CTCCE_0.h5
2 |
--------------------------------------------------------------------------------
/figures/example_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/example_data.png
--------------------------------------------------------------------------------
/figures/multi-channel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/multi-channel.png
--------------------------------------------------------------------------------
/figures/overlapping_cells.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/overlapping_cells.png
--------------------------------------------------------------------------------
/figures/timeseries.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/figures/timeseries.gif
--------------------------------------------------------------------------------
/models/DiffusionModel2D.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import json
5 | import torch
6 | import torch.nn.functional as F
7 | import pytorch_lightning as pl
8 | import torchvision
9 |
10 | from argparse import ArgumentParser, Namespace
11 | from torch.utils.data import DataLoader
12 | from dataloader.h5_dataloader import MeristemH5Dataset
13 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler
14 |
15 |
16 |
17 | class DiffusionModel2D(pl.LightningModule):
18 |
19 | def __init__(self, hparams):
20 | super(DiffusionModel2D, self).__init__()
21 |
22 | if type(hparams) is dict:
23 | hparams = Namespace(**hparams)
24 | self.save_hyperparameters(hparams)
25 | self.augmentation_dict = {}
26 |
27 | # load the backbone network architecture
28 | if self.hparams.backbone.lower() == 'unet2d_pixelshuffle_inject':
29 | from models.module_UNet2D_pixelshuffle_inject import module_UNet2D_pixelshuffle_inject as backbone
30 | else:
31 | raise ValueError('Unknown backbone architecture {0}!'.format(self.hparams.backbone))
32 |
33 | self.network = backbone(patch_size=self.hparams.patch_size, in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels,\
34 | feat_channels=self.hparams.feat_channels, t_channels=self.hparams.t_channels,\
35 | out_activation=self.hparams.out_activation, layer_norm=self.hparams.layer_norm)
36 | # cache for generated images
37 | self.last_predictions = None
38 | self.last_imgs = None
39 |
40 | # set up diffusion parameters
41 | device="cuda" if torch.cuda.is_available() else "cpu"
42 | self.DiffusionTrainer = GaussianDiffusionTrainer(self.hparams.num_timesteps, schedule=self.hparams.diffusion_schedule).to(device)
43 | self.DiffusionSampler = GaussianDiffusionSampler(self.network, self.hparams.num_timesteps, t_start=self.hparams.num_timesteps,\
44 | t_save=self.hparams.num_timesteps, t_step=1, schedule=self.hparams.diffusion_schedule,\
45 | mean_type='epsilon', var_type='fixedlarge').to(device)
46 |
47 | def forward(self, z, t):
48 | return self.network(z, t)
49 |
50 |
51 | def load_pretrained(self, pretrained_file, strict=True, verbose=True):
52 |
53 | # Load the state dict
54 | state_dict = torch.load(pretrained_file)['state_dict']
55 |
56 | # Make sure to have a weight dict
57 | if not isinstance(state_dict, dict):
58 | state_dict = dict(state_dict)
59 |
60 | # Get parameter dict of current model
61 | param_dict = dict(self.network.named_parameters())
62 |
63 | layers = []
64 | for layer in param_dict:
65 | if strict and not 'network.'+layer in state_dict:
66 | if verbose:
67 | print('Could not find weights for layer "{0}"'.format(layer))
68 | continue
69 | try:
70 | param_dict[layer].data.copy_(state_dict['network.'+layer].data)
71 | layers.append(layer)
72 | except (RuntimeError, KeyError) as e:
73 | print('Error at layer {0}:\n{1}'.format(layer, e))
74 |
75 | self.network.load_state_dict(param_dict)
76 |
77 | if verbose:
78 | print('Loaded weights for the following layers:\n{0}'.format(layers))
79 |
80 |
81 | def denoise_loss(self, y_hat, y):
82 | return F.mse_loss(y_hat, y)
83 |
84 |
85 | def training_step(self, batch, batch_idx):
86 |
87 | # Get image ans mask of current batch
88 | self.last_imgs = batch['image'].float()
89 |
90 | # get x_t, noise for a random t
91 | self.x_t, noise, t = self.DiffusionTrainer(self.last_imgs[:,0:1,...])
92 | self.x_t.requires_grad = True
93 |
94 | # generate prediction
95 | self.generated_noise = self.forward(torch.cat((self.x_t, self.last_imgs[:,1:,...]), axis=1), t)
96 |
97 | # get the losses
98 | loss_denoise = self.denoise_loss(self.generated_noise, noise)
99 |
100 | self.logger.experiment.add_scalar('loss_denoise', loss_denoise, self.current_epoch)
101 |
102 | return loss_denoise
103 |
104 |
105 | def test_step(self, batch, batch_idx):
106 | x = batch['image']
107 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index))
108 | return {'test_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])}
109 |
110 | def test_end(self, outputs):
111 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
112 | tensorboard_logs = {'test_loss': avg_loss}
113 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
114 |
115 | def validation_step(self, batch, batch_idx):
116 | x = batch['image']
117 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index))
118 | return {'val_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])}
119 |
120 | def validation_end(self, outputs):
121 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
122 | tensorboard_logs = {'val_loss': avg_loss}
123 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
124 |
125 | def configure_optimizers(self):
126 | opt = torch.optim.RAdam(self.network.parameters(), lr=self.hparams.learning_rate)
127 | return [opt], []
128 |
129 | def train_dataloader(self):
130 | if self.hparams.train_list is None:
131 | return None
132 | else:
133 | dataset = MeristemH5Dataset(self.hparams.train_list, self.hparams.data_root, patch_size=self.hparams.patch_size,\
134 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, reduce_dim=True,\
135 | augmentation_dict=self.augmentation_dict, samples_per_epoch=self.hparams.samples_per_epoch,\
136 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none', \
137 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
138 | return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True, drop_last=True)
139 |
140 | def test_dataloader(self):
141 | if self.hparams.test_list is None:
142 | return None
143 | else:
144 | dataset = MeristemH5Dataset(self.hparams.test_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=True,\
145 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\
146 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\
147 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
148 | return DataLoader(dataset, batch_size=self.hparams.batch_size)
149 |
150 | def val_dataloader(self):
151 | if self.hparams.val_list is None:
152 | return None
153 | else:
154 | dataset = MeristemH5Dataset(self.hparams.val_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=True,\
155 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\
156 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\
157 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
158 | return DataLoader(dataset, batch_size=self.hparams.batch_size)
159 |
160 |
161 | def on_train_epoch_end(self):
162 |
163 |
164 | self.DiffusionSampler.model = self.network
165 |
166 | with torch.no_grad():
167 |
168 | input_patch = self.last_imgs
169 |
170 | # get x_0
171 | x_0,_ = self.DiffusionSampler(input_patch[:,0:1,...], input_patch[:,1:,...])
172 |
173 | # log sampled images
174 | prediction_grid = torchvision.utils.make_grid(x_0)
175 | self.logger.experiment.add_image('predicted_x_0', prediction_grid, self.current_epoch)
176 |
177 | img_grid = torchvision.utils.make_grid(input_patch)
178 | self.logger.experiment.add_image('raw_x_0', img_grid, self.current_epoch)
179 |
180 |
181 | def set_augmentations(self, augmentation_dict_file):
182 | if not augmentation_dict_file is None:
183 | self.augmentation_dict = json.load(open(augmentation_dict_file))
184 |
185 |
186 | @staticmethod
187 | def add_model_specific_args(parent_parser):
188 | """
189 | Parameters you define here will be available to your model through self.hparams
190 | """
191 | parser = ArgumentParser(parents=[parent_parser])
192 |
193 | # network params
194 | parser.add_argument('--backbone', default='UNet2D_PixelShuffle_inject', type=str, help='which model to load (UNet3D_PixelShuffle_inject)')
195 | parser.add_argument('--in_channels', default=2, type=int)
196 | parser.add_argument('--out_channels', default=1, type=int)
197 | parser.add_argument('--feat_channels', default=16, type=int)
198 | parser.add_argument('--t_channels', default=128, type=int)
199 | parser.add_argument('--patch_size', default=(1,256,256), type=int, nargs='+')
200 | parser.add_argument('--layer_norm', default='instance', type=str)
201 | parser.add_argument('--out_activation', default='none', type=str)
202 |
203 | # data
204 | parser.add_argument('--data_norm', default='minmax_shifted', type=str)
205 | parser.add_argument('--data_root', default='/data/root', type=str)
206 | parser.add_argument('--train_list', default='/path/to/training_data/split1_train.csv', type=str)
207 | parser.add_argument('--test_list', default='/path/to/testing_data/split1_test.csv', type=str)
208 | parser.add_argument('--val_list', default='/path/to/validation_data/split1_val.csv', type=str)
209 | parser.add_argument('--image_groups', default=('data/image',), type=str, nargs='+')
210 | parser.add_argument('--mask_groups', default=('data/diffusion_mask',), type=str, nargs='+')
211 | parser.add_argument('--image_noise_channel', default=-1, type=int)
212 | parser.add_argument('--mask_noise_channel', default=-1, type=int)
213 | parser.add_argument('--noise_type', default='gaussian', type=str)
214 |
215 | # diffusion parameter
216 | parser.add_argument('--num_timesteps', default=1000, type=int)
217 | parser.add_argument('--diffusion_schedule', default='cosine', type=str)
218 |
219 | # training params
220 | parser.add_argument('--samples_per_epoch', default=-1, type=int)
221 | parser.add_argument('--batch_size', default=1, type=int)
222 | parser.add_argument('--learning_rate', default=0.001, type=float)
223 |
224 | return parser
--------------------------------------------------------------------------------
/models/DiffusionModel3D.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import json
5 | import torch
6 | import torch.nn.functional as F
7 | import pytorch_lightning as pl
8 | import torchvision
9 |
10 | from argparse import ArgumentParser, Namespace
11 | from torch.utils.data import DataLoader
12 | from dataloader.h5_dataloader import MeristemH5Dataset
13 | from ThirdParty.diffusion import GaussianDiffusionTrainer, GaussianDiffusionSampler
14 |
15 |
16 |
17 | class DiffusionModel3D(pl.LightningModule):
18 |
19 | def __init__(self, hparams):
20 | super(DiffusionModel3D, self).__init__()
21 |
22 | if type(hparams) is dict:
23 | hparams = Namespace(**hparams)
24 | self.save_hyperparameters(hparams)
25 | self.augmentation_dict = {}
26 |
27 | # load the backbone network architecture
28 | if self.hparams.backbone.lower() == 'unet3d_pixelshuffle_inject':
29 | from models.module_UNet3D_pixelshuffle_inject import module_UNet3D_pixelshuffle_inject as backbone
30 | else:
31 | raise ValueError('Unknown backbone architecture {0}!'.format(self.hparams.backbone))
32 |
33 | self.network = backbone(patch_size=self.hparams.patch_size, in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels,\
34 | feat_channels=self.hparams.feat_channels, t_channels=self.hparams.t_channels,\
35 | out_activation=self.hparams.out_activation, layer_norm=self.hparams.layer_norm)
36 | # cache for generated images
37 | self.last_predictions = None
38 | self.last_imgs = None
39 |
40 | # set up diffusion parameters
41 | device="cuda" if torch.cuda.is_available() else "cpu"
42 | self.DiffusionTrainer = GaussianDiffusionTrainer(self.hparams.num_timesteps, schedule=self.hparams.diffusion_schedule).to(device)
43 | self.DiffusionSampler = GaussianDiffusionSampler(self.network, self.hparams.num_timesteps, t_start=self.hparams.num_timesteps,\
44 | t_save=self.hparams.num_timesteps, t_step=1, schedule=self.hparams.diffusion_schedule,\
45 | mean_type='epsilon', var_type='fixedlarge').to(device)
46 |
47 | def forward(self, z, t):
48 | return self.network(z, t)
49 |
50 |
51 | def load_pretrained(self, pretrained_file, strict=True, verbose=True):
52 |
53 | # Load the state dict
54 | state_dict = torch.load(pretrained_file)['state_dict']
55 |
56 | # Make sure to have a weight dict
57 | if not isinstance(state_dict, dict):
58 | state_dict = dict(state_dict)
59 |
60 | # Get parameter dict of current model
61 | param_dict = dict(self.network.named_parameters())
62 |
63 | layers = []
64 | for layer in param_dict:
65 | if strict and not 'network.'+layer in state_dict:
66 | if verbose:
67 | print('Could not find weights for layer "{0}"'.format(layer))
68 | continue
69 | try:
70 | param_dict[layer].data.copy_(state_dict['network.'+layer].data)
71 | layers.append(layer)
72 | except (RuntimeError, KeyError) as e:
73 | print('Error at layer {0}:\n{1}'.format(layer, e))
74 |
75 | self.network.load_state_dict(param_dict)
76 |
77 | if verbose:
78 | print('Loaded weights for the following layers:\n{0}'.format(layers))
79 |
80 |
81 | def denoise_loss(self, y_hat, y):
82 | return F.mse_loss(y_hat, y)
83 |
84 |
85 | def training_step(self, batch, batch_idx):
86 |
87 | # Get image ans mask of current batch
88 | self.last_imgs = batch['image'].float()
89 |
90 | # get x_t, noise for a random t
91 | self.x_t, noise, t = self.DiffusionTrainer(self.last_imgs[:,0:1,...])
92 | self.x_t.requires_grad = True
93 |
94 | # generate prediction
95 | self.generated_noise = self.forward(torch.cat((self.x_t, self.last_imgs[:,1:,...]), axis=1), t)
96 |
97 | # get the losses
98 | loss_denoise = self.denoise_loss(self.generated_noise, noise)
99 |
100 | self.logger.experiment.add_scalar('loss_denoise', loss_denoise, self.current_epoch)
101 |
102 | return loss_denoise
103 |
104 |
105 | def test_step(self, batch, batch_idx):
106 | x = batch['image']
107 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index))
108 | return {'test_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])}
109 |
110 | def test_end(self, outputs):
111 | avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
112 | tensorboard_logs = {'test_loss': avg_loss}
113 | return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
114 |
115 | def validation_step(self, batch, batch_idx):
116 | x = batch['image']
117 | x_hat = self.forward(x, torch.tensor([0,], device=x.device.index))
118 | return {'val_loss': F.l1_loss(x[:,0:1,...]-x_hat, x[:,0:1,...])}
119 |
120 | def validation_end(self, outputs):
121 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
122 | tensorboard_logs = {'val_loss': avg_loss}
123 | return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
124 |
125 | def configure_optimizers(self):
126 | opt = torch.optim.RAdam(self.network.parameters(), lr=self.hparams.learning_rate)
127 | return [opt], []
128 |
129 | def train_dataloader(self):
130 | if self.hparams.train_list is None:
131 | return None
132 | else:
133 | dataset = MeristemH5Dataset(self.hparams.train_list, self.hparams.data_root, patch_size=self.hparams.patch_size,\
134 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, reduce_dim=False,\
135 | augmentation_dict=self.augmentation_dict, samples_per_epoch=self.hparams.samples_per_epoch,\
136 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none', \
137 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
138 | return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True, drop_last=True)
139 |
140 | def test_dataloader(self):
141 | if self.hparams.test_list is None:
142 | return None
143 | else:
144 | dataset = MeristemH5Dataset(self.hparams.test_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=False,\
145 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\
146 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\
147 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
148 | return DataLoader(dataset, batch_size=self.hparams.batch_size)
149 |
150 | def val_dataloader(self):
151 | if self.hparams.val_list is None:
152 | return None
153 | else:
154 | dataset = MeristemH5Dataset(self.hparams.val_list, self.hparams.data_root, patch_size=self.hparams.patch_size, reduce_dim=False,\
155 | image_groups=self.hparams.image_groups, mask_groups=self.hparams.mask_groups, augmentation_dict={},\
156 | data_norm=self.hparams.data_norm, no_mask=True, boundary_handling='none',\
157 | image_noise_channel=self.hparams.image_noise_channel, mask_noise_channel=self.hparams.mask_noise_channel, noise_type=self.hparams.noise_type)
158 | return DataLoader(dataset, batch_size=self.hparams.batch_size)
159 |
160 |
161 | def on_train_epoch_end(self):
162 |
163 |
164 | self.DiffusionSampler.model = self.network
165 |
166 | with torch.no_grad():
167 |
168 | input_patch = self.last_imgs
169 |
170 | # get x_0
171 | x_0,_ = self.DiffusionSampler(input_patch[:,0:1,...], input_patch[:,1:,...])
172 |
173 | # log sampled images
174 | prediction_grid = torchvision.utils.make_grid(x_0[...,int(self.hparams.patch_size[0]//2),:,:])
175 | self.logger.experiment.add_image('predicted_x_0', prediction_grid, self.current_epoch)
176 |
177 | img_grid = torchvision.utils.make_grid(input_patch[...,int(self.hparams.patch_size[0]//2),:,:])
178 | self.logger.experiment.add_image('raw_x_0', img_grid, self.current_epoch)
179 |
180 |
181 | def set_augmentations(self, augmentation_dict_file):
182 | if not augmentation_dict_file is None:
183 | self.augmentation_dict = json.load(open(augmentation_dict_file))
184 |
185 |
186 | @staticmethod
187 | def add_model_specific_args(parent_parser):
188 | """
189 | Parameters you define here will be available to your model through self.hparams
190 | """
191 | parser = ArgumentParser(parents=[parent_parser])
192 |
193 | # network params
194 | parser.add_argument('--backbone', default='UNet3D_PixelShuffle_inject', type=str, help='which model to load (UNet3D_PixelShuffle_inject)')
195 | parser.add_argument('--in_channels', default=2, type=int)
196 | parser.add_argument('--out_channels', default=1, type=int)
197 | parser.add_argument('--feat_channels', default=16, type=int)
198 | parser.add_argument('--t_channels', default=128, type=int)
199 | parser.add_argument('--patch_size', default=(32,128,128), type=int, nargs='+')
200 | parser.add_argument('--layer_norm', default='instance', type=str)
201 | parser.add_argument('--out_activation', default='none', type=str)
202 |
203 | # data
204 | parser.add_argument('--data_norm', default='minmax_shifted', type=str)
205 | parser.add_argument('--data_root', default='../data_samples', type=str)
206 | parser.add_argument('--train_list', default='../data_samples/image_files_3D.csv', type=str)
207 | parser.add_argument('--test_list', default='../data_samples/image_files_3D.csv', type=str)
208 | parser.add_argument('--val_list', default='../data_samples/image_files_3D.csv', type=str)
209 | parser.add_argument('--image_groups', default=('data/image',), type=str, nargs='+')
210 | parser.add_argument('--mask_groups', default=('data/image',), type=str, nargs='+')
211 | parser.add_argument('--image_noise_channel', default=-1, type=int)
212 | parser.add_argument('--mask_noise_channel', default=-1, type=int)
213 | parser.add_argument('--noise_type', default='gaussian', type=str)
214 |
215 | # diffusion parameter
216 | parser.add_argument('--num_timesteps', default=1000, type=int)
217 | parser.add_argument('--diffusion_schedule', default='cosine', type=str)
218 |
219 | # training params
220 | parser.add_argument('--samples_per_epoch', default=-1, type=int)
221 | parser.add_argument('--batch_size', default=1, type=int)
222 | parser.add_argument('--learning_rate', default=0.001, type=float)
223 |
224 | return parser
--------------------------------------------------------------------------------
/models/module_UNet2D_pixelshuffle_inject.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Implementation of the 3D UNet architecture with PixelShuffle upsampling.
4 | https://arxiv.org/pdf/1609.05158v2.pdf
5 | """
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 |
11 | from ThirdParty.diffusion import SinusoidalPositionEmbeddings
12 |
13 |
14 |
15 | class module_UNet2D_pixelshuffle_inject(nn.Module):
16 | """Implementation of the 3D U-Net architecture.
17 | """
18 |
19 | def __init__(self, patch_size, in_channels, out_channels, feat_channels=16, t_channels=128, out_activation='sigmoid', layer_norm='none', **kwargs):
20 | super(module_UNet2D_pixelshuffle_inject, self).__init__()
21 |
22 | self.patch_size = patch_size
23 | self.in_channels = in_channels
24 | self.out_channels = out_channels
25 | self.feat_channels = feat_channels
26 | self.t_channels = t_channels
27 | self.layer_norm = layer_norm # instance | batch | none
28 | self.out_activation = out_activation # relu | leakyrelu | sigmoid | tanh | hardtanh | none
29 |
30 | self.norm_methods = {
31 | 'instance': nn.InstanceNorm2d,
32 | 'batch': nn.BatchNorm2d,
33 | 'none': nn.Identity
34 | }
35 |
36 | self.out_activations = nn.ModuleDict({
37 | 'relu': nn.ReLU(),
38 | 'leakyrelu': nn.LeakyReLU(negative_slope=0.2, inplace=True),
39 | 'sigmoid': nn.Sigmoid(),
40 | 'tanh': nn.Tanh(),
41 | 'hardtanh': nn.Hardtanh(0,1),
42 | 'none': nn.Identity()
43 | })
44 |
45 |
46 | # Define layer instances
47 | self.t1 = nn.Sequential(
48 | SinusoidalPositionEmbeddings(t_channels),
49 | nn.Linear(t_channels, feat_channels),
50 | nn.PReLU(feat_channels),
51 | nn.Linear(feat_channels, feat_channels)
52 | )
53 | self.c1 = nn.Sequential(
54 | nn.Conv2d(in_channels, feat_channels//2, kernel_size=3, padding=1),
55 | nn.PReLU(feat_channels//2),
56 | self.norm_methods[self.layer_norm](feat_channels//2),
57 | nn.Conv2d(feat_channels//2, feat_channels, kernel_size=3, padding=1),
58 | nn.PReLU(feat_channels),
59 | self.norm_methods[self.layer_norm](feat_channels)
60 | )
61 | self.d1 = nn.Sequential(
62 | nn.Conv2d(feat_channels, feat_channels, kernel_size=4, stride=2, padding=1),
63 | nn.PReLU(feat_channels),
64 | self.norm_methods[self.layer_norm](feat_channels)
65 | )
66 |
67 |
68 | self.t2 = nn.Sequential(
69 | SinusoidalPositionEmbeddings(t_channels),
70 | nn.Linear(t_channels, feat_channels*2),
71 | nn.PReLU(feat_channels*2),
72 | nn.Linear(feat_channels*2, feat_channels*2)
73 | )
74 | self.c2 = nn.Sequential(
75 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1),
76 | nn.PReLU(feat_channels),
77 | self.norm_methods[self.layer_norm](feat_channels),
78 | nn.Conv2d(feat_channels, feat_channels*2, kernel_size=3, padding=1),
79 | nn.PReLU(feat_channels*2),
80 | self.norm_methods[self.layer_norm](feat_channels*2)
81 | )
82 | self.d2 = nn.Sequential(
83 | nn.Conv2d(feat_channels*2, feat_channels*2, kernel_size=4, stride=2, padding=1),
84 | nn.PReLU(feat_channels*2),
85 | self.norm_methods[self.layer_norm](feat_channels*2)
86 | )
87 |
88 |
89 | self.t3 = nn.Sequential(
90 | SinusoidalPositionEmbeddings(t_channels),
91 | nn.Linear(t_channels, feat_channels*4),
92 | nn.PReLU(feat_channels*4),
93 | nn.Linear(feat_channels*4, feat_channels*4)
94 | )
95 | self.c3 = nn.Sequential(
96 | nn.Conv2d(feat_channels*2, feat_channels*2, kernel_size=3, padding=1),
97 | nn.PReLU(feat_channels*2),
98 | self.norm_methods[self.layer_norm](feat_channels*2),
99 | nn.Conv2d(feat_channels*2, feat_channels*4, kernel_size=3, padding=1),
100 | nn.PReLU(feat_channels*4),
101 | self.norm_methods[self.layer_norm](feat_channels*4)
102 | )
103 | self.d3 = nn.Sequential(
104 | nn.Conv2d(feat_channels*4, feat_channels*4, kernel_size=4, stride=2, padding=1),
105 | nn.PReLU(feat_channels*4),
106 | self.norm_methods[self.layer_norm](feat_channels*4)
107 | )
108 |
109 |
110 | self.t4 = nn.Sequential(
111 | SinusoidalPositionEmbeddings(t_channels),
112 | nn.Linear(t_channels, feat_channels*8),
113 | nn.PReLU(feat_channels*8),
114 | nn.Linear(feat_channels*8, feat_channels*8)
115 | )
116 | self.c4 = nn.Sequential(
117 | nn.Conv2d(feat_channels*4, feat_channels*4, kernel_size=3, padding=1),
118 | nn.PReLU(feat_channels*4),
119 | self.norm_methods[self.layer_norm](feat_channels*4),
120 | nn.Conv2d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1),
121 | nn.PReLU(feat_channels*8),
122 | self.norm_methods[self.layer_norm](feat_channels*8)
123 | )
124 |
125 |
126 | self.u1 = nn.Sequential(
127 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1),
128 | nn.PReLU(feat_channels*8),
129 | self.norm_methods[self.layer_norm](feat_channels*8),
130 | nn.PixelShuffle(2)
131 | )
132 | self.t5 = nn.Sequential(
133 | SinusoidalPositionEmbeddings(t_channels),
134 | nn.Linear(t_channels, feat_channels*8),
135 | nn.PReLU(feat_channels*8),
136 | nn.Linear(feat_channels*8, feat_channels*8)
137 | )
138 | self.c5 = nn.Sequential(
139 | nn.Conv2d(feat_channels*6, feat_channels*8, kernel_size=3, padding=1),
140 | nn.PReLU(feat_channels*8),
141 | self.norm_methods[self.layer_norm](feat_channels*8),
142 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1),
143 | nn.PReLU(feat_channels*8),
144 | self.norm_methods[self.layer_norm](feat_channels*8)
145 | )
146 |
147 |
148 | self.u2 = nn.Sequential(
149 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1),
150 | nn.PReLU(feat_channels*8),
151 | self.norm_methods[self.layer_norm](feat_channels*8),
152 | nn.PixelShuffle(2)
153 | )
154 | self.t6 = nn.Sequential(
155 | SinusoidalPositionEmbeddings(t_channels),
156 | nn.Linear(t_channels, feat_channels*8),
157 | nn.PReLU(feat_channels*8),
158 | nn.Linear(feat_channels*8, feat_channels*8)
159 | )
160 | self.c6 = nn.Sequential(
161 | nn.Conv2d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1),
162 | nn.PReLU(feat_channels*8),
163 | self.norm_methods[self.layer_norm](feat_channels*8),
164 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1),
165 | nn.PReLU(feat_channels*8),
166 | self.norm_methods[self.layer_norm](feat_channels*8)
167 | )
168 |
169 |
170 | self.u3 = nn.Sequential(
171 | nn.Conv2d(feat_channels*8, feat_channels*8, kernel_size=1),
172 | nn.PReLU(feat_channels*8),
173 | self.norm_methods[self.layer_norm](feat_channels*8),
174 | nn.PixelShuffle(2)
175 | )
176 | self.t7 = nn.Sequential(
177 | SinusoidalPositionEmbeddings(t_channels),
178 | nn.Linear(t_channels, feat_channels),
179 | nn.PReLU(feat_channels),
180 | nn.Linear(feat_channels, feat_channels)
181 | )
182 | self.c7 = nn.Sequential(
183 | nn.Conv2d(feat_channels*3, feat_channels, kernel_size=3, padding=1),
184 | nn.PReLU(feat_channels),
185 | self.norm_methods[self.layer_norm](feat_channels),
186 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1),
187 | nn.PReLU(feat_channels),
188 | self.norm_methods[self.layer_norm](feat_channels)
189 | )
190 |
191 |
192 | self.out = nn.Sequential(
193 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1),
194 | nn.PReLU(feat_channels),
195 | self.norm_methods[self.layer_norm](feat_channels),
196 | nn.Conv2d(feat_channels, feat_channels, kernel_size=3, padding=1),
197 | nn.PReLU(feat_channels),
198 | self.norm_methods[self.layer_norm](feat_channels),
199 | nn.Conv2d(feat_channels, out_channels, kernel_size=1),
200 | self.out_activations[self.out_activation]
201 | )
202 |
203 |
204 | def forward(self, img, t):
205 |
206 | t1 = self.t1(t)
207 | c1 = self.c1(img)+t1[...,None,None]
208 | d1 = self.d1(c1)
209 |
210 | t2 = self.t2(t)
211 | c2 = self.c2(d1)+t2[...,None,None]
212 | d2 = self.d2(c2)
213 |
214 | t3 = self.t3(t)
215 | c3 = self.c3(d2)+t3[...,None,None]
216 | d3 = self.d3(c3)
217 |
218 | t4 = self.t4(t)
219 | c4 = self.c4(d3)+t4[...,None,None]
220 |
221 | u1 = self.u1(c4)
222 | t5 = self.t5(t)
223 | c5 = self.c5(torch.cat((u1,c3),1))+t5[...,None,None]
224 |
225 | u2 = self.u2(c5)
226 | t6 = self.t6(t)
227 | c6 = self.c6(torch.cat((u2,c2),1))+t6[...,None,None]
228 |
229 | u3 = self.u3(c6)
230 | t7 = self.t7(t)
231 | c7 = self.c7(torch.cat((u3,c1),1))+t7[...,None,None]
232 |
233 | out = self.out(c7)
234 |
235 | return out
236 |
237 |
--------------------------------------------------------------------------------
/models/module_UNet3D_pixelshuffle_inject.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Implementation of the 3D UNet architecture with PixelShuffle upsampling.
4 | https://arxiv.org/pdf/1609.05158v2.pdf
5 | """
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 |
11 | from ThirdParty.layers import PixelShuffle3d
12 | from ThirdParty.diffusion import SinusoidalPositionEmbeddings
13 |
14 |
15 |
16 | class module_UNet3D_pixelshuffle_inject(nn.Module):
17 | """Implementation of the 3D U-Net architecture.
18 | """
19 |
20 | def __init__(self, patch_size, in_channels, out_channels, feat_channels=16, t_channels=128, out_activation='sigmoid', layer_norm='none', **kwargs):
21 | super(module_UNet3D_pixelshuffle_inject, self).__init__()
22 |
23 | self.patch_size = patch_size
24 | self.in_channels = in_channels
25 | self.out_channels = out_channels
26 | self.feat_channels = feat_channels
27 | self.t_channels = t_channels
28 | self.layer_norm = layer_norm # instance | batch | none
29 | self.out_activation = out_activation # relu | leakyrelu | sigmoid | tanh | hardtanh | none
30 |
31 | self.norm_methods = {
32 | 'instance': nn.InstanceNorm3d,
33 | 'batch': nn.BatchNorm3d,
34 | 'none': nn.Identity
35 | }
36 |
37 | self.out_activations = nn.ModuleDict({
38 | 'relu': nn.ReLU(),
39 | 'leakyrelu': nn.LeakyReLU(negative_slope=0.2, inplace=True),
40 | 'sigmoid': nn.Sigmoid(),
41 | 'tanh': nn.Tanh(),
42 | 'hardtanh': nn.Hardtanh(0,1),
43 | 'none': nn.Identity()
44 | })
45 |
46 |
47 | # Define layer instances
48 | self.t1 = nn.Sequential(
49 | SinusoidalPositionEmbeddings(t_channels),
50 | nn.Linear(t_channels, feat_channels),
51 | nn.PReLU(feat_channels),
52 | nn.Linear(feat_channels, feat_channels)
53 | )
54 | self.c1 = nn.Sequential(
55 | nn.Conv3d(in_channels, feat_channels//2, kernel_size=3, padding=1),
56 | nn.PReLU(feat_channels//2),
57 | self.norm_methods[self.layer_norm](feat_channels//2),
58 | nn.Conv3d(feat_channels//2, feat_channels, kernel_size=3, padding=1),
59 | nn.PReLU(feat_channels),
60 | self.norm_methods[self.layer_norm](feat_channels)
61 | )
62 | self.d1 = nn.Sequential(
63 | nn.Conv3d(feat_channels, feat_channels, kernel_size=4, stride=2, padding=1),
64 | nn.PReLU(feat_channels),
65 | self.norm_methods[self.layer_norm](feat_channels)
66 | )
67 |
68 |
69 | self.t2 = nn.Sequential(
70 | SinusoidalPositionEmbeddings(t_channels),
71 | nn.Linear(t_channels, feat_channels*2),
72 | nn.PReLU(feat_channels*2),
73 | nn.Linear(feat_channels*2, feat_channels*2)
74 | )
75 | self.c2 = nn.Sequential(
76 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1),
77 | nn.PReLU(feat_channels),
78 | self.norm_methods[self.layer_norm](feat_channels),
79 | nn.Conv3d(feat_channels, feat_channels*2, kernel_size=3, padding=1),
80 | nn.PReLU(feat_channels*2),
81 | self.norm_methods[self.layer_norm](feat_channels*2)
82 | )
83 | self.d2 = nn.Sequential(
84 | nn.Conv3d(feat_channels*2, feat_channels*2, kernel_size=4, stride=2, padding=1),
85 | nn.PReLU(feat_channels*2),
86 | self.norm_methods[self.layer_norm](feat_channels*2)
87 | )
88 |
89 |
90 | self.t3 = nn.Sequential(
91 | SinusoidalPositionEmbeddings(t_channels),
92 | nn.Linear(t_channels, feat_channels*4),
93 | nn.PReLU(feat_channels*4),
94 | nn.Linear(feat_channels*4, feat_channels*4)
95 | )
96 | self.c3 = nn.Sequential(
97 | nn.Conv3d(feat_channels*2, feat_channels*2, kernel_size=3, padding=1),
98 | nn.PReLU(feat_channels*2),
99 | self.norm_methods[self.layer_norm](feat_channels*2),
100 | nn.Conv3d(feat_channels*2, feat_channels*4, kernel_size=3, padding=1),
101 | nn.PReLU(feat_channels*4),
102 | self.norm_methods[self.layer_norm](feat_channels*4)
103 | )
104 | self.d3 = nn.Sequential(
105 | nn.Conv3d(feat_channels*4, feat_channels*4, kernel_size=4, stride=2, padding=1),
106 | nn.PReLU(feat_channels*4),
107 | self.norm_methods[self.layer_norm](feat_channels*4)
108 | )
109 |
110 |
111 | self.t4 = nn.Sequential(
112 | SinusoidalPositionEmbeddings(t_channels),
113 | nn.Linear(t_channels, feat_channels*8),
114 | nn.PReLU(feat_channels*8),
115 | nn.Linear(feat_channels*8, feat_channels*8)
116 | )
117 | self.c4 = nn.Sequential(
118 | nn.Conv3d(feat_channels*4, feat_channels*4, kernel_size=3, padding=1),
119 | nn.PReLU(feat_channels*4),
120 | self.norm_methods[self.layer_norm](feat_channels*4),
121 | nn.Conv3d(feat_channels*4, feat_channels*8, kernel_size=3, padding=1),
122 | nn.PReLU(feat_channels*8),
123 | self.norm_methods[self.layer_norm](feat_channels*8)
124 | )
125 |
126 |
127 | self.u1 = nn.Sequential(
128 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1),
129 | nn.PReLU(feat_channels*8),
130 | self.norm_methods[self.layer_norm](feat_channels*8),
131 | PixelShuffle3d(2)
132 | )
133 | self.t5 = nn.Sequential(
134 | SinusoidalPositionEmbeddings(t_channels),
135 | nn.Linear(t_channels, feat_channels*8),
136 | nn.PReLU(feat_channels*8),
137 | nn.Linear(feat_channels*8, feat_channels*8)
138 | )
139 | self.c5 = nn.Sequential(
140 | nn.Conv3d(feat_channels*5, feat_channels*8, kernel_size=3, padding=1),
141 | nn.PReLU(feat_channels*8),
142 | self.norm_methods[self.layer_norm](feat_channels*8),
143 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1),
144 | nn.PReLU(feat_channels*8),
145 | self.norm_methods[self.layer_norm](feat_channels*8)
146 | )
147 |
148 |
149 | self.u2 = nn.Sequential(
150 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1),
151 | nn.PReLU(feat_channels*8),
152 | self.norm_methods[self.layer_norm](feat_channels*8),
153 | PixelShuffle3d(2)
154 | )
155 | self.t6 = nn.Sequential(
156 | SinusoidalPositionEmbeddings(t_channels),
157 | nn.Linear(t_channels, feat_channels*8),
158 | nn.PReLU(feat_channels*8),
159 | nn.Linear(feat_channels*8, feat_channels*8)
160 | )
161 | self.c6 = nn.Sequential(
162 | nn.Conv3d(feat_channels*3, feat_channels*8, kernel_size=3, padding=1),
163 | nn.PReLU(feat_channels*8),
164 | self.norm_methods[self.layer_norm](feat_channels*8),
165 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=3, padding=1),
166 | nn.PReLU(feat_channels*8),
167 | self.norm_methods[self.layer_norm](feat_channels*8)
168 | )
169 |
170 |
171 | self.u3 = nn.Sequential(
172 | nn.Conv3d(feat_channels*8, feat_channels*8, kernel_size=1),
173 | nn.PReLU(feat_channels*8),
174 | self.norm_methods[self.layer_norm](feat_channels*8),
175 | PixelShuffle3d(2)
176 | )
177 | self.t7 = nn.Sequential(
178 | SinusoidalPositionEmbeddings(t_channels),
179 | nn.Linear(t_channels, feat_channels),
180 | nn.PReLU(feat_channels),
181 | nn.Linear(feat_channels, feat_channels)
182 | )
183 | self.c7 = nn.Sequential(
184 | nn.Conv3d(feat_channels*2, feat_channels, kernel_size=3, padding=1),
185 | nn.PReLU(feat_channels),
186 | self.norm_methods[self.layer_norm](feat_channels),
187 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1),
188 | nn.PReLU(feat_channels),
189 | self.norm_methods[self.layer_norm](feat_channels)
190 | )
191 |
192 |
193 | self.out = nn.Sequential(
194 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1),
195 | nn.PReLU(feat_channels),
196 | self.norm_methods[self.layer_norm](feat_channels),
197 | nn.Conv3d(feat_channels, feat_channels, kernel_size=3, padding=1),
198 | nn.PReLU(feat_channels),
199 | self.norm_methods[self.layer_norm](feat_channels),
200 | nn.Conv3d(feat_channels, out_channels, kernel_size=1),
201 | self.out_activations[self.out_activation]
202 | )
203 |
204 |
205 | def forward(self, img, t):
206 |
207 | t1 = self.t1(t)
208 | c1 = self.c1(img)+t1[...,None,None,None]
209 | d1 = self.d1(c1)
210 |
211 | t2 = self.t2(t)
212 | c2 = self.c2(d1)+t2[...,None,None,None]
213 | d2 = self.d2(c2)
214 |
215 | t3 = self.t3(t)
216 | c3 = self.c3(d2)+t3[...,None,None,None]
217 | d3 = self.d3(c3)
218 |
219 | t4 = self.t4(t)
220 | c4 = self.c4(d3)+t4[...,None,None,None]
221 |
222 | u1 = self.u1(c4)
223 | t5 = self.t5(t)
224 | c5 = self.c5(torch.cat((u1,c3),1))+t5[...,None,None,None]
225 |
226 | u2 = self.u2(c5)
227 | t6 = self.t6(t)
228 | c6 = self.c6(torch.cat((u2,c2),1))+t6[...,None,None,None]
229 |
230 | u3 = self.u3(c6)
231 | t7 = self.t7(t)
232 | c7 = self.c7(torch.cat((u3,c1),1))+t7[...,None,None,None]
233 |
234 | out = self.out(c7)
235 |
236 | return out
237 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | By providing jupyter notebooks we try to give a clear overview of how the data preparation, training and application pipelines work.
2 | Parameters that are important to these pipelines and might require further clarification are listed and explained in the following.
3 |
4 | #### General Network Parameters
5 | - `in_channels`: (int) Number of input channels, including an additional noise channel
6 | - `out_channels`: (int) Number of output channels
7 | - `feat_channels`: (int) Number of feature channels in the first block of the UNet backbone. Other channel numbers are derived from this
8 | - `t_channels`: (int) Number of channels used for timestep encoding
9 | - `patch_size`: (int+) Size of the patches. For 2D make sure to still provide a 3D patch size with a leading 1 with (1,y,x)
10 | - `data_root`: (str) Directory of the image data
11 | - `train_list`, `test_list`, `val_list`: (str) Path to the csv file listing all image data that should be used for training/testing/validation. The concatenation of the _data_root_ parameter and the entries of this list should give the full path to each file.
12 | - `image_groups`: (str+) List of group names listed in the hdf5 file that should be used
13 | - `num_timesteps`: (int) Number of total diffusion timesteps
14 | - `diffusion_schedule`: (str) Noise schedule used during the diffusion forward process
15 |
16 | #### Training-specific Parameters
17 | - `output_path`: (str) Directory for saving the model
18 | - `log_path`: (str) Directory for saving the log files
19 | - `no_resume`: (bool) Flag to not resume training from an existing checkpoint with the same name as the current training experiment
20 | - `pretrained`: (str) Explicit definition of a pretrained model for further training
21 | - `epochs`: (int) Number of training epochs
22 | - `samples_per_epoch`: (int) Number of samples used in one training epoch. Set to -1 to use all available samples
23 |
24 | #### Application-specific Parameters
25 | - `output_path`: (str) Directory for saving the generated image data
26 | - `ckpt_path`: (str) Path to the trained model checkpoint file
27 | - `overlap`: (int+) Tuple of overlap of neighbouring patches during the patch-based application, defined in z,y,x
28 | - `crop`: (int+) Tuple of cropped patch borders to avoid potential border artifacts, defined in z,y,x
29 | - `clip`: (int+) Intensity range the output gets clipped to
30 | - `timesteps_start`: (int) Timestep for initiating the diffusion backward process
31 | - `timesteps_save`: (int) Interval of timesteps after which (intermediate) results are saved
32 | - `timesteps_step`: (int) Number of timesteps skipped between consecutive iterations of the diffusion backward process
33 | - `blur_sigma`: (int) Sigma of the Gaussian blurring applied before starting the diffusion forward process
34 | - `num_files`: (int) Number of files that should be processed. Set to -1 to process all available files
35 |
--------------------------------------------------------------------------------
/notebooks/jupyter_apply_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "3f58caaa",
6 | "metadata": {},
7 | "source": [
8 | "# Application Pipeline"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "5be99b70",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Imports\n",
19 | "\n",
20 | "import os\n",
21 | "import sys\n",
22 | "import torch\n",
23 | "import ipywidgets as wg\n",
24 | "from IPython.display import display, Javascript \n",
25 | "from argparse import ArgumentParser, Namespace\n",
26 | "\n",
27 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n",
28 | "from utils.jupyter_widgets import get_pipelin_widget, get_apply_parameter_widgets, get_execution_widgets"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "4a674686",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# Check for GPU\n",
39 | "\n",
40 | "use_cuda = torch.cuda.is_available()\n",
41 | "if use_cuda:\n",
42 | " print('The following GPU was found:\\n')\n",
43 | " print('CUDNN VERSION:', torch.backends.cudnn.version())\n",
44 | " print('Number CUDA Devices:', torch.cuda.device_count())\n",
45 | " print('CUDA Device Name:',torch.cuda.get_device_name(0))\n",
46 | " print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)\n",
47 | "else:\n",
48 | " print('No GPU was found. CPU will be used.')\n",
49 | "\n",
50 | "# Select a pipeline \n",
51 | "pipeline = get_pipelin_widget()\n",
52 | "display(pipeline)"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "id": "8162637c",
58 | "metadata": {},
59 | "source": [
60 | "---\n",
61 | "After executing the next block, please adapt all parameters accordingly.\n",
62 | "The pipeline expects a list of files that should be used for testing. \n",
63 | "Absolute paths to each files are automatically obtained by concatenating the provided data root and each entry of the file lists. When changing the selected pipeline, please again execute the following block.
\n",
64 | "A pretrained model is already provided with the repository for demonstration purposes. Further pretrained models can be downloaded from our website."
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "id": "22365964",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# Define general inference parameters\n",
75 | "params = {'output_path': '../data_samples/experiment_3D',\n",
76 | " 'ckpt_path': '../data_samples/experiment_3D/Diffusion_3DCTC_CE_epoch=4999.ckpt',\n",
77 | " 'gpus': use_cuda,\n",
78 | " 'overlap': (0,20,20),\n",
79 | " 'crop': (0,20,20),\n",
80 | " 'input_batch': 'image',\n",
81 | " 'clip': (-1.0, 1.0),\n",
82 | " 'num_files':-1,\n",
83 | " 'add_noise_channel': False,\n",
84 | " 'pipeline': pipeline.value,\n",
85 | " }\n",
86 | "\n",
87 | "params_diff = {'timesteps_start': 400,\n",
88 | " 'timesteps_save': 100,\n",
89 | " 'timesteps_step': 1,\n",
90 | " 'blur_sigma': 1\n",
91 | " }\n",
92 | "\n",
93 | "\n",
94 | "# Load selected pipeline\n",
95 | "if params['pipeline'].lower() == 'diffusionmodel3d':\n",
96 | " from models.DiffusionModel3D import DiffusionModel3D as network\n",
97 | " params.update(params_diff)\n",
98 | "elif params['pipeline'].lower() == 'diffusionmodel2d':\n",
99 | " from models.DiffusionModel2D import DiffusionModel2D as network\n",
100 | " params.update(params_diff)\n",
101 | "else:\n",
102 | " raise ValueError('Pipeline {0} unknown.'.format(params['pipeline']))\n",
103 | "\n",
104 | " \n",
105 | "# Get and show corresponding parameters\n",
106 | "pipeline_args = ArgumentParser(add_help=False)\n",
107 | "pipeline_args = network.add_model_specific_args(pipeline_args)\n",
108 | "pipeline_args = vars(pipeline_args.parse_known_args()[0])\n",
109 | "params = {**params, **pipeline_args}\n",
110 | "\n",
111 | "print('-'*60+'\\nPARAMETER FOR PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n",
112 | "param_names, widget_list = get_apply_parameter_widgets(params)\n",
113 | "for widget in widget_list: \n",
114 | " display(widget)\n",
115 | " \n",
116 | "print('-'*60+'\\nEXECUTION SETTINGS\\n'+'-'*60)\n",
117 | "wg_execute, wg_arguments = get_execution_widgets()\n",
118 | "display(wg_arguments)\n",
119 | "display(wg_execute)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "id": "19adcedb",
125 | "metadata": {},
126 | "source": [
127 | "---\n",
128 | "Finish preparations and start processing by executing the next block. The outputs are expected to be in the value range (-1,1)."
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "ec280ea3",
135 | "metadata": {
136 | "scrolled": false
137 | },
138 | "outputs": [],
139 | "source": [
140 | "# Get parameters\n",
141 | "param_names = [p for p,w in zip(param_names, widget_list) if w.value!=False and w.value!='']\n",
142 | "widget_list = [w for w in widget_list if w.value!=False and w.value!='']\n",
143 | "command_line_args = ' '.join(['--pipeline {0}'.format(pipeline.value)]+\\\n",
144 | " [n+' '+str(w.value) if not type(w.value)==bool else n\\\n",
145 | " for n,w in zip(param_names, widget_list)])\n",
146 | "\n",
147 | "# Show the command line arguments\n",
148 | "if wg_arguments.value:\n",
149 | " print('_'*90+'\\nCOMMAND LINE ARGUMENTS FOR apply_script_diffusion.py WITH PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*90)\n",
150 | " print(command_line_args)\n",
151 | " print('\\n')\n",
152 | " \n",
153 | "# Execute the pipeline\n",
154 | "if wg_execute.value:\n",
155 | " print('_'*60+'\\nEXECUTING PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n",
156 | " %run \"../apply_script_diffusion.py\" {command_line_args}"
157 | ]
158 | }
159 | ],
160 | "metadata": {
161 | "kernelspec": {
162 | "display_name": "Python 3 (ipykernel)",
163 | "language": "python",
164 | "name": "python3"
165 | },
166 | "language_info": {
167 | "codemirror_mode": {
168 | "name": "ipython",
169 | "version": 3
170 | },
171 | "file_extension": ".py",
172 | "mimetype": "text/x-python",
173 | "name": "python",
174 | "nbconvert_exporter": "python",
175 | "pygments_lexer": "ipython3",
176 | "version": "3.7.13"
177 | },
178 | "varInspector": {
179 | "cols": {
180 | "lenName": 16,
181 | "lenType": 16,
182 | "lenVar": 40
183 | },
184 | "kernels_config": {
185 | "python": {
186 | "delete_cmd_postfix": "",
187 | "delete_cmd_prefix": "del ",
188 | "library": "var_list.py",
189 | "varRefreshCmd": "print(var_dic_list())"
190 | },
191 | "r": {
192 | "delete_cmd_postfix": ") ",
193 | "delete_cmd_prefix": "rm(",
194 | "library": "var_list.r",
195 | "varRefreshCmd": "cat(var_dic_list()) "
196 | }
197 | },
198 | "types_to_exclude": [
199 | "module",
200 | "function",
201 | "builtin_function_or_method",
202 | "instance",
203 | "_Feature"
204 | ],
205 | "window_display": false
206 | },
207 | "vscode": {
208 | "interpreter": {
209 | "hash": "3d6afa663d3b7d8b7c28e0e5bf1fc62360d26f74485b03653b2cb99921ca431b"
210 | }
211 | }
212 | },
213 | "nbformat": 4,
214 | "nbformat_minor": 5
215 | }
216 |
--------------------------------------------------------------------------------
/notebooks/jupyter_preparation_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "e877057c",
6 | "metadata": {},
7 | "source": [
8 | "# Data Preparation Pipeline "
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "722a8a44",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Imports\n",
19 | "\n",
20 | "import os\n",
21 | "import glob\n",
22 | "from skimage import io\n",
23 | "\n",
24 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n",
25 | "from utils.h5_converter import prepare_images\n",
26 | "from utils.csv_generator import create_csv"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "id": "66084504",
32 | "metadata": {},
33 | "source": [
34 | "Both, the training and application scripts are designed to use the hdf5 data format. Therefore, all data samples need to be converted into this data format first and the function `prepare_images` can be used for that purpose.
\n",
35 | "Two example data samples are availabe in the _data_samples_ folder and the following function saves the converted data samples to the same folder. If another directory or folder is desired, the parameters _save_path_ and _save_folders_ can additionally be adapted accordingly."
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "id": "44dbe16b",
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "# Convert images to h5\n",
46 | "prepare_images(data_path=r'../', folders=['data_samples',], identifier='*.tif')"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "a3eb2027",
52 | "metadata": {},
53 | "source": [
54 | "As a next step, the converted data samples need to be listed in a csv file to make the accessible to the training and application pipelines.
\n",
55 | "For creating the file lists, the function `create_csv` can be used. In case a test or valdation split is desired, the parameters _test_split_ and _val_split_ can be adapted to provide percentages of the respective splits."
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "c7864ed0",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "# Get a list of all files\n",
66 | "image_files = glob.glob(r'../data_samples/image3D*.h5')\n",
67 | "sketch_files = glob.glob(r'../data_samples/sketch3D*.h5')\n",
68 | "\n",
69 | "# Save the lists as csv file\n",
70 | "create_csv([[os.path.split(i)[-1],] for i in image_files],\\\n",
71 | " save_path=r'../data_samples/image_files_3D', test_split=0, val_split=0)\n",
72 | "create_csv([[os.path.split(s)[-1]] for s in sketch_files],\\\n",
73 | " save_path=r'../data_samples/sketch_files_3D', test_split=0, val_split=0)"
74 | ]
75 | }
76 | ],
77 | "metadata": {
78 | "kernelspec": {
79 | "display_name": "Python 3 (ipykernel)",
80 | "language": "python",
81 | "name": "python3"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.7.13"
94 | },
95 | "varInspector": {
96 | "cols": {
97 | "lenName": 16,
98 | "lenType": 16,
99 | "lenVar": 40
100 | },
101 | "kernels_config": {
102 | "python": {
103 | "delete_cmd_postfix": "",
104 | "delete_cmd_prefix": "del ",
105 | "library": "var_list.py",
106 | "varRefreshCmd": "print(var_dic_list())"
107 | },
108 | "r": {
109 | "delete_cmd_postfix": ") ",
110 | "delete_cmd_prefix": "rm(",
111 | "library": "var_list.r",
112 | "varRefreshCmd": "cat(var_dic_list()) "
113 | }
114 | },
115 | "types_to_exclude": [
116 | "module",
117 | "function",
118 | "builtin_function_or_method",
119 | "instance",
120 | "_Feature"
121 | ],
122 | "window_display": false
123 | }
124 | },
125 | "nbformat": 4,
126 | "nbformat_minor": 5
127 | }
128 |
--------------------------------------------------------------------------------
/notebooks/jupyter_train_script.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c5a553ab",
6 | "metadata": {},
7 | "source": [
8 | "# Training Pipeline "
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "9914fa37",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Imports\n",
19 | "\n",
20 | "import os\n",
21 | "import sys\n",
22 | "import torch\n",
23 | "import ipywidgets as wg\n",
24 | "from IPython.display import display, Javascript \n",
25 | "from argparse import ArgumentParser, Namespace\n",
26 | "\n",
27 | "os.sys.path.append(os.path.dirname(os.path.abspath('.')))\n",
28 | "from utils.jupyter_widgets import get_pipelin_widget, get_parameter_widgets, get_execution_widgets"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "6dd050db",
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "# Check for GPU\n",
39 | "\n",
40 | "use_cuda = torch.cuda.is_available()\n",
41 | "if use_cuda:\n",
42 | " print('The following GPU was found:\\n')\n",
43 | " print('CUDNN VERSION:', torch.backends.cudnn.version())\n",
44 | " print('Number CUDA Devices:', torch.cuda.device_count())\n",
45 | " print('CUDA Device Name:',torch.cuda.get_device_name(0))\n",
46 | " print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)\n",
47 | "else:\n",
48 | " print('No GPU was found. CPU will be used.')\n",
49 | "\n",
50 | "# Select a pipeline \n",
51 | "pipeline = get_pipelin_widget()\n",
52 | "display(pipeline)"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "id": "fb2c9954",
58 | "metadata": {},
59 | "source": [
60 | "---\n",
61 | "After executing the next block, please adapt all parameters accordingly.\n",
62 | "The pipeline expects lists of files that should be used for training, testing and validation. \n",
63 | "Absolute paths to each files are automatically obtained by concatenating the provided data root and each entry of the file lists.\n",
64 | "When changing the selected pipeline, please again execute the following block."
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "id": "d5642dad",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# Define general training parameters\n",
75 | "params = {'output_path': '../data_samples/experiment_3D',\n",
76 | " 'log_path': '../data_samples/experiment_3D/logs',\n",
77 | " 'gpus': use_cuda,\n",
78 | " 'no_resume': False,\n",
79 | " 'pretrained': None,\n",
80 | " 'augmentations': None,\n",
81 | " 'epochs': 5000,\n",
82 | " 'pipeline': pipeline.value}\n",
83 | "\n",
84 | "# Load selected pipeline\n",
85 | "if params['pipeline'].lower() == 'diffusionmodel3d':\n",
86 | " from models.DiffusionModel3D import DiffusionModel3D as network\n",
87 | "elif params['pipeline'].lower() == 'diffusionmodel2d':\n",
88 | " from models.DiffusionModel2D import DiffusionModel2D as network\n",
89 | "else:\n",
90 | " raise ValueError('Pipeline {0} unknown.'.format(params['pipeline']))\n",
91 | "# Get and show corresponding parameters\n",
92 | "pipeline_args = ArgumentParser(add_help=False)\n",
93 | "pipeline_args = network.add_model_specific_args(pipeline_args)\n",
94 | "pipeline_args = vars(pipeline_args.parse_known_args()[0])\n",
95 | "params = {**params, **pipeline_args}\n",
96 | "\n",
97 | "print('-'*60+'\\nPARAMETER FOR PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n",
98 | "param_names, widget_list, _ = get_parameter_widgets(params)\n",
99 | "for widget in widget_list: \n",
100 | " display(widget)\n",
101 | " \n",
102 | "print('-'*60+'\\nEXECUTION SETTINGS\\n'+'-'*60)\n",
103 | "wg_execute, wg_arguments = get_execution_widgets()\n",
104 | "display(wg_arguments)\n",
105 | "display(wg_execute)"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "id": "1efbe8ae",
111 | "metadata": {},
112 | "source": [
113 | "---\n",
114 | "Finish preparations and start processing by executing the next block."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "2daffbc0",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# Get parameters\n",
125 | "param_names = [p for p,w in zip(param_names, widget_list) if w.value!=False and w.value!='']\n",
126 | "widget_list = [w for w in widget_list if w.value!=False and w.value!='']\n",
127 | "command_line_args = ' '.join(['--pipeline {0}'.format(pipeline.value)]+\\\n",
128 | " [n+' '+str(w.value) if not type(w.value)==bool else n\\\n",
129 | " for n,w in zip(param_names, widget_list)])\n",
130 | "\n",
131 | "# Show the command line arguments\n",
132 | "if wg_arguments.value:\n",
133 | " print('_'*90+'\\nCOMMAND LINE ARGUMENTS FOR train_script.py WITH PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*90)\n",
134 | " print(command_line_args)\n",
135 | " print('\\n')\n",
136 | " \n",
137 | "# Execute the pipeline\n",
138 | "if wg_execute.value:\n",
139 | " print('_'*60+'\\nEXECUTING PIPELINE \"{0}\"\\n'.format(pipeline.value)+'-'*60)\n",
140 | " %run \"../train_script.py\" {command_line_args}"
141 | ]
142 | }
143 | ],
144 | "metadata": {
145 | "kernelspec": {
146 | "display_name": "Python 3 (ipykernel)",
147 | "language": "python",
148 | "name": "python3"
149 | },
150 | "language_info": {
151 | "codemirror_mode": {
152 | "name": "ipython",
153 | "version": 3
154 | },
155 | "file_extension": ".py",
156 | "mimetype": "text/x-python",
157 | "name": "python",
158 | "nbconvert_exporter": "python",
159 | "pygments_lexer": "ipython3",
160 | "version": "3.7.13"
161 | },
162 | "varInspector": {
163 | "cols": {
164 | "lenName": 16,
165 | "lenType": 16,
166 | "lenVar": 40
167 | },
168 | "kernels_config": {
169 | "python": {
170 | "delete_cmd_postfix": "",
171 | "delete_cmd_prefix": "del ",
172 | "library": "var_list.py",
173 | "varRefreshCmd": "print(var_dic_list())"
174 | },
175 | "r": {
176 | "delete_cmd_postfix": ") ",
177 | "delete_cmd_prefix": "rm(",
178 | "library": "var_list.r",
179 | "varRefreshCmd": "cat(var_dic_list()) "
180 | }
181 | },
182 | "types_to_exclude": [
183 | "module",
184 | "function",
185 | "builtin_function_or_method",
186 | "instance",
187 | "_Feature"
188 | ],
189 | "window_display": false
190 | },
191 | "vscode": {
192 | "interpreter": {
193 | "hash": "3d6afa663d3b7d8b7c28e0e5bf1fc62360d26f74485b03653b2cb99921ca431b"
194 | }
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 5
199 | }
200 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.1.0
2 | aiohttp==3.8.1
3 | aiosignal==1.2.0
4 | alabaster==0.7.12
5 | appdirs==1.4.4
6 | argon2-cffi==21.3.0
7 | argon2-cffi-bindings==21.2.0
8 | arrow==1.2.2
9 | astroid==2.11.5
10 | astropy==4.3.1
11 | async-timeout==4.0.2
12 | asynctest==0.13.0
13 | atomicwrites==1.4.0
14 | attrs==21.4.0
15 | autopep8==1.6.0
16 | Babel==2.10.1
17 | backcall==0.2.0
18 | beautifulsoup4==4.11.1
19 | binaryornot==0.4.4
20 | black==22.3.0
21 | bleach==5.0.0
22 | cachetools==5.2.0
23 | certifi==2022.5.18.1
24 | cffi==1.15.0
25 | chardet==4.0.0
26 | charset-normalizer==2.0.12
27 | click==8.1.3
28 | cloudpickle==2.1.0
29 | cookiecutter==2.1.1
30 | cryptography==37.0.2
31 | cycler==0.11.0
32 | Cython==0.29.30
33 | debugpy==1.6.0
34 | decorator==5.1.1
35 | defusedxml==0.7.1
36 | diff-match-patch==20200713
37 | dill==0.3.5.1
38 | dipy==1.5.0
39 | docutils==0.17.1
40 | entrypoints==0.4
41 | fastjsonschema==2.15.3
42 | flake8==4.0.1
43 | fonttools==4.33.3
44 | frozenlist==1.3.0
45 | fsspec==2022.5.0
46 | google-auth==2.7.0
47 | google-auth-oauthlib==0.4.6
48 | grpcio==1.46.3
49 | gryds @ git+https://github.com/tueimage/gryds@cda4bac8f71e8bb47fc632b8cdea010904ae5cf1
50 | h5py==3.7.0
51 | hdbscan==0.8.28
52 | idna==3.3
53 | imagecodecs==2021.11.20
54 | imageio==2.19.3
55 | imagesize==1.3.0
56 | importlib-metadata==4.2.0
57 | importlib-resources==5.7.1
58 | inflection==0.5.1
59 | intervaltree==3.1.0
60 | ipykernel==6.13.1
61 | ipython==7.34.0
62 | ipython-genutils==0.2.0
63 | ipywidgets==7.7.0
64 | isort==5.10.1
65 | jedi==0.18.1
66 | jeepney==0.8.0
67 | jellyfish==0.9.0
68 | Jinja2==3.1.2
69 | jinja2-time==0.2.0
70 | joblib==1.1.0
71 | jsonschema==4.6.0
72 | jupyter==1.0.0
73 | jupyter-client==7.3.4
74 | jupyter-console==6.4.3
75 | jupyter-core==4.10.0
76 | jupyterlab-pygments==0.2.2
77 | jupyterlab-widgets==1.1.0
78 | keyring==23.6.0
79 | kiwisolver==1.4.3
80 | lazy-object-proxy==1.7.1
81 | Markdown==3.3.7
82 | MarkupSafe==2.1.1
83 | matplotlib==3.5.2
84 | matplotlib-inline==0.1.3
85 | mccabe==0.6.1
86 | mistune==0.8.4
87 | multidict==6.0.2
88 | mypy-extensions==0.4.3
89 | nbclient==0.6.4
90 | nbconvert==6.5.0
91 | nbformat==5.4.0
92 | nest-asyncio==1.5.5
93 | networkx==2.6.3
94 | nibabel==4.0.1
95 | notebook==6.4.12
96 | numpy==1.21.6
97 | numpydoc==1.3.1
98 | oauthlib==3.2.0
99 | opencv-python==4.7.0.72
100 | packaging==21.3
101 | pandas==1.3.5
102 | pandocfilters==1.5.0
103 | parso==0.8.3
104 | pathspec==0.9.0
105 | pexpect==4.8.0
106 | pickleshare==0.7.5
107 | Pillow==9.1.1
108 | platformdirs==2.5.2
109 | pluggy==1.0.0
110 | pooch==1.6.0
111 | prometheus-client==0.14.1
112 | prompt-toolkit==3.0.29
113 | protobuf==3.19.4
114 | psutil==5.9.1
115 | ptyprocess==0.7.0
116 | pyasn1==0.4.8
117 | pyasn1-modules==0.2.8
118 | pycodestyle==2.8.0
119 | pycparser==2.21
120 | pyDeprecate==0.3.2
121 | pydocstyle==6.1.1
122 | pyerfa==2.0.0.1
123 | pyflakes==2.4.0
124 | Pygments==2.12.0
125 | pylint==2.14.1
126 | pyls-spyder==0.4.0
127 | pyparsing==3.0.9
128 | PyQt5==5.15.6
129 | PyQt5-Qt5==5.15.2
130 | PyQt5-sip==12.10.1
131 | PyQtWebEngine==5.15.5
132 | PyQtWebEngine-Qt5==5.15.2
133 | pyquaternion==0.9.9
134 | pyrsistent==0.18.1
135 | pyshtools==4.10
136 | python-dateutil==2.8.2
137 | python-lsp-black==1.2.1
138 | python-lsp-jsonrpc==1.0.0
139 | python-lsp-server==1.4.1
140 | python-slugify==6.1.2
141 | pytorch-lightning==1.6.4
142 | pytz==2022.1
143 | PyWavelets==1.3.0
144 | pyxdg==0.28
145 | PyYAML==6.0
146 | pyzmq==23.1.0
147 | QDarkStyle==3.0.3
148 | qstylizer==0.2.1
149 | QtAwesome==1.1.1
150 | qtconsole==5.3.1
151 | QtPy==2.1.0
152 | requests==2.27.1
153 | requests-oauthlib==1.3.1
154 | rope==1.1.1
155 | rsa==4.8
156 | Rtree==1.0.0
157 | scikit-image==0.19.3
158 | scikit-learn==1.0.2
159 | scipy==1.7.3
160 | SecretStorage==3.3.2
161 | Send2Trash==1.8.0
162 | six==1.16.0
163 | snowballstemmer==2.2.0
164 | sortedcontainers==2.4.0
165 | soupsieve==2.3.2.post1
166 | Sphinx==4.3.2
167 | sphinxcontrib-applehelp==1.0.2
168 | sphinxcontrib-devhelp==1.0.2
169 | sphinxcontrib-htmlhelp==2.0.0
170 | sphinxcontrib-jsmath==1.0.1
171 | sphinxcontrib-qthelp==1.0.3
172 | sphinxcontrib-serializinghtml==1.1.5
173 | spyder==5.3.1
174 | spyder-kernels==2.3.1
175 | tensorboard==2.9.1
176 | tensorboard-data-server==0.6.1
177 | tensorboard-plugin-wit==1.8.1
178 | terminado==0.15.0
179 | text-unidecode==1.3
180 | textdistance==4.2.2
181 | threadpoolctl==3.1.0
182 | three-merge==0.1.1
183 | tifffile==2021.11.2
184 | tinycss2==1.1.1
185 | toml==0.10.2
186 | tomli==2.0.1
187 | tomlkit==0.11.0
188 | torch==1.11.0+cu113
189 | torchaudio==0.11.0+cu113
190 | torchmetrics==0.9.1
191 | torchvision==0.12.0
192 | tornado==6.1
193 | tqdm==4.64.0
194 | traitlets==5.2.2.post1
195 | typed-ast==1.5.4
196 | typing_extensions==4.2.0
197 | ujson==5.3.0
198 | urllib3==1.26.9
199 | watchdog==2.1.8
200 | wcwidth==0.2.5
201 | webencodings==0.5.1
202 | Werkzeug==2.1.2
203 | widgetsnbextension==3.6.0
204 | wrapt==1.14.1
205 | wurlitzer==3.0.2
206 | xarray==0.20.2
207 | yapf==0.32.0
208 | yarl==1.7.2
209 | zipp==3.8.0
210 |
--------------------------------------------------------------------------------
/train_script.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Tue Jan 14 13:46:20 2020
5 |
6 | @author: eschweiler
7 | """
8 |
9 | from argparse import ArgumentParser
10 |
11 | import numpy as np
12 | import torch
13 | import glob
14 | import os
15 |
16 | from pytorch_lightning import Trainer
17 | from pytorch_lightning.loggers import TensorBoardLogger
18 | from pytorch_lightning.callbacks import ModelCheckpoint
19 |
20 | SEED = 1337
21 | torch.manual_seed(SEED)
22 | np.random.seed(SEED)
23 |
24 |
25 | def main(hparams):
26 |
27 | """
28 | Main training routine specific for this project
29 | :param hparams:
30 | """
31 |
32 | # ------------------------
33 | # 1 INIT LIGHTNING MODEL
34 | # ------------------------
35 | model = network(hparams=hparams)
36 | os.makedirs(hparams.output_path, exist_ok=True)
37 |
38 | # Load pretrained weights if available
39 | if not hparams.pretrained is None:
40 | model.load_pretrained(hparams.pretrained)
41 |
42 | # Resume from checkpoint if available
43 | resume_ckpt = None
44 | if hparams.resume:
45 | checkpoints = glob.glob(os.path.join(hparams.output_path,'*.ckpt'))
46 | checkpoints.sort(key=os.path.getmtime)
47 | if len(checkpoints)>0:
48 | resume_ckpt = checkpoints[-1]
49 | print('Resuming from checkpoint: {0}'.format(resume_ckpt))
50 |
51 | # Set the augmentations if available
52 | model.set_augmentations(hparams.augmentations)
53 |
54 | # Save a few samples for sanity checks
55 | print('Saving 20 data samples for sanity checks...')
56 | model.train_dataloader().dataset.test(os.path.join(hparams.output_path, 'samples'), num_files=20)
57 |
58 | # ------------------------
59 | # 2 INIT TRAINER
60 | # ------------------------
61 |
62 | checkpoint_callback = ModelCheckpoint(
63 | dirpath=hparams.output_path,
64 | filename=hparams.pipeline+'-{epoch:03d}-{step}',
65 | save_top_k=1,
66 | monitor='step',
67 | mode='max',
68 | verbose=True,
69 | every_n_epochs=1
70 | )
71 |
72 | logger = TensorBoardLogger(
73 | save_dir=hparams.log_path,
74 | name='lightning_logs_'+hparams.pipeline.lower()
75 | )
76 |
77 | trainer = Trainer(
78 | logger=logger,
79 | enable_checkpointing=True,
80 | callbacks=[checkpoint_callback],
81 | gpus=hparams.gpus,
82 | min_epochs=hparams.epochs,
83 | max_epochs=hparams.epochs,
84 | resume_from_checkpoint=resume_ckpt
85 | )
86 |
87 | # ------------------------
88 | # 3 START TRAINING
89 | # ------------------------
90 | trainer.fit(model)
91 |
92 |
93 |
94 | if __name__ == '__main__':
95 | # ------------------------
96 | # TRAINING ARGUMENTS
97 | # ------------------------
98 | # these are project-wide arguments
99 |
100 | parent_parser = ArgumentParser(add_help=False)
101 |
102 | # gpu args
103 | parent_parser.add_argument(
104 | '--output_path',
105 | type=str,
106 | default=r'results/experiment1',
107 | help='output path for test results'
108 | )
109 |
110 | parent_parser.add_argument(
111 | '--log_path',
112 | type=str,
113 | default=r'logs/logs_experiment1',
114 | help='output path for test results'
115 | )
116 |
117 | parent_parser.add_argument(
118 | '--gpus',
119 | type=int,
120 | default=1,
121 | help='number of GPUs to use'
122 | )
123 |
124 | parent_parser.add_argument(
125 | '--no_resume',
126 | dest='resume',
127 | action='store_false',
128 | default=True,
129 | help='Do not resume training from latest checkpoint'
130 | )
131 |
132 | parent_parser.add_argument(
133 | '--pretrained',
134 | type=str,
135 | default=None,
136 | nargs='+',
137 | help='path to pretrained model weights'
138 | )
139 |
140 | parent_parser.add_argument(
141 | '--augmentations',
142 | type=str,
143 | default=None,
144 | help='path to augmentation dict file'
145 | )
146 |
147 | parent_parser.add_argument(
148 | '--epochs',
149 | type=int,
150 | default=5000,
151 | help='number of epochs'
152 | )
153 |
154 | parent_parser.add_argument(
155 | '--pipeline',
156 | type=str,
157 | default='DiffusionModel3D',
158 | help='which pipeline to load (DiffusionModel3D | DiffusionModel2D)'
159 | )
160 |
161 | parent_args = parent_parser.parse_known_args()[0]
162 |
163 | # load the desired network architecture
164 | if parent_args.pipeline.lower() == 'diffusionmodel3d':
165 | from models.DiffusionModel3D import DiffusionModel3D as network
166 | elif parent_args.pipeline.lower() == 'diffusionmodel2d':
167 | from models.DiffusionModel2D import DiffusionModel2D as network
168 | else:
169 | raise ValueError('Pipeline {0} unknown.'.format(parent_args.pipeline))
170 |
171 | # each LightningModule defines arguments relevant to it
172 | parser = network.add_model_specific_args(parent_parser)
173 | hyperparams = parser.parse_args()
174 |
175 | # ---------------------
176 | # RUN TRAINING
177 | # ---------------------
178 | main(hyperparams)
179 |
--------------------------------------------------------------------------------
/utils/csv_generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 19 15:55:13 2020
4 |
5 | @author: Nutzer
6 | """
7 |
8 | import os
9 | import glob
10 | import csv
11 | import numpy as np
12 |
13 |
14 | def get_files(folders, data_root='', descriptor='', filetype='tif'):
15 |
16 | filelist = []
17 |
18 | for folder in folders:
19 | files = glob.glob(os.path.join(data_root, folder, '*'+descriptor+'*.'+filetype))
20 | filelist.extend([os.path.join(folder, os.path.split(f)[-1]) for f in files])
21 |
22 | return filelist
23 |
24 |
25 |
26 | def read_csv(list_path, data_root=''):
27 |
28 | filelist = []
29 |
30 | with open(list_path, 'r') as f:
31 | reader = csv.reader(f, delimiter=';')
32 | for row in reader:
33 | if len(row)==0 or np.sum([len(r) for r in row])==0: continue
34 | row = [os.path.join(data_root, r) for r in row]
35 | filelist.append(row)
36 |
37 | return filelist
38 |
39 |
40 |
41 | def create_csv(data_list, save_path='list_folder/experiment_name', test_split=0.2, val_split=0.1, shuffle=False):
42 |
43 | if shuffle:
44 | np.random.shuffle(data_list)
45 |
46 | # Get number of files for each split
47 | num_files = len(data_list)
48 | num_test_files = int(test_split*num_files)
49 | num_val_files = int((num_files-num_test_files)*val_split)
50 | num_train_files = num_files - num_test_files - num_val_files
51 |
52 | # Adjust file identifier if there is no split
53 | if test_split>0 or val_split>0:
54 | train_identifier='_train.csv'
55 | else:
56 | train_identifier='.csv'
57 |
58 | # Get file indices
59 | file_idx = np.arange(num_files)
60 |
61 | # Save csv files
62 | if num_test_files > 0:
63 | test_idx = sorted(np.random.choice(file_idx, size=num_test_files, replace=False))
64 | with open(save_path+'_test.csv', 'w', newline='') as fh:
65 | writer = csv.writer(fh, delimiter=';')
66 | for idx in test_idx:
67 | writer.writerow(data_list[idx])
68 | else:
69 | test_idx = []
70 |
71 | if num_val_files > 0:
72 | val_idx = sorted(np.random.choice(list(set(file_idx)-set(test_idx)), size=num_val_files, replace=False))
73 | with open(save_path+'_val.csv', 'w', newline='') as fh:
74 | writer = csv.writer(fh, delimiter=';')
75 | for idx in val_idx:
76 | writer.writerow(data_list[idx])
77 | else:
78 | val_idx = []
79 |
80 | if num_train_files > 0:
81 | train_idx = sorted(list(set(file_idx) - set(test_idx) - set(val_idx)))
82 | with open(save_path+train_identifier, 'w', newline='') as fh:
83 | writer = csv.writer(fh, delimiter=';')
84 | for idx in train_idx:
85 | writer.writerow(data_list[idx])
86 |
87 |
--------------------------------------------------------------------------------
/utils/h5_converter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import os
5 | import h5py
6 | import glob
7 | import numpy as np
8 |
9 | from skimage import io, morphology, measure, filters
10 | from scipy.ndimage import distance_transform_edt, zoom, generic_filter
11 | from scipy.spatial import ConvexHull, Delaunay
12 |
13 | from utils.utils import print_timestamp
14 |
15 |
16 |
17 | def h5_writer(data_list, save_path, group_root='data', group_names=['image']):
18 |
19 | save_path = os.path.abspath(save_path)
20 |
21 | assert(len(data_list)==len(group_names)), 'Each data matrix needs a group name'
22 |
23 | with h5py.File(save_path, 'w') as f_handle:
24 | grp = f_handle.create_group(group_root)
25 | for data, group_name in zip(data_list, group_names):
26 | grp.create_dataset(group_name, data=data, chunks=True, compression='gzip')
27 |
28 |
29 |
30 | def h5_reader(file, source_group='data/image'):
31 |
32 | with h5py.File(file, 'r') as file_handle:
33 | data = file_handle[source_group][:]
34 |
35 | return data
36 |
37 |
38 |
39 |
40 | def h52tif(file_dir='', identifier='*', group_names=['data/image']):
41 |
42 | # Get all files within the given directory
43 | filelist = glob.glob(os.path.join(file_dir, identifier+'.h5'))
44 |
45 | # Create saving folders
46 | for group in group_names:
47 | os.makedirs(os.path.join(file_dir, ''.join(s for s in group if s.isalnum())), exist_ok=True)
48 |
49 | # Save each desired group
50 | for num_file,file in enumerate(filelist):
51 | print_timestamp('Processing file {0}/{1}', (num_file+1, len(filelist)))
52 | with h5py.File(file, 'r') as file_handle:
53 | for group in group_names:
54 | data = file_handle[group][:]
55 | io.imsave(os.path.join(file_dir, ''.join(s for s in group if s.isalnum()), os.path.split(file)[-1][:-2]+'tif'), data)
56 |
57 |
58 |
59 | def replace_h5_group(source_list, target_list, source_group='data/image', target_group=None):
60 |
61 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list))
62 | if target_group is None: target_group=source_group
63 |
64 | for num_pair, pair in enumerate(zip(source_list, target_list)):
65 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(target_list)])
66 |
67 | # Load the source mask
68 | with h5py.File(pair[0], 'r') as source_handle:
69 | source_data = source_handle[source_group][...]
70 |
71 | # Save the data to the target file
72 | with h5py.File(pair[1], 'r+') as target_handle:
73 | target_data = target_handle[target_group]
74 | target_data[...] = source_data
75 |
76 |
77 |
78 |
79 | def add_group(file, data, target_group='data/image'):
80 |
81 | with h5py.File(file, 'a') as file_handle:
82 | file_handle.create_dataset(target_group, data=data, chunks=True, compression='gzip')
83 |
84 |
85 |
86 |
87 | def add_h5_group(source_list, target_list, source_group='data/distance', target_group=None):
88 |
89 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list))
90 | if target_group is None: target_group=source_group
91 |
92 | for num_pair, pair in enumerate(zip(source_list, target_list)):
93 |
94 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(source_list)])
95 |
96 | # Get the data from the source file
97 | with h5py.File(pair[0], 'r') as source_handle:
98 | source_data = source_handle[source_group][...]
99 |
100 | # Save the data to the target file
101 | try:
102 | with h5py.File(pair[1], 'a') as target_handle:
103 | target_handle.create_dataset(target_group, data=source_data, chunks=True, compression='gzip')
104 | except:
105 | print_timestamp('Skipping file "{0}"...', [os.path.split(pair[1])[-1]])
106 |
107 |
108 |
109 |
110 | def add_tiff_group(source_list, target_list, target_group='data/newgroup'):
111 |
112 | assert len(target_list)==len(source_list), 'There needs to be one target ({0}) for each source ({1})!'.format(len(target_list), len(source_list))
113 | assert target_group is not None, 'There needs to be a target group name!'
114 |
115 | for num_pair, pair in enumerate(zip(source_list, target_list)):
116 |
117 | print_timestamp('Processing file {0}/{1}...', [num_pair+1, len(source_list)])
118 |
119 | # Get the data from the source file
120 | source_data = io.imread(pair[0])
121 |
122 | # Save the data to the target file
123 | with h5py.File(pair[1], 'a') as target_handle:
124 | target_handle.create_dataset(target_group, data=source_data-np.min(source_data), chunks=True, compression='gzip')
125 |
126 |
127 |
128 | def remove_h5_group(file_list, source_group='data/nuclei'):
129 |
130 | for num_file, file in enumerate(file_list):
131 |
132 | print_timestamp('Processing file {0}/{1}...', [num_file+1, len(file_list)])
133 |
134 | with h5py.File(file, 'a') as file_handle:
135 | del file_handle[source_group]
136 |
137 |
138 |
139 | def flood_fill_hull(image):
140 |
141 | # Credits: https://stackoverflow.com/questions/46310603/how-to-compute-convex-hull-image-volume-in-3d-numpy-arrays
142 | points = np.transpose(np.where(image))
143 | hull = ConvexHull(points)
144 | deln = Delaunay(points[hull.vertices])
145 | idx = np.stack(np.indices(image.shape), axis = -1)
146 | out_idx = np.nonzero(deln.find_simplex(idx) + 1)
147 | out_img = np.zeros(image.shape)
148 | out_img[out_idx] = 1
149 |
150 | return out_img
151 |
152 |
153 |
154 | def calculate_flows(instance_mask, bg_label=0):
155 |
156 | flow_x = np.zeros(instance_mask.shape, dtype=np.float32)
157 | flow_y = np.zeros(instance_mask.shape, dtype=np.float32)
158 | flow_z = np.zeros(instance_mask.shape, dtype=np.float32)
159 | regions = measure.regionprops(instance_mask)
160 | for props in regions:
161 |
162 | if props.label == bg_label:
163 | continue
164 |
165 | # get all coordinates within instance
166 | c = props.centroid
167 | coords = np.where(instance_mask==props.label)
168 |
169 | # calculate minimum extend in all spatial directions
170 | norm_x = np.maximum(1, np.minimum(np.abs(c[0]-props.bbox[0]),np.abs(c[0]-props.bbox[3]))/3)
171 | norm_y = np.maximum(1, np.minimum(np.abs(c[1]-props.bbox[1]),np.abs(c[1]-props.bbox[4]))/3)
172 | norm_z = np.maximum(1, np.minimum(np.abs(c[2]-props.bbox[2]),np.abs(c[2]-props.bbox[5]))/3)
173 |
174 | # calculate flows
175 | flow_x[coords] = np.tanh((coords[0]-c[0])/norm_x)
176 | flow_y[coords] = np.tanh((coords[1]-c[1])/norm_y)
177 | flow_z[coords] = np.tanh((coords[2]-c[2])/norm_z)
178 |
179 | return flow_x, flow_y, flow_z
180 |
181 |
182 |
183 | def rescale_data(data, zoom_factor, order=0):
184 |
185 | if any([zf!=1 for zf in zoom_factor]):
186 | data_shape = data.shape
187 | data = zoom(data, zoom_factor, order=order)
188 | print_timestamp('Rescaled image from size {0} to {1}'.format(data_shape, data.shape))
189 |
190 | return data
191 |
192 |
193 |
194 | def prepare_images(data_path='', folders=[''], identifier='*.tif', descriptor='', normalize=[0,100],\
195 | get_distance=False, get_illumination=False, get_variance=False, variance_size=(5,5,5),\
196 | fg_selem_size=5, zoom_factor=(1,1,1), channel=0, clip=(-99999,99999),\
197 | save_path=None, save_folders=None):
198 |
199 | data_path = os.path.abspath(data_path)
200 |
201 | if save_path is None: save_path = data_path
202 | if save_folders is None: save_folders = folders
203 | if len(save_folders)==1: save_folders = save_folders*len(folders)
204 | elif len(save_folders)!=len(folders):
205 | save_folders=folders
206 | print_timestamp('Could not save into the given folders! Number of save folders and folders must be the same.')
207 |
208 |
209 | for num_folder, (image_folder, save_folder) in enumerate(zip(folders, save_folders)):
210 | os.makedirs(os.path.join(save_path, save_folder), exist_ok=True)
211 | image_list = glob.glob(os.path.join(data_path, image_folder, identifier))
212 | for num_file,file in enumerate(image_list):
213 |
214 | print_timestamp('Processing image {0}/{1} in folder {2}/{3} {4}', (num_file+1, len(image_list), num_folder+1, len(folders), image_folder))
215 |
216 | # load the image
217 | processed_img = io.imread(file)
218 | processed_img = processed_img.astype(np.float32)
219 | processed_img = np.clip(processed_img, *clip)
220 |
221 | # get the desired channel, if the image is a multichannel image
222 | if processed_img.ndim == 4:
223 | processed_img = processed_img[...,channel]
224 |
225 | # get the desired image dimensionality, if image is only 2D
226 | if processed_img.ndim==2:
227 | processed_img = processed_img[np.newaxis,...]
228 |
229 | # rescale the image
230 | processed_img = rescale_data(processed_img, zoom_factor, order=3)
231 |
232 | # normalize the image
233 | perc1, perc2 = np.percentile(processed_img, list(normalize))
234 | processed_img -= perc1
235 | processed_img /= (perc2-perc1)
236 | processed_img = np.clip(processed_img, 0, 1)
237 | processed_img = processed_img.astype(np.float32)
238 |
239 | save_imgs = [processed_img,]
240 | save_groups = ['image',]
241 |
242 | if get_illumination:
243 |
244 | print_timestamp('Extracting illumination image...')
245 |
246 | # create downscales image for computantially intensive processing
247 | small_img = processed_img[::2,::2,::2]
248 |
249 | # create an illuminance image (downscale for faster processing)
250 | illu_img = morphology.closing(small_img, selem=morphology.ball(7))
251 | illu_img = filters.gaussian(illu_img, 2).astype(np.float32)
252 |
253 | # rescale illuminance image
254 | illu_img = np.repeat(illu_img, 2, axis=0)
255 | illu_img = np.repeat(illu_img, 2, axis=1)
256 | illu_img = np.repeat(illu_img, 2, axis=2)
257 | dim_missmatch = np.array(processed_img.shape)-np.array(illu_img.shape)
258 | if dim_missmatch[0]<0: illu_img = illu_img[:dim_missmatch[0],...]
259 | if dim_missmatch[1]<0: illu_img = illu_img[:,:dim_missmatch[1],:]
260 | if dim_missmatch[2]<0: illu_img = illu_img[...,:dim_missmatch[2]]
261 |
262 | save_imgs.append(illu_img.astype(np.float32))
263 | save_groups.append('illumination')
264 |
265 | if get_distance:
266 |
267 | print_timestamp('Extracting distance image...')
268 |
269 | # create downscales image for computantially intensive processing
270 | small_img = processed_img[::4,::4,::4]
271 |
272 | # find suitable threshold
273 | thresh = filters.threshold_otsu(small_img)
274 | fg_img = small_img > thresh
275 |
276 | # remove noise and fill holes
277 | fg_img = morphology.binary_closing(fg_img, selem=morphology.ball(fg_selem_size))
278 | fg_img = morphology.binary_opening(fg_img, selem=morphology.ball(fg_selem_size))
279 | fg_img = flood_fill_hull(fg_img)
280 | fg_img = fg_img.astype(np.bool)
281 |
282 | # create distance transform
283 | fg_img = distance_transform_edt(fg_img) - distance_transform_edt(~fg_img)
284 |
285 | # rescale distance image
286 | fg_img = np.repeat(fg_img, 4, axis=0)
287 | fg_img = np.repeat(fg_img, 4, axis=1)
288 | fg_img = np.repeat(fg_img, 4, axis=2)
289 | dim_missmatch = np.array(processed_img.shape)-np.array(fg_img.shape)
290 | if dim_missmatch[0]<0: fg_img = fg_img[:dim_missmatch[0],...]
291 | if dim_missmatch[1]<0: fg_img = fg_img[:,:dim_missmatch[1],:]
292 | if dim_missmatch[2]<0: fg_img = fg_img[...,:dim_missmatch[2]]
293 |
294 | save_imgs.append(fg_img.astype(np.float32))
295 | save_groups.append('distance')
296 |
297 | if get_variance:
298 |
299 | print_timestamp('Extracting variance image...')
300 |
301 | # create downscales image for computantially intensive processing
302 | small_img = processed_img[::4,::4,::4]
303 |
304 | # create variance image
305 | std_img = generic_filter(small_img, np.std, size=variance_size)
306 |
307 | # rescale variance image
308 | std_img = np.repeat(std_img, 4, axis=0)
309 | std_img = np.repeat(std_img, 4, axis=1)
310 | std_img = np.repeat(std_img, 4, axis=2)
311 | dim_missmatch = np.array(processed_img.shape)-np.array(std_img.shape)
312 | if dim_missmatch[0]<0: std_img = std_img[:dim_missmatch[0],...]
313 | if dim_missmatch[1]<0: std_img = std_img[:,:dim_missmatch[1],:]
314 | if dim_missmatch[2]<0: std_img = std_img[...,:dim_missmatch[2]]
315 |
316 | save_imgs.append(std_img.astype(np.float32))
317 | save_groups.append('variance')
318 |
319 | # save the data
320 | save_name = os.path.split(file)[-1]
321 | save_name = os.path.join(save_path, save_folder, descriptor+save_name[:-4]+'.h5')
322 | h5_writer(save_imgs, save_name, group_root='data', group_names=save_groups)
323 |
324 |
325 |
326 |
327 |
328 | def prepare_masks(data_path='', folders=[''], identifier='*.tif', descriptor='',\
329 | bg_label=0, get_flows=False, get_boundary=True, get_seeds=False, get_distance=True,\
330 | corrupt_prob=0.0, zoom_factor=(1,1,1), convex_hull=False,\
331 | save_path=None, save_folders=None):
332 |
333 | data_path = os.path.abspath(data_path)
334 |
335 | if save_path is None: save_path = data_path
336 | if save_folders is None: save_folders = folders
337 | if len(save_folders)==1: save_folders = save_folders*len(folders)
338 | elif len(save_folders)!=len(folders):
339 | save_folders=folders
340 | print_timestamp('Could not save into the given folders! Number of save folders and folders must be the same.')
341 |
342 | for num_folder, (mask_folder,save_folder) in enumerate(zip(folders,save_folders)):
343 | mask_list = glob.glob(os.path.join(data_path, mask_folder, identifier))
344 | experiment_identifier = 'corrupt'+str(corrupt_prob).replace('.','') if corrupt_prob > 0 else ''
345 | os.makedirs(os.path.join(data_path, save_folder, experiment_identifier), exist_ok=True)
346 | for num_file,file in enumerate(mask_list):
347 |
348 | print_timestamp('Processing mask {0}/{1} in folder {2}/{3} {4}', (num_file+1, len(mask_list), num_folder+1, len(folders), mask_folder))
349 |
350 | # load the mask
351 | instance_mask = io.imread(file)
352 | instance_mask = instance_mask.astype(np.uint16)
353 | instance_mask[instance_mask==bg_label] = 0
354 |
355 | # get the desired image dimensionality, if image is only 2D
356 | if instance_mask.ndim==2:
357 | instance_mask = instance_mask[np.newaxis,...]
358 |
359 | # rescale the mask
360 | instance_mask = rescale_data(instance_mask, zoom_factor, order=0)
361 |
362 | if corrupt_prob > 0:
363 | # Randomly merge neighbouring instances
364 | labels = list(set(np.unique(instance_mask))-set([bg_label]))
365 | instance_mask_eroded = morphology.erosion(instance_mask, selem=morphology.ball(3))
366 | instance_mask_dilated = morphology.dilation(instance_mask, selem=morphology.ball(3))
367 | for label in labels:
368 | if np.random.rand() < corrupt_prob:
369 | neighbour_labels = list(instance_mask_eroded[instance_mask==label]) + list(instance_mask_dilated[instance_mask==label])
370 | neighbour_labels = list(set(neighbour_labels)-set([label,]))
371 | if len(neighbour_labels) > 0:
372 | replace_label = np.random.choice(neighbour_labels)
373 | instance_mask[instance_mask==label] = replace_label
374 |
375 | save_groups = ['instance',]
376 | save_masks = [instance_mask,]
377 |
378 | # get the boundary mask
379 | if get_boundary:
380 | membrane_mask = morphology.dilation(instance_mask, selem=morphology.ball(2)) - instance_mask
381 | membrane_mask = membrane_mask != 0
382 | membrane_mask = membrane_mask.astype(np.float32)
383 | save_groups.append('boundary')
384 | save_masks.append(membrane_mask)
385 |
386 | # get the distance mask
387 | if get_distance:
388 | fg_img = instance_mask[::4,::4,::4]>0
389 | if convex_hull: fg_img = flood_fill_hull(fg_img)
390 | fg_img = fg_img.astype(np.bool)
391 | distance_mask = distance_transform_edt(fg_img) - distance_transform_edt(~fg_img)
392 | distance_mask = distance_mask.astype(np.float32)
393 | distance_mask = np.repeat(distance_mask, 4, axis=0)
394 | distance_mask = np.repeat(distance_mask, 4, axis=1)
395 | distance_mask = np.repeat(distance_mask, 4, axis=2)
396 | dim_missmatch = np.array(instance_mask.shape)-np.array(distance_mask.shape)
397 | if dim_missmatch[0]<0: distance_mask = distance_mask[:dim_missmatch[0],...]
398 | if dim_missmatch[1]<0: distance_mask = distance_mask[:,:dim_missmatch[1],:]
399 | if dim_missmatch[2]<0: distance_mask = distance_mask[...,:dim_missmatch[2]]
400 | save_groups.append('distance')
401 | save_masks.append(distance_mask)
402 |
403 | # get the centroid mask
404 | if get_seeds:
405 | centroid_mask = np.zeros(instance_mask.shape, dtype=np.float32)
406 | regions = measure.regionprops(instance_mask)
407 |
408 | for props in regions:
409 |
410 | if props.label == bg_label:
411 | continue
412 |
413 | c = props.centroid
414 | centroid_mask[np.int(c[0]), np.int(c[1]), np.int(c[2])] = 1
415 |
416 | save_groups.append('seeds')
417 | save_masks.append(centroid_mask)
418 |
419 | # calculate the flow field
420 | if get_flows:
421 |
422 | flow_x, flow_y, flow_z = calculate_flows(instance_mask, bg_label=bg_label)
423 |
424 | save_groups.extend(['flow_x','flow_y', 'flow_z'])
425 | save_masks.extend([flow_x, flow_y, flow_z])
426 |
427 | # save the data
428 | save_name = os.path.split(file)[-1]
429 | save_name = os.path.join(save_path, save_folder, experiment_identifier, descriptor+save_name[:-4]+'.h5')
430 | h5_writer(save_masks, save_name, group_root='data', group_names=save_groups)
431 |
--------------------------------------------------------------------------------
/utils/harmonics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Aug 5 11:28:53 2020
4 |
5 | @author: Nutzer
6 | """
7 |
8 | import numpy as np
9 | import multiprocessing as mp
10 |
11 | from functools import partial
12 | from skimage import morphology
13 | from scipy.spatial import cKDTree, Delaunay
14 | from scipy.special import sph_harm
15 | from matplotlib import pyplot as plt
16 | from mpl_toolkits.mplot3d import Axes3D
17 | from dipy.core.geometry import sphere2cart, cart2sphere
18 |
19 | from utils.utils import print_timestamp
20 |
21 |
22 |
23 | def scatter_3d(coords1, coords2, coords3, cartesian=True):
24 | # (x, y, z) or (r, theta, phi)
25 |
26 | if not cartesian:
27 | coords1, coords2, coords3 = sphere2cart(coords1, coords2, coords3)
28 |
29 | fig = plt.figure()
30 | ax = Axes3D(fig)
31 | ax.scatter(coords1, coords2, coords3, depthshade=True)
32 | plt.show()
33 |
34 |
35 |
36 | def render_coeffs(coeffs, theta_phi_sampling_file, sh_order=5, radius=30):
37 |
38 | assert len(coeffs)==(sh_order+1)**2, 'SH order and number of coefficients do not match.'
39 |
40 | coeffs = np.array(coeffs)
41 | coeffs = coeffs/coeffs[0]
42 | coeffs *= radius
43 |
44 | img_shape = np.array(3*(radius*3,))
45 |
46 | theta_phi_sampling = np.load(theta_phi_sampling_file)
47 | h2s = harmonics2sampling(sh_order, theta_phi_sampling)
48 | r_sampling = h2s.convert(coeffs[np.newaxis,:])
49 |
50 | instance_mask = sampling2instance([img_shape//2,], r_sampling, theta_phi_sampling, img_shape, verbose=True)
51 |
52 | return instance_mask
53 |
54 |
55 |
56 | def instance2sampling(instances, theta_phi_sampling, bg_label=1, centroids=None, verbose=True):
57 |
58 | # Get labels
59 | labels = np.unique(instances)
60 | labels = np.array(list(set(labels)-set([bg_label])))
61 |
62 | sampling = np.zeros((len(labels),theta_phi_sampling.shape[0]))
63 |
64 | if centroids is None:
65 | get_centroids = True
66 | centroids = np.zeros((len(labels),3))
67 | else:
68 | get_centroids = False
69 | assert len(labels) == len(centroids), 'There needs to be a centroid for each label!'
70 |
71 | for num_label, label in enumerate(labels):
72 |
73 | if verbose: print_timestamp('Sampling instance {0}/{1}...', [num_label+1, len(labels)])
74 |
75 | if np.count_nonzero(instances==label) < 9:
76 | print_timestamp('Skipping label {0} due to its tiny size.', [label])
77 | continue
78 |
79 | # get inner boundary of the current instance
80 | instance_inner = morphology.binary_erosion(instances==label, selem=morphology.ball(1))
81 | instance_inner = np.logical_xor(instance_inner, instances==label)
82 |
83 | # get binary instance mask
84 | x,y,z = np.where(instance_inner)
85 | if get_centroids:
86 | centroids[num_label,:] = [x.mean(),y.mean(),z.mean()]
87 | x -= int(centroids[num_label,0])
88 | y -= int(centroids[num_label,1])
89 | z -= int(centroids[num_label,2])
90 | r,theta,phi = cart2sphere(x, y, z)
91 |
92 | # find closest sampling angles
93 | sampling_tree = cKDTree(np.array([theta,phi]).T)
94 | _, assignments = sampling_tree.query(theta_phi_sampling, k=3)
95 |
96 | # get sampling
97 | sampling[num_label,:] = np.mean(r[assignments], axis=1)
98 |
99 | return labels, centroids, sampling
100 |
101 |
102 |
103 | def sampling2instance(centroids, sampling, theta_phi_sampling, shape, verbose=True):
104 |
105 | instances = np.full(shape, 0, dtype=np.uint16)
106 | idx = np.reshape(np.indices(shape), (3,-1))
107 |
108 | label = 0
109 | for centroid, r in zip(centroids, sampling):
110 | label += 1
111 | if verbose: print_timestamp('Reconstructing instance {0}/{1}...', [label, sampling.shape[0]])
112 |
113 | x,y,z = sphere2cart(r, theta_phi_sampling[:,0], theta_phi_sampling[:,1])
114 | x += centroid[0]
115 | y += centroid[1]
116 | z += centroid[2]
117 |
118 | delaunay_tri = Delaunay(np.array([x,y,z]).T)
119 |
120 | voxel_idx = delaunay_tri.find_simplex(idx.T).reshape(shape)>0
121 |
122 | instances[voxel_idx] = label
123 |
124 | return instances
125 |
126 |
127 |
128 |
129 | class sampling2harmonics():
130 |
131 | def __init__(self, sh_order, theta_phi_sampling, lb_lambda=0.006):
132 | super(sampling2harmonics, self).__init__()
133 | self.sh_order = sh_order
134 | self.theta_phi_sampling = theta_phi_sampling
135 | self.lb_lambda = lb_lambda
136 | self.num_samples = len(theta_phi_sampling)
137 | self.num_coefficients = np.int((self.sh_order+1)**2)
138 |
139 | b = np.zeros((self.num_samples, self.num_coefficients))
140 | l = np.zeros((self.num_coefficients, self.num_coefficients))
141 |
142 | for num_sample in range(self.num_samples):
143 | num_coefficient = 0
144 | for num_order in range(self.sh_order+1):
145 | for num_degree in range(-num_order, num_order+1):
146 |
147 | theta = theta_phi_sampling[num_sample][0]
148 | phi = theta_phi_sampling[num_sample][1]
149 |
150 | y = sph_harm(np.abs(num_degree), num_order, phi, theta)
151 |
152 | if num_degree < 0:
153 | b[num_sample, num_coefficient] = np.real(y) * np.sqrt(2)
154 | elif num_degree == 0:
155 | b[num_sample, num_coefficient] = np.real(y)
156 | elif num_degree > 0:
157 | b[num_sample, num_coefficient] = np.imag(y) * np.sqrt(2)
158 |
159 | l[num_coefficient, num_coefficient] = self.lb_lambda * num_order ** 2 * (num_order + 1) ** 2
160 | num_coefficient += 1
161 |
162 | b_inv = np.linalg.pinv(np.matmul(b.transpose(), b) + l)
163 | self.convert_mat = np.matmul(b_inv, b.transpose()).transpose()
164 |
165 | def convert(self, r_sampling):
166 | converted_samples = np.zeros((r_sampling.shape[0],self.num_coefficients))
167 | for num_sample, r_sample in enumerate(r_sampling):
168 | r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat)
169 | converted_samples[num_sample] = np.squeeze(r_converted)
170 | return converted_samples
171 |
172 |
173 |
174 |
175 | class harmonics2sampling():
176 |
177 | def __init__(self, sh_order, theta_phi_sampling):
178 | super(harmonics2sampling, self).__init__()
179 | self.sh_order = sh_order
180 | self.theta_phi_sampling = theta_phi_sampling
181 | self.num_samples = len(theta_phi_sampling)
182 | self.num_coefficients = np.int((self.sh_order+1)**2)
183 |
184 | convert_mat = np.zeros((self.num_coefficients, self.num_samples))
185 |
186 | for num_sample in range(self.num_samples):
187 | num_coefficient = 0
188 | for num_order in range(self.sh_order+1):
189 | for num_degree in range(-num_order, num_order+1):
190 |
191 | theta = theta_phi_sampling[num_sample][0]
192 | phi = theta_phi_sampling[num_sample][1]
193 |
194 | y = sph_harm(np.abs(num_degree), num_order, phi, theta)
195 |
196 | if num_degree < 0:
197 | convert_mat[num_coefficient, num_sample] = np.real(y) * np.sqrt(2)
198 | elif num_degree == 0:
199 | convert_mat[num_coefficient, num_sample] = np.real(y)
200 | elif num_degree > 0:
201 | convert_mat[num_coefficient, num_sample] = np.imag(y) * np.sqrt(2)
202 |
203 | num_coefficient += 1
204 |
205 | self.convert_mat = convert_mat
206 |
207 | def convert(self, r_harmonic):
208 | converted_harmonics = np.zeros((r_harmonic.shape[0],self.theta_phi_sampling.shape[0]))
209 | for num_sample, r_sample in enumerate(r_harmonic):
210 | r_converted = np.matmul(r_sample[np.newaxis], self.convert_mat)
211 | converted_harmonics[num_sample] = np.squeeze(r_converted)
212 | return converted_harmonics
213 |
214 |
215 |
216 |
217 | def sphere_intersection_poolhelper(instance_indices, point_coords=None, radii=None):
218 |
219 | # get radiii, positions and distance
220 | r1 = radii[instance_indices[0]]
221 | r2 = radii[instance_indices[1]]
222 | p1 = point_coords[instance_indices[0]]
223 | p2 = point_coords[instance_indices[1]]
224 | d = np.sqrt(np.sum((np.array(p1)-np.array(p2))**2))
225 |
226 | # calculate individual volumes
227 | vol1 = 4/3*np.pi*r1**3
228 | vol2 = 4/3*np.pi*r2**3
229 |
230 | # calculate intersection of volumes
231 |
232 | # Smaller sphere inside the bigger sphere
233 | if d <= np.abs(r1-r2):
234 | intersect_vol = 4/3*np.pi*np.minimum(r1,r2)**3
235 | # No intersection at all
236 | elif d > r1+r2:
237 | intersect_vol = 0
238 | # Partially intersecting spheres
239 | else:
240 | intersect_vol = np.pi * (r1 + r2 - d)**2 * (d**2 + 2*d*r2 - 3*r2**2 + 2*d*r1 + 6*r2*r1 - 3*r1**2) / (12*d)
241 |
242 | return (intersect_vol, vol1, vol2)
243 |
244 |
245 |
246 | def harmonic_non_max_suppression(point_coords, point_probs, shape_descriptors, overlap_thresh=0.5, dim_scale=(1,1,1), num_kernel=1, **kwargs):
247 |
248 | if len(point_coords)>3000:
249 |
250 | print_timestamp('Too many points, aborting NMS')
251 | nms_coords = point_coords[:3000]
252 | nms_probs = point_probs[:3000]
253 | nms_shapes = shape_descriptors[:3000]
254 |
255 | elif len(point_coords)>1:
256 |
257 | dim_scale = [d/np.min(dim_scale) for d in dim_scale]
258 | point_coords_uniform = []
259 | for point_coord in point_coords:
260 | point_coords_uniform.append(tuple([p*d for p,d in zip(point_coord,dim_scale)]))
261 |
262 | # calculate upper and lower volumes
263 | r_upper = [r.max() for r in shape_descriptors]
264 | r_lower = [r.min() for r in shape_descriptors]
265 |
266 | # Calculate intersections of lower and upper spheres
267 | #instance_indices = list(itertools.combinations(range(len(point_coords)), r=2))
268 | r_max = np.max(r_upper)
269 | instance_indices = [ (i, j) for i in range(len(point_coords))
270 | for j in range(i+1, len(point_coords))
271 | if np.sum(np.sqrt(np.abs(np.array(point_coords[i])-np.array(point_coords[j])))) < r_max*2 ]
272 | with mp.Pool(processes=num_kernel) as p:
273 | vol_upper = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_upper), instance_indices)
274 | vol_lower = p.map(partial(sphere_intersection_poolhelper, point_coords=point_coords_uniform, radii=r_lower), instance_indices)
275 |
276 | instances_keep = np.ones((len(point_coords),), dtype=np.bool)
277 |
278 | # calculate overlap measure
279 | for inst_idx, v_up, v_low in zip(instance_indices, vol_upper, vol_lower):
280 |
281 | # average intersection with smaller sphere
282 | overlap_measure_up = v_up[0] / np.minimum(v_up[1],v_up[2])
283 | overlap_measure_low = v_low[0] / np.minimum(v_low[1],v_low[2])
284 | overlap_measure = (overlap_measure_up+overlap_measure_low)/2
285 |
286 | if overlap_measure > overlap_thresh:
287 | # Get min and max probable indice
288 | inst_min = inst_idx[np.argmin([point_probs[i] for i in inst_idx])]
289 | inst_max = inst_idx[np.argmax([point_probs[i] for i in inst_idx])]
290 |
291 | # If there already was an instance with higher probability, don't add the current "winner"
292 | if instances_keep[inst_max] == 0:
293 | # Mark both as excluded
294 | instances_keep[inst_min] = 0
295 | instances_keep[inst_max] = 0
296 | else:
297 | # Exclude the loser
298 | instances_keep[inst_min] = 0
299 | #instances_keep[inst_max] = 1
300 |
301 | # Mark remaining indices for keeping
302 | #instances_keep = instances_keep != -1
303 |
304 | nms_coords = [point_coords[i] for i,v in enumerate(instances_keep) if v]
305 | nms_probs = [point_probs[i] for i,v in enumerate(instances_keep) if v]
306 | nms_shapes = [shape_descriptors[i] for i,v in enumerate(instances_keep) if v]
307 |
308 | else:
309 | nms_coords = point_coords
310 | nms_probs = point_probs
311 | nms_shapes = shape_descriptors
312 |
313 | return nms_coords, nms_shapes, nms_probs
314 |
315 |
316 |
--------------------------------------------------------------------------------
/utils/jupyter_widgets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import os
6 | import ipywidgets as wg
7 | from IPython.display import display
8 | style={'description_width': 'initial'}
9 |
10 |
11 |
12 | def get_pipelin_widget():
13 |
14 | pipeline = wg.Dropdown(
15 | options=['DiffusionModel3d', 'DiffusionModel2d'],
16 | description='Select Pipeline:', style=style)
17 |
18 | return pipeline
19 |
20 |
21 | def get_execution_widgets():
22 |
23 | wg_execute = wg.Checkbox(description='Execute Now!', value=True, style=style)
24 | wg_arguments = wg.Checkbox(description='Get Command Line Arguments', value=True, style=style)
25 |
26 | return [wg_execute, wg_arguments]
27 |
28 |
29 | def get_synthesizer_widget():
30 |
31 | pipeline = wg.Dropdown(
32 | options=['SyntheticTRIC', 'SyntheticCE', 'Synthetic2DGOWT1', 'Synthetic2DHeLa', 'SyntheticMeristem'],
33 | description='Select Synthesizer:', style=style)
34 |
35 | return pipeline
36 |
37 |
38 |
39 | def get_parameter_widgets(param_dict):
40 |
41 | param_names = []
42 | widget_list = []
43 | test_related = []
44 |
45 | for key in param_dict.keys():
46 | ### Script Parameter
47 |
48 | if key == 'output_path':
49 | widget_list.append(wg.Text(description='Output Path:', value=param_dict[key], style=style))
50 | param_names.append('--'+key)
51 | test_related.append(True)
52 | if key == 'log_path':
53 | widget_list.append(wg.Text(description='Log Path:', value=param_dict[key], style=style))
54 | param_names.append('--'+key)
55 | test_related.append(False)
56 | if key == 'gpus':
57 | widget_list.append(wg.IntSlider(description = 'Use GPU:', min=0, max=1, value=param_dict[key], style=style))
58 | param_names.append('--'+key)
59 | test_related.append(True)
60 | if key == 'no_resume':
61 | widget_list.append(wg.Checkbox(description='Resume:', value=not param_dict[key], style=style))
62 | param_names.append('--'+key)
63 | test_related.append(False)
64 | if key == 'pretrained':
65 | widget_list.append(wg.Text(description='Path To The Pretrained Model:', value=param_dict[key], style=style))
66 | param_names.append('--'+key)
67 | test_related.append(False)
68 | if key == 'augmentations':
69 | widget_list.append(wg.Text(description='Augmentation Dictionary File:', value=param_dict[key], style=style))
70 | param_names.append('--'+key)
71 | test_related.append(False)
72 | if key == 'epochs':
73 | widget_list.append(wg.BoundedIntText(description='Epochs:', min=1, max=10000, value=param_dict[key], style=style))
74 | param_names.append('--'+key)
75 | test_related.append(False)
76 |
77 |
78 | ### Network Parameter
79 |
80 | if key == 'backbone':
81 | widget_list.append(wg.Dropdown(description='Network Architecture:', options=['UNet3D_PixelShuffle_inject', 'UNet2D_PixelShuffle_inject'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
82 | param_names.append('--'+key)
83 | test_related.append(True)
84 | if key == 'in_channels':
85 | widget_list.append(wg.BoundedIntText(description='Input Channels:', value=param_dict[key], min=1, max=10000, style=style))
86 | param_names.append('--'+key)
87 | test_related.append(True)
88 | if key == 'out_channels':
89 | widget_list.append(wg.BoundedIntText(description='Output Channels:', value=param_dict[key], min=1, max=10000, style=style))
90 | param_names.append('--'+key)
91 | test_related.append(True)
92 | if key == 'feat_channels':
93 | widget_list.append(wg.BoundedIntText(description='Feature Channels:', value=param_dict[key], min=2, max=10000, style=style))
94 | param_names.append('--'+key)
95 | test_related.append(True)
96 | if key == 'patch_size':
97 | if not param_dict[key] is str:
98 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
99 | widget_list.append(wg.Text(description='Patch Size (z,y,x):', value=param_dict[key], style=style))
100 | param_names.append('--'+key)
101 | test_related.append(True)
102 | if key == 'out_activation':
103 | widget_list.append(wg.Dropdown(description='Output Activation:', options=['tanh', 'sigmoid', 'hardtanh', 'relu', 'leakyrelu', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
104 | param_names.append('--'+key)
105 | test_related.append(True)
106 | if key == 'layer_norm':
107 | widget_list.append(wg.Dropdown(description='Layer Normalization:', options=['instance', 'batch', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
108 | param_names.append('--'+key)
109 | test_related.append(True)
110 | if key == 't_channels':
111 | widget_list.append(wg.BoundedIntText(description='T Channels:', value=param_dict[key], min=1, max=10000, style=style))
112 | param_names.append('--'+key)
113 | test_related.append(True)
114 |
115 | ### Data Parameter
116 |
117 | if key == 'data_norm':
118 | widget_list.append(wg.Dropdown(description='Data Normalization:', options=['percentile', 'minmax', 'meanstd', 'minmax_shifted', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
119 | param_names.append('--'+key)
120 | test_related.append(True)
121 | if key == 'data_root':
122 | widget_list.append(wg.Text(value=param_dict[key], description='Data Root:', style=style))
123 | param_names.append('--'+key)
124 | test_related.append(True)
125 | if key == 'train_list':
126 | widget_list.append(wg.Text(value=param_dict[key], description='Train List:', style=style))
127 | param_names.append('--'+key)
128 | test_related.append(True)
129 | if key == 'test_list':
130 | widget_list.append(wg.Text(value=param_dict[key], description='Test List:', style=style))
131 | param_names.append('--'+key)
132 | test_related.append(True)
133 | if key == 'val_list':
134 | widget_list.append(wg.Text(value=param_dict[key], description='Validation List:', style=style))
135 | param_names.append('--'+key)
136 | test_related.append(True)
137 | if key == 'image_groups':
138 | if not param_dict[key] is str:
139 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
140 | widget_list.append(wg.Text(description='Image Groups:', value=param_dict[key], style=style))
141 | param_names.append('--'+key)
142 | test_related.append(True)
143 | if key == 'mask_groups':
144 | if not param_dict[key] is str:
145 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
146 | widget_list.append(wg.Text(description='Mask Groups:', value=param_dict[key], style=style))
147 | param_names.append('--'+key)
148 | test_related.append(True)
149 | if key == 'dist_handling':
150 | widget_list.append(wg.Dropdown(description='Distance Handling:', options=['float', 'bool', 'bool_inv', 'exp', 'tanh', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
151 | param_names.append('--'+key)
152 | test_related.append(True)
153 | if key == 'dist_scaling':
154 | if not param_dict[key] is str:
155 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
156 | widget_list.append(wg.Dropdown(description='Distance Scaling:', value=param_dict[key], style=style))
157 | param_names.append('--'+key)
158 | test_related.append(True)
159 | if key == 'seed_handling':
160 | widget_list.append(wg.Dropdown(description='Seed Handling:', options=['float', 'bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
161 | param_names.append('--'+key)
162 | test_related.append(True)
163 | if key == 'boundary_handling':
164 | widget_list.append(wg.Dropdown(description='Boundary Handling:', options=['bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
165 | param_names.append('--'+key)
166 | test_related.append(True)
167 | if key == 'instance_handling':
168 | widget_list.append(wg.Dropdown(description='Instance Handling:', options=['bool', 'none'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
169 | param_names.append('--'+key)
170 | test_related.append(True)
171 | if key == 'strides':
172 | if not param_dict[key] is str:
173 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
174 | widget_list.append(wg.Dropdown(description='Strides:', value=param_dict[key], style=style))
175 | param_names.append('--'+key)
176 | test_related.append(True)
177 | if key == 'sh_order':
178 | widget_list.append(wg.BoundedIntText(description='SH Order:', value=param_dict[key], min=0, max=10000, style=style))
179 | param_names.append('--'+key)
180 | test_related.append(True)
181 | if key == 'mean_low':
182 | widget_list.append(wg.BoundedFloatText(description='Sampling Mean Weight Low:', value=param_dict[key], min=0, max=1000000, style=style))
183 | param_names.append('--'+key)
184 | test_related.append(True)
185 | if key == 'mean_high':
186 | widget_list.append(wg.BoundedFloatText(description='Sampling Mean Weight High:', value=param_dict[key], min=0, max=1000000, style=style))
187 | param_names.append('--'+key)
188 | test_related.append(True)
189 | if key == 'var_high':
190 | widget_list.append(wg.BoundedFloatText(description='Sampling Variance Weight High:', value=param_dict[key], min=0, max=1000000, style=style))
191 | param_names.append('--'+key)
192 | test_related.append(True)
193 | if key == 'variance_levels':
194 | widget_list.append(wg.BoundedIntText(description='Sampling Variance Levels:', value=param_dict[key], min=0, max=10000, style=style))
195 | param_names.append('--'+key)
196 | test_related.append(True)
197 | if key == 'strategy':
198 | widget_list.append(wg.Dropdown(description='Sampling Variance Strategy:', options=['random', 'structured'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
199 | param_names.append('--'+key)
200 | test_related.append(True)
201 | if key == 'image_noise_channel':
202 | widget_list.append(wg.BoundedIntText(description='Image Noise Channel:', value=param_dict[key], min=-5, max=5, style=style))
203 | param_names.append('--'+key)
204 | test_related.append(True)
205 | if key == 'mask_noise_channel':
206 | widget_list.append(wg.BoundedIntText(description='Mask Noise Channel:', value=param_dict[key], min=-5, max=5, style=style))
207 | param_names.append('--'+key)
208 | test_related.append(True)
209 | if key == 'noise_type':
210 | widget_list.append(wg.Dropdown(description='Noise Type:', options=['gaussian', 'rayleigh', 'laplace'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
211 | param_names.append('--'+key)
212 | test_related.append(True)
213 |
214 | ### Training Parameter
215 |
216 | if key == 'samples_per_epoch':
217 | widget_list.append(wg.BoundedIntText(description='Samples Per Epoch:', value=param_dict[key], min=-1, max=1000000, style=style))
218 | param_names.append('--'+key)
219 | test_related.append(True)
220 | if key == 'batch_size':
221 | widget_list.append(wg.BoundedIntText(description='Batch Size:', value=param_dict[key], min=1, max=1000000, style=style))
222 | param_names.append('--'+key)
223 | test_related.append(True)
224 | if key == 'learning_rate':
225 | widget_list.append(wg.BoundedFloatText(description='Learning Rate:', value=param_dict[key], min=0, max=1000000, style=style))
226 | param_names.append('--'+key)
227 | test_related.append(True)
228 | if key == 'background_weight':
229 | widget_list.append(wg.BoundedFloatText(description='Background Weight:', value=param_dict[key], min=0, max=1000000, style=style))
230 | param_names.append('--'+key)
231 | test_related.append(True)
232 | if key == 'seed_weight':
233 | widget_list.append(wg.BoundedFloatText(description='Seed Weight:', value=param_dict[key], min=0, max=1000000, style=style))
234 | param_names.append('--'+key)
235 | test_related.append(True)
236 | if key == 'boundary_weight':
237 | widget_list.append(wg.BoundedFloatText(description='Boundary Weight:', value=param_dict[key], min=0, max=1000000, style=style))
238 | param_names.append('--'+key)
239 | test_related.append(True)
240 | if key == 'flow_weight':
241 | widget_list.append(wg.BoundedFloatText(description='Flow Weight:', value=param_dict[key], min=0, max=1000000, style=style))
242 | param_names.append('--'+key)
243 | test_related.append(True)
244 | if key == 'centroid_weight':
245 | widget_list.append(wg.BoundedFloatText(description='Centroid Weight:', value=param_dict[key], min=0, max=1000000, style=style))
246 | param_names.append('--'+key)
247 | test_related.append(True)
248 | if key == 'encoding_weight':
249 | widget_list.append(wg.BoundedFloatText(description='Encoding Weight:', value=param_dict[key], min=0, max=1000000, style=style))
250 | param_names.append('--'+key)
251 | test_related.append(True)
252 | if key == 'robustness_weight':
253 | widget_list.append(wg.BoundedFloatText(description='Robustness Weight:', value=param_dict[key], min=0, max=1000000, style=style))
254 | param_names.append('--'+key)
255 | test_related.append(True)
256 | if key == 'variance_interval':
257 | widget_list.append(wg.BoundedIntText(description='Sampling Variance Interval:', value=param_dict[key], min=0, max=10000, style=style))
258 | param_names.append('--'+key)
259 | test_related.append(True)
260 | if key == 'ada_update_period':
261 | widget_list.append(wg.BoundedIntText(description='ADA Update Period:', value=param_dict[key], min=0, max=10000, style=style))
262 | param_names.append('--'+key)
263 | test_related.append(True)
264 | if key == 'ada_update':
265 | widget_list.append(wg.BoundedFloatText(description='ADA Update Step:', value=param_dict[key], min=0, max=1000000, style=style))
266 | param_names.append('--'+key)
267 | test_related.append(True)
268 | if key == 'ada_target':
269 | widget_list.append(wg.BoundedFloatText(description='ADA Target:', value=param_dict[key], min=0, max=1000000, style=style))
270 | param_names.append('--'+key)
271 | test_related.append(True)
272 | if key == 'num_samples':
273 | widget_list.append(wg.BoundedIntText(description='Number Of Samples:', value=param_dict[key], min=0, max=10000, style=style))
274 | param_names.append('--'+key)
275 | test_related.append(True)
276 |
277 | # diffusion parameter
278 | if key == 'num_timesteps':
279 | widget_list.append(wg.BoundedIntText(description='Number Of Timesteps:', value=param_dict[key], min=0, max=1000, style=style))
280 | param_names.append('--'+key)
281 | test_related.append(True)
282 | if key == 'diffusion_schedule':
283 | widget_list.append(wg.Dropdown(description='Diffusion Schedule:', options=['cosine', 'linear', 'quadratic', 'sigmoid'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
284 | param_names.append('--'+key)
285 | test_related.append(True)
286 | return param_names, widget_list, test_related
287 |
288 |
289 |
290 | def get_apply_parameter_widgets(param_dict):
291 |
292 | param_names = []
293 | widget_list = []
294 |
295 | for key in param_dict.keys():
296 |
297 |
298 | ### Script Parameter
299 |
300 | if key == 'output_path':
301 | widget_list.append(wg.Text(description='Output Path:', value=param_dict[key], style=style))
302 | param_names.append('--'+key)
303 | elif key == 'ckpt_path':
304 | widget_list.append(wg.Text(description='Checkpoint Path:', value=param_dict[key], style=style))
305 | param_names.append('--'+key)
306 | elif key == 'gpus':
307 | widget_list.append(wg.IntSlider(description = 'Use GPU:', min=0, max=1, value=param_dict[key], style=style))
308 | param_names.append('--'+key)
309 | elif key == 'distributed_backend':
310 | widget_list.append(wg.Dropdown(description='Distributed Backend:', options=['dp', 'ddp', 'ddp2'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
311 | param_names.append('--'+key)
312 | elif key == 'overlap':
313 | if not param_dict[key] is str:
314 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
315 | widget_list.append(wg.Text(description='Overlap (z,y,x):', value=param_dict[key], style=style))
316 | param_names.append('--'+key)
317 | elif key == 'crop':
318 | if not param_dict[key] is str:
319 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
320 | widget_list.append(wg.Text(description='Crop (z,y,x):', value=param_dict[key], style=style))
321 | param_names.append('--'+key)
322 | elif key == 'input_batch':
323 | widget_list.append(wg.Dropdown(description='Input Batch:', options=['image', 'mask'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
324 | param_names.append('--'+key)
325 | elif key == 'clip':
326 | if not param_dict[key] is str:
327 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
328 | widget_list.append(wg.Text(description='Clip (min, max):', value=param_dict[key], style=style))
329 | param_names.append('--'+key)
330 | elif key == 'num_files':
331 | widget_list.append(wg.BoundedIntText(description='Number of Files:', value=param_dict[key], min=-1, max=10000, style=style))
332 | param_names.append('--'+key)
333 | elif key == 'add_noise_channel':
334 | widget_list.append(wg.BoundedIntText(description='Noise Channel:', value=param_dict[key], min=-2, max=10000, style=style))
335 | param_names.append('--'+key)
336 | elif key == 'theta_phi_sampling':
337 | widget_list.append(wg.Text(description='Angular Sampling File Path:', value=param_dict[key], style=style))
338 | param_names.append('--'+key)
339 |
340 | ### Network Parameter
341 | elif key == 'out_channels':
342 | widget_list.append(wg.BoundedIntText(description='Output Channels:', value=param_dict[key], min=1, max=10000, style=style))
343 | param_names.append('--'+key)
344 | elif key == 'patch_size':
345 | if not param_dict[key] is str:
346 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
347 | widget_list.append(wg.Text(description='Patch Size (z,y,x):', value=param_dict[key], style=style))
348 | param_names.append('--'+key)
349 | elif key == 'resolution_weights':
350 | if not param_dict[key] is str:
351 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
352 | widget_list.append(wg.Text(description='Prediction Weights at Each Resolution:', value=param_dict[key], style=style))
353 | param_names.append('--'+key)
354 | elif key == 'centroid_thresh':
355 | widget_list.append(wg.BoundedFloatText(description='Centroid Threshold:', value=param_dict[key], min=0, max=100, style=style))
356 | param_names.append('--'+key)
357 | elif key == 'minsize':
358 | widget_list.append(wg.BoundedIntText(description='The Minimum Cell Size:', value=param_dict[key], min=0, max=100, style=style))
359 | param_names.append('--'+key)
360 | elif key == 'maxsize':
361 | widget_list.append(wg.BoundedIntText(description='The Maximum Cell Size:', value=param_dict[key], min=0, max=100, style=style))
362 | param_names.append('--'+key)
363 | elif key == 'use_watershed':
364 | widget_list.append(wg.Checkbox(description='Use Watershed:', value=param_dict[key], style=style))
365 | param_names.append('--'+key)
366 | elif key == 'use_sizefilter':
367 | widget_list.append(wg.Checkbox(description='Use Size Filter:', value=param_dict[key], style=style))
368 | param_names.append('--'+key)
369 | elif key == 'use_nms':
370 | widget_list.append(wg.Checkbox(description='Use non-Maximum Suppression:', value=param_dict[key], style=style))
371 | param_names.append('--'+key)
372 | ### Data Parameter
373 | elif key == 'data_root':
374 | widget_list.append(wg.Text(value=param_dict[key], description='Data Root:', style=style))
375 | param_names.append('--'+key)
376 | elif key == 'test_list':
377 | widget_list.append(wg.Text(value=param_dict[key], description='Test List:', style=style))
378 | param_names.append('--'+key)
379 | elif key == 'image_groups':
380 | if not param_dict[key] is str:
381 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
382 | widget_list.append(wg.Text(description='Image Groups:', value=param_dict[key], style=style))
383 | param_names.append('--'+key)
384 | elif key == 'mask_groups':
385 | if not param_dict[key] is str:
386 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
387 | widget_list.append(wg.Text(description='Mask Groups:', value=param_dict[key], style=style))
388 | param_names.append('--'+key)
389 | elif key == 'sh_order':
390 | widget_list.append(wg.BoundedIntText(description='SH Order:', value=param_dict[key], min=0, max=10000, style=style))
391 | param_names.append('--'+key)
392 |
393 | # diffusion parameter
394 | elif key == 'timesteps_start':
395 | widget_list.append(wg.BoundedIntText(description='Timestep to Start the Reverse Process:', value=param_dict[key], min=0, max=1000, style=style))
396 | param_names.append('--'+key)
397 | elif key == 'timesteps_save':
398 | widget_list.append(wg.BoundedIntText(description='Number of Timesteps Between Saves:', value=param_dict[key], min=0, max=1000, style=style))
399 | param_names.append('--'+key)
400 | elif key == 'timesteps_step':
401 | widget_list.append(wg.BoundedIntText(description='Timesteps Skipped Between Iterations:', value=param_dict[key], min=0, max=1000, style=style))
402 | param_names.append('--'+key)
403 | elif key == 'blur_sigma':
404 | widget_list.append(wg.BoundedIntText(description='Sigma for Gaussian Blurring of Inputs:', value=param_dict[key], min=0, max=5, style=style))
405 | param_names.append('--'+key)
406 | elif key == 'num_timesteps':
407 | widget_list.append(wg.BoundedIntText(description='Total Number Of Training Timesteps:', value=param_dict[key], min=0, max=1000, style=style))
408 | param_names.append('--'+key)
409 | elif key == 'diffusion_schedule':
410 | widget_list.append(wg.Dropdown(description='Diffusion Schedule:', options=['cosine', 'linear', 'quadratic', 'sigmoid'], value=param_dict[key], layout={'width': 'max-content'}, style=style))
411 | param_names.append('--'+key)
412 |
413 | return param_names, widget_list
414 |
415 |
416 | def get_sim_parameter_widgets(param_dict):
417 |
418 | param_names = []
419 | widget_list = []
420 |
421 | for key in param_dict.keys():
422 | if key == 'save_path':
423 | widget_list.append(wg.Text(description='Save Path:', value=param_dict[key], style=style))
424 | param_names.append('--'+key)
425 | elif key == 'experiment_name':
426 | widget_list.append(wg.Text(description='Experiment Name:', value=param_dict[key], style=style))
427 | param_names.append('--'+key)
428 | elif key == 'num_imgs' or key == 'img_count':
429 | widget_list.append(wg.BoundedIntText(description='Number of Images:', value=param_dict[key], min=1, max=1000, style=style))
430 | param_names.append('--'+key)
431 |
432 | # Nuclei masks
433 | elif key == 'img_shape':
434 | if not param_dict[key] is str:
435 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
436 | widget_list.append(wg.Text(description='Image Shape:', value=param_dict[key], style=style))
437 | param_names.append('--'+key)
438 | elif key == 'max_radius':
439 | widget_list.append(wg.BoundedIntText(description='Maximum Radius:', value=param_dict[key], min=1, max=1000, style=style))
440 | param_names.append('--'+key)
441 | elif key == 'min_radius':
442 | widget_list.append(wg.BoundedIntText(description='Minimum Radius:', value=param_dict[key], min=1, max=100, style=style))
443 | param_names.append('--'+key)
444 | elif key == 'radius_range':
445 | widget_list.append(wg.BoundedIntText(description='Radius Range:', value=param_dict[key], min=-100, max=100, style=style))
446 | param_names.append('--'+key)
447 | elif key == 'sh_order':
448 | widget_list.append(wg.BoundedIntText(description='Sh Order :', value=param_dict[key], min=1, max=100, style=style))
449 | param_names.append('--'+key)
450 | elif key == 'smooth_std':
451 | widget_list.append(wg.BoundedFloatText(description='Smoothing Std :', value=param_dict[key], min=0, max=100, style=style))
452 | param_names.append('--'+key)
453 | elif key == 'noise_std':
454 | widget_list.append(wg.BoundedFloatText(description='Noise Std :', value=param_dict[key], min=0, max=100, style=style))
455 | param_names.append('--'+key)
456 | elif key == 'noise_mean':
457 | widget_list.append(wg.BoundedFloatText(description='Noise Mean :', value=param_dict[key], min=0, max=100, style=style))
458 | param_names.append('--'+key)
459 | elif key == 'position_std':
460 | widget_list.append(wg.BoundedFloatText(description='Position Std :', value=param_dict[key], min=0, max=100, style=style))
461 | param_names.append('--'+key)
462 | elif key == 'num_cells':
463 | widget_list.append(wg.BoundedIntText(description='Number of Cells :', value=param_dict[key], min=1, max=1000, style=style))
464 | param_names.append('--'+key)
465 | elif key == 'num_cells_range':
466 | widget_list.append(wg.BoundedIntText(description='Number of Cells Range:', value=param_dict[key], min=1, max=1000, style=style))
467 | param_names.append('--'+key)
468 | elif key == 'circularity':
469 | widget_list.append(wg.BoundedFloatText(description='Circularity :', value=param_dict[key], min=0, max=100, style=style))
470 | param_names.append('--'+key)
471 | elif key == 'generate_images':
472 | widget_list.append(wg.Checkbox(description='Generate Images ', value=False, style=style))
473 | param_names.append('--'+key)
474 | elif key == 'theta_phi_sampling_file':
475 | widget_list.append(wg.Text(value=param_dict[key], description='Theta Phi Sampling File:', style=style))
476 | param_names.append('--'+key)
477 | elif key == 'cell_elongation':
478 | widget_list.append(wg.BoundedFloatText(description='Cell Elongation :', value=param_dict[key], min=0, max=100, style=style))
479 | param_names.append('--'+key)
480 | elif key == 'z_anisotropy':
481 | widget_list.append(wg.BoundedFloatText(description='Z Anisotropy :', value=param_dict[key], min=0, max=100, style=style))
482 | param_names.append('--'+key)
483 | elif key == 'irregularity_extend':
484 | widget_list.append(wg.BoundedFloatText(description='Irregularity Extend :', value=param_dict[key], min=0, max=100, style=style))
485 | param_names.append('--'+key)
486 |
487 | # Membrane masks
488 | elif key == 'gridsize':
489 | if not param_dict[key] is str:
490 | param_dict[key] = ' '.join([str(p) for p in param_dict[key]])
491 | widget_list.append(wg.Text(description='Grid Size:', value=param_dict[key], style=style))
492 | param_names.append('--'+key)
493 | elif key == 'distance_weight':
494 | widget_list.append(wg.BoundedFloatText(description='Distance Weight :', value=param_dict[key], min=0, max=10, style=style))
495 | param_names.append('--'+key)
496 | elif key == 'morph_radius':
497 | widget_list.append(wg.BoundedIntText(description='Morph Radius:', value=param_dict[key], min=1, max=1000, style=style))
498 | param_names.append('--'+key)
499 | elif key == 'weights':
500 | widget_list.append(wg.Text(description='Cell Weights:', value=param_dict[key], style=style))
501 | param_names.append('--'+key)
502 | elif key == 'cell_density':
503 | widget_list.append(wg.BoundedFloatText(description='Cell Density:', value=param_dict[key], min=0, max=10, style=style))
504 | param_names.append('--'+key)
505 | elif key == 'cell_density_decay':
506 | widget_list.append(wg.BoundedFloatText(description='Cell Density Decay:', value=param_dict[key], min=0, max=10, style=style))
507 | param_names.append('--'+key)
508 | elif key == 'cell_position_smoothness':
509 | widget_list.append(wg.BoundedIntText(description='Cell Position Smoothness:', value=param_dict[key], min=1, max=1000, style=style))
510 | param_names.append('--'+key)
511 | elif key == 'ring_density':
512 | widget_list.append(wg.BoundedFloatText(description='Ring Density:', value=param_dict[key], min=0, max=10, style=style))
513 | param_names.append('--'+key)
514 | elif key == 'ring_density_decay':
515 | widget_list.append(wg.BoundedFloatText(description='Ring Density Decay:', value=param_dict[key], min=0, max=10, style=style))
516 | param_names.append('--'+key)
517 | elif key == 'angular_sampling_file':
518 | widget_list.append(wg.Text(description='Angular Sampling File:', value=param_dict[key], style=style))
519 | param_names.append('--'+key)
520 | elif key == 'specimen_sampling_file':
521 | widget_list.append(wg.Text(description='Specimen Sampling File:', value=param_dict[key], style=style))
522 | param_names.append('--'+key)
523 | return param_names, widget_list
524 |
525 |
--------------------------------------------------------------------------------
/utils/synthetic_cell_membrane_masks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 |
5 | from skimage import morphology, measure, io
6 | from scipy.ndimage import distance_transform_edt
7 | from scipy.spatial import cKDTree
8 | from sklearn.decomposition import PCA
9 | from sklearn.cluster import AgglomerativeClustering
10 |
11 | from utils.h5_converter import h5_writer
12 | from utils.utils import print_timestamp
13 |
14 |
15 |
16 | def generate_data(syn_class, save_path='Segmentations_h5', experiment_name='synthetic_data', img_count=100, param_dict={}):
17 |
18 | synthesizer = syn_class(**param_dict)
19 | os.makedirs(save_path, exist_ok=True)
20 |
21 | for num_img in range(img_count):
22 |
23 | print_timestamp('_'*20)
24 | print_timestamp('Generating mask {0}/{1}', (num_img+1, img_count))
25 |
26 | # Generate a new mask
27 | synthesizer.generate_instances()
28 |
29 | # Get and save the instance, boundary, centroid and distance masks
30 | print_timestamp('Saving...')
31 | instance_mask = synthesizer.get_instance_mask().astype(np.uint16)
32 | instance_mask = np.transpose(instance_mask, [2,1,0])
33 | boundary_mask = synthesizer.get_boundary_mask().astype(np.uint8)
34 | boundary_mask = np.transpose(boundary_mask, [2,1,0])
35 | distance_mask = synthesizer.get_distance_mask().astype(np.float32)
36 | distance_mask = np.transpose(distance_mask, [2,1,0])
37 | centroid_mask = synthesizer.get_centroid_mask().astype(np.uint8)
38 | centroid_mask = np.transpose(centroid_mask, [2,1,0])
39 |
40 | save_name = os.path.join(save_path, experiment_name+'_'+str(num_img)+'.h5')
41 | h5_writer([instance_mask, boundary_mask, distance_mask, centroid_mask], save_name, group_names=['instances', 'boundary', 'distance', 'seeds'])
42 |
43 |
44 |
45 | def agglomerative_clustering(x_samples, y_samples, z_samples, max_dist=10):
46 |
47 | samples = np.array([x_samples, y_samples, z_samples]).T
48 |
49 | clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=max_dist, linkage='complete').fit(samples)
50 |
51 | cluster_labels = clustering.labels_
52 |
53 | cluster_samples_x = []
54 | cluster_samples_y = []
55 | cluster_samples_z = []
56 | for label in np.unique(cluster_labels):
57 |
58 | cluster_samples_x.append(int(np.mean(x_samples[cluster_labels==label])))
59 | cluster_samples_y.append(int(np.mean(y_samples[cluster_labels==label])))
60 | cluster_samples_z.append(int(np.mean(z_samples[cluster_labels==label])))
61 |
62 | cluster_samples_x = np.array(cluster_samples_x)
63 | cluster_samples_y = np.array(cluster_samples_y)
64 | cluster_samples_z = np.array(cluster_samples_z)
65 |
66 | return cluster_samples_x, cluster_samples_y, cluster_samples_z
67 |
68 |
69 |
70 |
71 | class SyntheticCellMembranes:
72 | """Parent class for generating synthetic cell membranes"""
73 |
74 |
75 |
76 | def __init__(self, gridsize=(128,128,128), distance_weight=0.25, cell_density=1/20):
77 |
78 | self.gridsize = gridsize
79 | self.distance_weight = distance_weight
80 | self.cell_density = cell_density
81 |
82 | self.instance_mask = None
83 | self.x_fg = []
84 | self.y_fg = []
85 | self.z_fg = []
86 | self.x_cell = []
87 | self.y_cell = []
88 | self.z_cell = []
89 |
90 |
91 |
92 | def _cart2sphere(self, x, y, z):
93 |
94 | n = x**2 + y**2 + z**2
95 | n = n.astype(np.float)
96 | n[n==0] += 1e-7 # prevent zero divisions
97 | r = np.sqrt(n)
98 | p = np.arctan2(y, x)
99 | t = np.arccos(z/np.sqrt(n))
100 |
101 | return r, t, p
102 |
103 |
104 |
105 | def _sphere2cart(self, r, t, p):
106 |
107 | x = np.sin(t) * np.cos(p) * r
108 | y = np.sin(t) * np.sin(p) * r
109 | z = np.cos(t) * r
110 |
111 | return x, y, z
112 |
113 |
114 |
115 | def generate_instances(self):
116 |
117 | print_timestamp('Generating foreground region...')
118 | self._generate_foreground()
119 | print_timestamp('Placing random centroids...')
120 | self._place_centroids()
121 | print_timestamp('Creating Voronoi tessellation...')
122 | self._voronoi_tessellation()
123 | print_timestamp('Morphological postprocessing...')
124 | self._post_processing()
125 |
126 |
127 |
128 | # sphere generation
129 | def _generate_foreground(self):
130 |
131 | # Determine foreground region
132 | image_ind = np.indices(self.gridsize, dtype=np.int)
133 | x_fg = image_ind[0].flatten()
134 | y_fg = image_ind[1].flatten()
135 | z_fg = image_ind[2].flatten()
136 |
137 | self.x_fg = x_fg
138 | self.y_fg = y_fg
139 | self.z_fg = z_fg
140 |
141 |
142 |
143 | def _place_centroids(self):
144 |
145 | # calculate the volume
146 | vol = len(self.x_fg)
147 |
148 | # estimate the number of cells
149 | cell_count = vol*(self.cell_density**3)
150 | cell_count = int(cell_count)
151 |
152 | # select random points within the foregound region
153 | rnd_idx = np.random.choice(np.arange(0,len(self.x_fg)), size=cell_count, replace=False)
154 |
155 | x_cell = self.x_fg[rnd_idx].copy()
156 | y_cell = self.y_fg[rnd_idx].copy()
157 | z_cell = self.z_fg[rnd_idx].copy()
158 |
159 | # perform clustering
160 | x_cell = np.array(x_cell)
161 | y_cell = np.array(y_cell)
162 | z_cell = np.array(z_cell)
163 | x_cell, y_cell, z_cell = agglomerative_clustering(x_cell, y_cell, z_samples=z_cell, max_dist=2*self.cell_density**-1)
164 |
165 | self.x_cell = x_cell
166 | self.y_cell = y_cell
167 | self.z_cell = z_cell
168 |
169 |
170 |
171 | def _voronoi_tessellation(self):
172 |
173 | # get each foreground sample to be tesselated
174 | samples = list(zip(self.x_fg, self.y_fg, self.z_fg))
175 |
176 | # get each cell centroid
177 | cells = list(zip(self.x_cell, self.y_cell, self.z_cell))
178 | cells_tree = cKDTree(cells)
179 |
180 | # determine weights for each cell (improved roundness)
181 | closest_cell_dist, _ = cells_tree.query(cells, k=2)
182 | weights = closest_cell_dist[:,1]/np.max(closest_cell_dist[:,1])*self.distance_weight + 1
183 |
184 | # determine the closest cell candidates for each sampling point
185 | candidates_dist, candidates_idx = cells_tree.query(samples, k=np.minimum(5,len(self.x_cell)-1))
186 | candidates_dist = np.multiply(candidates_dist, weights[candidates_idx])
187 | candidates_closest = np.argmin(candidates_dist, axis=1)
188 |
189 | instance_mask = np.zeros(self.gridsize, dtype=np.uint16)
190 | for sample, cand_idx, closest_idx in zip(samples, candidates_idx, candidates_closest):
191 | instance_mask[sample] = cand_idx[closest_idx]+1
192 |
193 | self.instance_mask = instance_mask
194 |
195 |
196 |
197 | def _post_processing(self):
198 | pass
199 |
200 |
201 |
202 | def get_instance_mask(self):
203 |
204 | return self.instance_mask
205 |
206 |
207 |
208 | def get_centroid_mask(self, data=None):
209 |
210 | if data is None:
211 | if self.instance_mask is None:
212 | return None
213 | else:
214 | data = self.instance_mask
215 |
216 | # create centroid mask
217 | centroid_mask = np.zeros(data.shape, dtype=np.bool)
218 |
219 | # find and place centroids
220 | regions = measure.regionprops(data)
221 | for props in regions:
222 | c = props.centroid
223 | centroid_mask[np.int(c[0]), np.int(c[1]), np.int(c[2])] = True
224 |
225 | return centroid_mask
226 |
227 |
228 |
229 | def get_boundary_mask(self, data=None):
230 |
231 | # get instance mask
232 | if data is None:
233 | if self.instance_mask is None:
234 | return None
235 | else:
236 | data = self.instance_mask
237 |
238 | membrane_mask = morphology.dilation(data, selem=morphology.ball(3)) - data
239 | membrane_mask = membrane_mask != 0
240 |
241 | return membrane_mask
242 |
243 |
244 | def get_distance_mask(self, data=None):
245 |
246 | # get instance mask
247 | if data is None:
248 | if self.instance_mask is None:
249 | return None
250 | else:
251 | data = self.instance_mask
252 |
253 | distance_encoding = np.zeros(data.shape, dtype=np.float32)
254 |
255 | # get foreground distance
256 | distance_encoding = distance_transform_edt(data>0)
257 |
258 | # get background distance
259 | distance_encoding = distance_encoding - distance_transform_edt(data<=0)
260 |
261 | return distance_encoding
262 |
263 |
264 |
265 |
266 |
267 |
268 | # Class for generation of synthetic meristem data
269 | class SyntheticMeristem(SyntheticCellMembranes):
270 | """Child class for generating synthetic meristem membranes"""
271 |
272 |
273 |
274 | def __init__(self, gridsize=(120,512,512), distance_weight=0.25, # general params
275 | morph_radius=3, weights=None, # foreground params
276 | cell_density=1/23, cell_density_decay=0.9, cell_position_smoothness=10, # cell params
277 | ring_density=1/23, ring_density_decay=0.9, # ring params
278 | angular_sampling_file=r'utils/theta_phi_sampling_5000points_10000iter.npy',
279 | specimen_sampling_file=r'utils/PNAS_sampling.csv'):
280 |
281 | super().__init__(gridsize=gridsize, distance_weight=distance_weight, cell_density=cell_density)
282 |
283 | # foregound params
284 | self.gridsize_max = np.max(gridsize)
285 | self.morph_radius = morph_radius
286 | self.weights = weights
287 |
288 | # cell params
289 | self.cell_density_decay = cell_density_decay
290 | self.cell_position_smoothness = cell_position_smoothness
291 |
292 | # ring params
293 | self.ring_density = ring_density
294 | self.ring_density_decay = ring_density_decay
295 |
296 | self.angular_sampling_file = angular_sampling_file
297 | self.specimen_sampling_file = specimen_sampling_file
298 |
299 | # initialize the statistical shape model
300 | print_timestamp('Initializing the statistical shape model...')
301 | self.sampling_angles = np.load(angular_sampling_file)
302 | specimen_sampling = pd.read_csv(specimen_sampling_file, sep=';').to_numpy()
303 | self.specimen_mean = np.mean(specimen_sampling, axis=1)
304 |
305 | # calculate the PCA
306 | specimen_cov = np.cov(specimen_sampling)
307 | specimen_pca = PCA(n_components=3)
308 | specimen_pca.fit(specimen_cov)
309 | self.specimen_pca = specimen_pca
310 |
311 | if not self.weights is None:
312 | assert len(self.weights) == len(self.specimen_pca.singular_values_), 'Number of weights ({0}) does not match the number of eigenvectors {1}'.format(len(self.weights), len(self.specimen_pca.singular_values_))
313 |
314 |
315 | # sphere generation
316 | def _generate_foreground(self):
317 |
318 | # Construct the sampling tree
319 | sampling_tree = cKDTree(self.sampling_angles)
320 |
321 | if self.weights is None:
322 | # Generate random sampling
323 | weights = np.random.randn(len(self.specimen_pca.singular_values_))
324 | else:
325 | weights = np.array(self.weights)
326 | specimen_rnd = self.specimen_mean + np.matmul(self.specimen_pca.components_.T, np.sqrt(self.specimen_pca.singular_values_) * weights)
327 |
328 | # Normalize the radii
329 | specimen_rnd /= specimen_rnd.max()
330 |
331 | # Calculate the image size based on the shape model and the desired gridsize
332 | specimen_x, specimen_y, specimen_z = self._sphere2cart(specimen_rnd,self.sampling_angles[:,0],self.sampling_angles[:,1])
333 | specimen_dim_ratio = np.array([2*np.abs(specimen_x.max()), 2*np.abs(specimen_y.max()), np.abs(specimen_z.max())]) # *2 on x and y, since it's a hemisphere starting at the center
334 | specimen_dim_ratio /= specimen_dim_ratio.max()
335 | self.gridsize = np.array(specimen_dim_ratio*self.gridsize_max, dtype=np.int)
336 |
337 | # Adjust to desired grid size
338 | specimen_rnd *= self.gridsize_max/2*0.95
339 | self.sampling_radii = specimen_rnd.copy()
340 |
341 | # Determine foreground region
342 | image_ind = np.indices(self.gridsize, dtype=np.int)
343 | x = image_ind[0].flatten()-self.gridsize[0]/2
344 | y = image_ind[1].flatten()-self.gridsize[1]/2
345 | z = image_ind[2].flatten()
346 |
347 | r,t,p = self._cart2sphere(x,y,z)
348 |
349 | # Determine nearest sampling angles
350 | _, assignments = sampling_tree.query(np.array([t,p]).T, k=3)
351 | specimen_rnd[specimen_rnd==0] = np.nan
352 |
353 | # Determine foreground region
354 | foreground_ind = r <= np.nanmean(specimen_rnd[assignments], axis=1)
355 | x_fg = x[foreground_ind]+self.gridsize[0]/2
356 | x_fg = x_fg.astype(np.int)
357 | y_fg = y[foreground_ind]+self.gridsize[1]/2
358 | y_fg = y_fg.astype(np.int)
359 | z_fg = z[foreground_ind]
360 | z_fg = z_fg.astype(np.int)
361 |
362 | self.x_fg = x_fg
363 | self.y_fg = y_fg
364 | self.z_fg = z_fg
365 |
366 |
367 |
368 | def _place_centroids(self):
369 |
370 | cell_density = self.cell_density
371 | ring_density = self.ring_density
372 | cluster_density = cell_density
373 |
374 | # initialize the first radius (small offset to place the first ring near the boundary)
375 | ring_radii = self.sampling_radii + (self.ring_density**-1)/2
376 |
377 | x_cell = []
378 | y_cell = []
379 | z_cell = []
380 | ring_count = 0
381 |
382 | while ring_radii.max() - ring_density**-1 > 0:
383 |
384 | # set off the new ring
385 | ring_radii = np.maximum(0, ring_radii - ring_density**-1)
386 |
387 | # get the spherical coordinated of the new ring
388 | t = self.sampling_angles[ring_radii != 0, 0]
389 | p = self.sampling_angles[ring_radii != 0, 1]
390 | r = ring_radii[ring_radii != 0]
391 |
392 | # apply small offsets to the radii, depending on the depth of the current ring
393 | r = r + np.random.normal(loc=0, scale=ring_count/cell_density/self.cell_position_smoothness, size=len(r))
394 | r = np.maximum(0, r)
395 |
396 | # convert to cartesian coordinates
397 | x,y,z = self._sphere2cart(r,t,p)
398 | x = np.array(x)
399 | y = np.array(y)
400 | z = np.array(z)
401 |
402 | # cluster cell centroid candidates to reduce computation afterwards
403 | if len(x) > 1:
404 | x, y, z = agglomerative_clustering(x, y, z_samples=z, max_dist=(cell_density**-1))
405 |
406 | # extend the centroid list
407 | x_cell.extend(x+self.gridsize[0]//2)
408 | y_cell.extend(y+self.gridsize[1]//2)
409 | z_cell.extend(z)
410 |
411 | cluster_density = np.max([cluster_density, cell_density, ring_density])
412 | cell_density = cell_density*self.cell_density_decay
413 | ring_density = ring_density*self.ring_density_decay
414 | ring_count += 1
415 |
416 | # perform clustering
417 | x_cell = np.array(x_cell)
418 | y_cell = np.array(y_cell)
419 | z_cell = np.array(z_cell)
420 | x_cell, y_cell, z_cell = agglomerative_clustering(x_cell, y_cell, z_samples=z_cell, max_dist=cluster_density**-1)
421 |
422 | self.x_cell = x_cell
423 | self.y_cell = y_cell
424 | self.z_cell = z_cell
425 |
426 |
427 |
428 | def _post_processing(self):
429 |
430 | # get the memebrane mask
431 | membrane_mask = self.get_boundary_mask()
432 |
433 | # open the inner parts of the cells
434 | opened_fg = morphology.binary_opening(~membrane_mask, selem=morphology.ball(self.morph_radius))
435 | opened_fg[self.instance_mask==0] = False
436 |
437 | # erode the fg region
438 | eroded_fg = morphology.binary_erosion(self.instance_mask>0, selem=morphology.ball(int(self.morph_radius*1.5)))
439 |
440 | # generate the enhanced foreground mask
441 | enhanced_fg = np.logical_or(opened_fg, eroded_fg)
442 |
443 | # generate the instance mask
444 | self.instance_mask[~enhanced_fg] = 0
445 | self.x_fg, self.y_fg, self.z_fg = np.nonzero(self.instance_mask)
--------------------------------------------------------------------------------
/utils/theta_phi_sampling_5000points_10000iter.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stegmaierj/DiffusionModelsForImageSynthesis/f60d09b9a002a9786f0d7b19f7974b215609278e/utils/theta_phi_sampling_5000points_10000iter.npy
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu May 20 13:53:11 2021
4 |
5 | @author: Nutzer
6 | """
7 |
8 | import time
9 | import os
10 | import glob
11 | from skimage import io
12 |
13 |
14 | def print_timestamp(msg, args=None):
15 |
16 | print('[{0:02.0f}:{1:02.0f}:{2:02.0f}] '.format(time.localtime().tm_hour,time.localtime().tm_min,time.localtime().tm_sec) +\
17 | msg.format(*args if not args is None else ''))
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------