├── .gitignore ├── BAND.md ├── INSTALLATION.md ├── LICENSE ├── README.md ├── cell_classification ├── pytorch │ └── train_infection_classifier.ipynb └── torch_em │ ├── apply_infection_classifier.ipynb │ └── train_infection_classifier.ipynb ├── cell_segmentation ├── cellpose │ └── pretrained_segmentation.ipynb ├── pytorch │ └── train_cell_segmentation.ipynb └── torch_em │ ├── apply_cell_segmentation.ipynb │ └── train_cell_segmentation.ipynb ├── data-visualization-matplotlib.ipynb ├── data-visualization-napari.ipynb ├── data_annotation ├── class_annotation.ipynb └── segment_anything.ipynb ├── environment_cpu.yaml ├── environment_gpu.yaml ├── nucleus_segmentation ├── bioimageio │ └── pretrained_segmentation.ipynb └── stardist │ └── pretrained_segmentation.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .ipynb_checkpoints/ 3 | logs/ 4 | checkpoints/ 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /BAND.md: -------------------------------------------------------------------------------- 1 | # Using BAND for the course 2 | 3 | Here is a step-by-step guide for running the exercises on BAND. 4 | - Go to [https://band.embl.de](https://band.embl.de) 5 | - Agree to the terms of services and privacy note and click on `LOGIN`. 6 | - You can then log in with your google account. If you log in for in the first time this may take a while. 7 | - After you have logged in, start an online desktop: 8 | - Choose the number of CPUs (recommended: 4), memory (recommended: 24GB) and number of GPUs (1) and click launch. See 1. in the overview figure below. 9 | - You can also set a time limit (default: 2 days). After this time your desktop will be shut down (you can then start a new one). 10 | - After a few seconds an item should appear in `Running desktops`. Click `GO TO DESKTOP` to connect to it. 11 | - If you already have a desktop running you don't need to start a new one, just connect to the existing one. 12 | - If you start BAND for the first time (or use it for this course for the first time) then you have to download the exercises via git: 13 | - Open a terminal by clicking the terminal symbol in the top left corner, see 2. in the overview. This will open a terminal window. 14 | - Then enter `git clone https://github.com/computational-cell-analytics/dl-for-micro` in it and press enter. 15 | - This will download all materials for the course. 16 | - Now you can start jupyter, which we use to run the exercises: 17 | - Click on the `Applications` button in the top left corner and then select `Programming->JupyterLab`. 18 | - This will open a new window (it may take up to a minute) with the jupyter environment. 19 | - You can now open any of the exercises with it. They are in the folder `dl-for-micro` (that you have downloaded earlier). See 3. in the overview. 20 | - We recommend that you start with the notebook `data-visualization-napari`. 21 | - To run it (and the other exercises), you need to select the kernel `micro_sam`. See 4. in the overview. 22 | 23 | ![band-overview](https://github.com/computational-cell-analytics/dl-for-micro/assets/4263537/00ee541b-5a96-456d-ab23-1b7bae639433) 24 | 25 | ## Tips & Tricks 26 | 27 | Because BAND is running in the browser, the shortcuts you know for copying and pasting text are not working quite as usual. It is however very convenient to copy and paste text between your laptop and BAND, and to copy and paste inside of jupyter, so here is how this works in BAND: 28 | 29 | **Copy and paste between your laptop and BAND:** 30 | - Press `Ctrl + Shift + Alt` in your browser window with BAND. This will open a text field that you can use to copy and paste into from your laptop and from BAND. (See screenshot below). 31 | - To copy from your laptop to BAND, first copy the text on your laptop (e.g. via `Ctrl + C`), then paste it into the window (opened via `Ctrl + Shift + Alt`), then select the text again there and copy it (`Ctrl + C`). Now you can close the window (by pressing `Ctrl + Shift + Alt` again), and then paste the text within BAND. 32 | 33 | ![band-copy-paste](https://github.com/computational-cell-analytics/dl-for-micro/assets/4263537/4d3c95c2-8443-4aab-bc56-2bfef5b74b86) 34 | 35 | **Copy and paste in jupyter:** 36 | - It is often very convenient to copy and paste text or code when working in jupyter. You can do this via the following shortcuts: 37 | - Copy text via `Ctrl + C` (as usual) 38 | - Paste text via `Ctrl + Shift + V` (different from the usual shortcut!) 39 | 40 | You can find more infos on how to use BAND: 41 | - In the [user guide](https://docs.google.com/document/d/1TZBUsNIciGMH_g4aFj2Lu_upISxh5TV9FBMrvNDWmc8/edit?usp=sharing) 42 | - In the [video tutorial](https://drive.google.com/file/d/11pbF70auGyWF-1ir2XUGM8fgiY7ETxP8/view?usp=sharing) 43 | -------------------------------------------------------------------------------- /INSTALLATION.md: -------------------------------------------------------------------------------- 1 | # Installation instructions 2 | 3 | You can set up an environment that contains all dependencies using `mamba / conda`. If you don't have `mamba` or `conda` installed yet see [here](https://github.com/mamba-org/mamba) for installation insructions. 4 | 5 | You can then create an environment with all necessary dependencies. You have the choice between a cpu or gpu version: 6 | - CPU Version: [environment_cpu.yaml](https://github.com/computational-cell-analytics/dl-for-micro/blob/main/environment_cpu.yaml) 7 | Create the environment via 8 | ``` 9 | mamba env create -f environment_cpu.yaml 10 | ``` 11 | - GPU version: [environment_gpu.yaml](https://github.com/computational-cell-analytics/dl-for-micro/blob/main/environment_gpu.yaml) 12 | Create the environment via 13 | ``` 14 | mamba env create -f environment_gpu.yaml 15 | ``` 16 | 17 | Note: you may need to change the CUDA version to match your system [here](https://github.com/computational-cell-analytics/dl-for-micro/blob/main/environment_gpu.yaml#L15). 18 | 19 | This will install the environment `dl-for-micro` with all necessary dependencies. 20 | After setting up the environment the following should work (activate the environment first with `mamba/conda activate dl-for-micro`): 21 | ``` 22 | $ python -c "import torch_em" 23 | $ python -c "import micro_sam" 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 computational-cell-analytics 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction to Deep Learning for Microscopy 2 | 3 | Do you want to learn how to use deep learning for solving your microscopy image analysis tasks? 4 | Then you have come to the right spot! 5 | 6 | We offer materials for an introductory course on the topic, containing video lectures, exercises that show step by step how to build deep learning models for microscopy with [PyTorch](https://pytorch.org/) and advanced examples that explain how to use several of the most popular deep learning based tools. 7 | 8 | Sparked your interest? 9 | - Check out the [content](#content)! 10 | - See how you can work on the exercises [on publicly available resources](#on-band) or how to [set them up on your own computer](#on-your-own-computer). 11 | - Check out the [recommended prior knowledge and further reading](#recommended-knowledge-and-further-reading). 12 | 13 | 14 | ## Content 15 | 16 | This course consists of lectures and exercises that teach the background of deep learning for image analysis and show applications to classification and segmentation analysis problems. 17 | 18 | ### Video Lectures 19 | 20 | Coming soon. 21 | 22 | ### PyTorch Exercises 23 | 24 | These exercises demonstrate how to use PyTorch for image analysis for several tasks on a microscopy image analysis problem. 25 | We recommend that you work on these exercises in this order: 26 | 1. `data-visualization-napari`: will give you an overview of the data used in the exercise and give an introduction to [napari](https://napari.org/stable/), which is the image viewer we will use throughout the exercises. 27 | 2. `cell_classification/pytorch/train_infection_classifier`: will teach you how to train and apply a neural network for cell classification. 28 | 3. `nucleus_segmentation/bioimageio/pretrained_segmentation`: will teach you how to apply a pretrained segmentation network from [bioimageio.io](https://bioimage.io/#/). 29 | 4. `cell_segmentation/pytorch/train_cell_segmentation`: will teach you how to train and apply a U-Net for cell segmentation. 30 | 31 | We provide two additional exercises that show you how to annotate your own data, so that you can create training data for applying what you have learned to your own analysis problems: 32 | - `data_annotation/segment_anything`: will teach you how to use [Segment Anything for Microscopy](https://www.biorxiv.org/content/10.1101/2023.08.21.554208v1.abstract) to interactively annotate training data for segmentation. 33 | - `data_annotation/class_annotation`: will teach you how to annotate data for classifcation with napari. 34 | 35 | ### Exercises for Other Tools 36 | 37 | Coming soon. 38 | 39 | 40 | ## Getting Started 41 | 42 | ### On BAND 43 | 44 | [BAND](https://band.embl.de/#) is a online service for image analysis. It is free of charge and it offers a pre-installed environment for the [exercises](pytorch-exercises). Follow [this link](https://github.com/computational-cell-analytics/dl-for-micro/blob/main/BAND.md) for a step-by-step guide for how to run the course on BAND. 45 | 46 | ### On your own Computer 47 | 48 | Coming soon. 49 | 50 | ### On Kaggle 51 | 52 | Coming soon. 53 | 54 | 69 | 70 | ## Recommended Knowledge and Further Reading 71 | 72 | Coming soon. 73 | 74 | 75 | 119 | -------------------------------------------------------------------------------- /cell_classification/torch_em/apply_infection_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "6871de10", 7 | "metadata": {}, 8 | "source": [ 9 | "## Apply Infection Classifier\n", 10 | "\n", 11 | "Finally, we apply the trained infection classifier to the test data, also using the cell segmentation we predicted instead of the ground-truth. We will also evaluate the accuracy of predictions." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "42766aa9", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# General imports.\n", 22 | "import torch_em\n", 23 | "\n", 24 | "import os\n", 25 | "from glob import glob\n", 26 | "\n", 27 | "import h5py\n", 28 | "import napari\n", 29 | "import numpy as np\n", 30 | "from skimage.measure import regionprops" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "3a5ad6d5", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# Define the paths to folders with the data and predictions.\n", 41 | "# If you store the data somewhere else just change the 'data_folder' variable.\n", 42 | "\n", 43 | "data_folder = \"../data\"\n", 44 | "output_folder = os.path.join(data_folder, \"predictions\")" 45 | ] 46 | }, 47 | { 48 | "attachments": {}, 49 | "cell_type": "markdown", 50 | "id": "80df5add", 51 | "metadata": {}, 52 | "source": [ 53 | "### 1. Test Data Extraction\n", 54 | "\n", 55 | "We first extract the input patches and labels for the test images. We copy these functions from the previous function. With the difference that we do not skip cells that could not be assigned a label here, but instead set them to -1." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "b5a2b705", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# Function to extract the label (infected vs. not infected) for each cell in an image.\n", 66 | "def extract_labels_for_cells(cells, infected_labels):\n", 67 | " # First we get all non-background cell ids for this image.\n", 68 | " cell_ids = np.unique(cells)[1:]\n", 69 | " cell_labels = {}\n", 70 | " \n", 71 | " # We iterate over the ids.\n", 72 | " for cell_id in cell_ids:\n", 73 | " # Compute the cell mask and get the infection labels inside of it\n", 74 | " cell_mask = cells == cell_id\n", 75 | " infected_labels_cell = infected_labels[cell_mask]\n", 76 | " # Zero means no infection label.\n", 77 | " infected_labels_cell = infected_labels_cell[infected_labels_cell != 0]\n", 78 | "\n", 79 | " # If we only have zeros then mark this label with -1\n", 80 | " if infected_labels_cell.size == 0:\n", 81 | " cell_labels[cell_id] = -1\n", 82 | " continue\n", 83 | " \n", 84 | " # The label values mean the following: 1 = infected, 2 = not infected.\n", 85 | " # If there is more than one label we need to check which of the two is more prevalent.\n", 86 | " label_ids, counts = np.unique(infected_labels_cell, return_counts=True)\n", 87 | " # We map the label id to 0, 1 (infected, not infected) because pytorch / torch_em expects zero-based indexing.\n", 88 | " if len(label_ids) == 1:\n", 89 | " assert label_ids[0] in (1, 2)\n", 90 | " label = label_ids[0] - 1\n", 91 | " else:\n", 92 | " assert label_ids.tolist() == [1, 2], str(label_ids)\n", 93 | " label = 0 if counts[0] > counts[1] else 0 \n", 94 | " cell_labels[cell_id] = label\n", 95 | "\n", 96 | " return cell_labels" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "12a0146e", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "# Function to extract the training patches and labels for one image.\n", 107 | "def image_to_training_data(cells, marker, nucleus_image, infected_labels, apply_cell_mask=True):\n", 108 | " # Compute the infection labels with the previously defined function and the region properties.\n", 109 | " cell_infection_labels = extract_labels_for_cells(cells, infected_labels)\n", 110 | " props = regionprops(cells)\n", 111 | " \n", 112 | " # Iterate over all cells in the image and extract the training patch.\n", 113 | " train_image_data, train_labels = [], []\n", 114 | " for prop in props:\n", 115 | " cell_id = prop.label\n", 116 | " \n", 117 | " # Get the infection label and skip the cell if it doesn't have one.\n", 118 | " label = cell_infection_labels[cell_id]\n", 119 | " \n", 120 | " # Get the bounding box from the properties for this cell.\n", 121 | " bbox = prop.bbox\n", 122 | " bbox = np.s_[bbox[0]:bbox[2], bbox[1]:bbox[3]]\n", 123 | " \n", 124 | " # Cut out mask, nucleus image and virus marker for this cell.\n", 125 | " cell_mask = cells[bbox] == cell_id\n", 126 | " nuc_im = nucleus_image[bbox].astype(\"float32\")\n", 127 | " marker_im = marker[bbox].astype(\"float32\")\n", 128 | " # And se the image values outsied of the cell to 0.\n", 129 | " if apply_cell_mask:\n", 130 | " nuc_im[~cell_mask] = 0.0\n", 131 | " marker_im[~cell_mask] = 0.0\n", 132 | " \n", 133 | " # Stack the 3 channels into one image and append to the training patches and labels.\n", 134 | " image_data = np.stack([nuc_im, marker_im, cell_mask.astype(\"float32\")])\n", 135 | " train_image_data.append(image_data)\n", 136 | " train_labels.append(label)\n", 137 | " \n", 138 | " return train_image_data, train_labels" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "d380b3e4", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "# Get the test image and test prediction paths.\n", 149 | "test_images = glob(os.path.join(data_folder, \"test\", \"*.h5\"))\n", 150 | "test_images.sort()\n", 151 | "test_predictions = glob(os.path.join(output_folder, \"*.h5\"))\n", 152 | "test_predictions.sort()\n", 153 | "assert len(test_images) == len(test_predictions)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "2d2974b3", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "# Load the inputs and labels for the test images.\n", 164 | "classification_inputs, classification_labels = [], []\n", 165 | "for test_image, test_prediction in zip(test_images, test_predictions):\n", 166 | " with h5py.File(test_image, \"r\") as f:\n", 167 | " marker = f[\"raw/marker/s0\"][:]\n", 168 | " nucleus_image = f[\"raw/nuclei/s0\"][:]\n", 169 | " infected_labels = f[\"labels/infected/nuclei/s0\"][:]\n", 170 | " with h5py.File(test_prediction, \"r\") as f:\n", 171 | " cells = f[\"segmentations/cells/watershed_based\"][:]\n", 172 | " inputs, labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)\n", 173 | " classification_inputs.append(inputs)\n", 174 | " classification_labels.append(labels)" 175 | ] 176 | }, 177 | { 178 | "attachments": {}, 179 | "cell_type": "markdown", 180 | "id": "8ee7c3c9", 181 | "metadata": {}, 182 | "source": [ 183 | "### 2. Prediction and Visualization for a Test Image\n", 184 | "\n", 185 | "We run prediction for one of the test images and visualize the results in napari." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "46dc51ca", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# torch and model imports\n", 196 | "import torch\n", 197 | "from torch_em.classification import default_classification_loader\n", 198 | "from torchvision.models.resnet import resnet34" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "559bcc90", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# Use GPU if available, otherwise the CPU.\n", 209 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "d271742d", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "# Load the model from the best checkpoint.\n", 220 | "model_path = \"checkpoints/infection-classifier/best.pt\"\n", 221 | "model = resnet34(num_classes=2)\n", 222 | "model_state = torch.load(model_path)[\"model_state\"]\n", 223 | "model.load_state_dict(model_state)\n", 224 | "model.eval()\n", 225 | "model = model.to(device)" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "id": "16ce7b01", 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "# Function to run prediction and to return the corresponding labels in a format that can be \n", 236 | "# evaluated by sklearn.metrics (see below).\n", 237 | "def predict_infection(model, inputs, labels, batch_size=128):\n", 238 | " loader = default_classification_loader(\n", 239 | " inputs, labels, batch_size=batch_size, image_shape=(64, 64),\n", 240 | " )\n", 241 | " y_pred, y_true = [], []\n", 242 | " \n", 243 | " with torch.no_grad():\n", 244 | " for x, y in loader:\n", 245 | " x = x.to(device)\n", 246 | " pred = model(x).cpu().numpy()\n", 247 | " class_pred = np.argmax(pred, axis=1)\n", 248 | " y_pred.append(class_pred)\n", 249 | " y_true.append(y.numpy().squeeze())\n", 250 | " \n", 251 | " y_pred = np.concatenate(y_pred)\n", 252 | " y_true = np.concatenate(y_true)\n", 253 | " return y_pred, y_true" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "d0b4dc35", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "# Get the infection predictions for the first input.\n", 264 | "infection_predictions, _ = predict_infection(model, classification_inputs[0], classification_labels[0])" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "id": "2790c52a", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "# Load the images and segmentation for the first test image again.\n", 275 | "with h5py.File(test_images[0], \"r\") as f:\n", 276 | " marker = f[\"raw/marker/s0\"][:]\n", 277 | " nucleus_image = f[\"raw/nuclei/s0\"][:]\n", 278 | " infected_labels = f[\"labels/infected/nuclei/s0\"][:]\n", 279 | " \n", 280 | "with h5py.File(test_predictions[0], \"r\") as f:\n", 281 | " cells = f[\"segmentations/cells/watershed_based\"][:]" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "951701d1", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "# Visualize the predictions in napari.\n", 292 | "props = regionprops(cells)\n", 293 | "\n", 294 | "points = [prop.centroid for prop in props]\n", 295 | "infected_points = [\"infected\" if pred == 0 else \"not-infected\" for pred in infection_predictions]\n", 296 | "\n", 297 | "viewer = napari.Viewer()\n", 298 | "viewer.add_image(marker, colormap=\"red\", blending=\"additive\")\n", 299 | "viewer.add_image(nucleus_image, colormap=\"blue\", blending=\"additive\")\n", 300 | "point_layer = viewer.add_points(\n", 301 | " points, properties={\"infected\": infected_points}, face_color=\"infected\", face_color_cycle=[\"orange\", \"cyan\"],\n", 302 | ")\n", 303 | "point_layer.face_color_mode = \"cycle\"" 304 | ] 305 | }, 306 | { 307 | "attachments": {}, 308 | "cell_type": "markdown", 309 | "id": "51dcde77", 310 | "metadata": {}, 311 | "source": [ 312 | "### 3. Prediction and Evaluation for the Test Set\n", 313 | "\n", 314 | "Run prediction for all test images and evaluate the accuracy of the result." 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "619f93bc", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "from sklearn.metrics import accuracy_score\n", 325 | "from tqdm import tqdm" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "ae58e700", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "# Get the prediction and labels for all images.\n", 336 | "y_pred, y_true = [], []\n", 337 | "for inputs, labels in tqdm(zip(classification_inputs, classification_labels), total=len(classification_inputs)):\n", 338 | " pred, true = predict_infection(model, inputs, labels)\n", 339 | " y_pred.append(pred)\n", 340 | " y_true.append(true)\n", 341 | "y_pred = np.concatenate(y_pred)\n", 342 | "y_true = np.concatenate(y_true)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "id": "701b8912", 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "# Exclude the labels and predictions for which labels are -1 (could not be mapped to either of the two labels).\n", 353 | "valid_labels = y_true != -1\n", 354 | "y_pred, y_true = y_pred[valid_labels], y_true[valid_labels]" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "12eee241", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# Compute the accuracy.\n", 365 | "accuracy = accuracy_score(y_true, y_pred)\n", 366 | "print(\"The overall accuracy is:\", accuracy)" 367 | ] 368 | }, 369 | { 370 | "attachments": {}, 371 | "cell_type": "markdown", 372 | "id": "74270d69", 373 | "metadata": {}, 374 | "source": [ 375 | "### Exercises\n", 376 | "\n", 377 | "- If you have trained any other models in the previous notebook then evaluate them as well and compare the performance between the different models.\n", 378 | "- Use other metrics form [sklearn.metrics](https://scikit-learn.org/stable/modules/model_evaluation.html) to evaluate other aspects of the results. In particular check if there are differences in the precision vs. recall and think about what this implies experimentally.\n", 379 | "- Check if there are any systematic differences in the scores between the different test images. If yes, check the corresponding image data and see if you can find a reason for this visually." 380 | ] 381 | } 382 | ], 383 | "metadata": { 384 | "kernelspec": { 385 | "display_name": "Python 3 (ipykernel)", 386 | "language": "python", 387 | "name": "python3" 388 | }, 389 | "language_info": { 390 | "codemirror_mode": { 391 | "name": "ipython", 392 | "version": 3 393 | }, 394 | "file_extension": ".py", 395 | "mimetype": "text/x-python", 396 | "name": "python", 397 | "nbconvert_exporter": "python", 398 | "pygments_lexer": "ipython3", 399 | "version": "3.10.10" 400 | } 401 | }, 402 | "nbformat": 4, 403 | "nbformat_minor": 5 404 | } 405 | -------------------------------------------------------------------------------- /cell_classification/torch_em/train_infection_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "4c45f7a1", 7 | "metadata": {}, 8 | "source": [ 9 | "## Train Infection Classifier\n", 10 | "\n", 11 | "In the previous lessons we have built a method for cell instance segmentation and applied it to our dataset. Now we turn to classifying the cells into infected vs. non-infected cells, based on the virus marker channel, nucleus image channel and segmentation mask for each individual cell. We will use a ResNet for this task.\n", 12 | "\n", 13 | "The goal of this lesson is to learn how to train a classification model with `torch_em`." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "57794921", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# General imports.\n", 24 | "import torch_em\n", 25 | "\n", 26 | "import os\n", 27 | "from glob import glob\n", 28 | "\n", 29 | "import h5py\n", 30 | "import napari\n", 31 | "import numpy as np\n", 32 | "\n", 33 | "from skimage.measure import regionprops" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "8e725790", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# Define the paths to folders with the data and train/val splits.\n", 44 | "# If you store the data somewhere else just change the 'data_folder' variable.\n", 45 | "\n", 46 | "data_folder = \"../data\"\n", 47 | "train_data_folder = os.path.join(data_folder, \"train\")\n", 48 | "val_data_folder = os.path.join(data_folder, \"val\")" 49 | ] 50 | }, 51 | { 52 | "attachments": {}, 53 | "cell_type": "markdown", 54 | "id": "eec67a4b", 55 | "metadata": {}, 56 | "source": [ 57 | "### 1. Inspect Training Data\n", 58 | "\n", 59 | "First, we visually check all the relevant training data. We will use it to construct image patches for training the classification model as follows:\n", 60 | "- Compute the bounding box around each cell.\n", 61 | "- Cut out the nucleus image, virus marker and segmentation mask for the bounding box.\n", 62 | "- Set all values outside the mask to zero.\n", 63 | "- Derive the label (infected or not infected) for the given patch from the infetion label image." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "12ce06fa", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Load all necessary data for one training image.\n", 74 | "image_path = os.path.join(train_data_folder, \"gt_image_000.h5\")\n", 75 | "with h5py.File(image_path, \"r\") as f:\n", 76 | " marker = f[\"raw/marker/s0\"][:]\n", 77 | " nucleus_image = f[\"raw/nuclei/s0\"][:]\n", 78 | " cells = f[\"labels/cells/s0\"][:]\n", 79 | " infected_labels = f[\"labels/infected/nuclei/s0\"][:]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "04a15f69", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# Check it visually.\n", 90 | "viewer = napari.Viewer()\n", 91 | "viewer.add_image(marker, colormap=\"red\", blending=\"additive\")\n", 92 | "viewer.add_image(nucleus_image, colormap=\"blue\", blending=\"additive\")\n", 93 | "viewer.add_labels(cells)\n", 94 | "viewer.add_labels(infected_labels)" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "abe7d6b4", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# Function to extract the label (infected vs. not infected) for each cell in an image.\n", 105 | "def extract_labels_for_cells(cells, infected_labels):\n", 106 | " # First we get all non-background cell ids for this image.\n", 107 | " cell_ids = np.unique(cells)[1:]\n", 108 | " cell_labels = {}\n", 109 | " \n", 110 | " # We iterate over the ids.\n", 111 | " for cell_id in cell_ids:\n", 112 | " # Compute the cell mask and get the infection labels inside of it\n", 113 | " cell_mask = cells == cell_id\n", 114 | " infected_labels_cell = infected_labels[cell_mask]\n", 115 | " # Zero means on inferction label.\n", 116 | " infected_labels_cell = infected_labels_cell[infected_labels_cell != 0]\n", 117 | " # If we only have zeros then skip this cell.\n", 118 | " if infected_labels_cell.size == 0:\n", 119 | " cell_labels[cell_id] = None\n", 120 | " continue\n", 121 | " \n", 122 | " # The label values mean the following: 1 = infected, 2 = not infected.\n", 123 | " # If there is more than one label we need to check which of the two is more prevalent.\n", 124 | " label_ids, counts = np.unique(infected_labels_cell, return_counts=True)\n", 125 | " # We map the label id to 0, 1 (infected, not infected) because pytorch / torch_em expects zero-based indexing.\n", 126 | " if len(label_ids) == 1:\n", 127 | " assert label_ids[0] in (1, 2)\n", 128 | " label = label_ids[0] - 1\n", 129 | " else:\n", 130 | " assert label_ids.tolist() == [1, 2], str(label_ids)\n", 131 | " label = 0 if counts[0] > counts[1] else 0 \n", 132 | " cell_labels[cell_id] = label\n", 133 | " return cell_labels" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "1f8b7cb0", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# We apply the function to get the infection labels for the cells in our current image.\n", 144 | "cell_infection_labels = extract_labels_for_cells(cells, infected_labels)\n", 145 | "\n", 146 | "# And use skimage regionprops to compute other properties for all cells in the image.\n", 147 | "props = regionprops(cells)\n", 148 | "\n", 149 | "# Now we visualize the infected labels as points, by putting a point per cell centroid and coloring it\n", 150 | "# according to their label using a napari points layer (see below).\n", 151 | "points = [prop.centroid for prop in props]\n", 152 | "infected_points = [\"infected\" if label == 0 else \"not-infected\" for label in cell_infection_labels.values()]\n", 153 | "\n", 154 | "viewer = napari.Viewer()\n", 155 | "viewer.add_image(marker, colormap=\"red\", blending=\"additive\")\n", 156 | "viewer.add_image(nucleus_image, colormap=\"blue\", blending=\"additive\")\n", 157 | "point_layer = viewer.add_points(\n", 158 | " points, properties={\"infected\": infected_points}, face_color=\"infected\", face_color_cycle=[\"orange\", \"cyan\"],\n", 159 | ")\n", 160 | "point_layer.face_color_mode = \"cycle\"" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "00e0cd3f", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# Function to extract the training patches and labels for one image.\n", 171 | "def image_to_training_data(cells, marker, nucleus_image, infected_labels, apply_cell_mask=True):\n", 172 | " # Compute the infection labels with the previously defined function and the region properties.\n", 173 | " cell_infection_labels = extract_labels_for_cells(cells, infected_labels)\n", 174 | " props = regionprops(cells)\n", 175 | " \n", 176 | " # Iterate over all cells in the image and extract the training patch.\n", 177 | " train_image_data, train_labels = [], []\n", 178 | " for prop in props:\n", 179 | " cell_id = prop.label\n", 180 | " \n", 181 | " # Get the infection label and skip the cell if it doesn't have one.\n", 182 | " label = cell_infection_labels[cell_id]\n", 183 | " if label is None:\n", 184 | " continue\n", 185 | " \n", 186 | " # Get the bounding box from the properties for this cell.\n", 187 | " bbox = prop.bbox\n", 188 | " bbox = np.s_[bbox[0]:bbox[2], bbox[1]:bbox[3]]\n", 189 | " \n", 190 | " # Cut out mask, nucleus image and virus marker for this cell.\n", 191 | " cell_mask = cells[bbox] == cell_id\n", 192 | " nuc_im = nucleus_image[bbox].copy().astype(\"float32\")\n", 193 | " marker_im = marker[bbox].copy().astype(\"float32\")\n", 194 | " # And se the image values outsied of the cell to 0.\n", 195 | " if apply_cell_mask:\n", 196 | " nuc_im[~cell_mask] = 0.0\n", 197 | " marker_im[~cell_mask] = 0.0\n", 198 | " \n", 199 | " # Stack the 3 channels into one image and append to the training patches and labels.\n", 200 | " image_data = np.stack([nuc_im, marker_im, cell_mask.astype(\"float32\")])\n", 201 | " train_image_data.append(image_data)\n", 202 | " train_labels.append(label)\n", 203 | " \n", 204 | " return train_image_data, train_labels" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "eb9939bb", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "# Apply the function to our current image.\n", 215 | "train_image_data, train_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "ca5aa47a", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# Visualize 5 of the training patches.\n", 226 | "for i in range(25, 30):\n", 227 | " im_data = train_image_data[i]\n", 228 | " label = train_labels[i]\n", 229 | " viewer = napari.Viewer()\n", 230 | " viewer.add_image(im_data[0], name=\"nucleus-channel\", colormap=\"blue\", blending=\"additive\") \n", 231 | " viewer.add_image(im_data[1], name=\"marker-channel\", colormap=\"red\", blending=\"additive\")\n", 232 | " viewer.add_labels(im_data[2].astype(\"uint8\"), name=\"cell-mask\")\n", 233 | " viewer.title = f\"Label: {label}\"" 234 | ] 235 | }, 236 | { 237 | "attachments": {}, 238 | "cell_type": "markdown", 239 | "id": "20af4026", 240 | "metadata": {}, 241 | "source": [ 242 | "### 2. Prepare Training Data\n", 243 | "\n", 244 | "Now we apply the function we just defined to all training and validation data to build the training and validation sets for our classification model." 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "fff5de25", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "from tqdm import tqdm\n", 255 | "\n", 256 | "# Function that extracts the patches and labels for all images in a folder.\n", 257 | "def prepare_classification_data(root):\n", 258 | " images = glob(os.path.join(root, \"*.h5\"))\n", 259 | " images.sort()\n", 260 | "\n", 261 | " image_data, labels = [], []\n", 262 | " for path in tqdm(images, desc=\"Prepare classification data\"):\n", 263 | " with h5py.File(path, \"r\") as f:\n", 264 | " marker = f[\"raw/marker/s0\"][:]\n", 265 | " nucleus_image = f[\"raw/nuclei/s0\"][:]\n", 266 | " cells = f[\"labels/cells/s0\"][:]\n", 267 | " infected_labels = f[\"labels/infected/nuclei/s0\"][:]\n", 268 | " \n", 269 | " this_data, this_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)\n", 270 | " image_data.extend(this_data)\n", 271 | " labels.extend(this_labels)\n", 272 | " \n", 273 | " assert len(image_data) == len(labels)\n", 274 | " return image_data, labels" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "id": "65165594", 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "# Build the training and validation set.\n", 285 | "train_data, train_labels = prepare_classification_data(train_data_folder)\n", 286 | "print(\"We have\", len(train_data), \"samples for training\")\n", 287 | "\n", 288 | "val_data, val_labels = prepare_classification_data(val_data_folder)\n", 289 | "print(\"We have\", len(val_data), \"samples for validation\")" 290 | ] 291 | }, 292 | { 293 | "attachments": {}, 294 | "cell_type": "markdown", 295 | "id": "e5302bb3", 296 | "metadata": {}, 297 | "source": [ 298 | "### 3. Train the Infection Classifier\n", 299 | "\n", 300 | "And use the training and validation set to train a ResNet34 for infection classification, using the classification functionality from `torch_em`." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "e4745f29", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "# Import classification functionality.\n", 311 | "import torch\n", 312 | "import torch.nn as nn\n", 313 | "from torch_em.classification import default_classification_loader, default_classification_trainer\n", 314 | "from torchvision.models.resnet import resnet34\n", 315 | "from sklearn.metrics import accuracy_score" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "id": "1fde0408", 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "# Find the mean shape of all training and validation patches.\n", 326 | "shapes = np.stack([np.array(im.shape[1:]) for im in (train_data + val_data)])\n", 327 | "mean_shape = np.mean(shapes, axis=0)\n", 328 | "print(\"Mean image shape:\", mean_shape)" 329 | ] 330 | }, 331 | { 332 | "attachments": {}, 333 | "cell_type": "markdown", 334 | "id": "f4c1ba87", 335 | "metadata": {}, 336 | "source": [ 337 | "You should see that the mean image shape is roughly 52 x 52 pixels. We determine this shape to choose a suitable shape that all patches will be resized to for training the model. This is necessary to stack the patches across the batch dimensions and train the model with a batch size that is larger than 1.\n", 338 | "We choose the closest multiple of 16 as common patch shape, which is 64 x 64." 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "id": "075ed385", 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "# Build the training and validation loader.\n", 349 | "batch_size = 32 # The batch size used for training.\n", 350 | "image_shape = (64, 64) # The common shape all patches will be resized to before stacking them in a batch.\n", 351 | "num_workers = 4 if torch.cuda.is_available() else 1\n", 352 | "# Build the training and validation loader.\n", 353 | "train_loader = default_classification_loader(\n", 354 | " train_data, train_labels, batch_size=batch_size, image_shape=image_shape, num_workers=num_workers,\n", 355 | ")\n", 356 | "val_loader = default_classification_loader(\n", 357 | " val_data, val_labels, batch_size=batch_size, image_shape=image_shape, num_workers=num_workers,\n", 358 | ")" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "701b8f57", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "# Define the model (a resnet 34 with two output channels).\n", 369 | "model = resnet34(num_classes=2)\n", 370 | "# And build the trainer class. Here, we use the cross entropy as loss function and the accuracy error as metric.\n", 371 | "trainer = default_classification_trainer(\n", 372 | " name=\"infection-classifier\", model=model,\n", 373 | " train_loader=train_loader, val_loader=val_loader,\n", 374 | " loss=nn.CrossEntropyLoss(),\n", 375 | " metric=lambda a, b: 1.0 - accuracy_score(a, b),\n", 376 | " compile_model=False,\n", 377 | ")" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "id": "2de3e0d4", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "# Train the model for 10.000 iterations.\n", 388 | "trainer.fit(10000)" 389 | ] 390 | }, 391 | { 392 | "attachments": {}, 393 | "cell_type": "markdown", 394 | "id": "87373720", 395 | "metadata": {}, 396 | "source": [ 397 | "As before you can open the tensorboard to monitor the progress while training via\n", 398 | "```\n", 399 | "tensorboard --logdir=logs\n", 400 | "```\n", 401 | "See `2_cell_segmentation/torchem-train-cell-membrane-segmentation` for details." 402 | ] 403 | }, 404 | { 405 | "attachments": {}, 406 | "cell_type": "markdown", 407 | "id": "78be34a4", 408 | "metadata": {}, 409 | "source": [ 410 | "### Exercises\n", 411 | "\n", 412 | "Train different architectures for this task, for example a `resnet18` and a `resnet50`. Also export these models to the bioimage.io format, make sure to choose different file paths for the export so that you do not overwrite the previous exported models.\n", 413 | "You can also compare to training this network using only PyTorch in the `pytorch_train-infection-classifier` notebook (work in progress)." 414 | ] 415 | }, 416 | { 417 | "attachments": {}, 418 | "cell_type": "markdown", 419 | "id": "ad99b917", 420 | "metadata": {}, 421 | "source": [ 422 | "### What's next?\n", 423 | "\n", 424 | "Now we can apply the trained classification model to the test images in `apply_infection_classifier`.\n", 425 | "\n", 426 | "\n", 427 | "\n" 428 | ] 429 | }, 430 | { 431 | "attachments": {}, 432 | "cell_type": "markdown", 433 | "id": "538daf2f", 434 | "metadata": {}, 435 | "source": [ 436 | "\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | "\n", 441 | "\n", 442 | "\n" 443 | ] 444 | }, 445 | { 446 | "attachments": {}, 447 | "cell_type": "markdown", 448 | "id": "fecba392", 449 | "metadata": {}, 450 | "source": [ 451 | "**This is not working yet!**\n", 452 | "\n", 453 | "**Skip the cells below!**\n", 454 | "\n", 455 | "#### Export the model to bioimage.io\n", 456 | "\n", 457 | "Now we also export the model to the bioimage.io format to import it in other tools that support this format.\n", 458 | "See the notebook `2_cell_segmentation/torchem-train-cell-membrane-segmentation` for details." 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "id": "b3d734bb", 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "import h5py\n", 469 | "from torch_em.util.modelzoo import export_bioimageio_model" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "id": "fc9202b2", 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "model_root = os.path.join(data_folder, \"trained_models\")\n", 480 | "model_folder = os.path.join(model_root, \"infection-classification\")\n", 481 | "os.makedirs(model_folder, exist_ok=True)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "id": "c5981c5a", 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "input_, _ = next(iter(val_loader))\n", 492 | "input_ = input_[0:1].detach().cpu().numpy()" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "id": "b412cc6f", 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "doc = \"\"\"#ResNet for Covid Cell Infection Classification\n", 503 | "\n", 504 | "A model for classifying cells into infected vs. non-infected.\n", 505 | "\"\"\"\n", 506 | "\n", 507 | "citations = [{\"text\": \"Pape et al.\", \"doi\": \"https://doi.org/10.1002/bies.202000257\"}]" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "id": "671be5b7", 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [ 517 | "export_bioimageio_model(\n", 518 | " checkpoint=\"checkpoints/infection-classifier\",\n", 519 | " export_folder=model_folder,\n", 520 | " input_data=input_,\n", 521 | " name=\"infection_classification_model\",\n", 522 | " authors=[{\"name\": \"Your Name\", \"affiliation\": \"Your Affiliation\"}],\n", 523 | " tags=[\"uner\", \"cells\", \"2d\", \"immunofluorescence\", \"classification\"],\n", 524 | " license=\"CC-BY-4.0\",\n", 525 | " documentation=doc,\n", 526 | " description=\"Classify cell membranes in IF images\",\n", 527 | " cite=citations,\n", 528 | " input_optional_parameters=False,\n", 529 | " maintainers=[{\"github_user\": \"Your Github Handle\"}] # alternatively you can also give your mail address\n", 530 | ")" 531 | ] 532 | } 533 | ], 534 | "metadata": { 535 | "kernelspec": { 536 | "display_name": "Python 3 (ipykernel)", 537 | "language": "python", 538 | "name": "python3" 539 | }, 540 | "language_info": { 541 | "codemirror_mode": { 542 | "name": "ipython", 543 | "version": 3 544 | }, 545 | "file_extension": ".py", 546 | "mimetype": "text/x-python", 547 | "name": "python", 548 | "nbconvert_exporter": "python", 549 | "pygments_lexer": "ipython3", 550 | "version": "3.11.7" 551 | } 552 | }, 553 | "nbformat": 4, 554 | "nbformat_minor": 5 555 | } 556 | -------------------------------------------------------------------------------- /cell_segmentation/cellpose/pretrained_segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6fdd2e3e", 6 | "metadata": {}, 7 | "source": [ 8 | "Under construction" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "3d28f63d", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [] 18 | } 19 | ], 20 | "metadata": { 21 | "kernelspec": { 22 | "display_name": "Python 3 (ipykernel)", 23 | "language": "python", 24 | "name": "python3" 25 | }, 26 | "language_info": { 27 | "codemirror_mode": { 28 | "name": "ipython", 29 | "version": 3 30 | }, 31 | "file_extension": ".py", 32 | "mimetype": "text/x-python", 33 | "name": "python", 34 | "nbconvert_exporter": "python", 35 | "pygments_lexer": "ipython3", 36 | "version": "3.9.9" 37 | } 38 | }, 39 | "nbformat": 4, 40 | "nbformat_minor": 5 41 | } 42 | -------------------------------------------------------------------------------- /cell_segmentation/torch_em/apply_cell_segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "id": "4f1166ee", 7 | "metadata": {}, 8 | "source": [ 9 | "## Cell Segmentation\n", 10 | "\n", 11 | "Now we bring the nucleus segmentation, and cell foreground and boundary predictions together in order to obtain the complete cell instance segmentation. Here, we use a seeded watershed, where we use the nucleus instances as seeds, use the cell boundary predictions as height map for the watershed and the cell foreground prediction as mask. We use the watershed functionality from [skimage](https://scikit-image.org/) for this.\n", 12 | "\n", 13 | "The goal of this lesson is to further explore post-processing for instance segmentation and to also learn how to quantitatively evaluate segmentation results." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "32cc7a61", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# General imports and functionality for network prediction and watershed.\n", 24 | "import os\n", 25 | "\n", 26 | "import bioimageio.core\n", 27 | "import h5py\n", 28 | "import napari\n", 29 | "\n", 30 | "from skimage.segmentation import watershed\n", 31 | "from xarray import DataArray" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "8af9e539", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# Define the paths to folders with the data and predictions.\n", 42 | "# If you store the data somewhere else just change the 'data_folder' variable.\n", 43 | "\n", 44 | "data_folder = \"../data\"\n", 45 | "output_folder = os.path.join(data_folder, \"predictions\")" 46 | ] 47 | }, 48 | { 49 | "attachments": {}, 50 | "cell_type": "markdown", 51 | "id": "712a0ba2", 52 | "metadata": {}, 53 | "source": [ 54 | "### 1. Implement Cell Segmentation\n", 55 | "\n", 56 | "First, we implement the watershed based cell segmentation and visually check it for a test image." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "b375c3d8", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# We load the model we have trained in the previous notebook.\n", 67 | "model_path = os.path.join(data_folder, \"trained_models/boundary-segmentation/boundary_segmentation_model.zip\")\n", 68 | "model = bioimageio.core.load_resource_description(model_path)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "07d0b3b2", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# And load the serum channel as well as the nucleus segmentation for one of the test images.\n", 79 | "image_path = os.path.join(data_folder, \"test/gt_image_048.h5\")\n", 80 | "prediction_path = os.path.join(data_folder, \"predictions/gt_image_048.h5\")\n", 81 | "\n", 82 | "with h5py.File(image_path, \"r\") as f:\n", 83 | " image = f[\"raw/serum_IgG/s0\"][:]\n", 84 | " \n", 85 | "with h5py.File(prediction_path, \"r\") as f:\n", 86 | " nuclei = f[\"/segmentations/nuclei/watershed_based\"][:]" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "5427affa", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Next, we run prediction with the cell segmentation network.\n", 97 | "# For details on the bioimageio functionality see the previous notebook on nucleus segmentation.\n", 98 | "with bioimageio.core.create_prediction_pipeline(model) as pp:\n", 99 | " input_ = DataArray(image[None, None], dims=tuple(\"bcyx\"))\n", 100 | " pred = bioimageio.core.predict_with_padding(pp, input_, padding={\"x\": 16, \"y\": 16})[0].values.squeeze()" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "dab4ea39", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Check the predictions visually.\n", 111 | "viewer = napari.Viewer()\n", 112 | "viewer.add_image(image)\n", 113 | "viewer.add_image(pred)\n", 114 | "viewer.add_labels(nuclei)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "96db9233", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Run watershed to get the cell instance segmentation.\n", 125 | "foreground, boundaries = pred\n", 126 | "foreground = foreground > 0.5\n", 127 | "cells = watershed(boundaries, markers=nuclei, mask=foreground)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "b5aa6946", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# And check the result.\n", 138 | "viewer = napari.Viewer()\n", 139 | "viewer.add_image(image)\n", 140 | "viewer.add_labels(cells)" 141 | ] 142 | }, 143 | { 144 | "attachments": {}, 145 | "cell_type": "markdown", 146 | "id": "76ff57d0", 147 | "metadata": {}, 148 | "source": [ 149 | "### 2. Apply to Test Images\n", 150 | "\n", 151 | "Now we apply this segmentation approach to all test images." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "939dff36", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from glob import glob\n", 162 | "from tqdm import tqdm\n", 163 | "\n", 164 | "test_images = glob(os.path.join(data_folder, \"test/*.h5\"))\n", 165 | "test_images.sort()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "f3412039", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "# Combine the prediction and watershed in a function.\n", 176 | "def segment_cells(pp, image, nuclei):\n", 177 | " input_ = DataArray(image[None, None], dims=tuple(\"bcyx\"))\n", 178 | " pred = bioimageio.core.predict_with_padding(pp, input_, padding={\"x\": 16, \"y\": 16})[0].values.squeeze()\n", 179 | " foreground, boundaries = pred\n", 180 | " foreground = foreground > 0.5\n", 181 | " cells = watershed(boundaries, markers=nuclei, mask=foreground)\n", 182 | " return cells" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "1c9260e4", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "# And run this function for all test images, saving the results to hdf5.\n", 193 | "with bioimageio.core.create_prediction_pipeline(model) as pp:\n", 194 | " for path in tqdm(test_images):\n", 195 | " out_path = os.path.join(output_folder, os.path.basename(path))\n", 196 | " with h5py.File(path, \"r\") as f:\n", 197 | " image = f[\"raw/serum_IgG/s0\"][:]\n", 198 | " with h5py.File(out_path, \"r\") as f:\n", 199 | " nuclei = f[\"segmentations/nuclei/watershed_based\"][:]\n", 200 | " cells = segment_cells(pp, image, nuclei)\n", 201 | " with h5py.File(out_path, \"a\") as f:\n", 202 | " f.create_dataset(\"segmentations/cells/watershed_based\", data=cells, compression=\"gzip\")" 203 | ] 204 | }, 205 | { 206 | "attachments": {}, 207 | "cell_type": "markdown", 208 | "id": "fcd3a91e", 209 | "metadata": {}, 210 | "source": [ 211 | "### 3. Evaluate Cell Segmentation\n", 212 | "\n", 213 | "We can now also quantitatively evaluate the cell segementation. We use the AP50 evaluation metric for it. It measures the [precision](https://en.wikipedia.org/wiki/Precision_and_recall) of the matches between the predicted segmentation and ground-truth segmentation. This is a standard evaluation metric for instance segmentations, and we use the implementation from [elf](https://github.com/constantinpape/elf)." 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "1e3febff", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "import numpy as np\n", 224 | "from elf.evaluation import matching" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "17c9b283", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "predictions = glob(os.path.join(output_folder, \"*.h5\"))\n", 235 | "predictions.sort()\n", 236 | "assert len(predictions) == len(test_images)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "012f8f3b", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "evaluation_scores = []\n", 247 | "for image_path, pred_path in zip(test_images, predictions):\n", 248 | " with h5py.File(image_path, \"r\") as f:\n", 249 | " ground_truth = f[\"labels/cells/s0\"][:]\n", 250 | " with h5py.File(pred_path, \"r\") as f:\n", 251 | " segmentation = f[\"segmentations/cells/watershed_based\"][:]\n", 252 | " evaluation_scores.append(matching(segmentation, ground_truth)[\"precision\"])\n", 253 | "evaluation_score = np.mean(evaluation_scores)\n", 254 | "print(\"The AP50 score for the cell segmentation is\", evaluation_score)" 255 | ] 256 | }, 257 | { 258 | "attachments": {}, 259 | "cell_type": "markdown", 260 | "id": "612f418a", 261 | "metadata": {}, 262 | "source": [ 263 | "### Exercises\n", 264 | "\n", 265 | "- If you have trained different segmentation models in the previous notebook `torchem-train-cell-membrane-segmentation`, then compare the evaluation results between them.\n", 266 | "- [Cellpose](https://github.com/MouseLand/cellpose) is a generalist method for cell segmentation that can directly be applied to our data. Run segmentation for the test images with it and compare the evaluation scores.\n", 267 | " - We are also working on adding a notebook that shows how to apply Cellpose to this data `cellpose_pretrained-cell-segmentation`, but this is work in progress." 268 | ] 269 | }, 270 | { 271 | "attachments": {}, 272 | "cell_type": "markdown", 273 | "id": "5d62d52d", 274 | "metadata": {}, 275 | "source": [ 276 | "### What's next\n", 277 | "\n", 278 | "Now that we have obtained a cell classification we turn to classifying the cells into infected vs. non-infected in `3_cell_classification/pytorch_train-infection-classifier.ipynb`." 279 | ] 280 | } 281 | ], 282 | "metadata": { 283 | "kernelspec": { 284 | "display_name": "Python 3 (ipykernel)", 285 | "language": "python", 286 | "name": "python3" 287 | }, 288 | "language_info": { 289 | "codemirror_mode": { 290 | "name": "ipython", 291 | "version": 3 292 | }, 293 | "file_extension": ".py", 294 | "mimetype": "text/x-python", 295 | "name": "python", 296 | "nbconvert_exporter": "python", 297 | "pygments_lexer": "ipython3", 298 | "version": "3.9.9" 299 | } 300 | }, 301 | "nbformat": 4, 302 | "nbformat_minor": 5 303 | } 304 | -------------------------------------------------------------------------------- /data_annotation/segment_anything.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1be4d22d", 6 | "metadata": {}, 7 | "source": [ 8 | "## Segment Anything for Data Annotation\n", 9 | "\n", 10 | "[Segment Anything](https://segment-anything.com/) is a recent model for interactive segmentation published by Meta.AI. It can be used to genrate annotations for segmentation much faster compared to painting by hand. \n", 11 | "\n", 12 | "We have build some napari tools around it in https://github.com/computational-cell-analytics/micro-sam, which we will use here to annotate some of the cells for our example data. You can find more information on this tool and further extensions to Segment Anything in [our preprint](https://www.biorxiv.org/content/10.1101/2023.08.21.554208v1.abstract)." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "b423d46d", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# General Imports.\n", 23 | "import os\n", 24 | "import sys\n", 25 | "import imageio.v3 as imageio\n", 26 | "\n", 27 | "# Load the function to start the segment anything napari plugin.\n", 28 | "# (Note: the tool can also be started as a napari plugin.)\n", 29 | "from micro_sam.sam_annotator import annotator_2d\n", 30 | "\n", 31 | "sys.path.append(\"..\")\n", 32 | "import utils" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "e22ccc22-1c40-4c30-a3a7-9b4d1804131e", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# This function will download and unpack the data and do some further data preparation.\n", 43 | "# It will only be executed if the data has not been downloaded yet.\n", 44 | "data_dir = \"../data\"\n", 45 | "if os.path.exists(data_dir):\n", 46 | " print(\"The data is downloaded already.\")\n", 47 | "else:\n", 48 | " utils.prepare_data(data_dir)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "8459f217", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# Load an example image.\n", 59 | "image_path = os.path.join(data_dir, \"train\", \"gt_image_030\", \"gt_image_030_serum_image.tif\")\n", 60 | "image = imageio.imread(image_path)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "a55a747f-bc90-4677-8d86-a85235888199", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "# Create a cutout so that we can focus on annotating fewer cells.\n", 71 | "image = image[:600, :600]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "bdb1ebf8", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# Start the 2d annotation tool.\n", 82 | "# Here, we use a model that was fine-tuned by us on microscopy data.\n", 83 | "# We select it with the 'model_type' argument.\n", 84 | "# \"vit_b\" stands for the size of the model and \"_lm\" means that it is the model finetuned on light microscopy data.\n", 85 | "annotator_2d(image, model_type=\"vit_b_lm\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "6d9da249", 91 | "metadata": {}, 92 | "source": [ 93 | "Annotate some of the cell using the interactive annotation functionality. Try out different ways to annotate the cells, using box prompts, point prompts and combinations thereof. Which of these combinations works best?\n", 94 | "\n", 95 | "After you are done save the annotations as a tif file, load the saved tif and make sure that your annotations were saved correctly.\n", 96 | "- You can save the annotations by selecting the corresponding layer `committed_objects`, and then saveing it via `File->Save Selected Layer(s)...`\n", 97 | "\n", 98 | "You can find explanations for how to use the annotation tool [here](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#annotation-tools) and can watch [the video tutorial](https://www.youtube.com/watch?v=ket7bDUP9tI&list=PLwYZXQJ3f36GQPpKCrSbHjGiH39X4XjSO&index=1) to see a live demonstration." 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "737314a1-436f-45e0-9e78-26217eeddd68", 104 | "metadata": {}, 105 | "source": [ 106 | "### More Data Annotation Tools\n", 107 | "\n", 108 | "We also offer other napari based annotation tools building on Segment Anything for:\n", 109 | "- [Annotating volumetric data](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#annotator-3d)\n", 110 | "- [Annotating (2D image) timeseries for object tracking](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#annotator-tracking)\n", 111 | "- [Annotating 2D image data en bulk](https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/image_series_annotator.py)" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3 (ipykernel)", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.11.7" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 5 136 | } 137 | -------------------------------------------------------------------------------- /environment_cpu.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - conda-forge 4 | name: 5 | dl-for-micro 6 | dependencies: 7 | - cpuonly 8 | - bioimageio.core >=0.5.0 9 | - jupyter 10 | - napari 11 | - pyqt 12 | - pycocotools 13 | - python-elf 14 | - pytorch 15 | - pytorch-cuda>=11.7 # you may need to change the pytorch version to match your system 16 | - tensorboard 17 | - tifffile 18 | - torchvision 19 | - torch_em 20 | - micro_sam 21 | - tqdm 22 | -------------------------------------------------------------------------------- /environment_gpu.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - nvidia 4 | - conda-forge 5 | name: 6 | dl-for-micro 7 | dependencies: 8 | - bioimageio.core >=0.5.0 9 | - jupyter 10 | - napari 11 | - pyqt 12 | - pycocotools 13 | - python-elf 14 | - pytorch 15 | - pytorch-cuda>=11.7 # you may need to change the pytorch version to match your system 16 | - tensorboard 17 | - tifffile 18 | - torchvision 19 | - torch_em 20 | - micro_sam 21 | - tqdm 22 | -------------------------------------------------------------------------------- /nucleus_segmentation/stardist/pretrained_segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5e1ec7d1", 6 | "metadata": {}, 7 | "source": [ 8 | "Under construction." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "32d0fd75", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [] 18 | } 19 | ], 20 | "metadata": { 21 | "kernelspec": { 22 | "display_name": "Python 3 (ipykernel)", 23 | "language": "python", 24 | "name": "python3" 25 | }, 26 | "language_info": { 27 | "codemirror_mode": { 28 | "name": "ipython", 29 | "version": 3 30 | }, 31 | "file_extension": ".py", 32 | "mimetype": "text/x-python", 33 | "name": "python", 34 | "nbconvert_exporter": "python", 35 | "pygments_lexer": "ipython3", 36 | "version": "3.9.9" 37 | } 38 | }, 39 | "nbformat": 4, 40 | "nbformat_minor": 5 41 | } 42 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # General imports. 2 | import os 3 | import zipfile 4 | from glob import glob 5 | from shutil import copyfileobj 6 | from shutil import move 7 | import h5py 8 | from skimage.measure import regionprops 9 | import requests 10 | from tqdm import tqdm 11 | # saving images to tif formats 12 | import imageio.v3 as imageio 13 | # from numpyencoder import NumpyEncoder 14 | import json 15 | import numpy as np 16 | import torch 17 | 18 | 19 | # download the data using the requests library 20 | # this function is for downloading the data. 21 | # you don't need to understand what's going on here. 22 | def download_url(url, path): 23 | # If the file to be downloaded already exists, quit here. 24 | if os.path.isfile(path): 25 | return 26 | with requests.get(url, stream=True) as r: 27 | if r.status_code != 200: 28 | r.raise_for_status() 29 | raise RuntimeError(f"Request to {url} returned status code {r.status_code}") 30 | file_size = int(r.headers.get("Content-Length", 0)) 31 | desc = f"Download {url} to {path}" 32 | if file_size == 0: 33 | desc += " (unknown file size)" 34 | with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw, open(path, "wb") as f: 35 | copyfileobj(r_raw, f) 36 | 37 | 38 | # unzip the data using the zipfile library 39 | # function for unzipping the archive we have just downloaded and then removing the zip 40 | def unzip(zip_path, dst, remove=True): 41 | with zipfile.ZipFile(zip_path, "r") as f: 42 | f.extractall(dst) 43 | if remove: 44 | os.remove(zip_path) 45 | 46 | 47 | # We use a visitor pattern to check out the contents of the file. 48 | # the 'inspector' function will be called for each element in the file hierarchy. 49 | def inspector(name, node): 50 | # hdf5 files contain 'Dataset' that hold the actual data. With the function below we print 51 | # the name and shape if the inspector function encounters a dataset 52 | if isinstance(node, h5py.Dataset): 53 | print("The h5 file contains a dataset @", name, "with shape", node.shape) 54 | 55 | 56 | class CustomEncoder(json.JSONEncoder): 57 | def default(self, obj): 58 | if isinstance(obj, np.ndarray): 59 | return obj.tolist() 60 | elif isinstance(obj, np.int64): 61 | return int(obj) 62 | elif isinstance(obj, np.int32): 63 | return int(obj) 64 | return super(CustomEncoder, self).default(obj) 65 | 66 | 67 | # convert images from hd5 to tiff: 68 | # extract 5 images from 1 hd5 file and 69 | # put them in a designated directory 70 | def convert_hdf5_to_tif(paths, file_folders, data_folder): 71 | count = 0 72 | for file_path in tqdm(paths): 73 | with h5py.File(file_path, 'r') as f: 74 | # get image dataset 75 | marker = f["raw/marker/s0"][:] 76 | nucleus_image = f["raw/nuclei/s0"][:] 77 | cells = f["labels/cells/s0"][:] 78 | infected_labels = f["labels/infected/nuclei/s0"][:] 79 | serum_image = f["raw/serum_IgG/s0"][:] 80 | nuclei_labels = f["labels/nuclei/s0"][:] 81 | table_infected = f["tables/infected_labels/cells"][:] 82 | # print(f'Image Dataset info: Shape={marker.shape},Dtype={marker.dtype}') 83 | img_ds = { 84 | "marker_image": marker, 85 | "nucleus_image": nucleus_image, 86 | "cell_labels": cells, 87 | "infected_labels": infected_labels, 88 | "serum_image": serum_image, 89 | "nucleus_labels": nuclei_labels 90 | } 91 | # create a subdirectory for each hd5 file 92 | folder_name = f"gt_image_{count:03}" 93 | file_folders.append(folder_name) 94 | img_dir = os.path.join(data_folder, folder_name) 95 | os.makedirs(img_dir, exist_ok=True) 96 | bboxes = {} 97 | for key, value in img_ds.items(): 98 | img_name = f"{img_dir}_{key}.tif" 99 | imageio.imwrite(img_name, value, compression="zlib") 100 | if "cell" in img_name: 101 | im = imageio.imread(img_name) 102 | regions = regionprops(im) 103 | for props in regions: 104 | bboxes[props.label] = props.bbox 105 | # copy to subdirectoy 106 | move(img_name, os.path.join(img_dir, os.path.basename(img_name))) 107 | # convert from byte to float, then to int 108 | table_infected = (table_infected.astype(float)).astype(int) 109 | labels = { 110 | "cells": [ 111 | { 112 | "cell_id": id_value, 113 | "infected_label": status, 114 | "bbox": bboxes[id_value] if id_value in bboxes.keys() else None 115 | } for id_value, status in zip(table_infected[:, 0], table_infected[:, 1]) 116 | ] 117 | } 118 | # Assuming labels is a dictionary containing various types including NumPy arrays and int64 119 | labels_serializable = {key: value for key, value in labels.items()} 120 | 121 | # Write to JSON file using the custom encoder 122 | with open("labels.json", "w") as f: 123 | json.dump(labels_serializable, f, ensure_ascii=False, cls=CustomEncoder) 124 | 125 | move("labels.json", os.path.join(img_dir, "labels.json")) 126 | count += 1 127 | return file_folders 128 | 129 | 130 | def divide_data(n_train, n_val, file_folders, data_folder): 131 | 132 | train_folder = os.path.join(data_folder, "train") 133 | os.makedirs(train_folder, exist_ok=True) 134 | for train_image_dir in file_folders[:n_train]: 135 | if (not os.path.exists(os.path.join(train_folder, train_image_dir))): 136 | move( 137 | os.path.join(data_folder, train_image_dir), 138 | os.path.join(train_folder, os.path.basename(train_image_dir)) 139 | ) 140 | 141 | val_folder = os.path.join(data_folder, "val") 142 | os.makedirs(val_folder, exist_ok=True) 143 | for val_image_dir in file_folders[n_train:n_train+n_val]: 144 | if (not os.path.exists(os.path.join(val_folder, val_image_dir))): 145 | move(os.path.join(data_folder, val_image_dir), os.path.join(val_folder, os.path.basename(val_image_dir))) 146 | 147 | test_folder = os.path.join(data_folder, "test") 148 | os.makedirs(test_folder, exist_ok=True) 149 | for test_image_dir in file_folders[n_train+n_val:]: 150 | if (not os.path.exists(os.path.join(test_folder, test_image_dir))): 151 | move(os.path.join(data_folder, test_image_dir), os.path.join(test_folder, os.path.basename(test_image_dir))) 152 | return train_folder, val_folder, test_folder 153 | 154 | 155 | # This function takes a multi-channel input tensor and flattens it into a 2D tensor 156 | # by moving the channel axis to the first position and then flattening all other axes. 157 | def flatten_samples(input_): 158 | # Get number of channels 159 | num_channels = input_.size(1) 160 | # Permute the channel axis to first 161 | permute_axes = list(range(input_.dim())) 162 | permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0] 163 | # For input shape (say) NCHW, this should have the shape CNHW 164 | permuted = input_.permute(*permute_axes).contiguous() 165 | # Now flatten out all but the first axis and return 166 | flattened = permuted.view(num_channels, -1) 167 | return flattened 168 | 169 | # This function computes the Dice similarity coefficient between the predicted input and the target. 170 | # It's commonly used in evaluating the performance of segmentation models. 171 | def dice_score(input_, target, eps=1e-7): 172 | assert input_.shape == target.shape, f"{input_.shape}, {target.shape}" 173 | # Flatten input and target to have the shape (C, N), 174 | # where N is the number of samples 175 | input_ = flatten_samples(torch.sigmoid(input_)) 176 | target = flatten_samples(target) 177 | # Compute numerator and denominator (by summing over samples and 178 | # leaving the channels intact) 179 | numerator = (input_ * target).sum(-1) 180 | denominator = (input_ * input_).sum(-1) + (target * target).sum(-1) 181 | channelwise_score = 2 * (numerator / denominator.clamp(min=eps)) 182 | # take the average score over the channels 183 | score = channelwise_score.mean() 184 | 185 | return score 186 | 187 | # function to combine all data preparation steps 188 | def prepare_data(data_folder="data", remove_h5=True): 189 | """ 190 | :param string data_folder: folder for saving the data 191 | :param remove_h5: remove the h5 files after converting to tif 192 | :returns: 193 | - train_folder - path to folder with subfolders for training data 194 | - val_folder - path to folder with subfolders for validation data 195 | - test_folder - path to folder with subfolders for testing data 196 | """ 197 | os.makedirs(data_folder, exist_ok=True) 198 | data_url = "https://zenodo.org/record/5092850/files/covid-if-groundtruth.zip?download=1" 199 | download_url(data_url, os.path.join(data_folder, "data.zip")) 200 | unzip(os.path.join(data_folder, "data.zip"), data_folder, remove=True) 201 | file_paths = glob(os.path.join(data_folder, "*.h5")) 202 | file_folders = [] 203 | file_folders = convert_hdf5_to_tif(file_paths, file_folders, data_folder) 204 | train_folder, val_folder, test_folder = divide_data(35, 5, file_folders, data_folder) 205 | if remove_h5: 206 | for h5_file in file_paths: 207 | os.remove(h5_file) 208 | # double check that we have the correct number of images in the split folders 209 | print("We have", len(os.listdir(train_folder)), "training images in", train_folder) 210 | print("We have", len(os.listdir(val_folder)), "validation images in", val_folder) 211 | print("We have", len(os.listdir(test_folder)), "test images in", test_folder) 212 | 213 | 214 | if __name__ == "__main__": 215 | prepare_data() 216 | --------------------------------------------------------------------------------