├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── python4deepimagej.iml └── vcs.xml ├── LICENSE ├── exportFRUNet_from_keras.ipynb ├── export_StarDist_to_TensorFlow_SavedModel.ipynb ├── keras_for_deepimagej.ipynb ├── requirements.txt ├── unet ├── data │ ├── processed.zip │ └── raw.zip ├── py_files │ ├── convert_to_pb.py │ ├── data_loading.py │ ├── fit_model.py │ ├── helpers.py │ ├── model.py │ ├── prep_data.py │ ├── unet_dm.py │ └── unet_weights.py └── train_and_test_unet.ipynb └── xml ├── config_template.xml └── create_config.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 122 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/python4deepimagej.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2019, DeepImageJ 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /exportFRUNet_from_keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "rQfu09ZWP3RH" 8 | }, 9 | "source": [ 10 | "**Use this code to export FRU-Net model to proto buffer and use it in DeepImageJ**\n", 11 | "\n", 12 | "FRU-Net: https://cbia.fi.muni.cz/research/segmentation/fru-net\n", 13 | "\n", 14 | "DeepImageJ: https://deepimagej.github.io/deepimagej/index.html\n" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "colab_type": "text", 21 | "id": "Paui0JfkQ8Lt" 22 | }, 23 | "source": [ 24 | "Mount your Google Drive" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 10, 30 | "metadata": { 31 | "colab": { 32 | "base_uri": "https://localhost:8080/", 33 | "height": 54 34 | }, 35 | "colab_type": "code", 36 | "executionInfo": { 37 | "elapsed": 840, 38 | "status": "ok", 39 | "timestamp": 1564585146016, 40 | "user": { 41 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL", 42 | "photoUrl": "", 43 | "userId": "04592796515262324641" 44 | }, 45 | "user_tz": -120 46 | }, 47 | "id": "8FGAyk73Q7nR", 48 | "outputId": "a41a584b-7aa7-4b49-df83-b10ee0c5e9de" 49 | }, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "from google.colab import drive\n", 61 | "drive.mount('/content/drive')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "colab_type": "text", 68 | "id": "nfkCAMIkROF2" 69 | }, 70 | "source": [ 71 | "Install a compatible version of Keras and Tensorflow" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 11, 77 | "metadata": { 78 | "colab": { 79 | "base_uri": "https://localhost:8080/", 80 | "height": 530 81 | }, 82 | "colab_type": "code", 83 | "executionInfo": { 84 | "elapsed": 5862, 85 | "status": "ok", 86 | "timestamp": 1564585153377, 87 | "user": { 88 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL", 89 | "photoUrl": "", 90 | "userId": "04592796515262324641" 91 | }, 92 | "user_tz": -120 93 | }, 94 | "id": "tM-GjE2TRMXO", 95 | "outputId": "9972d447-6ca5-4b63-aa20-74029b2c8f91" 96 | }, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Cloning into 'python4deepimagej'...\n", 103 | "remote: Enumerating objects: 7, done.\u001b[K\n", 104 | "remote: Counting objects: 100% (7/7), done.\u001b[K\n", 105 | "remote: Compressing objects: 100% (7/7), done.\u001b[K\n", 106 | "remote: Total 7 (delta 2), reused 0 (delta 0), pack-reused 0\u001b[K\n", 107 | "Unpacking objects: 100% (7/7), done.\n", 108 | "Requirement already satisfied: keras==1.2.2 in /usr/local/lib/python3.6/dist-packages (1.2.2)\n", 109 | "Requirement already satisfied: tensorflow in /usr/local/lib/python3.6/dist-packages (1.13.1)\n", 110 | "Requirement already satisfied: theano in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (1.0.4)\n", 111 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (3.13)\n", 112 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from keras==1.2.2) (1.12.0)\n", 113 | "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.0)\n", 114 | "Requirement already satisfied: tensorboard<1.14.0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.13.1)\n", 115 | "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.1.0)\n", 116 | "Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.2.2)\n", 117 | "Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.8.0)\n", 118 | "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.0.8)\n", 119 | "Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.16.4)\n", 120 | "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.33.4)\n", 121 | "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.15.0)\n", 122 | "Requirement already satisfied: tensorflow-estimator<1.14.0rc0,>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (1.13.0)\n", 123 | "Requirement already satisfied: absl-py>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (0.7.1)\n", 124 | "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow) (3.7.1)\n", 125 | "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from theano->keras==1.2.2) (1.3.0)\n", 126 | "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow) (0.15.5)\n", 127 | "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard<1.14.0,>=1.13.0->tensorflow) (3.1.1)\n", 128 | "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow) (2.8.0)\n", 129 | "Requirement already satisfied: mock>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-estimator<1.14.0rc0,>=1.13.0->tensorflow) (3.0.5)\n", 130 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow) (41.0.1)\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "%pip install keras==1.2.2 tensorflow\n" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "colab_type": "text", 142 | "id": "DRHahE4lRXWs" 143 | }, 144 | "source": [ 145 | "Import dependencies" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 0, 151 | "metadata": { 152 | "colab": {}, 153 | "colab_type": "code", 154 | "id": "6wgEqbg1RTLg" 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "import tensorflow as tf\n", 159 | "import keras\n", 160 | "from keras import backend as K" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 0, 166 | "metadata": { 167 | "colab": {}, 168 | "colab_type": "code", 169 | "id": "jTQ8MQJjWdY7" 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "Download a trained FRU-Net model from FRU-Net: https://cbia.fi.muni.cz/research/segmentation/fru-net" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "colab_type": "text", 180 | "id": "a6jENPIkjmYG" 181 | }, 182 | "source": [ 183 | "Download the ZIP file containing all the information about FRU-Net from https://cbia.fi.muni.cz/research/segmentation/fru-net.\n", 184 | "\n", 185 | "Unzip the file and load one of the trained models (.h5)\n", 186 | "\n" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 0, 192 | "metadata": { 193 | "colab": {}, 194 | "colab_type": "code", 195 | "id": "GSIQ8TUUWZBe" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "#Fill the path to your keras network\n", 200 | "path2network='/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/fully_residual_dropout_segmentation.h5'\n", 201 | "\n", 202 | "# Set the learning phase to convert properly the model\n", 203 | "# The learning phase flag is a bool tensor (0 = test, 1 = train) to be passed as\n", 204 | "# input to any Keras function that uses a different behavior at train time and \n", 205 | "# test time.\n", 206 | "\n", 207 | "K.set_learning_phase(1)\n", 208 | "\n", 209 | "# Load the model\n", 210 | "model = keras.models.load_model(path2network)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": { 216 | "colab_type": "text", 217 | "id": "vstqSMe1XOri" 218 | }, 219 | "source": [ 220 | "Save your keras model as proto buffer" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 9, 226 | "metadata": { 227 | "colab": { 228 | "base_uri": "https://localhost:8080/", 229 | "height": 156 230 | }, 231 | "colab_type": "code", 232 | "executionInfo": { 233 | "elapsed": 4303, 234 | "status": "ok", 235 | "timestamp": 1564584430801, 236 | "user": { 237 | "displayName": "ESTIBALIZ GOMEZ DE MARISCAL", 238 | "photoUrl": "", 239 | "userId": "04592796515262324641" 240 | }, 241 | "user_tz": -120 242 | }, 243 | "id": "GNmQlysgWvBf", 244 | "outputId": "61a34f37-107a-41a2-8728-e16a41dde315" 245 | }, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:205: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.\n", 252 | "Instructions for updating:\n", 253 | "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.\n", 254 | "INFO:tensorflow:No assets to save.\n", 255 | "INFO:tensorflow:No assets to write.\n", 256 | "INFO:tensorflow:SavedModel written to: /content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet/saved_model.pb\n" 257 | ] 258 | }, 259 | { 260 | "data": { 261 | "text/plain": [ 262 | "b'/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet/saved_model.pb'" 263 | ] 264 | }, 265 | "execution_count": 9, 266 | "metadata": { 267 | "tags": [] 268 | }, 269 | "output_type": "execute_result" 270 | } 271 | ], 272 | "source": [ 273 | "OUTPUT_DIR = \"/content/drive/My Drive/Projectos/DEEP-IMAGEJ/examples_of_models/frunet/FRUNet\"\n", 274 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n", 275 | "\n", 276 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n", 277 | " inputs = {'input': model.input},\n", 278 | " outputs = {'output': model.output})\n", 279 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n", 280 | "\n", 281 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n", 282 | " signature_def_map=signature_def_map)\n", 283 | "builder.save()" 284 | ] 285 | } 286 | ], 287 | "metadata": { 288 | "colab": { 289 | "name": "exportFRUNet_from_keras.ipynb", 290 | "provenance": [], 291 | "version": "0.3.2" 292 | }, 293 | "kernelspec": { 294 | "display_name": "Python 3", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.6.8" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 1 313 | } 314 | -------------------------------------------------------------------------------- /export_StarDist_to_TensorFlow_SavedModel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Untitled0.ipynb", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | } 13 | }, 14 | "cells": [ 15 | { 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "id": "FtkHVlkGZuz3", 19 | "colab_type": "text" 20 | }, 21 | "source": [ 22 | "# **This is a genertic python code to export StarDist trained models and use them with DeepImageJ plugin**\n", 23 | "\n", 24 | "https://deepimagej.github.io/deepimagej/index.html\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "MklgDjVtbEKz", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "If you are using Google Colab, mount your Google Drive. Otherwise, skip this step" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "EzTJ9dClsagp", 41 | "colab_type": "code", 42 | "colab": {} 43 | }, 44 | "source": [ 45 | "from google.colab import drive\n", 46 | "drive.mount('/content/drive')" 47 | ], 48 | "execution_count": 0, 49 | "outputs": [] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "id": "OPfCnAKAZ7XI", 55 | "colab_type": "text" 56 | }, 57 | "source": [ 58 | "\n", 59 | "Install the following packages: \n", 60 | "- A compatible version of Tensorflow <= 1.13.\n", 61 | "- stardist python package. Here we used StarDist 0.3.6" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "id": "y_3YIsZxs39M", 68 | "colab_type": "code", 69 | "colab": {} 70 | }, 71 | "source": [ 72 | "% pip install tensorflow==1.13.1\n", 73 | "% pip install stardist" 74 | ], 75 | "execution_count": 0, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "wI2iQC_ttE_q", 82 | "colab_type": "text" 83 | }, 84 | "source": [ 85 | "# Load the StarDist trained model from your repository" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "MklgDjVtbEKz", 92 | "colab_type": "text" 93 | }, 94 | "source": [ 95 | "Verify input and output sizes of your model. They can be different when the parameter grid is not (1,1). A different output size can lead to errors in DeepImageJ. Take it also into account if you want to perform shape measurements using the output image." 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "metadata": { 101 | "id": "wQGQdksis3_o", 102 | "colab_type": "code", 103 | "colab": {} 104 | }, 105 | "source": [ 106 | "from stardist.models import StarDist2D\n", 107 | "# Without shape completion\n", 108 | "model_paper = StarDist2D(None, name='name_of_your_model', basedir='/content/drive/My Drive/the_path_to_your_model/folde_containing_the_model')\n", 109 | "# Indicate which weights you want to use\n", 110 | "model_paper.load_weights('weights_best.h5')" 111 | ], 112 | "execution_count": 0, 113 | "outputs": [] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": { 118 | "id": "tEj1Oj9FubL0", 119 | "colab_type": "text" 120 | }, 121 | "source": [ 122 | "# Save as a TensorFlow SavedModel" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "7B5veKEEuhq7", 129 | "colab_type": "code", 130 | "colab": {} 131 | }, 132 | "source": [ 133 | "import keras\n", 134 | "import keras.backend as K\n", 135 | "from keras.layers import concatenate\n", 136 | "import tensorflow as tf\n", 137 | "#Write the path where you would like to save the model. \n", 138 | "# The code will automatically create a new folder called \"new_folder\", where the\n", 139 | "# TensorFlow model will be saved\n", 140 | "OUTPUT_DIR = \"/content/drive/My Drive/the_path_where_you_want_to_save_your_model/new_folder\"\n", 141 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n", 142 | "\n", 143 | "# StarDist has two different outputs. DeepImageJ can only read one of them, so \n", 144 | "# we concatenate them as different channels in order to used them in ImageJ.\n", 145 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n", 146 | " inputs = {'input': model_paper.keras_model.input[0]},\n", 147 | " # concatenate the output of StarDist\n", 148 | " outputs = {'output': concatenate([model_paper.keras_model.output[0],model_paper.keras_model.output[1]], axis = 3)})\n", 149 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n", 150 | "\n", 151 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n", 152 | " signature_def_map=signature_def_map)\n", 153 | "builder.save()" 154 | ], 155 | "execution_count": 0, 156 | "outputs": [] 157 | } 158 | ] 159 | } 160 | -------------------------------------------------------------------------------- /keras_for_deepimagej.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "rQfu09ZWP3RH" 8 | }, 9 | "source": [ 10 | "**This is a genertic python code to export Keras models and use them with DeepImageJ plugin**\n", 11 | "\n", 12 | "\n", 13 | "https://deepimagej.github.io/deepimagej/index.html\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "colab_type": "text", 20 | "id": "Paui0JfkQ8Lt" 21 | }, 22 | "source": [ 23 | "Mount your Google Drive" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 0, 29 | "metadata": { 30 | "colab": {}, 31 | "colab_type": "code", 32 | "id": "8FGAyk73Q7nR" 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "from google.colab import drive\n", 37 | "drive.mount('/content/drive')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "colab_type": "text", 44 | "id": "nfkCAMIkROF2" 45 | }, 46 | "source": [ 47 | "Install a compatible version of Tensorflow <= 1.13" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 0, 53 | "metadata": { 54 | "colab": {}, 55 | "colab_type": "code", 56 | "id": "tM-GjE2TRMXO" 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "%pip install tensorflow==1.13.1\n", 61 | "%pip install keras==2.2.4\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": { 67 | "colab_type": "text", 68 | "id": "DRHahE4lRXWs" 69 | }, 70 | "source": [ 71 | "Import dependencies" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 0, 77 | "metadata": { 78 | "colab": {}, 79 | "colab_type": "code", 80 | "id": "6wgEqbg1RTLg" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "import tensorflow as tf\n", 85 | "import keras\n", 86 | "from keras import backend as K" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 0, 92 | "metadata": { 93 | "colab": {}, 94 | "colab_type": "code", 95 | "id": "jTQ8MQJjWdY7" 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "Load a keras network" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 0, 105 | "metadata": { 106 | "colab": {}, 107 | "colab_type": "code", 108 | "id": "GSIQ8TUUWZBe" 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "#Fill the path to your keras network\n", 113 | "path2network='/path2yournetwork/your_network.hdf5'\n", 114 | "model = keras.models.load_model(path2network)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": { 120 | "colab_type": "text", 121 | "id": "vstqSMe1XOri" 122 | }, 123 | "source": [ 124 | "Save your keras model as proto buffer" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 0, 130 | "metadata": { 131 | "colab": {}, 132 | "colab_type": "code", 133 | "id": "GNmQlysgWvBf" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "#If the model has only one input it can be converted\n", 138 | "OUTPUT_DIR = \"/your/output/directory/new_folder_name\"\n", 139 | "builder = tf.saved_model.builder.SavedModelBuilder(OUTPUT_DIR)\n", 140 | "\n", 141 | "signature = tf.saved_model.signature_def_utils.predict_signature_def(\n", 142 | " inputs = {'input': model.input},\n", 143 | " outputs = {'output': model.output})\n", 144 | "signature_def_map = { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature }\n", 145 | "\n", 146 | "builder.add_meta_graph_and_variables(K.get_session(), [tf.saved_model.tag_constants.SERVING],\n", 147 | " signature_def_map=signature_def_map)\n", 148 | "builder.save()" 149 | ] 150 | } 151 | ], 152 | "metadata": { 153 | "colab": { 154 | "name": "keras_for_deepimagej.ipynb", 155 | "provenance": [], 156 | "version": "0.3.2" 157 | }, 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.8" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 1 178 | } 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | xml 3 | time 4 | urllib 5 | shutil 6 | skimage 7 | tensorflow<=2.2.1 8 | -------------------------------------------------------------------------------- /unet/data/processed.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepimagej/python4deepimagej/535fc4061f9ae93878d70c68f23536233bb74562/unet/data/processed.zip -------------------------------------------------------------------------------- /unet/data/raw.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepimagej/python4deepimagej/535fc4061f9ae93878d70c68f23536233bb74562/unet/data/raw.zip -------------------------------------------------------------------------------- /unet/py_files/convert_to_pb.py: -------------------------------------------------------------------------------- 1 | # Important librairies. 2 | 3 | import keras 4 | from keras import backend as K 5 | from keras.models import load_model 6 | 7 | import tensorflow as tf 8 | from tensorflow.compat.v1 import graph_util 9 | from tensorflow.python.framework import graph_io 10 | 11 | # ----------------------------------------------------------------------------- 12 | 13 | def convert_to_pb(name_project): 14 | """ 15 | Converts Keras model into Tensorflow .pb file. 16 | 17 | The string 'name_project' represents the name of the DeepPix Worflow project 18 | given by the user. 19 | """ 20 | 21 | # Define paths. 22 | path_to_model = '/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=name_project) 23 | path_output = '/content/drive/My Drive/unser_project/' 24 | 25 | # Load model. 26 | model = load_model(path_to_model) 27 | 28 | # Get node names. 29 | node_names = [node.op.name for node in model.outputs] 30 | 31 | # Get Keras session. 32 | session = K.get_session() 33 | 34 | # Convert Keras variables to Tensorflow constants. 35 | graph_to_constant = graph_util.convert_variables_to_constants(session, session.graph.as_graph_def(), node_names) 36 | 37 | # Write graph as .pb file. 38 | graph_io.write_graph(graph_to_constant, path_output, name_project + ".pb", as_text=False) 39 | -------------------------------------------------------------------------------- /unet/py_files/data_loading.py: -------------------------------------------------------------------------------- 1 | # Important librairies. 2 | 3 | from __future__ import print_function 4 | from keras.preprocessing.image import ImageDataGenerator 5 | import numpy as np 6 | import glob 7 | 8 | # ----------------------------------------------------------------------------- 9 | 10 | # Important py.files. 11 | 12 | from helpers import * 13 | 14 | # ----------------------------------------------------------------------------- 15 | 16 | def dataGenerator(path, batch_size = 2, subset = 'train', target_size = (256,256), seed = 1): 17 | 18 | """ 19 | Builds generators for the U-Net. The generators can be built for 20 | training, testing and validation purposes. 21 | 22 | The string "subset" is used to specify which type of data we are dealing 23 | with (train, test or validation). Default value is set to 'train'. 24 | 25 | The string "path" represents a path that should lead to images and labels 26 | folders named 'image' and 'label' respectively. 27 | 28 | The tuple "target_size" is used to specify the final sizes of the images 29 | and labels after augmentation. If the given size does not correspond to 30 | original size of the images and labels, the data will be resized with the 31 | given size. Default value is set to (256, 256) (image size of 256x256 pixels). 32 | 33 | The variable seed is needed to ensure that images and labels will be augmented 34 | together in the right orders. Default value set to 1. 35 | """ 36 | 37 | # Builds generator for training set. 38 | if subset == "train": 39 | 40 | # Preprocessing arguments. 41 | aug_arg = dict(rotation_range = 40, 42 | width_shift_range = 0.2, 43 | height_shift_range = 0.2, 44 | shear_range = 0.2, 45 | horizontal_flip = True, 46 | vertical_flip = True, 47 | fill_mode='nearest') 48 | 49 | # Generates tensor images and labels with augmentations provided above. 50 | image_datagen = ImageDataGenerator(**aug_arg) 51 | label_datagen = ImageDataGenerator(**aug_arg) 52 | 53 | # Generator for images. 54 | image_generator = image_datagen.flow_from_directory( 55 | path, 56 | classes = ['image'], 57 | class_mode = None, 58 | color_mode = "grayscale", 59 | target_size = target_size, 60 | batch_size = batch_size, 61 | save_to_dir = None, 62 | seed = seed, 63 | shuffle = True) 64 | 65 | # Generator for labels. 66 | label_generator = label_datagen.flow_from_directory( 67 | path, 68 | classes = ['label'], 69 | class_mode = None, 70 | color_mode = "grayscale", 71 | target_size = target_size, 72 | batch_size = batch_size, 73 | save_to_dir = None, 74 | seed = seed, 75 | shuffle = True) 76 | 77 | # Builds generator for the training set. 78 | train_generator = zip(image_generator, label_generator) 79 | 80 | for (img,label) in train_generator: 81 | img, label = adjustData(img, label = label) 82 | 83 | yield (img, label) 84 | 85 | # Builds generator for validation set. 86 | elif subset == "validation": 87 | 88 | # Generates tensor images and labels with no augmentations 89 | # (validation set should not have any augmentation and does not 90 | # have to be shuffled). 91 | image_datagen = ImageDataGenerator() 92 | label_datagen = ImageDataGenerator() 93 | 94 | # Generator for images. 95 | image_generator = image_datagen.flow_from_directory( 96 | path, 97 | classes = ['image'], 98 | class_mode = None, 99 | color_mode = "grayscale", 100 | target_size = target_size, 101 | batch_size = batch_size, 102 | save_to_dir = None, 103 | seed = seed, 104 | shuffle = False) 105 | 106 | # Generator for labels. 107 | label_generator = label_datagen.flow_from_directory( 108 | path, 109 | classes = ['label'], 110 | class_mode = None, 111 | color_mode = "grayscale", 112 | target_size = target_size, 113 | batch_size = batch_size, 114 | save_to_dir = None, 115 | seed = seed, 116 | shuffle = False) 117 | 118 | # Builds generator for the validation set. 119 | validation_generator = zip(image_generator, label_generator) 120 | 121 | for (img,label) in validation_generator: 122 | img, label = adjustData(img, label = label) 123 | 124 | yield (img, label) 125 | 126 | # Builds generator for testing set. 127 | elif subset == "test": 128 | 129 | # Generates tensor images only with no augmentations (testing data 130 | # does not have to have labels and we do not shuffle the data 131 | # as it is not necessary). 132 | image_datagen = ImageDataGenerator() 133 | 134 | # Generator for images. 135 | image_generator = image_datagen.flow_from_directory( 136 | path, 137 | classes = ['image'], 138 | class_mode = None, 139 | color_mode = "grayscale", 140 | target_size = target_size, 141 | batch_size = batch_size, 142 | save_to_dir = None, 143 | seed = seed, 144 | shuffle = False) 145 | 146 | # Builds generator for the testing set. 147 | for img in image_generator: 148 | 149 | img = adjustData(img, False) 150 | 151 | yield img 152 | 153 | else: 154 | raise RuntimeError("Subset name not recognized") 155 | 156 | # ----------------------------------------------------------------------------- 157 | 158 | def weightGen(path, batch_size = 2, subset = 'train', target_size = (256,256), seed = 1): 159 | 160 | """ 161 | Builds generators for the weighted U-Net. The generators are built the 162 | same way as the dataGenerator function, only weight-maps are combined 163 | with the images. The generators can be built for training and 164 | validation purposes. 165 | 166 | The string "subset" is used to specify which type of data we are dealing 167 | with (train, test or validation). Default value is set to 'train'. 168 | 169 | The string "path" represents a path that should lead to images and labels 170 | folders named 'image' and 'label' respectively. 171 | 172 | The tuple "target_size" is used to specify the final sizes of the images 173 | and labels after augmentation. If the given size does not correspond to 174 | original size of the images and labels, the data will be resized with the 175 | given size. Default value is set to (256, 256) (image size of 256x256 pixels). 176 | 177 | The variable seed is needed to ensure that images and labels will be augmented 178 | together in the right orders. Default value set to 1. 179 | """ 180 | 181 | # Builds generator for training set. 182 | if subset == "train": 183 | 184 | # Preprocessing arguments. 185 | aug_arg = dict(rotation_range = 40, 186 | width_shift_range = 0.2, 187 | height_shift_range = 0.2, 188 | shear_range = 0.2, 189 | horizontal_flip = True, 190 | vertical_flip = True, 191 | fill_mode='nearest') 192 | 193 | # Generates tensor images, weight-maps and labels with augmentations provided above. 194 | image_datagen = ImageDataGenerator(**aug_arg) 195 | label_datagen = ImageDataGenerator(**aug_arg) 196 | weight_datagen = ImageDataGenerator(**aug_arg) 197 | 198 | # Generator for images. 199 | image_generator = image_datagen.flow_from_directory( 200 | path, 201 | classes = ['image'], 202 | class_mode = None, 203 | color_mode = 'grayscale', 204 | target_size = target_size, 205 | batch_size = batch_size, 206 | save_to_dir = None, 207 | seed = seed, 208 | shuffle = True) 209 | 210 | # Generator for labels. 211 | label_generator = label_datagen.flow_from_directory( 212 | path, 213 | classes = ['label'], 214 | class_mode = None, 215 | color_mode = 'grayscale', 216 | target_size = target_size, 217 | batch_size = batch_size, 218 | save_to_dir = None, 219 | seed = seed, 220 | shuffle = True) 221 | 222 | # Retrieve weight-maps. 223 | filelist = glob.glob(path + "/weight/*.npy") 224 | filelist.sort(key=natural_keys) 225 | 226 | # Loads all weight-map images in a list. 227 | weights = [np.load(fname) for fname in filelist] 228 | weights = np.array(weights) 229 | weights = weights.reshape((len(weights),256,256,1)) 230 | 231 | # Creates the weight generator. 232 | weight_generator = weight_datagen.flow( 233 | x = weights, 234 | y = None, 235 | batch_size = batch_size, 236 | seed = seed) 237 | 238 | # Builds generator for the training set. 239 | train_generator = zip(image_generator, label_generator, weight_generator) 240 | 241 | for (img, label, weight) in train_generator: 242 | img, label = adjustData(img, label = label) 243 | 244 | # This is the final generator. 245 | yield ([img, weight], label) 246 | 247 | elif subset == "validation": 248 | 249 | # Generates tensor images, weight maps and labels with no augmentations 250 | # and shuffling (since we are in the test set). 251 | image_datagen = ImageDataGenerator() 252 | label_datagen = ImageDataGenerator() 253 | weight_datagen = ImageDataGenerator() 254 | 255 | # Generator for images. 256 | image_generator = image_datagen.flow_from_directory( 257 | path, 258 | classes = ['image'], 259 | class_mode = None, 260 | color_mode = 'grayscale', 261 | target_size = target_size, 262 | batch_size = batch_size, 263 | save_to_dir = None, 264 | seed = seed, 265 | shuffle = False) 266 | 267 | # Generator for labels. 268 | label_generator = label_datagen.flow_from_directory( 269 | path, 270 | classes = ['label'], 271 | class_mode = None, 272 | color_mode = 'grayscale', 273 | target_size = target_size, 274 | batch_size = batch_size, 275 | save_to_dir = None, 276 | seed = seed, 277 | shuffle = False) 278 | 279 | # Retrieve weight maps. 280 | filelist = glob.glob(path + "/weight/*.npy") 281 | filelist.sort(key=natural_keys) 282 | 283 | # Loads all weight map images in a list. 284 | weights = [np.load(fname) for fname in filelist] 285 | weights = np.array(weights) 286 | weights = weights.reshape((len(weights),256,256,1)) 287 | 288 | # Creates the weight generator. 289 | weight_generator = weight_datagen.flow( 290 | x = weights, 291 | y = None, 292 | batch_size = batch_size, 293 | seed = seed, 294 | shuffle = False) 295 | 296 | # Builds generator for the test set. 297 | test_generator = zip(image_generator, label_generator, weight_generator) 298 | 299 | for (img, label, weight) in test_generator: 300 | img, label = adjustData(img, label = label) 301 | 302 | # This is the final generator. 303 | yield ([img, weight], label) 304 | 305 | else: 306 | raise RuntimeError("Subset name not recognized") 307 | 308 | # ----------------------------------------------------------------------------- 309 | 310 | def adjustData(image, adjust_lab = True, dist = False, label = None): 311 | """ 312 | Normalizes the data such that images are in the interval [0,1] and labels 313 | are binary values in {0,1}. This step is important as augmentations with 314 | Keras' ImageDataGenerator() will change the pixel values of the images 315 | and labels and most notably images will not be normalized anymore and labels 316 | will not be binary anymore. 317 | 318 | The numpy array 'image' represents the input image. 319 | 320 | The numpy array 'label' represents the label image. If adjust_lab is set to 321 | False, label should be set to None. Default value set to None. 322 | 323 | The boolean value 'adjust_lab' specifies if we need to process labels or not 324 | for training an validation purposes. Default value set to True. 325 | """ 326 | 327 | # Checks if the images are already between 0 and 1, otherwise 328 | # does the normalization. 329 | if(np.max(image) > 1): 330 | image = image / 255 331 | 332 | if adjust_lab: 333 | 334 | # Checks if the labels are already binary, otherwise 335 | # does the binarization. 336 | if (np.max(label) > 1): 337 | label = label / 255 338 | label[label > 0.5] = 1 339 | label[label <= 0.5] = 0 340 | 341 | return (image, label) 342 | 343 | if dist: 344 | return (image, label) 345 | 346 | else: 347 | return image 348 | 349 | # ----------------------------------------------------------------------------- 350 | 351 | def loadGenerator(name_project, model_type, batch_size = 2, target_size = (256,256)): 352 | """ 353 | Loads generators for the training of a model automatically. 354 | 355 | The string 'name_project' represents the name of the DeepPix Worflow project 356 | given by the user. 357 | 358 | The string 'model_type' represents the type of model desired by the user. 359 | 360 | The integer value 'batch_size' represents the size of the batches for the 361 | training. Default value set to 2. 362 | 363 | The tuple "target_size" is used to specify the final sizes of the images 364 | and labels after augmentation. If the given size does not correspond to 365 | original size of the images and labels, the data will be resized with the 366 | given size. Default value is set to (256, 256) (image size of 256x256 pixels). 367 | """ 368 | 369 | # Path to data folder. 370 | data_path = "/content/drive/My Drive/unser_project/data/processed/" 371 | 372 | # Paths for train and test folders. 373 | train_path = data_path + name_project + "/train/" 374 | test_path = data_path + name_project + "/test/" 375 | 376 | # Create generators depending on model type. 377 | if model_type == "unet_simple": 378 | trainGen = dataGenerator(path = train_path, batch_size = batch_size, subset = "train", target_size = target_size) 379 | validGen = dataGenerator(path = test_path, batch_size = batch_size, subset = "validation", target_size = target_size) 380 | 381 | elif model_type == "unet_weighted": 382 | trainGen = weightGen(path = train_path, batch_size = batch_size, subset = "train", target_size = target_size) 383 | validGen = weightGen(path = test_path, batch_size = batch_size, subset = "validation", target_size = target_size) 384 | 385 | else: 386 | raise RuntimeError("Model not recognised.") 387 | 388 | return trainGen, validGen 389 | 390 | # ----------------------------------------------------------------------------- 391 | 392 | #def distanceGen(path, n_classes, batch_size = 2, subset = 'train', image_folder = 'image', label_folder = 'distance', 393 | # image_col = "grayscale", label_col = "grayscale", target_size = (256,256), seed = 1): 394 | # 395 | # """Builds generators for the weighted U-Net. The generators are built 396 | # the same way as the dataGenerator function, only the weights are combined 397 | # with the images.""" 398 | # 399 | # # Builds generator for training set 400 | # if subset == "train": 401 | # 402 | # # Preprocessing arguments 403 | # aug_arg = dict(rotation_range = 40, 404 | # width_shift_range = 0.2, 405 | # height_shift_range = 0.2, 406 | # shear_range = 0.2, 407 | # horizontal_flip = True, 408 | # vertical_flip = True, 409 | # fill_mode='nearest') 410 | # 411 | # # Generates tensor images and labels with augmentations provided above 412 | # image_datagen = ImageDataGenerator(**aug_arg) 413 | # label_datagen = ImageDataGenerator(**aug_arg) 414 | # 415 | # # Generator for images 416 | # image_generator = image_datagen.flow_from_directory( 417 | # path, 418 | # classes = [image_folder], 419 | # class_mode = None, 420 | # color_mode = 'grayscale', 421 | # target_size = target_size, 422 | # batch_size = batch_size, 423 | # save_to_dir = None, 424 | # seed = seed, 425 | # shuffle = True) 426 | # 427 | # # Generator for labels 428 | # label_generator = label_datagen.flow_from_directory( 429 | # path, 430 | # classes = [label_folder], 431 | # class_mode = 'categorical', 432 | # color_mode = 'rgb', 433 | # target_size = target_size, 434 | # batch_size = batch_size, 435 | # save_to_dir = None, 436 | # seed = seed, 437 | # shuffle = True) 438 | # 439 | # # Builds generator for the training set 440 | # train_generator = zip(image_generator, label_generator) 441 | # 442 | # for (img, label) in train_generator: 443 | # img, label = adjustData(img, False, True, label = label) 444 | # 445 | # print(label) 446 | # # This is the final generator 447 | # yield (img, label) 448 | # 449 | # elif subset == "test": 450 | # 451 | # # Generates tensor images and labels with no augmnetations and shuffling 452 | # # (since we are in the test set) 453 | # image_datagen = ImageDataGenerator() 454 | # label_datagen = ImageDataGenerator() 455 | # 456 | # # Generator for images 457 | # image_generator = image_datagen.flow_from_directory( 458 | # path, 459 | # classes = [image_folder], 460 | # class_mode = None, 461 | # color_mode = 'grayscale', 462 | # target_size = target_size, 463 | # batch_size = batch_size, 464 | # save_to_dir = None, 465 | # seed = seed, 466 | # shuffle = False) 467 | # 468 | # # Generator for labels 469 | # label_generator = label_datagen.flow_from_directory( 470 | # path, 471 | # classes = [label_folder], 472 | # class_mode = 'categorical', 473 | # color_mode = 'rgb', 474 | # target_size = target_size, 475 | # batch_size = batch_size, 476 | # save_to_dir = None, 477 | # seed = seed, 478 | # shuffle = False) 479 | # 480 | # # Builds generator for the test set 481 | # test_generator = zip(image_generator, label_generator) 482 | # 483 | # for (img, label) in test_generator: 484 | # img, label = adjustData(img, False, True, label = label) 485 | # 486 | # # This is the final generator 487 | # yield (img, label) 488 | # 489 | # else: 490 | # print("Subset name not recognized") 491 | # return None 492 | 493 | # ----------------------------------------------------------------------------- 494 | 495 | #def label_to_cat(label, n_classes): 496 | # 497 | # label = np.rint(label / (255 / (n_classes - 1))) 498 | # 499 | # n, rows, cols, _ = label.shape 500 | # output = np.zeros((n, rows, cols, n_classes)) 501 | # 502 | # for i in range(n_classes): 503 | # tmp = (label[...,0] == i).astype(int) 504 | # output[...,i] = tmp 505 | # 506 | # output = np.reshape(output, (n, rows*cols, n_classes)) 507 | # 508 | # return output -------------------------------------------------------------------------------- /unet/py_files/fit_model.py: -------------------------------------------------------------------------------- 1 | # Important librairies 2 | 3 | import pickle 4 | import cv2 as cv 5 | import sys 6 | import glob 7 | 8 | # ----------------------------------------------------------------------------- 9 | 10 | # Important py.files 11 | sys.path.append("/content/drive/My Drive/unser_project/py_files") 12 | from model import * 13 | from data_loading import * 14 | from helpers import * 15 | from unet_weights import * 16 | from unet_dm import * 17 | 18 | # ----------------------------------------------------------------------------- 19 | 20 | def fit_model(trainGen, validGen, model_type, model_name, input_size = (256, 256, 1), loss_ = 'binary_crossentropy', 21 | lr = 1e-4, w_decay = 5e-7, steps = 500, epoch_num = 10, val_steps = 15, save_history = True): 22 | """ 23 | This function selects a model and fits the given generators with the given arguments. 24 | Then the history of the model and the model itself are saved. 25 | 26 | The generators 'trainGen' and 'validGen' represent the training and validation 27 | generators to fit the model. 28 | 29 | The string 'model_type' refers to the type of U-Net to use. 30 | 31 | The string 'model_name' refers to the name with which the model shall be saved. 32 | 33 | The tuple 'input_size' corresponds to the size of the input images and labels. 34 | Default value set to (256, 256, 1) (input images size is 256x256). 35 | 36 | The string 'loss_' represents the name of the loss that should be used. 37 | Default value set to 'binary_crossentropy'. 38 | 39 | The float 'lr' corresponds to the learning rate value for the training. 40 | Defaut value set to 1e-4. 41 | 42 | The float 'w_decay' corresponds to the weight decay value for the training. 43 | Default value set to 5e-7. 44 | 45 | The integer 'steps' refers to the number of steps between each epoch. This 46 | number should be big enough to allow for many augmentations. 47 | Default value set to 500. 48 | 49 | The integer 'epoch_num' refers to the number of epochs to be used for the training. 50 | Default value set to 10. 51 | 52 | The integer 'val_steps' refers to the number of steps for validation step of each 53 | epoch. This number should be equal to the number of validation images. 54 | Default value set to 15. 55 | 56 | The boolean 'save_history' refers to whether or not the history of the training 57 | should be saved. 58 | """ 59 | 60 | # Load a model. 61 | if model_type == "unet_simple": 62 | model = unet(input_size = input_size, loss_ = loss_, learning_rate = lr, weight_decay = w_decay) 63 | 64 | elif model_type == "unet_weighted": 65 | model = unet_weights(input_size = input_size, learning_rate = lr, weight_decay = w_decay) 66 | 67 | # elif model_type == "unet_dm": 68 | # model = unet_distance(learning_rate = lr, weight_decay = w_decay) 69 | 70 | else: 71 | raise RuntimeError("Model type not recognized") 72 | 73 | # Callbacks. 74 | model_checkpoint = ModelCheckpoint('/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=model_name), monitor='val_loss', verbose=1, save_best_only=True) 75 | early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=1, mode='auto', restore_best_weights=True) 76 | 77 | # Fit. 78 | history = model.fit_generator(trainGen, 79 | steps_per_epoch=steps, 80 | epochs=epoch_num, 81 | callbacks=[model_checkpoint, early_stopping], 82 | validation_data = validGen, 83 | validation_steps = val_steps) 84 | 85 | if save_history: 86 | 87 | # Saving the history for plotting. 88 | pickle.dump(history.history, open('/content/drive/My Drive/unser_project/histories/{b}.p'.format(b=model_name), "wb" )) 89 | 90 | return None 91 | 92 | # ----------------------------------------------------------------------------- 93 | 94 | def show_predictions(model, name_project, target_size = (256, 256)): 95 | """ 96 | Shows one image with its ground truth and the prediction of the model (as 97 | a binary image and a probability map). 98 | 99 | The string 'model' corresponds to the type of model used. 100 | 101 | The string 'name_project' refers to the name of the DeepPix Worflow project 102 | given by the user. 103 | 104 | The tuple "target_size" is used to specify the final sizes of the images 105 | and labels. If the given size does not correspond to original size of the 106 | images and labels, the data will be resized with the given size. 107 | Default value is set to (256, 256) (image size of 256x256 pixels). 108 | """ 109 | 110 | # Path of the test set 111 | test_path = "/content/drive/My Drive/unser_project/data/processed/" + name_project + "/test/" 112 | 113 | # List of files 114 | list_file = glob.glob(test_path + 'image/*.png') 115 | 116 | # Number of files (important for number of leading zeros) 117 | n_file = len(list_file) 118 | 119 | # Count number of digits in n_file. This is important for the number 120 | # of leading zeros in the name of the images and labels. 121 | n_digits = len(str(n_file)) 122 | 123 | # Creates title depending on model type and prepares test generator 124 | # depending on model type. 125 | if model == "unet_simple": 126 | title = "Simple U-Net" 127 | testGen = dataGenerator(batch_size = 1, subset = "test", path = test_path) 128 | mdl = unet(input_size = (256,256,1)) 129 | 130 | elif model == "unet_weighted": 131 | title = "Weighted U-Net" 132 | testGen = weightGen(batch_size = 1, subset = "test", path = test_path) 133 | mdl = unet_weights(input_size = (256,256,1)) 134 | 135 | else: 136 | raise RuntimeError("Model not recognised.") 137 | 138 | # Loads one image and label. 139 | img_path = test_path + "image/{b:0" + str(n_digits) + "d}.png" 140 | lbl_path = test_path + "label/{b:0" + str(n_digits) + "d}.png" 141 | 142 | img = cv.imread(img_path.format(b=0)) 143 | label = cv.imread(lbl_path.format(b=0)) 144 | 145 | # Resizes to target size. 146 | img = cv.resize(img, target_size) 147 | label = cv.resize(label, target_size) 148 | 149 | # Load model and perform predictions. 150 | mdl.load_weights('/content/drive/My Drive/unser_project/models/{b}.hdf5'.format(b=name_project)) 151 | prediction = mdl.predict_generator(testGen, 2, verbose=1, workers=1) 152 | 153 | # Binarizes one prediction. 154 | pred_binarized = convertLabel(prediction[0]) 155 | 156 | # Perform plot. 157 | fig, ax = plt.subplots(2, 2, sharex=True, sharey=True, figsize=((15,15))) 158 | 159 | ax[0,0].grid(False) 160 | ax[0,1].grid(False) 161 | ax[1,0].grid(False) 162 | ax[1,1].grid(False) 163 | 164 | ax[0,0].imshow(img, cmap = 'gray', aspect="auto") 165 | ax[0,1].imshow(label, cmap = 'gray', aspect="auto") 166 | ax[1,0].imshow(pred_binarized, cmap = 'gray', aspect="auto") 167 | ax[1,1].imshow(prediction[0,...,0], cmap = 'gray', aspect="auto", vmin=0, vmax=1) 168 | 169 | ax[0,0].set_title("Input", fontsize = 17.5) 170 | ax[0,1].set_title("Ground truth", fontsize = 17.5) 171 | ax[1,0].set_title(title + " - Binarized", fontsize = 17.5) 172 | ax[1,1].set_title(title + " - Probability map", fontsize = 17.5) 173 | -------------------------------------------------------------------------------- /unet/py_files/helpers.py: -------------------------------------------------------------------------------- 1 | # Important librairies. 2 | 3 | from PIL import Image 4 | import glob 5 | import numpy as np 6 | import re 7 | import matplotlib.pyplot as plt 8 | from skimage import measure 9 | import scipy.ndimage 10 | import os 11 | import cv2 12 | import pickle 13 | import copy 14 | from tifffile import imsave 15 | 16 | # ----------------------------------------------------------------------------- 17 | 18 | def prepare_standardplot(title, xlabel): 19 | """ 20 | Prepares the layout and axis for the plotting of the history from the training. 21 | 22 | The string 'title' refers to the title of the plot. 23 | 24 | The string 'xlabel' refers to the name of the x-axis. 25 | """ 26 | 27 | fig, (ax1, ax2) = plt.subplots(1, 2) 28 | fig.suptitle(title) 29 | 30 | ax1.set_ylabel('Binary cross-entropy') 31 | ax1.set_xlabel(xlabel) 32 | ax1.set_yscale('log') 33 | 34 | ax2.set_ylabel('Accuracy') 35 | ax2.set_xlabel(xlabel) 36 | 37 | return fig, ax1, ax2 38 | 39 | # ----------------------------------------------------------------------------- 40 | 41 | def finalize_standardplot(fig, ax1, ax2): 42 | """ 43 | Finalizes the layout of the plotting of the history from the training. 44 | 45 | The variable 'fig' refers to the created figure of the plot. 46 | 47 | The variables 'ax1' and 'ax2' refer to the axes of the plot. 48 | """ 49 | 50 | ax1handles, ax1labels = ax1.get_legend_handles_labels() 51 | if len(ax1labels) > 0: 52 | ax1.legend(ax1handles, ax1labels) 53 | 54 | ax2handles, ax2labels = ax2.get_legend_handles_labels() 55 | if len(ax2labels) > 0: 56 | ax2.legend(ax2handles, ax2labels) 57 | 58 | fig.tight_layout() 59 | 60 | plt.subplots_adjust(top=0.9) 61 | 62 | # ----------------------------------------------------------------------------- 63 | 64 | def plot_history(history, title): 65 | """ 66 | Plots the history from the training of a model. More precisely, this function 67 | plots the training loss, the validation loss, the training accuracy and 68 | the validation accuracy of a model training. 69 | 70 | The variable 'history' refers to the history file that was saved after 71 | the training of the model. 72 | 73 | The string 'title' represents the title that the plot will have. 74 | """ 75 | 76 | if title == "unet_simple": 77 | title = "Simple U-Net" 78 | 79 | elif title == "unet_weighted": 80 | title = "Weighted U-Net" 81 | 82 | fig, ax1, ax2 = prepare_standardplot(title, 'Epoch') 83 | 84 | ax1.plot(history['loss'], label = "Training") 85 | ax1.plot(history['val_loss'], label = "Validation") 86 | 87 | ax2.plot(history['acc'], label = "Training") 88 | ax2.plot(history['val_acc'], label = "Validation") 89 | 90 | finalize_standardplot(fig, ax1, ax2) 91 | 92 | return fig 93 | 94 | # ----------------------------------------------------------------------------- 95 | 96 | def natural_keys(text): 97 | """ 98 | Sorts the filelist in a more "human" order. 99 | 100 | The variable 'text' represents a file list that would be imported with 101 | the glob library. 102 | """ 103 | 104 | def atoi(text): 105 | return int(text) if text.isdigit() else text 106 | 107 | return [atoi(c) for c in re.split('(\d+)', text)] 108 | 109 | # ----------------------------------------------------------------------------- 110 | 111 | def load_data(path_images, path_labels): 112 | """ 113 | Loads and returns images and labels. 114 | 115 | The variables 'path_images' and 'path_labels' refer to the paths of the 116 | folders containing the images and labels, respectively. 117 | """ 118 | 119 | # Creates a list of file names in the data directory. 120 | filelist = glob.glob(path_images) 121 | filelist.sort(key=natural_keys) 122 | 123 | # Loads all data images in a list. 124 | data = [Image.open(fname) for fname in filelist] 125 | 126 | # Creates a list of file names in the labels directory. 127 | filelist = glob.glob(path_labels) 128 | filelist.sort(key=natural_keys) 129 | 130 | # Loads all labels images in a list. 131 | labels = [Image.open(fname) for fname in filelist] 132 | 133 | return data, labels 134 | 135 | # ----------------------------------------------------------------------------- 136 | 137 | def check_binary(labels): 138 | """ 139 | Checks if the given labels are binary or not. 140 | 141 | The variable "labels" correspond to a list of label images. 142 | """ 143 | 144 | # Initialize output variable. 145 | binary = True 146 | 147 | # Check every label. 148 | for k in range(len(labels)): 149 | 150 | # Number of unique values (should be = 2 for binary labels or > 2 for 151 | # categorical or non-binary data). 152 | n_unique = len(np.unique(np.array(labels[k]))) 153 | 154 | if n_unique > 2: 155 | binary = False 156 | 157 | # Raise exception if labels are constant images or not recognised. 158 | elif n_unique < 2: 159 | raise RuntimeError("Labels are neither binary or categorical.") 160 | 161 | return binary 162 | 163 | # ----------------------------------------------------------------------------- 164 | 165 | def make_binary(labels): 166 | """ 167 | Makes the given labels binary. 168 | 169 | The variable "labels" correspond to a list of label images. 170 | """ 171 | 172 | # For each label, convert the image to a numpy array, binarizes the array 173 | # and converts back the array to an image. 174 | for i in range(len(labels)): 175 | tmp = np.array(labels[i]) 176 | tmp[tmp > 0] = 255 177 | tmp[tmp == 0] = 0 178 | tmp = tmp.astype('uint8') 179 | tmp = Image.fromarray(tmp, 'L') 180 | labels[i] = tmp 181 | 182 | return labels 183 | 184 | # ----------------------------------------------------------------------------- 185 | 186 | def save_data(data, labels, path): 187 | """ 188 | Save images and labels. 189 | 190 | The variables 'data' and 'labels' refer to the processed images and labels. 191 | 192 | The string 'path' corresponds to the path where the images and labels will 193 | be saved. 194 | """ 195 | 196 | # Number of images. 197 | n_data = len(data) 198 | 199 | # Count number of digits in n_data. This is important for the number 200 | # of leading zeros in the name of the images and labels. 201 | n_digits = len(str(n_data)) 202 | 203 | # These represent the paths for the final label and images with the right 204 | # number of leading zeros given by n_digits. 205 | direc_d = path + "image/{b:0" + str(n_digits) + "d}.png" 206 | direc_l = path + "label/{b:0" + str(n_digits) + "d}.png" 207 | 208 | # Saves data and labels in the right folder. 209 | for i in range(len(data)): 210 | data[i].save(direc_d.format(b=i)) 211 | labels[i].save(direc_l.format(b=i)) 212 | 213 | return None 214 | 215 | # ----------------------------------------------------------------------------- 216 | 217 | def split_data(X, y, ratio=0.8, seed=1): 218 | """ 219 | The split_data function will shuffle data randomly as well as return 220 | a split data set that are individual for training and testing purposes. 221 | 222 | The input 'X' is a list of images. 223 | 224 | The input 'y' is a list of images with each image corresponding to the label 225 | of the corresponding sample in X. 226 | 227 | The 'ratio' variable is a float that sets the train set fraction of 228 | the entire dataset to this ratio and keeps the other part for test set. 229 | Default value set to 0.8. 230 | 231 | The 'seed' variable represents the seed value for the randomization of the 232 | process. Default value set to 1. 233 | """ 234 | 235 | # Set seed. 236 | np.random.seed(seed) 237 | 238 | # Perform shuffling. 239 | idx_shuffled = np.random.permutation(len(y)) 240 | 241 | # Return shuffled X and y. 242 | X_shuff = [X[i] for i in idx_shuffled] 243 | y_shuff = [y[i] for i in idx_shuffled] 244 | 245 | # Cut the data set into train and test. 246 | train_num = round(len(y) * ratio) 247 | X_train = X_shuff[:train_num] 248 | y_train = y_shuff[:train_num] 249 | X_test = X_shuff[train_num:] 250 | y_test = y_shuff[train_num:] 251 | 252 | return X_train, y_train, X_test, y_test 253 | 254 | # ----------------------------------------------------------------------------- 255 | 256 | def convertLabel(lab, threshold = 0.5): 257 | """ 258 | Converts the given label probability maps to a binary images using a specific 259 | threshold. 260 | 261 | The numpy array 'lab' correspond to label probability maps. 262 | 263 | The float 'threshold' corresponds to the threshold at which we binarize 264 | the probability map. Default value set to 0.5. 265 | """ 266 | 267 | # Converts the labels into boolean values using a threshold. 268 | label = lab[...,0] > threshold 269 | 270 | # Converts the boolean values into 0 and 1. 271 | label = label.astype(int) 272 | 273 | # Converts the labels to have values 0 and 255. 274 | label[label == 1] = 255 275 | 276 | return label 277 | 278 | # ----------------------------------------------------------------------------- 279 | 280 | def pred_accuracy(y_true, y_pred): 281 | """ 282 | Computes the prediction accuracy. 283 | 284 | The numpy array 'y_true' corresponds to the true label. 285 | 286 | The numpy array 'y_pred' corresponds to the predicted label. 287 | """ 288 | 289 | # Compares both the predictions and labels. 290 | compare = (y_true == y_pred) 291 | 292 | # Convert the resulting boolean values into 0 and 1. 293 | compare = compare.astype(int) 294 | 295 | # Computes the percentage of correct pixels. 296 | accuracy = np.sum(compare)/(len(y_true)**2) 297 | 298 | return accuracy 299 | 300 | # ----------------------------------------------------------------------------- 301 | 302 | def saveResults(save_path, results, convert = True, threshold = 0.5): 303 | """ 304 | Save the predicted arrays into a folder. 305 | 306 | The string 'save_path' corresponds to the path where the predicted images 307 | would be saved. 308 | 309 | The numpy array 'results' corresponds to the probability maps that were 310 | predicted with the model. 311 | 312 | The boolean 'convert' refers to whether or not the probability maps 313 | should be converted to binary arrays. Defaut value set to True. 314 | 315 | The float 'threshold' corresponds to the threshold at which we binarize 316 | the probability map. Default value set to 0.5. 317 | """ 318 | 319 | # Number of predictions. 320 | n_result = len(results) 321 | 322 | # Count number of digits in n_result. This is important for the number 323 | # of leading zeros in the name of the predictions. 324 | n_digits = len(str(n_result)) 325 | 326 | # These represent the paths for the predictions (binary or not) with the right 327 | # number of leading zeros given by n_digits. 328 | if convert: 329 | # Selects path for data and labels. 330 | direc_r = save_path + "result/{b:0" + str(n_digits) + "d}.tif" 331 | else: 332 | direc_r = save_path + "result_prob/{b:0" + str(n_digits) + "d}.tif" 333 | 334 | 335 | for i, lab in enumerate(results): 336 | 337 | if convert: 338 | # Converts the given label with a threshold. 339 | label = convertLabel(lab, threshold) 340 | 341 | else: 342 | label = lab[...,0] 343 | 344 | label = label.astype('float32') 345 | 346 | # Saves the label. 347 | imsave(direc_r.format(b=i), label) 348 | 349 | return None 350 | 351 | # ----------------------------------------------------------------------------- 352 | 353 | def make_weight_map(label, binary = True, w0 = 10, sigma = 5): 354 | """ 355 | Generates a weight map in order to make the U-Net learn better the 356 | borders of cells and distinguish individual cells that are tightly packed. 357 | These weight maps follow the methodololy of the original U-Net paper. 358 | 359 | The variable 'label' corresponds to a label image. 360 | 361 | The boolean 'binary' corresponds to whether or not the labels are 362 | binary. Default value set to True. 363 | 364 | The float 'w0' controls for the importance of separating tightly associated 365 | entities. Defaut value set to 10. 366 | 367 | The float 'sigma' represents the standard deviation of the Gaussian used 368 | for the weight map. Default value set to 5. 369 | """ 370 | 371 | # Initialization. 372 | lab = np.array(label) 373 | lab_multi = lab 374 | 375 | # Get shape of label. 376 | rows, cols = lab.shape 377 | 378 | if binary: 379 | 380 | # Converts the label into a binary image with background = 0 381 | # and cells = 1. 382 | lab[lab == 255] = 1 383 | 384 | 385 | # Builds w_c which is the class balancing map. In our case, we want cells to have 386 | # weight 2 as they are more important than background which is assigned weight 1. 387 | w_c = np.array(lab, dtype=float) 388 | w_c[w_c == 1] = 1 389 | w_c[w_c == 0] = 0.5 390 | 391 | # Converts the labels to have one class per object (cell). 392 | lab_multi = measure.label(lab, neighbors = 8, background = 0) 393 | 394 | else: 395 | 396 | # Converts the label into a binary image with background = 0. 397 | # and cells = 1. 398 | lab[lab > 0] = 1 399 | 400 | 401 | # Builds w_c which is the class balancing map. In our case, we want cells to have 402 | # weight 2 as they are more important than background which is assigned weight 1. 403 | w_c = np.array(lab, dtype=float) 404 | w_c[w_c == 1] = 1 405 | w_c[w_c == 0] = 0.5 406 | 407 | components = np.unique(lab_multi) 408 | 409 | n_comp = len(components)-1 410 | 411 | maps = np.zeros((n_comp, rows, cols)) 412 | 413 | map_weight = np.zeros((rows, cols)) 414 | 415 | if n_comp >= 2: 416 | for i in range(n_comp): 417 | 418 | # Only keeps current object. 419 | tmp = (lab_multi == components[i+1]) 420 | 421 | # Invert tmp so that it can have the correct distance. 422 | # transform 423 | tmp = ~tmp 424 | 425 | # For each pixel, computes the distance transform to 426 | # each object. 427 | maps[i][:][:] = scipy.ndimage.distance_transform_edt(tmp) 428 | 429 | maps = np.sort(maps, axis=0) 430 | 431 | # Get distance to the closest object (d1) and the distance to the second 432 | # object (d2). 433 | d1 = maps[0][:][:] 434 | d2 = maps[1][:][:] 435 | 436 | map_weight = w0*np.exp(-((d1+d2)**2)/(2*(sigma**2)) ) * (lab==0).astype(int); 437 | 438 | map_weight += w_c 439 | 440 | return map_weight 441 | 442 | # ----------------------------------------------------------------------------- 443 | 444 | def do_save_wm(labels, path, binary = True, w0 = 10, sigma = 5): 445 | """ 446 | Retrieves the label images, applies the weight-map algorithm and save the 447 | weight maps in a folder. 448 | 449 | The variable 'labels' corresponds to given label images. 450 | 451 | The string 'path' refers to the path where the weight maps should be saved. 452 | 453 | The boolean 'binary' corresponds to whether or not the labels are 454 | binary. Default value set to True. 455 | 456 | The float 'w0' controls for the importance of separating tightly associated 457 | entities. Default value set to 10. 458 | 459 | The float 'sigma' represents the standard deviation of the Gaussian used 460 | for the weight map. Default value set to 5. 461 | """ 462 | 463 | # Copy labels. 464 | labels_ = copy.deepcopy(labels) 465 | 466 | # Perform weight maps. 467 | for i in range(len(labels_)): 468 | labels_[i] = make_weight_map(labels[i].copy(), binary, w0, sigma) 469 | 470 | maps = np.array(labels_) 471 | 472 | n, rows, cols = maps.shape 473 | 474 | # Resize correctly the maps so that it can be used in the model. 475 | maps = maps.reshape((n, rows, cols, 1)) 476 | 477 | # Count number of digits in n. This is important for the number 478 | # of leading zeros in the name of the maps. 479 | n_digits = len(str(n)) 480 | 481 | # Save path with correct leading zeros. 482 | path_to_save = path + "weight/{b:0" + str(n_digits) + "d}.npy" 483 | 484 | # Saving files as .npy files. 485 | for i in range(len(labels_)): 486 | np.save(path_to_save.format(b=i), labels_[i]) 487 | 488 | return None 489 | 490 | # ----------------------------------------------------------------------------- 491 | 492 | #def make_distance_map(label): 493 | # """Generates a distance map from labels in order to test distance-map-based 494 | # U-Net training.""" 495 | # 496 | # lab = np.array(label) 497 | # 498 | # # Converts the label into a binary image with background = 0 499 | # # and cells = 1. 500 | # lab[lab == 255] = 1 501 | # 502 | # # Applies distance transform 503 | # output = cv2.distanceTransform(lab, cv2.DIST_C, 3) 504 | # 505 | # # Finds minimal cell size 506 | # size = 0 507 | # all_dist = np.unique(output) 508 | # blobbed_lab = measure.label(lab, neighbors = 8, background = 0) 509 | # number_blobs = np.max(blobbed_lab) 510 | # for i in all_dist[1:]: 511 | # tmp = (output >= i).astype(int) 512 | # blobbed_lab = measure.label(tmp, neighbors = 8, background = 0) 513 | # if number_blobs <= np.max(blobbed_lab): 514 | # size = i 515 | # 516 | # return output, size 517 | # 518 | ## ----------------------------------------------------------------------------- 519 | # 520 | #def do_make_dm(path): 521 | # """Retrieves the label images, applies the distance transform and save the 522 | # maps in the right folder.""" 523 | # 524 | # path_to_labels = path + "/label/*.png" 525 | # 526 | # # Creates a list of file names in the labels directory 527 | # filelist = glob.glob(path_to_labels) 528 | # filelist.sort(key=natural_keys) 529 | # 530 | # # Loads all data images in a list 531 | # labels = [Image.open(fname).resize((256,256)) for fname in filelist] 532 | # 533 | # # Copy labels 534 | # labels_ = labels 535 | # 536 | # # Vector of sizes 537 | # sizes = [] 538 | # 539 | # # Do maps 540 | # print("Doing distance maps") 541 | # for i in range(len(labels_)): 542 | # labels_[i], size = make_distance_map(labels_[i]) 543 | # sizes.append(size) 544 | # print("Maps done") 545 | # 546 | # min_size = np.min(np.array(sizes)) 547 | # print("Min size : {b}".format(b=min_size)) 548 | # print(sizes) 549 | # 550 | # maps = np.array(labels_) 551 | # 552 | # maps[maps >= min_size] = min_size 553 | # 554 | # n, rows, cols = maps.shape 555 | # 556 | # # Makes sure the data is saved with one leading zero. 557 | # if (n < 100): 558 | # 559 | # # Selects path for data and labels 560 | # direc_r = path + "/distance/{b:02d}.png" 561 | # 562 | # # If we have more than 100 images, we would have 2 leading zeros. 563 | # # We have 148 images, so there is no point doing other cases. 564 | # else: 565 | # 566 | # # Selects path for data and labels 567 | # direc_r = path + "/distance/{b:03d}.png" 568 | # 569 | # for i, lab in enumerate(maps): 570 | # 571 | # label = lab.astype('uint8') 572 | # label = Image.fromarray(label, 'L') 573 | # 574 | # # Saves the label 575 | # label.save(direc_r.format(b=i)) 576 | # 577 | # return None 578 | # 579 | ## ----------------------------------------------------------------------------- 580 | # 581 | #def make_three_class (label): 582 | # 583 | # lab = np.array(label) 584 | # 585 | # # Get shape of label 586 | # rows, cols = lab.shape 587 | # 588 | # components = np.unique(lab) 589 | # 590 | # n_comp = len(components)-1 591 | # 592 | # output = np.zeros((rows, cols)) 593 | # 594 | # for i in range(n_comp): 595 | # 596 | # # Only keeps current object 597 | # tmp = (lab == components[i+1]).astype('float32') 598 | # 599 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)) 600 | # 601 | # eroded_tmp = cv2.erode(tmp, kernel, iterations = 1) 602 | # 603 | # border = tmp - eroded_tmp 604 | # 605 | # output[border > 0] = 1 606 | # output[eroded_tmp > 0] = 2 607 | # 608 | # output = output.astype('uint8') 609 | # output = Image.fromarray(output, 'L') 610 | # 611 | # return output -------------------------------------------------------------------------------- /unet/py_files/model.py: -------------------------------------------------------------------------------- 1 | # Important librairies 2 | 3 | import numpy as np 4 | import os 5 | import skimage.io as io 6 | import skimage.transform as trans 7 | import numpy as np 8 | from keras.models import * 9 | from keras.layers import * 10 | from keras.optimizers import * 11 | from keras.callbacks import * 12 | from keras import backend as keras 13 | 14 | # ----------------------------------------------------------------------------- 15 | 16 | def jaccard_distance(y_true, y_pred, smooth=100): 17 | """Intersection-over-union loss (Jaccard distance).""" 18 | intersection = K.sum(K.abs(y_true * y_pred), axis=-1) 19 | sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) 20 | jac = (intersection + smooth) / (sum_ - intersection + smooth) 21 | return (1 - jac) * smooth 22 | 23 | # ----------------------------------------------------------------------------- 24 | 25 | def unet(input_size = (256,256,1), loss_ = 'binary_crossentropy', learning_rate = 1e-4, weight_decay = 5e-7): 26 | """ 27 | Simple U-net architecture. 28 | 29 | The tuple 'input_size' corresponds to the size of the input images and labels. 30 | Default value set to (256, 256, 1) (input images size is 256x256). 31 | 32 | The string 'loss_' represents the name of the loss that should be used. 33 | Default value set to 'binary_crossentropy'. 34 | 35 | The float 'learning_rate' corresponds to the learning rate value for the training. 36 | Defaut value set to 1e-4. 37 | 38 | The float 'weight_decay' corresponds to the weight decay value for the training. 39 | Default value set to 5e-7. 40 | """ 41 | 42 | # Get input. 43 | input_img = Input(input_size) 44 | 45 | # Layer 1. 46 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img) 47 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 48 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 49 | 50 | # Layer 2. 51 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 52 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 53 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 54 | 55 | # Layer 3. 56 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 57 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 58 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 59 | 60 | # Layer 4. 61 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 62 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 63 | drop4 = Dropout(0.5)(conv4) 64 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 65 | 66 | # layer 5. 67 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 68 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 69 | drop5 = Dropout(0.5)(conv5) 70 | 71 | # Layer 6. 72 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 73 | merge6 = concatenate([drop4,up6], axis = 3) 74 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 75 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 76 | 77 | # Layer 7. 78 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 79 | merge7 = concatenate([conv3,up7], axis = 3) 80 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 81 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 82 | 83 | # Layer 8. 84 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 85 | merge8 = concatenate([conv2,up8], axis = 3) 86 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 87 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 88 | 89 | # Layer 9. 90 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 91 | merge9 = concatenate([conv1,up9], axis = 3) 92 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 93 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 94 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 95 | 96 | # Final layer (output). 97 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 98 | 99 | # Specify input and output. 100 | model = Model(inputs = input_img, outputs = conv10) 101 | 102 | # Use Adam optimizer, binary cross-entropy loss and specify metrics. 103 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = loss_, metrics = ['accuracy']) 104 | 105 | return model 106 | 107 | 108 | -------------------------------------------------------------------------------- /unet/py_files/prep_data.py: -------------------------------------------------------------------------------- 1 | # Important librairies 2 | 3 | import numpy 4 | import sys 5 | import os 6 | import pandas as pd 7 | 8 | # ----------------------------------------------------------------------------- 9 | 10 | # Important py.files 11 | sys.path.append("/content/drive/My Drive/unser_project/py_files") 12 | from helpers import * 13 | from data_loading import * 14 | from fit_model import * 15 | 16 | # ----------------------------------------------------------------------------- 17 | 18 | def read_config(name_project, name_config): 19 | """ 20 | Reads configuration file from DeepPix Workflow plug-in. 21 | 22 | The string 'name_project' refers to the name of the DeepPix Worflow project 23 | given by the user. 24 | 25 | The string 'name_config' refers to the name of the configuration file given 26 | by the user. 27 | """ 28 | 29 | # Builds path to the configuration file. 30 | path_to_config = '/content/drive/My Drive/unser_project/data/raw/' + name_project + '/' + name_config + '-training-settings.txt' 31 | 32 | # Load configuration file 33 | df = pd.read_table(path_to_config, header = None, delimiter = '=', dtype = str, skiprows = 5) 34 | 35 | input_array = [] 36 | 37 | # Process input dataframe. 38 | for i in range(df.shape[0]): 39 | input_array.append(df[1][i][1:]) 40 | 41 | # The following code allocates the input configurations to variables that 42 | # will be used for the rest of the program. 43 | 44 | label_type = input_array[3] 45 | 46 | size = input_array[4] 47 | target_size = () 48 | 49 | if size == "256x256": 50 | target_size = (256, 256) 51 | 52 | elif size == "512x512": 53 | target_size = (512, 512) 54 | 55 | elif size == "1024x1024": 56 | target_size = (1024, 1024) 57 | 58 | else: 59 | raise RuntimeError("Input size unknown") 60 | 61 | model = input_array[5] 62 | model_type = "" 63 | 64 | if model == "Simple U-Net": 65 | model_type = "unet_simple" 66 | 67 | elif model == "Weighted U-Net": 68 | model_type = "unet_weighted" 69 | 70 | split_ratio = float(input_array[6])/100 71 | 72 | batch_size = int(input_array[7]) 73 | 74 | learning_rate = float(input_array[8])*1e-5 75 | 76 | return label_type, target_size, model_type, split_ratio, batch_size, learning_rate 77 | 78 | # ----------------------------------------------------------------------------- 79 | 80 | def prep_data(name_project, model, label_type, split_ratio, w0 = None, sigma = None): 81 | """ 82 | Prepares the data by randomizing the images and binarizing them if needed. 83 | 84 | The string 'name_project' refers to the name of the DeepPix Worflow project 85 | given by the user. 86 | 87 | The string 'model' refers to the type of model that will be used. 88 | 89 | The string 'label_type' corresponds to the type of model used, either 90 | categorical or binary. 91 | 92 | The float 'split_ratio' corresponds to the splitting ratio for the 93 | training and testing set. 94 | 95 | The float 'w0' corresponds to a constant used for the weighted U-Net. 96 | Default value set to None. 97 | 98 | The float 'sigma' corresponds to a constant used for the weighted U-Net. 99 | Default value set to None. 100 | """ 101 | 102 | print("Initialization of preparation of data.") 103 | 104 | # Constructs useful paths. 105 | 106 | # Path for data. 107 | data_path = "/content/drive/My Drive/unser_project/data/" 108 | 109 | # Paths for raw data and labels. 110 | path_data = data_path + "raw/" + name_project + "/image/*.tif" 111 | path_labels = data_path + "raw/" + name_project + "/label/*.tif" 112 | 113 | # Paths for train and test directories. 114 | train_path = data_path + "processed/" + name_project + "/train/" 115 | test_path = data_path + "processed/" + name_project + "/test/" 116 | 117 | # Load data and labels. 118 | print("Loading data and labels.") 119 | data, labels = load_data(path_data, path_labels) 120 | print("Loading successful.") 121 | 122 | print("Label type check and binarization if needed.") 123 | # Checks if labels are binary or categorical. 124 | binary = check_binary(labels) 125 | 126 | # Check which model is desired and binarizes labels or not depending on the model. 127 | if model == "unet_simple": 128 | 129 | if not binary: 130 | labels = make_binary(labels) 131 | 132 | elif model == "unet_weighted": 133 | 134 | if label_type == "categorical": 135 | 136 | if binary: 137 | raise RuntimeError("Labels are said to be categorical but they are not categorical.") 138 | 139 | elif label_type == "binary": 140 | 141 | if not binary: 142 | labels = make_binary(labels) 143 | 144 | else: 145 | raise RuntimeError("Labels are neither categorical or binary.") 146 | 147 | else: 148 | raise RuntimeError("Model type not recognised.") 149 | 150 | print("Splitting data") 151 | X_train, y_train, X_test, y_test = split_data(data, labels, ratio = split_ratio) 152 | 153 | if not os.path.exists(train_path + 'image'): 154 | os.makedirs(train_path + 'image') 155 | 156 | if not os.path.exists(train_path + 'label'): 157 | os.makedirs(train_path + 'label') 158 | 159 | if not os.path.exists(test_path + 'image'): 160 | os.makedirs(test_path + 'image') 161 | 162 | if not os.path.exists(test_path + 'label'): 163 | os.makedirs(test_path + 'label') 164 | 165 | if model == "unet_weighted": 166 | 167 | if not os.path.exists(train_path + 'weight'): 168 | os.makedirs(train_path + 'weight') 169 | 170 | if not os.path.exists(test_path + 'weight'): 171 | os.makedirs(test_path + 'weight') 172 | 173 | not_connected = True 174 | 175 | if label_type == "categorical": 176 | not_connected = False 177 | 178 | print("Constructing weight maps.") 179 | do_save_wm(y_train, train_path, not_connected = not_connected, w0 = w0, sigma = sigma) 180 | do_save_wm(y_test, test_path, not_connected = not_connected, w0 = w0, sigma = sigma) 181 | print("Weight maps achieved") 182 | 183 | print("Saving data.") 184 | save_data(X_train, y_train, train_path) 185 | save_data(X_test, y_test, test_path) 186 | print("Preparation of data completed.") -------------------------------------------------------------------------------- /unet/py_files/unet_dm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import skimage.io as io 4 | import skimage.transform as trans 5 | import numpy as np 6 | from keras.models import * 7 | from keras.layers import * 8 | from keras.optimizers import * 9 | from keras.callbacks import * 10 | from keras import backend as keras 11 | 12 | def unet_distance(input_size = (256,256,1), learning_rate = 1e-4, weight_decay = 5e-7): 13 | """Simple U-net architecture with distance maps. """ 14 | 15 | # Get input 16 | input_img = Input(input_size) 17 | 18 | # Layer 1 19 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img) 20 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 21 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 22 | 23 | # Layer 2 24 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 25 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 26 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 27 | 28 | # Layer 3 29 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 30 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 31 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 32 | 33 | # Layer 4 34 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 35 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 36 | drop4 = Dropout(0.5)(conv4) 37 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 38 | 39 | # layer 5 40 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 41 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 42 | drop5 = Dropout(0.5)(conv5) 43 | 44 | # Layer 6 45 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 46 | merge6 = concatenate([drop4,up6], axis = 3) 47 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 48 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 49 | 50 | # Layer 7 51 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 52 | merge7 = concatenate([conv3,up7], axis = 3) 53 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 54 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 55 | 56 | # Layer 8 57 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 58 | merge8 = concatenate([conv2,up8], axis = 3) 59 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 60 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 61 | 62 | # Layer 9 63 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 64 | merge9 = concatenate([conv1,up9], axis = 3) 65 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 66 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 67 | conv9 = Conv2D(3, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 68 | 69 | reshape = Reshape((3, 256 * 256), input_shape = (3, 256, 256))(conv9) 70 | 71 | permute = Permute((2,1))(reshape) 72 | 73 | activation = Activation('softmax')(permute) 74 | 75 | # Specify input and output 76 | model = Model(inputs = input_img, outputs = activation) 77 | 78 | # Use Adam optimizer, binary cross-entropy loss and specify metrics 79 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = 'categorical_crossentropy', metrics = ['accuracy']) 80 | 81 | return model 82 | 83 | 84 | -------------------------------------------------------------------------------- /unet/py_files/unet_weights.py: -------------------------------------------------------------------------------- 1 | # Important librairies 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | import numpy as np 7 | import os 8 | import skimage.io as io 9 | import skimage.transform as trans 10 | import numpy as np 11 | from keras.models import * 12 | from keras.layers import * 13 | from keras.optimizers import * 14 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard 15 | from keras import backend as K 16 | import tensorflow as tf 17 | 18 | # ----------------------------------------------------------------------------- 19 | 20 | def binary_crossentropy_weighted(weights): 21 | """ 22 | Custom binary cross entropy loss. The weights are used to multiply 23 | the results of the usual cross-entropy loss in order to give more weight 24 | to areas between cells close to one another. 25 | 26 | The variable 'weights' refers to input weight-maps. 27 | """ 28 | 29 | def loss(y_true, y_pred): 30 | 31 | return K.mean(weights * K.binary_crossentropy(y_true, y_pred), axis=-1) 32 | 33 | return loss 34 | 35 | # ----------------------------------------------------------------------------- 36 | 37 | def unet_weights(input_size = (256,256,1), learning_rate = 1e-4, weight_decay = 5e-7): 38 | """ 39 | Weighted U-net architecture. 40 | 41 | The tuple 'input_size' corresponds to the size of the input images and labels. 42 | Default value set to (256, 256, 1) (input images size is 256x256). 43 | 44 | The float 'learning_rate' corresponds to the learning rate value for the training. 45 | Defaut value set to 1e-4. 46 | 47 | The float 'weight_decay' corresponds to the weight decay value for the training. 48 | Default value set to 5e-7. 49 | """ 50 | 51 | # Get input. 52 | input_img = Input(input_size) 53 | 54 | # Get weights. 55 | weights = Input(input_size) 56 | 57 | # Layer 1. 58 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(input_img) 59 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 60 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 61 | 62 | # Layer 2. 63 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 64 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 65 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 66 | 67 | # Layer 3. 68 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 69 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 70 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 71 | 72 | # Layer 4. 73 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 74 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 75 | drop4 = Dropout(0.5)(conv4) 76 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 77 | 78 | # layer 5. 79 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 80 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 81 | drop5 = Dropout(0.5)(conv5) 82 | 83 | # Layer 6. 84 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 85 | merge6 = concatenate([drop4,up6], axis = 3) 86 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 87 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 88 | 89 | # Layer 7. 90 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 91 | merge7 = concatenate([conv3,up7], axis = 3) 92 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 93 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 94 | 95 | # Layer 8. 96 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 97 | merge8 = concatenate([conv2,up8], axis = 3) 98 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 99 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 100 | 101 | # Layer 9. 102 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 103 | merge9 = concatenate([conv1,up9], axis = 3) 104 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 105 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 106 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 107 | 108 | # Final layer (output). 109 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 110 | 111 | # Specify input (image + weights) and output. 112 | model = Model(inputs = [input_img, weights], outputs = conv10) 113 | 114 | # Use Adam optimizer, custom weighted binary cross-entropy loss and specify metrics 115 | # Also use weights inside the loss function. 116 | model.compile(optimizer = Adam(lr = learning_rate, decay = weight_decay), loss = binary_crossentropy_weighted(weights), metrics = ['accuracy']) 117 | 118 | return model 119 | 120 | 121 | -------------------------------------------------------------------------------- /unet/train_and_test_unet.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"unet_simple_github.ipynb","version":"0.3.2","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"colab_type":"text","id":"RSSaG4-n1qcO"},"source":["# **Mounts your drive**"]},{"cell_type":"code","metadata":{"colab_type":"code","id":"BAEopvk_l7wg","colab":{}},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NL0Ji_OcLD-z","colab_type":"text"},"source":["Prepare data\n","==="]},{"cell_type":"code","metadata":{"id":"rhK06MkxLDTP","colab_type":"code","colab":{}},"source":["!pip install tifffile"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"kcdpETWxLDo7","colab_type":"code","colab":{}},"source":["# Functions\n","import numpy\n","import sys\n","sys.path.append(\"/content/drive/My Drive/.../unet_segmentation/py_files\") # path to py_files folder\n","from helpers import *\n","from data_loading import *\n","\n","# Autoreload\n","%load_ext autoreload\n","%autoreload 2\n","\n","# Set random seed\n","np.random.seed(1)\n","\n","# Load raw data from the Cell Tracking Challenge http://celltrackingchallenge.net/2d-datasets/. Download first the data.\n","data, labels = load_data(\"/content/drive/My Drive/.../unet_segmentation/data/raw/hela/image/*.tif\",\n"," \"/content/drive/My Drive/.../unet_segmentation/data/raw/hela/label/*.tif\") # Set correct paths\n","for i in range(len(labels)):\n"," tmp = np.array(labels[i])\n"," tmp[tmp > 0] = 255\n"," tmp[tmp == 0] = 0\n"," tmp = tmp.astype('uint8')\n"," tmp = Image.fromarray(tmp, 'L')\n"," labels[i] = tmp\n"," \n","# Split the data into train and test\n","X_train, y_train, X_test, y_test = split_data(data, labels, ratio = 0.5)\n","\n","# Set the paths and create the folders to save preprocessed data as .png\n","TRAIN_DIR=\"/content/drive/My Drive/unser_project/data/processed/hela/train/\"\n","TEST_DIR=\"/content/drive/My Drive/unser_project/data/processed/hela/test/\"\n","\n","if not os.path.exists(TRAIN_DIR):\n"," os.makedirs(TRAIN_DIR+\"/image/\")\n"," os.makedirs(TRAIN_DIR+\"/label/\")\n"," \n","if not os.path.exists(TEST_DIR):\n"," os.makedirs(TEST_DIR+\"/image/\")\n"," os.makedirs(TEST_DIR+\"/label/\")\n"," \n","# Save train and test files\n","save_data(X_train, y_train, TRAIN_DIR)\n","save_data(X_test, y_test, TEST_DIR)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"colab_type":"text","id":"1_qBzyHg1ohr"},"source":["# **Imports modules**"]},{"cell_type":"code","metadata":{"colab_type":"code","id":"sk4cEzphpv4_","colab":{}},"source":["import sys\n","sys.path.append(\"/content/drive/My Drive/.../unet_segmentation/py_files\") # path to py_files folder\n","!pip install tifffile\n","!pip install --upgrade tensorflow\n","!pip install --upgrade keras\n","from model import *\n","from convert_to_pb import *\n","from data_loading import *\n","from helpers import *\n","from unet_weights import *\n","from fit_model import *\n","%matplotlib inline\n","import matplotlib.pyplot as plt\n","import matplotlib\n","from PIL import Image, ImageOps, ImageFilter\n","import pickle\n","from test import *\n","import cv2 as cv\n","\n","# Autoreload\n","%load_ext autoreload\n","%autoreload 2\n","%reload_ext autoreload\n","\n","# Set random seed\n","np.random.seed(1)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ifxpZTDtmyPn","colab_type":"text"},"source":["Train U-Net\n","==="]},{"cell_type":"markdown","metadata":{"id":"fqOD0FTkJMCT","colab_type":"text"},"source":["## Hela cells"]},{"cell_type":"code","metadata":{"id":"Kb0aNeIKy6cj","colab_type":"code","colab":{}},"source":["# Load training and validation data\n","# Note that the subset of generator used for the training generator is \"validation\" because we don't want to augment our data\n","# Specify paths where inside there are \"image\" and \"label\" folder\n","trainGen = dataGenerator(batch_size = 2, subset = \"train\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/train')\n","validGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"6A5LXMcPQUV9","colab_type":"code","colab":{}},"source":["model = unet()\n","\n","# Callbacks\n","model_checkpoint = ModelCheckpoint('/content/drive/My Drive/.../unet_segmentation/models/{b}.hdf5'.format(b=\"unet_hela\"), monitor='val_loss', verbose=1, save_best_only=True)\n","\n","# Fit\n","history = model.fit_generator(trainGen,\n"," steps_per_epoch=500,\n"," epochs=1,\n"," callbacks=[model_checkpoint], \n"," validation_data = validGen, \n"," validation_steps = 9)"],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"bWigI7T0Mlyp","colab_type":"text"},"source":["Results\n","==="]},{"cell_type":"markdown","metadata":{"id":"59BUzz8bc_GN","colab_type":"text"},"source":["## Hela cells"]},{"cell_type":"code","metadata":{"id":"xf_rj9545A2j","colab_type":"code","colab":{}},"source":["# Define paths.\n"," path_to_model = '/content/drive/My Drive/.../unet_segmentation/models/unet_hela.hdf5'\n"," \n"," # Load model.\n"," model = load_model(path_to_model)\n"," # Load training and validation data\n"," # Note that the subset of generator used for the training generator is \"validation\" because we don't want to augment our data\n"," print(\"Validation\")\n"," validGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')\n"," \n"," accuracies = model.evaluate_generator(validGen, steps=9, verbose=1) \n"," print(accuracies)\n"," \n"," print(\"Training\")\n"," trainGen = dataGenerator(batch_size = 1, subset = \"validation\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/train')\n"," \n"," accuracies = model.evaluate_generator(validGen, steps=8, verbose=1) \n"," print(accuracies)\n"," "],"execution_count":0,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"UoRhKGShIvke","colab_type":"text"},"source":["Prediction\n","==="]},{"cell_type":"code","metadata":{"id":"7zq_yID3jP8G","colab_type":"code","colab":{}},"source":["import tensorflow as tf"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"vccfXzFJIu6z","colab_type":"code","colab":{}},"source":["from tensorflow.contrib.saved_model import save_keras_model\n","import tensorflow.keras\n","from keras.models import load_model\n","testGen = dataGenerator(batch_size = 1, subset = \"test\", path = '/content/drive/My Drive/.../unet_segmentation/data/processed/hela/test')\n","model = unet()\n","model = load_model('/content/drive/My Drive/.../unet_segmentation/models/unet_hela.hdf5')\n","results = model.predict_generator(testGen,9,verbose=1, workers=1)\n","#saveResults('/content/drive/My Drive/.../unet_segmentation/data/hela/test/', results, convert = True)\n","#saveResults('/content/drive/My Drive/.../unet_segmentation/data/hela/test/', results, convert = False)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"wIRWpNL0XC_Q","colab_type":"code","colab":{}},"source":["from sklearn.metrics import jaccard_similarity_score\n"," \n","acc_tot = [];\n","\n","for i in range(9):\n"," label = cv.imread('/content/drive/My Drive/.../unet_segmentation/data/hela/test/label/0{b}.png'.format(b=i))\n"," label = cv.resize(label, (256,256))\n"," acc = jaccard_similarity_score(label[...,0].flatten(), convertLabel(results[i]).flatten())\n"," acc_tot.append(acc)\n"," \n","print(\"Jaccard average : {b}\".format(b=np.mean(acc_tot)))"],"execution_count":0,"outputs":[]}]} -------------------------------------------------------------------------------- /xml/config_template.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /xml/create_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeepImageJ 3 | 4 | https://deepimagej.github.io/deepimagej/ 5 | 6 | Conditions of use: 7 | 8 | DeepImageJ is an open source software (OSS): you can redistribute it and/or modify it under 9 | the terms of the BSD 2-Clause License. 10 | 11 | In addition, we strongly encourage you to include adequate citations and acknowledgments 12 | whenever you present or publish results that are based on it. 13 | 14 | DeepImageJ is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; 15 | without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 16 | 17 | You should have received a copy of the BSD 2-Clause License along with DeepImageJ. 18 | If not, see . 19 | 20 | 21 | Reference: 22 | 23 | DeepImageJ: A user-friendly plugin to run deep learning models in ImageJ 24 | E. Gomez-de-Mariscal, C. Garcia-Lopez-de-Haro, L. Donati, M. Unser, A. Munoz-Barrutia, D. Sage. 25 | Submitted 2019. 26 | 27 | Bioengineering and Aerospace Engineering Department, Universidad Carlos III de Madrid, Spain 28 | Biomedical Imaging Group, Ecole polytechnique federale de Lausanne (EPFL), Switzerland 29 | 30 | Corresponding authors: mamunozb@ing.uc3m.es, daniel.sage@epfl.ch 31 | 32 | Copyright 2019. Universidad Carlos III, Madrid, Spain and EPFL, Lausanne, Switzerland. 33 | 34 | """ 35 | 36 | import os 37 | import xml.etree.ElementTree as ET 38 | import time 39 | import numpy as np 40 | import urllib 41 | import shutil 42 | from skimage import io 43 | 44 | """ 45 | Download the template from this link: 46 | https://raw.githubusercontent.com/esgomezm/python4deepimagej/yaml/yaml/config_template.xml 47 | TensorFlow library is needed. It is imported later to save the model as a SavedModel protobuffer 48 | 49 | Try to check TensorFlow version and read DeepImageJ's compatibility requirements. 50 | 51 | import tensorflow as tf 52 | tf.__version__ 53 | ---------------------------------------------------- 54 | Example: 55 | ---------------------------------------------------- 56 | dij_config = DeepImageJConfig(model) 57 | # Update model information 58 | dij_config.Authors = authors 59 | dij_config.Credits = credits 60 | 61 | # Add info about the minimum size in case it is not fixed. 62 | pooling_steps = 0 63 | for keras_layer in model.layers: 64 | if keras_layer.name.startswith('max') or "pool" in keras_layer.name: 65 | pooling_steps += 1 66 | dij_config.MinimumSize = np.str(2**(pooling_steps)) 67 | 68 | # Add the information about the test image 69 | dij_config.add_test_info(test_img, test_prediction, PixelSize) 70 | 71 | ## Prepare preprocessing file 72 | path_preprocessing = "PercentileNormalization.ijm" 73 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/PercentileNormalization.ijm", path_preprocessing ) 74 | # Include the info about the preprocessing 75 | dij_config.add_preprocessing(path_preprocessing, "preprocessing") 76 | 77 | ## Prepare postprocessing file 78 | path_postprocessing = "8bitBinarize.ijm" 79 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm", path_postprocessing ) 80 | # Include the info about the postprocessing 81 | post_processing_name = "postprocessing_LocalMaximaSMLM" 82 | dij_config.add_postprocessing(path_postprocessing_max,post_processing_name) 83 | 84 | ## EXPORT THE MODEL 85 | deepimagej_model_path = os.path.join(QC_model_folder, 'deepimagej') 86 | dij_config.export_model(model, deepimagej_model_path) 87 | ---------------------------------------------------- 88 | Example: change one line in an ImageJ macro 89 | ---------------------------------------------------- 90 | ## Prepare postprocessing file 91 | path_postprocessing = "8bitBinarize.ijm" 92 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/imagej-macros/master/8bitBinarize.ijm", path_postprocessing ) 93 | # Modify the threshold in the macro to the chosen threshold 94 | ijmacro = open(path_postprocessing,"r") 95 | list_of_lines = ijmacro. readlines() 96 | # Line 21 is the one corresponding to the optimal threshold 97 | list_of_lines[21] = "optimalThreshold = {}\n".format(128) 98 | ijmacro.close() 99 | ijmacro = open(path_postprocessing,"w") 100 | ijmacro. writelines(list_of_lines) 101 | ijmacro. close() 102 | """ 103 | 104 | class DeepImageJConfig: 105 | def __init__(self, tf_model): 106 | # ModelInformation 107 | self.Name = 'null' 108 | self.Authors = 'null' 109 | self.URL = 'null' 110 | self.Credits = 'null' 111 | self.Version = 'null' 112 | self.References = 'null' 113 | self.Date = time.ctime() 114 | # Same value as 2**pooling_steps 115 | # (related to encoder-decoder archtiectures) when the input size is not 116 | # fixed 117 | self.MinimumSize = '8' 118 | self.get_dimensions(tf_model) 119 | # Receptive field of the network to process input 120 | self.Padding = np.str(self._pixel_half_receptive_field(tf_model)) 121 | self.Preprocessing = list() 122 | self.Postprocessing = list() 123 | self.Preprocessing_files = list() 124 | self.Postprocessing_files = list() 125 | 126 | def get_dimensions(self, tf_model): 127 | """ 128 | Calculates the array organization and shapes of inputs and outputs. 129 | """ 130 | input_dim = tf_model.input_shape 131 | output_dim = tf_model.output_shape 132 | # Deal with the order of the dimensions and whether the size is fixed 133 | # or not 134 | if input_dim[2] is None: 135 | self.FixedPatch = 'false' 136 | self.PatchSize = self.MinimumSize 137 | if input_dim[-1] is None: 138 | self.InputOrganization0 = 'NCHW' 139 | self.Channels = np.str(input_dim[1]) 140 | else: 141 | self.InputOrganization0 = 'NHWC' 142 | self.Channels = np.str(input_dim[-1]) 143 | 144 | if output_dim[-1] is None: 145 | self.OutputOrganization0 = 'NCHW' 146 | else: 147 | self.OutputOrganization0 = 'NHWC' 148 | else: 149 | self.FixedPatch = 'true' 150 | self.PatchSize = np.str(input_dim[2]) 151 | 152 | if input_dim[-1] < input_dim[-2] and input_dim[-1] < input_dim[-3]: 153 | self.InputOrganization0 = 'NHWC' 154 | self.Channels = np.str(input_dim[-1]) 155 | else: 156 | self.InputOrganization0 = 'NCHW' 157 | self.Channels = np.str(input_dim[1]) 158 | 159 | if output_dim[-1] < output_dim[-2] and output_dim[-1] < output_dim[-3]: 160 | self.OutputOrganization0 = 'NHWC' 161 | else: 162 | self.OutputOrganization0 = 'NCHW' 163 | 164 | # Adapt the format from brackets to parenthesis 165 | input_dim = np.str(input_dim) 166 | input_dim = input_dim.replace('(', ',') 167 | input_dim = input_dim.replace(')', ',') 168 | input_dim = input_dim.replace('None', '-1') 169 | input_dim = input_dim.replace(' ', "") 170 | self.InputTensorDimensions = input_dim 171 | 172 | def _pixel_half_receptive_field(self, tf_model): 173 | """ 174 | The halo is equivalent to the receptive field of one pixel. This value 175 | is used for image reconstruction when a entire image is processed. 176 | """ 177 | input_shape = tf_model.input_shape 178 | 179 | if self.FixedPatch == 'false': 180 | min_size = 50*np.int(self.MinimumSize) 181 | 182 | if self.InputOrganization0 == 'NHWC': 183 | null_im = np.zeros((1, min_size, min_size, input_shape[-1]) 184 | , dtype=np.float32) 185 | else: 186 | null_im = np.zeros((1, input_shape[1], min_size, min_size) 187 | , dtype=np.float32) 188 | else: 189 | null_im = np.zeros((input_shape[1:]) 190 | , dtype=np.float32) 191 | null_im = np.expand_dims(null_im, axis=0) 192 | min_size = np.int(self.PatchSize) 193 | 194 | point_im = np.zeros_like(null_im) 195 | min_size = np.int(min_size/2) 196 | 197 | if self.InputOrganization0 == 'NHWC': 198 | point_im[0,min_size,min_size] = 1 199 | else: 200 | point_im[0,:,min_size,min_size] = 1 201 | 202 | result_unit = tf_model.predict(np.concatenate((null_im, point_im))) 203 | 204 | D = np.abs(result_unit[0]-result_unit[1])>0 205 | 206 | if self.InputOrganization0 == 'NHWC': 207 | D = D[:,:,0] 208 | else: 209 | D = D[0,:,:] 210 | 211 | ind = np.where(D[:min_size,:min_size]==1) 212 | halo = np.min(ind[1]) 213 | halo = min_size-halo+1 214 | 215 | return halo 216 | 217 | class TestImage: 218 | def __add__(self, input_im, output_im, pixel_size): 219 | """ 220 | pixel size must be given in microns 221 | """ 222 | self.Input_shape = '{0}x{1}'.format(input_im.shape[0], input_im.shape[1]) 223 | self.InputImage = input_im 224 | self.Output_shape = '{0}x{1}'.format(output_im.shape[0], output_im.shape[1]) 225 | self.OutputImage = output_im 226 | self.MemoryPeak = 'null' 227 | self.Runtime = 'null' 228 | self.PixelSize = '{0}µmx{1}µm'.format(pixel_size, pixel_size) 229 | 230 | def add_test_info(self, input_im, output_im, pixel_size): 231 | self.test_info = self.TestImage() 232 | self.test_info.__add__(input_im, output_im, pixel_size) 233 | 234 | def add_preprocessing(self, file, name): 235 | file_extension = file.split('.')[-1] 236 | name = name + '.' + file_extension 237 | if name.startswith('preprocessing'): 238 | self.Preprocessing.insert(len(self.Preprocessing),name) 239 | else: 240 | name = "preprocessing_"+name 241 | self.Preprocessing.insert(len(self.Preprocessing),name) 242 | self.Preprocessing_files.insert(len(self.Preprocessing_files), file) 243 | 244 | def add_postprocessing(self, file, name): 245 | file_extension = file.split('.')[-1] 246 | name = name + '.' + file_extension 247 | if name.startswith('postprocessing'): 248 | self.Postprocessing.insert(len(self.Postprocessing), name) 249 | else: 250 | name = "postprocessing_" + name 251 | self.Postprocessing.insert(len(self.Postprocessing), name) 252 | self.Postprocessing_files.insert(len(self.Postprocessing_files), file) 253 | 254 | 255 | def export_model(self, tf_model,deepimagej_model_path, **kwargs): 256 | """ 257 | Main function to export the model as a bundled model of DeepImageJ 258 | tf_model: tensorflow/keras model 259 | deepimagej_model_path: directory where DeepImageJ model is stored. 260 | """ 261 | # Save the mode as protobuffer 262 | self.save_tensorflow_pb(tf_model, deepimagej_model_path) 263 | 264 | # extract the information about the testing image 265 | test_info = self.test_info 266 | io.imsave(os.path.join(deepimagej_model_path,'exampleImage.tiff'), self.test_info.InputImage) 267 | io.imsave(os.path.join(deepimagej_model_path,'resultImage.tiff'), self.test_info.OutputImage) 268 | print("Example images stored.") 269 | 270 | # write the DeepImageJ configuration as an xml file 271 | write_config(self, test_info, deepimagej_model_path) 272 | 273 | # Add preprocessing and postprocessing macros. 274 | # More than one is available, but the first one is set by default. 275 | for i in range(len(self.Preprocessing)): 276 | shutil.copy2(self.Preprocessing_files[i], os.path.join(deepimagej_model_path, self.Preprocessing[i])) 277 | print("ImageJ macro {} included in the bundled model.".format(self.Preprocessing[i])) 278 | 279 | for i in range(len(self.Postprocessing)): 280 | shutil.copy2(self.Postprocessing_files[i], os.path.join(deepimagej_model_path, self.Postprocessing[i])) 281 | print("ImageJ macro {} included in the bundled model.".format(self.Postprocessing[i])) 282 | 283 | # Zip the bundled model to download 284 | shutil.make_archive(deepimagej_model_path, 'zip', deepimagej_model_path) 285 | print("DeepImageJ model was successfully exported as {0}.zip. You can download and start using it in DeepImageJ.".format(deepimagej_model_path)) 286 | 287 | 288 | def save_tensorflow_pb(self,tf_model, deepimagej_model_path): 289 | # Check whether the folder to save the DeepImageJ bundled model exists. 290 | # If so, it needs to be removed (TensorFlow requirements) 291 | # -------------- Other definitions ----------- 292 | W = '\033[0m' # white (normal) 293 | R = '\033[31m' # red 294 | if os.path.exists(deepimagej_model_path): 295 | print(R+'!! WARNING: DeepImageJ model folder already existed and has been removed !!'+W) 296 | shutil.rmtree(deepimagej_model_path) 297 | 298 | import tensorflow as tf 299 | TF_VERSION = tf.__version__ 300 | print("DeepImageJ model will be exported using TensorFlow version {0}".format(TF_VERSION)) 301 | if TF_VERSION[:3] == "2.3": 302 | print(R+"DeepImageJ plugin is only compatible with TensorFlow version 1.x, 2.0.0, 2.1.0 and 2.2.0. Later versions are not suported in DeepImageJ."+W) 303 | 304 | def _save_model(): 305 | if tf_version==2: 306 | """TODO: change it once TF 2.3.0 is available in JAVA""" 307 | from tensorflow.compat.v1 import saved_model 308 | from tensorflow.compat.v1.keras.backend import get_session 309 | else: 310 | from tensorflow import saved_model 311 | from keras.backend import get_session 312 | 313 | builder = saved_model.builder.SavedModelBuilder(deepimagej_model_path) 314 | 315 | signature = saved_model.signature_def_utils.predict_signature_def( 316 | inputs = {'input': tf_model.input}, 317 | outputs = {'output': tf_model.output} ) 318 | 319 | signature_def_map = { saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature } 320 | 321 | builder.add_meta_graph_and_variables( get_session(), 322 | [saved_model.tag_constants.SERVING], 323 | signature_def_map=signature_def_map ) 324 | builder.save() 325 | print("TensorFlow model exported to {0}".format(deepimagej_model_path)) 326 | 327 | if TF_VERSION[0] == '1': 328 | tf_version = 1 329 | _save_model() 330 | else: 331 | tf_version = 2 332 | """TODO: change it once TF 2.3.0 is available in JAVA""" 333 | from tensorflow.keras.models import clone_model 334 | _weights = tf_model.get_weights(tf_model) 335 | with tf.Graph().as_default(): 336 | # clone model in new graph and set weights 337 | _model = clone_model(tf_model) 338 | _model.set_weights(_weights) 339 | _save_model() 340 | 341 | 342 | def write_config(Config, TestInfo, config_path): 343 | """ 344 | - Config: Class with all the information about the model's architecture and pre/post-processing 345 | - TestInfo: Metadata of the image provided as an example 346 | - config_path: path to the template of the configuration file. 347 | It can be downloaded from: 348 | https://raw.githubusercontent.com/deepimagej/python4deepimagej/blob/master/xml/config_template.xml 349 | The function updates the fields in the template provided with the 350 | information about the model and the example image. 351 | """ 352 | urllib.request.urlretrieve("https://raw.githubusercontent.com/deepimagej/python4deepimagej/master/xml/config_template.xml", "config_template.xml") 353 | try: 354 | tree = ET.parse('config_template.xml') 355 | root = tree.getroot() 356 | except: 357 | print("config_template.xml not found.") 358 | 359 | # WorkCitation-Credits 360 | root[0][0].text = Config.Name 361 | root[0][1].text = Config.Authors 362 | root[0][2].text = Config.URL 363 | root[0][3].text = Config.Credits 364 | root[0][4].text = Config.Version 365 | root[0][5].text = Config.Date 366 | root[0][6].text = Config.References 367 | 368 | # ExampleImage 369 | root[1][0].text = TestInfo.Input_shape 370 | root[1][1].text = TestInfo.Output_shape 371 | root[1][2].text = TestInfo.MemoryPeak 372 | root[1][3].text = TestInfo.Runtime 373 | root[1][4].text = TestInfo.PixelSize 374 | 375 | # ModelArchitecture 376 | root[2][0].text = 'tf.saved_model.tag_constants.SERVING' 377 | root[2][1].text = 'tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY' 378 | root[2][2].text = Config.InputTensorDimensions 379 | root[2][3].text = '1' 380 | root[2][4].text = 'input' 381 | root[2][5].text = Config.InputOrganization0 382 | root[2][6].text = '1' 383 | root[2][7].text = 'output' 384 | root[2][8].text = Config.OutputOrganization0 385 | root[2][9].text = Config.Channels 386 | root[2][10].text = Config.FixedPatch 387 | root[2][11].text = Config.MinimumSize 388 | root[2][12].text = Config.PatchSize 389 | root[2][13].text = 'true' 390 | root[2][14].text = Config.Padding 391 | root[2][15].text = Config.Preprocessing[0] 392 | print("Preprocessing macro '{}' set by default".format(Config.Preprocessing[0])) 393 | root[2][16].text = Config.Postprocessing[0] 394 | print("Postprocessing macro '{}' set by default".format(Config.Postprocessing[0])) 395 | root[2][17].text = '1' 396 | try: 397 | tree.write(os.path.join(config_path,'config.xml'),encoding="UTF-8",xml_declaration=True, ) 398 | print("DeepImageJ configuration file exported.") 399 | except: 400 | print("The directory {} does not exist.".format(config_path)) 401 | --------------------------------------------------------------------------------