├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── python4deepimagej.iml
└── vcs.xml
├── LICENSE
├── exportFRUNet_from_keras.ipynb
├── export_StarDist_to_TensorFlow_SavedModel.ipynb
├── keras_for_deepimagej.ipynb
├── requirements.txt
├── unet
├── data
│ ├── processed.zip
│ └── raw.zip
├── py_files
│ ├── convert_to_pb.py
│ ├── data_loading.py
│ ├── fit_model.py
│ ├── helpers.py
│ ├── model.py
│ ├── prep_data.py
│ ├── unet_dm.py
│ └── unet_weights.py
└── train_and_test_unet.ipynb
└── xml
├── config_template.xml
└── create_config.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
113 |
114 |
115 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/python4deepimagej.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2019, DeepImageJ
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/exportFRUNet_from_keras.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "rQfu09ZWP3RH"
8 | },
9 | "source": [
10 | "**Use this code to export FRU-Net model to proto buffer and use it in DeepImageJ**\n",
11 | "\n",
12 | "FRU-Net: https://cbia.fi.muni.cz/research/segmentation/fru-net\n",
13 | "\n",
14 | "DeepImageJ: https://deepimagej.github.io/deepimagej/index.html\n"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "colab_type": "text",
21 | "id": "Paui0JfkQ8Lt"
22 | },
23 | "source": [
24 | "Mount your Google Drive"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 10,
30 | "metadata": {
31 | "colab": {
32 | "base_uri": "https://localhost:8080/",
33 | "height": 54
34 | },
35 | "colab_type": "code",
36 | "executionInfo": {
37 | "elapsed": 840,
38 | "status": "ok",
39 | "timestamp": 1564585146016,
40 | "user": {
41 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL",
42 | "photoUrl": "",
43 | "userId": "04592796515262324641"
44 | },
45 | "user_tz": -120
46 | },
47 | "id": "8FGAyk73Q7nR",
48 | "outputId": "a41a584b-7aa7-4b49-df83-b10ee0c5e9de"
49 | },
50 | "outputs": [
51 | {
52 | "name": "stdout",
53 | "output_type": "stream",
54 | "text": [
55 | "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "from google.colab import drive\n",
61 | "drive.mount('/content/drive')"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {
67 | "colab_type": "text",
68 | "id": "nfkCAMIkROF2"
69 | },
70 | "source": [
71 | "Install a compatible version of Keras and Tensorflow"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 11,
77 | "metadata": {
78 | "colab": {
79 | "base_uri": "https://localhost:8080/",
80 | "height": 530
81 | },
82 | "colab_type": "code",
83 | "executionInfo": {
84 | "elapsed": 5862,
85 | "status": "ok",
86 | "timestamp": 1564585153377,
87 | "user": {
88 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL",
89 | "photoUrl": "",
90 | "userId": "04592796515262324641"
91 | },
92 | "user_tz": -120
93 | },
94 | "id": "tM-GjE2TRMXO",
95 | "outputId": "9972d447-6ca5-4b63-aa20-74029b2c8f91"
96 | },
97 | "outputs": [
98 | {
99 | "name": "stdout",
100 | "output_type": "stream",
101 | "text": [
102 | "Cloning into 'python4deepimagej'...\n",
103 | "remote: Enumerating objects: 7, done.\u001b[K\n",
104 | "remote: Counting objects: 100% (7/7), done.\u001b[K\n",
105 | "remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
106 | "remote: Total 7 (delta 2), reused 0 (delta 0), pack-reused 0\u001b[K\n",
107 | "Unpacking objects: 100% (7/7), done.\n",
108 | "Requirement already satisfied: keras==1.2.2 in /usr/local/lib/python3.6/dist-packages (1.2.2)\n",
109 | "Requirement already satisfied: tensorflow in /usr/local/lib/python3.6/dist-packages (1.13.1)\n",
110 | "Requirement already satisfied: theano in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (1.0.4)\n",
111 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (3.13)\n",
112 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (1.12.0)\n",
113 | "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.0)\n",
114 | "Requirement already satisfied: tensorboard<1.14.0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.13.1)\n",
115 | "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.0)\n",
116 | "Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.2.2)\n",
117 | "Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.8.0)\n",
118 | "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.0.8)\n",
119 | "Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.16.4)\n",
120 | "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.33.4)\n",
121 | "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.15.0)\n",
122 | "Requirement already satisfied: tensorflow-estimator<1.14.0rc0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.13.0)\n",
123 | "Requirement already satisfied: absl-py>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.7.1)\n",
124 | "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (3.7.1)\n",
125 | "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from theano->keras==1.2.2) (1.3.0)\n",
126 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow) (0.15.5)\n",
127 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow) (3.1.1)\n",
128 | "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow) (2.8.0)\n",
129 | "Requirement already satisfied: mock>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-estimator<1.14.0rc0,>=1.13.0->tensorflow) (3.0.5)\n",
130 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow) (41.0.1)\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "%pip install keras==1.2.2 tensorflow\n"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {
141 | "colab_type": "text",
142 | "id": "DRHahE4lRXWs"
143 | },
144 | "source": [
145 | "Import dependencies"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 0,
151 | "metadata": {
152 | "colab": {},
153 | "colab_type": "code",
154 | "id": "6wgEqbg1RTLg"
155 | },
156 | "outputs": [],
157 | "source": [
158 | "import tensorflow as tf\n",
159 | "import keras\n",
160 | "from keras import backend as K"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 0,
166 | "metadata": {
167 | "colab": {},
168 | "colab_type": "code",
169 | "id": "jTQ8MQJjWdY7"
170 | },
171 | "outputs": [],
172 | "source": [
173 | "Download a trained FRU-Net model from FRU-Net: https://cbia.fi.muni.cz/research/segmentation/fru-net"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "colab_type": "text",
180 | "id": "a6jENPIkjmYG"
181 | },
182 | "source": [
183 | "Download the ZIP file containing all the information about FRU-Net from https://cbia.fi.muni.cz/research/segmentation/fru-net.\n",
184 | "\n",
185 | "Unzip the file and load one of the trained models (.h5)\n",
186 | "\n"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": 0,
192 | "metadata": {
193 | "colab": {},
194 | "colab_type": "code",
195 | "id": "GSIQ8TUUWZBe"
196 | },
197 | "outputs": [],
198 | "source": [
199 | "#Fill the path to your keras network\n",
200 | "path2network='/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/fully_residual_dropout_segmentation.h5'\n",
201 | "\n",
202 | "# Set the learning phase to convert properly the model\n",
203 | "# The learning phase flag is a bool tensor (0 = test, 1 = train) to be passed as\n",
204 | "# input to any Keras function that uses a different behavior at train time and \n",
205 | "# test time.\n",
206 | "\n",
207 | "K.set_learning_phase(1)\n",
208 | "\n",
209 | "# Load the model\n",
210 | "model = keras.models.load_model(path2network)"
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {
216 | "colab_type": "text",
217 | "id": "vstqSMe1XOri"
218 | },
219 | "source": [
220 | "Save your keras model as proto buffer"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 9,
226 | "metadata": {
227 | "colab": {
228 | "base_uri": "https://localhost:8080/",
229 | "height": 156
230 | },
231 | "colab_type": "code",
232 | "executionInfo": {
233 | "elapsed": 4303,
234 | "status": "ok",
235 | "timestamp": 1564584430801,
236 | "user": {
237 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL",
238 | "photoUrl": "",
239 | "userId": "04592796515262324641"
240 | },
241 | "user_tz": -120
242 | },
243 | "id": "GNmQlysgWvBf",
244 | "outputId": "61a34f37-107a-41a2-8728-e16a41dde315"
245 | },
246 | "outputs": [
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:205: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n",
252 | "Instructions for updating:\n",
253 | "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n",
254 | "INFO:tensorflow:No assets to save.\n",
255 | "INFO:tensorflow:No assets to write.\n",
256 | "INFO:tensorflow:SavedModel written to: /content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet/saved_model.pb\n"
257 | ]
258 | },
259 | {
260 | "data": {
261 | "text/plain": [
262 | "b'/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet/saved_model.pb'"
263 | ]
264 | },
265 | "execution_count": 9,
266 | "metadata": {
267 | "tags": []
268 | },
269 | "output_type": "execute_result"
270 | }
271 | ],
272 | "source": [
273 | "OUTPUT_DIR = \"/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet\"\n",
274 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n",
275 | "\n",
276 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n",
277 | " inputs = {'input': model.input},\n",
278 | " outputs = {'output': model.output})\n",
279 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n",
280 | "\n",
281 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n",
282 | " signature_def_map=signature_def_map)\n",
283 | "builder.save()"
284 | ]
285 | }
286 | ],
287 | "metadata": {
288 | "colab": {
289 | "name": "exportFRUNet_from_keras.ipynb",
290 | "provenance": [],
291 | "version": "0.3.2"
292 | },
293 | "kernelspec": {
294 | "display_name": "Python 3",
295 | "language": "python",
296 | "name": "python3"
297 | },
298 | "language_info": {
299 | "codemirror_mode": {
300 | "name": "ipython",
301 | "version": 3
302 | },
303 | "file_extension": ".py",
304 | "mimetype": "text/x-python",
305 | "name": "python",
306 | "nbconvert_exporter": "python",
307 | "pygments_lexer": "ipython3",
308 | "version": "3.6.8"
309 | }
310 | },
311 | "nbformat": 4,
312 | "nbformat_minor": 1
313 | }
314 |
--------------------------------------------------------------------------------
/export_StarDist_to_TensorFlow_SavedModel.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Untitled0.ipynb",
7 | "provenance": []
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | }
13 | },
14 | "cells": [
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "id": "FtkHVlkGZuz3",
19 | "colab_type": "text"
20 | },
21 | "source": [
22 | "# **This is a genertic python code to export StarDist trained models and use them with DeepImageJ plugin**\n",
23 | "\n",
24 | "https://deepimagej.github.io/deepimagej/index.html\n"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "MklgDjVtbEKz",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "If you are using Google Colab, mount your Google Drive. Otherwise, skip this step"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "metadata": {
40 | "id": "EzTJ9dClsagp",
41 | "colab_type": "code",
42 | "colab": {}
43 | },
44 | "source": [
45 | "from google.colab import drive\n",
46 | "drive.mount('/content/drive')"
47 | ],
48 | "execution_count": 0,
49 | "outputs": []
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {
54 | "id": "OPfCnAKAZ7XI",
55 | "colab_type": "text"
56 | },
57 | "source": [
58 | "\n",
59 | "Install the following packages: \n",
60 | "- A compatible version of Tensorflow <= 1.13.\n",
61 | "- stardist python package. Here we used StarDist 0.3.6"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "metadata": {
67 | "id": "y_3YIsZxs39M",
68 | "colab_type": "code",
69 | "colab": {}
70 | },
71 | "source": [
72 | "% pip install tensorflow==1.13.1\n",
73 | "% pip install stardist"
74 | ],
75 | "execution_count": 0,
76 | "outputs": []
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {
81 | "id": "wI2iQC_ttE_q",
82 | "colab_type": "text"
83 | },
84 | "source": [
85 | "# Load the StarDist trained model from your repository"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {
91 | "id": "MklgDjVtbEKz",
92 | "colab_type": "text"
93 | },
94 | "source": [
95 | "Verify input and output sizes of your model. They can be different when the parameter grid is not (1,1). A different output size can lead to errors in DeepImageJ. Take it also into account if you want to perform shape measurements using the output image."
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "metadata": {
101 | "id": "wQGQdksis3_o",
102 | "colab_type": "code",
103 | "colab": {}
104 | },
105 | "source": [
106 | "from stardist.models import StarDist2D\n",
107 | "# Without shape completion\n",
108 | "model_paper = StarDist2D(None, name='name_of_your_model', basedir='/content/drive/My Drive/the_path_to_your_model/folde_containing_the_model')\n",
109 | "# Indicate which weights you want to use\n",
110 | "model_paper.load_weights('weights_best.h5')"
111 | ],
112 | "execution_count": 0,
113 | "outputs": []
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {
118 | "id": "tEj1Oj9FubL0",
119 | "colab_type": "text"
120 | },
121 | "source": [
122 | "# Save as a TensorFlow SavedModel"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "metadata": {
128 | "id": "7B5veKEEuhq7",
129 | "colab_type": "code",
130 | "colab": {}
131 | },
132 | "source": [
133 | "import keras\n",
134 | "import keras.backend as K\n",
135 | "from keras.layers import concatenate\n",
136 | "import tensorflow as tf\n",
137 | "#Write the path where you would like to save the model. \n",
138 | "# The code will automatically create a new folder called \"new_folder\", where the\n",
139 | "# TensorFlow model will be saved\n",
140 | "OUTPUT_DIR = \"/content/drive/My Drive/the_path_where_you_want_to_save_your_model/new_folder\"\n",
141 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n",
142 | "\n",
143 | "# StarDist has two different outputs. DeepImageJ can only read one of them, so \n",
144 | "# we concatenate them as different channels in order to used them in ImageJ.\n",
145 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n",
146 | " inputs = {'input': model_paper.keras_model.input[0]},\n",
147 | " # concatenate the output of StarDist\n",
148 | " outputs = {'output': concatenate([model_paper.keras_model.output[0],model_paper.keras_model.output[1]], axis = 3)})\n",
149 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n",
150 | "\n",
151 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n",
152 | " signature_def_map=signature_def_map)\n",
153 | "builder.save()"
154 | ],
155 | "execution_count": 0,
156 | "outputs": []
157 | }
158 | ]
159 | }
160 |
--------------------------------------------------------------------------------
/keras_for_deepimagej.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "rQfu09ZWP3RH"
8 | },
9 | "source": [
10 | "**This is a genertic python code to export Keras models and use them with DeepImageJ plugin**\n",
11 | "\n",
12 | "\n",
13 | "https://deepimagej.github.io/deepimagej/index.html\n"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {
19 | "colab_type": "text",
20 | "id": "Paui0JfkQ8Lt"
21 | },
22 | "source": [
23 | "Mount your Google Drive"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 0,
29 | "metadata": {
30 | "colab": {},
31 | "colab_type": "code",
32 | "id": "8FGAyk73Q7nR"
33 | },
34 | "outputs": [],
35 | "source": [
36 | "from google.colab import drive\n",
37 | "drive.mount('/content/drive')"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "colab_type": "text",
44 | "id": "nfkCAMIkROF2"
45 | },
46 | "source": [
47 | "Install a compatible version of Tensorflow <= 1.13"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 0,
53 | "metadata": {
54 | "colab": {},
55 | "colab_type": "code",
56 | "id": "tM-GjE2TRMXO"
57 | },
58 | "outputs": [],
59 | "source": [
60 | "%pip install tensorflow==1.13.1\n",
61 | "%pip install keras==2.2.4\n"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {
67 | "colab_type": "text",
68 | "id": "DRHahE4lRXWs"
69 | },
70 | "source": [
71 | "Import dependencies"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 0,
77 | "metadata": {
78 | "colab": {},
79 | "colab_type": "code",
80 | "id": "6wgEqbg1RTLg"
81 | },
82 | "outputs": [],
83 | "source": [
84 | "import tensorflow as tf\n",
85 | "import keras\n",
86 | "from keras import backend as K"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 0,
92 | "metadata": {
93 | "colab": {},
94 | "colab_type": "code",
95 | "id": "jTQ8MQJjWdY7"
96 | },
97 | "outputs": [],
98 | "source": [
99 | "Load a keras network"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 0,
105 | "metadata": {
106 | "colab": {},
107 | "colab_type": "code",
108 | "id": "GSIQ8TUUWZBe"
109 | },
110 | "outputs": [],
111 | "source": [
112 | "#Fill the path to your keras network\n",
113 | "path2network='/path2yournetwork/your_network.hdf5'\n",
114 | "model = keras.models.load_model(path2network)"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {
120 | "colab_type": "text",
121 | "id": "vstqSMe1XOri"
122 | },
123 | "source": [
124 | "Save your keras model as proto buffer"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 0,
130 | "metadata": {
131 | "colab": {},
132 | "colab_type": "code",
133 | "id": "GNmQlysgWvBf"
134 | },
135 | "outputs": [],
136 | "source": [
137 | "#If the model has only one input it can be converted\n",
138 | "OUTPUT_DIR = \"/your/output/directory/new_folder_name\"\n",
139 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n",
140 | "\n",
141 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n",
142 | " inputs = {'input': model.input},\n",
143 | " outputs = {'output': model.output})\n",
144 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n",
145 | "\n",
146 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n",
147 | " signature_def_map=signature_def_map)\n",
148 | "builder.save()"
149 | ]
150 | }
151 | ],
152 | "metadata": {
153 | "colab": {
154 | "name": "keras_for_deepimagej.ipynb",
155 | "provenance": [],
156 | "version": "0.3.2"
157 | },
158 | "kernelspec": {
159 | "display_name": "Python 3",
160 | "language": "python",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.6.8"
174 | }
175 | },
176 | "nbformat": 4,
177 | "nbformat_minor": 1
178 | }
179 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | xml
3 | time
4 | urllib
5 | shutil
6 | skimage
7 | tensorflow<=2.2.1
8 |
--------------------------------------------------------------------------------
/unet/data/processed.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepimagej/python4deepimagej/535fc4061f9ae93878d70c68f23536233bb74562/unet/data/processed.zip
--------------------------------------------------------------------------------
/unet/data/raw.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepimagej/python4deepimagej/535fc4061f9ae93878d70c68f23536233bb74562/unet/data/raw.zip
--------------------------------------------------------------------------------
/unet/py_files/convert_to_pb.py:
--------------------------------------------------------------------------------
1 | # Important librairies.
2 |
3 | import keras
4 | from keras import backend as K
5 | from keras.models import load_model
6 |
7 | import tensorflow as tf
8 | from tensorflow.compat.v1 import graph_util
9 | from tensorflow.python.framework import graph_io
10 |
11 | # -----------------------------------------------------------------------------
12 |
13 | def convert_to_pb(name_project):
14 | """
15 | Converts Keras model into Tensorflow .pb file.
16 |
17 | The string 'name_project' represents the name of the DeepPix Worflow project
18 | given by the user.
19 | """
20 |
21 | # Define paths.
22 | path_to_model = '/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=name_project)
23 | path_output = '/content/drive/My Drive/unser_project/'
24 |
25 | # Load model.
26 | model = load_model(path_to_model)
27 |
28 | # Get node names.
29 | node_names = [node.op.name for node in model.outputs]
30 |
31 | # Get Keras session.
32 | session = K.get_session()
33 |
34 | # Convert Keras variables to Tensorflow constants.
35 | graph_to_constant = graph_util.convert_variables_to_constants(session, session.graph.as_graph_def(), node_names)
36 |
37 | # Write graph as .pb file.
38 | graph_io.write_graph(graph_to_constant, path_output, name_project + ".pb", as_text=False)
39 |
--------------------------------------------------------------------------------
/unet/py_files/data_loading.py:
--------------------------------------------------------------------------------
1 | # Important librairies.
2 |
3 | from __future__ import print_function
4 | from keras.preprocessing.image import ImageDataGenerator
5 | import numpy as np
6 | import glob
7 |
8 | # -----------------------------------------------------------------------------
9 |
10 | # Important py.files.
11 |
12 | from helpers import *
13 |
14 | # -----------------------------------------------------------------------------
15 |
16 | def dataGenerator(path, batch_size = 2, subset = 'train', target_size = (256,256), seed = 1):
17 |
18 | """
19 | Builds generators for the U-Net. The generators can be built for
20 | training, testing and validation purposes.
21 |
22 | The string "subset" is used to specify which type of data we are dealing
23 | with (train, test or validation). Default value is set to 'train'.
24 |
25 | The string "path" represents a path that should lead to images and labels
26 | folders named 'image' and 'label' respectively.
27 |
28 | The tuple "target_size" is used to specify the final sizes of the images
29 | and labels after augmentation. If the given size does not correspond to
30 | original size of the images and labels, the data will be resized with the
31 | given size. Default value is set to (256, 256) (image size of 256x256 pixels).
32 |
33 | The variable seed is needed to ensure that images and labels will be augmented
34 | together in the right orders. Default value set to 1.
35 | """
36 |
37 | # Builds generator for training set.
38 | if subset == "train":
39 |
40 | # Preprocessing arguments.
41 | aug_arg = dict(rotation_range = 40,
42 | width_shift_range = 0.2,
43 | height_shift_range = 0.2,
44 | shear_range = 0.2,
45 | horizontal_flip = True,
46 | vertical_flip = True,
47 | fill_mode='nearest')
48 |
49 | # Generates tensor images and labels with augmentations provided above.
50 | image_datagen = ImageDataGenerator(**aug_arg)
51 | label_datagen = ImageDataGenerator(**aug_arg)
52 |
53 | # Generator for images.
54 | image_generator = image_datagen.flow_from_directory(
55 | path,
56 | classes = ['image'],
57 | class_mode = None,
58 | color_mode = "grayscale",
59 | target_size = target_size,
60 | batch_size = batch_size,
61 | save_to_dir = None,
62 | seed = seed,
63 | shuffle = True)
64 |
65 | # Generator for labels.
66 | label_generator = label_datagen.flow_from_directory(
67 | path,
68 | classes = ['label'],
69 | class_mode = None,
70 | color_mode = "grayscale",
71 | target_size = target_size,
72 | batch_size = batch_size,
73 | save_to_dir = None,
74 | seed = seed,
75 | shuffle = True)
76 |
77 | # Builds generator for the training set.
78 | train_generator = zip(image_generator, label_generator)
79 |
80 | for (img,label) in train_generator:
81 | img, label = adjustData(img, label = label)
82 |
83 | yield (img, label)
84 |
85 | # Builds generator for validation set.
86 | elif subset == "validation":
87 |
88 | # Generates tensor images and labels with no augmentations
89 | # (validation set should not have any augmentation and does not
90 | # have to be shuffled).
91 | image_datagen = ImageDataGenerator()
92 | label_datagen = ImageDataGenerator()
93 |
94 | # Generator for images.
95 | image_generator = image_datagen.flow_from_directory(
96 | path,
97 | classes = ['image'],
98 | class_mode = None,
99 | color_mode = "grayscale",
100 | target_size = target_size,
101 | batch_size = batch_size,
102 | save_to_dir = None,
103 | seed = seed,
104 | shuffle = False)
105 |
106 | # Generator for labels.
107 | label_generator = label_datagen.flow_from_directory(
108 | path,
109 | classes = ['label'],
110 | class_mode = None,
111 | color_mode = "grayscale",
112 | target_size = target_size,
113 | batch_size = batch_size,
114 | save_to_dir = None,
115 | seed = seed,
116 | shuffle = False)
117 |
118 | # Builds generator for the validation set.
119 | validation_generator = zip(image_generator, label_generator)
120 |
121 | for (img,label) in validation_generator:
122 | img, label = adjustData(img, label = label)
123 |
124 | yield (img, label)
125 |
126 | # Builds generator for testing set.
127 | elif subset == "test":
128 |
129 | # Generates tensor images only with no augmentations (testing data
130 | # does not have to have labels and we do not shuffle the data
131 | # as it is not necessary).
132 | image_datagen = ImageDataGenerator()
133 |
134 | # Generator for images.
135 | image_generator = image_datagen.flow_from_directory(
136 | path,
137 | classes = ['image'],
138 | class_mode = None,
139 | color_mode = "grayscale",
140 | target_size = target_size,
141 | batch_size = batch_size,
142 | save_to_dir = None,
143 | seed = seed,
144 | shuffle = False)
145 |
146 | # Builds generator for the testing set.
147 | for img in image_generator:
148 |
149 | img = adjustData(img, False)
150 |
151 | yield img
152 |
153 | else:
154 | raise RuntimeError("Subset name not recognized")
155 |
156 | # -----------------------------------------------------------------------------
157 |
158 | def weightGen(path, batch_size = 2, subset = 'train', target_size = (256,256), seed = 1):
159 |
160 | """
161 | Builds generators for the weighted U-Net. The generators are built the
162 | same way as the dataGenerator function, only weight-maps are combined
163 | with the images. The generators can be built for training and
164 | validation purposes.
165 |
166 | The string "subset" is used to specify which type of data we are dealing
167 | with (train, test or validation). Default value is set to 'train'.
168 |
169 | The string "path" represents a path that should lead to images and labels
170 | folders named 'image' and 'label' respectively.
171 |
172 | The tuple "target_size" is used to specify the final sizes of the images
173 | and labels after augmentation. If the given size does not correspond to
174 | original size of the images and labels, the data will be resized with the
175 | given size. Default value is set to (256, 256) (image size of 256x256 pixels).
176 |
177 | The variable seed is needed to ensure that images and labels will be augmented
178 | together in the right orders. Default value set to 1.
179 | """
180 |
181 | # Builds generator for training set.
182 | if subset == "train":
183 |
184 | # Preprocessing arguments.
185 | aug_arg = dict(rotation_range = 40,
186 | width_shift_range = 0.2,
187 | height_shift_range = 0.2,
188 | shear_range = 0.2,
189 | horizontal_flip = True,
190 | vertical_flip = True,
191 | fill_mode='nearest')
192 |
193 | # Generates tensor images, weight-maps and labels with augmentations provided above.
194 | image_datagen = ImageDataGenerator(**aug_arg)
195 | label_datagen = ImageDataGenerator(**aug_arg)
196 | weight_datagen = ImageDataGenerator(**aug_arg)
197 |
198 | # Generator for images.
199 | image_generator = image_datagen.flow_from_directory(
200 | path,
201 | classes = ['image'],
202 | class_mode = None,
203 | color_mode = 'grayscale',
204 | target_size = target_size,
205 | batch_size = batch_size,
206 | save_to_dir = None,
207 | seed = seed,
208 | shuffle = True)
209 |
210 | # Generator for labels.
211 | label_generator = label_datagen.flow_from_directory(
212 | path,
213 | classes = ['label'],
214 | class_mode = None,
215 | color_mode = 'grayscale',
216 | target_size = target_size,
217 | batch_size = batch_size,
218 | save_to_dir = None,
219 | seed = seed,
220 | shuffle = True)
221 |
222 | # Retrieve weight-maps.
223 | filelist = glob.glob(path + "/weight/*.npy")
224 | filelist.sort(key=natural_keys)
225 |
226 | # Loads all weight-map images in a list.
227 | weights = [np.load(fname) for fname in filelist]
228 | weights = np.array(weights)
229 | weights = weights.reshape((len(weights),256,256,1))
230 |
231 | # Creates the weight generator.
232 | weight_generator = weight_datagen.flow(
233 | x = weights,
234 | y = None,
235 | batch_size = batch_size,
236 | seed = seed)
237 |
238 | # Builds generator for the training set.
239 | train_generator = zip(image_generator, label_generator, weight_generator)
240 |
241 | for (img, label, weight) in train_generator:
242 | img, label = adjustData(img, label = label)
243 |
244 | # This is the final generator.
245 | yield ([img, weight], label)
246 |
247 | elif subset == "validation":
248 |
249 | # Generates tensor images, weight maps and labels with no augmentations
250 | # and shuffling (since we are in the test set).
251 | image_datagen = ImageDataGenerator()
252 | label_datagen = ImageDataGenerator()
253 | weight_datagen = ImageDataGenerator()
254 |
255 | # Generator for images.
256 | image_generator = image_datagen.flow_from_directory(
257 | path,
258 | classes = ['image'],
259 | class_mode = None,
260 | color_mode = 'grayscale',
261 | target_size = target_size,
262 | batch_size = batch_size,
263 | save_to_dir = None,
264 | seed = seed,
265 | shuffle = False)
266 |
267 | # Generator for labels.
268 | label_generator = label_datagen.flow_from_directory(
269 | path,
270 | classes = ['label'],
271 | class_mode = None,
272 | color_mode = 'grayscale',
273 | target_size = target_size,
274 | batch_size = batch_size,
275 | save_to_dir = None,
276 | seed = seed,
277 | shuffle = False)
278 |
279 | # Retrieve weight maps.
280 | filelist = glob.glob(path + "/weight/*.npy")
281 | filelist.sort(key=natural_keys)
282 |
283 | # Loads all weight map images in a list.
284 | weights = [np.load(fname) for fname in filelist]
285 | weights = np.array(weights)
286 | weights = weights.reshape((len(weights),256,256,1))
287 |
288 | # Creates the weight generator.
289 | weight_generator = weight_datagen.flow(
290 | x = weights,
291 | y = None,
292 | batch_size = batch_size,
293 | seed = seed,
294 | shuffle = False)
295 |
296 | # Builds generator for the test set.
297 | test_generator = zip(image_generator, label_generator, weight_generator)
298 |
299 | for (img, label, weight) in test_generator:
300 | img, label = adjustData(img, label = label)
301 |
302 | # This is the final generator.
303 | yield ([img, weight], label)
304 |
305 | else:
306 | raise RuntimeError("Subset name not recognized")
307 |
308 | # -----------------------------------------------------------------------------
309 |
310 | def adjustData(image, adjust_lab = True, dist = False, label = None):
311 | """
312 | Normalizes the data such that images are in the interval [0,1] and labels
313 | are binary values in {0,1}. This step is important as augmentations with
314 | Keras' ImageDataGenerator() will change the pixel values of the images
315 | and labels and most notably images will not be normalized anymore and labels
316 | will not be binary anymore.
317 |
318 | The numpy array 'image' represents the input image.
319 |
320 | The numpy array 'label' represents the label image. If adjust_lab is set to
321 | False, label should be set to None. Default value set to None.
322 |
323 | The boolean value 'adjust_lab' specifies if we need to process labels or not
324 | for training an validation purposes. Default value set to True.
325 | """
326 |
327 | # Checks if the images are already between 0 and 1, otherwise
328 | # does the normalization.
329 | if(np.max(image) > 1):
330 | image = image / 255
331 |
332 | if adjust_lab:
333 |
334 | # Checks if the labels are already binary, otherwise
335 | # does the binarization.
336 | if (np.max(label) > 1):
337 | label = label / 255
338 | label[label > 0.5] = 1
339 | label[label <= 0.5] = 0
340 |
341 | return (image, label)
342 |
343 | if dist:
344 | return (image, label)
345 |
346 | else:
347 | return image
348 |
349 | # -----------------------------------------------------------------------------
350 |
351 | def loadGenerator(name_project, model_type, batch_size = 2, target_size = (256,256)):
352 | """
353 | Loads generators for the training of a model automatically.
354 |
355 | The string 'name_project' represents the name of the DeepPix Worflow project
356 | given by the user.
357 |
358 | The string 'model_type' represents the type of model desired by the user.
359 |
360 | The integer value 'batch_size' represents the size of the batches for the
361 | training. Default value set to 2.
362 |
363 | The tuple "target_size" is used to specify the final sizes of the images
364 | and labels after augmentation. If the given size does not correspond to
365 | original size of the images and labels, the data will be resized with the
366 | given size. Default value is set to (256, 256) (image size of 256x256 pixels).
367 | """
368 |
369 | # Path to data folder.
370 | data_path = "/content/drive/My Drive/unser_project/data/processed/"
371 |
372 | # Paths for train and test folders.
373 | train_path = data_path + name_project + "/train/"
374 | test_path = data_path + name_project + "/test/"
375 |
376 | # Create generators depending on model type.
377 | if model_type == "unet_simple":
378 | trainGen = dataGenerator(path = train_path, batch_size = batch_size, subset = "train", target_size = target_size)
379 | validGen = dataGenerator(path = test_path, batch_size = batch_size, subset = "validation", target_size = target_size)
380 |
381 | elif model_type == "unet_weighted":
382 | trainGen = weightGen(path = train_path, batch_size = batch_size, subset = "train", target_size = target_size)
383 | validGen = weightGen(path = test_path, batch_size = batch_size, subset = "validation", target_size = target_size)
384 |
385 | else:
386 | raise RuntimeError("Model not recognised.")
387 |
388 | return trainGen, validGen
389 |
390 | # -----------------------------------------------------------------------------
391 |
392 | #def distanceGen(path, n_classes, batch_size = 2, subset = 'train', image_folder = 'image', label_folder = 'distance',
393 | # image_col = "grayscale", label_col = "grayscale", target_size = (256,256), seed = 1):
394 | #
395 | # """Builds generators for the weighted U-Net. The generators are built
396 | # the same way as the dataGenerator function, only the weights are combined
397 | # with the images."""
398 | #
399 | # # Builds generator for training set
400 | # if subset == "train":
401 | #
402 | # # Preprocessing arguments
403 | # aug_arg = dict(rotation_range = 40,
404 | # width_shift_range = 0.2,
405 | # height_shift_range = 0.2,
406 | # shear_range = 0.2,
407 | # horizontal_flip = True,
408 | # vertical_flip = True,
409 | # fill_mode='nearest')
410 | #
411 | # # Generates tensor images and labels with augmentations provided above
412 | # image_datagen = ImageDataGenerator(**aug_arg)
413 | # label_datagen = ImageDataGenerator(**aug_arg)
414 | #
415 | # # Generator for images
416 | # image_generator = image_datagen.flow_from_directory(
417 | # path,
418 | # classes = [image_folder],
419 | # class_mode = None,
420 | # color_mode = 'grayscale',
421 | # target_size = target_size,
422 | # batch_size = batch_size,
423 | # save_to_dir = None,
424 | # seed = seed,
425 | # shuffle = True)
426 | #
427 | # # Generator for labels
428 | # label_generator = label_datagen.flow_from_directory(
429 | # path,
430 | # classes = [label_folder],
431 | # class_mode = 'categorical',
432 | # color_mode = 'rgb',
433 | # target_size = target_size,
434 | # batch_size = batch_size,
435 | # save_to_dir = None,
436 | # seed = seed,
437 | # shuffle = True)
438 | #
439 | # # Builds generator for the training set
440 | # train_generator = zip(image_generator, label_generator)
441 | #
442 | # for (img, label) in train_generator:
443 | # img, label = adjustData(img, False, True, label = label)
444 | #
445 | # print(label)
446 | # # This is the final generator
447 | # yield (img, label)
448 | #
449 | # elif subset == "test":
450 | #
451 | # # Generates tensor images and labels with no augmnetations and shuffling
452 | # # (since we are in the test set)
453 | # image_datagen = ImageDataGenerator()
454 | # label_datagen = ImageDataGenerator()
455 | #
456 | # # Generator for images
457 | # image_generator = image_datagen.flow_from_directory(
458 | # path,
459 | # classes = [image_folder],
460 | # class_mode = None,
461 | # color_mode = 'grayscale',
462 | # target_size = target_size,
463 | # batch_size = batch_size,
464 | # save_to_dir = None,
465 | # seed = seed,
466 | # shuffle = False)
467 | #
468 | # # Generator for labels
469 | # label_generator = label_datagen.flow_from_directory(
470 | # path,
471 | # classes = [label_folder],
472 | # class_mode = 'categorical',
473 | # color_mode = 'rgb',
474 | # target_size = target_size,
475 | # batch_size = batch_size,
476 | # save_to_dir = None,
477 | # seed = seed,
478 | # shuffle = False)
479 | #
480 | # # Builds generator for the test set
481 | # test_generator = zip(image_generator, label_generator)
482 | #
483 | # for (img, label) in test_generator:
484 | # img, label = adjustData(img, False, True, label = label)
485 | #
486 | # # This is the final generator
487 | # yield (img, label)
488 | #
489 | # else:
490 | # print("Subset name not recognized")
491 | # return None
492 |
493 | # -----------------------------------------------------------------------------
494 |
495 | #def label_to_cat(label, n_classes):
496 | #
497 | # label = np.rint(label / (255 / (n_classes - 1)))
498 | #
499 | # n, rows, cols, _ = label.shape
500 | # output = np.zeros((n, rows, cols, n_classes))
501 | #
502 | # for i in range(n_classes):
503 | # tmp = (label[...,0] == i).astype(int)
504 | # output[...,i] = tmp
505 | #
506 | # output = np.reshape(output, (n, rows*cols, n_classes))
507 | #
508 | # return output
--------------------------------------------------------------------------------
/unet/py_files/fit_model.py:
--------------------------------------------------------------------------------
1 | # Important librairies
2 |
3 | import pickle
4 | import cv2 as cv
5 | import sys
6 | import glob
7 |
8 | # -----------------------------------------------------------------------------
9 |
10 | # Important py.files
11 | sys.path.append("/content/drive/My Drive/unser_project/py_files")
12 | from model import *
13 | from data_loading import *
14 | from helpers import *
15 | from unet_weights import *
16 | from unet_dm import *
17 |
18 | # -----------------------------------------------------------------------------
19 |
20 | def fit_model(trainGen, validGen, model_type, model_name, input_size = (256, 256, 1), loss_ = 'binary_crossentropy',
21 | lr = 1e-4, w_decay = 5e-7, steps = 500, epoch_num = 10, val_steps = 15, save_history = True):
22 | """
23 | This function selects a model and fits the given generators with the given arguments.
24 | Then the history of the model and the model itself are saved.
25 |
26 | The generators 'trainGen' and 'validGen' represent the training and validation
27 | generators to fit the model.
28 |
29 | The string 'model_type' refers to the type of U-Net to use.
30 |
31 | The string 'model_name' refers to the name with which the model shall be saved.
32 |
33 | The tuple 'input_size' corresponds to the size of the input images and labels.
34 | Default value set to (256, 256, 1) (input images size is 256x256).
35 |
36 | The string 'loss_' represents the name of the loss that should be used.
37 | Default value set to 'binary_crossentropy'.
38 |
39 | The float 'lr' corresponds to the learning rate value for the training.
40 | Defaut value set to 1e-4.
41 |
42 | The float 'w_decay' corresponds to the weight decay value for the training.
43 | Default value set to 5e-7.
44 |
45 | The integer 'steps' refers to the number of steps between each epoch. This
46 | number should be big enough to allow for many augmentations.
47 | Default value set to 500.
48 |
49 | The integer 'epoch_num' refers to the number of epochs to be used for the training.
50 | Default value set to 10.
51 |
52 | The integer 'val_steps' refers to the number of steps for validation step of each
53 | epoch. This number should be equal to the number of validation images.
54 | Default value set to 15.
55 |
56 | The boolean 'save_history' refers to whether or not the history of the training
57 | should be saved.
58 | """
59 |
60 | # Load a model.
61 | if model_type == "unet_simple":
62 | model = unet(input_size = input_size, loss_ = loss_, learning_rate = lr, weight_decay = w_decay)
63 |
64 | elif model_type == "unet_weighted":
65 | model = unet_weights(input_size = input_size, learning_rate = lr, weight_decay = w_decay)
66 |
67 | # elif model_type == "unet_dm":
68 | # model = unet_distance(learning_rate = lr, weight_decay = w_decay)
69 |
70 | else:
71 | raise RuntimeError("Model type not recognized")
72 |
73 | # Callbacks.
74 | model_checkpoint = ModelCheckpoint('/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=model_name), monitor='val_loss', verbose=1, save_best_only=True)
75 | early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, mode='auto', restore_best_weights=True)
76 |
77 | # Fit.
78 | history = model.fit_generator(trainGen,
79 | steps_per_epoch=steps,
80 | epochs=epoch_num,
81 | callbacks=[model_checkpoint, early_stopping],
82 | validation_data = validGen,
83 | validation_steps = val_steps)
84 |
85 | if save_history:
86 |
87 | # Saving the history for plotting.
88 | pickle.dump(history.history, open('/content/drive/My Drive/unser_project/histories/{b}.p'.format(b=model_name), "wb" ))
89 |
90 | return None
91 |
92 | # -----------------------------------------------------------------------------
93 |
94 | def show_predictions(model, name_project, target_size = (256, 256)):
95 | """
96 | Shows one image with its ground truth and the prediction of the model (as
97 | a binary image and a probability map).
98 |
99 | The string 'model' corresponds to the type of model used.
100 |
101 | The string 'name_project' refers to the name of the DeepPix Worflow project
102 | given by the user.
103 |
104 | The tuple "target_size" is used to specify the final sizes of the images
105 | and labels. If the given size does not correspond to original size of the
106 | images and labels, the data will be resized with the given size.
107 | Default value is set to (256, 256) (image size of 256x256 pixels).
108 | """
109 |
110 | # Path of the test set
111 | test_path = "/content/drive/My Drive/unser_project/data/processed/" + name_project + "/test/"
112 |
113 | # List of files
114 | list_file = glob.glob(test_path + 'image/*.png')
115 |
116 | # Number of files (important for number of leading zeros)
117 | n_file = len(list_file)
118 |
119 | # Count number of digits in n_file. This is important for the number
120 | # of leading zeros in the name of the images and labels.
121 | n_digits = len(str(n_file))
122 |
123 | # Creates title depending on model type and prepares test generator
124 | # depending on model type.
125 | if model == "unet_simple":
126 | title = "Simple U-Net"
127 | testGen = dataGenerator(batch_size = 1, subset = "test", path = test_path)
128 | mdl = unet(input_size = (256,256,1))
129 |
130 | elif model == "unet_weighted":
131 | title = "Weighted U-Net"
132 | testGen = weightGen(batch_size = 1, subset = "test", path = test_path)
133 | mdl = unet_weights(input_size = (256,256,1))
134 |
135 | else:
136 | raise RuntimeError("Model not recognised.")
137 |
138 | # Loads one image and label.
139 | img_path = test_path + "image/{b:0" + str(n_digits) + "d}.png"
140 | lbl_path = test_path + "label/{b:0" + str(n_digits) + "d}.png"
141 |
142 | img = cv.imread(img_path.format(b=0))
143 | label = cv.imread(lbl_path.format(b=0))
144 |
145 | # Resizes to target size.
146 | img = cv.resize(img, target_size)
147 | label = cv.resize(label, target_size)
148 |
149 | # Load model and perform predictions.
150 | mdl.load_weights('/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=name_project))
151 | prediction = mdl.predict_generator(testGen, 2, verbose=1, workers=1)
152 |
153 | # Binarizes one prediction.
154 | pred_binarized = convertLabel(prediction[0])
155 |
156 | # Perform plot.
157 | fig, ax = plt.subplots(2, 2, sharex=True, sharey=True, figsize=((15,15)))
158 |
159 | ax[0,0].grid(False)
160 | ax[0,1].grid(False)
161 | ax[1,0].grid(False)
162 | ax[1,1].grid(False)
163 |
164 | ax[0,0].imshow(img, cmap = 'gray', aspect="auto")
165 | ax[0,1].imshow(label, cmap = 'gray', aspect="auto")
166 | ax[1,0].imshow(pred_binarized, cmap = 'gray', aspect="auto")
167 | ax[1,1].imshow(prediction[0,...,0], cmap = 'gray', aspect="auto", vmin=0, vmax=1)
168 |
169 | ax[0,0].set_title("Input", fontsize = 17.5)
170 | ax[0,1].set_title("Ground truth", fontsize = 17.5)
171 | ax[1,0].set_title(title + " - Binarized", fontsize = 17.5)
172 | ax[1,1].set_title(title + " - Probability map", fontsize = 17.5)
173 |
--------------------------------------------------------------------------------
/unet/py_files/helpers.py:
--------------------------------------------------------------------------------
1 | # Important librairies.
2 |
3 | from PIL import Image
4 | import glob
5 | import numpy as np
6 | import re
7 | import matplotlib.pyplot as plt
8 | from skimage import measure
9 | import scipy.ndimage
10 | import os
11 | import cv2
12 | import pickle
13 | import copy
14 | from tifffile import imsave
15 |
16 | # -----------------------------------------------------------------------------
17 |
18 | def prepare_standardplot(title, xlabel):
19 | """
20 | Prepares the layout and axis for the plotting of the history from the training.
21 |
22 | The string 'title' refers to the title of the plot.
23 |
24 | The string 'xlabel' refers to the name of the x-axis.
25 | """
26 |
27 | fig, (ax1, ax2) = plt.subplots(1, 2)
28 | fig.suptitle(title)
29 |
30 | ax1.set_ylabel('Binary cross-entropy')
31 | ax1.set_xlabel(xlabel)
32 | ax1.set_yscale('log')
33 |
34 | ax2.set_ylabel('Accuracy')
35 | ax2.set_xlabel(xlabel)
36 |
37 | return fig, ax1, ax2
38 |
39 | # -----------------------------------------------------------------------------
40 |
41 | def finalize_standardplot(fig, ax1, ax2):
42 | """
43 | Finalizes the layout of the plotting of the history from the training.
44 |
45 | The variable 'fig' refers to the created figure of the plot.
46 |
47 | The variables 'ax1' and 'ax2' refer to the axes of the plot.
48 | """
49 |
50 | ax1handles, ax1labels = ax1.get_legend_handles_labels()
51 | if len(ax1labels) > 0:
52 | ax1.legend(ax1handles, ax1labels)
53 |
54 | ax2handles, ax2labels = ax2.get_legend_handles_labels()
55 | if len(ax2labels) > 0:
56 | ax2.legend(ax2handles, ax2labels)
57 |
58 | fig.tight_layout()
59 |
60 | plt.subplots_adjust(top=0.9)
61 |
62 | # -----------------------------------------------------------------------------
63 |
64 | def plot_history(history, title):
65 | """
66 | Plots the history from the training of a model. More precisely, this function
67 | plots the training loss, the validation loss, the training accuracy and
68 | the validation accuracy of a model training.
69 |
70 | The variable 'history' refers to the history file that was saved after
71 | the training of the model.
72 |
73 | The string 'title' represents the title that the plot will have.
74 | """
75 |
76 | if title == "unet_simple":
77 | title = "Simple U-Net"
78 |
79 | elif title == "unet_weighted":
80 | title = "Weighted U-Net"
81 |
82 | fig, ax1, ax2 = prepare_standardplot(title, 'Epoch')
83 |
84 | ax1.plot(history['loss'], label = "Training")
85 | ax1.plot(history['val_loss'], label = "Validation")
86 |
87 | ax2.plot(history['acc'], label = "Training")
88 | ax2.plot(history['val_acc'], label = "Validation")
89 |
90 | finalize_standardplot(fig, ax1, ax2)
91 |
92 | return fig
93 |
94 | # -----------------------------------------------------------------------------
95 |
96 | def natural_keys(text):
97 | """
98 | Sorts the filelist in a more "human" order.
99 |
100 | The variable 'text' represents a file list that would be imported with
101 | the glob library.
102 | """
103 |
104 | def atoi(text):
105 | return int(text) if text.isdigit() else text
106 |
107 | return [atoi(c) for c in re.split('(\d+)', text)]
108 |
109 | # -----------------------------------------------------------------------------
110 |
111 | def load_data(path_images, path_labels):
112 | """
113 | Loads and returns images and labels.
114 |
115 | The variables 'path_images' and 'path_labels' refer to the paths of the
116 | folders containing the images and labels, respectively.
117 | """
118 |
119 | # Creates a list of file names in the data directory.
120 | filelist = glob.glob(path_images)
121 | filelist.sort(key=natural_keys)
122 |
123 | # Loads all data images in a list.
124 | data = [Image.open(fname) for fname in filelist]
125 |
126 | # Creates a list of file names in the labels directory.
127 | filelist = glob.glob(path_labels)
128 | filelist.sort(key=natural_keys)
129 |
130 | # Loads all labels images in a list.
131 | labels = [Image.open(fname) for fname in filelist]
132 |
133 | return data, labels
134 |
135 | # -----------------------------------------------------------------------------
136 |
137 | def check_binary(labels):
138 | """
139 | Checks if the given labels are binary or not.
140 |
141 | The variable "labels" correspond to a list of label images.
142 | """
143 |
144 | # Initialize output variable.
145 | binary = True
146 |
147 | # Check every label.
148 | for k in range(len(labels)):
149 |
150 | # Number of unique values (should be = 2 for binary labels or > 2 for
151 | # categorical or non-binary data).
152 | n_unique = len(np.unique(np.array(labels[k])))
153 |
154 | if n_unique > 2:
155 | binary = False
156 |
157 | # Raise exception if labels are constant images or not recognised.
158 | elif n_unique < 2:
159 | raise RuntimeError("Labels are neither binary or categorical.")
160 |
161 | return binary
162 |
163 | # -----------------------------------------------------------------------------
164 |
165 | def make_binary(labels):
166 | """
167 | Makes the given labels binary.
168 |
169 | The variable "labels" correspond to a list of label images.
170 | """
171 |
172 | # For each label, convert the image to a numpy array, binarizes the array
173 | # and converts back the array to an image.
174 | for i in range(len(labels)):
175 | tmp = np.array(labels[i])
176 | tmp[tmp > 0] = 255
177 | tmp[tmp == 0] = 0
178 | tmp = tmp.astype('uint8')
179 | tmp = Image.fromarray(tmp, 'L')
180 | labels[i] = tmp
181 |
182 | return labels
183 |
184 | # -----------------------------------------------------------------------------
185 |
186 | def save_data(data, labels, path):
187 | """
188 | Save images and labels.
189 |
190 | The variables 'data' and 'labels' refer to the processed images and labels.
191 |
192 | The string 'path' corresponds to the path where the images and labels will
193 | be saved.
194 | """
195 |
196 | # Number of images.
197 | n_data = len(data)
198 |
199 | # Count number of digits in n_data. This is important for the number
200 | # of leading zeros in the name of the images and labels.
201 | n_digits = len(str(n_data))
202 |
203 | # These represent the paths for the final label and images with the right
204 | # number of leading zeros given by n_digits.
205 | direc_d = path + "image/{b:0" + str(n_digits) + "d}.png"
206 | direc_l = path + "label/{b:0" + str(n_digits) + "d}.png"
207 |
208 | # Saves data and labels in the right folder.
209 | for i in range(len(data)):
210 | data[i].save(direc_d.format(b=i))
211 | labels[i].save(direc_l.format(b=i))
212 |
213 | return None
214 |
215 | # -----------------------------------------------------------------------------
216 |
217 | def split_data(X, y, ratio=0.8, seed=1):
218 | """
219 | The split_data function will shuffle data randomly as well as return
220 | a split data set that are individual for training and testing purposes.
221 |
222 | The input 'X' is a list of images.
223 |
224 | The input 'y' is a list of images with each image corresponding to the label
225 | of the corresponding sample in X.
226 |
227 | The 'ratio' variable is a float that sets the train set fraction of
228 | the entire dataset to this ratio and keeps the other part for test set.
229 | Default value set to 0.8.
230 |
231 | The 'seed' variable represents the seed value for the randomization of the
232 | process. Default value set to 1.
233 | """
234 |
235 | # Set seed.
236 | np.random.seed(seed)
237 |
238 | # Perform shuffling.
239 | idx_shuffled = np.random.permutation(len(y))
240 |
241 | # Return shuffled X and y.
242 | X_shuff = [X[i] for i in idx_shuffled]
243 | y_shuff = [y[i] for i in idx_shuffled]
244 |
245 | # Cut the data set into train and test.
246 | train_num = round(len(y) * ratio)
247 | X_train = X_shuff[:train_num]
248 | y_train = y_shuff[:train_num]
249 | X_test = X_shuff[train_num:]
250 | y_test = y_shuff[train_num:]
251 |
252 | return X_train, y_train, X_test, y_test
253 |
254 | # -----------------------------------------------------------------------------
255 |
256 | def convertLabel(lab, threshold = 0.5):
257 | """
258 | Converts the given label probability maps to a binary images using a specific
259 | threshold.
260 |
261 | The numpy array 'lab' correspond to label probability maps.
262 |
263 | The float 'threshold' corresponds to the threshold at which we binarize
264 | the probability map. Default value set to 0.5.
265 | """
266 |
267 | # Converts the labels into boolean values using a threshold.
268 | label = lab[...,0] > threshold
269 |
270 | # Converts the boolean values into 0 and 1.
271 | label = label.astype(int)
272 |
273 | # Converts the labels to have values 0 and 255.
274 | label[label == 1] = 255
275 |
276 | return label
277 |
278 | # -----------------------------------------------------------------------------
279 |
280 | def pred_accuracy(y_true, y_pred):
281 | """
282 | Computes the prediction accuracy.
283 |
284 | The numpy array 'y_true' corresponds to the true label.
285 |
286 | The numpy array 'y_pred' corresponds to the predicted label.
287 | """
288 |
289 | # Compares both the predictions and labels.
290 | compare = (y_true == y_pred)
291 |
292 | # Convert the resulting boolean values into 0 and 1.
293 | compare = compare.astype(int)
294 |
295 | # Computes the percentage of correct pixels.
296 | accuracy = np.sum(compare)/(len(y_true)**2)
297 |
298 | return accuracy
299 |
300 | # -----------------------------------------------------------------------------
301 |
302 | def saveResults(save_path, results, convert = True, threshold = 0.5):
303 | """
304 | Save the predicted arrays into a folder.
305 |
306 | The string 'save_path' corresponds to the path where the predicted images
307 | would be saved.
308 |
309 | The numpy array 'results' corresponds to the probability maps that were
310 | predicted with the model.
311 |
312 | The boolean 'convert' refers to whether or not the probability maps
313 | should be converted to binary arrays. Defaut value set to True.
314 |
315 | The float 'threshold' corresponds to the threshold at which we binarize
316 | the probability map. Default value set to 0.5.
317 | """
318 |
319 | # Number of predictions.
320 | n_result = len(results)
321 |
322 | # Count number of digits in n_result. This is important for the number
323 | # of leading zeros in the name of the predictions.
324 | n_digits = len(str(n_result))
325 |
326 | # These represent the paths for the predictions (binary or not) with the right
327 | # number of leading zeros given by n_digits.
328 | if convert:
329 | # Selects path for data and labels.
330 | direc_r = save_path + "result/{b:0" + str(n_digits) + "d}.tif"
331 | else:
332 | direc_r = save_path + "result_prob/{b:0" + str(n_digits) + "d}.tif"
333 |
334 |
335 | for i, lab in enumerate(results):
336 |
337 | if convert:
338 | # Converts the given label with a threshold.
339 | label = convertLabel(lab, threshold)
340 |
341 | else:
342 | label = lab[...,0]
343 |
344 | label = label.astype('float32')
345 |
346 | # Saves the label.
347 | imsave(direc_r.format(b=i), label)
348 |
349 | return None
350 |
351 | # -----------------------------------------------------------------------------
352 |
353 | def make_weight_map(label, binary = True, w0 = 10, sigma = 5):
354 | """
355 | Generates a weight map in order to make the U-Net learn better the
356 | borders of cells and distinguish individual cells that are tightly packed.
357 | These weight maps follow the methodololy of the original U-Net paper.
358 |
359 | The variable 'label' corresponds to a label image.
360 |
361 | The boolean 'binary' corresponds to whether or not the labels are
362 | binary. Default value set to True.
363 |
364 | The float 'w0' controls for the importance of separating tightly associated
365 | entities. Defaut value set to 10.
366 |
367 | The float 'sigma' represents the standard deviation of the Gaussian used
368 | for the weight map. Default value set to 5.
369 | """
370 |
371 | # Initialization.
372 | lab = np.array(label)
373 | lab_multi = lab
374 |
375 | # Get shape of label.
376 | rows, cols = lab.shape
377 |
378 | if binary:
379 |
380 | # Converts the label into a binary image with background = 0
381 | # and cells = 1.
382 | lab[lab == 255] = 1
383 |
384 |
385 | # Builds w_c which is the class balancing map. In our case, we want cells to have
386 | # weight 2 as they are more important than background which is assigned weight 1.
387 | w_c = np.array(lab, dtype=float)
388 | w_c[w_c == 1] = 1
389 | w_c[w_c == 0] = 0.5
390 |
391 | # Converts the labels to have one class per object (cell).
392 | lab_multi = measure.label(lab, neighbors = 8, background = 0)
393 |
394 | else:
395 |
396 | # Converts the label into a binary image with background = 0.
397 | # and cells = 1.
398 | lab[lab > 0] = 1
399 |
400 |
401 | # Builds w_c which is the class balancing map. In our case, we want cells to have
402 | # weight 2 as they are more important than background which is assigned weight 1.
403 | w_c = np.array(lab, dtype=float)
404 | w_c[w_c == 1] = 1
405 | w_c[w_c == 0] = 0.5
406 |
407 | components = np.unique(lab_multi)
408 |
409 | n_comp = len(components)-1
410 |
411 | maps = np.zeros((n_comp, rows, cols))
412 |
413 | map_weight = np.zeros((rows, cols))
414 |
415 | if n_comp >= 2:
416 | for i in range(n_comp):
417 |
418 | # Only keeps current object.
419 | tmp = (lab_multi == components[i+1])
420 |
421 | # Invert tmp so that it can have the correct distance.
422 | # transform
423 | tmp = ~tmp
424 |
425 | # For each pixel, computes the distance transform to
426 | # each object.
427 | maps[i][:][:] = scipy.ndimage.distance_transform_edt(tmp)
428 |
429 | maps = np.sort(maps, axis=0)
430 |
431 | # Get distance to the closest object (d1) and the distance to the second
432 | # object (d2).
433 | d1 = maps[0][:][:]
434 | d2 = maps[1][:][:]
435 |
436 | map_weight = w0*np.exp(-((d1+d2)**2)/(2*(sigma**2)) ) * (lab==0).astype(int);
437 |
438 | map_weight += w_c
439 |
440 | return map_weight
441 |
442 | # -----------------------------------------------------------------------------
443 |
444 | def do_save_wm(labels, path, binary = True, w0 = 10, sigma = 5):
445 | """
446 | Retrieves the label images, applies the weight-map algorithm and save the
447 | weight maps in a folder.
448 |
449 | The variable 'labels' corresponds to given label images.
450 |
451 | The string 'path' refers to the path where the weight maps should be saved.
452 |
453 | The boolean 'binary' corresponds to whether or not the labels are
454 | binary. Default value set to True.
455 |
456 | The float 'w0' controls for the importance of separating tightly associated
457 | entities. Default value set to 10.
458 |
459 | The float 'sigma' represents the standard deviation of the Gaussian used
460 | for the weight map. Default value set to 5.
461 | """
462 |
463 | # Copy labels.
464 | labels_ = copy.deepcopy(labels)
465 |
466 | # Perform weight maps.
467 | for i in range(len(labels_)):
468 | labels_[i] = make_weight_map(labels[i].copy(), binary, w0, sigma)
469 |
470 | maps = np.array(labels_)
471 |
472 | n, rows, cols = maps.shape
473 |
474 | # Resize correctly the maps so that it can be used in the model.
475 | maps = maps.reshape((n, rows, cols, 1))
476 |
477 | # Count number of digits in n. This is important for the number
478 | # of leading zeros in the name of the maps.
479 | n_digits = len(str(n))
480 |
481 | # Save path with correct leading zeros.
482 | path_to_save = path + "weight/{b:0" + str(n_digits) + "d}.npy"
483 |
484 | # Saving files as .npy files.
485 | for i in range(len(labels_)):
486 | np.save(path_to_save.format(b=i), labels_[i])
487 |
488 | return None
489 |
490 | # -----------------------------------------------------------------------------
491 |
492 | #def make_distance_map(label):
493 | # """Generates a distance map from labels in order to test distance-map-based
494 | # U-Net training."""
495 | #
496 | # lab = np.array(label)
497 | #
498 | # # Converts the label into a binary image with background = 0
499 | # # and cells = 1.
500 | # lab[lab == 255] = 1
501 | #
502 | # # Applies distance transform
503 | # output = cv2.distanceTransform(lab, cv2.DIST_C, 3)
504 | #
505 | # # Finds minimal cell size
506 | # size = 0
507 | # all_dist = np.unique(output)
508 | # blobbed_lab = measure.label(lab, neighbors = 8, background = 0)
509 | # number_blobs = np.max(blobbed_lab)
510 | # for i in all_dist[1:]:
511 | # tmp = (output >= i).astype(int)
512 | # blobbed_lab = measure.label(tmp, neighbors = 8, background = 0)
513 | # if number_blobs <= np.max(blobbed_lab):
514 | # size = i
515 | #
516 | # return output, size
517 | #
518 | ## -----------------------------------------------------------------------------
519 | #
520 | #def do_make_dm(path):
521 | # """Retrieves the label images, applies the distance transform and save the
522 | # maps in the right folder."""
523 | #
524 | # path_to_labels = path + "/label/*.png"
525 | #
526 | # # Creates a list of file names in the labels directory
527 | # filelist = glob.glob(path_to_labels)
528 | # filelist.sort(key=natural_keys)
529 | #
530 | # # Loads all data images in a list
531 | # labels = [Image.open(fname).resize((256,256)) for fname in filelist]
532 | #
533 | # # Copy labels
534 | # labels_ = labels
535 | #
536 | # # Vector of sizes
537 | # sizes = []
538 | #
539 | # # Do maps
540 | # print("Doing distance maps")
541 | # for i in range(len(labels_)):
542 | # labels_[i], size = make_distance_map(labels_[i])
543 | # sizes.append(size)
544 | # print("Maps done")
545 | #
546 | # min_size = np.min(np.array(sizes))
547 | # print("Min size : {b}".format(b=min_size))
548 | # print(sizes)
549 | #
550 | # maps = np.array(labels_)
551 | #
552 | # maps[maps >= min_size] = min_size
553 | #
554 | # n, rows, cols = maps.shape
555 | #
556 | # # Makes sure the data is saved with one leading zero.
557 | # if (n < 100):
558 | #
559 | # # Selects path for data and labels
560 | # direc_r = path + "/distance/{b:02d}.png"
561 | #
562 | # # If we have more than 100 images, we would have 2 leading zeros.
563 | # # We have 148 images, so there is no point doing other cases.
564 | # else:
565 | #
566 | # # Selects path for data and labels
567 | # direc_r = path + "/distance/{b:03d}.png"
568 | #
569 | # for i, lab in enumerate(maps):
570 | #
571 | # label = lab.astype('uint8')
572 | # label = Image.fromarray(label, 'L')
573 | #
574 | # # Saves the label
575 | # label.save(direc_r.format(b=i))
576 | #
577 | # return None
578 | #
579 | ## -----------------------------------------------------------------------------
580 | #
581 | #def make_three_class (label):
582 | #
583 | # lab = np.array(label)
584 | #
585 | # # Get shape of label
586 | # rows, cols = lab.shape
587 | #
588 | # components = np.unique(lab)
589 | #
590 | # n_comp = len(components)-1
591 | #
592 | # output = np.zeros((rows, cols))
593 | #
594 | # for i in range(n_comp):
595 | #
596 | # # Only keeps current object
597 | # tmp = (lab == components[i+1]).astype('float32')
598 | #
599 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
600 | #
601 | # eroded_tmp = cv2.erode(tmp, kernel, iterations = 1)
602 | #
603 | # border = tmp - eroded_tmp
604 | #
605 | # output[border > 0] = 1
606 | # output[eroded_tmp > 0] = 2
607 | #
608 | # output = output.astype('uint8')
609 | # output = Image.fromarray(output, 'L')
610 | #
611 | # return output
--------------------------------------------------------------------------------
/unet/py_files/model.py:
--------------------------------------------------------------------------------
1 | # Important librairies
2 |
3 | import numpy as np
4 | import os
5 | import skimage.io as io
6 | import skimage.transform as trans
7 | import numpy as np
8 | from keras.models import *
9 | from keras.layers import *
10 | from keras.optimizers import *
11 | from keras.callbacks import *
12 | from keras import backend as keras
13 |
14 | # -----------------------------------------------------------------------------
15 |
16 | def jaccard_distance(y_true, y_pred, smooth=100):
17 | """Intersection-over-union loss (Jaccard distance)."""
18 | intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
19 | sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
20 | jac = (intersection + smooth) / (sum_ - intersection + smooth)
21 | return (1 - jac) * smooth
22 |
23 | # -----------------------------------------------------------------------------
24 |
25 | def unet(input_size = (256,256,1), loss_ = 'binary_crossentropy', learning_rate = 1e-4, weight_decay = 5e-7):
26 | """
27 | Simple U-net architecture.
28 |
29 | The tuple 'input_size' corresponds to the size of the input images and labels.
30 | Default value set to (256, 256, 1) (input images size is 256x256).
31 |
32 | The string 'loss_' represents the name of the loss that should be used.
33 | Default value set to 'binary_crossentropy'.
34 |
35 | The float 'learning_rate' corresponds to the learning rate value for the training.
36 | Defaut value set to 1e-4.
37 |
38 | The float 'weight_decay' corresponds to the weight decay value for the training.
39 | Default value set to 5e-7.
40 | """
41 |
42 | # Get input.
43 | input_img = Input(input_size)
44 |
45 | # Layer 1.
46 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img)
47 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
48 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
49 |
50 | # Layer 2.
51 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
52 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
53 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
54 |
55 | # Layer 3.
56 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
57 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
58 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
59 |
60 | # Layer 4.
61 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
62 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
63 | drop4 = Dropout(0.5)(conv4)
64 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
65 |
66 | # layer 5.
67 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
68 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
69 | drop5 = Dropout(0.5)(conv5)
70 |
71 | # Layer 6.
72 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
73 | merge6 = concatenate([drop4,up6], axis = 3)
74 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
75 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
76 |
77 | # Layer 7.
78 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
79 | merge7 = concatenate([conv3,up7], axis = 3)
80 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
81 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
82 |
83 | # Layer 8.
84 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
85 | merge8 = concatenate([conv2,up8], axis = 3)
86 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
87 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
88 |
89 | # Layer 9.
90 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
91 | merge9 = concatenate([conv1,up9], axis = 3)
92 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
93 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
94 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
95 |
96 | # Final layer (output).
97 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
98 |
99 | # Specify input and output.
100 | model = Model(inputs = input_img, outputs = conv10)
101 |
102 | # Use Adam optimizer, binary cross-entropy loss and specify metrics.
103 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = loss_, metrics = ['accuracy'])
104 |
105 | return model
106 |
107 |
108 |
--------------------------------------------------------------------------------
/unet/py_files/prep_data.py:
--------------------------------------------------------------------------------
1 | # Important librairies
2 |
3 | import numpy
4 | import sys
5 | import os
6 | import pandas as pd
7 |
8 | # -----------------------------------------------------------------------------
9 |
10 | # Important py.files
11 | sys.path.append("/content/drive/My Drive/unser_project/py_files")
12 | from helpers import *
13 | from data_loading import *
14 | from fit_model import *
15 |
16 | # -----------------------------------------------------------------------------
17 |
18 | def read_config(name_project, name_config):
19 | """
20 | Reads configuration file from DeepPix Workflow plug-in.
21 |
22 | The string 'name_project' refers to the name of the DeepPix Worflow project
23 | given by the user.
24 |
25 | The string 'name_config' refers to the name of the configuration file given
26 | by the user.
27 | """
28 |
29 | # Builds path to the configuration file.
30 | path_to_config = '/content/drive/My Drive/unser_project/data/raw/' + name_project + '/' + name_config + '-training-settings.txt'
31 |
32 | # Load configuration file
33 | df = pd.read_table(path_to_config, header = None, delimiter = '=', dtype = str, skiprows = 5)
34 |
35 | input_array = []
36 |
37 | # Process input dataframe.
38 | for i in range(df.shape[0]):
39 | input_array.append(df[1][i][1:])
40 |
41 | # The following code allocates the input configurations to variables that
42 | # will be used for the rest of the program.
43 |
44 | label_type = input_array[3]
45 |
46 | size = input_array[4]
47 | target_size = ()
48 |
49 | if size == "256x256":
50 | target_size = (256, 256)
51 |
52 | elif size == "512x512":
53 | target_size = (512, 512)
54 |
55 | elif size == "1024x1024":
56 | target_size = (1024, 1024)
57 |
58 | else:
59 | raise RuntimeError("Input size unknown")
60 |
61 | model = input_array[5]
62 | model_type = ""
63 |
64 | if model == "Simple U-Net":
65 | model_type = "unet_simple"
66 |
67 | elif model == "Weighted U-Net":
68 | model_type = "unet_weighted"
69 |
70 | split_ratio = float(input_array[6])/100
71 |
72 | batch_size = int(input_array[7])
73 |
74 | learning_rate = float(input_array[8])*1e-5
75 |
76 | return label_type, target_size, model_type, split_ratio, batch_size, learning_rate
77 |
78 | # -----------------------------------------------------------------------------
79 |
80 | def prep_data(name_project, model, label_type, split_ratio, w0 = None, sigma = None):
81 | """
82 | Prepares the data by randomizing the images and binarizing them if needed.
83 |
84 | The string 'name_project' refers to the name of the DeepPix Worflow project
85 | given by the user.
86 |
87 | The string 'model' refers to the type of model that will be used.
88 |
89 | The string 'label_type' corresponds to the type of model used, either
90 | categorical or binary.
91 |
92 | The float 'split_ratio' corresponds to the splitting ratio for the
93 | training and testing set.
94 |
95 | The float 'w0' corresponds to a constant used for the weighted U-Net.
96 | Default value set to None.
97 |
98 | The float 'sigma' corresponds to a constant used for the weighted U-Net.
99 | Default value set to None.
100 | """
101 |
102 | print("Initialization of preparation of data.")
103 |
104 | # Constructs useful paths.
105 |
106 | # Path for data.
107 | data_path = "/content/drive/My Drive/unser_project/data/"
108 |
109 | # Paths for raw data and labels.
110 | path_data = data_path + "raw/" + name_project + "/image/*.tif"
111 | path_labels = data_path + "raw/" + name_project + "/label/*.tif"
112 |
113 | # Paths for train and test directories.
114 | train_path = data_path + "processed/" + name_project + "/train/"
115 | test_path = data_path + "processed/" + name_project + "/test/"
116 |
117 | # Load data and labels.
118 | print("Loading data and labels.")
119 | data, labels = load_data(path_data, path_labels)
120 | print("Loading successful.")
121 |
122 | print("Label type check and binarization if needed.")
123 | # Checks if labels are binary or categorical.
124 | binary = check_binary(labels)
125 |
126 | # Check which model is desired and binarizes labels or not depending on the model.
127 | if model == "unet_simple":
128 |
129 | if not binary:
130 | labels = make_binary(labels)
131 |
132 | elif model == "unet_weighted":
133 |
134 | if label_type == "categorical":
135 |
136 | if binary:
137 | raise RuntimeError("Labels are said to be categorical but they are not categorical.")
138 |
139 | elif label_type == "binary":
140 |
141 | if not binary:
142 | labels = make_binary(labels)
143 |
144 | else:
145 | raise RuntimeError("Labels are neither categorical or binary.")
146 |
147 | else:
148 | raise RuntimeError("Model type not recognised.")
149 |
150 | print("Splitting data")
151 | X_train, y_train, X_test, y_test = split_data(data, labels, ratio = split_ratio)
152 |
153 | if not os.path.exists(train_path + 'image'):
154 | os.makedirs(train_path + 'image')
155 |
156 | if not os.path.exists(train_path + 'label'):
157 | os.makedirs(train_path + 'label')
158 |
159 | if not os.path.exists(test_path + 'image'):
160 | os.makedirs(test_path + 'image')
161 |
162 | if not os.path.exists(test_path + 'label'):
163 | os.makedirs(test_path + 'label')
164 |
165 | if model == "unet_weighted":
166 |
167 | if not os.path.exists(train_path + 'weight'):
168 | os.makedirs(train_path + 'weight')
169 |
170 | if not os.path.exists(test_path + 'weight'):
171 | os.makedirs(test_path + 'weight')
172 |
173 | not_connected = True
174 |
175 | if label_type == "categorical":
176 | not_connected = False
177 |
178 | print("Constructing weight maps.")
179 | do_save_wm(y_train, train_path, not_connected = not_connected, w0 = w0, sigma = sigma)
180 | do_save_wm(y_test, test_path, not_connected = not_connected, w0 = w0, sigma = sigma)
181 | print("Weight maps achieved")
182 |
183 | print("Saving data.")
184 | save_data(X_train, y_train, train_path)
185 | save_data(X_test, y_test, test_path)
186 | print("Preparation of data completed.")
--------------------------------------------------------------------------------
/unet/py_files/unet_dm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import skimage.io as io
4 | import skimage.transform as trans
5 | import numpy as np
6 | from keras.models import *
7 | from keras.layers import *
8 | from keras.optimizers import *
9 | from keras.callbacks import *
10 | from keras import backend as keras
11 |
12 | def unet_distance(input_size = (256,256,1), learning_rate = 1e-4, weight_decay = 5e-7):
13 | """Simple U-net architecture with distance maps. """
14 |
15 | # Get input
16 | input_img = Input(input_size)
17 |
18 | # Layer 1
19 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img)
20 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
21 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
22 |
23 | # Layer 2
24 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
25 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
26 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
27 |
28 | # Layer 3
29 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
30 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
31 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
32 |
33 | # Layer 4
34 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
35 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
36 | drop4 = Dropout(0.5)(conv4)
37 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
38 |
39 | # layer 5
40 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
41 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
42 | drop5 = Dropout(0.5)(conv5)
43 |
44 | # Layer 6
45 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
46 | merge6 = concatenate([drop4,up6], axis = 3)
47 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
48 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
49 |
50 | # Layer 7
51 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
52 | merge7 = concatenate([conv3,up7], axis = 3)
53 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
54 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
55 |
56 | # Layer 8
57 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
58 | merge8 = concatenate([conv2,up8], axis = 3)
59 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
60 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
61 |
62 | # Layer 9
63 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
64 | merge9 = concatenate([conv1,up9], axis = 3)
65 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
66 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
67 | conv9 = Conv2D(3, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
68 |
69 | reshape = Reshape((3, 256 * 256), input_shape = (3, 256, 256))(conv9)
70 |
71 | permute = Permute((2,1))(reshape)
72 |
73 | activation = Activation('softmax')(permute)
74 |
75 | # Specify input and output
76 | model = Model(inputs = input_img, outputs = activation)
77 |
78 | # Use Adam optimizer, binary cross-entropy loss and specify metrics
79 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = 'categorical_crossentropy', metrics = ['accuracy'])
80 |
81 | return model
82 |
83 |
84 |
--------------------------------------------------------------------------------
/unet/py_files/unet_weights.py:
--------------------------------------------------------------------------------
1 | # Important librairies
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 | import numpy as np
7 | import os
8 | import skimage.io as io
9 | import skimage.transform as trans
10 | import numpy as np
11 | from keras.models import *
12 | from keras.layers import *
13 | from keras.optimizers import *
14 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard
15 | from keras import backend as K
16 | import tensorflow as tf
17 |
18 | # -----------------------------------------------------------------------------
19 |
20 | def binary_crossentropy_weighted(weights):
21 | """
22 | Custom binary cross entropy loss. The weights are used to multiply
23 | the results of the usual cross-entropy loss in order to give more weight
24 | to areas between cells close to one another.
25 |
26 | The variable 'weights' refers to input weight-maps.
27 | """
28 |
29 | def loss(y_true, y_pred):
30 |
31 | return K.mean(weights * K.binary_crossentropy(y_true, y_pred), axis=-1)
32 |
33 | return loss
34 |
35 | # -----------------------------------------------------------------------------
36 |
37 | def unet_weights(input_size = (256,256,1), learning_rate = 1e-4, weight_decay = 5e-7):
38 | """
39 | Weighted U-net architecture.
40 |
41 | The tuple 'input_size' corresponds to the size of the input images and labels.
42 | Default value set to (256, 256, 1) (input images size is 256x256).
43 |
44 | The float 'learning_rate' corresponds to the learning rate value for the training.
45 | Defaut value set to 1e-4.
46 |
47 | The float 'weight_decay' corresponds to the weight decay value for the training.
48 | Default value set to 5e-7.
49 | """
50 |
51 | # Get input.
52 | input_img = Input(input_size)
53 |
54 | # Get weights.
55 | weights = Input(input_size)
56 |
57 | # Layer 1.
58 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img)
59 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
60 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
61 |
62 | # Layer 2.
63 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
64 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
65 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
66 |
67 | # Layer 3.
68 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
69 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
70 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
71 |
72 | # Layer 4.
73 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
74 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
75 | drop4 = Dropout(0.5)(conv4)
76 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
77 |
78 | # layer 5.
79 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
80 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
81 | drop5 = Dropout(0.5)(conv5)
82 |
83 | # Layer 6.
84 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
85 | merge6 = concatenate([drop4,up6], axis = 3)
86 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
87 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
88 |
89 | # Layer 7.
90 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
91 | merge7 = concatenate([conv3,up7], axis = 3)
92 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
93 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
94 |
95 | # Layer 8.
96 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
97 | merge8 = concatenate([conv2,up8], axis = 3)
98 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
99 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
100 |
101 | # Layer 9.
102 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
103 | merge9 = concatenate([conv1,up9], axis = 3)
104 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
105 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
106 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
107 |
108 | # Final layer (output).
109 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)
110 |
111 | # Specify input (image + weights) and output.
112 | model = Model(inputs = [input_img, weights], outputs = conv10)
113 |
114 | # Use Adam optimizer, custom weighted binary cross-entropy loss and specify metrics
115 | # Also use weights inside the loss function.
116 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = binary_crossentropy_weighted(weights), metrics = ['accuracy'])
117 |
118 | return model
119 |
120 |
121 |
--------------------------------------------------------------------------------
/unet/train_and_test_unet.ipynb:
--------------------------------------------------------------------------------
1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"unet_simple_github.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"RSSaG4-n1qcO"},"source":["# **Mounts your drive**"]},{"cell_type":"code","metadata":{"colab_type":"code","id":"BAEopvk_l7wg","colab":{}},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NL0Ji_OcLD-z","colab_type":"text"},"source":["Prepare data\n","==="]},{"cell_type":"code","metadata":{"id":"rhK06MkxLDTP","colab_type":"code","colab":{}},"source":["!pip install tifffile"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"kcdpETWxLDo7","colab_type":"code","colab":{}},"source":["# Functions\n","import numpy\n","import sys\n","sys.path.append(\"/content/drive/My Drive/.../unet_segmentation/py_files\") # path to py_files folder\n","from helpers import *\n","from data_loading import *\n","\n","# Autoreload\n","%load_ext autoreload\n","%autoreload 2\n","\n","# Set random seed\n","np.random.seed(1)\n","\n","# Load raw data from the Cell Tracking Challenge http://celltrackingchallenge.net/2d-datasets/. Download first the data.\n","data, labels = load_data(\"/content/drive/My Drive/.../unet_segmentation/data/raw/hela/image/*.tif\",\n"," \"/content/drive/My Drive/.../unet_segmentation/data/raw/hela/label/*.tif\") # Set correct paths\n","for i in range(len(labels)):\n"," tmp = np.array(labels[i])\n"," tmp[tmp > 0] = 255\n"," tmp[tmp == 0] = 0\n"," tmp = tmp.astype('uint8')\n"," tmp = Image.fromarray(tmp, 'L')\n"," labels[i] = tmp\n"," \n","# Split the data into train and test\n","X_train, y_train, X_test, y_test = split_data(data, labels, ratio = 0.5)\n","\n","# Set the paths and create the folders to save preprocessed data as .png\n","TRAIN_DIR=\"/content/drive/My Drive/unser_project/data/processed/hela/train/\"\n","TEST_DIR=\"/content/drive/My Drive/unser_project/data/processed/hela/test/\"\n","\n","if not os.path.exists(TRAIN_DIR):\n"," os.makedirs(TRAIN_DIR+\"/image/\")\n"," os.makedirs(TRAIN_DIR+\"/label/\")\n"," \n","if not os.path.exists(TEST_DIR):\n"," os.makedirs(TEST_DIR+\"/image/\")\n"," os.makedirs(TEST_DIR+\"/label/\")\n"," \n","# Save train and test files\n","save_data(X_train, y_train, TRAIN_DIR)\n","save_data(X_test, y_test, TEST_DIR)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"1_qBzyHg1ohr"},"source":["# **Imports modules**"]},{"cell_type":"code","metadata":{"colab_type":"code","id":"sk4cEzphpv4_","colab":{}},"source":["import sys\n","sys.path.append(\"/content/drive/My Drive/.../unet_segmentation/py_files\") # path to py_files folder\n","!pip install tifffile\n","!pip install --upgrade tensorflow\n","!pip install --upgrade keras\n","from model import *\n","from convert_to_pb import *\n","from data_loading import *\n","from helpers import *\n","from unet_weights import *\n","from fit_model import *\n","%matplotlib inline\n","import matplotlib.pyplot as plt\n","import matplotlib\n","from PIL import Image, ImageOps, ImageFilter\n","import pickle\n","from test import *\n","import cv2 as cv\n","\n","# Autoreload\n","%load_ext autoreload\n","%autoreload 2\n","%reload_ext autoreload\n","\n","# Set random seed\n","np.random.seed(1)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ifxpZTDtmyPn","colab_type":"text"},"source":["Train U-Net\n","==="]},{"cell_type":"markdown","metadata":{"id":"fqOD0FTkJMCT","colab_type":"text"},"source":["## Hela cells"]},{"cell_type":"code","metadata":{"id":"Kb0aNeIKy6cj","colab_type":"code","colab":{}},"source":["# Load training and validation data\n","# Note that the subset of generator used for the training generator is \"validation\" because we don't want to augment our data\n","# Specify paths where inside there are \"image\" and \"label\" folder\n","trainGen = dataGenerator(batch_size = 2, subset = \"train\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/train')\n","validGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"6A5LXMcPQUV9","colab_type":"code","colab":{}},"source":["model = unet()\n","\n","# Callbacks\n","model_checkpoint = ModelCheckpoint('/content/drive/My Drive/.../unet_segmentation/models/{b}.hdf5'.format(b=\"unet_hela\"), monitor='val_loss', verbose=1, save_best_only=True)\n","\n","# Fit\n","history = model.fit_generator(trainGen,\n"," steps_per_epoch=500,\n"," epochs=1,\n"," callbacks=[model_checkpoint], \n"," validation_data = validGen, \n"," validation_steps = 9)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bWigI7T0Mlyp","colab_type":"text"},"source":["Results\n","==="]},{"cell_type":"markdown","metadata":{"id":"59BUzz8bc_GN","colab_type":"text"},"source":["## Hela cells"]},{"cell_type":"code","metadata":{"id":"xf_rj9545A2j","colab_type":"code","colab":{}},"source":["# Define paths.\n"," path_to_model = '/content/drive/My Drive/.../unet_segmentation/models/unet_hela.hdf5'\n"," \n"," # Load model.\n"," model = load_model(path_to_model)\n"," # Load training and validation data\n"," # Note that the subset of generator used for the training generator is \"validation\" because we don't want to augment our data\n"," print(\"Validation\")\n"," validGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')\n"," \n"," accuracies = model.evaluate_generator(validGen, steps=9, verbose=1) \n"," print(accuracies)\n"," \n"," print(\"Training\")\n"," trainGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/train')\n"," \n"," accuracies = model.evaluate_generator(validGen, steps=8, verbose=1) \n"," print(accuracies)\n"," "],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UoRhKGShIvke","colab_type":"text"},"source":["Prediction\n","==="]},{"cell_type":"code","metadata":{"id":"7zq_yID3jP8G","colab_type":"code","colab":{}},"source":["import tensorflow as tf"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"vccfXzFJIu6z","colab_type":"code","colab":{}},"source":["from tensorflow.contrib.saved_model import save_keras_model\n","import tensorflow.keras\n","from keras.models import load_model\n","testGen = dataGenerator(batch_size = 1, subset = \"test\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')\n","model = unet()\n","model = load_model('/content/drive/My Drive/.../unet_segmentation/models/unet_hela.hdf5')\n","results = model.predict_generator(testGen,9,verbose=1, workers=1)\n","#saveResults('/content/drive/My Drive/.../unet_segmentation/data/hela/test/', results, convert = True)\n","#saveResults('/content/drive/My Drive/.../unet_segmentation/data/hela/test/', results, convert = False)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"wIRWpNL0XC_Q","colab_type":"code","colab":{}},"source":["from sklearn.metrics import jaccard_similarity_score\n"," \n","acc_tot = [];\n","\n","for i in range(9):\n"," label = cv.imread('/content/drive/My Drive/.../unet_segmentation/data/hela/test/label/0{b}.png'.format(b=i))\n"," label = cv.resize(label, (256,256))\n"," acc = jaccard_similarity_score(label[...,0].flatten(), convertLabel(results[i]).flatten())\n"," acc_tot.append(acc)\n"," \n","print(\"Jaccard average : {b}\".format(b=np.mean(acc_tot)))"],"execution_count":0,"outputs":[]}]}
--------------------------------------------------------------------------------
/xml/config_template.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/xml/create_config.py:
--------------------------------------------------------------------------------
1 | """
2 | DeepImageJ
3 |
4 | https://deepimagej.github.io/deepimagej/
5 |
6 | Conditions of use:
7 |
8 | DeepImageJ is an open source software (OSS): you can redistribute it and/or modify it under
9 | the terms of the BSD 2-Clause License.
10 |
11 | In addition, we strongly encourage you to include adequate citations and acknowledgments
12 | whenever you present or publish results that are based on it.
13 |
14 | DeepImageJ is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
15 | without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
16 |
17 | You should have received a copy of the BSD 2-Clause License along with DeepImageJ.
18 | If not, see .
19 |
20 |
21 | Reference:
22 |
23 | DeepImageJ: A user-friendly plugin to run deep learning models in ImageJ
24 | E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, L. Donati, M. Unser, A. Munoz-Barrutia, D. Sage.
25 | Submitted 2019.
26 |
27 | Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain
28 | Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland
29 |
30 | Corresponding authors: mamunozb@ing.uc3m.es, daniel.sage@epfl.ch
31 |
32 | Copyright 2019. Universidad Carlos III, Madrid, Spain and EPFL, Lausanne, Switzerland.
33 |
34 | """
35 |
36 | import os
37 | import xml.etree.ElementTree as ET
38 | import time
39 | import numpy as np
40 | import urllib
41 | import shutil
42 | from skimage import io
43 |
44 | """
45 | Download the template from this link:
46 | https://raw.githubusercontent.com/esgomezm/python4deepimagej/yaml/yaml/config_template.xml
47 | TensorFlow library is needed. It is imported later to save the model as a SavedModel protobuffer
48 |
49 | Try to check TensorFlow version and read DeepImageJ's compatibility requirements.
50 |
51 | import tensorflow as tf
52 | tf.__version__
53 | ----------------------------------------------------
54 | Example:
55 | ----------------------------------------------------
56 | dij_config = DeepImageJConfig(model)
57 | # Update model information
58 | dij_config.Authors = authors
59 | dij_config.Credits = credits
60 |
61 | # Add info about the minimum size in case it is not fixed.
62 | pooling_steps = 0
63 | for keras_layer in model.layers:
64 | if keras_layer.name.startswith('max') or "pool" in keras_layer.name:
65 | pooling_steps += 1
66 | dij_config.MinimumSize = np.str(2**(pooling_steps))
67 |
68 | # Add the information about the test image
69 | dij_config.add_test_info(test_img, test_prediction, PixelSize)
70 |
71 | ## Prepare preprocessing file
72 | path_preprocessing = "PercentileNormalization.ijm"
73 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/PercentileNormalization.ijm", path_preprocessing )
74 | # Include the info about the preprocessing
75 | dij_config.add_preprocessing(path_preprocessing, "preprocessing")
76 |
77 | ## Prepare postprocessing file
78 | path_postprocessing = "8bitBinarize.ijm"
79 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm", path_postprocessing )
80 | # Include the info about the postprocessing
81 | post_processing_name = "postprocessing_LocalMaximaSMLM"
82 | dij_config.add_postprocessing(path_postprocessing_max,post_processing_name)
83 |
84 | ## EXPORT THE MODEL
85 | deepimagej_model_path = os.path.join(QC_model_folder, 'deepimagej')
86 | dij_config.export_model(model, deepimagej_model_path)
87 | ----------------------------------------------------
88 | Example: change one line in an ImageJ macro
89 | ----------------------------------------------------
90 | ## Prepare postprocessing file
91 | path_postprocessing = "8bitBinarize.ijm"
92 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm", path_postprocessing )
93 | # Modify the threshold in the macro to the chosen threshold
94 | ijmacro = open(path_postprocessing,"r")
95 | list_of_lines = ijmacro. readlines()
96 | # Line 21 is the one corresponding to the optimal threshold
97 | list_of_lines[21] = "optimalThreshold = {}\n".format(128)
98 | ijmacro.close()
99 | ijmacro = open(path_postprocessing,"w")
100 | ijmacro. writelines(list_of_lines)
101 | ijmacro. close()
102 | """
103 |
104 | class DeepImageJConfig:
105 | def __init__(self, tf_model):
106 | # ModelInformation
107 | self.Name = 'null'
108 | self.Authors = 'null'
109 | self.URL = 'null'
110 | self.Credits = 'null'
111 | self.Version = 'null'
112 | self.References = 'null'
113 | self.Date = time.ctime()
114 | # Same value as 2**pooling_steps
115 | # (related to encoder-decoder archtiectures) when the input size is not
116 | # fixed
117 | self.MinimumSize = '8'
118 | self.get_dimensions(tf_model)
119 | # Receptive field of the network to process input
120 | self.Padding = np.str(self._pixel_half_receptive_field(tf_model))
121 | self.Preprocessing = list()
122 | self.Postprocessing = list()
123 | self.Preprocessing_files = list()
124 | self.Postprocessing_files = list()
125 |
126 | def get_dimensions(self, tf_model):
127 | """
128 | Calculates the array organization and shapes of inputs and outputs.
129 | """
130 | input_dim = tf_model.input_shape
131 | output_dim = tf_model.output_shape
132 | # Deal with the order of the dimensions and whether the size is fixed
133 | # or not
134 | if input_dim[2] is None:
135 | self.FixedPatch = 'false'
136 | self.PatchSize = self.MinimumSize
137 | if input_dim[-1] is None:
138 | self.InputOrganization0 = 'NCHW'
139 | self.Channels = np.str(input_dim[1])
140 | else:
141 | self.InputOrganization0 = 'NHWC'
142 | self.Channels = np.str(input_dim[-1])
143 |
144 | if output_dim[-1] is None:
145 | self.OutputOrganization0 = 'NCHW'
146 | else:
147 | self.OutputOrganization0 = 'NHWC'
148 | else:
149 | self.FixedPatch = 'true'
150 | self.PatchSize = np.str(input_dim[2])
151 |
152 | if input_dim[-1] < input_dim[-2] and input_dim[-1] < input_dim[-3]:
153 | self.InputOrganization0 = 'NHWC'
154 | self.Channels = np.str(input_dim[-1])
155 | else:
156 | self.InputOrganization0 = 'NCHW'
157 | self.Channels = np.str(input_dim[1])
158 |
159 | if output_dim[-1] < output_dim[-2] and output_dim[-1] < output_dim[-3]:
160 | self.OutputOrganization0 = 'NHWC'
161 | else:
162 | self.OutputOrganization0 = 'NCHW'
163 |
164 | # Adapt the format from brackets to parenthesis
165 | input_dim = np.str(input_dim)
166 | input_dim = input_dim.replace('(', ',')
167 | input_dim = input_dim.replace(')', ',')
168 | input_dim = input_dim.replace('None', '-1')
169 | input_dim = input_dim.replace(' ', "")
170 | self.InputTensorDimensions = input_dim
171 |
172 | def _pixel_half_receptive_field(self, tf_model):
173 | """
174 | The halo is equivalent to the receptive field of one pixel. This value
175 | is used for image reconstruction when a entire image is processed.
176 | """
177 | input_shape = tf_model.input_shape
178 |
179 | if self.FixedPatch == 'false':
180 | min_size = 50*np.int(self.MinimumSize)
181 |
182 | if self.InputOrganization0 == 'NHWC':
183 | null_im = np.zeros((1, min_size, min_size, input_shape[-1])
184 | , dtype=np.float32)
185 | else:
186 | null_im = np.zeros((1, input_shape[1], min_size, min_size)
187 | , dtype=np.float32)
188 | else:
189 | null_im = np.zeros((input_shape[1:])
190 | , dtype=np.float32)
191 | null_im = np.expand_dims(null_im, axis=0)
192 | min_size = np.int(self.PatchSize)
193 |
194 | point_im = np.zeros_like(null_im)
195 | min_size = np.int(min_size/2)
196 |
197 | if self.InputOrganization0 == 'NHWC':
198 | point_im[0,min_size,min_size] = 1
199 | else:
200 | point_im[0,:,min_size,min_size] = 1
201 |
202 | result_unit = tf_model.predict(np.concatenate((null_im, point_im)))
203 |
204 | D = np.abs(result_unit[0]-result_unit[1])>0
205 |
206 | if self.InputOrganization0 == 'NHWC':
207 | D = D[:,:,0]
208 | else:
209 | D = D[0,:,:]
210 |
211 | ind = np.where(D[:min_size,:min_size]==1)
212 | halo = np.min(ind[1])
213 | halo = min_size-halo+1
214 |
215 | return halo
216 |
217 | class TestImage:
218 | def __add__(self, input_im, output_im, pixel_size):
219 | """
220 | pixel size must be given in microns
221 | """
222 | self.Input_shape = '{0}x{1}'.format(input_im.shape[0], input_im.shape[1])
223 | self.InputImage = input_im
224 | self.Output_shape = '{0}x{1}'.format(output_im.shape[0], output_im.shape[1])
225 | self.OutputImage = output_im
226 | self.MemoryPeak = 'null'
227 | self.Runtime = 'null'
228 | self.PixelSize = '{0}µmx{1}µm'.format(pixel_size, pixel_size)
229 |
230 | def add_test_info(self, input_im, output_im, pixel_size):
231 | self.test_info = self.TestImage()
232 | self.test_info.__add__(input_im, output_im, pixel_size)
233 |
234 | def add_preprocessing(self, file, name):
235 | file_extension = file.split('.')[-1]
236 | name = name + '.' + file_extension
237 | if name.startswith('preprocessing'):
238 | self.Preprocessing.insert(len(self.Preprocessing),name)
239 | else:
240 | name = "preprocessing_"+name
241 | self.Preprocessing.insert(len(self.Preprocessing),name)
242 | self.Preprocessing_files.insert(len(self.Preprocessing_files), file)
243 |
244 | def add_postprocessing(self, file, name):
245 | file_extension = file.split('.')[-1]
246 | name = name + '.' + file_extension
247 | if name.startswith('postprocessing'):
248 | self.Postprocessing.insert(len(self.Postprocessing), name)
249 | else:
250 | name = "postprocessing_" + name
251 | self.Postprocessing.insert(len(self.Postprocessing), name)
252 | self.Postprocessing_files.insert(len(self.Postprocessing_files), file)
253 |
254 |
255 | def export_model(self, tf_model,deepimagej_model_path, **kwargs):
256 | """
257 | Main function to export the model as a bundled model of DeepImageJ
258 | tf_model: tensorflow/keras model
259 | deepimagej_model_path: directory where DeepImageJ model is stored.
260 | """
261 | # Save the mode as protobuffer
262 | self.save_tensorflow_pb(tf_model, deepimagej_model_path)
263 |
264 | # extract the information about the testing image
265 | test_info = self.test_info
266 | io.imsave(os.path.join(deepimagej_model_path,'exampleImage.tiff'), self.test_info.InputImage)
267 | io.imsave(os.path.join(deepimagej_model_path,'resultImage.tiff'), self.test_info.OutputImage)
268 | print("Example images stored.")
269 |
270 | # write the DeepImageJ configuration as an xml file
271 | write_config(self, test_info, deepimagej_model_path)
272 |
273 | # Add preprocessing and postprocessing macros.
274 | # More than one is available, but the first one is set by default.
275 | for i in range(len(self.Preprocessing)):
276 | shutil.copy2(self.Preprocessing_files[i], os.path.join(deepimagej_model_path, self.Preprocessing[i]))
277 | print("ImageJ macro {} included in the bundled model.".format(self.Preprocessing[i]))
278 |
279 | for i in range(len(self.Postprocessing)):
280 | shutil.copy2(self.Postprocessing_files[i], os.path.join(deepimagej_model_path, self.Postprocessing[i]))
281 | print("ImageJ macro {} included in the bundled model.".format(self.Postprocessing[i]))
282 |
283 | # Zip the bundled model to download
284 | shutil.make_archive(deepimagej_model_path, 'zip', deepimagej_model_path)
285 | print("DeepImageJ model was successfully exported as {0}.zip. You can download and start using it in DeepImageJ.".format(deepimagej_model_path))
286 |
287 |
288 | def save_tensorflow_pb(self,tf_model, deepimagej_model_path):
289 | # Check whether the folder to save the DeepImageJ bundled model exists.
290 | # If so, it needs to be removed (TensorFlow requirements)
291 | # -------------- Other definitions -----------
292 | W = '\033[0m' # white (normal)
293 | R = '\033[31m' # red
294 | if os.path.exists(deepimagej_model_path):
295 | print(R+'!! WARNING: DeepImageJ model folder already existed and has been removed !!'+W)
296 | shutil.rmtree(deepimagej_model_path)
297 |
298 | import tensorflow as tf
299 | TF_VERSION = tf.__version__
300 | print("DeepImageJ model will be exported using TensorFlow version {0}".format(TF_VERSION))
301 | if TF_VERSION[:3] == "2.3":
302 | print(R+"DeepImageJ plugin is only compatible with TensorFlow version 1.x, 2.0.0, 2.1.0 and 2.2.0. Later versions are not suported in DeepImageJ."+W)
303 |
304 | def _save_model():
305 | if tf_version==2:
306 | """TODO: change it once TF 2.3.0 is available in JAVA"""
307 | from tensorflow.compat.v1 import saved_model
308 | from tensorflow.compat.v1.keras.backend import get_session
309 | else:
310 | from tensorflow import saved_model
311 | from keras.backend import get_session
312 |
313 | builder = saved_model.builder.SavedModelBuilder(deepimagej_model_path)
314 |
315 | signature = saved_model.signature_def_utils.predict_signature_def(
316 | inputs = {'input': tf_model.input},
317 | outputs = {'output': tf_model.output} )
318 |
319 | signature_def_map = { saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }
320 |
321 | builder.add_meta_graph_and_variables( get_session(),
322 | [saved_model.tag_constants.SERVING],
323 | signature_def_map=signature_def_map )
324 | builder.save()
325 | print("TensorFlow model exported to {0}".format(deepimagej_model_path))
326 |
327 | if TF_VERSION[0] == '1':
328 | tf_version = 1
329 | _save_model()
330 | else:
331 | tf_version = 2
332 | """TODO: change it once TF 2.3.0 is available in JAVA"""
333 | from tensorflow.keras.models import clone_model
334 | _weights = tf_model.get_weights(tf_model)
335 | with tf.Graph().as_default():
336 | # clone model in new graph and set weights
337 | _model = clone_model(tf_model)
338 | _model.set_weights(_weights)
339 | _save_model()
340 |
341 |
342 | def write_config(Config, TestInfo, config_path):
343 | """
344 | - Config: Class with all the information about the model's architecture and pre/post-processing
345 | - TestInfo: Metadata of the image provided as an example
346 | - config_path: path to the template of the configuration file.
347 | It can be downloaded from:
348 | https://raw.githubusercontent.com/deepimagej/python4deepimagej/blob/master/xml/config_template.xml
349 | The function updates the fields in the template provided with the
350 | information about the model and the example image.
351 | """
352 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/python4deepimagej/master/xml/config_template.xml", "config_template.xml")
353 | try:
354 | tree = ET.parse('config_template.xml')
355 | root = tree.getroot()
356 | except:
357 | print("config_template.xml not found.")
358 |
359 | # WorkCitation-Credits
360 | root[0][0].text = Config.Name
361 | root[0][1].text = Config.Authors
362 | root[0][2].text = Config.URL
363 | root[0][3].text = Config.Credits
364 | root[0][4].text = Config.Version
365 | root[0][5].text = Config.Date
366 | root[0][6].text = Config.References
367 |
368 | # ExampleImage
369 | root[1][0].text = TestInfo.Input_shape
370 | root[1][1].text = TestInfo.Output_shape
371 | root[1][2].text = TestInfo.MemoryPeak
372 | root[1][3].text = TestInfo.Runtime
373 | root[1][4].text = TestInfo.PixelSize
374 |
375 | # ModelArchitecture
376 | root[2][0].text = 'tf.saved_model.tag_constants.SERVING'
377 | root[2][1].text = 'tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY'
378 | root[2][2].text = Config.InputTensorDimensions
379 | root[2][3].text = '1'
380 | root[2][4].text = 'input'
381 | root[2][5].text = Config.InputOrganization0
382 | root[2][6].text = '1'
383 | root[2][7].text = 'output'
384 | root[2][8].text = Config.OutputOrganization0
385 | root[2][9].text = Config.Channels
386 | root[2][10].text = Config.FixedPatch
387 | root[2][11].text = Config.MinimumSize
388 | root[2][12].text = Config.PatchSize
389 | root[2][13].text = 'true'
390 | root[2][14].text = Config.Padding
391 | root[2][15].text = Config.Preprocessing[0]
392 | print("Preprocessing macro '{}' set by default".format(Config.Preprocessing[0]))
393 | root[2][16].text = Config.Postprocessing[0]
394 | print("Postprocessing macro '{}' set by default".format(Config.Postprocessing[0]))
395 | root[2][17].text = '1'
396 | try:
397 | tree.write(os.path.join(config_path,'config.xml'),encoding="UTF-8",xml_declaration=True, )
398 | print("DeepImageJ configuration file exported.")
399 | except:
400 | print("The directory {} does not exist.".format(config_path))
401 |
--------------------------------------------------------------------------------